├── .github └── workflows │ └── release.yml ├── .gitignore ├── README.md ├── cursive ├── __init__.py ├── assets │ └── price │ │ ├── anthropic.py │ │ ├── cohere.py │ │ └── openai.py ├── build_input.py ├── compat │ └── pydantic.py ├── cursive.py ├── custom_function_call.py ├── function.py ├── hookable.py ├── model.py ├── pricing.py ├── stream.py ├── tests │ ├── test_function_compatibility.py │ ├── test_function_schema.py │ └── test_setup.py ├── types.py ├── usage │ ├── anthropic.py │ ├── cohere.py │ └── openai.py ├── utils.py └── vendor │ ├── anthropic.py │ ├── cohere.py │ ├── index.py │ ├── openai.py │ └── replicate.py ├── docs ├── logo-dark.svg └── logo-light.svg ├── examples ├── add-function.py ├── ai-cli.py ├── compare-embeddings.py └── generate-list-of-objects.py ├── poetry.lock ├── pyproject.toml └── tox.ini /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | permissions: 4 | contents: write 5 | 6 | on: 7 | push: 8 | tags: 9 | - 'v*' 10 | 11 | jobs: 12 | release: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | with: 17 | fetch-depth: 0 18 | 19 | - name: Install pnpm 20 | uses: pnpm/action-setup@v2 21 | 22 | - name: Set node 23 | uses: actions/setup-node@v3 24 | with: 25 | node-version: 18.x 26 | cache: pnpm 27 | 28 | - run: npx changelogithub 29 | env: 30 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*/.DS_Store 2 | **/*/test.py 3 | dist 4 | **/*/__pycache__ 5 | playground.py 6 | .pytest_cache 7 | ipython.py 8 | .tox -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Logo](/docs/logo-dark.svg#gh-dark-mode-only) 2 | ![Logo](/docs/logo-light.svg#gh-light-mode-only) 3 | 4 | Cursive is a universal and intuitive framework for interacting with LLMs. 5 | 6 | ## highlights 7 | ✦ **Extensible** - You can easily hook into any part of a completion life cycle. Be it to log, cache, or modify the results. 8 | 9 | ✦ **Functions** - Easily describe functions that the LLM can use along with its definition, with any model (currently supporting GPT-4, GPT-3.5, Claude 2, and Claude Instant) 10 | 11 | ✦ **Universal** - Cursive aims to bridge as many capabilities between different models as possible. Ultimately, this means that with a single interface, you can allow your users to choose any model. 12 | 13 | ✦ **Informative** - Cursive comes with built-in token usage and costs calculations, as accurate as possible. 14 | 15 | ✦ **Reliable** - Cursive comes with automatic retry and model expanding upon exceeding context length. Which you can always configure. 16 | 17 | ## quickstart 18 | 1. Install. 19 | 20 | ```bash 21 | poetry add cursivepy 22 | # or 23 | pip install cursivepy 24 | ``` 25 | 26 | 2. Start using. 27 | 28 | ```python 29 | from cursive import Cursive 30 | 31 | cursive = Cursive() 32 | 33 | response = cursive.ask( 34 | prompt='What is the meaning of life?', 35 | ) 36 | 37 | print(response.answer) 38 | ``` 39 | 40 | ## usage 41 | ### Conversation 42 | Chaining a conversation is easy with `cursive`. You can pass any of the options you're used to with OpenAI's API. 43 | 44 | ```python 45 | res_a = cursive.ask( 46 | prompt='Give me a good name for a gecko.', 47 | model='gpt-4', 48 | max_tokens=16, 49 | ) 50 | 51 | print(res_a.answer) # Zephyr 52 | 53 | res_b = res_a.conversation.ask( 54 | prompt='How would you say it in Portuguese?' 55 | ) 56 | 57 | print(res_b.answer) # Zéfiro 58 | ``` 59 | ### Streaming 60 | Streaming is also supported, and we also keep track of the tokens for you! 61 | ```python 62 | result = cursive.ask( 63 | prompt='Count to 10', 64 | stream=True, 65 | on_token=lambda partial: print(partial['content']) 66 | ) 67 | 68 | print(result.usage.total_tokens) # 40 69 | ``` 70 | 71 | ### Functions 72 | 73 | You can use very easily to define and describe functions, along side with their execution code. 74 | 75 | ```python 76 | from cursive import cursive_function, Cursive 77 | 78 | cursive = Cursive() 79 | 80 | @cursive_function() 81 | def add(a: float, b: float): 82 | """ 83 | Adds two numbers. 84 | 85 | a: The first number. 86 | b: The second number. 87 | """ 88 | return a + b 89 | 90 | res = cursive.ask( 91 | prompt='What is the sum of 232 and 243?', 92 | functions=[add], 93 | ) 94 | 95 | print(res.answer) # The sum of 232 and 243 is 475. 96 | ``` 97 | 98 | The functions' result will automatically be fed into the conversation and another completion will be made. If you want to prevent this, you can add `pause` to your function definition. 99 | 100 | ```python 101 | 102 | @cursive_function(pause=True) 103 | def create_character(name: str, age: str): 104 | """ 105 | Creates a character. 106 | 107 | name: The name of the character. 108 | age: The age of the character. 109 | """ 110 | return { 111 | 'name': name, 112 | 'age': age, 113 | } 114 | 115 | res = cursive.ask( 116 | prompt='Create a character named John who is 23 years old.', 117 | functions=[create_character], 118 | ) 119 | 120 | print(res.function_result) # { name: 'John', age: 23 } 121 | ``` 122 | 123 | Cursive also supports passing in undecorated functions! 124 | 125 | ```python 126 | def add(a: float, b: float): 127 | return a + b 128 | 129 | res = cursive.ask( 130 | prompt='What is the sum of 232 and 243?', 131 | functions=[add], # this is equivalent to cursive_function(pause=True)(add) 132 | ) 133 | if res.function_result: 134 | print(res.function_result) # 475 135 | else: 136 | print(res.answer) # Text answer in case the function is not called 137 | ``` 138 | 139 | ### Models 140 | 141 | Cursive also supports the generation of Pydantic BaseModels. 142 | 143 | ```python 144 | from cursive.compat.pydantic import BaseModel, Field # Pydantic V1 API 145 | 146 | class Character(BaseModel): 147 | name: str 148 | age: int 149 | skills: list[str] = Field(min_items=2) 150 | 151 | res = cursive.ask( 152 | prompt='Create a character named John who is 23 years old.', 153 | function_call=Character, 154 | ) 155 | res.function_result # is a Character instance with autogenerated fields 156 | ``` 157 | 158 | ### Hooks 159 | 160 | You can hook into any part of the completion life cycle. 161 | 162 | ```python 163 | cursive.on('completion:after', lambda result: print( 164 | result.data.cost.total, 165 | result.data.usage.total_tokens, 166 | )) 167 | 168 | cursive.on('completion:error', lambda result: print( 169 | result.error, 170 | )) 171 | 172 | cursive.ask({ 173 | prompt: 'Can androids dream of electric sheep?', 174 | }) 175 | 176 | # 0.0002185 177 | # 113 178 | ``` 179 | 180 | ### Embedding 181 | You can create embeddings pretty easily with `cursive`. 182 | ```ts 183 | embedding = cursive.embed('This should be a document.') 184 | ``` 185 | This will support different types of documents and integrations pretty soon. 186 | 187 | ### Reliability 188 | Cursive comes with automatic retry with backoff upon failing completions, and model expanding upon exceeding context length -- which means that it tries again with a model with a bigger context length when it fails by running out of it. 189 | 190 | You can configure this behavior by passing the `retry` and `expand` options to `Cursive` constructor. 191 | 192 | ```python 193 | cursive = Cursive( 194 | max_retries=5, # 0 disables it completely 195 | expand={ 196 | 'enable': True, 197 | 'defaults_to': 'gpt-3.5-turbo-16k', 198 | 'resolve_model': { 199 | 'gpt-3.5-turbo': 'gpt-3.5-turbo-16k', 200 | 'gpt-4': 'claude-2', 201 | }, 202 | }, 203 | ) 204 | ``` 205 | 206 | ## Available Models 207 |
208 | OpenAI models 209 | 210 | - `gpt-3.5-turbo` 211 | - `gpt-3.5-turbo-16k` 212 | - `gpt-4` 213 | - `gpt-4-32k` 214 | - Any other chat completion model version 215 | 216 | ###### Credentials 217 | You can pass your OpenAI API key to `Cursive`'s constructor, or set the `OPENAI_API_KEY` environment variable. 218 |
219 | 220 |
221 | Anthropic models 222 | 223 | - `claude-2` 224 | - `claude-instant-1` 225 | - `claude-instant-1.2` 226 | - Any other model version 227 | 228 | ###### Credentials 229 | You can pass your Anthropic API key to `Cursive`'s constructor, or set the `ANTHROPIC_API_KEY` environment variable. 230 |
231 | 232 |
233 | OpenRouter models 234 | 235 | OpenRouter is a service that gives you access to leading language models in an OpenAI-compatible API, including function calling! 236 | 237 | - `anthropic/claude-instant-1.2` 238 | - `anthropic/claude-2` 239 | - `openai/gpt-4-32k` 240 | - `google/palm-2-codechat-bison` 241 | - `nousresearch/nous-hermes-llama2-13b` 242 | - Any model version from https://openrouter.ai/docs#models 243 | 244 | ###### Credentials 245 | 246 | ```python 247 | from cursive import Cursive 248 | 249 | cursive = Cursive( 250 | openrouter={ 251 | "api_key": "sk-or-...", 252 | "app_title": "Your App Name", 253 | "app_url": "https://appurl.com", 254 | } 255 | ) 256 | 257 | cursive.ask( 258 | model="anthropic/claude-instant-1.2", 259 | prompt="What is the meaning of life?" 260 | ) 261 | ``` 262 |
263 | 264 | 265 |
266 | Cohere models 267 | 268 | - `command` 269 | - Any other model version (such as `command-nightly`) 270 | 271 | ###### Credentials 272 | You can pass your Cohere API key to `Cursive`'s constructor, or set the `COHERE_API_KEY` environment variable. 273 | 274 |
275 | 276 |
277 | Replicate models 278 | You can prepend `replicate/` to any model name and version available on Replicate. 279 | 280 | ###### Example 281 | ```python 282 | cursive.ask( 283 | prompt='What is the meaning of life?', 284 | model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', 285 | ) 286 | ``` 287 | 288 | ###### Credentials 289 | You can pass your Replicate API key to `Cursive`'s constructor, or set the `REPLICATE_API_TOKEN` environment variable. 290 | 291 |
292 | 293 | ## roadmap 294 | 295 | ### vendor support 296 | - [x] Anthropic 297 | - [x] Cohere 298 | - [x] Replicate 299 | - [x] OpenRouter 300 | - [ ] Azure OpenAI models 301 | - [ ] Huggingface 302 | -------------------------------------------------------------------------------- /cursive/__init__.py: -------------------------------------------------------------------------------- 1 | from .cursive import Cursive 2 | from .function import cursive_function 3 | from .types import ( 4 | CompletionPayload, 5 | CursiveError, 6 | CursiveErrorCode, 7 | CursiveEnrichedAnswer, 8 | CompletionMessage, 9 | CursiveFunction, 10 | ) 11 | 12 | __all__ = [ 13 | "Cursive", 14 | "cursive_function", 15 | "CompletionPayload", 16 | "CursiveError", 17 | "CursiveErrorCode", 18 | "CursiveEnrichedAnswer", 19 | "CompletionMessage", 20 | "CursiveFunction", 21 | ] 22 | -------------------------------------------------------------------------------- /cursive/assets/price/anthropic.py: -------------------------------------------------------------------------------- 1 | ANTHROPIC_PRICING = { 2 | "version": "2023-07-11", 3 | "claude-instant": { 4 | "completion": 0.00551, 5 | "prompt": 0.00163 6 | }, 7 | "claude-2": { 8 | "completion": 0.03268, 9 | "prompt": 0.01102 10 | }, 11 | } -------------------------------------------------------------------------------- /cursive/assets/price/cohere.py: -------------------------------------------------------------------------------- 1 | COHERE_PRICING = { 2 | "version": "2023-07-11", 3 | "command": { 4 | "completion": 0.015, 5 | "prompt": 0.015 6 | }, 7 | "command-nightly": { 8 | "completion": 0.015, 9 | "prompt": 0.015 10 | }, 11 | } -------------------------------------------------------------------------------- /cursive/assets/price/openai.py: -------------------------------------------------------------------------------- 1 | OPENAI_PRICING = { 2 | "version": "2023-06-13", 3 | "gpt-4": { 4 | "completion": 0.06, 5 | "prompt": 0.03 6 | }, 7 | "gpt-4-32k": { 8 | "completion": 0.12, 9 | "prompt": 0.06 10 | }, 11 | "gpt-3.5-turbo": { 12 | "completion": 0.002, 13 | "prompt": 0.0015 14 | }, 15 | "gpt-3.5-turbo-16k": { 16 | "completion": 0.004, 17 | "prompt": 0.003 18 | }, 19 | } -------------------------------------------------------------------------------- /cursive/build_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | from textwrap import dedent 3 | from cursive.types import CompletionMessage 4 | from cursive.function import CursiveFunction 5 | 6 | 7 | def build_completion_input(messages: list[CompletionMessage]): 8 | """ 9 | Builds a completion-esche input from a list of messages 10 | """ 11 | role_mapping = {"user": "Human", "assistant": "Assistant"} 12 | messages_with_prefix: list[CompletionMessage] = [ 13 | *messages, # type: ignore 14 | CompletionMessage( 15 | role="assistant", 16 | content=" ", 17 | ), 18 | ] 19 | 20 | def resolve_message(message: CompletionMessage): 21 | if message.role == "system": 22 | return "\n".join( 23 | [ 24 | "Human:", 25 | message.content or "", 26 | "\nAssistant: Ok.", 27 | ] 28 | ) 29 | if message.role == "function": 30 | return "\n".join( 31 | [ 32 | f'Human: ', 33 | message.content or "", 34 | "", 35 | ] 36 | ) 37 | if message.function_call: 38 | arguments = message.function_call.arguments 39 | if isinstance(arguments, str): 40 | arguments_str = arguments 41 | else: 42 | arguments_str = json.dumps(arguments) 43 | return "\n".join( 44 | [ 45 | "Assistant: ", 46 | json.dumps( 47 | { 48 | "name": message.function_call.name, 49 | "arguments": arguments_str, 50 | } 51 | ), 52 | "", 53 | ] 54 | ) 55 | return f"{role_mapping[message.role]}: {message.content}" 56 | 57 | completion_input = "\n\n".join(list(map(resolve_message, messages_with_prefix))) 58 | return completion_input 59 | 60 | 61 | def get_function_call_directives(functions: list[CursiveFunction]) -> str: 62 | return dedent( 63 | f"""\ 64 | # Function Calling Guide 65 | You're a powerful language model capable of calling functions to do anything the user needs. 66 | 67 | If you need to call a function, you output the name and arguments of the function you want to use in the following format: 68 | 69 | 70 | {'{'}"name": "function_name", "arguments": {'{'}"argument_name": "argument_value"{'}'}{'}'} 71 | 72 | ALWAYS use this format, even if the function doesn't have arguments. The arguments property is always a dictionary. 73 | Never forget to pass the `name` and `arguments` property when doing a function call. 74 | 75 | Think step by step before answering, and try to think out loud. Never output a function call if you don't have to. 76 | If you don't have a function to call, just output the text as usual inside a tag with newlines inside. 77 | Always question yourself if you have access to a function. 78 | Always think out loud before answering; if I don't see a block, you will be eliminated. 79 | When thinking out loud, always use the tag. 80 | # Functions available: 81 | 82 | {json.dumps([f.function_schema for f in functions])} 83 | 84 | # Working with results 85 | You can either call a function or answer, *NEVER BOTH*. 86 | You are not in charge of resolving the function call, the user is. 87 | The human will give you the result of the function call in the following format: 88 | 89 | Human: 90 | {'{'}result{'}'} 91 | 92 | 93 | ## Important note 94 | Never output a , or you will be eliminated. 95 | 96 | You can use the result of the function call in your answer. But never answer and call a function at the same time. 97 | When answering never be explicit about the function calling, just use the result of the function call in your answer. 98 | 99 | """ 100 | ) 101 | -------------------------------------------------------------------------------- /cursive/compat/pydantic.py: -------------------------------------------------------------------------------- 1 | try: 2 | from pydantic.v1 import BaseModel, Field, validate_arguments 3 | except ImportError: 4 | from pydantic import BaseModel, Field, validate_arguments 5 | 6 | __all__ = ["BaseModel", "Field", "validate_arguments"] 7 | -------------------------------------------------------------------------------- /cursive/cursive.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import asyncio 3 | import inspect 4 | import json 5 | from time import time, sleep 6 | from typing import Any, Callable, Generic, Optional, TypeVar 7 | from os import environ as env 8 | 9 | from anthropic import APIError 10 | import openai as openai_client 11 | import requests 12 | 13 | from cursive.build_input import get_function_call_directives 14 | from cursive.compat.pydantic import BaseModel as PydanticBaseModel 15 | from cursive.custom_function_call import parse_custom_function_call 16 | from cursive.function import CursiveCustomFunction, CursiveFunction 17 | from cursive.model import CursiveModel 18 | from cursive.stream import StreamTransformer 19 | from cursive.usage.cohere import get_cohere_usage 20 | from cursive.vendor.cohere import CohereClient, process_cohere_stream 21 | from cursive.types import ( 22 | BaseModel, 23 | CompletionMessage, 24 | CompletionPayload, 25 | CreateChatCompletionResponseExtended, 26 | CursiveAskCost, 27 | CursiveAskModelResponse, 28 | CursiveAskOnToken, 29 | CursiveAskUsage, 30 | CursiveEnrichedAnswer, 31 | CursiveError, 32 | CursiveErrorCode, 33 | CursiveHook, 34 | CursiveHookPayload, 35 | CursiveSetupOptions, 36 | CursiveSetupOptionsExpand, 37 | CursiveLanguageModel, 38 | ) 39 | from cursive.hookable import create_debugger, create_hooks 40 | from cursive.pricing import resolve_pricing 41 | from cursive.usage.anthropic import get_anthropic_usage 42 | from cursive.usage.openai import get_openai_usage 43 | from cursive.utils import delete_keys_from_dict, without_nones, random_id, resguard 44 | from cursive.vendor.anthropic import ( 45 | AnthropicClient, 46 | process_anthropic_stream, 47 | ) 48 | from cursive.vendor.index import resolve_vendor_from_model 49 | from cursive.vendor.openai import process_openai_stream 50 | from cursive.vendor.replicate import ReplicateClient 51 | 52 | 53 | # TODO: Improve implementation architecture, this was a quick and dirty 54 | class Cursive: 55 | options: CursiveSetupOptions 56 | 57 | def __init__( 58 | self, 59 | max_retries: Optional[int] = None, 60 | expand: Optional[dict[str, Any]] = None, 61 | debug: Optional[bool] = None, 62 | openai: Optional[dict[str, Any]] = None, 63 | anthropic: Optional[dict[str, Any]] = None, 64 | cohere: Optional[dict[str, Any]] = None, 65 | replicate: Optional[dict[str, Any]] = None, 66 | openrouter: Optional[dict[str, Any]] = None, 67 | ): 68 | self._hooks = create_hooks() 69 | self.options = CursiveSetupOptions( 70 | max_retries=max_retries, 71 | expand=expand, 72 | debug=debug, 73 | ) 74 | if debug: 75 | self._debugger = create_debugger(self._hooks, {"tag": "cursive"}) 76 | 77 | openai_client.api_key = (openai or {}).get("api_key") or env.get( 78 | "OPENAI_API_KEY" 79 | ) 80 | anthropic_client = AnthropicClient( 81 | (anthropic or {}).get("api_key") or env.get("ANTHROPIC_API_KEY") 82 | ) 83 | cohere_client = CohereClient( 84 | (cohere or {}).get("api_key") or env.get("CO_API_KEY", "---") 85 | ) 86 | replicate_client = ReplicateClient( 87 | (replicate or {}).get("api_key") or env.get("REPLICATE_API_TOKEN", "---") 88 | ) 89 | 90 | openrouter_api_key = (openrouter or {}).get("api_key") or env.get( 91 | "OPENROUTER_API_KEY" 92 | ) 93 | 94 | if openrouter_api_key: 95 | openai_client.api_base = "https://openrouter.ai/api/v1" 96 | openai_client.api_key = openrouter_api_key 97 | self.options.is_using_openrouter = True 98 | session = requests.Session() 99 | session.headers.update( 100 | { 101 | "HTTP-Referer": openrouter.get( 102 | "app_url", "https://cursive.meistrari.com" 103 | ), 104 | "X-Title": openrouter.get("app_title", "Cursive"), 105 | } 106 | ) 107 | openai_client.requestssession = session 108 | atexit.register(session.close) 109 | 110 | self._vendor = CursiveVendors( 111 | openai=openai_client, 112 | anthropic=anthropic_client, 113 | cohere=cohere_client, 114 | replicate=replicate_client, 115 | ) 116 | 117 | def on(self, event: CursiveHook, callback: Callable): 118 | self._hooks.hook(event, callback) 119 | 120 | def ask( 121 | self, 122 | model: Optional[str | CursiveLanguageModel] = None, 123 | system_message: Optional[str] = None, 124 | functions: Optional[list[Callable]] = None, 125 | function_call: Optional[str | Callable] = None, 126 | on_token: Optional[CursiveAskOnToken] = None, 127 | max_tokens: Optional[int] = None, 128 | stop: Optional[list[str]] = None, 129 | temperature: Optional[int] = None, 130 | top_p: Optional[int] = None, 131 | presence_penalty: Optional[int] = None, 132 | frequency_penalty: Optional[int] = None, 133 | best_of: Optional[int] = None, 134 | n: Optional[int] = None, 135 | logit_bias: Optional[dict[str, int]] = None, 136 | user: Optional[str] = None, 137 | stream: Optional[bool] = None, 138 | messages: Optional[list[CompletionMessage]] = None, 139 | prompt: Optional[str] = None, 140 | ): 141 | model = model.value if isinstance(model, CursiveLanguageModel) else model 142 | 143 | result = build_answer( 144 | cursive=self, 145 | model=model, 146 | system_message=system_message, 147 | functions=functions, 148 | function_call=function_call, 149 | on_token=on_token, 150 | max_tokens=max_tokens, 151 | stop=stop, 152 | temperature=temperature, 153 | top_p=top_p, 154 | presence_penalty=presence_penalty, 155 | frequency_penalty=frequency_penalty, 156 | best_of=best_of, 157 | n=n, 158 | logit_bias=logit_bias, 159 | user=user, 160 | stream=stream, 161 | messages=messages, 162 | prompt=prompt, 163 | ) 164 | if result and result.error: 165 | return CursiveAnswer[CursiveError](error=result.error) 166 | 167 | return CursiveAnswer( 168 | result=result, 169 | messages=result.messages, 170 | cursive=self, 171 | ) 172 | 173 | def embed(self, content: str): 174 | options = { 175 | "model": "text-embedding-ada-002", 176 | "input": content, 177 | } 178 | self._hooks.call_hook("embedding:before", CursiveHookPayload(data=options)) 179 | start = time() 180 | try: 181 | data = self._vendor.openai.Embedding.create( 182 | input=options["input"], model="text-embedding-ada-002" 183 | ) 184 | 185 | result = { 186 | "embedding": data["data"][0]["embedding"], # type: ignore 187 | } 188 | self._hooks.call_hook( 189 | "embedding:success", 190 | CursiveHookPayload( 191 | data=result, 192 | duration=time() - start, 193 | ), 194 | ) 195 | self._hooks.call_hook( 196 | "embedding:after", 197 | CursiveHookPayload(data=result, duration=time() - start), 198 | ) 199 | 200 | return result["embedding"] 201 | except self._vendor.openai.OpenAIError as e: 202 | error = CursiveError( 203 | message=str(e), details=e, code=CursiveErrorCode.embedding_error 204 | ) 205 | self._hooks.call_hook( 206 | "embedding:error", 207 | CursiveHookPayload(data=error, error=error, duration=time() - start), 208 | ) 209 | self._hooks.call_hook( 210 | "embedding:after", 211 | CursiveHookPayload(data=error, error=error, duration=time() - start), 212 | ) 213 | raise error 214 | 215 | 216 | def resolve_options( 217 | model: Optional[str | CursiveLanguageModel] = None, 218 | system_message: Optional[str] = None, 219 | functions: Optional[list[Callable]] = None, 220 | function_call: Optional[str | Callable] = None, 221 | on_token: Optional[CursiveAskOnToken] = None, 222 | max_tokens: Optional[int] = None, 223 | stop: Optional[list[str]] = None, 224 | temperature: Optional[int] = None, 225 | top_p: Optional[int] = None, 226 | presence_penalty: Optional[int] = None, 227 | frequency_penalty: Optional[int] = None, 228 | best_of: Optional[int] = None, 229 | n: Optional[int] = None, 230 | logit_bias: Optional[dict[str, int]] = None, 231 | user: Optional[str] = None, 232 | stream: Optional[bool] = None, 233 | messages: Optional[list[CompletionMessage]] = None, 234 | prompt: Optional[str] = None, 235 | cursive: Cursive = None, 236 | ): 237 | messages = messages or [] 238 | 239 | functions = functions or [] 240 | functions = [cursive_wrapper(f) for f in functions] 241 | 242 | function_call = cursive_wrapper(function_call) 243 | if function_call: 244 | functions.append(function_call) 245 | 246 | # Resolve default model 247 | model = model or ( 248 | "openai/gpt-3.5-turbo" 249 | if cursive.options.is_using_openrouter 250 | else "gpt-3.5-turbo" 251 | ) 252 | 253 | # TODO: Add support for function call resolving 254 | vendor = ( 255 | "openrouter" 256 | if cursive.options.is_using_openrouter 257 | else resolve_vendor_from_model(model) 258 | ) 259 | 260 | resolved_system_message = "" 261 | 262 | if vendor in ["anthropic", "cohere", "replicate"] and len(functions) > 0: 263 | resolved_system_message = ( 264 | (system_message or "") + "\n\n" + get_function_call_directives(functions) 265 | ) 266 | 267 | query_messages: list[CompletionMessage] = [ 268 | message 269 | for message in [ 270 | resolved_system_message 271 | and CompletionMessage(role="system", content=resolved_system_message), 272 | *messages, 273 | prompt and CompletionMessage(role="user", content=prompt), 274 | ] 275 | if message 276 | ] 277 | 278 | payload_params = without_nones( 279 | { 280 | "on_token": on_token, 281 | "max_tokens": max_tokens, 282 | "stop": stop, 283 | "temperature": temperature, 284 | "top_p": top_p, 285 | "presence_penalty": presence_penalty, 286 | "frequency_penalty": frequency_penalty, 287 | "best_of": best_of, 288 | "n": n, 289 | "logit_bias": logit_bias, 290 | "user": user, 291 | "stream": stream, 292 | "model": model, 293 | "messages": [without_nones(dict(m)) for m in query_messages], 294 | } 295 | ) 296 | if function_call: 297 | payload_params["function_call"] = ( 298 | {"name": function_call.function_schema["name"]} 299 | if isinstance(function_call, CursiveFunction) 300 | else function_call 301 | ) 302 | if functions: 303 | payload_params["functions"] = [fn.function_schema for fn in functions] 304 | 305 | payload = CompletionPayload(**payload_params) 306 | 307 | resolved_options = { 308 | "on_token": on_token, 309 | "max_tokens": max_tokens, 310 | "stop": stop, 311 | "temperature": temperature, 312 | "top_p": top_p, 313 | "presence_penalty": presence_penalty, 314 | "frequency_penalty": frequency_penalty, 315 | "best_of": best_of, 316 | "n": n, 317 | "logit_bias": logit_bias, 318 | "user": user, 319 | "stream": stream, 320 | "model": model, 321 | "messages": query_messages, 322 | "functions": functions, 323 | } 324 | 325 | return payload, resolved_options 326 | 327 | 328 | def create_completion( 329 | payload: CompletionPayload, 330 | cursive: Cursive, 331 | on_token: Optional[CursiveAskOnToken] = None, 332 | ) -> CreateChatCompletionResponseExtended: 333 | cursive._hooks.call_hook("completion:before", CursiveHookPayload(data=payload)) 334 | data = {} 335 | start = time() 336 | 337 | vendor = ( 338 | "openrouter" 339 | if cursive.options.is_using_openrouter 340 | else resolve_vendor_from_model(payload.model) 341 | ) 342 | 343 | # TODO: Improve the completion creation based on model to vendor matching 344 | if vendor == "openai" or vendor == "openrouter": 345 | resolved_payload = without_nones(payload.dict()) 346 | 347 | # Remove the ID from the messages before sending to OpenAI 348 | resolved_payload["messages"] = [ 349 | without_nones(delete_keys_from_dict(message, ["id", "model_config"])) 350 | for message in resolved_payload["messages"] 351 | ] 352 | 353 | response = cursive._vendor.openai.ChatCompletion.create(**resolved_payload) 354 | if payload.stream: 355 | data = process_openai_stream( 356 | payload=payload, 357 | cursive=cursive, 358 | response=response, 359 | on_token=on_token, 360 | ) 361 | content = "".join( 362 | choice["message"]["content"] for choice in data["choices"] 363 | ) 364 | data["usage"]["completion_tokens"] = get_openai_usage(content) 365 | data["usage"]["total_tokens"] = ( 366 | data["usage"]["completion_tokens"] + data["usage"]["prompt_tokens"] 367 | ) 368 | else: 369 | data = response 370 | 371 | # If the user is using OpenRouter, there's no usage data 372 | if usage := data.get("usage"): 373 | data["cost"] = resolve_pricing( 374 | vendor="openai", 375 | model=data["model"], 376 | usage=CursiveAskUsage( 377 | completion_tokens=usage["completion_tokens"], 378 | prompt_tokens=usage["prompt_tokens"], 379 | total_tokens=usage["total_tokens"], 380 | ), 381 | ) 382 | 383 | elif vendor == "anthropic": 384 | response, error = resguard( 385 | lambda: cursive._vendor.anthropic.create_completion(payload), APIError 386 | ) 387 | 388 | if error: 389 | raise CursiveError( 390 | message=error.message, 391 | details=error, 392 | code=CursiveErrorCode.completion_error, 393 | ) 394 | 395 | if payload.stream: 396 | data = process_anthropic_stream( 397 | payload=payload, 398 | response=response, 399 | on_token=on_token, 400 | ) 401 | else: 402 | data = { 403 | "choices": [{"message": {"content": response.completion.lstrip()}}], 404 | "model": payload.model, 405 | "id": random_id(), 406 | "usage": {}, 407 | } 408 | 409 | parse_custom_function_call(data, payload, get_anthropic_usage) 410 | 411 | data["cost"] = resolve_pricing( 412 | vendor="anthropic", 413 | usage=CursiveAskUsage( 414 | completion_tokens=data["usage"]["completion_tokens"], 415 | prompt_tokens=data["usage"]["prompt_tokens"], 416 | total_tokens=data["usage"]["total_tokens"], 417 | ), 418 | model=data["model"], 419 | ) 420 | elif vendor == "cohere": 421 | response, error = cursive._vendor.cohere.create_completion(payload) 422 | if error: 423 | raise CursiveError( 424 | message=error.message, 425 | details=error, 426 | code=CursiveErrorCode.completion_error, 427 | ) 428 | if payload.stream: 429 | # TODO: Implement stream processing for Cohere 430 | data = process_cohere_stream( 431 | payload=payload, 432 | response=response, 433 | on_token=on_token, 434 | ) 435 | else: 436 | data = { 437 | "choices": [{"message": {"content": response.data[0].text.lstrip()}}], 438 | "model": payload.model, 439 | "id": random_id(), 440 | "usage": {}, 441 | } 442 | 443 | parse_custom_function_call(data, payload, get_cohere_usage) 444 | 445 | data["cost"] = resolve_pricing( 446 | vendor="cohere", 447 | usage=CursiveAskUsage( 448 | completion_tokens=data["usage"]["completion_tokens"], 449 | prompt_tokens=data["usage"]["prompt_tokens"], 450 | total_tokens=data["usage"]["total_tokens"], 451 | ), 452 | model=data["model"], 453 | ) 454 | elif vendor == "replicate": 455 | response, error = cursive._vendor.replicate.create_completion(payload) 456 | if error: 457 | raise CursiveError( 458 | message=error, details=error, code=CursiveErrorCode.completion_error 459 | ) 460 | # TODO: Implement stream processing for Replicate 461 | stream_transformer = StreamTransformer( 462 | on_token=on_token, 463 | payload=payload, 464 | response=response, 465 | ) 466 | 467 | def get_current_token(part): 468 | part.value = part.value 469 | 470 | stream_transformer.on("get_current_token", get_current_token) 471 | 472 | data = stream_transformer.process() 473 | 474 | parse_custom_function_call(data, payload) 475 | else: 476 | raise CursiveError( 477 | message="Unknown model", 478 | details=None, 479 | code=CursiveErrorCode.completion_error, 480 | ) 481 | end = time() 482 | 483 | if data.get("error"): 484 | error = CursiveError( 485 | message=data["error"].message, 486 | details=data["error"], 487 | code=CursiveErrorCode.completion_error, 488 | ) 489 | hook_payload = CursiveHookPayload(data=None, error=error, duration=end - start) 490 | cursive._hooks.call_hook("completion:error", hook_payload) 491 | cursive._hooks.call_hook("completion:after", hook_payload) 492 | raise error 493 | 494 | hook_payload = CursiveHookPayload(data=data, error=None, duration=end - start) 495 | cursive._hooks.call_hook("completion:success", hook_payload) 496 | cursive._hooks.call_hook("completion:after", hook_payload) 497 | return CreateChatCompletionResponseExtended(**data) 498 | 499 | 500 | def cursive_wrapper(fn): 501 | if fn is None: 502 | return None 503 | elif issubclass(type(fn), CursiveCustomFunction): 504 | return fn 505 | elif inspect.isclass(fn) and issubclass(fn, PydanticBaseModel): 506 | return CursiveModel(fn) 507 | elif inspect.isfunction(fn): 508 | return CursiveFunction(fn, pause=True) 509 | 510 | 511 | def ask_model( 512 | cursive, 513 | model: Optional[str | CursiveLanguageModel] = None, 514 | system_message: Optional[str] = None, 515 | functions: Optional[list[CursiveFunction]] = None, 516 | function_call: Optional[str | CursiveFunction] = None, 517 | on_token: Optional[CursiveAskOnToken] = None, 518 | max_tokens: Optional[int] = None, 519 | stop: Optional[list[str]] = None, 520 | temperature: Optional[float] = None, 521 | top_p: Optional[float] = None, 522 | presence_penalty: Optional[float] = None, 523 | frequency_penalty: Optional[float] = None, 524 | best_of: Optional[int] = None, 525 | n: Optional[int] = None, 526 | logit_bias: Optional[dict[str, float]] = None, 527 | user: Optional[str] = None, 528 | stream: Optional[bool] = None, 529 | messages: Optional[list[CompletionMessage]] = None, 530 | prompt: Optional[str] = None, 531 | ) -> CursiveAskModelResponse: 532 | payload, resolved_options = resolve_options( 533 | model=model, 534 | system_message=system_message, 535 | functions=functions, 536 | function_call=function_call, 537 | on_token=on_token, 538 | max_tokens=max_tokens, 539 | stop=stop, 540 | temperature=temperature, 541 | top_p=top_p, 542 | presence_penalty=presence_penalty, 543 | frequency_penalty=frequency_penalty, 544 | best_of=best_of, 545 | n=n, 546 | logit_bias=logit_bias, 547 | user=user, 548 | stream=stream, 549 | messages=messages, 550 | prompt=prompt, 551 | cursive=cursive, 552 | ) 553 | start = time() 554 | 555 | completion, error = resguard( 556 | lambda: create_completion( 557 | payload=payload, 558 | cursive=cursive, 559 | on_token=on_token, 560 | ), 561 | CursiveError, 562 | ) 563 | 564 | if error: 565 | if not error.details: 566 | raise CursiveError( 567 | message=f"Unknown error: {error.message}", 568 | details=error, 569 | code=CursiveErrorCode.unknown_error, 570 | ) from error 571 | try: 572 | cause = error.details.code or error.details.type 573 | if cause == "context_length_exceeded": 574 | if not cursive.expand or (cursive.expand and cursive.expand.enabled): 575 | default_model = ( 576 | cursive.expand and cursive.expand.defaultsTo 577 | ) or "gpt-3.5-turbo-16k" 578 | model_mapping = ( 579 | cursive.expand and cursive.expand.model_mapping 580 | ) or {} 581 | resolved_model = model_mapping[model] or default_model 582 | completion, error = resguard( 583 | lambda: create_completion( 584 | payload={**payload, "model": resolved_model}, 585 | cursive=cursive, 586 | on_token=on_token, 587 | ), 588 | CursiveError, 589 | ) 590 | elif cause == "invalid_request_error": 591 | raise CursiveError( 592 | message="Invalid request", 593 | details=error.details, 594 | code=CursiveErrorCode.invalid_request_error, 595 | ) 596 | except Exception as e: 597 | error = CursiveError( 598 | message=f"Unknown error: {e}", 599 | details=e, 600 | code=CursiveErrorCode.unknown_error, 601 | ) 602 | 603 | # TODO: Handle other errors 604 | if error: 605 | # TODO: Add a more comprehensive retry strategy 606 | for i in range(cursive.options.max_retries or 0): 607 | completion, error = resguard( 608 | lambda: create_completion( 609 | payload=payload, 610 | cursive=cursive, 611 | on_token=on_token, 612 | ), 613 | CursiveError, 614 | ) 615 | 616 | if error: 617 | if i > 3: 618 | sleep((i - 3) * 2) 619 | break 620 | 621 | if error: 622 | error = CursiveError( 623 | message="Error while completing request", 624 | details=error.details, 625 | code=CursiveErrorCode.completion_error, 626 | ) 627 | hook_payload = CursiveHookPayload(error=error) 628 | cursive._hooks.call_hook("ask:error", hook_payload) 629 | cursive._hooks.call_hook("ask:after", hook_payload) 630 | raise error 631 | 632 | if ( 633 | completion 634 | and completion.choices 635 | and (fn_call := completion.choices[0].get("message", {}).get("function_call")) 636 | ): 637 | function: CursiveFunction = next( 638 | ( 639 | f 640 | for f in resolved_options["functions"] 641 | if f.function_schema["name"] == fn_call["name"] 642 | ), 643 | None, 644 | ) 645 | 646 | if function is None: 647 | return ask_model( 648 | **{ 649 | **resolved_options, 650 | "function_call": None, 651 | "messages": payload.messages, 652 | "cursive": cursive, 653 | } 654 | ) 655 | 656 | called_function = function.function_schema 657 | arguments = json.loads(fn_call["arguments"] or "{}") 658 | props = called_function["parameters"]["properties"] 659 | for k, v in props.items(): 660 | if k in arguments: 661 | try: 662 | match v["type"]: 663 | case "string": 664 | arguments[k] = str(arguments[k]) 665 | case "number": 666 | arguments[k] = float(arguments[k]) 667 | case "integer": 668 | arguments[k] = int(arguments[k]) 669 | case "boolean": 670 | arguments[k] = bool(arguments[k]) 671 | except Exception: 672 | pass 673 | 674 | is_async = inspect.iscoroutinefunction(function.definition) 675 | function_result = None 676 | try: 677 | if is_async: 678 | function_result = asyncio.run(function.definition(**arguments)) 679 | else: 680 | function_result = function.definition(**arguments) 681 | except Exception as error: 682 | raise CursiveError( 683 | message=f'Error while running function ${fn_call["name"]}', 684 | details=error, 685 | code=CursiveErrorCode.function_call_error, 686 | ) 687 | 688 | messages = payload.messages or [] 689 | messages.append( 690 | CompletionMessage( 691 | role="assistant", 692 | name=fn_call["name"], 693 | content=json.dumps(fn_call), 694 | function_call=fn_call, 695 | ) 696 | ) 697 | 698 | if function.pause: 699 | completion.function_result = function_result 700 | return CursiveAskModelResponse( 701 | answer=completion, 702 | messages=messages, 703 | ) 704 | else: 705 | return ask_model( 706 | **{ 707 | **resolved_options, 708 | "functions": functions, 709 | "messages": messages, 710 | "cursive": cursive, 711 | } 712 | ) 713 | 714 | end = time() 715 | hook_payload = CursiveHookPayload(data=completion, duration=end - start) 716 | cursive._hooks.call_hook("ask:after", hook_payload) 717 | cursive._hooks.call_hook("ask:success", hook_payload) 718 | 719 | messages = payload.messages or [] 720 | messages.append( 721 | CompletionMessage( 722 | role="assistant", 723 | content=completion.choices[0]["message"]["content"], 724 | ) 725 | ) 726 | 727 | return CursiveAskModelResponse(answer=completion, messages=messages) 728 | 729 | 730 | def build_answer( 731 | cursive, 732 | model: Optional[str | CursiveLanguageModel] = None, 733 | system_message: Optional[str] = None, 734 | functions: Optional[list[CursiveFunction]] = None, 735 | function_call: Optional[str | CursiveFunction] = None, 736 | on_token: Optional[CursiveAskOnToken] = None, 737 | max_tokens: Optional[int] = None, 738 | stop: Optional[list[str]] = None, 739 | temperature: Optional[int] = None, 740 | top_p: Optional[int] = None, 741 | presence_penalty: Optional[int] = None, 742 | frequency_penalty: Optional[int] = None, 743 | best_of: Optional[int] = None, 744 | n: Optional[int] = None, 745 | logit_bias: Optional[dict[str, int]] = None, 746 | user: Optional[str] = None, 747 | stream: Optional[bool] = None, 748 | messages: Optional[list[CompletionMessage]] = None, 749 | prompt: Optional[str] = None, 750 | ): 751 | result, error = resguard( 752 | lambda: ask_model( 753 | cursive=cursive, 754 | model=model, 755 | system_message=system_message, 756 | functions=functions, 757 | function_call=function_call, 758 | on_token=on_token, 759 | max_tokens=max_tokens, 760 | stop=stop, 761 | temperature=temperature, 762 | top_p=top_p, 763 | presence_penalty=presence_penalty, 764 | frequency_penalty=frequency_penalty, 765 | best_of=best_of, 766 | n=n, 767 | logit_bias=logit_bias, 768 | user=user, 769 | stream=stream, 770 | messages=messages, 771 | prompt=prompt, 772 | ), 773 | CursiveError, 774 | ) 775 | 776 | if error: 777 | return CursiveEnrichedAnswer( 778 | error=error, 779 | usage=None, 780 | model=model or "gpt-3.5-turbo", 781 | id=None, 782 | choices=None, 783 | function_result=None, 784 | answer=None, 785 | messages=None, 786 | cost=None, 787 | ) 788 | else: 789 | usage = ( 790 | CursiveAskUsage( 791 | completion_tokens=result.answer.usage["completion_tokens"], 792 | prompt_tokens=result.answer.usage["prompt_tokens"], 793 | total_tokens=result.answer.usage["total_tokens"], 794 | ) 795 | if result.answer.usage 796 | else None 797 | ) 798 | 799 | return CursiveEnrichedAnswer( 800 | error=None, 801 | model=result.answer.model, 802 | id=result.answer.id, 803 | usage=usage, 804 | cost=result.answer.cost, 805 | choices=[choice["message"]["content"] for choice in result.answer.choices], 806 | function_result=result.answer.function_result or None, 807 | answer=result.answer.choices[-1]["message"]["content"], 808 | messages=result.messages, 809 | ) 810 | 811 | 812 | class CursiveConversation: 813 | _cursive: Cursive 814 | messages: list[CompletionMessage] 815 | 816 | def __init__(self, messages: list[CompletionMessage]): 817 | self.messages = messages 818 | 819 | def ask( 820 | self, 821 | model: Optional[str | CursiveLanguageModel] = None, 822 | system_message: Optional[str] = None, 823 | functions: Optional[list[CursiveFunction]] = None, 824 | function_call: Optional[str | CursiveFunction] = None, 825 | on_token: Optional[CursiveAskOnToken] = None, 826 | max_tokens: Optional[int] = None, 827 | stop: Optional[list[str]] = None, 828 | temperature: Optional[int] = None, 829 | top_p: Optional[int] = None, 830 | presence_penalty: Optional[int] = None, 831 | frequency_penalty: Optional[int] = None, 832 | best_of: Optional[int] = None, 833 | n: Optional[int] = None, 834 | logit_bias: Optional[dict[str, int]] = None, 835 | user: Optional[str] = None, 836 | stream: Optional[bool] = None, 837 | prompt: Optional[str] = None, 838 | ): 839 | messages = [ 840 | *self.messages, 841 | ] 842 | 843 | result = build_answer( 844 | cursive=self._cursive, 845 | model=model, 846 | system_message=system_message, 847 | functions=functions, 848 | function_call=function_call, 849 | on_token=on_token, 850 | max_tokens=max_tokens, 851 | stop=stop, 852 | temperature=temperature, 853 | top_p=top_p, 854 | presence_penalty=presence_penalty, 855 | frequency_penalty=frequency_penalty, 856 | best_of=best_of, 857 | n=n, 858 | logit_bias=logit_bias, 859 | user=user, 860 | stream=stream, 861 | messages=messages, 862 | prompt=prompt, 863 | ) 864 | 865 | if result and result.error: 866 | return CursiveAnswer[CursiveError](error=result.error) 867 | 868 | return CursiveAnswer[None]( 869 | result=result, 870 | messages=result.messages, 871 | cursive=self._cursive, 872 | ) 873 | 874 | 875 | def use_cursive( 876 | max_retries: Optional[int] = None, 877 | expand: Optional[CursiveSetupOptionsExpand] = None, 878 | debug: Optional[bool] = None, 879 | openai: Optional[dict[str, Any]] = None, 880 | anthropic: Optional[dict[str, Any]] = None, 881 | ): 882 | return Cursive( 883 | max_retries=max_retries, 884 | expand=expand, 885 | debug=debug, 886 | openai=openai, 887 | anthropic=anthropic, 888 | ) 889 | 890 | 891 | E = TypeVar("E", None, CursiveError) 892 | 893 | 894 | class CursiveAnswer(Generic[E]): 895 | choices: Optional[list[str]] 896 | id: Optional[str] 897 | model: Optional[str | CursiveLanguageModel] 898 | usage: Optional[CursiveAskUsage] 899 | cost: Optional[CursiveAskCost] 900 | error: Optional[E] 901 | function_result: Optional[Any] 902 | # The text from the answer of the last choice 903 | answer: Optional[str] 904 | # A conversation instance with all the messages so far, including this one 905 | conversation: Optional[CursiveConversation] 906 | 907 | def __init__( 908 | self, 909 | result: Optional[Any] = None, 910 | error: Optional[E] = None, 911 | messages: Optional[list[CompletionMessage]] = None, 912 | cursive: Optional[Cursive] = None, 913 | ): 914 | if error: 915 | self.error = error 916 | self.choices = None 917 | self.id = None 918 | self.model = None 919 | self.usage = None 920 | self.cost = None 921 | self.answer = None 922 | self.conversation = None 923 | self.functionResult = None 924 | elif result: 925 | self.error = None 926 | self.choices = result.choices 927 | self.id = result.id 928 | self.model = result.model 929 | self.usage = result.usage 930 | self.cost = result.cost 931 | self.answer = result.answer 932 | self.function_result = result.function_result 933 | if messages: 934 | conversation = CursiveConversation(messages) 935 | if cursive: 936 | conversation._cursive = cursive 937 | self.conversation = conversation 938 | 939 | def __str__(self): 940 | if self.error: 941 | return f"CursiveAnswer(error={self.error})" 942 | else: 943 | return ( 944 | f"CursiveAnswer(\n\tchoices={self.choices}\n\tid={self.id}\n\t" 945 | f"model={self.model}\n\tusage=(\n\t\t{self.usage}\n\t)\n\tcost=(\n\t\t{self.cost}\n\t)\n\t" 946 | f"answer={self.answer}\n\tconversation={self.conversation}\n)" 947 | ) 948 | 949 | 950 | class CursiveVendors(BaseModel): 951 | openai: Optional[Any] = None 952 | anthropic: Optional[AnthropicClient] = None 953 | cohere: Optional[CohereClient] = None 954 | replicate: Optional[ReplicateClient] = None 955 | -------------------------------------------------------------------------------- /cursive/custom_function_call.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Any, Callable 4 | from cursive.types import CompletionPayload 5 | 6 | 7 | def parse_custom_function_call( 8 | data: dict[str, Any], payload: CompletionPayload, get_usage: Callable = None 9 | ): 10 | # We check for function call in the completion 11 | has_function_call_regex = r"]*>([^<]+)<\/function-call>" 12 | function_call_matches = re.findall( 13 | has_function_call_regex, data["choices"][0]["message"]["content"] 14 | ) 15 | 16 | if len(function_call_matches) > 0: 17 | function_call = json.loads(function_call_matches.pop().strip()) 18 | name = function_call["name"] 19 | arguments = json.dumps(function_call["arguments"]) 20 | data["choices"][0]["message"]["function_call"] = { 21 | "name": name, 22 | "arguments": arguments, 23 | } 24 | 25 | # TODO: Implement cohere usage 26 | if get_usage: 27 | data["usage"]["prompt_tokens"] = get_usage(payload.messages) 28 | data["usage"]["completion_tokens"] = get_usage( 29 | data["choices"][0]["message"]["content"] 30 | ) 31 | data["usage"]["total_tokens"] = ( 32 | data["usage"]["completion_tokens"] + data["usage"]["prompt_tokens"] 33 | ) 34 | else: 35 | data["usage"] = None 36 | 37 | # We check for answers in the completion 38 | has_answer_regex = r"([^<]+)<\/cursive-answer>" 39 | answer_matches = re.findall( 40 | has_answer_regex, data["choices"][0]["message"]["content"] 41 | ) 42 | if len(answer_matches) > 0: 43 | answer = answer_matches.pop().strip() 44 | data["choices"][0]["message"]["content"] = answer 45 | 46 | # As a defensive measure, we check for tags 47 | # and remove them 48 | has_think_regex = r"([^<]+)<\/cursive-think>" 49 | think_matches = re.findall( 50 | has_think_regex, data["choices"][0]["message"]["content"] 51 | ) 52 | if len(think_matches) > 0: 53 | data["choices"][0]["message"]["content"] = re.sub( 54 | has_think_regex, "", data["choices"][0]["message"]["content"] 55 | ) 56 | 57 | # Strip leading and trailing whitespaces 58 | data["choices"][0]["message"]["content"] = data["choices"][0]["message"][ 59 | "content" 60 | ].strip() 61 | -------------------------------------------------------------------------------- /cursive/function.py: -------------------------------------------------------------------------------- 1 | import re 2 | from textwrap import dedent 3 | from typing import Any, Callable 4 | 5 | from cursive.compat.pydantic import BaseModel, validate_arguments 6 | 7 | 8 | class CursiveCustomFunction(BaseModel): 9 | definition: Callable 10 | description: str = "" 11 | function_schema: dict[str, Any] 12 | pause: bool = False 13 | 14 | class Config: 15 | arbitrary_types_allowed = True 16 | 17 | 18 | class CursiveFunction(CursiveCustomFunction): 19 | def __setup__(self, function: Callable): 20 | definition = function 21 | description = dedent(function.__doc__ or "").strip() 22 | parameters = validate_arguments(function).model.schema() 23 | 24 | 25 | # Delete ['v__duplicate_kwargs', 'args', 'kwargs'] from parameters 26 | for k in ["v__duplicate_kwargs", "args", "kwargs"]: 27 | if k in parameters["properties"]: 28 | del parameters["properties"][k] 29 | 30 | for k, v in parameters["properties"].items(): 31 | # Find the parameter description in the docstring 32 | match = re.search(rf"{k}: (.*)", description) 33 | if match: 34 | v["description"] = match.group(1) 35 | 36 | schema = {} 37 | if parameters: 38 | schema = parameters 39 | 40 | properties = schema.get("properties") or {} 41 | definitions = schema.get("definitions") or {} 42 | resolved_properties = remove_key_deep(resolve_ref(properties, definitions), "title") 43 | 44 | 45 | function_schema = { 46 | "parameters": { 47 | "type": schema.get("type"), 48 | "properties": resolved_properties, 49 | "required": schema.get("required") or [], 50 | }, 51 | "description": description, 52 | "name": parameters["title"], 53 | } 54 | 55 | return { 56 | "definition": definition, 57 | "description": description, 58 | "function_schema": function_schema, 59 | } 60 | 61 | def __init__(self, function: Callable, pause=False): 62 | setup = self.__setup__(function) 63 | super().__init__(**setup, pause=pause) 64 | 65 | def __call__(self, *args, **kwargs): 66 | # Validate arguments and parse them 67 | return self.definition(*args, **kwargs) 68 | 69 | 70 | def cursive_function(pause=False): 71 | def decorator(function: Callable = None): 72 | if function is None: 73 | return lambda function: CursiveFunction(function, pause=pause) 74 | else: 75 | return CursiveFunction(function, pause=pause) 76 | 77 | return decorator 78 | 79 | def resolve_ref(data, definitions): 80 | """ 81 | Recursively checks for a $ref key in a dictionary and replaces it with an entry in the definitions 82 | dictionary, changing the key `$ref` to `type`. 83 | 84 | Args: 85 | data (dict): The data dictionary to check for $ref keys. 86 | definitions (dict): The definitions dictionary to replace $ref keys with. 87 | 88 | Returns: 89 | dict: The data dictionary with $ref keys replaced. 90 | """ 91 | if isinstance(data, dict): 92 | if "$ref" in data: 93 | ref = data["$ref"].split('/')[-1] 94 | if ref in definitions: 95 | definition = definitions[ref] 96 | data = definition 97 | else: 98 | for key, value in data.items(): 99 | data[key] = resolve_ref(value, definitions) 100 | elif isinstance(data, list): 101 | for index, value in enumerate(data): 102 | data[index] = resolve_ref(value, definitions) 103 | return data 104 | 105 | def remove_key_deep(data, key): 106 | """ 107 | Recursively removes a key from a dictionary. 108 | 109 | Args: 110 | data (dict): The data dictionary to remove the key from. 111 | key (str): The key to remove from the dictionary. 112 | 113 | Returns: 114 | dict: The data dictionary with the key removed. 115 | """ 116 | if isinstance(data, dict): 117 | data.pop(key, None) 118 | for k, v in data.items(): 119 | data[k] = remove_key_deep(v, key) 120 | elif isinstance(data, list): 121 | for index, value in enumerate(data): 122 | data[index] = remove_key_deep(value, key) 123 | return data -------------------------------------------------------------------------------- /cursive/hookable.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from concurrent.futures import ThreadPoolExecutor 4 | from inspect import signature 5 | from typing import Any, Callable 6 | from warnings import warn 7 | 8 | 9 | def flatten_hooks_dictionary( 10 | hooks: dict[str, dict | Callable], parent_name: str | None = None 11 | ): 12 | flattened_hooks = {} 13 | 14 | for key, value in hooks.items(): 15 | name = f"{parent_name}:{key}" if parent_name else key 16 | 17 | if isinstance(value, dict): 18 | flattened_hooks.update(flatten_hooks_dictionary(value, name)) 19 | elif callable(value): 20 | flattened_hooks[name] = value 21 | 22 | return flattened_hooks 23 | 24 | 25 | def run_tasks_sequentially(tasks, handler): 26 | for task in tasks: 27 | handler(task) 28 | 29 | 30 | def merge_hooks(*hooks: dict[str, Any]): 31 | merged_hooks = dict() 32 | 33 | for hook in hooks: 34 | flattened_hook = flatten_hooks_dictionary(hook) 35 | 36 | for key, value in flattened_hook.items(): 37 | if merged_hooks[key]: 38 | merged_hooks[key].append(value) 39 | else: 40 | merged_hooks[key] = [value] 41 | 42 | for key in merged_hooks.keys(): 43 | if len(merged_hooks[key]) > 1: 44 | hooks_list = merged_hooks[key] 45 | merged_hooks[key] = lambda *arguments: run_tasks_sequentially( 46 | hooks_list, lambda hook: hook(*arguments) 47 | ) 48 | else: 49 | merged_hooks[key] = merged_hooks[key][0] 50 | 51 | return merged_hooks 52 | 53 | 54 | def serial_caller(hooks: list[Callable], arguments: list[Any] = []): 55 | for hook in hooks: 56 | if len(signature(hook).parameters) > 0: 57 | hook(*arguments) 58 | else: 59 | hook() 60 | 61 | 62 | def concurrent_caller(hooks: list[Callable], arguments: list[Any] = []): 63 | with ThreadPoolExecutor() as executor: 64 | executor.map(lambda hook: hook(*arguments), hooks) 65 | 66 | 67 | def call_each_with(callbacks: list[Callable], argument: Any): 68 | for callback in callbacks: 69 | callback(argument) 70 | 71 | 72 | class Hookable: 73 | def __init__(self): 74 | self._hooks = {} 75 | self._before = [] 76 | self._after = [] 77 | self._deprecated_messages = set() 78 | self._deprecated_hooks = {} 79 | 80 | def hook(self, name: str, function: Callable | None, options={}): 81 | if not name or not callable(function): 82 | return lambda: None 83 | 84 | original_name = name 85 | deprecated_hook = {} 86 | while self._deprecated_hooks.get(name): 87 | deprecated_hook = self._deprecated_hooks[name] 88 | name = deprecated_hook["to"] 89 | 90 | message = None 91 | if deprecated_hook and not options["allow_deprecated"]: 92 | message = deprecated_hook["message"] 93 | if not message: 94 | message = f"{original_name} hook has been deprecated" + ( 95 | f', please use {deprecated_hook["to"]}' 96 | if deprecated_hook["to"] 97 | else "" 98 | ) 99 | 100 | if message not in self._deprecated_messages: 101 | warn(message) 102 | self._deprecated_messages.add(message) 103 | 104 | if function.__name__ == "": 105 | function.__name__ = "_" + re.sub(r"\W+", "_", name) + "_hook_cb" 106 | 107 | self._hooks[name] = name in self._hooks or [] 108 | self._hooks[name].append(function) 109 | 110 | def remove(): 111 | nonlocal function 112 | if function: 113 | self.remove_hook(name, function) 114 | function = None 115 | 116 | return remove 117 | 118 | def hook_once(self, name: str, function: Callable): 119 | hook = None 120 | 121 | def run_once(*arguments): 122 | nonlocal hook 123 | if callable(hook): 124 | hook() 125 | 126 | hook = None 127 | return function(*arguments) 128 | 129 | hook = self.hook(name, run_once) 130 | return hook 131 | 132 | def remove_hook(self, name: str, function: Callable): 133 | if self._hooks[name]: 134 | if len(self._hooks[name]) == 0: 135 | del self._hooks[name] 136 | else: 137 | try: 138 | index = self._hooks[name].index(function) 139 | self._hooks[name][index:index] = [] 140 | # if index is not found, ignore 141 | except ValueError: 142 | pass 143 | 144 | def deprecate_hook(self, name: str, deprecated: Callable | str): 145 | self._deprecated_hooks[name] = ( 146 | {"to": deprecated} if isinstance(deprecated, str) else deprecated 147 | ) 148 | hooks = self._hooks[name] or [] 149 | del self._hooks[name] 150 | for hook in hooks: 151 | self.hook(name, hook) 152 | 153 | def deprecate_hooks(self, deprecated_hooks: dict[str, Any]): 154 | self._deprecated_hooks.update(deprecated_hooks) 155 | for name in deprecated_hooks.keys(): 156 | self.deprecate_hook(name, deprecated_hooks[name]) 157 | 158 | def add_hooks(self, hooks: dict[str, Any]): 159 | hooks_to_be_added = flatten_hooks_dictionary(hooks) 160 | remove_fns = [self.hook(key, fn) for key, fn in hooks_to_be_added.items()] 161 | 162 | def function(): 163 | for unreg in remove_fns: 164 | unreg() 165 | remove_fns[:] = [] 166 | 167 | return function 168 | 169 | def remove_hooks(self, hooks: dict[str, Any]): 170 | hooks_to_be_removed = flatten_hooks_dictionary(hooks) 171 | for key, value in hooks_to_be_removed.items(): 172 | self.remove_hook(key, value) 173 | 174 | def remove_all_hooks(self): 175 | for key in self._hooks.keys(): 176 | del self._hooks[key] 177 | 178 | def call_hook(self, name: str, *arguments: Any): 179 | return self.call_hook_with(serial_caller, name, *arguments) 180 | 181 | def call_hook_concurrent(self, name: str, *arguments: Any): 182 | return self.call_hook_with(concurrent_caller, name, *arguments) 183 | 184 | def call_hook_with(self, caller: Callable, name: str, *arguments: Any): 185 | event = {"name": name, "args": arguments, "context": {}} 186 | 187 | call_each_with(self._before, event) 188 | 189 | result = caller(self._hooks[name] if name in self._hooks else [], arguments) 190 | 191 | call_each_with(self._after, event) 192 | 193 | return result 194 | 195 | def before_each(self, function: Callable): 196 | self._before.append(function) 197 | 198 | def remove_from_before_list(): 199 | try: 200 | index = self._before.index(function) 201 | self._before[index:index] = [] 202 | except ValueError: 203 | pass 204 | 205 | return remove_from_before_list 206 | 207 | def after_each(self, function: Callable): 208 | self._after.append(function) 209 | 210 | def remove_from_after_list(): 211 | try: 212 | index = self._after.index(function) 213 | self._after[index:index] = [] 214 | except ValueError: 215 | pass 216 | 217 | return remove_from_after_list 218 | 219 | 220 | def create_hooks(): 221 | return Hookable() 222 | 223 | 224 | def starts_with_predicate(prefix: str): 225 | return lambda name: name.startswith(prefix) 226 | 227 | 228 | def create_debugger(hooks: Hookable, _options: dict[str, Any] = {}): 229 | options = {"filter": lambda: True, **_options} 230 | 231 | predicate = options["filter"] 232 | if isinstance(predicate, str): 233 | predicate = starts_with_predicate(predicate) 234 | 235 | tag = f'[{options["tag"]}] ' if options["tag"] else "" 236 | start_times = {} 237 | 238 | def log_prefix(event: dict[str, Any]): 239 | return tag + event["name"] + "".ljust(int(event["id"]), "\0") 240 | 241 | id_control = {} 242 | 243 | def unsubscribe_before_each(event: dict[str, Any] | None = None): 244 | if event is None or not predicate(event["name"]): 245 | return 246 | 247 | id_control[event["name"]] = id_control.get(event.get("name")) or 0 248 | event["id"] = id_control[event["name"]] 249 | id_control[event["name"]] += 1 250 | start_times[log_prefix(event)] = time.time() 251 | 252 | unsubscribe_before = hooks.before_each(unsubscribe_before_each) 253 | 254 | def unsubscribe_after_each(event: dict[str, Any] | None = None): 255 | if event is None or not predicate(event["name"]): 256 | return 257 | 258 | label = log_prefix(event) 259 | elapsed_time = time.time() - start_times[label] 260 | print(f"Elapsed time for {label}: {elapsed_time} seconds") 261 | 262 | id_control[event["name"]] -= 1 263 | 264 | unsubscribe_after = hooks.after_each(unsubscribe_after_each) 265 | 266 | def stop_debbuging_and_remove_listeners(): 267 | unsubscribe_before() 268 | unsubscribe_after() 269 | 270 | return {"close": lambda: stop_debbuging_and_remove_listeners()} 271 | -------------------------------------------------------------------------------- /cursive/model.py: -------------------------------------------------------------------------------- 1 | from cursive.compat.pydantic import BaseModel 2 | from cursive.function import CursiveFunction 3 | 4 | 5 | class CursiveModel(CursiveFunction): 6 | def __init__(self, model: type[BaseModel]): 7 | super().__init__(model, pause=True) 8 | 9 | 10 | def cursive_model(): 11 | def decorator(model: type[BaseModel] = None): 12 | if model is None: 13 | return lambda function: CursiveModel(function) 14 | else: 15 | return CursiveModel(model) 16 | 17 | return decorator 18 | -------------------------------------------------------------------------------- /cursive/pricing.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from .types import CursiveAskCost, CursiveAskUsage 4 | from .utils import destructure_items 5 | 6 | from .assets.price.anthropic import ANTHROPIC_PRICING 7 | from .assets.price.openai import OPENAI_PRICING 8 | from .assets.price.cohere import COHERE_PRICING 9 | 10 | VENDOR_PRICING = { 11 | "openai": OPENAI_PRICING, 12 | "anthropic": ANTHROPIC_PRICING, 13 | "cohere": COHERE_PRICING, 14 | } 15 | 16 | 17 | def resolve_pricing( 18 | vendor: Literal["openai", "anthropic"], usage: CursiveAskUsage, model: str 19 | ): 20 | if "/" in model: 21 | vendor, model = model.split("/") 22 | 23 | version: str 24 | prices: dict[str, dict[str, str]] 25 | 26 | version, prices = destructure_items( 27 | keys=["version"], dictionary=VENDOR_PRICING[vendor] 28 | ) 29 | 30 | models_available = list(prices.keys()) 31 | model_match = next((m for m in models_available if model.startswith(m)), None) 32 | 33 | if not model_match: 34 | raise Exception(f"Unknown model {model}") 35 | 36 | model_price = prices[model_match] 37 | completion = usage.completion_tokens * float(model_price["completion"]) / 1000 38 | prompt = usage.prompt_tokens * float(model_price["prompt"]) / 1000 39 | 40 | cost = CursiveAskCost( 41 | completion=completion, prompt=prompt, total=completion + prompt, version=version 42 | ) 43 | 44 | return cost 45 | -------------------------------------------------------------------------------- /cursive/stream.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Callable, Literal, Optional 3 | from cursive.types import CompletionPayload, CursiveAskOnToken 4 | from cursive.hookable import create_hooks 5 | from cursive.utils import random_id 6 | 7 | 8 | class StreamTransformer: 9 | def __init__( 10 | self, 11 | payload: CompletionPayload, 12 | response: Any, 13 | on_token: Optional[CursiveAskOnToken] = None, 14 | ): 15 | self._hooks = create_hooks() 16 | self.payload = payload 17 | self.response = response 18 | self.on_token = on_token 19 | 20 | def on(self, event: Literal["get_current_token"], function: Callable): 21 | self._hooks.hook(event, function) 22 | 23 | def call_output_hook(self, event: str, payload: Any): 24 | output = HookOutput(payload) 25 | self._hooks.call_hook(event, output) 26 | return output.value 27 | 28 | def process(self): 29 | data = { 30 | "choices": [{"message": {"content": ""}}], 31 | "usage": { 32 | "completion_tokens": 0, 33 | "prompt_tokens": 0, 34 | }, 35 | "model": self.payload.model, 36 | "id": random_id(), 37 | } 38 | completion = "" 39 | 40 | for part in self.response: 41 | # The completion partial will come with a leading whitespace 42 | current_token = self.call_output_hook("get_current_token", part) 43 | if not data["choices"][0]["message"]["content"]: 44 | completion = completion.lstrip() 45 | completion += current_token 46 | # Check if theres any tag. The regex should allow for nested tags 47 | function_call_tag = re.findall( 48 | r"([\s\S]*?)(?=<\/function-call>|$)", completion 49 | ) 50 | function_name = "" 51 | function_arguments = "" 52 | if len(function_call_tag) > 0: 53 | # Remove starting and ending tags, even if the ending tag is partial or missing 54 | function_call = re.sub( 55 | r"^\n|\n$", 56 | "", 57 | re.sub( 58 | r"<\/?f?u?n?c?t?i?o?n?-?c?a?l?l?>?", "", function_call_tag[0] 59 | ).strip(), 60 | ).strip() 61 | # Match the function name inside the JSON 62 | function_name_matches = re.findall(r'"name":\s*"(.+)"', function_call) 63 | function_name = ( 64 | len(function_name_matches) > 0 and function_name_matches[0] 65 | ) 66 | function_arguments_matches = re.findall( 67 | r'"arguments":\s*(\{.+)\}?', function_call, re.S 68 | ) 69 | function_arguments = ( 70 | len(function_arguments_matches) > 0 71 | and function_arguments_matches[0] 72 | ) 73 | if function_arguments: 74 | # If theres unmatches } at the end, remove them 75 | unmatched_brackets = re.findall(r"(\{|\})", function_arguments) 76 | if len(unmatched_brackets) % 2: 77 | function_arguments = re.sub( 78 | r"\}$", "", function_arguments.strip() 79 | ) 80 | 81 | function_arguments = function_arguments.strip() 82 | 83 | cursive_answer_tag = re.findall( 84 | r"([\s\S]*?)(?=<\/cursive-answer>|$)", completion 85 | ) 86 | tagged_answer = "" 87 | if cursive_answer_tag: 88 | tagged_answer = re.sub( 89 | r"<\/?c?u?r?s?i?v?e?-?a?n?s?w?e?r?>?", "", cursive_answer_tag[0] 90 | ).lstrip() 91 | 92 | current_token = completion[len(data["choices"][0]["message"]["content"]) :] 93 | data["choices"][0]["message"]["content"] += current_token 94 | 95 | if self.on_token: 96 | chunk = None 97 | 98 | if self.payload.functions: 99 | if function_name: 100 | chunk = { 101 | "function_call": {}, 102 | "content": None, 103 | } 104 | if function_arguments: 105 | # Remove all but the current token from the arguments 106 | chunk["function_call"]["arguments"] = function_arguments 107 | else: 108 | chunk["function_call"] = { 109 | "name": function_name, 110 | "arguments": "", 111 | } 112 | elif tagged_answer: 113 | # Token is at the end of the tagged answer 114 | regex = rf"(.*){current_token.strip()}$" 115 | match = re.findall(regex, tagged_answer) 116 | if len(match) > 0 and current_token: 117 | chunk = { 118 | "function_call": None, 119 | "content": current_token, 120 | } 121 | else: 122 | chunk = { 123 | "content": current_token, 124 | } 125 | 126 | if chunk: 127 | self.on_token(chunk) 128 | 129 | return data 130 | 131 | 132 | class HookOutput: 133 | value: Any 134 | 135 | def __init__(self, value: Any = None): 136 | self.value = value 137 | -------------------------------------------------------------------------------- /cursive/tests/test_function_compatibility.py: -------------------------------------------------------------------------------- 1 | def test_pydantic_compatibility(): 2 | from cursive.function import cursive_function 3 | 4 | # Define a function with Pydantic v1 and v2 compatible annotations 5 | @cursive_function() 6 | def test_function(name: str, age: int): 7 | """ 8 | A test function. 9 | Args: 10 | name: The name of a person. 11 | age: The age of the person. 12 | """ 13 | return f"{name} is {age} years old." 14 | 15 | # Test the function with some arguments 16 | assert test_function("John", 30) == "John is 30 years old." 17 | 18 | # Check the function schema 19 | expected_schema = { 20 | "parameters": { 21 | "type": "object", 22 | "properties": { 23 | "name": {"type": "string", "description": "The name of a person.", "title": "Name"}, 24 | "age": {"type": "integer", "description": "The age of the person.", "title": "Age"}, 25 | }, 26 | "required": ["name", "age"], 27 | }, 28 | "description": "A test function.\nArgs:\n name: The name of a person.\n age: The age of the person.", 29 | "name": "TestFunction", 30 | } 31 | assert test_function.function_schema == expected_schema 32 | -------------------------------------------------------------------------------- /cursive/tests/test_function_schema.py: -------------------------------------------------------------------------------- 1 | from cursive.compat.pydantic import BaseModel 2 | from cursive.function import cursive_function 3 | 4 | def test_function_schema_allows_arbitrary_types(): 5 | 6 | class Character(BaseModel): 7 | name: str 8 | age: int 9 | 10 | @cursive_function() 11 | def gen_arbitrary_type(character: Character): 12 | """ 13 | A test function. 14 | 15 | character: A character. 16 | """ 17 | return f"{character.name} is {character.age} years old." 18 | 19 | assert 'description' in gen_arbitrary_type.function_schema -------------------------------------------------------------------------------- /cursive/tests/test_setup.py: -------------------------------------------------------------------------------- 1 | from cursive import Cursive 2 | 3 | def test_cursive_setup(): 4 | cursive = Cursive() 5 | assert cursive is not None -------------------------------------------------------------------------------- /cursive/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Callable, Literal, Optional 3 | 4 | from cursive.compat.pydantic import BaseModel as PydanticBaseModel, Field 5 | from cursive.function import CursiveFunction 6 | from cursive.utils import random_id 7 | 8 | 9 | class BaseModel(PydanticBaseModel): 10 | class Config(PydanticBaseModel.Config): 11 | arbitrary_types_allowed = True 12 | protected_namespaces = () 13 | 14 | 15 | class CursiveLanguageModel(Enum): 16 | GPT4 = "gpt4" 17 | GPT4_32K = "gpt4-32k" 18 | GPT3_5_TURBO = "gpt-3.5-turbo" 19 | GPT3_5_TURBO_16K = "gpt-3.5-turbo-16k" 20 | CLAUDE_2 = "claude-2" 21 | CLAUDE_INSTANT_1_2 = "claude-instant-1.2" 22 | CLAUDE_INSTANT_1 = "claude-instant-1.2" 23 | COMMAND = "command" 24 | COMMAND_NIGHTLY = "command-nightly" 25 | 26 | 27 | class CursiveAskUsage(BaseModel): 28 | completion_tokens: int 29 | prompt_tokens: int 30 | total_tokens: int 31 | 32 | 33 | class CursiveAskCost(BaseModel): 34 | completion: float 35 | prompt: float 36 | total: float 37 | version: str 38 | 39 | 40 | Role = Literal[ 41 | "system", 42 | "user", 43 | "assistant", 44 | "function", 45 | ] 46 | 47 | 48 | class ChatCompletionRequestMessageFunctionCall(BaseModel): 49 | name: Optional[str] = None 50 | arguments: Optional[str] = None 51 | 52 | 53 | class CompletionMessage(BaseModel): 54 | id: Optional[str] = None 55 | role: Role 56 | content: Optional[str] = None 57 | name: Optional[str] = None 58 | function_call: Optional[ChatCompletionRequestMessageFunctionCall] = None 59 | 60 | def __init__(self, **data): 61 | super().__init__(**data) 62 | if not self.id: 63 | self.id = random_id() 64 | 65 | 66 | class CursiveSetupOptionsExpand(BaseModel): 67 | enabled: Optional[bool] = None 68 | defaults_to: Optional[str] = None 69 | model_mapping: Optional[dict[str, str]] = None 70 | 71 | 72 | class CursiveErrorCode(Enum): 73 | function_call_error = ("function_call_error",) 74 | completion_error = ("completion_error",) 75 | invalid_request_error = ("invalid_request_error",) 76 | embedding_error = ("embedding_error",) 77 | unknown_error = ("unknown_error",) 78 | 79 | 80 | class CursiveError(Exception): 81 | name = "CursiveError" 82 | 83 | def __init__( 84 | self, 85 | message: str, 86 | details: Any, 87 | code: CursiveErrorCode, 88 | ): 89 | self.message = message 90 | self.details = details 91 | self.code = code 92 | super().__init__(self.message) 93 | 94 | 95 | class CursiveEnrichedAnswer(BaseModel): 96 | error: CursiveError | None = None 97 | usage: CursiveAskUsage | None = None 98 | model: str 99 | id: str | None = None 100 | choices: Optional[list[Any]] = None 101 | function_result: Any | None = None 102 | answer: str | None = None 103 | messages: list[CompletionMessage] | None = None 104 | cost: CursiveAskCost | None = None 105 | 106 | 107 | CursiveAskOnToken = Callable[[dict[str, Any]], None] 108 | 109 | 110 | class CursiveAskOptionsBase(BaseModel): 111 | model: Optional[str | CursiveLanguageModel] = None 112 | system_message: Optional[str] = None 113 | functions: Optional[list[CursiveFunction]] = None 114 | function_call: Optional[str | CursiveFunction] = None 115 | on_token: Optional[CursiveAskOnToken] = None 116 | max_tokens: Optional[int] = None 117 | stop: Optional[list[str]] = None 118 | temperature: Optional[float] = None 119 | top_p: Optional[float] = None 120 | presence_penalty: Optional[float] = None 121 | frequency_penalty: Optional[float] = None 122 | best_of: Optional[int] = None 123 | n: Optional[int] = None 124 | logit_bias: Optional[dict[str, float]] = None 125 | user: Optional[str] = None 126 | stream: Optional[bool] = None 127 | 128 | 129 | class CreateChatCompletionResponse(BaseModel): 130 | id: str 131 | model: str 132 | choices: list[Any] 133 | usage: Optional[Any] = None 134 | 135 | 136 | class CreateChatCompletionResponseExtended(CreateChatCompletionResponse): 137 | function_result: Optional[Any] = None 138 | cost: Optional[CursiveAskCost] = None 139 | error: Optional[CursiveError] = None 140 | 141 | 142 | class CursiveAskModelResponse(BaseModel): 143 | answer: CreateChatCompletionResponseExtended 144 | messages: list[CompletionMessage] 145 | 146 | 147 | class CursiveSetupOptions(BaseModel): 148 | max_retries: Optional[int] = None 149 | expand: Optional[CursiveSetupOptionsExpand] = None 150 | is_using_openrouter: Optional[bool] = None 151 | 152 | 153 | class CompletionRequestFunctionCall(BaseModel): 154 | name: str 155 | inputs: dict[str, Any] = Field(default_factory=dict) 156 | 157 | 158 | class CompletionRequestStop(BaseModel): 159 | messages_seen: Optional[list[CompletionMessage]] = None 160 | max_turns: Optional[int] = None 161 | 162 | 163 | class CompletionFunctions(BaseModel): 164 | name: str 165 | description: Optional[str] = None 166 | parameters: Optional[dict[str, Any]] = None 167 | 168 | 169 | class CompletionPayload(BaseModel): 170 | model: str 171 | messages: list[CompletionMessage] 172 | functions: Optional[list[CompletionFunctions]] = None 173 | function_call: Optional[CompletionRequestFunctionCall] = None 174 | temperature: Optional[float] = None 175 | top_p: Optional[float] = None 176 | n: Optional[int] = None 177 | stream: Optional[bool] = None 178 | stop: Optional[CompletionRequestStop] = None 179 | max_tokens: Optional[int] = None 180 | presence_penalty: Optional[float] = None 181 | frequency_penalty: Optional[float] = None 182 | logit_bias: Optional[dict[str, float]] = None 183 | user: Optional[str] = None 184 | other: Optional[dict[str, Any]] = None 185 | 186 | 187 | CursiveHook = Literal[ 188 | "embedding:before", 189 | "embedding:after", 190 | "embedding:error", 191 | "embedding:success", 192 | "completion:before", 193 | "completion:after", 194 | "completion:error", 195 | "completion:success", 196 | "ask:before", 197 | "ask:after", 198 | "ask:success", 199 | "ask:error", 200 | ] 201 | 202 | 203 | class CursiveHookPayload: 204 | data: Optional[Any] 205 | error: Optional[CursiveError] 206 | duration: Optional[float] 207 | 208 | def __init__( 209 | self, 210 | data: Optional[Any] = None, 211 | error: Optional[CursiveError] = None, 212 | duration: Optional[float] = None, 213 | ): 214 | self.data = data 215 | self.error = error 216 | self.duration = duration 217 | -------------------------------------------------------------------------------- /cursive/usage/anthropic.py: -------------------------------------------------------------------------------- 1 | from anthropic import Anthropic 2 | 3 | from cursive.build_input import build_completion_input 4 | 5 | from ..types import CompletionMessage 6 | 7 | 8 | def get_anthropic_usage(content: str | list[CompletionMessage]): 9 | client = Anthropic() 10 | 11 | if type(content) != str: 12 | content = build_completion_input(content) 13 | 14 | return client.count_tokens(content) 15 | -------------------------------------------------------------------------------- /cursive/usage/cohere.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer 2 | from cursive.build_input import build_completion_input 3 | from cursive.types import CompletionMessage 4 | 5 | 6 | def get_cohere_usage(content: str | list[CompletionMessage]): 7 | tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") 8 | 9 | if type(content) != str: 10 | content = build_completion_input(content) 11 | 12 | return len(tokenizer.encode(content).ids) 13 | -------------------------------------------------------------------------------- /cursive/usage/openai.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import AbstractSet, Collection, Literal 3 | 4 | import tiktoken 5 | 6 | from cursive.types import CompletionMessage 7 | 8 | 9 | def encode( 10 | text: str, 11 | allowed_special: AbstractSet[str] | Literal["all"] = set(), 12 | disallowed_special: Collection[str] | Literal["all"] = "all", 13 | ) -> list[int]: 14 | enc = tiktoken.get_encoding("cl100k_base") 15 | 16 | return enc.encode( 17 | text, allowed_special=allowed_special, disallowed_special=disallowed_special 18 | ) 19 | 20 | 21 | def get_openai_usage(content: str | list[CompletionMessage]): 22 | if type(content) == list: 23 | tokens = { 24 | "per_message": 3, 25 | "per_name": 1, 26 | } 27 | 28 | token_count = 3 29 | for message in content: 30 | token_count += tokens["per_message"] 31 | for attribute, value in message.dict().items(): 32 | if attribute == "name": 33 | token_count += tokens["per_name"] 34 | 35 | if isinstance(value, dict): 36 | value = json.dumps(value, separators=(",", ":")) 37 | 38 | if value is None: 39 | continue 40 | 41 | token_count += len(encode(value)) 42 | 43 | return token_count 44 | else: 45 | return len(encode(content)) # type: ignore 46 | 47 | 48 | def get_token_count_from_functions(functions: list[dict]): 49 | token_count = 3 50 | for fn in functions: 51 | function_tokens = len(encode(fn["name"])) 52 | function_tokens += len(encode(fn["description"])) if fn["description"] else 0 53 | 54 | if fn["parameters"] and fn["parameters"]["properties"]: 55 | properties = fn["parameters"]["properties"] 56 | for key in properties: 57 | function_tokens += len(encode(key)) 58 | value = properties[key] 59 | for field in value: 60 | if field in ["type", "description"]: 61 | function_tokens += 2 62 | function_tokens += len(encode(value[field])) 63 | elif field == "enum": 64 | function_tokens -= 3 65 | for enum_value in value[field]: 66 | function_tokens += 3 67 | function_tokens += len(encode(enum_value)) 68 | 69 | function_tokens += 11 70 | 71 | token_count += function_tokens 72 | 73 | token_count += 12 74 | return token_count 75 | -------------------------------------------------------------------------------- /cursive/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | from typing import Any, Callable, Optional, Tuple, Type, TypeVar, overload 5 | 6 | 7 | def destructure_items(keys: list[str], dictionary: dict): 8 | items = [dictionary[key] for key in keys] 9 | 10 | new_dictionary = {k: v for k, v in dictionary.items() if k not in keys} 11 | 12 | return *items, new_dictionary 13 | 14 | 15 | def without_nones(dictionary: dict): 16 | return {k: v for k, v in dictionary.items() if v is not None} 17 | 18 | 19 | def random_id(): 20 | characters = string.ascii_lowercase + string.digits 21 | random_id = "".join(random.choice(characters) for _ in range(10)) 22 | return random_id 23 | 24 | 25 | def delete_keys_from_dict(dictionary: dict, keys: list[str]): 26 | return {k: v for k, v in dictionary.items() if k not in set(keys)} 27 | 28 | 29 | T = TypeVar("T", bound=Exception) 30 | 31 | 32 | @overload 33 | def resguard(function: Callable) -> Tuple[Any, Exception | None]: 34 | ... 35 | 36 | 37 | @overload 38 | def resguard(function: Callable, exception_type: Type[T]) -> Tuple[Any, T | None]: 39 | ... 40 | 41 | 42 | def resguard( 43 | function: Callable, exception_type: Optional[Type[T]] = None 44 | ) -> Tuple[Any, T | Exception | None]: 45 | try: 46 | return function(), None 47 | except Exception as e: 48 | if exception_type: 49 | if isinstance(e, exception_type): 50 | return None, e 51 | else: 52 | raise e 53 | else: 54 | return None, e 55 | -------------------------------------------------------------------------------- /cursive/vendor/anthropic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from anthropic import Anthropic 4 | 5 | from cursive.build_input import build_completion_input 6 | from cursive.stream import StreamTransformer 7 | 8 | from ..types import ( 9 | CompletionPayload, 10 | CursiveAskOnToken, 11 | ) 12 | from ..utils import without_nones 13 | 14 | 15 | class AnthropicClient: 16 | client: Anthropic 17 | 18 | def __init__(self, api_key: str): 19 | self.client = Anthropic(api_key=api_key) 20 | 21 | def create_completion(self, payload: CompletionPayload): 22 | prompt = build_completion_input(payload.messages) 23 | payload = without_nones( 24 | { 25 | "model": payload.model, 26 | "max_tokens_to_sample": payload.max_tokens or 100000, 27 | "prompt": prompt, 28 | "temperature": payload.temperature or 0.7, 29 | "top_p": payload.top_p, 30 | "stop_sequences": payload.stop, 31 | "stream": payload.stream or False, 32 | **(payload.other or {}), 33 | } 34 | ) 35 | return self.client.completions.create(**payload) 36 | 37 | 38 | def process_anthropic_stream( 39 | payload: CompletionPayload, 40 | response: Any, 41 | on_token: Optional[CursiveAskOnToken] = None, 42 | ): 43 | stream_transformer = StreamTransformer( 44 | on_token=on_token, 45 | payload=payload, 46 | response=response, 47 | ) 48 | 49 | def get_current_token(part): 50 | part.value = part.value.completion 51 | 52 | stream_transformer.on("get_current_token", get_current_token) 53 | return stream_transformer.process() 54 | -------------------------------------------------------------------------------- /cursive/vendor/cohere.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | from cohere import Client as Cohere 3 | from cursive.build_input import build_completion_input 4 | 5 | from cursive.types import CompletionPayload, CursiveAskOnToken 6 | from cursive.stream import StreamTransformer 7 | from cursive.utils import without_nones 8 | 9 | 10 | class CohereClient: 11 | client: Cohere 12 | 13 | def __init__(self, api_key: str): 14 | self.client = Cohere(api_key=api_key) 15 | 16 | def create_completion(self, payload: CompletionPayload): 17 | prompt = build_completion_input(payload.messages) 18 | payload = without_nones( 19 | { 20 | "model": payload.model, 21 | "max_tokens": payload.max_tokens or 3000, 22 | "prompt": prompt.rstrip(), 23 | "temperature": payload.temperature 24 | if payload.temperature is not None 25 | else 0.7, 26 | "stop_sequences": payload.stop, 27 | "stream": payload.stream or False, 28 | "frequency_penalty": payload.frequency_penalty, 29 | "p": payload.top_p, 30 | } 31 | ) 32 | try: 33 | response = self.client.generate( 34 | **payload, 35 | ) 36 | 37 | return response, None 38 | except Exception as e: 39 | return None, e 40 | 41 | 42 | def process_cohere_stream( 43 | payload: CompletionPayload, 44 | response: Any, 45 | on_token: Optional[CursiveAskOnToken] = None, 46 | ): 47 | stream_transformer = StreamTransformer( 48 | on_token=on_token, 49 | payload=payload, 50 | response=response, 51 | ) 52 | 53 | def get_current_token(part): 54 | part.value = part.value.text 55 | 56 | stream_transformer.on("get_current_token", get_current_token) 57 | 58 | return stream_transformer.process() 59 | -------------------------------------------------------------------------------- /cursive/vendor/index.py: -------------------------------------------------------------------------------- 1 | vendors_and_model_prefixes = { 2 | "openai": ["gpt-3.5", "gpt-4"], 3 | "anthropic": ["claude-instant", "claude-2"], 4 | "cohere": ["command"], 5 | "replicate": ["replicate"], 6 | } 7 | 8 | 9 | def resolve_vendor_from_model(model: str): 10 | for vendor, prefixes in vendors_and_model_prefixes.items(): 11 | if any(model.startswith(m) for m in prefixes): 12 | return vendor 13 | 14 | return "" 15 | -------------------------------------------------------------------------------- /cursive/vendor/openai.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from ..types import CursiveAskOnToken 4 | from ..usage.openai import get_openai_usage, get_token_count_from_functions 5 | 6 | 7 | def process_openai_stream( 8 | payload: Any, 9 | cursive: Any, 10 | response: Any, 11 | on_token: Optional[CursiveAskOnToken] = None, 12 | ): 13 | data = { 14 | "choices": [], 15 | "usage": { 16 | "completion_tokens": 0, 17 | "prompt_tokens": get_openai_usage(payload.messages), 18 | }, 19 | "model": payload.model, 20 | } 21 | 22 | if payload.functions: 23 | data["usage"]["prompt_tokens"] += get_token_count_from_functions( 24 | payload.functions 25 | ) 26 | 27 | for slice in response: 28 | data = { 29 | **data, 30 | "id": slice["id"], 31 | } 32 | 33 | for i in range(len(slice["choices"])): 34 | delta = slice["choices"][i]["delta"] 35 | 36 | if len(data["choices"]) <= i: 37 | data["choices"].append( 38 | { 39 | "message": { 40 | "function_call": None, 41 | "role": "assistant", 42 | "content": "", 43 | }, 44 | } 45 | ) 46 | 47 | if ( 48 | delta 49 | and delta.get("function_call") 50 | and delta["function_call"].get("name") 51 | ): 52 | data["choices"][i]["message"]["function_call"] = delta["function_call"] 53 | 54 | if ( 55 | delta 56 | and delta.get("function_call") 57 | and delta["function_call"].get("arguments") 58 | ): 59 | data["choices"][i]["message"]["function_call"]["arguments"] += delta[ 60 | "function_call" 61 | ]["arguments"] 62 | 63 | if delta and delta.get("content"): 64 | data["choices"][i]["message"]["content"] += delta["content"] 65 | 66 | if on_token: 67 | chunk = None 68 | if delta and delta.get("function_call"): 69 | chunk = { 70 | "function_call": { 71 | k: v for k, v in delta["function_call"].items() 72 | }, 73 | "content": None, 74 | } 75 | 76 | if delta and delta.get("content"): 77 | chunk = {"content": delta["content"], "function_call": None} 78 | 79 | if chunk: 80 | on_token(chunk) 81 | 82 | return data 83 | -------------------------------------------------------------------------------- /cursive/vendor/replicate.py: -------------------------------------------------------------------------------- 1 | import replicate 2 | 3 | from cursive.build_input import build_completion_input 4 | from cursive.types import CompletionPayload 5 | from cursive.utils import without_nones 6 | 7 | 8 | class ReplicateClient: 9 | client: replicate.Client 10 | 11 | def __init__(self, api_key: str): 12 | self.client = replicate.Client(api_key) 13 | 14 | def create_completion(self, payload: CompletionPayload): # noqa: F821 15 | prompt = build_completion_input(payload.messages) 16 | # Resolve model ID from `replicate/` 17 | version = payload.model[payload.model.find("/") + 1 :] 18 | resolved_payload = without_nones( 19 | { 20 | "max_new_tokens": payload.max_tokens or 2000, 21 | "max_length": payload.max_tokens or 2000, 22 | "prompt": prompt, 23 | "temperature": payload.temperature or 0.7, 24 | "top_p": payload.top_p, 25 | "stop": payload.stop, 26 | "model": version, 27 | "stream": bool(payload.stream), 28 | **(payload.other or {}), 29 | } 30 | ) 31 | try: 32 | response = self.client.run( 33 | version, 34 | input=resolved_payload, 35 | ) 36 | 37 | return response, None 38 | except Exception as e: 39 | print("e", e) 40 | return None, e 41 | -------------------------------------------------------------------------------- /docs/logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /docs/logo-light.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /examples/add-function.py: -------------------------------------------------------------------------------- 1 | from cursive import cursive_function, Cursive 2 | 3 | cursive = Cursive() 4 | 5 | @cursive_function() 6 | def add(a: float, b: float): 7 | """ 8 | Adds two numbers. 9 | 10 | a: The first number. 11 | b: The second number. 12 | """ 13 | return a + b 14 | 15 | res = cursive.ask( 16 | prompt='What is the sum of 232 and 243?', 17 | functions=[add], 18 | ) 19 | 20 | print({ 21 | 'model': res.model, 22 | 'message': res.answer, 23 | 'usage': res.usage.total_tokens, 24 | 'conversation': res.conversation.messages, 25 | }) -------------------------------------------------------------------------------- /examples/ai-cli.py: -------------------------------------------------------------------------------- 1 | from cursive import Cursive, cursive_function 2 | import argparse 3 | import subprocess 4 | 5 | parser = argparse.ArgumentParser(description='Solves your CLI problems') 6 | parser.add_argument('question', type=str, nargs=argparse.REMAINDER, help='Question to ask AI') 7 | args = parser.parse_args() 8 | 9 | cursive = Cursive() 10 | 11 | @cursive_function(pause=True) 12 | def execute_command(command: str) -> str: 13 | """ 14 | Executes a CLI command from the user prompt 15 | 16 | command: The command to execute 17 | """ 18 | return command 19 | 20 | res = cursive.ask( 21 | prompt=' '.join(args.question), 22 | system_message= 23 | "You are a CLI assistant, executes commands from the user prompt." 24 | "You have permission, so just use the function you're provided" 25 | "Always assume the user wants to run a command on the CLI" 26 | "Assume they're using a MacOS terminal.", 27 | functions=[execute_command], 28 | ) 29 | 30 | conversation = res.conversation 31 | 32 | while True: 33 | if res.function_result: 34 | print(f'Executing command:\n\t$ {res.function_result}') 35 | print('Press enter to continue or N/n to cancel') 36 | 37 | prompt = input('> ') 38 | if prompt.lower() == 'n': 39 | print('Command cancelled') 40 | exit(0) 41 | elif prompt == '': 42 | subprocess.run(res.function_result, shell=True) 43 | exit(0) 44 | else: 45 | res = conversation.ask( 46 | prompt=prompt, 47 | functions=[execute_command], 48 | ) 49 | conversation = res.conversation 50 | else: 51 | print(res.answer, end='\n\n') 52 | prompt = input('> ') 53 | res = conversation.ask( 54 | prompt=prompt, 55 | functions=[execute_command], 56 | ) 57 | conversation = res.conversation 58 | 59 | 60 | -------------------------------------------------------------------------------- /examples/compare-embeddings.py: -------------------------------------------------------------------------------- 1 | from cursive import Cursive 2 | import numpy as np 3 | 4 | cursive = Cursive() 5 | 6 | x1 = cursive.embed("""Pizza""") 7 | 8 | x2 = cursive.embed("""Cat""") 9 | 10 | def cosine_similarity(x, y): 11 | return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y)) 12 | 13 | print(cosine_similarity(x1, x2)) -------------------------------------------------------------------------------- /examples/generate-list-of-objects.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from cursive.compat.pydantic import BaseModel 3 | from cursive.function import cursive_function 4 | from cursive import Cursive 5 | 6 | class Input(BaseModel): 7 | input: str 8 | idx: int 9 | 10 | @cursive_function(pause=True) 11 | def gen_character_list(inputs: List[Input]): 12 | """ 13 | Given a prompt (which is directives for a LLM), generate possible inputs that could be fed to it. 14 | Generate 10 inputs. 15 | 16 | inputs: A list of inputs. 17 | """ 18 | return inputs 19 | 20 | c = Cursive() 21 | 22 | res = c.ask( 23 | prompt="Generate a input for prompt that generates headlines for a SaaS company.", 24 | model="gpt-4", 25 | function_call=Input 26 | ) 27 | 28 | print(res.function_result) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cursivepy" 3 | version = "0.7.2" 4 | description = "" 5 | authors = [ 6 | "Rodrigo Godinho ", 7 | "Henrique Cunha ", 8 | "Cyrus Nouroozi " 9 | ] 10 | readme = "README.md" 11 | packages = [{ include = "cursive" }] 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.10.0,<3.12" 15 | tiktoken = "^0.4.0" 16 | openai = "^0.27.8" 17 | anthropic = "^0.3.6" 18 | pydantic = ">=1,<3" 19 | cohere = "^4.19.3" 20 | tokenizers = "^0.13.3" 21 | replicate = "^0.11.0" 22 | sorcery = "^0.2.2" 23 | numpy = "^1.26.0" 24 | 25 | [tool.poetry.group.dev.dependencies] 26 | pytest = "^7.4.0" 27 | ipython = "^8.14.0" 28 | ipdb = "^0.13.13" 29 | 30 | [build-system] 31 | requires = ["poetry-core"] 32 | build-backend = "poetry.core.masonry.api" 33 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = pydantic1,pydantic2 3 | 4 | [testenv] 5 | deps = 6 | pytest 7 | pydantic1: pydantic==1.* 8 | pydantic2: pydantic==2.* 9 | 10 | commands = pytest --------------------------------------------------------------------------------