├── .artefacts └── cerebras-api-notes.txt ├── .gitignore ├── LICENSE ├── README.md ├── llm_cerebras ├── __init__.py └── cerebras.py ├── pyproject.toml ├── setup.py ├── test_results ├── basic_schema.json ├── complex_schema.json ├── multi_schema.json ├── person_schema.json ├── schema_by_id.json ├── schema_tests_summary.md ├── schema_with_descriptions.json ├── template_schema.json ├── test_cerebras_schema.py └── test_cerebras_version.py └── tests ├── test_automated_user.py ├── test_cerebras.py ├── test_integration.py └── test_schema_support.py /.artefacts/cerebras-api-notes.txt: -------------------------------------------------------------------------------- 1 | Notes for cerebras-api 2 | 3 | ... 4 | # Writing a plugin to support a new model 5 | 6 | This tutorial will walk you through developing a new plugin for LLM that adds support for a new Large Language Model. 7 | 8 | We will be developing a plugin that implements a simple [Markov chain](https://en.wikipedia.org/wiki/Markov_chain) to generate words based on an input string. Markov chains are not technically large language models, but they provide a useful exercise for demonstrating how the LLM tool can be extended through plugins. 9 | 10 | ## The initial structure of the plugin 11 | 12 | First create a new directory with the name of your plugin - it should be called something like `llm-markov`. 13 | ```bash 14 | mkdir llm-markov 15 | cd llm-markov 16 | ``` 17 | In that directory create a file called `llm_markov.py` containing this: 18 | 19 | ```python 20 | import llm 21 | 22 | @llm.hookimpl 23 | def register_models(register): 24 | register(Markov()) 25 | 26 | class Markov(llm.Model): 27 | model_id = "markov" 28 | 29 | def execute(self, prompt, stream, response, conversation): 30 | return ["hello world"] 31 | ``` 32 | 33 | The `def register_models()` function here is called by the plugin system (thanks to the `@hookimpl` decorator). It uses the `register()` function passed to it to register an instance of the new model. 34 | 35 | The `Markov` class implements the model. It sets a `model_id` - an identifier that can be passed to `llm -m` in order to identify the model to be executed. 36 | 37 | The logic for executing the model goes in the `execute()` method. We'll extend this to do something more useful in a later step. 38 | 39 | Next, create a `pyproject.toml` file. This is necessary to tell LLM how to load your plugin: 40 | 41 | ```toml 42 | [project] 43 | name = "llm-markov" 44 | version = "0.1" 45 | 46 | [project.entry-points.llm] 47 | markov = "llm_markov" 48 | ``` 49 | 50 | This is the simplest possible configuration. It defines a plugin name and provides an [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html) for `llm` telling it how to load the plugin. 51 | 52 | If you are comfortable with Python virtual environments you can create one now for your project, activate it and run `pip install llm` before the next step. 53 | 54 | If you aren't familiar with virtual environments, don't worry: you can develop plugins without them. You'll need to have LLM installed using Homebrew or `pipx` or one of the [other installation options](https://llm.datasette.io/en/latest/setup.html#installation). 55 | 56 | ## Installing your plugin to try it out 57 | 58 | Having created a directory with a `pyproject.toml` file and an `llm_markov.py` file, you can install your plugin into LLM by running this from inside your `llm-markov` directory: 59 | 60 | ```bash 61 | llm install -e . 62 | ``` 63 | 64 | The `-e` stands for "editable" - it means you'll be able to make further changes to the `llm_markov.py` file that will be reflected without you having to reinstall the plugin. 65 | 66 | The `.` means the current directory. You can also install editable plugins by passing a path to their directory this: 67 | ```bash 68 | llm install -e path/to/llm-markov 69 | ``` 70 | To confirm that your plugin has installed correctly, run this command: 71 | ```bash 72 | llm plugins 73 | ``` 74 | The output should look like this: 75 | ```json 76 | [ 77 | { 78 | "name": "llm-markov", 79 | "hooks": [ 80 | "register_models" 81 | ], 82 | "version": "0.1" 83 | }, 84 | { 85 | "name": "llm.default_plugins.openai_models", 86 | "hooks": [ 87 | "register_commands", 88 | "register_models" 89 | ] 90 | } 91 | ] 92 | ``` 93 | This command lists default plugins that are included with LLM as well as new plugins that have been installed. 94 | 95 | Now let's try the plugin by running a prompt through it: 96 | ```bash 97 | llm -m markov "the cat sat on the mat" 98 | ``` 99 | It outputs: 100 | ``` 101 | hello world 102 | ``` 103 | Next, we'll make it execute and return the results of a Markov chain. 104 | 105 | ## Building the Markov chain 106 | 107 | Markov chains can be thought of as the simplest possible example of a generative language model. They work by building an index of words that have been seen following other words. 108 | 109 | Here's what that index looks like for the phrase "the cat sat on the mat" 110 | ```json 111 | { 112 | "the": ["cat", "mat"], 113 | "cat": ["sat"], 114 | "sat": ["on"], 115 | "on": ["the"] 116 | } 117 | ``` 118 | Here's a Python function that builds that data structure from a text input: 119 | ```python 120 | def build_markov_table(text): 121 | words = text.split() 122 | transitions = {} 123 | # Loop through all but the last word 124 | for i in range(len(words) - 1): 125 | word = words[i] 126 | next_word = words[i + 1] 127 | transitions.setdefault(word, []).append(next_word) 128 | return transitions 129 | ``` 130 | We can try that out by pasting it into the interactive Python interpreter and running this: 131 | ```pycon 132 | >>> transitions = build_markov_table("the cat sat on the mat") 133 | >>> transitions 134 | {'the': ['cat', 'mat'], 'cat': ['sat'], 'sat': ['on'], 'on': ['the']} 135 | ``` 136 | ## Executing the Markov chain 137 | 138 | To execute the model, we start with a word. We look at the options for words that might come next and pick one of those at random. Then we repeat that process until we have produced the desired number of output words. 139 | 140 | Some words might not have any following words from our training sentence. For our implementation we will fall back on picking a random word from our collection. 141 | 142 | We will implement this as a [Python generator](https://realpython.com/introduction-to-python-generators/), using the yield keyword to produce each token: 143 | ```python 144 | def generate(transitions, length, start_word=None): 145 | all_words = list(transitions.keys()) 146 | next_word = start_word or random.choice(all_words) 147 | for i in range(length): 148 | yield next_word 149 | options = transitions.get(next_word) or all_words 150 | next_word = random.choice(options) 151 | ``` 152 | If you aren't familiar with generators, the above code could also be implemented like this - creating a Python list and returning it at the end of the function: 153 | ```python 154 | def generate_list(transitions, length, start_word=None): 155 | all_words = list(transitions.keys()) 156 | next_word = start_word or random.choice(all_words) 157 | output = [] 158 | for i in range(length): 159 | output.append(next_word) 160 | options = transitions.get(next_word) or all_words 161 | next_word = random.choice(options) 162 | return output 163 | ``` 164 | You can try out the `generate()` function like this: 165 | ```python 166 | lookup = build_markov_table("the cat sat on the mat") 167 | for word in generate(transitions, 20): 168 | print(word) 169 | ``` 170 | Or you can generate a full string sentence with it like this: 171 | ```python 172 | sentence = " ".join(generate(transitions, 20)) 173 | ``` 174 | ## Adding that to the plugin 175 | 176 | Our `execute()` method from earlier currently returns the list `["hello world"]`. 177 | 178 | Update that to use our new Markov chain generator instead. Here's the full text of the new `llm_markov.py` file: 179 | 180 | ```python 181 | import llm 182 | import random 183 | 184 | @llm.hookimpl 185 | def register_models(register): 186 | register(Markov()) 187 | 188 | def build_markov_table(text): 189 | words = text.split() 190 | transitions = {} 191 | # Loop through all but the last word 192 | for i in range(len(words) - 1): 193 | word = words[i] 194 | next_word = words[i + 1] 195 | transitions.setdefault(word, []).append(next_word) 196 | return transitions 197 | 198 | def generate(transitions, length, start_word=None): 199 | all_words = list(transitions.keys()) 200 | next_word = start_word or random.choice(all_words) 201 | for i in range(length): 202 | yield next_word 203 | options = transitions.get(next_word) or all_words 204 | next_word = random.choice(options) 205 | 206 | class Markov(llm.Model): 207 | model_id = "markov" 208 | 209 | def execute(self, prompt, stream, response, conversation): 210 | text = prompt.prompt 211 | transitions = build_markov_table(text) 212 | for word in generate(transitions, 20): 213 | yield word + ' ' 214 | ``` 215 | The `execute()` method can access the text prompt that the user provided using` prompt.prompt` - `prompt` is a `Prompt` object that might include other more advanced input details as well. 216 | 217 | Now when you run this you should see the output of the Markov chain! 218 | ```bash 219 | llm -m markov "the cat sat on the mat" 220 | ``` 221 | ``` 222 | the mat the cat sat on the cat sat on the mat cat sat on the mat cat sat on 223 | ``` 224 | 225 | ## Understanding execute() 226 | 227 | The full signature of the `execute()` method is: 228 | ```python 229 | def execute(self, prompt, stream, response, conversation): 230 | ``` 231 | The `prompt` argument is a `Prompt` object that contains the text that the user provided, the system prompt and the provided options. 232 | 233 | `stream` is a boolean that says if the model is being run in streaming mode. 234 | 235 | `response` is the `Response` object that is being created by the model. This is provided so you can write additional information to `response.response_json`, which may be logged to the database. 236 | 237 | `conversation` is the `Conversation` that the prompt is a part of - or `None` if no conversation was provided. Some models may use `conversation.responses` to access previous prompts and responses in the conversation and use them to construct a call to the LLM that includes previous context. 238 | 239 | ## Prompts and responses are logged to the database 240 | 241 | The prompt and the response will be logged to a SQLite database automatically by LLM. You can see the single most recent addition to the logs using: 242 | ``` 243 | llm logs -n 1 244 | ``` 245 | The output should look something like this: 246 | ```json 247 | [ 248 | { 249 | "id": "01h52s4yez2bd1qk2deq49wk8h", 250 | "model": "markov", 251 | "prompt": "the cat sat on the mat", 252 | "system": null, 253 | "prompt_json": null, 254 | "options_json": {}, 255 | "response": "on the cat sat on the cat sat on the mat cat sat on the cat sat on the cat ", 256 | "response_json": null, 257 | "conversation_id": "01h52s4yey7zc5rjmczy3ft75g", 258 | "duration_ms": 0, 259 | "datetime_utc": "2023-07-11T15:29:34.685868", 260 | "conversation_name": "the cat sat on the mat", 261 | "conversation_model": "markov" 262 | } 263 | ] 264 | ``` 265 | Plugins can log additional information to the database by assigning a dictionary to the `response.response_json` property during the `execute()` method. 266 | 267 | Here's how to include that full `transitions` table in the `response_json` in the log: 268 | ```python 269 | def execute(self, prompt, stream, response, conversation): 270 | text = self.prompt.prompt 271 | transitions = build_markov_table(text) 272 | for word in generate(transitions, 20): 273 | yield word + ' ' 274 | response.response_json = {"transitions": transitions} 275 | ``` 276 | 277 | Now when you run the logs command you'll see that too: 278 | ```bash 279 | llm logs -n 1 280 | ``` 281 | ```json 282 | [ 283 | { 284 | "id": 623, 285 | "model": "markov", 286 | "prompt": "the cat sat on the mat", 287 | "system": null, 288 | "prompt_json": null, 289 | "options_json": {}, 290 | "response": "on the mat the cat sat on the cat sat on the mat sat on the cat sat on the ", 291 | "response_json": { 292 | "transitions": { 293 | "the": [ 294 | "cat", 295 | "mat" 296 | ], 297 | "cat": [ 298 | "sat" 299 | ], 300 | "sat": [ 301 | "on" 302 | ], 303 | "on": [ 304 | "the" 305 | ] 306 | } 307 | }, 308 | "reply_to_id": null, 309 | "chat_id": null, 310 | "duration_ms": 0, 311 | "datetime_utc": "2023-07-06T01:34:45.376637" 312 | } 313 | ] 314 | ``` 315 | In this particular case this isn't a great idea here though: the `transitions` table is duplicate information, since it can be reproduced from the input data - and it can get really large for longer prompts. 316 | 317 | ## Adding options 318 | 319 | LLM models can take options. For large language models these can be things like `temperature` or `top_k`. 320 | 321 | Options are passed using the `-o/--option` command line parameters, for example: 322 | ```bash 323 | llm -m gpt4 "ten pet pelican names" -o temperature 1.5 324 | ``` 325 | We're going to add two options to our Markov chain model: 326 | 327 | - `length`: Number of words to generate 328 | - `delay`: a floating point number of Delay in between output token 329 | 330 | The `delay` token will let us simulate a streaming language model, where tokens take time to generate and are returned by the `execute()` function as they become ready. 331 | 332 | Options are defined using an inner class on the model, called `Options`. It should extend the `llm.Options` class. 333 | 334 | First, add this import to the top of your `llm_markov.py` file: 335 | ```python 336 | from typing import Optional 337 | ``` 338 | Then add this `Options` class to your model: 339 | ```python 340 | class Markov(Model): 341 | model_id = "markov" 342 | 343 | class Options(llm.Options): 344 | length: Optional[int] = None 345 | delay: Optional[float] = None 346 | ``` 347 | Let's add extra validation rules to our options. Length must be at least 2. Duration must be between 0 and 10. 348 | 349 | The `Options` class uses [Pydantic 2](https://pydantic.org/), which can support all sorts of advanced validation rules. 350 | 351 | We can also add inline documentation, which can then be displayed by the `llm models --options` command. 352 | 353 | Add these imports to the top of `llm_markov.py`: 354 | ```python 355 | from pydantic import field_validator, Field 356 | ``` 357 | 358 | We can now add Pydantic field validators for our two new rules, plus inline documentation: 359 | 360 | ```python 361 | class Options(llm.Options): 362 | length: Optional[int] = Field( 363 | description="Number of words to generate", 364 | default=None 365 | ) 366 | delay: Optional[float] = Field( 367 | description="Seconds to delay between each token", 368 | default=None 369 | ) 370 | 371 | @field_validator("length") 372 | def validate_length(cls, length): 373 | if length is None: 374 | return None 375 | if length < 2: 376 | raise ValueError("length must be >= 2") 377 | return length 378 | 379 | @field_validator("delay") 380 | def validate_delay(cls, delay): 381 | if delay is None: 382 | return None 383 | if not 0 <= delay <= 10: 384 | raise ValueError("delay must be between 0 and 10") 385 | return delay 386 | ``` 387 | Lets test our options validation: 388 | ```bash 389 | llm -m markov "the cat sat on the mat" -o length -1 390 | ``` 391 | ``` 392 | Error: length 393 | Value error, length must be >= 2 394 | ``` 395 | 396 | Next, we will modify our `execute()` method to handle those options. Add this to the beginning of `llm_markov.py`: 397 | ```python 398 | import time 399 | ``` 400 | Then replace the `execute()` method with this one: 401 | ```python 402 | def execute(self, prompt, stream, response, conversation): 403 | text = prompt.prompt 404 | transitions = build_markov_table(text) 405 | length = prompt.options.length or 20 406 | for word in generate(transitions, length): 407 | yield word + ' ' 408 | if prompt.options.delay: 409 | time.sleep(prompt.options.delay) 410 | ``` 411 | Add `can_stream = True` to the top of the `Markov` model class, on the line below `model_id = "markov". This tells LLM that the model is able to stream content to the console. 412 | 413 | The full `llm_markov.py` file should now look like this: 414 | 415 | ```{literalinclude} llm-markov/llm_markov.py 416 | :language: python 417 | ``` 418 | 419 | Now we can request a 20 word completion with a 0.1s delay between tokens like this: 420 | ```bash 421 | llm -m markov "the cat sat on the mat" \ 422 | -o length 20 -o delay 0.1 423 | ``` 424 | LLM provides a `--no-stream` option users can use to turn off streaming. Using that option causes LLM to gather the response from the stream and then return it to the console in one block. You can try that like this: 425 | ```bash 426 | llm -m markov "the cat sat on the mat" \ 427 | -o length 20 -o delay 0.1 --no-stream 428 | ``` 429 | In this case it will still delay for 2s total while it gathers the tokens, then output them all at once. 430 | 431 | That `--no-stream` option causes the `stream` argument passed to `execute()` to be false. Your `execute()` method can then behave differently depending on whether it is streaming or not. 432 | 433 | Options are also logged to the database. You can see those here: 434 | ```bash 435 | llm logs -n 1 436 | ``` 437 | ```json 438 | [ 439 | { 440 | "id": 636, 441 | "model": "markov", 442 | "prompt": "the cat sat on the mat", 443 | "system": null, 444 | "prompt_json": null, 445 | "options_json": { 446 | "length": 20, 447 | "delay": 0.1 448 | }, 449 | "response": "the mat on the mat on the cat sat on the mat sat on the mat cat sat on the ", 450 | "response_json": null, 451 | "reply_to_id": null, 452 | "chat_id": null, 453 | "duration_ms": 2063, 454 | "datetime_utc": "2023-07-07T03:02:28.232970" 455 | } 456 | ] 457 | ``` 458 | 459 | ## Distributing your plugin 460 | 461 | There are many different options for distributing your new plugin so other people can try it out. 462 | 463 | You can create a downloadable wheel or `.zip` or `.tar.gz` files, or share the plugin through GitHub Gists or repositories. 464 | 465 | You can also publish your plugin to PyPI, the Python Package Index. 466 | 467 | ### Wheels and sdist packages 468 | 469 | The easiest option is to produce a distributable package is to use the `build` command. First, install the `build` package by running this: 470 | ```bash 471 | python -m pip install build 472 | ``` 473 | Then run `build` in your plugin directory to create the packages: 474 | ```bash 475 | python -m build 476 | ``` 477 | This will create two files: `dist/llm-markov-0.1.tar.gz` and `dist/llm-markov-0.1-py3-none-any.whl`. 478 | 479 | Either of these files can be used to install the plugin: 480 | 481 | ```bash 482 | llm install dist/llm_markov-0.1-py3-none-any.whl 483 | ``` 484 | If you host this file somewhere online other people will be able to install it using `pip install` against the URL to your package: 485 | ```bash 486 | llm install 'https://.../llm_markov-0.1-py3-none-any.whl' 487 | ``` 488 | You can run the following command at any time to uninstall your plugin, which is useful for testing out different installation methods: 489 | ```bash 490 | llm uninstall llm-markov -y 491 | ``` 492 | 493 | ### GitHub Gists 494 | 495 | A neat quick option for distributing a simple plugin is to host it in a GitHub Gist. These are available for free with a GitHub account, and can be public or private. Gists can contain multiple files but don't support directory structures - which is OK, because our plugin is just two files, `pyproject.toml` and `llm_markov.py`. 496 | 497 | Here's an example Gist I created for this tutorial: 498 | 499 | [https://gist.github.com/simonw/6e56d48dc2599bffba963cef0db27b6d](https://gist.github.com/simonw/6e56d48dc2599bffba963cef0db27b6d) 500 | 501 | You can turn a Gist into an installable `.zip` URL by right-clicking on the "Download ZIP" button and selecting "Copy Link". Here's that link for my example Gist: 502 | 503 | `https://gist.github.com/simonw/6e56d48dc2599bffba963cef0db27b6d/archive/cc50c854414cb4deab3e3ab17e7e1e07d45cba0c.zip` 504 | 505 | The plugin can be installed using the `llm install` command like this: 506 | ```bash 507 | llm install 'https://gist.github.com/simonw/6e56d48dc2599bffba963cef0db27b6d/archive/cc50c854414cb4deab3e3ab17e7e1e07d45cba0c.zip' 508 | ``` 509 | 510 | ## GitHub repositories 511 | 512 | The same trick works for regular GitHub repositories as well: the "Download ZIP" button can be found by clicking the green "Code" button at the top of the repository. The URL which that provide scan then be used to install the plugin that lives in that repository. 513 | 514 | ## Publishing plugins to PyPI 515 | 516 | The [Python Package Index (PyPI)](https://pypi.org/) is the official repository for Python packages. You can upload your plugin to PyPI and reserve a name for it - once you have done that, anyone will be able to install your plugin using `llm install `. 517 | 518 | Follow [these instructions](https://packaging.python.org/en/latest/tutorials/packaging-projects/#uploading-the-distribution-archives) to publish a package to PyPI. The short version: 519 | ```bash 520 | python -m pip install twine 521 | python -m twine upload dist/* 522 | ``` 523 | You will need an account on PyPI, then you can enter your username and password - or create a token in the PyPI settings and use `__token__` as the username and the token as the password. 524 | 525 | ## Adding metadata 526 | 527 | Before uploading a package to PyPI it's a good idea to add documentation and expand `pyproject.toml` with additional metadata. 528 | 529 | Create a `README.md` file in the root of your plugin directory with instructions about how to install, configure and use your plugin. 530 | 531 | You can then replace `pyproject.toml` with something like this: 532 | 533 | ```toml 534 | [project] 535 | name = "llm-markov" 536 | version = "0.1" 537 | description = "Plugin for LLM adding a Markov chain generating model" 538 | readme = "README.md" 539 | authors = [{name = "Simon Willison"}] 540 | license = {text = "Apache-2.0"} 541 | classifiers = [ 542 | "License :: OSI Approved :: Apache Software License" 543 | ] 544 | dependencies = [ 545 | "llm" 546 | ] 547 | requires-python = ">3.7" 548 | 549 | [project.urls] 550 | Homepage = "https://github.com/simonw/llm-markov" 551 | Changelog = "https://github.com/simonw/llm-markov/releases" 552 | Issues = "https://github.com/simonw/llm-markov/issues" 553 | 554 | [project.entry-points.llm] 555 | markov = "llm_markov" 556 | ``` 557 | This will pull in your README to be displayed as part of your project's listing page on PyPI. 558 | 559 | It adds `llm` as a dependency, ensuring it will be installed if someone tries to install your plugin package without it. 560 | 561 | It adds some links to useful pages (you can drop the `project.urls` section if those links are not useful for your project). 562 | 563 | You should drop a `LICENSE` file into the GitHub repository for your package as well. I like to use the Apache 2 license [like this](https://github.com/simonw/llm/blob/main/LICENSE). 564 | 565 | ## What to do if it breaks 566 | 567 | Sometimes you may make a change to your plugin that causes it to break, preventing `llm` from starting. For example you may see an error like this one: 568 | 569 | ``` 570 | $ llm 'hi' 571 | Traceback (most recent call last): 572 | ... 573 | File llm-markov/llm_markov.py", line 10 574 | register(Markov()): 575 | ^ 576 | SyntaxError: invalid syntax 577 | ``` 578 | You may find that you are unable to uninstall the plugin using `llm uninstall llm-markov` because the command itself fails with the same error. 579 | 580 | Should this happen, you can uninstall the plugin after first disabling it using the {ref}`LLM_LOAD_PLUGINS ` environment variable like this: 581 | ```bash 582 | LLM_LOAD_PLUGINS='' llm uninstall llm-markov 583 | ``` 584 | 585 | ... 586 | curl --location 'https://api.cerebras.ai/v1/chat/completions' \ 587 | --header 'Content-Type: application/json' \ 588 | --header "Authorization: Bearer ${CEREBRAS_API_KEY}" \ 589 | --data '{ 590 | "model": "llama3.1-8b", 591 | "stream": false, 592 | "messages": [{"content": "why is fast inference important?", "role": "user"}], 593 | "temperature": 0, 594 | "max_tokens": -1, 595 | "seed": 0, 596 | "top_p": 1 597 | }' 598 | 599 | ... 600 | 601 | Starter Kit home pagelight logo 602 | 603 | Search or ask... 604 | Ctrl K 605 | Documentation 606 | API Reference 607 | Endpoints 608 | 609 | Chat Completions 610 | Models 611 | Endpoints 612 | Chat Completions 613 | 614 | Python 615 | 616 | Node.js 617 | 618 | cURL 619 | 620 | curl --location 'https://api.cerebras.ai/v1/chat/completions' \ 621 | --header 'Content-Type: application/json' \ 622 | --header "Authorization: Bearer ${CEREBRAS_API_KEY}" \ 623 | --data '{ 624 | "model": "llama3.1-8b", 625 | "stream": false, 626 | "messages": [{"content": "Hello!", "role": "user"}], 627 | "temperature": 0, 628 | "max_tokens": -1, 629 | "seed": 0, 630 | "top_p": 1 631 | }' 632 | 633 | Response 634 | 635 | { 636 | "id": "chatcmpl-292e278f-514e-4186-9010-91ce6a14168b", 637 | "choices": [ 638 | { 639 | "finish_reason": "stop", 640 | "index": 0, 641 | "message": { 642 | "content": "Hello! How can I assist you today?", 643 | "role": "assistant" 644 | } 645 | } 646 | ], 647 | "created": 1723733419, 648 | "model": "llama3.1-8b", 649 | "system_fingerprint": "fp_70185065a4", 650 | "object": "chat.completion", 651 | "usage": { 652 | "prompt_tokens": 12, 653 | "completion_tokens": 10, 654 | "total_tokens": 22 655 | }, 656 | "time_info": { 657 | "queue_time": 0.000073161, 658 | "prompt_time": 0.0010744798888888889, 659 | "completion_time": 0.005658071111111111, 660 | "total_time": 0.022224903106689453, 661 | "created": 1723733419 662 | } 663 | } 664 | messages 665 | object[] 666 | required 667 | A list of messages comprising the conversation so far. 668 | model 669 | string 670 | required 671 | Available options: llama3.1-8b, llama3.1-70b 672 | max_tokens 673 | integer | null 674 | The maximum number of tokens that can be generated in the completion. The total length of input tokens and generated tokens is limited by the model’s context length. 675 | response_format 676 | object | null 677 | Setting to { "type": "json_object" } enables JSON mode, which ensures that the response is either a valid JSON object or an error response. 678 | 679 | Note that enabling JSON mode does not guarantee that the model will successfully generate valid JSON. The model may fail to generate valid JSON due to various reasons such as incorrect formatting, missing or mismatched brackets, or exceeding the length limit. 680 | 681 | In cases where the model fails to generate valid JSON, the error response will be a valid JSON object with a key failed_generation containing the string representing the invalid JSON. This allows you to re-submit the request with additional prompting to correct the issue. The error response will have a 400 server error status code. 682 | 683 | Note that JSON mode is not compatible with streaming. "stream" must be set to false. 684 | 685 | Important: When using JSON mode, you need to explicitly instruct the model to generate JSON through a system or user message. 686 | seed 687 | integer | null 688 | If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed. 689 | stop 690 | string | null 691 | Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. 692 | stream 693 | boolean | null 694 | If set, partial message deltas will be sent. 695 | temperature 696 | number | null 697 | What sampling temperature to use, between 0 and 1.5. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. 698 | top_p 699 | number | null 700 | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. 701 | tool_choice 702 | string | object 703 | Controls which (if any) tool is called by the model. none means the model will not call any tool and instead generates a message. auto means the model can pick between generating a message or calling one or more tools. required means the model must call one or more tools. Specifying a particular tool via {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool. 704 | 705 | none is the default when no tools are present. auto is the default if tools are present. 706 | tools 707 | object | null 708 | A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. 709 | 710 | Specifying tools consumes prompt tokens in the context. If too many are given, the model may perform poorly or you may hit context length limitations 711 | 712 | 713 | Show properties 714 | user 715 | string | null 716 | A unique identifier representing your end-user, which can help to monitor and detect abuse. 717 | 718 | ... 719 | 720 | Starter Kit home pagelight logo 721 | 722 | Search or ask... 723 | Ctrl K 724 | Documentation 725 | API Reference 726 | Endpoints 727 | 728 | Chat Completions 729 | Models 730 | Endpoints 731 | Models 732 | 733 | Python 734 | 735 | import os 736 | from cerebras.cloud.sdk import Cerebras 737 | 738 | client = Cerebras(api_key=os.environ.get("CEREBRAS_API_KEY"),) 739 | 740 | client.models.list() 741 | client.models.retrieve("llama3.1-8b") 742 | 743 | list Response 744 | 745 | retrieve Response 746 | 747 | { 748 | "object": "list", 749 | "data": [ 750 | { 751 | "id": "llama3.1-8b", 752 | "object": "model", 753 | "created": 1721692800, 754 | "owned_by": "Meta" 755 | }, 756 | { 757 | "id": "llama3.1-70b", 758 | "object": "model", 759 | "created": 1721692800, 760 | "owned_by": "Meta" 761 | } 762 | ] 763 | } 764 | list 765 | GET https://api.cerebras.ai/v1/models 766 | 767 | Lists the currently available models and provides essential details about each, including the owner and availability. 768 | retrieve 769 | GET https://api.cerebras.ai/v1/models/{model} 770 | 771 | Fetches a model instance, offering key details about the model, including its owner and permissions. 772 | 773 | Accepts model IDs as arguments. Available options: llama3.1-8b, llama3.1-70b 774 | 775 | ... 776 | mixtral-8x7b-32768 777 | 778 | ... 779 | verview 780 | The Cerebras API offers developers a low-latency solution for AI model inference powered by Cerebras Wafer-Scale Engines and CS-3 systems. We invite developers to explore the new possibilities that our high-speed inferencing solution unlocks. 781 | 782 | Currently, the Cerebras API provides access to two models: Meta’s Llama 3.1 8B and 70B models. Both models are instruction-tuned and can be used for conversational applications. 783 | 784 | Llama 3.1 8B 785 | Model ID: llama3.1-8b 786 | Parameters: 8 billion 787 | Knowledge cutoff: March 2023 788 | Context Length: 8192 789 | Training Tokens: 15 trillion 790 | Llama 3.1 70B 791 | Model ID: llama3.1-70b 792 | Parameters: 70 billion 793 | Knowledge cutoff: December 2023 794 | Context Length: 8192 795 | Training Tokens: 15 trillion 796 | Due to high demand in our early launch phase, we are temporarily limiting Llama 3.1 models to a context window of 8192 in our Free Tier. If your use case or application would benefit from longer context windows, please let us know! 797 | QuickStart Guide 798 | Get started by building your first application using our QuickStart guide. 799 | 800 | 801 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__/ 3 | *.py[cod] 4 | *.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | 24 | venv/ 25 | ENV/ 26 | 27 | 28 | .idea/ 29 | 30 | 31 | .vscode/ 32 | 33 | 34 | .ipynb_checkpoints 35 | 36 | 37 | .pytest_cache/ 38 | 39 | 40 | htmlcov/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | 45 | 46 | .DS_Store 47 | *.log 48 | .artefacts/.agent 49 | .agent 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | ... 8 | 9 | END OF TERMS AND CONDITIONS 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-cerebras 2 | 3 | This is a plugin for [LLM](https://llm.datasette.io/) that adds support for the Cerebras inference API. 4 | 5 | ## Installation 6 | 7 | Install this plugin in the same environment as LLM. 8 | 9 | ```bash 10 | pip install llm-cerebras 11 | ``` 12 | 13 | ## Configuration 14 | 15 | You'll need to provide an API key for Cerebras. 16 | 17 | ```bash 18 | llm keys set cerebras 19 | ``` 20 | 21 | ## Listing available models 22 | 23 | ```bash 24 | llm models list | grep cerebras 25 | # CerebrasModel: cerebras-llama3.1-8b 26 | # CerebrasModel: cerebras-llama3.3-70b 27 | # CerebrasModel: cerebras-llama-4-scout-17b-16e-instruct 28 | # CerebrasModel: cerebras-deepseek-r1-distill-llama-70b 29 | ``` 30 | 31 | ## Schema Support 32 | 33 | The llm-cerebras plugin supports schemas for structured output. You can use either compact schema syntax or full JSON Schema: 34 | 35 | ```bash 36 | # Using compact schema syntax 37 | llm -m cerebras-llama3.3-70b 'invent a dog' --schema 'name, age int, breed' 38 | 39 | # Using multi-item schema for lists 40 | llm -m cerebras-llama3.3-70b 'invent three dogs' --schema-multi 'name, age int, breed' 41 | 42 | # Using full JSON Schema 43 | llm -m cerebras-llama3.3-70b 'invent a dog' --schema '{ 44 | "type": "object", 45 | "properties": { 46 | "name": {"type": "string"}, 47 | "age": {"type": "integer"}, 48 | "breed": {"type": "string"} 49 | }, 50 | "required": ["name", "age", "breed"] 51 | }' 52 | ``` 53 | 54 | ### Schema with Descriptions 55 | 56 | You can add descriptions to your schema fields to guide the model: 57 | 58 | ```bash 59 | llm -m cerebras-llama3.3-70b 'invent a famous scientist' --schema ' 60 | name: the full name including any titles 61 | field: their primary field of study 62 | year_born int: year of birth 63 | year_died int: year of death, can be null if still alive 64 | achievements: a list of their major achievements 65 | ' 66 | ``` 67 | 68 | ### Creating Schema Templates 69 | 70 | You can save schemas as templates for reuse: 71 | 72 | ```bash 73 | # Create a template 74 | llm -m cerebras-llama3.3-70b --schema 'title, director, year int, genre' --save movie_template 75 | 76 | # Use the template 77 | llm -t movie_template 'suggest a sci-fi movie from the 1980s' 78 | ``` 79 | 80 | ## Development 81 | 82 | To set up this plugin locally, first checkout the code. Then create a new virtual environment: 83 | 84 | ```bash 85 | cd llm-cerebras 86 | python -m venv venv 87 | source venv/bin/activate 88 | ``` 89 | 90 | Now install the dependencies and test dependencies: 91 | 92 | ```bash 93 | pip install -e '.[test]' 94 | ``` 95 | 96 | ### Running Tests 97 | 98 | To run the unit tests: 99 | 100 | ```bash 101 | pytest tests/test_cerebras.py tests/test_schema_support.py 102 | ``` 103 | 104 | To run integration tests (requires a valid API key): 105 | 106 | ```bash 107 | pytest tests/test_integration.py 108 | ``` 109 | 110 | To run automated user workflow tests: 111 | 112 | ```bash 113 | pytest tests/test_automated_user.py 114 | ``` 115 | 116 | You can run specific test types using markers: 117 | 118 | ```bash 119 | pytest -m "integration" # Run only integration tests 120 | pytest -m "user" # Run only user workflow tests 121 | ``` 122 | 123 | ## License 124 | 125 | Apache 2.0 126 | -------------------------------------------------------------------------------- /llm_cerebras/__init__.py: -------------------------------------------------------------------------------- 1 | from .cerebras import register_models 2 | 3 | -------------------------------------------------------------------------------- /llm_cerebras/cerebras.py: -------------------------------------------------------------------------------- 1 | import llm 2 | import httpx 3 | import json 4 | from pydantic import Field 5 | from typing import Optional, List, Dict, Any, Union 6 | import logging 7 | 8 | # Try to import jsonschema for validation 9 | try: 10 | import jsonschema 11 | HAVE_JSONSCHEMA = True 12 | except ImportError: 13 | HAVE_JSONSCHEMA = False 14 | logging.warning("jsonschema not installed, schema validation will be limited") 15 | 16 | @llm.hookimpl 17 | def register_models(register): 18 | for model_id in CerebrasModel.model_map.keys(): 19 | aliases = tuple() 20 | register(CerebrasModel(model_id), aliases=aliases) 21 | 22 | class CerebrasModel(llm.Model): 23 | can_stream = True 24 | model_id: str 25 | api_base = "https://api.cerebras.ai/v1" 26 | supports_schema = True # Enable schema support 27 | 28 | model_map = { 29 | "cerebras-llama3.1-8b": "llama3.1-8b", 30 | "cerebras-llama3.3-70b": "llama-3.3-70b", 31 | "cerebras-llama-4-scout-17b-16e-instruct": "llama-4-scout-17b-16e-instruct", 32 | "cerebras-deepseek-r1-distill-llama-70b": "DeepSeek-R1-Distill-Llama-70B", 33 | } 34 | 35 | class Options(llm.Options): 36 | temperature: Optional[float] = Field( 37 | description="What sampling temperature to use, between 0 and 1.5.", 38 | ge=0, 39 | le=1.5, 40 | default=0.7, 41 | ) 42 | max_tokens: Optional[int] = Field( 43 | description="The maximum number of tokens to generate.", 44 | default=None, 45 | ) 46 | top_p: Optional[float] = Field( 47 | description="An alternative to sampling with temperature, called nucleus sampling.", 48 | ge=0, 49 | le=1, 50 | default=1, 51 | ) 52 | seed: Optional[int] = Field( 53 | description="If specified, our system will make a best effort to sample deterministically.", 54 | default=None, 55 | ) 56 | 57 | def __init__(self, model_id): 58 | self.model_id = model_id 59 | 60 | def execute(self, prompt, stream, response, conversation): 61 | messages = self._build_messages(prompt, conversation) 62 | api_key = llm.get_key("", "cerebras", "CEREBRAS_API_KEY") 63 | 64 | headers = { 65 | "Content-Type": "application/json", 66 | "Authorization": f"Bearer {api_key}" 67 | } 68 | 69 | data = { 70 | "model": self.model_map[self.model_id], 71 | "messages": messages, 72 | "stream": stream, 73 | "temperature": prompt.options.temperature, 74 | "max_tokens": prompt.options.max_tokens, 75 | "top_p": prompt.options.top_p, 76 | "seed": prompt.options.seed, 77 | } 78 | 79 | # Handle schema using json_object mode 80 | if hasattr(prompt, 'schema') and prompt.schema: 81 | # Convert llm's concise schema format to JSON Schema if needed 82 | schema = self._process_schema(prompt.schema) 83 | 84 | # First try the native json_schema approach 85 | try_native_schema = False # Set to True to try native schema first 86 | 87 | if try_native_schema and not stream: # json_schema doesn't support streaming 88 | try: 89 | json_schema_data = data.copy() 90 | json_schema_data["response_format"] = { 91 | "type": "json_schema", 92 | "json_schema": { 93 | "strict": True, 94 | "schema": schema 95 | } 96 | } 97 | 98 | # Try the API with json_schema format 99 | url = f"{self.api_base}/chat/completions" 100 | r = httpx.post(url, json=json_schema_data, headers=headers, timeout=None) 101 | r.raise_for_status() 102 | content = r.json()["choices"][0]["message"]["content"] 103 | yield content 104 | return 105 | except httpx.HTTPStatusError: 106 | # If json_schema fails, fall back to json_object with instructions 107 | logging.info("json_schema format not supported yet, falling back to json_object with instructions") 108 | 109 | # Use json_object mode with schema in system message 110 | data["response_format"] = {"type": "json_object"} 111 | 112 | # Add schema instructions via system message if not already present 113 | schema_instructions = self._build_schema_instructions(schema) 114 | has_system = any(msg.get("role") == "system" for msg in messages) 115 | 116 | if not has_system: 117 | # Insert system message at the beginning 118 | messages.insert(0, {"role": "system", "content": schema_instructions}) 119 | data["messages"] = messages 120 | else: 121 | # Append schema instructions to existing system message 122 | for msg in messages: 123 | if msg.get("role") == "system": 124 | msg["content"] = msg["content"] + "\n\n" + schema_instructions 125 | break 126 | data["messages"] = messages 127 | 128 | url = f"{self.api_base}/chat/completions" 129 | 130 | if stream: 131 | with httpx.stream("POST", url, json=data, headers=headers, timeout=None) as r: 132 | for line in r.iter_lines(): 133 | if line.startswith("data: "): 134 | chunk = line[6:] 135 | if chunk != "[DONE]": 136 | content = json.loads(chunk)["choices"][0]["delta"].get("content") 137 | if content: 138 | yield content 139 | else: 140 | r = httpx.post(url, json=data, headers=headers, timeout=None) 141 | r.raise_for_status() 142 | content = r.json()["choices"][0]["message"]["content"] 143 | 144 | # If we have a schema, validate the response 145 | if hasattr(prompt, 'schema') and prompt.schema and not stream: 146 | try: 147 | # Parse the JSON content 148 | json_content = json.loads(content) 149 | # Validate against the schema 150 | schema = self._process_schema(prompt.schema) 151 | self._validate_schema(json_content, schema) 152 | # Return the validated JSON as a string 153 | content = json.dumps(json_content) 154 | except (json.JSONDecodeError, ValueError, jsonschema.exceptions.ValidationError) as e: 155 | logging.warning(f"Schema validation failed: {str(e)}") 156 | # Continue with the original content 157 | 158 | yield content 159 | 160 | def _build_messages(self, prompt, conversation) -> List[dict]: 161 | messages = [] 162 | if conversation: 163 | for response in conversation.responses: 164 | messages.extend([ 165 | {"role": "user", "content": response.prompt.prompt}, 166 | {"role": "assistant", "content": response.text()}, 167 | ]) 168 | messages.append({"role": "user", "content": prompt.prompt}) 169 | return messages 170 | 171 | def _process_schema(self, schema) -> Dict[str, Any]: 172 | """ 173 | Process schema from llm's format to a proper JSON Schema. 174 | """ 175 | if isinstance(schema, dict): 176 | return schema 177 | 178 | # If it's a string, check if it's a JSON string 179 | if isinstance(schema, str): 180 | try: 181 | return json.loads(schema) 182 | except json.JSONDecodeError: 183 | # This might be using llm's concise schema format 184 | # For now, convert it to a basic JSON schema 185 | properties = {} 186 | required = [] 187 | 188 | # Handle both comma-separated and newline-separated formats 189 | if "," in schema and "\n" not in schema: 190 | parts = [p.strip() for p in schema.split(",")] 191 | else: 192 | parts = [p.strip() for p in schema.split("\n") if p.strip()] 193 | 194 | for part in parts: 195 | # Handle field description format: name: description 196 | if ":" in part: 197 | field_def, description = part.split(":", 1) 198 | else: 199 | field_def, description = part, "" 200 | 201 | # Handle type annotations: name int, name float, etc. 202 | if " " in field_def: 203 | field_name, field_type = field_def.split(" ", 1) 204 | else: 205 | field_name, field_type = field_def, "string" 206 | 207 | # Map to JSON schema types 208 | type_mapping = { 209 | "int": "integer", 210 | "float": "number", 211 | "str": "string", 212 | "string": "string", 213 | "bool": "boolean", 214 | } 215 | json_type = type_mapping.get(field_type.lower(), "string") 216 | 217 | # Add to properties 218 | properties[field_name] = {"type": json_type} 219 | if description: 220 | properties[field_name]["description"] = description.strip() 221 | 222 | # All fields are required by default in llm's schema format 223 | required.append(field_name) 224 | 225 | return { 226 | "type": "object", 227 | "properties": properties, 228 | "required": required 229 | } 230 | 231 | # Default empty schema 232 | return {"type": "object", "properties": {}} 233 | 234 | def _build_schema_instructions(self, schema: Dict[str, Any]) -> str: 235 | """ 236 | Generate instructions for the model to follow the schema. 237 | """ 238 | instructions = "You are a helpful assistant that returns responses in JSON format. " 239 | instructions += "Your response must follow this schema exactly:\n" 240 | 241 | # Format the schema as a readable instruction 242 | if schema.get("type") == "object" and "properties" in schema: 243 | properties = schema.get("properties", {}) 244 | required = schema.get("required", []) 245 | 246 | instructions += "{\n" 247 | for prop_name, prop_details in properties.items(): 248 | prop_type = prop_details.get("type", "string") 249 | prop_desc = prop_details.get("description", "") 250 | is_required = prop_name in required 251 | 252 | instructions += f' "{prop_name}": {prop_type}' 253 | if prop_desc: 254 | instructions += f" // {prop_desc}" 255 | if is_required: 256 | instructions += " (required)" 257 | instructions += ",\n" 258 | instructions += "}\n" 259 | else: 260 | # Fallback to JSON representation 261 | instructions += json.dumps(schema, indent=2) 262 | 263 | instructions += "\nYour response must be valid JSON and follow this schema exactly. Do not include any explanations or text outside of the JSON structure." 264 | return instructions 265 | 266 | def _validate_schema(self, data: Any, schema: Dict[str, Any]) -> bool: 267 | """ 268 | Validate the response against the schema. 269 | """ 270 | if HAVE_JSONSCHEMA: 271 | try: 272 | jsonschema.validate(instance=data, schema=schema) 273 | return True 274 | except jsonschema.exceptions.ValidationError as e: 275 | raise ValueError(f"Schema validation failed: {str(e)}") 276 | else: 277 | # Basic validation if jsonschema is not available 278 | if schema.get("type") == "object" and "properties" in schema: 279 | properties = schema.get("properties", {}) 280 | required = schema.get("required", []) 281 | 282 | # Check required fields 283 | for field in required: 284 | if field not in data: 285 | raise ValueError(f"Required field '{field}' is missing from response") 286 | 287 | # Check field types (simplified) 288 | for field, value in data.items(): 289 | if field in properties: 290 | prop_type = properties[field].get("type") 291 | if prop_type == "string" and not isinstance(value, str): 292 | raise ValueError(f"Field '{field}' should be a string") 293 | elif prop_type == "integer" and not isinstance(value, int): 294 | raise ValueError(f"Field '{field}' should be an integer") 295 | elif prop_type == "number" and not isinstance(value, (int, float)): 296 | raise ValueError(f"Field '{field}' should be a number") 297 | elif prop_type == "boolean" and not isinstance(value, bool): 298 | raise ValueError(f"Field '{field}' should be a boolean") 299 | elif prop_type == "array" and not isinstance(value, list): 300 | raise ValueError(f"Field '{field}' should be an array") 301 | elif prop_type == "object" and not isinstance(value, dict): 302 | raise ValueError(f"Field '{field}' should be an object") 303 | 304 | return True 305 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "llm-cerebras" 3 | version = "0.1.7" 4 | description = "Plugin for LLM adding fast Cerebras inference API support" 5 | readme = "README.md" 6 | authors = [{name = "Thomas (Thomasthomas) Hughes"}] 7 | license = {text = "Apache-2.0"} 8 | classifiers = [ 9 | "License :: OSI Approved :: Apache Software License" 10 | ] 11 | dependencies = [ 12 | "llm", 13 | "httpx", 14 | ] 15 | requires-python = ">3.7" 16 | 17 | [project.urls] 18 | Homepage = "https://github.com/irthomasthomas/llm-cerebras" 19 | Changelog = "https://github.com/irthomasthomas/llm-cerebras/releases" 20 | Issues = "https://github.com/irthomasthomas/llm-cerebras/issues" 21 | 22 | [project.entry-points.llm] 23 | cerebras = "llm_cerebras.cerebras" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="llm-cerebras", 5 | version="0.1.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "llm", 9 | "httpx", 10 | ], 11 | entry_points={ 12 | "llm": [ 13 | "cerebras=llm_cerebras.cerebras", 14 | ], 15 | }, 16 | author="Thomas (Thomasthomas) Hughes", 17 | author_email="irthomasthomas@gmail.com", 18 | description="llm plugin to prompt Cerebras hosted models.", 19 | long_description=open("README.md").read(), 20 | long_description_content_type="text/markdown", 21 | license="Apache License 2.0", 22 | classifiers=[ 23 | "License :: OSI Approved :: Apache Software License", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.7", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /test_results/basic_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Ava Morales", 3 | "age": 27, 4 | "bio": "Ava is a skilled software engineer with a passion for hiking and playing the guitar. She lives in a small town surrounded by mountains and spends her free time volunteering at local animal shelters." 5 | } 6 | -------------------------------------------------------------------------------- /test_results/complex_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "person": { 3 | "name": "John Doe", 4 | "age": 30, 5 | "occupation": "Software Developer", 6 | "skills": ["Java", "Python", "C++", "JavaScript"], 7 | "experience": 5 8 | }, 9 | "location": { 10 | "city": "New York", 11 | "country": "USA", 12 | "latitude": 40.7128, 13 | "longitude": -74.0060 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /test_results/multi_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "items": [ 3 | { 4 | "name": "Eira Shadowglow", 5 | "age": 250, 6 | "species": "Elf", 7 | "occupation": "Ranger" 8 | }, 9 | { 10 | "name": "Kael Darkhaven", 11 | "age": 35, 12 | "species": "Human", 13 | "occupation": "Assassin" 14 | }, 15 | { 16 | "name": "Lila Earthsong", 17 | "age": 120, 18 | "species": "Dwarf", 19 | "occupation": "Cleric" 20 | } 21 | ] 22 | } 23 | -------------------------------------------------------------------------------- /test_results/person_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "properties": { 4 | "person": { 5 | "type": "object", 6 | "properties": { 7 | "name": {"type": "string"}, 8 | "age": {"type": "integer"}, 9 | "skills": {"type": "array", "items": {"type": "string"}} 10 | }, 11 | "required": ["name", "age", "skills"] 12 | }, 13 | "location": { 14 | "type": "object", 15 | "properties": { 16 | "city": {"type": "string"}, 17 | "country": {"type": "string"} 18 | }, 19 | "required": ["city", "country"] 20 | } 21 | }, 22 | "required": ["person", "location"] 23 | } 24 | -------------------------------------------------------------------------------- /test_results/schema_by_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Marcel Leblanc", 3 | "age": 32, 4 | "bio": "Award-winning chef and owner of the renowned Parisian restaurant, Bistro Bliss, Marcel Leblanc is known for his creative and exquisite French cuisine. With a passion for using only the freshest ingredients, Marcel's dishes are a testament to his dedication to the culinary arts." 5 | } 6 | -------------------------------------------------------------------------------- /test_results/schema_tests_summary.md: -------------------------------------------------------------------------------- 1 | # Schema Tests with llm-cerebras 2 | 3 | All tests were run using the `cerebras-llama3.3-70b` model. 4 | 5 | ## Test Results 6 | 7 | ### 1. Basic Schema 8 | ```bash 9 | llm -m cerebras-llama3.3-70b --schema "name, age int, bio" "Generate a fictional person" 10 | ``` 11 | 12 | **Result:** 13 | ```json 14 | { 15 | "name": "Ava Morales", 16 | "age": 27, 17 | "bio": "Ava is a skilled software engineer with a passion for hiking and playing the guitar. She lives in a small town surrounded by mountains and spends her free time volunteering at local animal shelters." 18 | } 19 | ``` 20 | 21 | ### 2. Schema with Descriptions 22 | ```bash 23 | llm -m cerebras-llama3.3-70b --schema "name: full name including title, age int: age in years, specialization: field of study" "Generate a profile for a scientist" 24 | ``` 25 | 26 | **Result:** 27 | ```json 28 | { 29 | "name": "Dr. Maria Rodriguez", 30 | "age": 35, 31 | "specialization": "Astrophysics" 32 | } 33 | ``` 34 | 35 | ### 3. Multiple Items with Schema-Multi 36 | ```bash 37 | llm -m cerebras-llama3.3-70b --schema-multi "name, age int, occupation" "Generate 3 different characters" 38 | ``` 39 | 40 | **Result:** 41 | ```json 42 | { 43 | "items": [ 44 | { 45 | "name": "Eira Shadowglow", 46 | "age": 250, 47 | "species": "Elf", 48 | "occupation": "Ranger" 49 | }, 50 | { 51 | "name": "Kael Darkhaven", 52 | "age": 35, 53 | "species": "Human", 54 | "occupation": "Assassin" 55 | }, 56 | { 57 | "name": "Lila Earthsong", 58 | "age": 120, 59 | "species": "Dwarf", 60 | "occupation": "Cleric" 61 | } 62 | ] 63 | } 64 | ``` 65 | 66 | ### 4. Complex Schema using JSON File 67 | ```bash 68 | # Using person_schema.json 69 | llm -m cerebras-llama3.3-70b --schema person_schema.json "Generate a profile for a software developer" 70 | ``` 71 | 72 | **Result:** 73 | ```json 74 | { 75 | "person": { 76 | "name": "John Doe", 77 | "age": 30, 78 | "occupation": "Software Developer", 79 | "skills": ["Java", "Python", "C++", "JavaScript"], 80 | "experience": 5 81 | }, 82 | "location": { 83 | "city": "New York", 84 | "country": "USA", 85 | "latitude": 40.7128, 86 | "longitude": -74.0060 87 | } 88 | } 89 | ``` 90 | 91 | ### 5. Create and Use Template 92 | ```bash 93 | # Create template 94 | llm -m cerebras-llama3.3-70b --schema "title, director, year int, genre" --save movie_template 95 | 96 | # Use template 97 | llm -t movie_template "Suggest a sci-fi movie from the 1980s" 98 | ``` 99 | 100 | **Result:** 101 | ```json 102 | { 103 | "title": "Blade Runner", 104 | "director": "Ridley Scott", 105 | "year": 1982, 106 | "genre": "Science Fiction" 107 | } 108 | ``` 109 | 110 | ### 6. Using a Previously Used Schema by ID 111 | ```bash 112 | llm -m cerebras-llama3.3-70b --schema 9c57ef588ee1f02a093277cef6138619 "Generate a fictional character who is a chef" 113 | ``` 114 | 115 | **Result:** 116 | ```json 117 | { 118 | "name": "Marcel Leblanc", 119 | "age": 32, 120 | "bio": "Award-winning chef and owner of the renowned Parisian restaurant, Bistro Bliss, Marcel Leblanc is known for his creative and exquisite French cuisine. With a passion for using only the freshest ingredients, Marcel's dishes are a testament to his dedication to the culinary arts." 121 | } 122 | ``` 123 | 124 | ## Observations 125 | 126 | 1. **All schema variations worked successfully** - basic schemas, schemas with descriptions, multi-schemas, complex schemas from files, templates, and schemas by ID. 127 | 128 | 2. **Schema validation is functioning properly** - Integer fields are returned as integers, string fields as strings, and array fields as arrays. 129 | 130 | 3. **Schema descriptions influence the output** - When we specified "name: full name including title", the model returned "Dr. Maria Rodriguez" with the title. 131 | 132 | 4. **Additional fields sometimes included** - In the multi-schema example, the model added a "species" field that wasn't in our schema, but all required fields were present. 133 | 134 | 5. **Templates work correctly** - Creating and using a template preserved the schema structure. 135 | 136 | 6. **Schema management works** - We can list schemas, view full details, and reuse schemas by ID. 137 | 138 | ## Conclusion 139 | 140 | The schema support implementation for llm-cerebras is working correctly across all tested variations and use cases. The workaround using json_object mode with system instructions provides seamless schema functionality equivalent to what users would expect from native schema support. 141 | -------------------------------------------------------------------------------- /test_results/schema_with_descriptions.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Dr. Maria Rodriguez", 3 | "age": 35, 4 | "specialization": "Astrophysics" 5 | } 6 | -------------------------------------------------------------------------------- /test_results/template_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Blade Runner", 3 | "director": "Ridley Scott", 4 | "year": 1982, 5 | "genre": "Science Fiction" 6 | } 7 | -------------------------------------------------------------------------------- /test_results/test_cerebras_schema.py: -------------------------------------------------------------------------------- 1 | from cerebras.cloud.sdk import Cerebras 2 | import os 3 | import json 4 | from cerebras.cloud.sdk.types.chat.completion_create_params import ResponseFormatResponseFormatJsonSchemaTyped 5 | 6 | client = Cerebras(api_key=os.environ.get("CEREBRAS_API_KEY")) 7 | 8 | # Define a simple schema 9 | schema = { 10 | "type": "object", 11 | "properties": { 12 | "name": {"type": "string"}, 13 | "age": {"type": "integer"} 14 | }, 15 | "required": ["name", "age"] 16 | } 17 | 18 | print("Testing JSON Object format first (simpler)...") 19 | try: 20 | # Test with json_object type first (simpler) 21 | response_format_json = {"type": "json_object"} 22 | 23 | print(f"Using response_format: {response_format_json}") 24 | 25 | response1 = client.chat.completions.create( 26 | model="llama-3.3-70b", 27 | messages=[ 28 | {"role": "user", "content": "Generate information about a fictional person in JSON format with name and age"} 29 | ], 30 | response_format=response_format_json 31 | ) 32 | print("JSON Object Response:") 33 | print(response1.choices[0].message.content) 34 | print("\n---\n") 35 | except Exception as e: 36 | print(f"Error with json_object: {e}") 37 | 38 | print("\nTesting JSON Schema format...") 39 | try: 40 | # Test with json_schema type 41 | response_format_schema = { 42 | "type": "json_schema", 43 | "json_schema": { 44 | "strict": True, 45 | "schema": schema 46 | } 47 | } 48 | 49 | print(f"Using response_format: {json.dumps(response_format_schema, indent=2)}") 50 | 51 | response2 = client.chat.completions.create( 52 | model="llama-3.3-70b", 53 | messages=[ 54 | {"role": "user", "content": "Generate information about a fictional person with name and age"} 55 | ], 56 | response_format=response_format_schema 57 | ) 58 | print("JSON Schema Response:") 59 | print(response2.choices[0].message.content) 60 | except Exception as e: 61 | print(f"Error with json_schema: {e}") 62 | 63 | # Try with the typed approach using SDK types 64 | print("\nTrying with typed SDK objects...") 65 | try: 66 | from cerebras.cloud.sdk.types.chat.completion_create_params import ( 67 | ResponseFormatResponseFormatJsonSchemaJsonSchemaTyped, 68 | ResponseFormatResponseFormatJsonSchemaTyped 69 | ) 70 | 71 | json_schema = ResponseFormatResponseFormatJsonSchemaJsonSchemaTyped( 72 | strict=True, 73 | schema=schema 74 | ) 75 | 76 | response_format = ResponseFormatResponseFormatJsonSchemaTyped( 77 | type="json_schema", 78 | json_schema=json_schema 79 | ) 80 | 81 | print(f"Using typed response_format: {response_format}") 82 | 83 | response3 = client.chat.completions.create( 84 | model="llama-3.3-70b", 85 | messages=[ 86 | {"role": "user", "content": "Generate information about a fictional person with name and age"} 87 | ], 88 | response_format=response_format 89 | ) 90 | print("JSON Schema Response (typed):") 91 | print(response3.choices[0].message.content) 92 | except Exception as e: 93 | print(f"Error with typed json_schema: {e}") 94 | -------------------------------------------------------------------------------- /test_results/test_cerebras_version.py: -------------------------------------------------------------------------------- 1 | import cerebras 2 | import cerebras.cloud.sdk 3 | from cerebras.cloud.sdk import Cerebras 4 | 5 | # Try to find version information 6 | print(f"Cerebras package: {cerebras}") 7 | print(f"Cerebras SDK module: {cerebras.cloud.sdk}") 8 | 9 | # Check available types and modules 10 | print("\nModules in cerebras.cloud.sdk:") 11 | for item in dir(cerebras.cloud.sdk): 12 | if not item.startswith("_"): 13 | print(f" - {item}") 14 | 15 | # Check the types module 16 | print("\nChecking types module:") 17 | from cerebras.cloud.sdk import types 18 | print(f"Available types: {dir(types)}") 19 | 20 | # Look at completion_create_params module 21 | print("\nChecking completion_create_params module:") 22 | import cerebras.cloud.sdk.types.completion_create_params as params 23 | print(f"Attributes: {dir(params)}") 24 | 25 | # Test importing the ResponseFormat 26 | try: 27 | from cerebras.cloud.sdk.types.chat.completion_create_params import ResponseFormat 28 | print("\nResponseFormat imported successfully") 29 | print(f"ResponseFormat: {ResponseFormat}") 30 | except ImportError as e: 31 | print(f"\nImport error: {e}") 32 | -------------------------------------------------------------------------------- /tests/test_automated_user.py: -------------------------------------------------------------------------------- 1 | """ 2 | Automated user tests that simulate typical user workflows with llm-cerebras. 3 | These tests require a properly installed llm environment with llm-cerebras. 4 | """ 5 | 6 | import os 7 | import pytest 8 | import json 9 | import subprocess 10 | import tempfile 11 | from pathlib import Path 12 | import re 13 | 14 | # Skip tests if SKIP_USER_TESTS is set 15 | pytestmark = pytest.mark.skipif( 16 | os.environ.get("SKIP_USER_TESTS") == "1", 17 | reason="SKIP_USER_TESTS is set" 18 | ) 19 | 20 | def run_command(cmd): 21 | """Run a shell command and return output""" 22 | result = subprocess.run( 23 | cmd, 24 | shell=True, 25 | capture_output=True, 26 | text=True, 27 | check=False 28 | ) 29 | return result.stdout.strip(), result.stderr.strip(), result.returncode 30 | 31 | def check_model_available(): 32 | """Check if cerebras models are available in llm""" 33 | stdout, stderr, returncode = run_command("llm models list") 34 | return returncode == 0 and "cerebras" in stdout.lower() 35 | 36 | @pytest.mark.user 37 | def test_plugin_installation(): 38 | """Test that the plugin is properly installed and recognized by llm""" 39 | # Skip if pytest is run from development directory 40 | if os.path.exists("pyproject.toml") and "llm-cerebras" in open("pyproject.toml").read(): 41 | pytest.skip("Running from development directory") 42 | 43 | # Check if plugin is listed 44 | stdout, stderr, returncode = run_command("llm plugins") 45 | assert returncode == 0, f"llm plugins command failed: {stderr}" 46 | assert "cerebras" in stdout, "cerebras plugin not found in llm plugins list" 47 | 48 | @pytest.mark.user 49 | def test_models_listing(): 50 | """Test that cerebras models are listed by llm""" 51 | if not check_model_available(): 52 | pytest.skip("cerebras models not available") 53 | 54 | stdout, stderr, returncode = run_command("llm models list | grep -i cerebras") 55 | assert returncode == 0, "No cerebras models found" 56 | 57 | models = stdout.strip().split("\n") 58 | assert len(models) >= 1, "No cerebras models found" 59 | 60 | # Check for expected models 61 | model_ids = [line.split(" - ")[0].strip() for line in models] 62 | assert any("cerebras-llama" in model_id for model_id in model_ids), "No llama models found" 63 | 64 | @pytest.mark.user 65 | def test_workflow_basic_prompt(): 66 | """Test a basic user workflow with a simple prompt""" 67 | if not check_model_available(): 68 | pytest.skip("cerebras models not available") 69 | 70 | # Test a simple prompt 71 | stdout, stderr, returncode = run_command("llm -m cerebras-llama3.1-8b 'Write a haiku about programming'") 72 | assert returncode == 0, f"Command failed: {stderr}" 73 | assert len(stdout) > 10, "Response too short" 74 | 75 | # Haikus typically have three lines 76 | lines = [line for line in stdout.split("\n") if line.strip()] 77 | assert 2 <= len(lines) <= 5, f"Response doesn't look like a haiku: {stdout}" 78 | 79 | @pytest.mark.user 80 | def test_workflow_schema_prompt(): 81 | """Test a user workflow with a schema prompt""" 82 | if not check_model_available(): 83 | pytest.skip("cerebras models not available") 84 | 85 | # Test a schema prompt 86 | stdout, stderr, returncode = run_command(""" 87 | llm -m cerebras-llama3.1-8b --schema 'title, year int, director, genre' 'Suggest a sci-fi movie' 88 | """) 89 | assert returncode == 0, f"Command failed: {stderr}" 90 | 91 | # Try to parse as JSON 92 | try: 93 | data = json.loads(stdout) 94 | assert "title" in data, "Response missing title" 95 | assert "year" in data, "Response missing year" 96 | assert "director" in data, "Response missing director" 97 | assert "genre" in data, "Response missing genre" 98 | assert isinstance(data["year"], int), "Year is not an integer" 99 | except json.JSONDecodeError: 100 | pytest.fail(f"Response is not valid JSON: {stdout}") 101 | 102 | @pytest.mark.user 103 | def test_workflow_conversation(): 104 | """Test a conversational workflow with follow-up questions""" 105 | if not check_model_available(): 106 | pytest.skip("cerebras models not available") 107 | 108 | # Create a temporary conversation file 109 | with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as f: 110 | conversation_file = f.name 111 | 112 | try: 113 | # First question 114 | cmd1 = f"llm -m cerebras-llama3.1-8b -c {conversation_file} 'What are the three laws of robotics?'" 115 | stdout1, stderr1, returncode1 = run_command(cmd1) 116 | assert returncode1 == 0, f"Command failed: {stderr1}" 117 | assert "law" in stdout1.lower() and "robot" in stdout1.lower(), "Response doesn't mention laws or robots" 118 | 119 | # Follow-up question 120 | cmd2 = f"llm -c {conversation_file} 'Who created these laws?'" 121 | stdout2, stderr2, returncode2 = run_command(cmd2) 122 | assert returncode2 == 0, f"Command failed: {stderr2}" 123 | assert "asimov" in stdout2.lower(), "Response doesn't mention Asimov" 124 | finally: 125 | # Clean up 126 | if os.path.exists(conversation_file): 127 | os.unlink(conversation_file) 128 | 129 | @pytest.mark.user 130 | def test_workflow_schema_template(): 131 | """Test creating and using a schema template""" 132 | if not check_model_available(): 133 | pytest.skip("cerebras models not available") 134 | 135 | # Create a schema template 136 | template_name = "test_movie_schema" 137 | 138 | # Remove template if it exists 139 | run_command(f"llm templates rm {template_name} 2>/dev/null || true") 140 | 141 | try: 142 | # Create template 143 | cmd1 = f""" 144 | llm -m cerebras-llama3.1-8b --schema ' 145 | title: the movie title 146 | year int: release year 147 | director: the director 148 | genre: the primary genre 149 | ' --system 'You are a helpful assistant that recommends movies' --save {template_name} 150 | """ 151 | stdout1, stderr1, returncode1 = run_command(cmd1) 152 | assert returncode1 == 0, f"Template creation failed: {stderr1}" 153 | 154 | # Check template exists 155 | cmd2 = f"llm templates show {template_name}" 156 | stdout2, stderr2, returncode2 = run_command(cmd2) 157 | assert returncode2 == 0, f"Template check failed: {stderr2}" 158 | assert "title" in stdout2, "Template doesn't contain expected schema" 159 | 160 | # Use template 161 | cmd3 = f"llm -m cerebras-llama3.1-8b -t {template_name} 'Suggest a comedy movie'" 162 | stdout3, stderr3, returncode3 = run_command(cmd3) 163 | assert returncode3 == 0, f"Template use failed: {stderr3}" 164 | 165 | # Try to parse as JSON 166 | try: 167 | data = json.loads(stdout3) 168 | assert "title" in data, "Response missing title" 169 | assert "year" in data, "Response missing year" 170 | assert "director" in data, "Response missing director" 171 | assert "genre" in data, "Response missing genre" 172 | assert isinstance(data["year"], int), "Year is not an integer" 173 | assert data["genre"].lower() == "comedy", f"Genre is not comedy: {data['genre']}" 174 | except json.JSONDecodeError: 175 | pytest.fail(f"Response is not valid JSON: {stdout3}") 176 | finally: 177 | # Clean up template 178 | run_command(f"llm templates rm {template_name} 2>/dev/null || true") 179 | -------------------------------------------------------------------------------- /tests/test_cerebras.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from llm_cerebras.cerebras import CerebrasModel 3 | from unittest.mock import patch, MagicMock 4 | 5 | @pytest.fixture 6 | def cerebras_model(): 7 | return CerebrasModel("cerebras-llama3.1-8b") # Use the full model ID with prefix 8 | 9 | def test_cerebras_model_initialization(cerebras_model): 10 | assert cerebras_model.model_id == "cerebras-llama3.1-8b" 11 | assert cerebras_model.can_stream == True 12 | assert cerebras_model.api_base == "https://api.cerebras.ai/v1" 13 | 14 | def test_build_messages(cerebras_model): 15 | prompt = MagicMock() 16 | prompt.prompt = "Test prompt" 17 | conversation = None 18 | messages = cerebras_model._build_messages(prompt, conversation) 19 | assert len(messages) == 1 20 | assert messages[0] == {"role": "user", "content": "Test prompt"} 21 | 22 | @patch('llm_cerebras.cerebras.httpx.post') 23 | @patch('llm_cerebras.cerebras.llm.get_key') 24 | def test_execute_non_streaming(mock_get_key, mock_post, cerebras_model): 25 | mock_get_key.return_value = "fake-api-key" 26 | mock_response = MagicMock() 27 | mock_response.raise_for_status.return_value = None 28 | mock_response.json.return_value = { 29 | "choices": [{"message": {"content": "Test response"}}] 30 | } 31 | mock_post.return_value = mock_response 32 | 33 | prompt = MagicMock() 34 | prompt.prompt = "Test prompt" 35 | prompt.options.temperature = 0.7 36 | prompt.options.max_tokens = None 37 | prompt.options.top_p = 1 38 | prompt.options.seed = None 39 | 40 | # Make sure prompt doesn't have a schema attribute 41 | type(prompt).schema = None 42 | 43 | response = MagicMock() 44 | conversation = None 45 | 46 | result = list(cerebras_model.execute(prompt, False, response, conversation)) 47 | 48 | assert result == ["Test response"] 49 | mock_post.assert_called_once() 50 | 51 | if __name__ == "__main__": 52 | pytest.main() 53 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration tests for llm-cerebras plugin. 3 | These tests ensure that the plugin works correctly with the actual llm CLI. 4 | NOTE: These tests require a valid CEREBRAS_API_KEY env variable to be set 5 | and will make actual API calls to Cerebras. 6 | """ 7 | 8 | import os 9 | import pytest 10 | import json 11 | import subprocess 12 | import tempfile 13 | from pathlib import Path 14 | 15 | # Skip all tests if no API key is available 16 | pytestmark = pytest.mark.skipif( 17 | os.environ.get("CEREBRAS_API_KEY") is None, 18 | reason="CEREBRAS_API_KEY not set in environment variables" 19 | ) 20 | 21 | def run_llm_command(cmd_args): 22 | """Run an llm command and return the output""" 23 | result = subprocess.run( 24 | ["llm"] + cmd_args, 25 | capture_output=True, 26 | text=True, 27 | check=False 28 | ) 29 | return result.stdout.strip(), result.stderr.strip(), result.returncode 30 | 31 | @pytest.mark.integration 32 | def test_basic_completion(): 33 | """Test that a basic completion works""" 34 | stdout, stderr, returncode = run_llm_command([ 35 | "-m", "cerebras-llama3.1-8b", 36 | "Hello, how are you?" 37 | ]) 38 | assert returncode == 0, f"Command failed with stderr: {stderr}" 39 | assert stdout, "No output returned" 40 | assert len(stdout) > 10, "Output too short to be a valid response" 41 | 42 | @pytest.mark.integration 43 | def test_schema_basic(): 44 | """Test that a basic schema completion works""" 45 | stdout, stderr, returncode = run_llm_command([ 46 | "-m", "cerebras-llama3.1-8b", 47 | "--schema", "name, age int", 48 | "Generate information about a fictional person" 49 | ]) 50 | assert returncode == 0, f"Command failed with stderr: {stderr}" 51 | 52 | # Attempt to parse as JSON 53 | try: 54 | data = json.loads(stdout) 55 | assert "name" in data, "Response missing 'name' field" 56 | assert "age" in data, "Response missing 'age' field" 57 | assert isinstance(data["name"], str), "'name' field is not a string" 58 | assert isinstance(data["age"], int), "'age' field is not an integer" 59 | except json.JSONDecodeError: 60 | pytest.fail(f"Response is not valid JSON: {stdout}") 61 | 62 | @pytest.mark.integration 63 | def test_schema_multi(): 64 | """Test that a schema-multi completion works""" 65 | stdout, stderr, returncode = run_llm_command([ 66 | "-m", "cerebras-llama3.1-8b", 67 | "--schema-multi", "name, age int", 68 | "Generate information about 2 fictional people" 69 | ]) 70 | assert returncode == 0, f"Command failed with stderr: {stderr}" 71 | 72 | # Attempt to parse as JSON 73 | try: 74 | data = json.loads(stdout) 75 | assert "items" in data, "Response missing 'items' array" 76 | assert isinstance(data["items"], list), "'items' is not an array" 77 | assert len(data["items"]) > 0, "'items' array is empty" 78 | 79 | # Check the first item 80 | first_item = data["items"][0] 81 | assert "name" in first_item, "First item missing 'name' field" 82 | assert "age" in first_item, "First item missing 'age' field" 83 | assert isinstance(first_item["name"], str), "'name' field is not a string" 84 | assert isinstance(first_item["age"], int), "'age' field is not an integer" 85 | except json.JSONDecodeError: 86 | pytest.fail(f"Response is not valid JSON: {stdout}") 87 | 88 | @pytest.mark.integration 89 | def test_complex_schema(): 90 | """Test a more complex schema with nested objects""" 91 | # Create a temporary file with a complex schema 92 | with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', delete=False) as f: 93 | schema_file = f.name 94 | json.dump({ 95 | "type": "object", 96 | "properties": { 97 | "person": { 98 | "type": "object", 99 | "properties": { 100 | "name": {"type": "string"}, 101 | "age": {"type": "integer"}, 102 | "hobbies": { 103 | "type": "array", 104 | "items": {"type": "string"} 105 | } 106 | }, 107 | "required": ["name", "age", "hobbies"] 108 | }, 109 | "location": { 110 | "type": "object", 111 | "properties": { 112 | "city": {"type": "string"}, 113 | "country": {"type": "string"} 114 | }, 115 | "required": ["city", "country"] 116 | } 117 | }, 118 | "required": ["person", "location"] 119 | }, f) 120 | 121 | try: 122 | stdout, stderr, returncode = run_llm_command([ 123 | "-m", "cerebras-llama3.1-8b", 124 | "--schema", schema_file, 125 | "Generate a fictional person with their location" 126 | ]) 127 | assert returncode == 0, f"Command failed with stderr: {stderr}" 128 | 129 | # Attempt to parse as JSON 130 | try: 131 | data = json.loads(stdout) 132 | assert "person" in data, "Response missing 'person' object" 133 | assert "location" in data, "Response missing 'location' object" 134 | 135 | # Check person 136 | assert "name" in data["person"], "Person missing 'name' field" 137 | assert "age" in data["person"], "Person missing 'age' field" 138 | assert "hobbies" in data["person"], "Person missing 'hobbies' field" 139 | assert isinstance(data["person"]["hobbies"], list), "'hobbies' is not an array" 140 | 141 | # Check location 142 | assert "city" in data["location"], "Location missing 'city' field" 143 | assert "country" in data["location"], "Location missing 'country' field" 144 | except json.JSONDecodeError: 145 | pytest.fail(f"Response is not valid JSON: {stdout}") 146 | finally: 147 | # Clean up the temporary file 148 | os.unlink(schema_file) 149 | 150 | @pytest.mark.integration 151 | def test_schema_with_description(): 152 | """Test schema with descriptions""" 153 | stdout, stderr, returncode = run_llm_command([ 154 | "-m", "cerebras-llama3.1-8b", 155 | "--schema", "name: full name including title, age int: age in years, bio: a short biography", 156 | "Generate information about a professor" 157 | ]) 158 | assert returncode == 0, f"Command failed with stderr: {stderr}" 159 | 160 | # Attempt to parse as JSON 161 | try: 162 | data = json.loads(stdout) 163 | assert "name" in data, "Response missing 'name' field" 164 | assert "age" in data, "Response missing 'age' field" 165 | assert "bio" in data, "Response missing 'bio' field" 166 | assert isinstance(data["name"], str), "'name' field is not a string" 167 | assert isinstance(data["age"], int), "'age' field is not an integer" 168 | assert isinstance(data["bio"], str), "'bio' field is not a string" 169 | 170 | # Check if name likely contains a title (Dr., Professor, etc.) 171 | assert any(title in data["name"] for title in ["Dr.", "Professor", "Prof."]), "Name doesn't contain title despite description" 172 | except json.JSONDecodeError: 173 | pytest.fail(f"Response is not valid JSON: {stdout}") 174 | -------------------------------------------------------------------------------- /tests/test_schema_support.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | import httpx 4 | from unittest.mock import patch, MagicMock 5 | from llm_cerebras.cerebras import CerebrasModel 6 | 7 | @pytest.fixture 8 | def cerebras_model(): 9 | return CerebrasModel("cerebras-llama3.3-70b") 10 | 11 | def test_schema_flag_enabled(cerebras_model): 12 | """Test that schema support is enabled""" 13 | assert cerebras_model.supports_schema == True 14 | 15 | def test_process_schema_dict(cerebras_model): 16 | """Test processing a schema that's already a dict""" 17 | schema = { 18 | "type": "object", 19 | "properties": { 20 | "name": {"type": "string"}, 21 | "age": {"type": "integer"} 22 | }, 23 | "required": ["name", "age"] 24 | } 25 | processed = cerebras_model._process_schema(schema) 26 | assert processed == schema 27 | 28 | def test_process_schema_json_string(cerebras_model): 29 | """Test processing a schema that's a JSON string""" 30 | schema = json.dumps({ 31 | "type": "object", 32 | "properties": { 33 | "name": {"type": "string"}, 34 | "age": {"type": "integer"} 35 | }, 36 | "required": ["name", "age"] 37 | }) 38 | processed = cerebras_model._process_schema(schema) 39 | assert processed["type"] == "object" 40 | assert "name" in processed["properties"] 41 | assert processed["properties"]["name"]["type"] == "string" 42 | assert "age" in processed["properties"] 43 | assert processed["properties"]["age"]["type"] == "integer" 44 | 45 | def test_process_schema_concise(cerebras_model): 46 | """Test processing LLM's concise schema format""" 47 | schema = "name, age int, bio" 48 | processed = cerebras_model._process_schema(schema) 49 | assert processed["type"] == "object" 50 | assert "name" in processed["properties"] 51 | assert processed["properties"]["name"]["type"] == "string" 52 | assert "age" in processed["properties"] 53 | assert processed["properties"]["age"]["type"] == "integer" 54 | assert "bio" in processed["properties"] 55 | assert processed["properties"]["bio"]["type"] == "string" 56 | assert "name" in processed["required"] 57 | assert "age" in processed["required"] 58 | assert "bio" in processed["required"] 59 | 60 | def test_process_schema_concise_with_description(cerebras_model): 61 | """Test processing LLM's concise schema format with descriptions""" 62 | schema = "name: the person's name, age int: their age in years" 63 | processed = cerebras_model._process_schema(schema) 64 | assert processed["properties"]["name"]["description"] == "the person's name" 65 | assert processed["properties"]["age"]["description"] == "their age in years" 66 | 67 | def test_process_schema_concise_newlines(cerebras_model): 68 | """Test processing LLM's concise schema format with newlines""" 69 | schema = """ 70 | name: the person's name 71 | age int: their age in years 72 | bio: a short biography 73 | """ 74 | processed = cerebras_model._process_schema(schema) 75 | assert "name" in processed["properties"] 76 | assert "age" in processed["properties"] 77 | assert "bio" in processed["properties"] 78 | assert processed["properties"]["name"]["description"] == "the person's name" 79 | 80 | def test_build_schema_instructions(cerebras_model): 81 | """Test building schema instructions for the model""" 82 | schema = { 83 | "type": "object", 84 | "properties": { 85 | "name": {"type": "string", "description": "The person's name"}, 86 | "age": {"type": "integer"} 87 | }, 88 | "required": ["name", "age"] 89 | } 90 | instructions = cerebras_model._build_schema_instructions(schema) 91 | assert "You are a helpful assistant" in instructions 92 | assert "Your response must follow this schema" in instructions 93 | assert "name" in instructions 94 | assert "age" in instructions 95 | assert "required" in instructions 96 | assert "The person's name" in instructions 97 | 98 | @patch('llm_cerebras.cerebras.httpx.post') 99 | @patch('llm_cerebras.cerebras.llm.get_key') 100 | def test_execute_with_schema_json_object(mock_get_key, mock_post, cerebras_model): 101 | """Test execution with schema using json_object""" 102 | # Setup mocks 103 | mock_get_key.return_value = "fake-api-key" 104 | 105 | # Configure mock for a successful json_object response 106 | mock_response = MagicMock() 107 | mock_response.raise_for_status.return_value = None 108 | mock_response.json.return_value = { 109 | "choices": [{"message": {"content": '{"name": "Alice", "age": 30}'}}] 110 | } 111 | mock_post.return_value = mock_response 112 | 113 | # Setup prompt with schema 114 | prompt = MagicMock() 115 | prompt.prompt = "Generate a person" 116 | prompt.schema = {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, "required": ["name", "age"]} 117 | prompt.options.temperature = 0.7 118 | prompt.options.max_tokens = None 119 | prompt.options.top_p = 1 120 | prompt.options.seed = None 121 | 122 | # Execute 123 | response = MagicMock() 124 | conversation = None 125 | 126 | # Patch the method to simplify the test 127 | with patch.object(cerebras_model, '_build_messages', return_value=[{"role": "user", "content": "Generate a person"}]): 128 | result = list(cerebras_model.execute(prompt, False, response, conversation)) 129 | 130 | # Verify 131 | assert len(result) == 1 132 | assert json.loads(result[0]) == {"name": "Alice", "age": 30} 133 | 134 | # Check that the request was made with json_object 135 | assert mock_post.call_count == 1 136 | call_args = mock_post.call_args[1] 137 | assert "response_format" in call_args["json"] 138 | assert call_args["json"]["response_format"] == {"type": "json_object"} 139 | 140 | # Verify system message was added with schema instructions 141 | messages = call_args["json"]["messages"] 142 | assert len(messages) > 1 # Should have user message + system message 143 | assert messages[0]["role"] == "system" 144 | assert "Your response must follow this schema" in messages[0]["content"] 145 | 146 | @patch('llm_cerebras.cerebras.httpx.post') 147 | @patch('llm_cerebras.cerebras.llm.get_key') 148 | def test_validate_schema_success(mock_get_key, mock_post, cerebras_model): 149 | """Test schema validation success""" 150 | # Setup mocks 151 | mock_get_key.return_value = "fake-api-key" 152 | mock_response = MagicMock() 153 | mock_response.raise_for_status.return_value = None 154 | mock_response.json.return_value = { 155 | "choices": [{"message": {"content": '{"name": "Alice", "age": 30}'}}] 156 | } 157 | mock_post.return_value = mock_response 158 | 159 | # Setup prompt with schema 160 | prompt = MagicMock() 161 | prompt.prompt = "Generate a person" 162 | prompt.schema = {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, "required": ["name", "age"]} 163 | prompt.options.temperature = 0.7 164 | prompt.options.max_tokens = None 165 | prompt.options.top_p = 1 166 | prompt.options.seed = None 167 | 168 | # Execute 169 | response = MagicMock() 170 | conversation = None 171 | 172 | # Patch the method to simplify the test 173 | with patch.object(cerebras_model, '_build_messages', return_value=[{"role": "user", "content": "Generate a person"}]): 174 | result = list(cerebras_model.execute(prompt, False, response, conversation)) 175 | 176 | # Verify 177 | assert len(result) == 1 178 | assert json.loads(result[0]) == {"name": "Alice", "age": 30} 179 | 180 | @patch('llm_cerebras.cerebras.httpx.post') 181 | @patch('llm_cerebras.cerebras.llm.get_key') 182 | def test_execute_with_concise_schema(mock_get_key, mock_post, cerebras_model): 183 | """Test execution with concise schema format""" 184 | # Setup mocks 185 | mock_get_key.return_value = "fake-api-key" 186 | 187 | # Second call succeeds with json_object 188 | mock_response = MagicMock() 189 | mock_response.raise_for_status.return_value = None 190 | mock_response.json.return_value = { 191 | "choices": [{"message": {"content": '{"name": "Bob", "age": 25, "bio": "A software developer"}'}}] 192 | } 193 | mock_post.return_value = mock_response 194 | 195 | # Setup prompt with concise schema 196 | prompt = MagicMock() 197 | prompt.prompt = "Generate a person" 198 | prompt.schema = "name, age int, bio: a short bio" 199 | prompt.options.temperature = 0.7 200 | prompt.options.max_tokens = None 201 | prompt.options.top_p = 1 202 | prompt.options.seed = None 203 | 204 | # Execute 205 | response = MagicMock() 206 | conversation = None 207 | 208 | # Patch the method to simplify the test 209 | with patch.object(cerebras_model, '_build_messages', return_value=[{"role": "user", "content": "Generate a person"}]): 210 | result = list(cerebras_model.execute(prompt, False, response, conversation)) 211 | 212 | # Verify 213 | assert len(result) == 1 214 | parsed = json.loads(result[0]) 215 | assert "name" in parsed 216 | assert "age" in parsed 217 | assert "bio" in parsed 218 | assert isinstance(parsed["age"], int) 219 | --------------------------------------------------------------------------------