TOKEN={eos_token}")
95 | try:
96 | fail = acceptor.advance_token(eos_token)
97 | except acceptor.TokenRejected:
98 | fail = True
99 | if fail:
100 | print("[FAIL]")
101 | result = 1
102 | else:
103 | print("[SUCCESS]")
104 | result = 0
105 | if debug:
106 | debug("\n".join(repr(c) for c in acceptor.cursors))
107 | if args.paths:
108 | print(json.dumps(values_by_path, indent=2))
109 | return result
110 |
111 |
112 | if __name__ == "__main__":
113 | sys.exit(main())
114 |
--------------------------------------------------------------------------------
/src/examples/llm_schema.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-class-docstring,missing-function-docstring
2 | """
3 | Example of JSON schema decoding with MLX.
4 | """
5 | import argparse
6 | import json
7 | import time
8 | from math import inf
9 | from operator import itemgetter
10 | from typing import Iterable, Optional, Union
11 |
12 | import mlx.core as mx
13 | import mlx.nn as nn
14 |
15 | from mlx_lm.utils import load
16 |
17 | from llm_structured_output import JsonSchemaAcceptorDriver
18 | from llm_structured_output.util.bitmap import (
19 | bias_logits,
20 | count_set_bits,
21 | enumerate_set_bits,
22 | )
23 | from llm_structured_output.util.output import info, bold, bolddim, debug
24 | from llm_structured_output.util.tokenization import HuggingfaceTokenizerHelper
25 |
26 | from .reusable_kv_cache import ReusableKVCache
27 |
28 |
29 | class RejectedCompletion(Exception):
30 | """
31 | It's rare, but sometimes we reach a state where it's not possible to
32 | advance the acceptor. For example, when closing a JSON string we get
33 | a higher probability for slanted quotes than straight ones and select
34 | the wrong token. At that point, the LLM will continue generating with
35 | the prior that the string is closed, but our acceptor will remain in
36 | the string-accepting state. This can indicate an issue with the
37 | tokenizer vocabulary passed to the acceptor, or a bug in the code
38 | used to decode tokens from the LLM. If none of these apply, check that
39 | the LLM is actually able to generate JSON, although most are.
40 | """
41 |
42 |
43 | class Model:
44 | def __init__(self):
45 | mx.random.seed(0)
46 | self.model = None
47 | self.tokenizer = None
48 | self.vocabulary = None
49 | self.eos_id = None
50 | self.json_schema_acceptor_driver_factory = None
51 | self._cached_prompt = None
52 | self._cached_cache = None
53 |
54 | def load(self, model_path: str):
55 | """
56 | Load locally or download from Huggingface hub.
57 | """
58 | self.model, tokenizer = load(model_path)
59 | self.tokenizer = HuggingfaceTokenizerHelper(tokenizer)
60 | self.vocabulary, self.eos_id = self.tokenizer.extract_vocabulary()
61 | self.json_schema_acceptor_driver_factory = (
62 | JsonSchemaAcceptorDriver.driver_factory_for_model(
63 | self.vocabulary, self.eos_id
64 | )
65 | )
66 |
67 | def get_driver_for_json_schema(self, schema, encapsulated: bool = False):
68 | return self.json_schema_acceptor_driver_factory(
69 | schema, is_encapsulated_json=encapsulated
70 | )
71 |
72 | def _evaluate_prompt(
73 | self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
74 | ):
75 | if prior_prompt:
76 | i = 0
77 | for i, t in enumerate(prior_prompt):
78 | # We need to leave at least one token to evaluate because we don't
79 | # save the past logits.
80 | if i >= len(prompt) - 1 or prompt[i] != t:
81 | break
82 | cache = prior_cache
83 | for layer_cache in cache:
84 | layer_cache.reuse(len(prompt), i)
85 | tokens = prompt[i:]
86 | else:
87 | cache = ReusableKVCache.for_model(self.model)
88 | tokens = prompt
89 |
90 | logits = self.model(mx.array(tokens)[None], cache=cache)
91 | return logits, cache
92 |
93 | def _decode(self, tokens):
94 | return self.tokenizer.no_strip_decode(tokens)
95 |
96 | def _debug_top_tokens(self, logits, count=10):
97 | token_logits = sorted(
98 | enumerate(logits.tolist()), key=itemgetter(1), reverse=True
99 | )
100 | top_tokens = [
101 | (self._decode([t]), p) for t, p in token_logits[:count] if p != -inf
102 | ]
103 | debug("TOP TOKENS:", top_tokens)
104 |
105 | def _sample(self, logits, temp: float = 0):
106 | if temp == 0:
107 | result = mx.argmax(logits, axis=-1)
108 | else:
109 | result = mx.random.categorical(logits * (1 / temp))
110 | return result.item()
111 |
112 | def _sample_with_bias(
113 | self, logits, temp: float = 0, token_acceptor=None, lazy_bias: bool = True
114 | ):
115 | if token_acceptor is None:
116 | return self._sample(logits, temp)
117 |
118 | if lazy_bias:
119 | token = self._sample(logits, temp)
120 | try:
121 | token_acceptor.advance_token(token)
122 | return token
123 | except JsonSchemaAcceptorDriver.TokenRejected:
124 | pass
125 |
126 | accepted_token_bitmap = token_acceptor.select_valid_tokens()
127 | if not accepted_token_bitmap:
128 | debug(token_acceptor.cursors)
129 | self._debug_top_tokens(logits)
130 | raise RejectedCompletion()
131 | token = self._sample(bias_logits(mx, logits, accepted_token_bitmap), temp)
132 | token_acceptor.advance_token(token)
133 | return token
134 |
135 | def generate_without_schema(self, logits, cache, temp: Optional[float] = 0.0):
136 | """
137 | For testing / comparison purposes.
138 | """
139 | while True:
140 | tokens = [self._sample(logits[0, -1, :], temp)]
141 | yield tokens
142 | if tokens[-1] == self.eos_id:
143 | break
144 | logits = self.model(mx.array(tokens)[None], cache=cache)
145 |
146 | def generate_with_schema(
147 | self, logits, cache, token_acceptor, temp: Optional[float] = 0.0
148 | ):
149 | while True:
150 | tokens = [self._sample_with_bias(logits[0, -1, :], temp, token_acceptor)]
151 | yield tokens
152 | if tokens[-1] == self.eos_id:
153 | break
154 | logits = self.model(mx.array(tokens)[None], cache=cache)
155 |
156 | def generate_with_preemptive_decoding(
157 | self,
158 | logits,
159 | cache,
160 | token_acceptor,
161 | temp: Optional[float] = 0.0,
162 | max_batch_size=5,
163 | ):
164 | """
165 | Try to generate faster by precomputing two tokens at a time when possible.
166 | If we know that the acceptor will only accept a small set of tokens after
167 | the current one, we can evaluate a batch with one entry per possible
168 | future token. Each entry in the batch contains the current token sampled,
169 | which we have to evaluate anyway, and a second token corresponding to one
170 | of the possible tokens that could be sampled from the output to the first
171 | token. We get back logits for both tokens for each item in the batch: the
172 | logits for the first token will be the same (as long as the model applies
173 | a causal mask), and we can sample those logits to select from which of the
174 | items in the batch we can select the second token.
175 | In practice, this only seems to accelerate things for unquantized models.
176 | """
177 | # Sample token from prompt evaluation
178 | accepted_token_bitmap = token_acceptor.select_valid_tokens()
179 | first_token_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap)
180 | first_token = self._sample(first_token_logits, temp)
181 | tokens = [first_token]
182 | yield tokens
183 | token_acceptor.advance_token(first_token)
184 | accepted_token_bitmap = token_acceptor.select_valid_tokens()
185 |
186 | while True:
187 | last_token = tokens[-1]
188 | if count_set_bits(accepted_token_bitmap) in range(1, max_batch_size + 1):
189 | # If the number of possible follow-up tokens is small, submit for
190 | # evaluation a batch of 2-token continuations.
191 | batch = []
192 | for followup_token in enumerate_set_bits(accepted_token_bitmap):
193 | batch.append([last_token, followup_token])
194 | # Re-shape the cache to match the input.
195 | for layer_cache in cache:
196 | layer_cache.keys = mx.concatenate([layer_cache.keys] * len(batch))
197 | layer_cache.values = mx.concatenate(
198 | [layer_cache.values] * len(batch)
199 | )
200 | else: # Otherwise, submit the normal one-token continuation.
201 | batch = [[last_token]]
202 |
203 | logits = self.model(mx.array(batch), cache=cache)
204 | mx.eval(logits)
205 |
206 | first_token_logits = bias_logits(mx, logits[0, 0, :], accepted_token_bitmap)
207 | first_token = self._sample(first_token_logits, temp)
208 | tokens = [first_token]
209 |
210 | if first_token == self.eos_id:
211 | yield tokens
212 | break
213 |
214 | token_acceptor.advance_token(first_token)
215 | accepted_token_bitmap = token_acceptor.select_valid_tokens()
216 | if not accepted_token_bitmap:
217 | raise RejectedCompletion()
218 |
219 | # If we had submitted 2-token continuations, we can decode a second token
220 | if len(batch[0]) > 1:
221 | index = next( # Find which of the second tokens was selected
222 | i
223 | for i, batch_item in enumerate(batch)
224 | if batch_item[1] == first_token
225 | )
226 | second_token_logits = bias_logits(
227 | mx, logits[index, 1, :], accepted_token_bitmap
228 | )
229 | second_token = self._sample(second_token_logits, temp)
230 | tokens.append(second_token)
231 |
232 | token_acceptor.advance_token(second_token)
233 | accepted_token_bitmap = token_acceptor.select_valid_tokens()
234 |
235 | # Select the accepted generation in the cache, restoring it to batch dimension 1.
236 | for layer_cache in cache:
237 | layer_cache.keys = layer_cache.keys.split([index, index + 1])[1]
238 | layer_cache.values = layer_cache.values.split([index, index + 1])[1]
239 |
240 | yield tokens
241 |
242 | def _generate_tokens(
243 | self,
244 | generator: Iterable,
245 | max_tokens: int = 1000,
246 | ) -> Iterable:
247 | start_time = time.time_ns()
248 | token_count = 0
249 |
250 | for tokens in generator:
251 | token_count += len(tokens)
252 |
253 | try:
254 | eos_index = tokens.index(self.eos_id)
255 | tokens = tokens[0:eos_index]
256 | except ValueError:
257 | eos_index = -1
258 |
259 | if tokens:
260 | text = self._decode(tokens)
261 | yield {
262 | "op": "generatedTokens",
263 | "text": text,
264 | "token_count": len(tokens),
265 | "time_ms": (time.time_ns() - start_time) / 1e6,
266 | }
267 |
268 | if eos_index >= 0:
269 | yield {"op": "stop", "reason": "end"}
270 | return
271 |
272 | if token_count >= max_tokens:
273 | yield {"op": "stop", "reason": "max_tokens"}
274 | return
275 |
276 | start_time = time.time_ns()
277 |
278 | assert False
279 |
280 | def completion(
281 | self,
282 | prompt: Union[str, Iterable[dict[str, str]]],
283 | schema: dict,
284 | encapsulated: bool = False,
285 | max_tokens: int = 1000,
286 | temp: float = 0.0,
287 | seed: int = None,
288 | preemptive_batch_size: int = 0,
289 | cache_prompt: bool = False,
290 | ):
291 | if seed is not None:
292 | mx.random.seed(seed)
293 |
294 | start_time = time.time_ns()
295 | prompt_tokens = self.tokenizer.encode_prompt(prompt)
296 | logits, cache = self._evaluate_prompt(
297 | prompt_tokens, self._cached_prompt, self._cached_cache
298 | )
299 | if cache_prompt:
300 | self._cached_prompt = prompt_tokens
301 | self._cached_cache = cache
302 | # Eager eval to more accurately reflect the prompt evaluation time.
303 | mx.eval(logits)
304 | prompt_time = time.time_ns() - start_time
305 | yield {
306 | "op": "evaluatedPrompt",
307 | "prompt": prompt,
308 | "token_count": len(prompt_tokens),
309 | "time_ms": prompt_time / 1e6,
310 | "prompt_tps": len(prompt_tokens) / (prompt_time / 1e9),
311 | }
312 |
313 | if schema:
314 | token_acceptor = self.get_driver_for_json_schema(schema, encapsulated)
315 | if preemptive_batch_size > 0:
316 | generator = self.generate_with_preemptive_decoding(
317 | logits,
318 | cache,
319 | token_acceptor,
320 | temp,
321 | max_batch_size=preemptive_batch_size,
322 | )
323 | else:
324 | generator = self.generate_with_schema(
325 | logits, cache, token_acceptor, temp
326 | )
327 | else:
328 | generator = self.generate_without_schema(logits, cache, temp)
329 |
330 | token_count = 0
331 | generation_time = 0
332 | for generation_result in self._generate_tokens(generator, max_tokens):
333 | if generation_result["op"] == "generatedTokens":
334 | token_count += generation_result["token_count"]
335 | generation_time += generation_result["time_ms"]
336 | elif generation_result["op"] == "stop":
337 | generation_result["token_count"] = token_count
338 | generation_result["time_ms"] = generation_time
339 | # This is slightly incorrect, because the first token is generated
340 | # from the prompt evaluation.
341 | generation_result["generation_tps"] = token_count / (
342 | generation_time / 1e3
343 | )
344 | yield generation_result
345 |
346 |
347 | def main():
348 | parser = argparse.ArgumentParser(
349 | description="LLM inference script with schema-constrained sampling"
350 | )
351 | parser.add_argument(
352 | "--model-path",
353 | type=str,
354 | default="mlx_model",
355 | help="The path to the model weights and tokenizer",
356 | )
357 | parser.add_argument(
358 | "--prompt",
359 | default="Once upon a midnight dreary",
360 | help="The message to be processed by the model",
361 | )
362 | parser.add_argument(
363 | "--max-tokens",
364 | "-m",
365 | type=int,
366 | default=100,
367 | help="Maximum number of tokens to generate",
368 | )
369 | parser.add_argument(
370 | "--temp",
371 | help="The sampling temperature.",
372 | type=float,
373 | default=0.0,
374 | )
375 | parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
376 | parser.add_argument(
377 | "--repeat-prompt",
378 | action=argparse.BooleanOptionalAction,
379 | help="Print prompt before start of generation",
380 | )
381 | parser.add_argument(
382 | "--schema",
383 | help="A JSON schema to constrain the output.",
384 | type=str,
385 | default=None,
386 | )
387 | parser.add_argument(
388 | "--encapsulated",
389 | action=argparse.BooleanOptionalAction,
390 | help="Whether the LLM is expected to encapsulate the JSON within ```json and ```.",
391 | )
392 | parser.add_argument(
393 | "--preemptive",
394 | type=int,
395 | default=0,
396 | help="If greater than zero, the maximum size of the batch for pre-emptive decoding",
397 | )
398 |
399 | args = parser.parse_args()
400 |
401 | info("Loading model from disk.")
402 | model = Model()
403 | model.load(args.model_path)
404 |
405 | if args.schema is not None:
406 | schema = json.loads(args.schema)
407 | info("Using schema")
408 | else:
409 | schema = None
410 | info("Starting generation...")
411 |
412 | for result in model.completion(
413 | prompt=args.prompt,
414 | schema=schema,
415 | encapsulated=args.encapsulated,
416 | max_tokens=args.max_tokens,
417 | temp=args.temp,
418 | seed=args.seed,
419 | preemptive_batch_size=args.preemptive,
420 | ):
421 | if result["op"] == "evaluatedPrompt":
422 | prompt_token_count = result["token_count"]
423 | prompt_time = result["time_ms"]
424 | prompt_tps = result["prompt_tps"]
425 | if args.repeat_prompt:
426 | bolddim(result["prompt"], flush=True)
427 | elif result["op"] == "generatedTokens":
428 | bold(result["text"], end="", flush=True)
429 | elif result["op"] == "stop":
430 | end_reason = result["reason"]
431 | generated_token_count = result["token_count"]
432 | generation_time = result["time_ms"]
433 | generation_tps = result["generation_tps"]
434 | else:
435 | assert False
436 |
437 | print()
438 | info(f"End reason: {end_reason}")
439 | info(f"Tokens: prompt {prompt_token_count}, generation {generated_token_count}")
440 | info(f"Tokens per second: prompt {prompt_tps:.2f}, generation {generation_tps:.2f}")
441 | info(f"Total time: prompt {prompt_time:.2f}ms, generation {generation_time:.2f}ms")
442 |
443 |
444 | if __name__ == "__main__":
445 | main()
446 |
--------------------------------------------------------------------------------
/src/examples/reluctance.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-class-docstring,missing-function-docstring
2 | """
3 | Example of JSON schema decoding with MLX.
4 | """
5 | import argparse
6 | import json
7 |
8 | import mlx.core as mx
9 | import mlx.nn as nn
10 | from mlx_lm.utils import load
11 |
12 | from llm_structured_output import (
13 | JsonSchemaAcceptorDriver,
14 | HuggingfaceTokenizerHelper,
15 | bias_logits,
16 | )
17 | from llm_structured_output.util.output import info, setbg, setfg, clear
18 |
19 | from .reusable_kv_cache import ReusableKVCache
20 |
21 |
22 | def compute_reluctance(logits, accepted_token_bitmap) -> float:
23 | """
24 | Sum the probabilities of each token that has higher probability than
25 | the highest-probability token selected by the schema. This gives an
26 | idea of the model's preference for tokens that don't follow the schema.
27 | """
28 | p = nn.softmax(logits)
29 | indices = mx.argsort(p)[::-1]
30 | r = 0
31 | for i in indices.tolist():
32 | if (1 << i) & accepted_token_bitmap:
33 | break
34 | r += p[i].item()
35 | return r
36 |
37 |
38 | def main():
39 | parser = argparse.ArgumentParser(
40 | description="Visualize LLM reluctance to generate according to the schema."
41 | )
42 | parser.add_argument(
43 | "--model-path",
44 | type=str,
45 | default="mlx_model",
46 | help="The path to the model weights and tokenizer",
47 | )
48 | parser.add_argument(
49 | "--schema",
50 | help="A JSON schema to constrain the output.",
51 | type=str,
52 | )
53 | parser.add_argument(
54 | "--prompt",
55 | help="The message to be processed by the model",
56 | )
57 | parser.add_argument(
58 | "--max-tokens",
59 | "-m",
60 | type=int,
61 | default=1000,
62 | help="Maximum number of tokens to generate",
63 | )
64 |
65 | args = parser.parse_args()
66 |
67 | info("Loading model from disk...")
68 | model, tokenizer = load(args.model_path)
69 | schema = json.loads(args.schema)
70 |
71 | tokenizer_helper = HuggingfaceTokenizerHelper(tokenizer)
72 | vocabulary, eos_id = tokenizer_helper.extract_vocabulary()
73 | token_acceptor_factory = JsonSchemaAcceptorDriver.driver_factory_for_model(vocabulary, eos_id)
74 | token_acceptor = token_acceptor_factory(schema)
75 |
76 |
77 | info("Starting generation...")
78 | tokens = tokenizer_helper.encode_prompt(args.prompt)
79 | cache = ReusableKVCache.for_model(model)
80 | while tokens[-1] != eos_id:
81 | logits = model(mx.array(tokens)[None], cache)
82 | accepted_token_bitmap = token_acceptor.select_valid_tokens()
83 | reluctance = compute_reluctance(logits[0, -1, :], accepted_token_bitmap)
84 | biased_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap)
85 | token = mx.argmax(biased_logits, axis=-1).item()
86 | if token == eos_id:
87 | break
88 | tokens = [token]
89 | text = tokenizer_helper.no_strip_decode(tokens)
90 | setbg(reluctance, 0.8 * (1 - reluctance), 0)
91 | setfg(1, 1, 1)
92 | print(text, end="")
93 | token_acceptor.advance_token(token)
94 | clear()
95 | print()
96 |
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/src/examples/requirements.txt:
--------------------------------------------------------------------------------
1 | mlx >= 0.19.1
2 | mlx-lm >= 0.19.2
3 | tokenizers >= 0.20.1
4 | sentencepiece
5 | fastapi
6 | pydantic
7 | uvicorn
8 |
--------------------------------------------------------------------------------
/src/examples/reusable_kv_cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper with improvements over mlx-lm's KVCache.
3 | """
4 |
5 | import mlx.core as mx
6 | from mlx_lm.models.cache import KVCache
7 |
8 |
9 | class ReusableKVCache(KVCache):
10 | """
11 | Usability improvements over KVCache.
12 | """
13 |
14 | @classmethod
15 | def for_model(cls, model):
16 | return [cls() for _ in model.layers]
17 |
18 | def reuse(self, new_prompt_length, common_prefix_length):
19 | """
20 | Reuse (part of) this cache for a new prompt that shares a prefix with it.
21 | """
22 | if self.keys is None:
23 | return
24 | # Clip the cache to the common length.
25 | self.offset = common_prefix_length
26 | # Make sure the cache can fit the whole prompt. Because the offset is
27 | # (very likely) not a multiple of the step size, update_and_fetch()
28 | # won't resize the cache when evaluating the rest of the prompt as it
29 | # would if it were an empty cache.
30 | current_size = self.keys.shape[2]
31 | if current_size < new_prompt_length:
32 | _, n_kv_heads, _, k_head_dim = self.keys.shape
33 | v_head_dim = self.values.shape[3]
34 | n_steps = (self.step + new_prompt_length - 1) // self.step
35 | k_add_shape = (1, n_kv_heads, n_steps * self.step - current_size, k_head_dim)
36 | v_add_shape = (1, n_kv_heads, n_steps * self.step - current_size, v_head_dim)
37 | k_zeros = mx.zeros(k_add_shape, self.keys.dtype)
38 | v_zeros = mx.zeros(v_add_shape, self.values.dtype)
39 | self.keys = mx.concatenate([self.keys, k_zeros], axis=2)
40 | self.values = mx.concatenate([self.values, v_zeros], axis=2)
41 |
42 | def update_and_fetch(self, keys, values):
43 | """
44 | Override the base class method to allow the cache to be used with batches of
45 | size greater than 1.
46 | This is just a tiny change in the line that determines the shape.
47 | """
48 | prev = self.offset
49 | if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
50 | B, n_kv_heads, _, k_head_dim = keys.shape
51 | v_head_dim = values.shape[3]
52 | n_steps = (self.step + keys.shape[2] - 1) // self.step
53 | k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
54 | v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
55 | new_k = mx.zeros(k_shape, keys.dtype)
56 | new_v = mx.zeros(v_shape, values.dtype)
57 | if self.keys is not None:
58 | if prev % self.step != 0:
59 | self.keys = self.keys[..., :prev, :]
60 | self.values = self.values[..., :prev, :]
61 | self.keys = mx.concatenate([self.keys, new_k], axis=2)
62 | self.values = mx.concatenate([self.values, new_v], axis=2)
63 | else:
64 | self.keys, self.values = new_k, new_v
65 |
66 | self.offset += keys.shape[2]
67 | self.keys[..., prev : self.offset, :] = keys
68 | self.values[..., prev : self.offset, :] = values
69 | return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
70 |
71 |
--------------------------------------------------------------------------------
/src/examples/server.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring,missing-class-docstring
2 | """
3 | Example model server with OpenAI-like API, including function calls / tools.
4 | """
5 | import json
6 | import time
7 | import os
8 | from enum import Enum
9 | from traceback import format_exc
10 | from typing import Literal, List, Optional, Union
11 |
12 | from fastapi import FastAPI, Request, status
13 | from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
14 | from fastapi.exceptions import RequestValidationError
15 | from pydantic import BaseModel
16 |
17 | from examples.llm_schema import Model
18 | from llm_structured_output.util.output import info, warning, debug
19 |
20 |
21 | app = FastAPI()
22 |
23 | model = Model()
24 | info("Loading model...")
25 | try:
26 | model_path = os.environ["MODEL_PATH"]
27 | model.load(model_path)
28 | except KeyError:
29 | warning("Need to specify MODEL_PATH environment variable")
30 |
31 |
32 | @app.exception_handler(RequestValidationError)
33 | # pylint: disable-next=unused-argument
34 | async def validation_exception_handler(request: Request, exc: RequestValidationError):
35 | exc_str = f"{exc}"
36 | warning(f"RequestValidationError: {exc_str}")
37 | content = {"error": exc_str}
38 | return JSONResponse(
39 | content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
40 | )
41 |
42 |
43 | @app.get("/status")
44 | def get_status():
45 | return {"status": "OK"}
46 |
47 |
48 | @app.get("/")
49 | def get_root():
50 | return FileResponse(f"{os.path.dirname(os.path.realpath(__file__))}/static/ui.html")
51 |
52 |
53 | class V1ChatMessageRole(str, Enum):
54 | SYSTEM = "system"
55 | USER = "user"
56 | ASSISTANT = "assistant"
57 |
58 |
59 | class V1ChatMessage(BaseModel):
60 | role: V1ChatMessageRole
61 | content: str
62 |
63 |
64 | class V1Function(BaseModel):
65 | name: str
66 | description: str = ""
67 | parameters: dict = {}
68 |
69 |
70 | class V1ToolFunction(BaseModel):
71 | type: Literal["function"]
72 | function: V1Function
73 |
74 |
75 | class V1ToolChoiceKeyword(str, Enum):
76 | AUTO = "auto"
77 | NONE = "none"
78 |
79 |
80 | class V1ToolChoiceFunction(BaseModel):
81 | type: Optional[Literal["function"]] = None
82 | name: str
83 |
84 |
85 | class V1ToolOptions(BaseModel): # Non-standard, our addition.
86 | # We automatically add instructions with the JSON schema
87 | # for the tool calls to the prompt. This option disables
88 | # it and is useful when the user prompt already includes
89 | # the schema and relevant instructions.
90 | no_prompt_steering: bool = False
91 |
92 |
93 | class V1ResponseFormatType(str, Enum):
94 | JSON_OBJECT = "json_object"
95 |
96 |
97 | class V1ResponseFormat(BaseModel):
98 | type: V1ResponseFormatType
99 | # schema is our addition, not an OpenAI API parameter
100 | schema: str = None
101 |
102 |
103 | class V1StreamOptions(BaseModel):
104 | include_usage: bool = False
105 |
106 |
107 | class V1ChatCompletionsRequest(
108 | BaseModel
109 | ): # pylint: disable=too-many-instance-attributes
110 | model: str = "default"
111 | max_tokens: int = 1000
112 | temperature: float = 0.0
113 | messages: List[V1ChatMessage]
114 | # The 'functions' and 'function_call' fields have been dreprecated and
115 | # replaced with 'tools' and 'tool_choice', that work similarly but allow
116 | # for multiple functions to be invoked.
117 | functions: List[V1Function] = None
118 | function_call: Union[V1ToolChoiceKeyword, V1ToolChoiceFunction] = None
119 | tools: List[V1ToolFunction] = None
120 | tool_choice: Union[V1ToolChoiceKeyword, V1ToolChoiceFunction] = None
121 | tool_options: V1ToolOptions = None
122 | response_format: V1ResponseFormat = None
123 | stream: bool = False
124 | stream_options: V1StreamOptions = None
125 |
126 |
127 | @app.post("/v1/chat/completions")
128 | async def post_v1_chat_completions(request: V1ChatCompletionsRequest):
129 | debug("REQUEST", request)
130 | if request.stream:
131 | async def get_content():
132 | try:
133 | async for message in post_v1_chat_completions_impl(request):
134 | yield message
135 | # pylint: disable-next=broad-exception-caught
136 | except Exception as e:
137 | warning(format_exc())
138 | yield 'data: {"choices": [{"index": 0, "finish_reason": "error: ' + str(e) + '"}]}'
139 | return StreamingResponse(
140 | content=get_content(),
141 | media_type="text/event-stream",
142 | )
143 | else:
144 | # FUTURE: Python 3.10 can use `await anext(x))` instead of `await x.__anext__()`.
145 | try:
146 | response = await post_v1_chat_completions_impl(request).__anext__()
147 | # pylint: disable-next=broad-exception-caught
148 | except Exception as e:
149 | warning(format_exc())
150 | content = {"error": str(e)}
151 | response = JSONResponse(
152 | content=content, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
153 | )
154 | debug("RESPONSE", response)
155 | return response
156 |
157 |
158 | async def post_v1_chat_completions_impl(request: V1ChatCompletionsRequest):
159 | messages = request.messages[:]
160 |
161 | # Extract valid functions from the request.
162 | functions = []
163 | is_legacy_function_call = False
164 | if request.tool_choice == "none":
165 | pass
166 | elif request.tool_choice == "auto":
167 | functions = [tool.function for tool in request.tools if tool.type == "function"]
168 | elif request.tool_choice is not None:
169 | functions = [
170 | next(
171 | tool.function
172 | for tool in request.tools
173 | if tool.type == "function"
174 | and tool.function.name == request.function_call.name
175 | )
176 | ]
177 | elif request.function_call == "none":
178 | pass
179 | elif request.function_call == "auto":
180 | functions = request.functions
181 | is_legacy_function_call = True
182 | elif request.function_call is not None:
183 | functions = [
184 | next(
185 | fn for fn in request.functions if fn.name == request.function_call.name
186 | )
187 | ]
188 | is_legacy_function_call = True
189 |
190 | model_name = model_path
191 | schema = None
192 | if functions:
193 | # If the request includes functions, create a system prompt to instruct the LLM
194 | # to use tools, and assemble a JSON schema to steer the LLM output.
195 | if request.stream:
196 | responder = ToolCallStreamingResponder(
197 | model_name,
198 | functions,
199 | is_legacy_function_call,
200 | model,
201 | )
202 | else:
203 | responder = ToolCallResponder(
204 | model_name, functions, is_legacy_function_call
205 | )
206 | if not (request.tool_options and request.tool_options.no_prompt_steering):
207 | messages.insert(
208 | 0,
209 | V1ChatMessage(
210 | role="system",
211 | content=responder.tool_prompt,
212 | ),
213 | )
214 | schema = responder.schema
215 | else:
216 | if request.response_format:
217 | assert request.response_format.type == V1ResponseFormatType.JSON_OBJECT
218 | # The request may specify a JSON schema (this option is not in the OpenAI API)
219 | if request.response_format.schema:
220 | schema = json.loads(request.response_format.schema)
221 | else:
222 | schema = {"type": "object"}
223 | if request.stream:
224 | responder = ChatCompletionStreamingResponder(model_name, schema, model)
225 | else:
226 | responder = ChatCompletionResponder(model_name)
227 |
228 | if schema is not None:
229 | debug("Using schema:", schema)
230 |
231 | info("Starting generation...")
232 |
233 | prompt_tokens = None
234 |
235 | for result in model.completion(
236 | messages,
237 | schema=schema,
238 | max_tokens=request.max_tokens,
239 | temp=request.temperature,
240 | cache_prompt=True,
241 | ):
242 | if result["op"] == "evaluatedPrompt":
243 | prompt_tokens = result["token_count"]
244 | elif result["op"] == "generatedTokens":
245 | message = responder.generated_tokens(result["text"])
246 | if message:
247 | yield message
248 | elif result["op"] == "stop":
249 | completion_tokens = result["token_count"]
250 | yield responder.generation_stopped(
251 | result["reason"], prompt_tokens, completion_tokens
252 | )
253 | else:
254 | assert False
255 |
256 |
257 | class ChatCompletionResponder:
258 | def __init__(self, model_name: str):
259 | self.object_type = "chat.completion"
260 | self.model_name = model_name
261 | self.created = int(time.time())
262 | self.id = f"{id(self)}_{self.created}"
263 | self.content = ""
264 |
265 | def message_properties(self):
266 | return {
267 | "object": self.object_type,
268 | "id": f"chatcmpl-{self.id}",
269 | "created": self.created,
270 | "model": self.model_name,
271 | }
272 |
273 | def translate_reason(self, reason):
274 | """
275 | Translate our reason codes to OpenAI ones.
276 | """
277 | if reason == "end":
278 | return "stop"
279 | if reason == "max_tokens":
280 | return "length"
281 | return f"error: {reason}" # Not a standard OpenAI API reason
282 |
283 | def format_usage(self, prompt_tokens: int, completion_tokens: int):
284 | return {
285 | "usage": {
286 | "completion_tokens": completion_tokens,
287 | "prompt_tokens": prompt_tokens,
288 | "total_tokens": completion_tokens + prompt_tokens,
289 | },
290 | }
291 |
292 | def generated_tokens(
293 | self,
294 | text: str,
295 | ):
296 | self.content += text
297 | return None
298 |
299 | def generation_stopped(
300 | self,
301 | stop_reason: str,
302 | prompt_tokens: int,
303 | completion_tokens: int,
304 | ):
305 | finish_reason = self.translate_reason(stop_reason)
306 | message = {"role": "assistant", "content": self.content}
307 | return {
308 | "choices": [
309 | {"index": 0, "message": message, "finish_reason": finish_reason}
310 | ],
311 | **self.format_usage(prompt_tokens, completion_tokens),
312 | **self.message_properties(),
313 | }
314 |
315 |
316 | class ChatCompletionStreamingResponder(ChatCompletionResponder):
317 | def __init__(self, model_name: str, schema: dict = None, _model = None):
318 | super().__init__(model_name)
319 | self.object_type = "chat.completion.chunk"
320 | if schema:
321 | assert _model
322 | self.schema_parser = _model.get_driver_for_json_schema(schema)
323 | else:
324 | self.schema_parser = None
325 |
326 | def generated_tokens(
327 | self,
328 | text: str,
329 | ):
330 | delta = {"role": "assistant", "content": text}
331 | if self.schema_parser:
332 | values = {}
333 | for char in text:
334 | self.schema_parser.advance_char(char)
335 | for path in self.schema_parser.get_current_value_paths():
336 | values[path] = values.get(path, "") + char
337 | delta["values"] = values
338 | message = {
339 | "choices": [{"index": 0, "delta": delta, "finish_reason": None}],
340 | **self.message_properties(),
341 | }
342 | return f"data: {json.dumps(message)}\n"
343 |
344 | def generation_stopped(
345 | self,
346 | stop_reason: str,
347 | prompt_tokens: int,
348 | completion_tokens: int,
349 | ):
350 | finish_reason = self.translate_reason(stop_reason)
351 | delta = {"role": "assistant", "content": ""}
352 | message = {
353 | "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
354 | # Usage field notes:
355 | # - OpenAI only sends usage in streaming if the option
356 | # stream_options.include_usage is true, but we send it always.
357 | **self.format_usage(prompt_tokens, completion_tokens),
358 | **self.message_properties(),
359 | }
360 | return f"data: {json.dumps(message)}\ndata: [DONE]\n"
361 |
362 |
363 | class ToolCallResponder(ChatCompletionResponder):
364 | def __init__(
365 | self, model_name: str, functions: list[dict], is_legacy_function_call: bool
366 | ):
367 | super().__init__(model_name)
368 |
369 | self.is_legacy_function_call = is_legacy_function_call
370 |
371 | function_schemas = [
372 | {
373 | "type": "object",
374 | "properties": {
375 | "name": {"type": "const", "const": fn.name},
376 | "arguments": fn.parameters,
377 | },
378 | "required": ["name", "arguments"],
379 | }
380 | for fn in functions
381 | ]
382 | if len(function_schemas) == 1:
383 | self.schema = function_schemas[0]
384 | self.tool_prompt = self._one_tool_prompt(functions[0], function_schemas[0])
385 | elif is_legacy_function_call: # Only allows one function to be called.
386 | self.schema = {"oneOf": function_schemas}
387 | self.tool_prompt = self._select_tool_prompt(functions, function_schemas)
388 | else:
389 | self.schema = {"type": "array", "items": {"anyOf": function_schemas}}
390 | self.tool_prompt = self._multiple_tool_prompt(functions, function_schemas)
391 |
392 | def translate_reason(self, reason):
393 | if reason == "end":
394 | if self.is_legacy_function_call:
395 | return "function_call"
396 | return "tool_calls"
397 | return super().translate_reason(reason)
398 |
399 | def generation_stopped(
400 | self,
401 | stop_reason: str,
402 | prompt_tokens: int,
403 | completion_tokens: int,
404 | ):
405 | finish_reason = self.translate_reason(stop_reason)
406 | if finish_reason == "tool_calls":
407 | tool_calls = json.loads(self.content)
408 | if not isinstance(tool_calls, list):
409 | # len(functions) == 1 was special cased
410 | tool_calls = [tool_calls]
411 | message = {
412 | "role": "assistant",
413 | "tool_calls": [
414 | {
415 | "id": f"call_{self.id}_{i}",
416 | "type": "function",
417 | "function": {
418 | "name": function_call["name"],
419 | "arguments": json.dumps(function_call["arguments"]),
420 | },
421 | }
422 | for i, function_call in enumerate(tool_calls)
423 | ],
424 | }
425 | elif finish_reason == "function_call":
426 | function_call = json.loads(self.content)
427 | message = {
428 | "role": "assistant",
429 | "function_call": {
430 | "name": function_call["name"],
431 | "arguments": json.dumps(function_call["arguments"]),
432 | },
433 | }
434 | else:
435 | message = None
436 | return {
437 | "choices": [
438 | {"index": 0, "message": message, "finish_reason": finish_reason}
439 | ],
440 | **self.format_usage(prompt_tokens, completion_tokens),
441 | **self.message_properties(),
442 | }
443 |
444 | def _one_tool_prompt(self, tool, tool_schema):
445 | return f"""
446 | You are a helpful assistant with access to a tool that you must invoke to answer the user's request.
447 | The tool is:
448 | Tool {tool.name}: {tool.description}
449 | Invocation schema: {json.dumps(tool_schema)}
450 | Your answer is a JSON object according to the invocation schema in order to answer the user request below.
451 | """
452 |
453 | def _multiple_tool_prompt(self, tools, tool_schemas, separator="\n"):
454 | return f"""
455 | You are a helpful assistant with access to tools that you must invoke to answer the user's request.
456 | The following tools are available:
457 | {separator.join([ f'''
458 | Tool {tool.name}: {tool.description}
459 | Invocation schema: {json.dumps(tool_schema)}
460 | ''' for tool, tool_schema in zip(tools, tool_schemas) ])}
461 | Your answer is a JSON array with one or more tool invocations according to the appropriate schema(s)
462 | in order to answer the user request below.
463 | """
464 |
465 | def _select_tool_prompt(self, tools, tool_schemas, separator="\n"):
466 | return f"""
467 | You are a helpful assistant with access to tools that you must invoke to answer the user's request.
468 | The following tools are available:
469 | {separator.join([ f'''
470 | Function {tool.name}: {tool.description}
471 | Tool schema: {json.dumps(tool_schema)}
472 | ''' for tool, tool_schema in zip(tools, tool_schemas) ])}
473 | Your answer is a JSON object according to the invocation schema of the most appropriate tool to use
474 | to answer the user request below.
475 | """
476 |
477 |
478 | class ToolCallStreamingResponder(ToolCallResponder):
479 | def __init__(
480 | self,
481 | model_name: str,
482 | functions: list[dict],
483 | is_legacy_function_call: bool,
484 | _model,
485 | ):
486 | super().__init__(model_name, functions, is_legacy_function_call)
487 | self.object_type = "chat.completion.chunk"
488 |
489 | # We need to parse the output as it's being generated in order to send
490 | # streaming messages that contain the name and arguments of the function
491 | # being called.
492 |
493 | self.current_function_index = -1
494 | self.current_function_name = None
495 | self.in_function_arguments = False
496 |
497 | def set_function_name(_prop_name: str, prop_value):
498 | self.current_function_index += 1
499 | self.current_function_name = prop_value
500 |
501 | def start_function_arguments(_prop_name: str):
502 | self.in_function_arguments = True
503 |
504 | def end_function_arguments(_prop_name: str, _prop_value: str):
505 | self.in_function_arguments = False
506 |
507 | hooked_function_schemas = [
508 | {
509 | "type": "object",
510 | "properties": {
511 | "name": {
512 | "type": "const",
513 | "const": fn.name,
514 | "__hooks": {
515 | "value_end": set_function_name,
516 | },
517 | },
518 | "arguments": {
519 | **fn.parameters,
520 | "__hooks": {
521 | "value_start": start_function_arguments,
522 | "value_end": end_function_arguments,
523 | },
524 | },
525 | },
526 | "required": ["name", "arguments"],
527 | }
528 | for fn in functions
529 | ]
530 | if len(hooked_function_schemas) == 1:
531 | hooked_schema = hooked_function_schemas[0]
532 | elif is_legacy_function_call:
533 | hooked_schema = {"oneOf": hooked_function_schemas}
534 | else:
535 | hooked_schema = {
536 | "type": "array",
537 | "items": {"anyOf": hooked_function_schemas},
538 | }
539 | self.tool_call_parser = _model.get_driver_for_json_schema(hooked_schema)
540 |
541 | def generated_tokens(
542 | self,
543 | text: str,
544 | ):
545 | argument_text = ""
546 | for char in text:
547 | if self.in_function_arguments:
548 | argument_text += char
549 | # Update state. This is certain to parse, no need to check for rejections.
550 | self.tool_call_parser.advance_char(char)
551 | if not argument_text:
552 | return None
553 | assert self.current_function_name
554 | if self.is_legacy_function_call:
555 | delta = {
556 | "function_call": {
557 | "name": self.current_function_name,
558 | "arguments": argument_text,
559 | }
560 | }
561 | else:
562 | delta = {
563 | "tool_calls": [
564 | {
565 | "index": self.current_function_index,
566 | "id": f"call_{self.id}_{self.current_function_index}",
567 | "type": "function",
568 | "function": {
569 | # We send the name on every update, but OpenAI only sends it on
570 | # the first one for each call, with empty arguments (""). Further
571 | # updates only have the arguments field. This is something we may
572 | # want to emulate if client code depends on this behavior.
573 | "name": self.current_function_name,
574 | "arguments": argument_text,
575 | },
576 | }
577 | ]
578 | }
579 | message = {
580 | "choices": [{"index": 0, "delta": delta, "finish_reason": None}],
581 | **self.message_properties(),
582 | }
583 | return f"data: {json.dumps(message)}\n"
584 |
585 | def generation_stopped(
586 | self,
587 | stop_reason: str,
588 | prompt_tokens: int,
589 | completion_tokens: int,
590 | ):
591 | finish_reason = self.translate_reason(stop_reason)
592 | message = {
593 | "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}],
594 | # Usage field notes:
595 | # - OpenAI only sends usage in streaming if the option
596 | # stream_options.include_usage is true, but we send it always.
597 | # - OpenAI sends two separate messages: one with the finish_reason and no
598 | # usage field, and one with an empty choices array and the usage field.
599 | **self.format_usage(prompt_tokens, completion_tokens),
600 | **self.message_properties(),
601 | }
602 | return f"data: {json.dumps(message)}\ndata: [DONE]\n"
603 |
--------------------------------------------------------------------------------
/src/examples/static/attention.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Attention
6 |
176 |
386 |
387 |
388 |
389 |
390 |
Prompt
391 |
392 |
393 |
394 |
395 |
396 |
397 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
Prompt perplexity
407 |
408 |
409 |
410 |
Attention
411 |
425 |
426 |
436 |
437 |
453 |
454 |
455 |
--------------------------------------------------------------------------------
/src/examples/static/ui.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | LLM
8 |
163 |
295 |
296 |
297 |
298 |
299 |
Prompt
300 |
301 |
302 |
303 |
304 |
305 |
306 |
308 |
309 |
310 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
Output
328 |
329 |
330 |
334 |
335 |
336 |
337 |
338 |
--------------------------------------------------------------------------------
/src/llm_structured_output/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | LLM structured output: constrain generation to a JSON schema.
3 | """
4 | from .json_schema_acceptor import JsonSchemaAcceptor, JsonSchemaAcceptorDriver
5 | from .json_acceptor import JsonAcceptor
6 | from .util.bitmap import bias_logits
7 | from .util.tokenization import HuggingfaceTokenizerHelper
8 |
--------------------------------------------------------------------------------
/src/llm_structured_output/acceptor.py:
--------------------------------------------------------------------------------
1 | """
2 | Base token acceptors.
3 |
4 | A token acceptor constrains the tokens that are acceptable at this point in
5 | the parsing or generation of a text.
6 |
7 | Since multiple parses of a given input may be possible (or multiple generations
8 | valid according to e.g. a schema), the acceptor creates multiple "cursors", one
9 | for each valid current state of the acceptor. This is akin to a chart parser,
10 | where all possible parses of the input are carried in parallel, which minimizes
11 | backtracking that is expensive on an LLM.
12 |
13 | The basic flow is:
14 | - First, the vocabulary (list of possible tokens for the LLM) is prepared into
15 | a trie for logarithmic traversal. Subclasses may also perform their own
16 | vocabulary preparation.
17 | - The acceptor's get_cursors() method is called, and the acceptor issues one or
18 | more cursors with initial state(s).
19 | - The trie is traversed to find which tokens are a valid match in the current
20 | state of the active cursor(s). For each cursor:
21 | - The select() method is called to narrow down the next character(s) that the
22 | cursor can accept in its current state.
23 | - For each selected character, we advance() the cursor, obtaining one or more
24 | follow-up cursors that represent the next state(s) of the cursor.
25 | - We descend down the trie branch corresponding to the selected character, and
26 | perform the same select(), advance() operation on the new cursor(s).
27 | - We traverse until the cursor(s) have reached an accepted state or we reach a
28 | leaf node.
29 | - As we traverse the trie recursively, we collect the token ids for each node.
30 | This creates a set of valid tokens that will be accepted.
31 |
32 | For example: if we have a TextAcceptor that will accept the word "true", the
33 | initial cursor's select() method will return "t" as the set of acceptable
34 | characters. We will then advance the cursor and obtain a cursor that accepts the
35 | word "rue", and our current trie node will become the "t" child branch of the
36 | prior trie node. We will then match the new trie node with the new acceptor, etc.
37 |
38 | Acceptors can be chained with e.g. a StateMachineAcceptor. In this case, when a
39 | cursor reaches a final state, the parent acceptor moves its own cursor forward,
40 | potentially issuing more cursors that can be matched with the remainder of the
41 | trie.
42 |
43 | Some methods have been added to help prevent combinatorial explosions while
44 | searching that can have a big effect in performance. For example, an acceptor
45 | for a quoted string can select() a very large amount of characters after the
46 | first quote. Descending upon every branch of the trie is not necessary in as
47 | much as every character is essentially equivalently valid. To avoid this, we
48 | allow the acceptor to prune the trie so that all equivalent characters are
49 | collapsed into one branch. In such a collapsed trie, each node keeps a set with
50 | all the ids for valid tokens of the same length, which are equivalent from the
51 | point of the view of the acceptor.
52 | """
53 |
54 | from __future__ import annotations
55 | from copy import copy as shallowcopy
56 | from time import time_ns
57 | from typing import Iterable, Tuple
58 |
59 | from .util.tokentrie import TokenTrie
60 |
61 |
62 | class TokenAcceptor:
63 | """
64 | Base class for token acceptors.
65 | """
66 |
67 | @classmethod
68 | def prepare_vocabulary(cls, vocabulary: Iterable[Tuple[int, str]]) -> TokenTrie:
69 | """
70 | Given a list of tokens (typically the vocabulary of an LLM), create
71 | a trie that will be used to select the tokens accepted by the current
72 | set of cursors.
73 | """
74 | vocabulary_trie = TokenTrie()
75 | vocabulary_trie.insert_all(vocabulary)
76 | return vocabulary_trie
77 |
78 | @classmethod
79 | def match_all(cls, cursors: Iterable[TokenAcceptor.Cursor], trie: TokenTrie) -> int:
80 | """
81 | Find which tokens in the vocabulary move any of the cursors towards an
82 | acceptance state from their current state.
83 | """
84 | if any(cursor.matches_all() for cursor in cursors):
85 | return trie.collect_ids()
86 | bitmap = 0
87 | for cursor in cursors:
88 | bitmap |= cursor.match(trie)
89 | return bitmap
90 |
91 | @classmethod
92 | def debug_match_all(
93 | cls,
94 | cursors: Iterable[TokenAcceptor.Cursor],
95 | trie: TokenTrie,
96 | debug_output_fn=print,
97 | ) -> int:
98 | """
99 | Same as match_all() but outputs debug information.
100 | """
101 | if any(cursor.matches_all() for cursor in cursors):
102 | return trie.collect_ids()
103 | debug_output_fn("MATCH ALL")
104 | bitmap = 0
105 | for cursor in cursors:
106 | start = time_ns()
107 | cursor_matches = cursor.debug_match(trie, debug_output_fn)
108 | dt_ns = time_ns() - start
109 | match_count = bin(cursor_matches).count("1")
110 | debug_output_fn(f"t={dt_ns/1e6:.02f}ms {match_count=} {repr(cursor)}")
111 | bitmap |= cursor_matches
112 | return bitmap
113 |
114 | @classmethod
115 | def advance_all(
116 | cls, cursors: Iterable[TokenAcceptor.Cursor], char: str
117 | ) -> list[TokenAcceptor.Cursor]:
118 | """
119 | Advance multiple cursors in parallel.
120 | """
121 | return [
122 | new_cursor
123 | for cursor in cursors
124 | if char in cursor.select(set(char))
125 | for new_cursor in cursor.advance(char)
126 | ]
127 |
128 | def get_cursors(self) -> Iterable[TokenAcceptor.Cursor]:
129 | """
130 | Get one or more cursors to traverse the acceptor.
131 | Override.
132 | """
133 | return [self.__class__.Cursor(self)]
134 |
135 | class Cursor:
136 | """
137 | A cursor encapsulates a valid current state of a token acceptor.
138 | """
139 |
140 | def __init__(self, acceptor: TokenAcceptor):
141 | pass
142 |
143 | def clone(self):
144 | """
145 | Cursors are never mutated, they are cloned as they advance.
146 | They should also be lightweight: think twice before overriding this
147 | to e.g. a deepcopy.
148 | """
149 | return shallowcopy(self)
150 |
151 | def matches_all(self) -> bool:
152 | """
153 | The acceptor accepts all the tokens (i.e. free text).
154 | This is an optimization and only useful for acceptors that don't constrain
155 | the input, such as WaitForAcceptor.
156 | """
157 | return False
158 |
159 | def select(self, candidate_chars: set[str]) -> Iterable[str]:
160 | """
161 | Narrow down the characters that are offered to the cursor for advancement.
162 | This is a crucial performance improvement for cursors in a state where they'll
163 | accept only a small set of characters, since they will be tested against that
164 | set instead of the whole range of characters available.
165 | Override.
166 | """
167 | return candidate_chars
168 |
169 | # pylint: disable-next=unused-argument
170 | def advance(self, char: str) -> Iterable[TokenAcceptor.Cursor]:
171 | """
172 | If the character can be consumed, return new cursor(s) for the possible
173 | continuation(s). IMPORTANT: Cursors should not mutate their state, only
174 | return mutated copies of the object, as the advance method is called
175 | multiple times with different inputs. See clone() method above.
176 | Override.
177 | """
178 | return []
179 |
180 | def in_accepted_state(self) -> bool:
181 | """
182 | Returns True if the cursor has reached a final state.
183 | Typically, rather than override you should return an AcceptedState object
184 | in the advance() method when the state is reached after consuming input.
185 | """
186 | return False
187 |
188 | def get_value(self):
189 | """
190 | Returns the current value of the cursor as defined by itself. This can be
191 | either the ongoing representation of its temporary state, or its final value
192 | usable for the application once it reaches accepted state. At that point,
193 | cursors that return the same value are considered identical and duplicates
194 | may be discarded for performance.
195 | Override.
196 | """
197 | return None
198 |
199 | def get_value_path(self):
200 | """
201 | Returns the path of the value being pointed at by the cursor as defined by the
202 | application. This can be for example a JSON path in the case of a JSON acceptor.
203 | For higher-level application purposes only, not required for accepting.
204 | Override.
205 | """
206 | return ""
207 |
208 | def is_in_value(self):
209 | """
210 | Returns true if the cursor is accepting a value as opposed to syntactic elements.
211 | Used in conjunction with get_value_path().
212 | Override.
213 | """
214 | return False
215 |
216 | def prune(self, trie: TokenTrie) -> Iterable[(str, TokenTrie)]:
217 | """
218 | Select the children of the trie to search for matches. See match() below.
219 | This can be overriden in order to e.g. use a collapsed trie.
220 | """
221 | if trie.children:
222 | chars = set(trie.children.keys())
223 | selected_chars = chars & set(self.select(chars))
224 | for char in selected_chars:
225 | yield (char, trie.children[char])
226 |
227 | def match(self, trie: TokenTrie) -> int:
228 | """
229 | Find which tokens in the vocabulary move the acceptor towards an acceptance
230 | state from the current state held by this cursor.
231 | Returns a bit map with the bits corresponding to the index if the matched
232 | tokens set to 1.
233 | """
234 | if self.matches_all():
235 | return trie.collect_ids()
236 | bitmap = 0
237 | for char, child in self.prune(trie):
238 | followup_cursors = self.advance(char)
239 | if followup_cursors:
240 | bitmap |= child.ids
241 | for followup_cursor in followup_cursors:
242 | bitmap |= followup_cursor.match(child)
243 | return bitmap
244 |
245 | def debug_match(
246 | self, trie: TokenTrie, debug_output_fn=print, debug_indent=1
247 | ) -> int:
248 | """
249 | Same as match() but outputs debug information
250 | """
251 | debug_start = time_ns()
252 | if self.matches_all():
253 | return trie.collect_ids()
254 | bitmap = 0
255 | debug_label = type(self).__qualname__
256 | if isinstance(self, StateMachineAcceptor.Cursor):
257 | debug_label += f"({type(self.transition_cursor).__qualname__})"
258 | debug_prefix = " " * debug_indent + debug_label
259 | debug_prune_start = time_ns()
260 | for char, child in self.prune(trie):
261 | debug_advance_start = time_ns()
262 | followup_cursors = self.advance(char)
263 | debug_advance_end = time_ns()
264 | prune_time = (debug_advance_start - debug_prune_start) / 1e6
265 | advance_time = (debug_advance_end - debug_advance_start) / 1e6
266 | debug_output_fn(
267 | f"{debug_prefix} >>> "
268 | f"{prune_time=:.02f}ms {advance_time=:.02f}ms char={repr(char)}"
269 | )
270 | debug_followup_start = time_ns()
271 | if followup_cursors:
272 | bitmap |= child.ids
273 | for followup_cursor in followup_cursors:
274 | bitmap |= followup_cursor.debug_match(
275 | child, debug_output_fn, debug_indent + 1
276 | )
277 | debug_followup_end = time_ns()
278 | followup_time = (debug_followup_end - debug_followup_start) / 1e6
279 | followup_count = len(followup_cursors)
280 | match_count = bin(bitmap).count("1")
281 | debug_output_fn(
282 | f"{debug_prefix} <<< {followup_count=} {followup_time=:.02f}ms {match_count=}"
283 | )
284 | debug_prune_start = time_ns()
285 | total_time = (time_ns() - debug_start) / 1e6
286 | debug_output_fn(f"{debug_prefix} {total_time=:.02f}ms")
287 | return bitmap
288 |
289 | def __repr__(self):
290 | return f"{type(self).__qualname__}(value={repr(self.get_value())})"
291 |
292 |
293 | class AcceptedState(TokenAcceptor.Cursor):
294 | """
295 | Holds a cursor that has reached the accepted state.
296 | """
297 |
298 | def __init__(self, cursor: TokenAcceptor.Cursor):
299 | self.cursor = cursor
300 |
301 | def in_accepted_state(self):
302 | return True
303 |
304 | def get_value(self):
305 | return self.cursor.get_value()
306 |
307 | def __repr__(self):
308 | return f"✅{repr(self.cursor)}"
309 |
310 |
311 | class CharAcceptor(TokenAcceptor):
312 | """
313 | Accept one character iff is in the set of expected characters.
314 | """
315 |
316 | def __init__(self, charset: Iterable[str]):
317 | self.charset = charset
318 |
319 | class Cursor(TokenAcceptor.Cursor):
320 | """
321 | Cursor for CharAcceptor
322 | """
323 |
324 | def __init__(self, acceptor, value=None):
325 | self.acceptor = acceptor
326 | self.value = value
327 |
328 | def select(self, candidate_chars):
329 | return self.acceptor.charset
330 |
331 | def advance(self, char):
332 | # Because we implemented the select method, we are guaranteed that the
333 | # char is in our accepted set.
334 | return [AcceptedState(self.__class__(self.acceptor, char))]
335 |
336 | def get_value(self):
337 | return self.value
338 |
339 | def __repr__(self):
340 | return f"charset={repr(self.acceptor.charset)} value={repr(self.value)}"
341 |
342 |
343 | class TextAcceptor(TokenAcceptor):
344 | """
345 | Accept a pre-determined string of characters.
346 | """
347 |
348 | def __init__(self, text: str):
349 | assert len(text) > 0
350 | self.text = text
351 |
352 | class Cursor(TokenAcceptor.Cursor):
353 | """
354 | Cursor for TextAcceptor
355 | """
356 |
357 | def __init__(self, acceptor, pos=0):
358 | self.acceptor = acceptor
359 | self.pos = pos
360 |
361 | def select(self, candidate_chars):
362 | return self.acceptor.text[self.pos]
363 |
364 | def advance(self, char):
365 | next_cursor = self.__class__(self.acceptor, self.pos + 1)
366 | if next_cursor.pos == len(self.acceptor.text):
367 | return [AcceptedState(next_cursor)]
368 | return [next_cursor]
369 |
370 | def get_value(self) -> str:
371 | head = self.acceptor.text[0 : self.pos]
372 | tail = self.acceptor.text[self.pos :]
373 | if len(tail):
374 | return f"{head}👉{tail}"
375 | else:
376 | return f"{head}"
377 |
378 |
379 | class StateMachineAcceptor(TokenAcceptor):
380 | """
381 | Token acceptor that follows a state graph that defines edges to transition
382 | from state to state. Each state can have multiple edges, defined by the
383 | target state and a TokenAcceptor that, when reaching accepted state, causes
384 | the state machine acceptor to move to the target state. This is repeated
385 | until the state machine reaches a final state. Multiple transition paths
386 | are explored in parallel.
387 | """
388 |
389 | def __init__(self, graph=None, initial_state=None, end_states=None):
390 | self.graph = graph or []
391 | self.initial_state = initial_state or 0
392 | self.end_states = set(end_states or ["$"])
393 |
394 | def get_edges(self, state):
395 | """
396 | Retrieve the graph edges for transitions out of this state.
397 | Can be overriden for dynamic graphs.
398 | """
399 | return self.graph[state]
400 |
401 | def get_cursors(self):
402 | initial_cursor = self.Cursor(self)
403 | initial_cursor.current_state = self.initial_state
404 | return self._find_transitions(initial_cursor, [], set())
405 |
406 | def _find_transitions(self, cursor, visited_states, traversed_edges):
407 | try:
408 | edges = self.get_edges(cursor.current_state)
409 | except (KeyError, IndexError, TypeError):
410 | assert cursor.current_state in self.end_states
411 | return []
412 | cursors = []
413 | for transition_acceptor, target_state in edges:
414 | if cursor.start_transition(transition_acceptor, target_state):
415 | for transition_cursor in transition_acceptor.get_cursors():
416 | copy = cursor.clone()
417 | copy.transition_cursor = transition_cursor
418 | copy.target_state = target_state
419 | # Handle cursors that start in an accepted state,
420 | # e.g. EmptyTransition, WhitespaceAcceptor
421 | if transition_cursor.in_accepted_state():
422 | new_visited_states = visited_states + [cursor.current_state]
423 | assert target_state not in new_visited_states # Infinite loop
424 | cursors += self._cascade_transition(
425 | copy, new_visited_states, traversed_edges
426 | )
427 | else:
428 | cursors.append(copy)
429 | return cursors
430 |
431 | def _cascade_transition(self, cursor, visited_states, traversed_edges):
432 | assert cursor.transition_cursor.in_accepted_state()
433 | # Copy before validation to allow for cursor mutation, e.g. storing the transition_value
434 | cursors = []
435 | copy: StateMachineAcceptor.Cursor = cursor.clone()
436 | if copy.complete_transition(
437 | copy.transition_cursor.get_value(),
438 | copy.target_state,
439 | copy.target_state in copy.acceptor.end_states,
440 | ):
441 | copy.current_state = copy.target_state
442 | copy.target_state = None
443 | copy.accept_history = copy.accept_history + [copy.transition_cursor.cursor]
444 | copy.transition_cursor = None
445 | copy.consumed_character_count = 0
446 | # De-duplicate cursors that have reached the same state with the same value.
447 | # This prevents combinatorial explosion because of e.g. empty transitions.
448 | state_value = (copy.current_state, repr(copy.get_value()))
449 | if state_value not in traversed_edges:
450 | traversed_edges.add(state_value)
451 | if copy.current_state in self.end_states:
452 | cursors.append(AcceptedState(copy))
453 | cursors += self._find_transitions(copy, visited_states, traversed_edges)
454 | return cursors
455 |
456 | def advance_cursor(self, cursor, char):
457 | """
458 | Advance a cursor, and if it reaches accepted state, cause the state machine to transition.
459 | """
460 | next_cursors = []
461 | traversed_edges = set()
462 | for followup_cursor in cursor.transition_cursor.advance(char):
463 | copy = cursor.clone()
464 | copy.transition_cursor = followup_cursor
465 | copy.consumed_character_count += 1
466 | if followup_cursor.in_accepted_state():
467 | next_cursors += self._cascade_transition(
468 | copy, [], traversed_edges
469 | )
470 | else:
471 | next_cursors.append(copy)
472 | return next_cursors
473 |
474 | class Cursor(TokenAcceptor.Cursor):
475 | """
476 | Cursor for StateMachineAcceptor
477 | """
478 |
479 | def __init__(self, acceptor):
480 | self.acceptor = acceptor
481 | self.accept_history = []
482 | self.current_state = None
483 | self.transition_cursor = None
484 | self.target_state = None
485 | self.consumed_character_count = 0
486 |
487 | def matches_all(self):
488 | if self.transition_cursor is None:
489 | return False
490 | return self.transition_cursor.matches_all()
491 |
492 | def select(self, candidate_chars):
493 | if self.transition_cursor is None:
494 | return set()
495 | return self.transition_cursor.select(candidate_chars)
496 |
497 | def prune(self, trie):
498 | if self.transition_cursor is None:
499 | return []
500 | return self.transition_cursor.prune(trie)
501 |
502 | def advance(self, char):
503 | return self.acceptor.advance_cursor(self, char)
504 |
505 | # pylint: disable-next=unused-argument
506 | def start_transition(self, transition_acceptor, target_state) -> bool:
507 | """
508 | Override to prevent an edge to be traversed.
509 | """
510 | return True
511 |
512 | def complete_transition( # pylint: disable-next=unused-argument
513 | self, transition_value, target_state, is_end_state
514 | ) -> bool:
515 | """
516 | Override to perform additional checks on the acceptee and mutate the cursor
517 | with the transition_value as appropriate.
518 | """
519 | return True
520 |
521 | def get_value(self):
522 | value = [
523 | accepted_transition_cursor.get_value()
524 | for accepted_transition_cursor in self.accept_history
525 | ]
526 | if self.transition_cursor is not None:
527 | value.append(self.transition_cursor.get_value())
528 | return value
529 |
530 | def is_in_value(self):
531 | if self.consumed_character_count > 0:
532 | return self.transition_cursor.is_in_value()
533 | return self.accept_history[-1].is_in_value() if self.accept_history else None
534 |
535 | def get_value_path(self):
536 | if self.consumed_character_count > 0:
537 | return self.transition_cursor.get_value_path()
538 | return self.accept_history[-1].get_value_path() if self.accept_history else ""
539 |
540 | def __repr__(self) -> str:
541 | if self.transition_cursor is not None:
542 | transition_cursor = repr(self.transition_cursor)
543 | target_state = self.target_state
544 | else:
545 | transition_cursor = "None"
546 | target_state = "None"
547 | if self.accept_history:
548 | accept_history = []
549 | for accepted_transition_cursor in self.accept_history:
550 | if isinstance(
551 | accepted_transition_cursor, StateMachineAcceptor.Cursor
552 | ):
553 | accept_history += accepted_transition_cursor.accept_history
554 | else:
555 | accept_history.append(accepted_transition_cursor)
556 | history = repr(
557 | "".join(
558 | [
559 | str(accepted_transition_cursor.get_value())
560 | for accepted_transition_cursor in accept_history
561 | ]
562 | )
563 | )
564 | else:
565 | history = ""
566 | state = (
567 | f"{history} {self.current_state}⇒{target_state} {transition_cursor}"
568 | )
569 | return f"{type(self).__qualname__}({state})"
570 |
571 | class EmptyTransitionAcceptor(TokenAcceptor):
572 | """
573 | Faux acceptor that allows to create empty transition edges in a state
574 | machine graph for convenience in expressing complex graphs.
575 | An empty edge skips the current state altogether, without the need to
576 | consume input.
577 | """
578 |
579 | def get_cursors(self):
580 | return [AcceptedState(self.Cursor(self))]
581 |
582 | class Cursor(TokenAcceptor.Cursor):
583 | """
584 | Cursor for EmptyTransitionAcceptor
585 | """
586 |
587 | def get_value(self):
588 | return ""
589 |
590 | # Singleton EmptyTransitionAcceptor
591 | EmptyTransition = EmptyTransitionAcceptor()
592 |
593 |
594 | class SequenceAcceptor(StateMachineAcceptor):
595 | """
596 | Chain acceptors in sequence
597 | """
598 |
599 | def __init__(self, acceptors):
600 | graph = [[(acceptor, i + 1)] for i, acceptor in enumerate(acceptors)]
601 | super().__init__(graph, initial_state=0, end_states=[len(acceptors)])
602 |
603 | class Cursor(StateMachineAcceptor.Cursor):
604 | """
605 | Cursor for SequenceAcceptor. Defined for inspectability.
606 | """
607 |
608 |
609 | class WaitForAcceptor(TokenAcceptor):
610 | """
611 | Accept all text until finding a segment that triggers another acceptor.
612 | This is useful to allow for free text until a delimiter is found, e.g.
613 | when the output of an LLM includes JSON that is encapsulated within a
614 | ```json ... ``` block.
615 | """
616 |
617 | def __init__(self, wait_for_acceptor: TokenAcceptor):
618 | self.wait_for_acceptor = wait_for_acceptor
619 |
620 | class Cursor(TokenAcceptor.Cursor):
621 | """
622 | Cursor for WaitForAcceptor
623 | """
624 |
625 | def __init__(self, acceptor, cursors=None):
626 | self.acceptor = acceptor
627 | if cursors:
628 | self.cursors = cursors
629 | else:
630 | self.cursors = acceptor.wait_for_acceptor.get_cursors()
631 |
632 | def matches_all(self):
633 | return True
634 |
635 | def advance(self, char):
636 | cursors = TokenAcceptor.advance_all(self.cursors, char)
637 | accepted_cursors = [
638 | cursor for cursor in cursors if cursor.in_accepted_state()
639 | ]
640 | if accepted_cursors:
641 | return accepted_cursors
642 | return [self.__class__(self.acceptor, cursors)]
643 |
644 | def get_value(self):
645 | return f"Waiting for {repr(self.cursors)}"
646 |
--------------------------------------------------------------------------------
/src/llm_structured_output/json_acceptor.py:
--------------------------------------------------------------------------------
1 | """
2 | Acceptors for JSON parsing or constraning LLM generation to JSON outputs.
3 | """
4 |
5 | import json
6 |
7 | from .acceptor import (
8 | TokenAcceptor,
9 | AcceptedState,
10 | CharAcceptor,
11 | StateMachineAcceptor,
12 | SequenceAcceptor,
13 | TextAcceptor,
14 | )
15 | from .util.tokentrie import TokenTrie
16 |
17 |
18 | class WhitespaceTokenTrie(TokenTrie):
19 | """
20 | Create a smaller trie by collapsing all whitespace to a single symbol.
21 | Since all whitespace is equivalent in JSON, tokens that only differ in
22 | the type of whitespace are equivalent from a semantic point of view.
23 |
24 | For example, the tokens "\n\n\n", "\t\t\t" and " " are all mapped to the same
25 | node root -> " " -> " " -> " ", which now contains the token ids of all three
26 | tokens in its set of ids.
27 |
28 | This allows us to reduce the number of equivalent branches we explore when
29 | finding valid tokens. Note that this doesn't limit the possible output of
30 | an LLM, since the token ids are kept in the trie and thus matched as valid,
31 | and are accepted by the acceptor.
32 | """
33 |
34 | @classmethod
35 | def from_trie(cls, trie, whitespace_charset):
36 | """
37 | Create a WhitespaceTokenTrie given a full vocabulary trie.
38 | """
39 | if isinstance(trie, WhitespaceTokenTrie):
40 | return trie
41 |
42 | def _whitespace_collapse_fn(char, level):
43 | if char in whitespace_charset:
44 | return " "
45 | if level == 0:
46 | # The trie doesn't need to contain tokens that don't start with whitespace,
47 | # since they won't be selected by the WhitespaceAcceptor.
48 | return None
49 | return True
50 |
51 | # pylint: disable-next=protected-access
52 | return trie._map(_whitespace_collapse_fn, WhitespaceTokenTrie())
53 |
54 |
55 | class WhitespaceAcceptor(TokenAcceptor):
56 | """
57 | Optional whitespace
58 | """
59 |
60 | WHITESPACE = " \n\r\t"
61 |
62 | _cached_tries = {}
63 |
64 | @classmethod
65 | def prepare_trie(cls, trie: TokenTrie):
66 | """
67 | Build a collapsed trie that reduces the search space for valid tokens.
68 | """
69 | trie_id = id(trie)
70 | if trie_id in cls._cached_tries:
71 | return cls._cached_tries[trie_id]
72 | collapsed_trie = WhitespaceTokenTrie.from_trie(
73 | trie, WhitespaceAcceptor.WHITESPACE
74 | )
75 | cls._cached_tries[trie_id] = collapsed_trie
76 | return collapsed_trie
77 |
78 | def __init__(self, max_whitespace: int = 40):
79 | self.max_whitespace = max_whitespace
80 |
81 | def get_cursors(self):
82 | # Whitespace is optional
83 | cursor = WhitespaceAcceptor.Cursor(self)
84 | return [cursor, AcceptedState(cursor)]
85 |
86 | class Cursor(TokenAcceptor.Cursor):
87 | """
88 | Cursor for WhitespaceAcceptor
89 | """
90 |
91 | def __init__(self, acceptor, text=""):
92 | self.acceptor = acceptor
93 | self.text = text
94 | self.length_exceeded = len(text) > self.acceptor.max_whitespace
95 |
96 | def select(self, candidate_chars):
97 | if self.length_exceeded:
98 | return set()
99 | return WhitespaceAcceptor.WHITESPACE
100 |
101 | def prune(self, trie):
102 | """
103 | Use a custom matching trie to collapse all equivalent whitespace
104 | into one, saving time when selecting valid tokens.
105 | """
106 | collapsed_trie = WhitespaceAcceptor.prepare_trie(trie)
107 | return super().prune(collapsed_trie)
108 |
109 | def advance(self, char):
110 | # Sometimes, LLMs try to run away with spaces when they don't know how to continue.
111 | # If the LLM triggers this often, consider whether the LLM is suitable for emitting
112 | # JSON and/or whether the task is achievable and makes sense with the information
113 | # provided in the prompt.
114 | if self.length_exceeded:
115 | return []
116 | next_cursor = WhitespaceAcceptor.Cursor(self.acceptor, self.text + char)
117 | # More whitespace is optional
118 | return [next_cursor, AcceptedState(next_cursor)]
119 |
120 | def get_value(self):
121 | return self.text
122 |
123 |
124 | class BooleanAcceptor(StateMachineAcceptor):
125 | """
126 | Accepts a JSON boolean value: true, false
127 | """
128 |
129 | def __init__(self):
130 | super().__init__([[(TextAcceptor("true"), "$"), (TextAcceptor("false"), "$")]])
131 |
132 | class Cursor(StateMachineAcceptor.Cursor):
133 | """
134 | Cursor for BooleanAcceptor
135 | """
136 |
137 | def __init__(self, acceptor):
138 | super().__init__(acceptor)
139 | self.value = None
140 |
141 | def complete_transition(self, transition_value, target_state, is_end_state):
142 | if is_end_state:
143 | if transition_value == "true":
144 | self.value = True
145 | else:
146 | assert transition_value == "false"
147 | self.value = False
148 | return True
149 |
150 | def get_value(self):
151 | return self.value
152 |
153 | def is_in_value(self):
154 | return True
155 |
156 |
157 | class NullAcceptor(TextAcceptor):
158 | """
159 | Accepts the JSON null value
160 | """
161 |
162 | def __init__(self):
163 | super().__init__("null")
164 |
165 | class Cursor(TextAcceptor.Cursor):
166 | """
167 | Cursor for NullAcceptor
168 | """
169 |
170 | def is_in_value(self):
171 | return True
172 |
173 |
174 | DigitAcceptor = CharAcceptor("0123456789")
175 | HexDigitAcceptor = CharAcceptor("0123456789ABCDEFabcdef")
176 |
177 |
178 | class StringCharTokenTrie(TokenTrie):
179 | """
180 | Create a smaller trie by collapsing all unescaped valid string characters
181 | to a single one while keeping the token ids. This is useful to reduce
182 | combinatorial explosion in string acceptance when all strings of equal
183 | length are equally acceptable.
184 | """
185 |
186 | @classmethod
187 | def from_trie(cls, trie):
188 | """
189 | Create a StringCharTokenTrie given a full trie.
190 | """
191 | if isinstance(trie, StringCharTokenTrie):
192 | return trie
193 |
194 | def _string_char_acceptor_collapse_fn(char, _level):
195 | if char in ['"', "\\"]:
196 | return True
197 | if char in StringCharAcceptor.INVALID_CHARS:
198 | return None
199 | return "."
200 |
201 | # pylint: disable-next=protected-access
202 | return trie._map(_string_char_acceptor_collapse_fn, StringCharTokenTrie())
203 |
204 |
205 | class StringCharAcceptor(TokenAcceptor):
206 | """
207 | Accepts a valid JSON unescaped string character
208 | """
209 |
210 | INVALID_CHARS = set(chr(c) for c in range(0, 0x20)) | set(['"', "\\"])
211 | _cached_tries = {}
212 |
213 | @classmethod
214 | def prepare_trie(cls, trie: TokenTrie):
215 | """
216 | Build a collapsed trie that reduces the search space for valid tokens.
217 | Note that while there is only one main vocabulary trie, we may need to
218 | several collapsed tries because sometimes string matching will start
219 | in the middle of the main trie. I.e. we ara half way through the main
220 | trie with another acceptor; that acceptor reaches an end state and then
221 | we transition to the string acceptor; thus we start string matching in
222 | the middle of the main trie instead of the root. This can happen e.g.
223 | if there's tokens in the vocabulary that contain a quote and then
224 | additional characters afterwards.
225 | """
226 | trie_id = id(trie)
227 | if trie_id in cls._cached_tries:
228 | return cls._cached_tries[trie_id]
229 | collapsed_trie = StringCharTokenTrie().from_trie(trie)
230 | cls._cached_tries[trie_id] = collapsed_trie
231 | return collapsed_trie
232 |
233 | class Cursor(TokenAcceptor.Cursor):
234 | """
235 | Cursor for StringCharAcceptor
236 | """
237 |
238 | def __init__(self, acceptor, value=None):
239 | self.acceptor = acceptor
240 | self.value = value
241 |
242 | def select(self, candidate_chars):
243 | return candidate_chars - StringCharAcceptor.INVALID_CHARS
244 |
245 | def prune(self, trie):
246 | """
247 | Use a custom matching trie to avoid an explosion of valid options that
248 | are equivalent from the point of view of token matching.
249 | """
250 | return super().prune(StringCharAcceptor.prepare_trie(trie))
251 |
252 | def advance(self, char):
253 | return [AcceptedState(StringCharAcceptor.Cursor(self.acceptor, char))]
254 |
255 | def get_value(self):
256 | return self.value
257 |
258 |
259 | class StringAcceptor(StateMachineAcceptor):
260 | """
261 | Accepts a well-formed JSON string
262 | """
263 |
264 | STATES = [
265 | [(CharAcceptor('"'), 1)],
266 | [(CharAcceptor('"'), "$"), (CharAcceptor("\\"), 2), (StringCharAcceptor(), 1)],
267 | [
268 | (CharAcceptor('"\\/bfnrt'), 1),
269 | (CharAcceptor("u"), 3),
270 | ],
271 | [(HexDigitAcceptor, 4)],
272 | [(HexDigitAcceptor, 5)],
273 | [(HexDigitAcceptor, 6)],
274 | [(HexDigitAcceptor, 1)],
275 | ]
276 |
277 | def __init__(self):
278 | super().__init__(StringAcceptor.STATES)
279 |
280 | class Cursor(StateMachineAcceptor.Cursor):
281 | """
282 | Cursor for StringAcceptor
283 | """
284 |
285 | def __init__(self, acceptor):
286 | super().__init__(acceptor)
287 | self.text = ""
288 | self.length = 0
289 | self.value = None
290 |
291 | def complete_transition(self, transition_value, target_state, is_end_state):
292 | self.text += transition_value
293 | if target_state == 1 and self.current_state != 0:
294 | self.length += 1
295 | if is_end_state:
296 | self.value = json.loads(self.text)
297 | return True
298 |
299 | def get_value(self):
300 | if self.value is not None:
301 | return self.value
302 | else:
303 | return f"{self.text}👉"
304 |
305 | def is_in_value(self):
306 | return True
307 |
308 |
309 | class StringConstantAcceptor(TextAcceptor):
310 | """
311 | Accept a constant string, quoted and escaped.
312 | """
313 |
314 | def __init__(self, string: str):
315 | self.string = string
316 | super().__init__(json.dumps(string))
317 |
318 | class Cursor(TextAcceptor.Cursor):
319 | """
320 | Cursor for StringConstantAcceptor
321 | """
322 |
323 | def get_value(self) -> str:
324 | if self.pos == len(self.acceptor.text):
325 | return self.acceptor.string
326 | return super().get_value()
327 |
328 | def is_in_value(self):
329 | return True
330 |
331 |
332 | class NumberTokenTrie(TokenTrie):
333 | """
334 | Create a smaller trie by collapsing digit sequences.
335 | """
336 |
337 | @classmethod
338 | def from_trie(cls, trie):
339 | """
340 | Create a NumberTokenTrie given a full trie.
341 | """
342 | if isinstance(trie, NumberTokenTrie):
343 | return trie
344 |
345 | def _number_acceptor_collapse_fn(char, level):
346 | if char in "0123456789":
347 | return "9"
348 | # Only store branches that start with a digit.
349 | return level > 0
350 |
351 | # pylint: disable-next=protected-access
352 | return trie._map(_number_acceptor_collapse_fn, StringCharTokenTrie())
353 |
354 |
355 | class NumberAcceptor(StateMachineAcceptor):
356 | """
357 | Accepts a well-formed JSON number
358 | """
359 |
360 | STATES = {
361 | 0: [(CharAcceptor("-"), 1), (StateMachineAcceptor.EmptyTransition, 1)], # Sign
362 | 1: [(CharAcceptor("123456789"), 2), (CharAcceptor("0"), 3)], # First digit
363 | 2: [
364 | (DigitAcceptor, 2),
365 | (StateMachineAcceptor.EmptyTransition, 3),
366 | ], # More digits
367 | 3: [(CharAcceptor("."), 4), (StateMachineAcceptor.EmptyTransition, 6)],
368 | 4: [(DigitAcceptor, 5)], # First decimal
369 | 5: [
370 | (DigitAcceptor, 5),
371 | (StateMachineAcceptor.EmptyTransition, 6),
372 | ], # More decimals
373 | 6: [(CharAcceptor("eE"), 7)],
374 | 7: [(CharAcceptor("+-"), 8), (StateMachineAcceptor.EmptyTransition, 8)],
375 | 8: [(DigitAcceptor, 9)], # Exponential, first digit
376 | 9: [(DigitAcceptor, 9)], # Exponential, more digits
377 | "$": [2, 3, 5, 9],
378 | }
379 | _cached_tries = {}
380 |
381 | @classmethod
382 | def prepare_trie(cls, trie: TokenTrie):
383 | """
384 | Build a collapsed trie that reduces the search space for valid tokens.
385 | """
386 | trie_id = id(trie)
387 | if trie_id in cls._cached_tries:
388 | return cls._cached_tries[trie_id]
389 | collapsed_trie = NumberTokenTrie().from_trie(trie)
390 | cls._cached_tries[trie_id] = collapsed_trie
391 | return collapsed_trie
392 |
393 | def __init__(self):
394 | super().__init__(self.STATES, 0, self.STATES["$"])
395 |
396 | class Cursor(StateMachineAcceptor.Cursor):
397 | """
398 | Cursor for NumberAcceptor
399 | """
400 |
401 | def __init__(self, acceptor):
402 | super().__init__(acceptor)
403 | self.text = ""
404 | self.value = None
405 |
406 | def prune(self, trie):
407 | """
408 | Use a custom matching trie to avoid an explosion of valid options that
409 | are equivalent from the point of view of token matching.
410 | """
411 | return super().prune(NumberAcceptor.prepare_trie(trie))
412 |
413 | def complete_transition(self, transition_value, target_state, is_end_state):
414 | self.text += transition_value
415 | if is_end_state:
416 | self.value = json.loads(self.text)
417 | return True
418 |
419 | def get_value(self):
420 | if self.value is None:
421 | return f"{self.text}👉"
422 | return self.value
423 |
424 | def is_in_value(self):
425 | return True
426 |
427 |
428 | class ArrayAcceptor(StateMachineAcceptor):
429 | """
430 | Accepts a well-formed JSON array
431 | """
432 |
433 | def __init__(self):
434 | super().__init__()
435 |
436 | def get_edges(self, state):
437 | return {
438 | 0: [(TextAcceptor("["), 1)],
439 | 1: [(WhitespaceAcceptor(), 2), (TextAcceptor("]"), "$")],
440 | 2: [(JsonAcceptor(), 3)],
441 | 3: [(WhitespaceAcceptor(), 4)],
442 | 4: [
443 | (SequenceAcceptor([TextAcceptor(","), WhitespaceAcceptor()]), 2),
444 | (TextAcceptor("]"), "$"),
445 | ],
446 | }[state]
447 |
448 | class Cursor(StateMachineAcceptor.Cursor):
449 | """
450 | Cursor for ArrayAcceptor
451 | """
452 |
453 | def __init__(self, acceptor):
454 | super().__init__(acceptor)
455 | self.value = []
456 |
457 | def clone(self):
458 | c = super().clone()
459 | c.value = self.value[:]
460 | return c
461 |
462 | def complete_transition(
463 | self, transition_value, target_state, is_end_state
464 | ) -> bool:
465 | if self.current_state == 2:
466 | self.value.append(transition_value)
467 | return True
468 |
469 | def get_value_path(self):
470 | index = len(self.value)
471 | if self.current_state > 2:
472 | index -= 1
473 | return f"[{index}]{super().get_value_path()}"
474 |
475 |
476 | class ObjectAcceptor(StateMachineAcceptor):
477 | """
478 | Accepts a well-formed JSON object
479 | """
480 |
481 | def __init__(self):
482 | super().__init__()
483 |
484 | def get_edges(self, state):
485 | return {
486 | 0: [(TextAcceptor("{"), 1)],
487 | 1: [(self.EmptyTransition, 2), (self.EmptyTransition, 6)],
488 | 2: [(WhitespaceAcceptor(), 3)],
489 | 3: [(ObjectAcceptor.PropertyAcceptor(), 4)],
490 | 4: [(WhitespaceAcceptor(), 5)],
491 | 5: [(TextAcceptor(","), 2), (self.EmptyTransition, 7)],
492 | 6: [(WhitespaceAcceptor(), 7)],
493 | 7: [(TextAcceptor("}"), "$")],
494 | }[state]
495 |
496 | class Cursor(StateMachineAcceptor.Cursor):
497 | """
498 | Cursor for ObjectAcceptor
499 | """
500 |
501 | def __init__(self, acceptor):
502 | super().__init__(acceptor)
503 | self.value = {}
504 |
505 | def complete_transition(
506 | self, transition_value, target_state, is_end_state
507 | ) -> bool:
508 | if self.current_state == 3:
509 | prop_name, prop_value = transition_value
510 | self.value[prop_name] = prop_value
511 | return True
512 |
513 | def get_value(self):
514 | return self.value
515 |
516 | class PropertyAcceptor(SequenceAcceptor):
517 | """
518 | JSON object property acceptor
519 | """
520 |
521 | def __init__(self, graph=None):
522 | if graph is None:
523 | graph = [
524 | StringAcceptor(),
525 | WhitespaceAcceptor(),
526 | TextAcceptor(":"),
527 | WhitespaceAcceptor(),
528 | JsonAcceptor(),
529 | ]
530 | super().__init__(graph)
531 |
532 | class Cursor(SequenceAcceptor.Cursor):
533 | """
534 | Cursor for ObjectAcceptor.PropertyAcceptor
535 | """
536 |
537 | def __init__(self, acceptor):
538 | super().__init__(acceptor)
539 | self.prop_name = None
540 | self.prop_value = None
541 |
542 | def complete_transition(
543 | self, transition_value, target_state, is_end_state
544 | ) -> bool:
545 | if target_state == 1:
546 | self.prop_name = transition_value
547 | elif is_end_state:
548 | self.prop_value = transition_value
549 | return True
550 |
551 | def get_value(self):
552 | return (self.prop_name, self.prop_value)
553 |
554 | def is_in_value(self):
555 | return self.current_state >= 4 and super().is_in_value()
556 |
557 | def get_value_path(self):
558 | return f".{self.prop_name}{super().get_value_path()}"
559 |
560 |
561 | class JsonAcceptor(StateMachineAcceptor):
562 | """
563 | Acceptor for a JSON value
564 | """
565 |
566 | def get_edges(self, state):
567 | if state == 0:
568 | return [
569 | (BooleanAcceptor(), "$"),
570 | (NumberAcceptor(), "$"),
571 | (StringAcceptor(), "$"),
572 | (NullAcceptor(), "$"),
573 | (ObjectAcceptor(), "$"),
574 | (ArrayAcceptor(), "$"),
575 | ]
576 | return []
577 |
578 |
579 | def prepare_json_acceptor_tries(trie: TokenTrie):
580 | """
581 | Pre-cache custom acceptor tries.
582 | """
583 | WhitespaceAcceptor.prepare_trie(trie)
584 | NumberAcceptor.prepare_trie(trie)
585 | StringCharAcceptor.prepare_trie(trie)
586 | if '"' in trie.children:
587 | StringCharAcceptor.prepare_trie(trie.children['"'])
588 |
--------------------------------------------------------------------------------
/src/llm_structured_output/util/__init__.py:
--------------------------------------------------------------------------------
1 | from . import bitmap
2 | from . import output
3 | from . import tokentrie
4 | from . import tokenization
5 |
--------------------------------------------------------------------------------
/src/llm_structured_output/util/bitmap.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities to use the bitmap of accepted token ids returned by TokenAcceptor.
3 | """
4 |
5 | from math import inf
6 | from typing import Iterable
7 |
8 |
9 | def count_set_bits(bitmap: int) -> int:
10 | """
11 | Count the number of bits set to one.
12 | """
13 | # FUTURE: self.ids.bit_count() available from Python 3.10 is said to be 6x faster
14 | return bin(bitmap).count("1")
15 |
16 |
17 | def highest_bit_set(bitmap: int) -> int:
18 | """
19 | Return the index of the highest bit set in the bitmap.
20 | """
21 | return bitmap.bit_length() - 1
22 |
23 |
24 | def bitmap_complement(bitmap: int, set_size: int = None) -> int:
25 | """
26 | Negate the bits in the bitmap.
27 | Since the bitmap is encoded as a Python int, it can be of arbitrary length.
28 | I.e. we don't know how many zeros are above the top set bit. The set_size
29 | parameter can be passed to indicate the number of bits in the bitmap (which
30 | is akin to the number of members in the set it represents). If unspecified,
31 | the top set bit in the bitmap is used as its set size.
32 | """
33 | if not set_size:
34 | set_size = bitmap.bit_length()
35 | return (1 << set_size) - 1 - bitmap
36 |
37 |
38 | def enumerate_set_bits(bitmap: int) -> Iterable[int]:
39 | """
40 | Generator that yields the indices of the set bits in the bitmap.
41 | Note that it does so from highest to lowest.
42 | """
43 | while bitmap:
44 | highest_bit = highest_bit_set(bitmap)
45 | yield highest_bit
46 | bitmap -= 1 << highest_bit
47 |
48 |
49 | def bias_logits(np, logits, accepted_token_bitmap):
50 | """
51 | Apply a -inf bias to tokens that will not be accepted.
52 | Rather than import here, the np parameters is numpy or a compatible library
53 | import, such as mlx.core.
54 | """
55 | vocab_size = logits.shape[0]
56 | highest_token_accepted = highest_bit_set(accepted_token_bitmap)
57 | accepted_token_count = count_set_bits(accepted_token_bitmap)
58 | # Check whether there's more tokens to be rejected or to be allowed, then do what's less work.
59 | if accepted_token_count <= highest_token_accepted / 2:
60 | bias = np.full(vocab_size, -inf)
61 | indices = np.array([*enumerate_set_bits(accepted_token_bitmap)])
62 | bias[indices] = 0
63 | else:
64 | bias = np.concatenate(
65 | [
66 | np.full(highest_token_accepted + 1, 0),
67 | # All tokens above the highest accepted token are rejected.
68 | np.full(vocab_size - highest_token_accepted - 1, -inf),
69 | ]
70 | )
71 | rejected_token_bitmap = bitmap_complement(accepted_token_bitmap)
72 | indices = np.array([*enumerate_set_bits(rejected_token_bitmap)])
73 | bias[indices] = -inf
74 | return np.add(logits, bias)
75 |
--------------------------------------------------------------------------------
/src/llm_structured_output/util/output.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | """
3 | Terminal colored output
4 | """
5 |
6 |
7 | def info(*args, **kwargs):
8 | print("\033[34mℹ ", end="")
9 | print(*args, **kwargs)
10 | print("\033[0m", end="")
11 |
12 |
13 | def warning(*args, **kwargs):
14 | print("\033[43;37m", end="")
15 | print(*args, **kwargs)
16 | print("\033[0m", end="")
17 |
18 |
19 | def debug(*args, **kwargs):
20 | print("\033[33m", end="")
21 | print(*args, **kwargs)
22 | print("\033[0m", end="")
23 |
24 |
25 | def debugbold(*args, **kwargs):
26 | print("\033[1;33m", end="")
27 | print(*args, **kwargs)
28 | print("\033[0m", end="")
29 |
30 |
31 | def bold(*args, **kwargs):
32 | print("\033[1;30m", end="")
33 | print(*args, **kwargs)
34 | print("\033[0m", end="")
35 |
36 |
37 | def bolddim(*args, **kwargs):
38 | print("\033[1;2;30m", end="")
39 | print(*args, **kwargs)
40 | print("\033[0m", end="")
41 |
42 |
43 | def boldalt(*args, **kwargs):
44 | print("\033[1;36m", end="")
45 | print(*args, **kwargs)
46 | print("\033[0m", end="")
47 |
48 |
49 | def underline(*args, **kwargs):
50 | print("\033[4m", end="")
51 | print(*args, **kwargs)
52 | print("\033[0m", end="")
53 |
54 |
55 | def inverse(*args, **kwargs):
56 | print("\033[7m", end="")
57 | print(*args, **kwargs)
58 | print("\033[0m", end="")
59 |
60 |
61 | def setfg(r: float, g: float, b: float):
62 | """Each of r,g,b must be between 0 and 1"""
63 | color = 16 + 36 * round(5 * r) + 6 * round(5 * g) + round(5 * b)
64 | print(f"\033[38;5;{color}m", end="")
65 |
66 |
67 | def setbg(r: float, g: float, b: float):
68 | """Each of r,g,b must be between 0 and 1"""
69 | color = 16 + 36 * round(5 * r) + 6 * round(5 * g) + round(5 * b)
70 | print(f"\033[48;5;{color}m", end="")
71 |
72 |
73 | def clear():
74 | print("\033[0m", end="")
75 |
--------------------------------------------------------------------------------
/src/llm_structured_output/util/tokenization.py:
--------------------------------------------------------------------------------
1 | """
2 | Tokenizer utils.
3 | """
4 |
5 | from typing import Union
6 |
7 | SPIECE_UNDERLINE = "▁"
8 |
9 |
10 | class HuggingfaceTokenizerHelper:
11 | """
12 | Helper to use Huggingface tokenizers effectively.
13 | """
14 |
15 | def __init__(self, tokenizer):
16 | """
17 | tokenizer is expected to be a Huggingface PreTrainedTokenizer[Fast]
18 | """
19 | self.tokenizer = tokenizer
20 | self.token_has_space_prefix = dict(
21 | [
22 | (i, fragment[0] == SPIECE_UNDERLINE)
23 | for fragment, i in tokenizer.vocab.items()
24 | ]
25 | )
26 |
27 | def encode_prompt(self, prompt: Union[str, list[dict[str, str]]]) -> list[int]:
28 | """
29 | Encode the prompt, applying the tokenizer template first if the prompt
30 | is a series of messages instead of a straight string.
31 | """
32 | if isinstance(prompt, str):
33 | return self.tokenizer.encode(prompt)
34 | if not self.tokenizer.chat_template:
35 | return self.tokenizer.encode("\n\n".join(
36 | f"{message['role']}: {message['content']}"
37 | for message in prompt
38 | ))
39 | return self.tokenizer.apply_chat_template(prompt)
40 |
41 | def no_strip_decode(self, tokens):
42 | """
43 | Allows to decode single tokens without removing the initial space.
44 | The Huggingface tokenizer doesn't seem to have an easy way to do this.
45 | """
46 | fragment = self.tokenizer.decode(tokens)
47 | if self.token_has_space_prefix[tokens[0]]:
48 | return f" {fragment}"
49 | else:
50 | return fragment
51 |
52 | def extract_vocabulary(self) -> tuple[list[tuple[int, str]], int]:
53 | """
54 | Extract the vocabulary and eos_token_id from a Huggingface PreTrainedTokenizer.
55 | """
56 | return (
57 | [(i, self.no_strip_decode([i])) for _, i in self.tokenizer.vocab.items()],
58 | self.tokenizer.eos_token_id,
59 | )
60 |
--------------------------------------------------------------------------------
/src/llm_structured_output/util/tokentrie.py:
--------------------------------------------------------------------------------
1 | """
2 | TokenTrie: hold the LLM token vocabulary in a prefix tree in otder to perform
3 | operations over the whole vocabulary or parts of it in logarithmic time instead
4 | of linear.
5 | """
6 |
7 | from __future__ import annotations
8 | from collections import namedtuple
9 | from typing import Callable, Iterable, Tuple
10 |
11 |
12 | TokenTrieStats = namedtuple(
13 | "TokenTrieStats", ["tokenids", "trienodes", "trieleaves", "triedepth"]
14 | )
15 |
16 |
17 | class TokenTrie:
18 | """
19 | Access the tokens in a vocabulary hierarchically by prefix.
20 | Ids are stored as a bitmap with bits set to one meaning id is present.
21 | """
22 |
23 | def __init__(self):
24 | self.children: dict[str, TokenTrie] = {}
25 | self.ids: int = 0
26 |
27 | def insert_all(self, vocabulary: Iterable[Tuple[int, str]]):
28 | """
29 | Insert all the tokens in the vocabulary in the trie, with the id of
30 | each token being its index in the vocabulary.
31 | """
32 | for _id, token in vocabulary:
33 | if len(token) > 0:
34 | self.insert(token, _id)
35 |
36 | def insert(self, token, _id):
37 | """
38 | Insert one token in the trie, with the given id.
39 | """
40 | if len(token) == 0:
41 | self.ids |= 1 << _id
42 | else:
43 | head, tail = token[0], token[1:]
44 | child = self.children.get(head, self.__class__())
45 | child.insert(tail, _id)
46 | self.children[head] = child
47 |
48 | def insert_ids(self, token, ids):
49 | """
50 | Insert a token in the trie, with the given id set.
51 | This is useful e.g. when collapsing multiple branches into one.
52 | """
53 | if len(token) == 0:
54 | self.ids |= ids
55 | else:
56 | head, tail = token[0], token[1:]
57 | child = self.children.get(head, self.__class__())
58 | child.insert_ids(tail, ids)
59 | self.children[head] = child
60 |
61 | def collect_ids(self) -> set[int]:
62 | """
63 | Returns a set with the ids of the token(s) in this node and all the
64 | nodes below it.
65 | """
66 | ids = self.ids
67 | for child in self.children.values():
68 | ids |= child.collect_ids()
69 | return ids
70 |
71 | def dfs(self, prefix="") -> Iterable[tuple[str, int]]:
72 | """
73 | Traverse the trie depth-first, yielding (token, ids) tuples.
74 | """
75 | if self.ids:
76 | yield (prefix, self.ids)
77 | for char, child in self.children.items():
78 | yield from child.dfs(prefix + char)
79 |
80 | def map(self, map_fn: Callable[[str, int], str]) -> TokenTrie:
81 | """
82 | Return a trie where the characters are mapped to other characters using a
83 | function. This is useful for example to collapse a tree into a smaller one
84 | by pruning or merging branches where the characters are equivalent for a
85 | particular use case. The mapping function is passed a character to map, and
86 | the recursion level in the tree, and it can return True to preserve the
87 | branch of the tree as is, None to prune it, or a replacement character.
88 | If the latter, the branch will be recursed upon and stored under the
89 | replacement branch.
90 | """
91 | return self._map(map_fn, self.__class__())
92 |
93 | def _map(
94 | self, map_fn: Callable[[str, int], str], mapped_trie: TokenTrie, level: int = 0
95 | ) -> TokenTrie:
96 | """
97 | Internal implementation of map()
98 | """
99 | mapped_trie.ids |= self.ids
100 | for char, child in self.children.items():
101 | mapped_char = map_fn(char, level)
102 | if mapped_char is True:
103 | # If the mapping function returns True, preserve the original branch
104 | mapped_trie.children[char] = child
105 | elif mapped_char is None:
106 | # If the mapping function returns None, prune the original branch
107 | pass
108 | else:
109 | # Map the branch to a new character, e.g. merge several chars into one
110 | mapped_child = mapped_trie.children.get(
111 | mapped_char, mapped_trie.__class__()
112 | )
113 | # pylint: disable-next=protected-access
114 | mapped_trie.children[mapped_char] = child._map(
115 | map_fn, mapped_child, level + 1
116 | )
117 | return mapped_trie
118 |
119 | def _id_count(self) -> int:
120 | """
121 | Returns the number of ids in this node
122 | """
123 | # FUTURE: self.ids.bit_count() available from Python 3.10 is said to be 6x faster
124 | return bin(self.ids).count("1")
125 |
126 | def max_depth(self) -> int:
127 | """
128 | Return the max depth of any branch on the trie, i.e. the length of the longest token.
129 | """
130 | return max((child.max_depth() for child in self.children.values()), default=0) + 1
131 |
132 | def stats(self) -> TokenTrieStats:
133 | """
134 | Compute and return statistics on the trie, for debugging purposes.
135 | """
136 | ids = self._id_count()
137 | nodes = 1
138 | leaves = 0
139 | depth = 0
140 | if len(self.children) == 0:
141 | leaves = 1
142 | else:
143 | for branch in self.children.values():
144 | branch_ids, branch_nodes, branch_leaves, branch_depth = branch.stats()
145 | ids += branch_ids
146 | nodes += branch_nodes
147 | leaves += branch_leaves
148 | depth = max(depth, branch_depth)
149 | return TokenTrieStats(
150 | tokenids=ids, trienodes=nodes, trieleaves=leaves, triedepth=depth + 1
151 | )
152 |
153 | def __repr__(self):
154 | id_count = self._id_count()
155 | child_count = len(self.children)
156 | return f"{super().__repr__()}({id_count=}, {child_count=})"
157 |
--------------------------------------------------------------------------------
/src/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/otriscon/llm-structured-output/037e8eb7447005fda06e7d811b041efcb94b0cef/src/tests/__init__.py
--------------------------------------------------------------------------------
/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/multi_turn-00000-of-00001.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/otriscon/llm-structured-output/037e8eb7447005fda06e7d811b041efcb94b0cef/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/multi_turn-00000-of-00001.parquet
--------------------------------------------------------------------------------
/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/parquet_to_jsonl.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert a fireworks function calling dataset from parquet to jsonl that can be
3 | used by the evaluation scripts.
4 |
5 | https://huggingface.co/datasets/fireworks-ai/function-calling-eval-dataset-v0
6 | """
7 |
8 | import sys
9 | import json
10 | import pyarrow.parquet as pq
11 |
12 | if len(sys.argv) < 2:
13 | print("Need path to parquet file.")
14 | sys.exit(1)
15 | input_file = sys.argv[1]
16 | data = pq.read_table(input_file).to_pydict()
17 | prompts = data["prompt"]
18 | completions = data["completion"]
19 | tools = data["tools"]
20 |
21 | output_file = input_file.replace(".parquet", ".jsonl")
22 | if output_file == input_file:
23 | output_file += ".jsonl"
24 |
25 | with open(output_file, mode="w", encoding="utf-8") as f:
26 | for i, prompt in enumerate(prompts):
27 | json.dump(
28 | {
29 | "prompt": prompt,
30 | "tools": json.loads(tools[i]),
31 | # The source dataset contains one gold completion per case, but we output an array
32 | # to support multiple gold answers down the line.
33 | "gold": [
34 | {
35 | "type": "function",
36 | "function": json.loads(
37 | completions[i].partition("")[2]
38 | ),
39 | }
40 | ],
41 | "options": {
42 | "prompt_includes_schema": True,
43 | # This dataset has only cases where one tool is invoked, and the prompt includes
44 | # an example in which the output is not an array but a single tool call.
45 | "single_tool": True,
46 | },
47 | },
48 | f,
49 | )
50 | f.write("\n")
51 |
--------------------------------------------------------------------------------
/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/report-multi_turn.md:
--------------------------------------------------------------------------------
1 | case | mlx-community/Meta-Llama-3-8B-Instruct-4bit | gpt-4o-2024-05-13
2 | --- | --- | ---
3 | 0 | ✅ | ✅
4 | 1 | ✅ | ✅
5 | 2 | ✅ | ✅
6 | 3 | ✅ | ✅
7 | 4 | ✅ | ✅
8 | 5 | ✅ | ✅
9 | 6 | ✅ | _function_call[0].url_ ~~'www.mywebsite.com'~~ 'http://www.mywebsite.com'
10 | 7 | ✅ | ✅
11 | 8 | ✅ | ✅
12 | 9 | ✅ | ✅
13 | 10 | ✅ | ✅
14 | 11 | ✅ | ✅
15 | 12 | ✅ | ✅
16 | 13 | ✅ | ✅
17 | 14 | ✅ | ✅
18 | 15 | ✅ | ✅
19 | 16 | ✅ | ✅
20 | 17 | ✅ | ✅
21 | 18 | ➕ _function_call[0].genre_ 'action' | ✅
22 | 19 | ✅ | ✅
23 | 20 | ✅ | ✅
24 | 21 | ✅ | ✅
25 | 22 | ✅ | ✅
26 | 23 | ✅ | ✅
27 | 24 | ✅ | ✅
28 | 25 | ➕ _function_call[0].genre_ 'pop' ⸱ _function_call[0].keyword_ ~~'pop'~~ 'Taylor Swift' | ✅
29 | 26 | ✅ | _function_call[0].country_ ~~'US'~~ 'us'
30 | 27 | ✅ | ✅
31 | 28 | ✅ | ✅
32 | 29 | _function_call[0].amount_ ~~1500 [int]~~ 1500.0 [float] | ✅
33 | 30 | ✅ | ✅
34 | 31 | ✅ | ✅
35 | 32 | ✅ | ✅
36 | 33 | ✅ | ✅
37 | 34 | ✅ | ✅
38 | 35 | ✅ | ✅
39 | 36 | _function_call[0].rating_ ~~7 [int]~~ 7.0 [float] | ✅
40 | 37 | _function_call[0].date_range['start_date']_ ~~'2022-02-01'~~ '2023-02-20' ⸱ _function_call[0].date_range['end_date']_ ~~'2022-02-08'~~ '2023-02-27' | _function_call[0].date_range['start_date']_ ~~'2022-02-01'~~ '2023-09-30' ⸱ _function_call[0].date_range['end_date']_ ~~'2022-02-08'~~ '2023-10-07'
41 | 38 | ✅ | ✅
42 | 39 | ✅ | ✅
43 | 40 | _function_call[0].event_date_ ~~'2022-04-15'~~ 'today' | _function_call[0].event_date_ ~~'2022-04-15'~~ '2023-10-06'
44 | 41 | ✅ | ✅
45 | 42 | ✅ | ✅
46 | 43 | ✅ | ✅
47 | 44 | ✅ | ✅
48 | 45 | ✅ | ✅
49 | 46 | ✅ | ✅
50 | 47 | ✅ | ✅
51 | 48 | ✅ | ✅
52 | 49 | _function_call[0].query_ ~~''~~ 'comedy' | ➖ ~~_function_call[0].query_ ''~~
53 | 50 | ✅ | _function_call[0].url_ ~~'www.example.com'~~ 'http://www.example.com'
54 | 51 | ✅ | _function_call[0].username_ ~~'@JohnDoe'~~ 'JohnDoe'
55 | 52 | ✅ | ✅
56 | 53 | _function_call[0].source_language_ ~~'fr'~~ 'French' ⸱ _function_call[0].target_language_ ~~'en'~~ 'English' | ✅
57 | 54 | ✅ | ✅
58 | 55 | ✅ | ✅
59 | 56 | ✅ | ✅
60 | 57 | ✅ | _function_call[0].country_ ~~'United States'~~ 'us'
61 | 58 | _function_call[0].event_location_ ~~'conference room in our office'~~ 'Conference room in our office' | _function_call[0].event_date_ ~~'15th of next month'~~ '2023-11-15' ⸱ _function_call[0].event_time_ ~~'10 AM'~~ '10:00 AM' ⸱ _function_call[0].event_location_ ~~'conference room in our office'~~ 'Conference Room, Office'
62 | 59 | ✅ | ✅
63 | 60 | ✅ | ✅
64 | 61 | ✅ | _function_call[0].locations[0]_ ~~'Brooklyn'~~ 'Brooklyn, NY' ⸱ _function_call[0].locations[1]_ ~~'Manhattan'~~ 'Manhattan, NY' ⸱ _function_call[0].locations[2]_ ~~'Queens'~~ 'Queens, NY' ⸱ _function_call[0].locations[3]_ ~~'Brooklyn'~~ 'Brooklyn, NY'
65 | 62 | _function_call[0].image_ ~~'user_image'~~ 'The image you sent' | ➕ _tool_call[0]['error']_ {'error': "Parsing tool_calls: KeyError('tool_calls')", 'completion_message': {'role': 'assistant', 'content': 'Please provide the image of the barcode so I can proceed with scanning it.'}} ⸱ ➖ ~~_function_call[0]._ {'name': 'scan_barcode', 'arguments': {'image': 'user_image'}}~~ ⸱ _tool_call[0]['type']_ ~~'function'~~ 'error'
66 | 63 | ✅ | ✅
67 | 64 | ✅ | ✅
68 | 65 | ✅ | ✅
69 | 66 | _function_call[0].meal_ ~~'pizza'~~ 'lunch' ⸱ _function_call[0].date_ ~~'2022-03-01'~~ 'today' | _function_call[0].date_ ~~'2022-03-01'~~ '2023-10-10'
70 | 67 | ✅ | ✅
71 | 68 | ✅ | ✅
72 | 69 | _function_call[0].language_ ~~'English'~~ 'en' | _function_call[0].language_ ~~'English'~~ 'en'
73 | 70 | ✅ | ✅
74 | 71 | _function_call[0].order_items[0]['product_name']_ ~~'laptop'~~ 'Laptop' | ✅
75 | 72 | ✅ | ✅
76 | 73 | ✅ | _function_call[0].background_color_ ~~'white'~~ '#FFFFFF' ⸱ _function_call[0].foreground_color_ ~~'black'~~ '#000000'
77 | 74 | _function_call[0].items[0]['name']_ ~~'apple'~~ 'apples' ⸱ _function_call[0].items[0]['price']_ ~~0.5~~ 1.0 ⸱ _function_call[0].items[1]['name']_ ~~'orange'~~ 'oranges' ⸱ _function_call[0].items[1]['price']_ ~~0.75~~ 0.5 | _function_call[0].items[0]['price']_ ~~0.5~~ 1.0 ⸱ _function_call[0].items[1]['price']_ ~~0.75~~ 0.5
78 | 75 | ✅ | ✅
79 | 76 | ✅ | ✅
80 | 77 | ✅ | ✅
81 | 78 | ✅ | ✅
82 | 79 | ✅ | _function_call[0].start_date_ ~~'1st June'~~ '2023-06-01' ⸱ _function_call[0].end_date_ ~~'10th June'~~ '2023-06-10'
83 | 80 | ✅ | ✅
84 | 81 | ✅ | ✅
85 | 82 | ✅ | ✅
86 | 83 | ➕ _function_call[0].source_currency_ 'USD' ⸱ ➖ ~~_function_call[0].base_currency_ 'USD'~~ ⸱ _function_call[0].['name']_ ~~'get_currency_conversion_rate'~~ 'convert_currency' | ➕ _function_call[0].source_currency_ 'USD' ⸱ ➖ ~~_function_call[0].base_currency_ 'USD'~~ ⸱ _function_call[0].['name']_ ~~'get_currency_conversion_rate'~~ 'convert_currency'
87 | 84 | ✅ | ✅
88 | 85 | _function_call[0].location_ ~~'main office'~~ 'our main office' | ✅
89 | 86 | ✅ | ✅
90 | 87 | ✅ | ✅
91 | 88 | ✅ | ✅
92 | 89 | ✅ | _tool_call[1]_ ➕ {'type': 'function', 'function': {'name': 'get_news', 'arguments': {'interests': ['sports'], 'location': 'New York'}}} ⸱ _function_call[0].interests[1]_ ➖ ~~'sports'~~
93 | 90 | ✅ | ✅
94 | 91 | ✅ | ✅
95 | 92 | ✅ | ✅
96 | 93 | ✅ | ✅
97 | 94 | ✅ | ✅
98 | 95 | ✅ | ✅
99 | 96 | ✅ | ✅
100 | 97 | ✅ | ✅
101 | 98 | ✅ | ✅
102 | 99 | ✅ | ✅
103 | pass | 84 (84.0%) | 82 (82.0%)
104 |
--------------------------------------------------------------------------------
/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/report-single_turn.md:
--------------------------------------------------------------------------------
1 | case | mlx-community/Meta-Llama-3-8B-Instruct-4bit | gpt-4o-2024-05-13
2 | --- | --- | ---
3 | 0 | ✅ | ✅
4 | 1 | ✅ | ✅
5 | 2 | ✅ | ✅
6 | 3 | ✅ | ✅
7 | 4 | ➕ _function_call[0].limit_ 100 ⸱ ➖ ~~_function_call[0].relationship_ 'siblings'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | _function_call[0].relationship_ ~~'siblings'~~ 'sibling_domains'
8 | 5 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'referrer_files'~~ 'includes' | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'referrer_files'~~ 'files'
9 | 6 | _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅
10 | 7 | ✅ | ✅
11 | 8 | ✅ | ➕ _function_call[0].ip_ '192.0.2.1' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '192.0.2.1'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address'
12 | 9 | _function_call[0].relationship_ ~~'referrer_files'~~ 'has_file' | _function_call[0].relationship_ ~~'referrer_files'~~ 'files'
13 | 10 | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain'
14 | 11 | ✅ | ➕ _function_call[0].ip_ '203.0.113.0' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '203.0.113.0'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address'
15 | 12 | ✅ | ✅
16 | 13 | ✅ | ✅
17 | 14 | ✅ | _function_call[0].ip_ ~~'http://www.example.org'~~ '93.184.216.34'
18 | 15 | ✅ | ✅
19 | 16 | _function_call[0].ip_ ~~'12.234.56.126'~~ '22.242.75.136' | _function_call[0].ip_ ~~'12.234.56.126'~~ '22.242.75.136'
20 | 17 | _function_call[0].relationship_ ~~'urls'~~ 'related_to' | ✅
21 | 18 | ✅ | ✅
22 | 19 | _function_call[0].ip_ ~~'explorerweb.org'~~ 'http://explorerweb.org' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'explorerweb.org'~~ 'http://explorerweb.org'
23 | 20 | ✅ | ✅
24 | 21 | ✅ | ✅
25 | 22 | _function_call[0].relationship_ ~~'referrer_files'~~ 'contains' | _function_call[0].relationship_ ~~'referrer_files'~~ 'communicating_files'
26 | 23 | ➖ ~~_function_call[0].relationship_ 'downloaded_files'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ✅
27 | 24 | _function_call[0].relationship_ ~~'siblings'~~ 'sibling' | _function_call[0].relationship_ ~~'siblings'~~ 'sibling_domains'
28 | 25 | ➕ _function_call[0].limit_ 100 ⸱ ➕ _function_call[0].cursor_ '' ⸱ _function_call[0].x-apikey_ ~~'delta_key'~~ 'your_delta_key' | ✅
29 | 26 | ✅ | ✅
30 | 27 | ✅ | ➕ _function_call[0].ip_ '44.55.66.77' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '44.55.66.77'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address'
31 | 28 | ➖ ~~_function_call[0].relationship_ 'graphs'~~ ⸱ ➖ ~~_function_call[0].x-apikey_ 'sec_key2'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_votes_on_ip_address' | ✅
32 | 29 | ✅ | ✅
33 | 30 | ✅ | ✅
34 | 31 | ✅ | ✅
35 | 32 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅
36 | 33 | _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅
37 | 34 | ✅ | ➕ _function_call[0].ip_ '10.0.0.1' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '10.0.0.1'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address'
38 | 35 | ✅ | ✅
39 | 36 | ✅ | ✅
40 | 37 | ✅ | ✅
41 | 38 | ✅ | ✅
42 | 39 | ✅ | ✅
43 | 40 | ➕ _function_call[0].limit_ 100 | ✅
44 | 41 | ✅ | ➕ _function_call[0].domain_ 'mysite.io' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'mysite.io'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report'
45 | 42 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_objects_related_to_domain' ⸱ _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅
46 | 43 | ✅ | ✅
47 | 44 | ✅ | ✅
48 | 45 | ✅ | ✅
49 | 46 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ ➖ ~~_function_call[0].x-apikey_ 'gamma_key'~~ ⸱ ➖ ~~_function_call[0].cursor_ 'next_page'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_votes_on_ip_address' | ✅
50 | 47 | ✅ | ➕ _function_call[0].domain_ 'samplepage.net' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'samplepage.net'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report'
51 | 48 | ✅ | ✅
52 | 49 | ➖ ~~_function_call[0].cursor_ 'start_cursor'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'downloaded_files'~~ 'downloaded_from' | ➖ ~~_function_call[0].cursor_ 'start_cursor'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address'
53 | 50 | ➖ ~~_function_call[0].relationship_ 'parent'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ➕ _function_call[0].limit_ 1
54 | 51 | ✅ | ✅
55 | 52 | _function_call[0].relationship_ ~~'caa_records'~~ 'CAA' | ✅
56 | 53 | ✅ | ✅
57 | 54 | ✅ | ✅
58 | 55 | ✅ | ✅
59 | 56 | ✅ | ✅
60 | 57 | ✅ | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report'
61 | 58 | ➕ _function_call[0].limit_ 0 ⸱ ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅
62 | 59 | ✅ | _function_call[0].ip_ ~~'https://www.example.org'~~ '93.184.216.34'
63 | 60 | ➖ ~~_function_call[0].relationship_ 'siblings'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | _function_call[0].relationship_ ~~'siblings'~~ 'sibling_domains'
64 | 61 | _function_call[0].ip_ ~~'viewpage.net'~~ 'http://viewpage.net' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'viewpage.net'~~ 'http://viewpage.net'
65 | 62 | ✅ | ✅
66 | 63 | ✅ | ✅
67 | 64 | ➖ ~~_function_call[0].relationship_ 'caa_records'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ✅
68 | 65 | ✅ | _tool_call[1]_ ➕ {'type': 'function', 'function': {'name': 'vt_get_comments_on_domain', 'arguments': {'domain': 'reddit.com', 'x-apikey': 'reddit_api_key'}}}
69 | 66 | ✅ | ✅
70 | 67 | ✅ | ✅
71 | 68 | _function_call[0].relationship_ ~~'historical_whois'~~ 'whois' ⸱ _function_call[0].x-apikey_ ~~'elite_api'~~ 'your_api_key' | _function_call[0].relationship_ ~~'historical_whois'~~ 'whois'
72 | 69 | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain'
73 | 70 | ✅ | ✅
74 | 71 | ➕ _function_call[0].limit_ 100 | ✅
75 | 72 | ✅ | ✅
76 | 73 | ✅ | ✅
77 | 74 | ✅ | ✅
78 | 75 | ✅ | ✅
79 | 76 | _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_object_descriptors_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'related_threat_actors'~~ 'threat_actor' | _function_call[0].relationship_ ~~'related_threat_actors'~~ 'threat_actors'
80 | 77 | ➖ ~~_function_call[0].relationship_ 'subdomains'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅
81 | 78 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅
82 | 79 | ➕ _function_call[0].id_ 'dns_resolution_object_id' ⸱ ➖ ~~_function_call[0].domain_ 'site5.info'~~ ⸱ ➖ ~~_function_call[0].relationship_ 'resolutions'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_dns_resolution_object' | ✅
83 | 80 | ✅ | ✅
84 | 81 | ✅ | ✅
85 | 82 | ✅ | ✅
86 | 83 | ✅ | ✅
87 | 84 | _function_call[0].relationship_ ~~'historical_whois'~~ 'whois' | ✅
88 | 85 | ➕ _function_call[0].id_ 'yahoo.com' ⸱ ➖ ~~_function_call[0].domain_ 'yahoo.com'~~ ⸱ ➖ ~~_function_call[0].relationship_ 'resolutions'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_dns_resolution_object' | ✅
89 | 86 | _function_call[0].relationship_ ~~'referrer_files'~~ 'contains' | _function_call[0].relationship_ ~~'referrer_files'~~ 'files'
90 | 87 | _function_call[0].ip_ ~~'digdeep.io'~~ 'http://digdeep.io' | ➕ _function_call[0].domain_ 'digdeep.io' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'digdeep.io'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report'
91 | 88 | ✅ | ✅
92 | 89 | _function_call[0].ip_ ~~'surfthis.net'~~ 'http://surfthis.net' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'surfthis.net'~~ 'http://surfthis.net'
93 | 90 | _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅
94 | 91 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅
95 | 92 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl_certificate' | _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl_certificates'
96 | 93 | _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_object_descriptors_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl-certificate' | _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl_certificates'
97 | 94 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'referrer_files'~~ 'REFERENCES' | ✅
98 | 95 | _function_call[0].id_ ~~'10.10.10.10linked.site'~~ '10.10.10.10_linked.site' | ✅
99 | 96 | _function_call[0].ip_ ~~'checkthisout.net'~~ 'http://checkthisout.net' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'checkthisout.net'~~ 'http://checkthisout.net'
100 | 97 | _function_call[0].domain_ ~~'sample.org'~~ 'sample.com' ⸱ _function_call[0].relationship_ ~~'cname_records'~~ 'dns_resolution' | _function_call[0].domain_ ~~'sample.org'~~ 'sample.com'
101 | 98 | _function_call[0].relationship_ ~~'communicating_files'~~ 'file' | ✅
102 | 99 | _function_call[0].x-apikey_ ~~'eta_key'~~ 'your_api_key' | ✅
103 | 100 | ➖ ~~_function_call[0].relationship_ 'historical_whois'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_ip_address_report' | ➖ ~~_function_call[0].relationship_ 'historical_whois'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_ip_address_report'
104 | 101 | ✅ | ✅
105 | 102 | ➖ ~~_function_call[0].x-apikey_ 'KEY123'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_ip_address_report'~~ 'vt_get_votes_on_ip_address' | ✅
106 | 103 | ✅ | ✅
107 | 104 | ✅ | ✅
108 | 105 | ✅ | ✅
109 | 106 | _function_call[0].ip_ ~~'inspectlink.com'~~ 'http://inspectlink.com' | ➕ _function_call[0].domain_ 'inspectlink.com' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'inspectlink.com'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report'
110 | 107 | ✅ | ✅
111 | 108 | ➖ ~~_function_call[0].relationship_ 'historical_whois'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ✅
112 | 109 | ✅ | ✅
113 | 110 | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain'
114 | 111 | ➕ _function_call[0].limit_ 100 ⸱ ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅
115 | pass | 59 (52.68%) | 77 (68.75%)
116 |
--------------------------------------------------------------------------------
/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/requirements.txt:
--------------------------------------------------------------------------------
1 | pyarrow
2 |
--------------------------------------------------------------------------------
/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/single_turn-00000-of-00001.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/otriscon/llm-structured-output/037e8eb7447005fda06e7d811b041efcb94b0cef/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/single_turn-00000-of-00001.parquet
--------------------------------------------------------------------------------
/src/tests/eval_api.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | """
3 | Run a tool use evaluation using an LLM with an OpenAI-like API.
4 | """
5 | import argparse
6 | import json
7 | import time
8 | import requests
9 |
10 | from llm_structured_output.util.output import info, inverse, debug, warning
11 |
12 | from .eval_report import eval_completion
13 |
14 |
15 | def run_eval_case(
16 | api_url,
17 | api_key,
18 | model_name,
19 | case,
20 | header,
21 | temp=0,
22 | seed=0,
23 | stream=False,
24 | out=None,
25 | ):
26 | options = case.get("options", {})
27 | prompt_includes_schema = options.get("prompt_includes_schema", False)
28 |
29 | payload = {
30 | "model": model_name,
31 | "messages": case["prompt"],
32 | "tools": case["tools"],
33 | "tool_choice": "auto",
34 | "temperature": temp,
35 | "seed": seed,
36 | }
37 | if stream:
38 | payload["stream"] = True
39 | payload["stream_options"] = {"include_usage": True}
40 | if prompt_includes_schema and "api.openai.com" not in api_url:
41 | # Non-standard option, should not be set for OpenAI API.
42 | payload["tool_options"] = {
43 | # Do not dump the schema again, since it's already in the prompt
44 | "no_prompt_steering": True,
45 | }
46 |
47 | info(f"{header} Sending API request...")
48 | start_time = time.time_ns()
49 |
50 | r = requests.post(
51 | f"{api_url}/v1/chat/completions",
52 | json=payload,
53 | headers={"Authorization": f"Bearer {api_key}"},
54 | timeout=60,
55 | stream=stream,
56 | )
57 | if stream:
58 | response = None
59 | tool_calls = []
60 | for line in r.iter_lines(decode_unicode=True):
61 | if not line:
62 | continue
63 | if not line.startswith("data:"):
64 | warning("Expected all server-sent events to start with 'data:'")
65 | line = line[5:].strip()
66 | if line == "[DONE]":
67 | break
68 | message = json.loads(line)
69 | if response is None:
70 | response = message
71 | elif "usage" in message:
72 | response["usage"] = message["usage"]
73 | if not message["choices"]:
74 | continue
75 | tool_deltas = message["choices"][0]["delta"].get("tool_calls", [])
76 | if len(tool_deltas) > 1:
77 | warning(
78 | f"Expected updates for one tool_call at a time, got multiple: {tool_deltas=}"
79 | )
80 | if tool_deltas:
81 | tool_delta = tool_deltas[0]
82 | index = tool_delta["index"]
83 | argument_delta = tool_delta["function"]["arguments"]
84 | if index == len(tool_calls):
85 | tool_calls.append(tool_delta)
86 | tool_name = tool_delta["function"][
87 | "name"
88 | ] # name may not be present in additional updates
89 | debug(
90 | f"[call #{index}]\nname: {tool_name}\narguments: {argument_delta}",
91 | end="",
92 | )
93 | elif index == len(tool_calls) - 1:
94 | tool_calls[index]["function"]["arguments"] += argument_delta
95 | debug(argument_delta, end="")
96 | else:
97 | warning(
98 | f"Unexpected tool_delta out of sequence: "
99 | f"current_index={len(tool_calls)-1} {tool_delta=}"
100 | )
101 | response["choices"] = [
102 | {"message": {"role": "assistant", "tool_calls": tool_calls}}
103 | ]
104 | debug()
105 | else:
106 | response = r.json()
107 | debug(response)
108 |
109 | total_time = (time.time_ns() - start_time) / 1e6
110 | prompt_tokens = response["usage"]["prompt_tokens"]
111 | completion_tokens = response["usage"]["completion_tokens"]
112 | info(f"{header} {prompt_tokens=} {completion_tokens=} {total_time=:.02f}")
113 |
114 | if out:
115 | json.dump(response, out)
116 | out.write("\n")
117 |
118 | diff = eval_completion(case, response)
119 | if diff:
120 | inverse(f"{header} DIFF:", diff)
121 | return False
122 | else:
123 | info(f"{header} PASS")
124 | return True
125 |
126 |
127 | def main():
128 | parser = argparse.ArgumentParser(
129 | description="Run a function calling evaluation with the Fireworks AI dataset or similar"
130 | )
131 | parser.add_argument(
132 | "--api-url",
133 | type=str,
134 | default="https://api.openai.com",
135 | help="The URL of the API server",
136 | )
137 | parser.add_argument(
138 | "--api-key",
139 | type=str,
140 | default=None,
141 | help="The URL of the API server",
142 | )
143 | parser.add_argument(
144 | "--model-name",
145 | type=str,
146 | default="gpt-4o",
147 | help="The name of the model to use",
148 | )
149 | parser.add_argument(
150 | "--dataset-path",
151 | required=True,
152 | type=str,
153 | help="The path to the evaluation dataset (JSONL)",
154 | )
155 | parser.add_argument(
156 | "--skip",
157 | type=int,
158 | default=0,
159 | help="Start at the given evaluation case number",
160 | )
161 | parser.add_argument(
162 | "--count",
163 | type=int,
164 | default=None,
165 | help="Limit the number of cases to run",
166 | )
167 | parser.add_argument(
168 | "--temp",
169 | help="The sampling temperature.",
170 | type=float,
171 | default=0.0,
172 | )
173 | parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
174 | parser.add_argument(
175 | "--stream",
176 | help="Use streaming API.",
177 | action=argparse.BooleanOptionalAction,
178 | default=False,
179 | )
180 | parser.add_argument(
181 | "--output-file",
182 | help="Write completions to JSONL file.",
183 | type=str,
184 | default=None,
185 | )
186 | args = parser.parse_args()
187 |
188 | out = None
189 | if args.output_file:
190 | out = open(args.output_file, mode="w", encoding="utf-8")
191 |
192 | with open(args.dataset_path, encoding="utf-8") as dataset:
193 | if args.count:
194 | end_index = args.skip + args.count
195 | else:
196 | end_index = None
197 | pass_count = 0
198 | fail_count = 0
199 | t0 = time.time_ns()
200 | for i, line in enumerate(dataset.readlines()):
201 | if i < args.skip:
202 | continue
203 | if end_index is not None and i == end_index:
204 | break
205 | case = json.loads(line)
206 | if run_eval_case(
207 | args.api_url,
208 | args.api_key,
209 | args.model_name,
210 | case,
211 | f"[{i}]",
212 | temp=args.temp,
213 | seed=args.seed,
214 | stream=args.stream,
215 | out=out,
216 | ):
217 | pass_count += 1
218 | else:
219 | fail_count += 1
220 | average_time = (time.time_ns() - t0) / 1e9 / (pass_count + fail_count)
221 | info(f"Totals: {pass_count=} {fail_count=} {average_time=:.02}s")
222 |
223 | if out:
224 | out.close()
225 |
226 |
227 | main()
228 |
--------------------------------------------------------------------------------
/src/tests/eval_local.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | """
3 | Run a tool use evaluation using a local LLM.
4 | """
5 | import argparse
6 | import json
7 | import time
8 |
9 | from examples.llm_schema import Model
10 | from llm_structured_output.util.output import info, bold, inverse, debug
11 |
12 | from .eval_report import eval_tool_calls
13 |
14 |
15 | def run_eval_case(model, case, header, temp=None, seed=None, preemptive_batch_size=0):
16 | messages = case["prompt"]
17 | tools = case["tools"]
18 | options = case.get("options", {})
19 | prompt_includes_schema = options.get("prompt_includes_schema", False)
20 | single_tool = options.get("single_tool", False)
21 |
22 | tool_schemas = [
23 | {
24 | "type": "object",
25 | "properties": {
26 | "name": {
27 | "type": "const",
28 | "const": tool["function"]["name"],
29 | },
30 | "arguments": tool["function"]["parameters"],
31 | },
32 | "required": ["name", "arguments"],
33 | }
34 | for tool in tools
35 | ]
36 |
37 | separator = "\n"
38 | if single_tool:
39 | schema = {"anyOf": tool_schemas}
40 | if not prompt_includes_schema:
41 | schema_message = f"""
42 | You are a helpful assistant with access to tools that you must invoke to answer the user's request.
43 | The following tools are available:
44 | {separator.join([ f'''
45 | Tool {repr(tool[tool["type"]]["name"])}: {tool[tool["type"]]["description"]}
46 | Invocation schema: {json.dumps(tool_schema)}
47 | ''' for tool, tool_schema in zip(tools, tool_schemas) ])}
48 | Your answer is a JSON object according to the invocation schema of the most appropriate tool to use
49 | to answer the user request below.
50 | """
51 | print(json.dumps(schema, indent=2)) ###
52 | print(schema_message) ###
53 | messages.insert(0, {"role": "system", "message": schema_message})
54 | else:
55 | tool_call_schemas = [
56 | {
57 | "type": "object",
58 | "properties": {
59 | "type": {
60 | "type": "const",
61 | "const": tool["type"],
62 | },
63 | tool["type"]: tool_schema,
64 | },
65 | "required": ["type", tool["type"]],
66 | }
67 | for tool, tool_schema in zip(tools, tool_schemas)
68 | ]
69 | schema = {
70 | "type": "array",
71 | "items": {"anyOf": tool_call_schemas},
72 | }
73 | if not prompt_includes_schema:
74 | schema_message = f"""
75 | You are a helpful assistant with access to tools that you must invoke to answer the user's request.
76 | The following tools are available:
77 | {separator.join([ f'''
78 | Tool {repr(tool[tool["type"]]["name"])}: {tool[tool["type"]]["description"]}
79 | Invocation schema: {json.dumps(tool_call_schema)}
80 | ''' for tool, tool_call_schema in zip(tools, tool_call_schemas) ])}
81 | Your answer is a JSON array with one or more tool invocations according to the appropriate schema(s)
82 | in order to answer the user request below.
83 | """
84 | print(json.dumps(schema, indent=2)) ###
85 | print(schema_message) ###
86 | messages.insert(0, {"role": "system", "message": schema_message})
87 |
88 | info(f"{header} Starting generation...")
89 | content = ""
90 | prompt_tokens = 0
91 | completion_tokens = 0
92 | completion_time = 0
93 | start_time = time.time_ns()
94 |
95 | for result in model.completion(
96 | messages,
97 | schema=schema,
98 | max_tokens=4000,
99 | temp=temp,
100 | seed=seed,
101 | preemptive_batch_size=preemptive_batch_size,
102 | cache_prompt=True,
103 | ):
104 | if result["op"] == "evaluatedPrompt":
105 | prompt_tokens += result["token_count"]
106 | prompt_time = result["time_ms"]
107 | elif result["op"] == "generatedTokens":
108 | completion_tokens += result["token_count"]
109 | completion_time += result["time_ms"]
110 | content += result["text"]
111 | bold(result["text"], end="", flush=True)
112 | elif result["op"] == "stop":
113 | print()
114 | else:
115 | debug(f"{result=}")
116 | assert False
117 |
118 | total_time = (time.time_ns() - start_time) / 1e6
119 | prompt_tps = prompt_tokens / prompt_time * 1e3
120 | completion_tps = completion_tokens / completion_time * 1e3
121 | info(
122 | f"{header} {prompt_tokens=} {prompt_tps=:.02f} {completion_tokens=} {completion_tps=:.02f}"
123 | f" {prompt_time=:.02f} {completion_time=:.02f} {total_time=:.02f}"
124 | )
125 |
126 | tool_calls = json.loads(content)
127 | if single_tool:
128 | tool_calls = [{"type": "function", "function": tool_calls}]
129 |
130 | diff = eval_tool_calls(case, tool_calls)
131 | if diff:
132 | inverse(f"{header} DIFF:", diff)
133 | return False
134 | else:
135 | info(f"{header} PASS")
136 | return True
137 |
138 |
139 | def main():
140 | parser = argparse.ArgumentParser(
141 | description="Run a function calling evaluation with the Fireworks AI dataset or similar"
142 | )
143 | parser.add_argument(
144 | "--model-path",
145 | type=str,
146 | default="mlx_model",
147 | help="The path to the model weights and tokenizer",
148 | )
149 | parser.add_argument(
150 | "--dataset-path",
151 | required=True,
152 | type=str,
153 | help="The path to the evaluation dataset (JSONL)",
154 | )
155 | parser.add_argument(
156 | "--skip",
157 | type=int,
158 | default=0,
159 | help="Start at the given evaluation case number",
160 | )
161 | parser.add_argument(
162 | "--count",
163 | type=int,
164 | default=None,
165 | help="Limit the number of cases to run",
166 | )
167 | parser.add_argument(
168 | "--temp",
169 | help="The sampling temperature.",
170 | type=float,
171 | default=0.0,
172 | )
173 | parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
174 | parser.add_argument(
175 | "--preemptive",
176 | type=int,
177 | default=0,
178 | help="If greater than zero, the maximum size of the batch for pre-emptive decoding",
179 | )
180 | args = parser.parse_args()
181 |
182 | info("Loading model...")
183 | model = Model()
184 | model.load(args.model_path)
185 |
186 | with open(args.dataset_path, encoding="utf-8") as dataset:
187 | if args.count:
188 | end_index = args.skip + args.count
189 | else:
190 | end_index = None
191 | pass_count = 0
192 | fail_count = 0
193 | t0 = time.time_ns()
194 | for i, line in enumerate(dataset.readlines()):
195 | if i < args.skip:
196 | continue
197 | if end_index is not None and i == end_index:
198 | break
199 | case = json.loads(line)
200 | if run_eval_case(
201 | model,
202 | case,
203 | f"[{i}]",
204 | temp=args.temp,
205 | seed=args.seed,
206 | preemptive_batch_size=args.preemptive,
207 | ):
208 | pass_count += 1
209 | else:
210 | fail_count += 1
211 | average_time = (time.time_ns() - t0) / 1e9 / (pass_count + fail_count)
212 | info(f"Totals: {pass_count=} {fail_count=} {average_time=:.02}s")
213 |
214 |
215 | main()
216 |
--------------------------------------------------------------------------------
/src/tests/eval_report.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | """
3 | Create a markdown report from an evaluation dataset and one or more completions.
4 | """
5 | import argparse
6 | import json
7 | import re
8 | import sys
9 |
10 | from deepdiff import DeepDiff
11 |
12 |
13 | def eval_tool_calls(case, tool_calls):
14 | single_tool = case.get("options", {}).get("single_tool", False)
15 |
16 | best_diff_count = 1e10
17 | for gold_tool_calls in case["gold"]:
18 | if single_tool:
19 | # The gold set in the source dataset is a single tool invocation instead of an array.
20 | # We could use the legacy function_call method to force a single function call, but
21 | # we think it's better to evaluate the model for non-legacy tool use. If the model
22 | # comes up with multi-tool solutions that are deemed acceptable, we can then:
23 | # - Remove this flag for this evaluation case,
24 | # - Wrap each existing gold value for this case in an array,
25 | # - Add the new solution that has multiple invocations to the gold set for the case.
26 | gold_tool_calls = [gold_tool_calls]
27 | diff = DeepDiff(gold_tool_calls, tool_calls, verbose_level=2)
28 | if diff is None:
29 | best_diff = None
30 | best_diff_count = 0
31 | break
32 | else:
33 | diff_count = diff.get_stats()["DIFF COUNT"]
34 | if diff_count < best_diff_count:
35 | best_diff_count = diff_count
36 | best_diff = diff
37 | return best_diff
38 |
39 |
40 | def eval_completion(case, completion):
41 | try:
42 | completion_tool_calls = completion["choices"][0]["message"]["tool_calls"]
43 | except (KeyError, TypeError) as e:
44 | sys.stderr.write(
45 | f"Completion object doesn't match expected format: {completion=}\n"
46 | )
47 | completion_tool_calls = [
48 | {
49 | "type": "error",
50 | "error": {
51 | "error": f"Parsing tool_calls: {repr(e)}",
52 | "completion_message": completion["choices"][0]["message"],
53 | },
54 | }
55 | ]
56 |
57 | # Remove call metadata (currently only id) to compare with gold.
58 | # Note that we expect the gold set in the evaluation dataset to have
59 | # deserialized function arguments rather than as a string.
60 | tool_calls = [
61 | (
62 | {
63 | "type": "function",
64 | "function": {
65 | "name": tool_call["function"]["name"],
66 | "arguments": json.loads(tool_call["function"]["arguments"]),
67 | },
68 | }
69 | if tool_call["type"] == "function"
70 | else {
71 | "type": tool_call["type"],
72 | tool_call["type"]: tool_call[tool_call["type"]],
73 | }
74 | )
75 | for tool_call in completion_tool_calls
76 | ]
77 |
78 | return eval_tool_calls(case, tool_calls)
79 |
80 |
81 | CHANGE_FORMATTERS = {
82 | "type_changes": lambda path, change: f"_{path}_ ~~{repr(change['old_value'])} [{change['old_type'].__name__}]~~ {repr(change['new_value'])} [{change['new_type'].__name__}]",
83 | "values_changed": lambda path, change: f"_{path}_ ~~{repr(change['old_value'])}~~ {repr(change['new_value'])}",
84 | "dictionary_item_added": lambda path, change: f"➕ _{path}_ {repr(change)}",
85 | "dictionary_item_removed": lambda path, change: f"➖ ~~_{path}_ {repr(change)}~~",
86 | "iterable_item_added": lambda path, change: f"_{path}_ ➕ {repr(change)}",
87 | "iterable_item_removed": lambda path, change: f"_{path}_ ➖ ~~{repr(change)}~~",
88 | "set_item_added": lambda path, change: f"_{path}_ ➕ {repr(change)}",
89 | "set_item_removed": lambda path, change: f"_{path}_ ➖ ~~{repr(change)}~~",
90 | }
91 |
92 |
93 | def diff_to_md(diff):
94 | if not diff:
95 | return "✅"
96 | md_changes = []
97 | for change_type, changes in diff.items():
98 | formatter = CHANGE_FORMATTERS[change_type]
99 | for path, change in changes.items():
100 | path = re.sub(r"root\[(\d*)\]\['function'\]", "function_call[\\1].", path)
101 | path = re.sub(r"root\[(\d*)\]", "tool_call[\\1]", path)
102 | path = re.sub(r"\['arguments'\]\['([^']*)']", "\\1", path)
103 | md_changes.append(formatter(path, change))
104 | return " ⸱ ".join(md_changes)
105 |
106 |
107 | def report_eval_case(
108 | case,
109 | completions,
110 | index,
111 | out,
112 | ):
113 | eval_diffs = [eval_completion(case, completion) for completion in completions]
114 | columns = [diff_to_md(diff) for diff in eval_diffs]
115 | out.write(f"{index} | {' | '.join(columns)}\n")
116 | results = [not diff for diff in eval_diffs]
117 | return results
118 |
119 |
120 | def main():
121 | parser = argparse.ArgumentParser(
122 | description="Run a function calling evaluation with the Fireworks AI dataset or similar"
123 | )
124 | parser.add_argument(
125 | "--dataset-path",
126 | required=True,
127 | type=str,
128 | help="The path to the evaluation dataset (JSONL)",
129 | )
130 | parser.add_argument(
131 | "completions",
132 | metavar="completion_files",
133 | type=str,
134 | nargs="+",
135 | help="One or more jsonl files with completions for the evaluation dataset",
136 | )
137 | parser.add_argument(
138 | "--output-file",
139 | help="Write report to a file instead of stdout",
140 | type=str,
141 | default=None,
142 | )
143 | args = parser.parse_args()
144 |
145 | input_files = [open(filename, encoding="utf-8") for filename in args.completions]
146 |
147 | out = sys.stdout
148 | if args.output_file:
149 | out = open(args.output_file, mode="w", encoding="utf-8")
150 |
151 | i = 0
152 | with open(args.dataset_path, encoding="utf-8") as dataset:
153 | for i, line in enumerate(dataset.readlines()):
154 | case = json.loads(line)
155 | completions = [
156 | json.loads(input_file.readline()) for input_file in input_files
157 | ]
158 | if i == 0:
159 | sum_results = [0 for completion in completions]
160 | models = [completion["model"] for completion in completions]
161 | out.write(f"case | {' | '.join(models)}\n")
162 | out.write(f"--- | {' | '.join(['---'] * len(models))}\n")
163 | results = report_eval_case(case, completions, i, out)
164 | sum_results = [sum_results[i] + result for i, result in enumerate(results)]
165 | total = i + 1
166 | out.write(
167 | f"pass | {' | '.join([f'{r} ({round(100*r/total, 2)}%)' for r in sum_results])}\n"
168 | )
169 |
170 | for input_file in input_files:
171 | input_file.close()
172 | if out:
173 | out.close()
174 |
175 |
176 | if __name__ == "__main__":
177 | main()
178 |
--------------------------------------------------------------------------------
/src/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | mlx-lm >= 0.19.2
2 | tokenizers >= 0.20.1
3 | sentencepiece
4 | deepdiff
5 | requests
6 |
--------------------------------------------------------------------------------