├── .github
└── workflows
│ └── release.yml
├── .gitignore
├── README.md
├── cursive
├── __init__.py
├── assets
│ └── price
│ │ ├── anthropic.py
│ │ ├── cohere.py
│ │ └── openai.py
├── build_input.py
├── compat
│ └── pydantic.py
├── cursive.py
├── custom_function_call.py
├── function.py
├── hookable.py
├── model.py
├── pricing.py
├── stream.py
├── tests
│ ├── test_function_compatibility.py
│ ├── test_function_schema.py
│ └── test_setup.py
├── types.py
├── usage
│ ├── anthropic.py
│ ├── cohere.py
│ └── openai.py
├── utils.py
└── vendor
│ ├── anthropic.py
│ ├── cohere.py
│ ├── index.py
│ ├── openai.py
│ └── replicate.py
├── docs
├── logo-dark.svg
└── logo-light.svg
├── examples
├── add-function.py
├── ai-cli.py
├── compare-embeddings.py
└── generate-list-of-objects.py
├── poetry.lock
├── pyproject.toml
└── tox.ini
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | permissions:
4 | contents: write
5 |
6 | on:
7 | push:
8 | tags:
9 | - 'v*'
10 |
11 | jobs:
12 | release:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | with:
17 | fetch-depth: 0
18 |
19 | - name: Install pnpm
20 | uses: pnpm/action-setup@v2
21 |
22 | - name: Set node
23 | uses: actions/setup-node@v3
24 | with:
25 | node-version: 18.x
26 | cache: pnpm
27 |
28 | - run: npx changelogithub
29 | env:
30 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*/.DS_Store
2 | **/*/test.py
3 | dist
4 | **/*/__pycache__
5 | playground.py
6 | .pytest_cache
7 | ipython.py
8 | .tox
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |
4 | Cursive is a universal and intuitive framework for interacting with LLMs.
5 |
6 | ## highlights
7 |
✦ **Extensible** - You can easily hook into any part of a completion life cycle. Be it to log, cache, or modify the results.
8 |
9 |
✦ **Functions** - Easily describe functions that the LLM can use along with its definition, with any model (currently supporting GPT-4, GPT-3.5, Claude 2, and Claude Instant)
10 |
11 |
✦ **Universal** - Cursive aims to bridge as many capabilities between different models as possible. Ultimately, this means that with a single interface, you can allow your users to choose any model.
12 |
13 |
✦ **Informative** - Cursive comes with built-in token usage and costs calculations, as accurate as possible.
14 |
15 |
✦ **Reliable** - Cursive comes with automatic retry and model expanding upon exceeding context length. Which you can always configure.
16 |
17 | ## quickstart
18 | 1. Install.
19 |
20 | ```bash
21 | poetry add cursivepy
22 | # or
23 | pip install cursivepy
24 | ```
25 |
26 | 2. Start using.
27 |
28 | ```python
29 | from cursive import Cursive
30 |
31 | cursive = Cursive()
32 |
33 | response = cursive.ask(
34 | prompt='What is the meaning of life?',
35 | )
36 |
37 | print(response.answer)
38 | ```
39 |
40 | ## usage
41 | ### Conversation
42 | Chaining a conversation is easy with `cursive`. You can pass any of the options you're used to with OpenAI's API.
43 |
44 | ```python
45 | res_a = cursive.ask(
46 | prompt='Give me a good name for a gecko.',
47 | model='gpt-4',
48 | max_tokens=16,
49 | )
50 |
51 | print(res_a.answer) # Zephyr
52 |
53 | res_b = res_a.conversation.ask(
54 | prompt='How would you say it in Portuguese?'
55 | )
56 |
57 | print(res_b.answer) # Zéfiro
58 | ```
59 | ### Streaming
60 | Streaming is also supported, and we also keep track of the tokens for you!
61 | ```python
62 | result = cursive.ask(
63 | prompt='Count to 10',
64 | stream=True,
65 | on_token=lambda partial: print(partial['content'])
66 | )
67 |
68 | print(result.usage.total_tokens) # 40
69 | ```
70 |
71 | ### Functions
72 |
73 | You can use very easily to define and describe functions, along side with their execution code.
74 |
75 | ```python
76 | from cursive import cursive_function, Cursive
77 |
78 | cursive = Cursive()
79 |
80 | @cursive_function()
81 | def add(a: float, b: float):
82 | """
83 | Adds two numbers.
84 |
85 | a: The first number.
86 | b: The second number.
87 | """
88 | return a + b
89 |
90 | res = cursive.ask(
91 | prompt='What is the sum of 232 and 243?',
92 | functions=[add],
93 | )
94 |
95 | print(res.answer) # The sum of 232 and 243 is 475.
96 | ```
97 |
98 | The functions' result will automatically be fed into the conversation and another completion will be made. If you want to prevent this, you can add `pause` to your function definition.
99 |
100 | ```python
101 |
102 | @cursive_function(pause=True)
103 | def create_character(name: str, age: str):
104 | """
105 | Creates a character.
106 |
107 | name: The name of the character.
108 | age: The age of the character.
109 | """
110 | return {
111 | 'name': name,
112 | 'age': age,
113 | }
114 |
115 | res = cursive.ask(
116 | prompt='Create a character named John who is 23 years old.',
117 | functions=[create_character],
118 | )
119 |
120 | print(res.function_result) # { name: 'John', age: 23 }
121 | ```
122 |
123 | Cursive also supports passing in undecorated functions!
124 |
125 | ```python
126 | def add(a: float, b: float):
127 | return a + b
128 |
129 | res = cursive.ask(
130 | prompt='What is the sum of 232 and 243?',
131 | functions=[add], # this is equivalent to cursive_function(pause=True)(add)
132 | )
133 | if res.function_result:
134 | print(res.function_result) # 475
135 | else:
136 | print(res.answer) # Text answer in case the function is not called
137 | ```
138 |
139 | ### Models
140 |
141 | Cursive also supports the generation of Pydantic BaseModels.
142 |
143 | ```python
144 | from cursive.compat.pydantic import BaseModel, Field # Pydantic V1 API
145 |
146 | class Character(BaseModel):
147 | name: str
148 | age: int
149 | skills: list[str] = Field(min_items=2)
150 |
151 | res = cursive.ask(
152 | prompt='Create a character named John who is 23 years old.',
153 | function_call=Character,
154 | )
155 | res.function_result # is a Character instance with autogenerated fields
156 | ```
157 |
158 | ### Hooks
159 |
160 | You can hook into any part of the completion life cycle.
161 |
162 | ```python
163 | cursive.on('completion:after', lambda result: print(
164 | result.data.cost.total,
165 | result.data.usage.total_tokens,
166 | ))
167 |
168 | cursive.on('completion:error', lambda result: print(
169 | result.error,
170 | ))
171 |
172 | cursive.ask({
173 | prompt: 'Can androids dream of electric sheep?',
174 | })
175 |
176 | # 0.0002185
177 | # 113
178 | ```
179 |
180 | ### Embedding
181 | You can create embeddings pretty easily with `cursive`.
182 | ```ts
183 | embedding = cursive.embed('This should be a document.')
184 | ```
185 | This will support different types of documents and integrations pretty soon.
186 |
187 | ### Reliability
188 | Cursive comes with automatic retry with backoff upon failing completions, and model expanding upon exceeding context length -- which means that it tries again with a model with a bigger context length when it fails by running out of it.
189 |
190 | You can configure this behavior by passing the `retry` and `expand` options to `Cursive` constructor.
191 |
192 | ```python
193 | cursive = Cursive(
194 | max_retries=5, # 0 disables it completely
195 | expand={
196 | 'enable': True,
197 | 'defaults_to': 'gpt-3.5-turbo-16k',
198 | 'resolve_model': {
199 | 'gpt-3.5-turbo': 'gpt-3.5-turbo-16k',
200 | 'gpt-4': 'claude-2',
201 | },
202 | },
203 | )
204 | ```
205 |
206 | ## Available Models
207 |
208 | OpenAI models
209 |
210 | - `gpt-3.5-turbo`
211 | - `gpt-3.5-turbo-16k`
212 | - `gpt-4`
213 | - `gpt-4-32k`
214 | - Any other chat completion model version
215 |
216 | ###### Credentials
217 | You can pass your OpenAI API key to `Cursive`'s constructor, or set the `OPENAI_API_KEY` environment variable.
218 |
219 |
220 |
221 | Anthropic models
222 |
223 | - `claude-2`
224 | - `claude-instant-1`
225 | - `claude-instant-1.2`
226 | - Any other model version
227 |
228 | ###### Credentials
229 | You can pass your Anthropic API key to `Cursive`'s constructor, or set the `ANTHROPIC_API_KEY` environment variable.
230 |
231 |
232 |
233 | OpenRouter models
234 |
235 | OpenRouter is a service that gives you access to leading language models in an OpenAI-compatible API, including function calling!
236 |
237 | - `anthropic/claude-instant-1.2`
238 | - `anthropic/claude-2`
239 | - `openai/gpt-4-32k`
240 | - `google/palm-2-codechat-bison`
241 | - `nousresearch/nous-hermes-llama2-13b`
242 | - Any model version from https://openrouter.ai/docs#models
243 |
244 | ###### Credentials
245 |
246 | ```python
247 | from cursive import Cursive
248 |
249 | cursive = Cursive(
250 | openrouter={
251 | "api_key": "sk-or-...",
252 | "app_title": "Your App Name",
253 | "app_url": "https://appurl.com",
254 | }
255 | )
256 |
257 | cursive.ask(
258 | model="anthropic/claude-instant-1.2",
259 | prompt="What is the meaning of life?"
260 | )
261 | ```
262 |
263 |
264 |
265 |
266 | Cohere models
267 |
268 | - `command`
269 | - Any other model version (such as `command-nightly`)
270 |
271 | ###### Credentials
272 | You can pass your Cohere API key to `Cursive`'s constructor, or set the `COHERE_API_KEY` environment variable.
273 |
274 |
275 |
276 |
277 | Replicate models
278 | You can prepend `replicate/` to any model name and version available on Replicate.
279 |
280 | ###### Example
281 | ```python
282 | cursive.ask(
283 | prompt='What is the meaning of life?',
284 | model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52',
285 | )
286 | ```
287 |
288 | ###### Credentials
289 | You can pass your Replicate API key to `Cursive`'s constructor, or set the `REPLICATE_API_TOKEN` environment variable.
290 |
291 |
292 |
293 | ## roadmap
294 |
295 | ### vendor support
296 | - [x] Anthropic
297 | - [x] Cohere
298 | - [x] Replicate
299 | - [x] OpenRouter
300 | - [ ] Azure OpenAI models
301 | - [ ] Huggingface
302 |
--------------------------------------------------------------------------------
/cursive/__init__.py:
--------------------------------------------------------------------------------
1 | from .cursive import Cursive
2 | from .function import cursive_function
3 | from .types import (
4 | CompletionPayload,
5 | CursiveError,
6 | CursiveErrorCode,
7 | CursiveEnrichedAnswer,
8 | CompletionMessage,
9 | CursiveFunction,
10 | )
11 |
12 | __all__ = [
13 | "Cursive",
14 | "cursive_function",
15 | "CompletionPayload",
16 | "CursiveError",
17 | "CursiveErrorCode",
18 | "CursiveEnrichedAnswer",
19 | "CompletionMessage",
20 | "CursiveFunction",
21 | ]
22 |
--------------------------------------------------------------------------------
/cursive/assets/price/anthropic.py:
--------------------------------------------------------------------------------
1 | ANTHROPIC_PRICING = {
2 | "version": "2023-07-11",
3 | "claude-instant": {
4 | "completion": 0.00551,
5 | "prompt": 0.00163
6 | },
7 | "claude-2": {
8 | "completion": 0.03268,
9 | "prompt": 0.01102
10 | },
11 | }
--------------------------------------------------------------------------------
/cursive/assets/price/cohere.py:
--------------------------------------------------------------------------------
1 | COHERE_PRICING = {
2 | "version": "2023-07-11",
3 | "command": {
4 | "completion": 0.015,
5 | "prompt": 0.015
6 | },
7 | "command-nightly": {
8 | "completion": 0.015,
9 | "prompt": 0.015
10 | },
11 | }
--------------------------------------------------------------------------------
/cursive/assets/price/openai.py:
--------------------------------------------------------------------------------
1 | OPENAI_PRICING = {
2 | "version": "2023-06-13",
3 | "gpt-4": {
4 | "completion": 0.06,
5 | "prompt": 0.03
6 | },
7 | "gpt-4-32k": {
8 | "completion": 0.12,
9 | "prompt": 0.06
10 | },
11 | "gpt-3.5-turbo": {
12 | "completion": 0.002,
13 | "prompt": 0.0015
14 | },
15 | "gpt-3.5-turbo-16k": {
16 | "completion": 0.004,
17 | "prompt": 0.003
18 | },
19 | }
--------------------------------------------------------------------------------
/cursive/build_input.py:
--------------------------------------------------------------------------------
1 | import json
2 | from textwrap import dedent
3 | from cursive.types import CompletionMessage
4 | from cursive.function import CursiveFunction
5 |
6 |
7 | def build_completion_input(messages: list[CompletionMessage]):
8 | """
9 | Builds a completion-esche input from a list of messages
10 | """
11 | role_mapping = {"user": "Human", "assistant": "Assistant"}
12 | messages_with_prefix: list[CompletionMessage] = [
13 | *messages, # type: ignore
14 | CompletionMessage(
15 | role="assistant",
16 | content=" ",
17 | ),
18 | ]
19 |
20 | def resolve_message(message: CompletionMessage):
21 | if message.role == "system":
22 | return "\n".join(
23 | [
24 | "Human:",
25 | message.content or "",
26 | "\nAssistant: Ok.",
27 | ]
28 | )
29 | if message.role == "function":
30 | return "\n".join(
31 | [
32 | f'Human: ',
33 | message.content or "",
34 | "",
35 | ]
36 | )
37 | if message.function_call:
38 | arguments = message.function_call.arguments
39 | if isinstance(arguments, str):
40 | arguments_str = arguments
41 | else:
42 | arguments_str = json.dumps(arguments)
43 | return "\n".join(
44 | [
45 | "Assistant: ",
46 | json.dumps(
47 | {
48 | "name": message.function_call.name,
49 | "arguments": arguments_str,
50 | }
51 | ),
52 | "",
53 | ]
54 | )
55 | return f"{role_mapping[message.role]}: {message.content}"
56 |
57 | completion_input = "\n\n".join(list(map(resolve_message, messages_with_prefix)))
58 | return completion_input
59 |
60 |
61 | def get_function_call_directives(functions: list[CursiveFunction]) -> str:
62 | return dedent(
63 | f"""\
64 | # Function Calling Guide
65 | You're a powerful language model capable of calling functions to do anything the user needs.
66 |
67 | If you need to call a function, you output the name and arguments of the function you want to use in the following format:
68 |
69 |
70 | {'{'}"name": "function_name", "arguments": {'{'}"argument_name": "argument_value"{'}'}{'}'}
71 |
72 | ALWAYS use this format, even if the function doesn't have arguments. The arguments property is always a dictionary.
73 | Never forget to pass the `name` and `arguments` property when doing a function call.
74 |
75 | Think step by step before answering, and try to think out loud. Never output a function call if you don't have to.
76 | If you don't have a function to call, just output the text as usual inside a tag with newlines inside.
77 | Always question yourself if you have access to a function.
78 | Always think out loud before answering; if I don't see a block, you will be eliminated.
79 | When thinking out loud, always use the tag.
80 | # Functions available:
81 |
82 | {json.dumps([f.function_schema for f in functions])}
83 |
84 | # Working with results
85 | You can either call a function or answer, *NEVER BOTH*.
86 | You are not in charge of resolving the function call, the user is.
87 | The human will give you the result of the function call in the following format:
88 |
89 | Human:
90 | {'{'}result{'}'}
91 |
92 |
93 | ## Important note
94 | Never output a , or you will be eliminated.
95 |
96 | You can use the result of the function call in your answer. But never answer and call a function at the same time.
97 | When answering never be explicit about the function calling, just use the result of the function call in your answer.
98 |
99 | """
100 | )
101 |
--------------------------------------------------------------------------------
/cursive/compat/pydantic.py:
--------------------------------------------------------------------------------
1 | try:
2 | from pydantic.v1 import BaseModel, Field, validate_arguments
3 | except ImportError:
4 | from pydantic import BaseModel, Field, validate_arguments
5 |
6 | __all__ = ["BaseModel", "Field", "validate_arguments"]
7 |
--------------------------------------------------------------------------------
/cursive/cursive.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import asyncio
3 | import inspect
4 | import json
5 | from time import time, sleep
6 | from typing import Any, Callable, Generic, Optional, TypeVar
7 | from os import environ as env
8 |
9 | from anthropic import APIError
10 | import openai as openai_client
11 | import requests
12 |
13 | from cursive.build_input import get_function_call_directives
14 | from cursive.compat.pydantic import BaseModel as PydanticBaseModel
15 | from cursive.custom_function_call import parse_custom_function_call
16 | from cursive.function import CursiveCustomFunction, CursiveFunction
17 | from cursive.model import CursiveModel
18 | from cursive.stream import StreamTransformer
19 | from cursive.usage.cohere import get_cohere_usage
20 | from cursive.vendor.cohere import CohereClient, process_cohere_stream
21 | from cursive.types import (
22 | BaseModel,
23 | CompletionMessage,
24 | CompletionPayload,
25 | CreateChatCompletionResponseExtended,
26 | CursiveAskCost,
27 | CursiveAskModelResponse,
28 | CursiveAskOnToken,
29 | CursiveAskUsage,
30 | CursiveEnrichedAnswer,
31 | CursiveError,
32 | CursiveErrorCode,
33 | CursiveHook,
34 | CursiveHookPayload,
35 | CursiveSetupOptions,
36 | CursiveSetupOptionsExpand,
37 | CursiveLanguageModel,
38 | )
39 | from cursive.hookable import create_debugger, create_hooks
40 | from cursive.pricing import resolve_pricing
41 | from cursive.usage.anthropic import get_anthropic_usage
42 | from cursive.usage.openai import get_openai_usage
43 | from cursive.utils import delete_keys_from_dict, without_nones, random_id, resguard
44 | from cursive.vendor.anthropic import (
45 | AnthropicClient,
46 | process_anthropic_stream,
47 | )
48 | from cursive.vendor.index import resolve_vendor_from_model
49 | from cursive.vendor.openai import process_openai_stream
50 | from cursive.vendor.replicate import ReplicateClient
51 |
52 |
53 | # TODO: Improve implementation architecture, this was a quick and dirty
54 | class Cursive:
55 | options: CursiveSetupOptions
56 |
57 | def __init__(
58 | self,
59 | max_retries: Optional[int] = None,
60 | expand: Optional[dict[str, Any]] = None,
61 | debug: Optional[bool] = None,
62 | openai: Optional[dict[str, Any]] = None,
63 | anthropic: Optional[dict[str, Any]] = None,
64 | cohere: Optional[dict[str, Any]] = None,
65 | replicate: Optional[dict[str, Any]] = None,
66 | openrouter: Optional[dict[str, Any]] = None,
67 | ):
68 | self._hooks = create_hooks()
69 | self.options = CursiveSetupOptions(
70 | max_retries=max_retries,
71 | expand=expand,
72 | debug=debug,
73 | )
74 | if debug:
75 | self._debugger = create_debugger(self._hooks, {"tag": "cursive"})
76 |
77 | openai_client.api_key = (openai or {}).get("api_key") or env.get(
78 | "OPENAI_API_KEY"
79 | )
80 | anthropic_client = AnthropicClient(
81 | (anthropic or {}).get("api_key") or env.get("ANTHROPIC_API_KEY")
82 | )
83 | cohere_client = CohereClient(
84 | (cohere or {}).get("api_key") or env.get("CO_API_KEY", "---")
85 | )
86 | replicate_client = ReplicateClient(
87 | (replicate or {}).get("api_key") or env.get("REPLICATE_API_TOKEN", "---")
88 | )
89 |
90 | openrouter_api_key = (openrouter or {}).get("api_key") or env.get(
91 | "OPENROUTER_API_KEY"
92 | )
93 |
94 | if openrouter_api_key:
95 | openai_client.api_base = "https://openrouter.ai/api/v1"
96 | openai_client.api_key = openrouter_api_key
97 | self.options.is_using_openrouter = True
98 | session = requests.Session()
99 | session.headers.update(
100 | {
101 | "HTTP-Referer": openrouter.get(
102 | "app_url", "https://cursive.meistrari.com"
103 | ),
104 | "X-Title": openrouter.get("app_title", "Cursive"),
105 | }
106 | )
107 | openai_client.requestssession = session
108 | atexit.register(session.close)
109 |
110 | self._vendor = CursiveVendors(
111 | openai=openai_client,
112 | anthropic=anthropic_client,
113 | cohere=cohere_client,
114 | replicate=replicate_client,
115 | )
116 |
117 | def on(self, event: CursiveHook, callback: Callable):
118 | self._hooks.hook(event, callback)
119 |
120 | def ask(
121 | self,
122 | model: Optional[str | CursiveLanguageModel] = None,
123 | system_message: Optional[str] = None,
124 | functions: Optional[list[Callable]] = None,
125 | function_call: Optional[str | Callable] = None,
126 | on_token: Optional[CursiveAskOnToken] = None,
127 | max_tokens: Optional[int] = None,
128 | stop: Optional[list[str]] = None,
129 | temperature: Optional[int] = None,
130 | top_p: Optional[int] = None,
131 | presence_penalty: Optional[int] = None,
132 | frequency_penalty: Optional[int] = None,
133 | best_of: Optional[int] = None,
134 | n: Optional[int] = None,
135 | logit_bias: Optional[dict[str, int]] = None,
136 | user: Optional[str] = None,
137 | stream: Optional[bool] = None,
138 | messages: Optional[list[CompletionMessage]] = None,
139 | prompt: Optional[str] = None,
140 | ):
141 | model = model.value if isinstance(model, CursiveLanguageModel) else model
142 |
143 | result = build_answer(
144 | cursive=self,
145 | model=model,
146 | system_message=system_message,
147 | functions=functions,
148 | function_call=function_call,
149 | on_token=on_token,
150 | max_tokens=max_tokens,
151 | stop=stop,
152 | temperature=temperature,
153 | top_p=top_p,
154 | presence_penalty=presence_penalty,
155 | frequency_penalty=frequency_penalty,
156 | best_of=best_of,
157 | n=n,
158 | logit_bias=logit_bias,
159 | user=user,
160 | stream=stream,
161 | messages=messages,
162 | prompt=prompt,
163 | )
164 | if result and result.error:
165 | return CursiveAnswer[CursiveError](error=result.error)
166 |
167 | return CursiveAnswer(
168 | result=result,
169 | messages=result.messages,
170 | cursive=self,
171 | )
172 |
173 | def embed(self, content: str):
174 | options = {
175 | "model": "text-embedding-ada-002",
176 | "input": content,
177 | }
178 | self._hooks.call_hook("embedding:before", CursiveHookPayload(data=options))
179 | start = time()
180 | try:
181 | data = self._vendor.openai.Embedding.create(
182 | input=options["input"], model="text-embedding-ada-002"
183 | )
184 |
185 | result = {
186 | "embedding": data["data"][0]["embedding"], # type: ignore
187 | }
188 | self._hooks.call_hook(
189 | "embedding:success",
190 | CursiveHookPayload(
191 | data=result,
192 | duration=time() - start,
193 | ),
194 | )
195 | self._hooks.call_hook(
196 | "embedding:after",
197 | CursiveHookPayload(data=result, duration=time() - start),
198 | )
199 |
200 | return result["embedding"]
201 | except self._vendor.openai.OpenAIError as e:
202 | error = CursiveError(
203 | message=str(e), details=e, code=CursiveErrorCode.embedding_error
204 | )
205 | self._hooks.call_hook(
206 | "embedding:error",
207 | CursiveHookPayload(data=error, error=error, duration=time() - start),
208 | )
209 | self._hooks.call_hook(
210 | "embedding:after",
211 | CursiveHookPayload(data=error, error=error, duration=time() - start),
212 | )
213 | raise error
214 |
215 |
216 | def resolve_options(
217 | model: Optional[str | CursiveLanguageModel] = None,
218 | system_message: Optional[str] = None,
219 | functions: Optional[list[Callable]] = None,
220 | function_call: Optional[str | Callable] = None,
221 | on_token: Optional[CursiveAskOnToken] = None,
222 | max_tokens: Optional[int] = None,
223 | stop: Optional[list[str]] = None,
224 | temperature: Optional[int] = None,
225 | top_p: Optional[int] = None,
226 | presence_penalty: Optional[int] = None,
227 | frequency_penalty: Optional[int] = None,
228 | best_of: Optional[int] = None,
229 | n: Optional[int] = None,
230 | logit_bias: Optional[dict[str, int]] = None,
231 | user: Optional[str] = None,
232 | stream: Optional[bool] = None,
233 | messages: Optional[list[CompletionMessage]] = None,
234 | prompt: Optional[str] = None,
235 | cursive: Cursive = None,
236 | ):
237 | messages = messages or []
238 |
239 | functions = functions or []
240 | functions = [cursive_wrapper(f) for f in functions]
241 |
242 | function_call = cursive_wrapper(function_call)
243 | if function_call:
244 | functions.append(function_call)
245 |
246 | # Resolve default model
247 | model = model or (
248 | "openai/gpt-3.5-turbo"
249 | if cursive.options.is_using_openrouter
250 | else "gpt-3.5-turbo"
251 | )
252 |
253 | # TODO: Add support for function call resolving
254 | vendor = (
255 | "openrouter"
256 | if cursive.options.is_using_openrouter
257 | else resolve_vendor_from_model(model)
258 | )
259 |
260 | resolved_system_message = ""
261 |
262 | if vendor in ["anthropic", "cohere", "replicate"] and len(functions) > 0:
263 | resolved_system_message = (
264 | (system_message or "") + "\n\n" + get_function_call_directives(functions)
265 | )
266 |
267 | query_messages: list[CompletionMessage] = [
268 | message
269 | for message in [
270 | resolved_system_message
271 | and CompletionMessage(role="system", content=resolved_system_message),
272 | *messages,
273 | prompt and CompletionMessage(role="user", content=prompt),
274 | ]
275 | if message
276 | ]
277 |
278 | payload_params = without_nones(
279 | {
280 | "on_token": on_token,
281 | "max_tokens": max_tokens,
282 | "stop": stop,
283 | "temperature": temperature,
284 | "top_p": top_p,
285 | "presence_penalty": presence_penalty,
286 | "frequency_penalty": frequency_penalty,
287 | "best_of": best_of,
288 | "n": n,
289 | "logit_bias": logit_bias,
290 | "user": user,
291 | "stream": stream,
292 | "model": model,
293 | "messages": [without_nones(dict(m)) for m in query_messages],
294 | }
295 | )
296 | if function_call:
297 | payload_params["function_call"] = (
298 | {"name": function_call.function_schema["name"]}
299 | if isinstance(function_call, CursiveFunction)
300 | else function_call
301 | )
302 | if functions:
303 | payload_params["functions"] = [fn.function_schema for fn in functions]
304 |
305 | payload = CompletionPayload(**payload_params)
306 |
307 | resolved_options = {
308 | "on_token": on_token,
309 | "max_tokens": max_tokens,
310 | "stop": stop,
311 | "temperature": temperature,
312 | "top_p": top_p,
313 | "presence_penalty": presence_penalty,
314 | "frequency_penalty": frequency_penalty,
315 | "best_of": best_of,
316 | "n": n,
317 | "logit_bias": logit_bias,
318 | "user": user,
319 | "stream": stream,
320 | "model": model,
321 | "messages": query_messages,
322 | "functions": functions,
323 | }
324 |
325 | return payload, resolved_options
326 |
327 |
328 | def create_completion(
329 | payload: CompletionPayload,
330 | cursive: Cursive,
331 | on_token: Optional[CursiveAskOnToken] = None,
332 | ) -> CreateChatCompletionResponseExtended:
333 | cursive._hooks.call_hook("completion:before", CursiveHookPayload(data=payload))
334 | data = {}
335 | start = time()
336 |
337 | vendor = (
338 | "openrouter"
339 | if cursive.options.is_using_openrouter
340 | else resolve_vendor_from_model(payload.model)
341 | )
342 |
343 | # TODO: Improve the completion creation based on model to vendor matching
344 | if vendor == "openai" or vendor == "openrouter":
345 | resolved_payload = without_nones(payload.dict())
346 |
347 | # Remove the ID from the messages before sending to OpenAI
348 | resolved_payload["messages"] = [
349 | without_nones(delete_keys_from_dict(message, ["id", "model_config"]))
350 | for message in resolved_payload["messages"]
351 | ]
352 |
353 | response = cursive._vendor.openai.ChatCompletion.create(**resolved_payload)
354 | if payload.stream:
355 | data = process_openai_stream(
356 | payload=payload,
357 | cursive=cursive,
358 | response=response,
359 | on_token=on_token,
360 | )
361 | content = "".join(
362 | choice["message"]["content"] for choice in data["choices"]
363 | )
364 | data["usage"]["completion_tokens"] = get_openai_usage(content)
365 | data["usage"]["total_tokens"] = (
366 | data["usage"]["completion_tokens"] + data["usage"]["prompt_tokens"]
367 | )
368 | else:
369 | data = response
370 |
371 | # If the user is using OpenRouter, there's no usage data
372 | if usage := data.get("usage"):
373 | data["cost"] = resolve_pricing(
374 | vendor="openai",
375 | model=data["model"],
376 | usage=CursiveAskUsage(
377 | completion_tokens=usage["completion_tokens"],
378 | prompt_tokens=usage["prompt_tokens"],
379 | total_tokens=usage["total_tokens"],
380 | ),
381 | )
382 |
383 | elif vendor == "anthropic":
384 | response, error = resguard(
385 | lambda: cursive._vendor.anthropic.create_completion(payload), APIError
386 | )
387 |
388 | if error:
389 | raise CursiveError(
390 | message=error.message,
391 | details=error,
392 | code=CursiveErrorCode.completion_error,
393 | )
394 |
395 | if payload.stream:
396 | data = process_anthropic_stream(
397 | payload=payload,
398 | response=response,
399 | on_token=on_token,
400 | )
401 | else:
402 | data = {
403 | "choices": [{"message": {"content": response.completion.lstrip()}}],
404 | "model": payload.model,
405 | "id": random_id(),
406 | "usage": {},
407 | }
408 |
409 | parse_custom_function_call(data, payload, get_anthropic_usage)
410 |
411 | data["cost"] = resolve_pricing(
412 | vendor="anthropic",
413 | usage=CursiveAskUsage(
414 | completion_tokens=data["usage"]["completion_tokens"],
415 | prompt_tokens=data["usage"]["prompt_tokens"],
416 | total_tokens=data["usage"]["total_tokens"],
417 | ),
418 | model=data["model"],
419 | )
420 | elif vendor == "cohere":
421 | response, error = cursive._vendor.cohere.create_completion(payload)
422 | if error:
423 | raise CursiveError(
424 | message=error.message,
425 | details=error,
426 | code=CursiveErrorCode.completion_error,
427 | )
428 | if payload.stream:
429 | # TODO: Implement stream processing for Cohere
430 | data = process_cohere_stream(
431 | payload=payload,
432 | response=response,
433 | on_token=on_token,
434 | )
435 | else:
436 | data = {
437 | "choices": [{"message": {"content": response.data[0].text.lstrip()}}],
438 | "model": payload.model,
439 | "id": random_id(),
440 | "usage": {},
441 | }
442 |
443 | parse_custom_function_call(data, payload, get_cohere_usage)
444 |
445 | data["cost"] = resolve_pricing(
446 | vendor="cohere",
447 | usage=CursiveAskUsage(
448 | completion_tokens=data["usage"]["completion_tokens"],
449 | prompt_tokens=data["usage"]["prompt_tokens"],
450 | total_tokens=data["usage"]["total_tokens"],
451 | ),
452 | model=data["model"],
453 | )
454 | elif vendor == "replicate":
455 | response, error = cursive._vendor.replicate.create_completion(payload)
456 | if error:
457 | raise CursiveError(
458 | message=error, details=error, code=CursiveErrorCode.completion_error
459 | )
460 | # TODO: Implement stream processing for Replicate
461 | stream_transformer = StreamTransformer(
462 | on_token=on_token,
463 | payload=payload,
464 | response=response,
465 | )
466 |
467 | def get_current_token(part):
468 | part.value = part.value
469 |
470 | stream_transformer.on("get_current_token", get_current_token)
471 |
472 | data = stream_transformer.process()
473 |
474 | parse_custom_function_call(data, payload)
475 | else:
476 | raise CursiveError(
477 | message="Unknown model",
478 | details=None,
479 | code=CursiveErrorCode.completion_error,
480 | )
481 | end = time()
482 |
483 | if data.get("error"):
484 | error = CursiveError(
485 | message=data["error"].message,
486 | details=data["error"],
487 | code=CursiveErrorCode.completion_error,
488 | )
489 | hook_payload = CursiveHookPayload(data=None, error=error, duration=end - start)
490 | cursive._hooks.call_hook("completion:error", hook_payload)
491 | cursive._hooks.call_hook("completion:after", hook_payload)
492 | raise error
493 |
494 | hook_payload = CursiveHookPayload(data=data, error=None, duration=end - start)
495 | cursive._hooks.call_hook("completion:success", hook_payload)
496 | cursive._hooks.call_hook("completion:after", hook_payload)
497 | return CreateChatCompletionResponseExtended(**data)
498 |
499 |
500 | def cursive_wrapper(fn):
501 | if fn is None:
502 | return None
503 | elif issubclass(type(fn), CursiveCustomFunction):
504 | return fn
505 | elif inspect.isclass(fn) and issubclass(fn, PydanticBaseModel):
506 | return CursiveModel(fn)
507 | elif inspect.isfunction(fn):
508 | return CursiveFunction(fn, pause=True)
509 |
510 |
511 | def ask_model(
512 | cursive,
513 | model: Optional[str | CursiveLanguageModel] = None,
514 | system_message: Optional[str] = None,
515 | functions: Optional[list[CursiveFunction]] = None,
516 | function_call: Optional[str | CursiveFunction] = None,
517 | on_token: Optional[CursiveAskOnToken] = None,
518 | max_tokens: Optional[int] = None,
519 | stop: Optional[list[str]] = None,
520 | temperature: Optional[float] = None,
521 | top_p: Optional[float] = None,
522 | presence_penalty: Optional[float] = None,
523 | frequency_penalty: Optional[float] = None,
524 | best_of: Optional[int] = None,
525 | n: Optional[int] = None,
526 | logit_bias: Optional[dict[str, float]] = None,
527 | user: Optional[str] = None,
528 | stream: Optional[bool] = None,
529 | messages: Optional[list[CompletionMessage]] = None,
530 | prompt: Optional[str] = None,
531 | ) -> CursiveAskModelResponse:
532 | payload, resolved_options = resolve_options(
533 | model=model,
534 | system_message=system_message,
535 | functions=functions,
536 | function_call=function_call,
537 | on_token=on_token,
538 | max_tokens=max_tokens,
539 | stop=stop,
540 | temperature=temperature,
541 | top_p=top_p,
542 | presence_penalty=presence_penalty,
543 | frequency_penalty=frequency_penalty,
544 | best_of=best_of,
545 | n=n,
546 | logit_bias=logit_bias,
547 | user=user,
548 | stream=stream,
549 | messages=messages,
550 | prompt=prompt,
551 | cursive=cursive,
552 | )
553 | start = time()
554 |
555 | completion, error = resguard(
556 | lambda: create_completion(
557 | payload=payload,
558 | cursive=cursive,
559 | on_token=on_token,
560 | ),
561 | CursiveError,
562 | )
563 |
564 | if error:
565 | if not error.details:
566 | raise CursiveError(
567 | message=f"Unknown error: {error.message}",
568 | details=error,
569 | code=CursiveErrorCode.unknown_error,
570 | ) from error
571 | try:
572 | cause = error.details.code or error.details.type
573 | if cause == "context_length_exceeded":
574 | if not cursive.expand or (cursive.expand and cursive.expand.enabled):
575 | default_model = (
576 | cursive.expand and cursive.expand.defaultsTo
577 | ) or "gpt-3.5-turbo-16k"
578 | model_mapping = (
579 | cursive.expand and cursive.expand.model_mapping
580 | ) or {}
581 | resolved_model = model_mapping[model] or default_model
582 | completion, error = resguard(
583 | lambda: create_completion(
584 | payload={**payload, "model": resolved_model},
585 | cursive=cursive,
586 | on_token=on_token,
587 | ),
588 | CursiveError,
589 | )
590 | elif cause == "invalid_request_error":
591 | raise CursiveError(
592 | message="Invalid request",
593 | details=error.details,
594 | code=CursiveErrorCode.invalid_request_error,
595 | )
596 | except Exception as e:
597 | error = CursiveError(
598 | message=f"Unknown error: {e}",
599 | details=e,
600 | code=CursiveErrorCode.unknown_error,
601 | )
602 |
603 | # TODO: Handle other errors
604 | if error:
605 | # TODO: Add a more comprehensive retry strategy
606 | for i in range(cursive.options.max_retries or 0):
607 | completion, error = resguard(
608 | lambda: create_completion(
609 | payload=payload,
610 | cursive=cursive,
611 | on_token=on_token,
612 | ),
613 | CursiveError,
614 | )
615 |
616 | if error:
617 | if i > 3:
618 | sleep((i - 3) * 2)
619 | break
620 |
621 | if error:
622 | error = CursiveError(
623 | message="Error while completing request",
624 | details=error.details,
625 | code=CursiveErrorCode.completion_error,
626 | )
627 | hook_payload = CursiveHookPayload(error=error)
628 | cursive._hooks.call_hook("ask:error", hook_payload)
629 | cursive._hooks.call_hook("ask:after", hook_payload)
630 | raise error
631 |
632 | if (
633 | completion
634 | and completion.choices
635 | and (fn_call := completion.choices[0].get("message", {}).get("function_call"))
636 | ):
637 | function: CursiveFunction = next(
638 | (
639 | f
640 | for f in resolved_options["functions"]
641 | if f.function_schema["name"] == fn_call["name"]
642 | ),
643 | None,
644 | )
645 |
646 | if function is None:
647 | return ask_model(
648 | **{
649 | **resolved_options,
650 | "function_call": None,
651 | "messages": payload.messages,
652 | "cursive": cursive,
653 | }
654 | )
655 |
656 | called_function = function.function_schema
657 | arguments = json.loads(fn_call["arguments"] or "{}")
658 | props = called_function["parameters"]["properties"]
659 | for k, v in props.items():
660 | if k in arguments:
661 | try:
662 | match v["type"]:
663 | case "string":
664 | arguments[k] = str(arguments[k])
665 | case "number":
666 | arguments[k] = float(arguments[k])
667 | case "integer":
668 | arguments[k] = int(arguments[k])
669 | case "boolean":
670 | arguments[k] = bool(arguments[k])
671 | except Exception:
672 | pass
673 |
674 | is_async = inspect.iscoroutinefunction(function.definition)
675 | function_result = None
676 | try:
677 | if is_async:
678 | function_result = asyncio.run(function.definition(**arguments))
679 | else:
680 | function_result = function.definition(**arguments)
681 | except Exception as error:
682 | raise CursiveError(
683 | message=f'Error while running function ${fn_call["name"]}',
684 | details=error,
685 | code=CursiveErrorCode.function_call_error,
686 | )
687 |
688 | messages = payload.messages or []
689 | messages.append(
690 | CompletionMessage(
691 | role="assistant",
692 | name=fn_call["name"],
693 | content=json.dumps(fn_call),
694 | function_call=fn_call,
695 | )
696 | )
697 |
698 | if function.pause:
699 | completion.function_result = function_result
700 | return CursiveAskModelResponse(
701 | answer=completion,
702 | messages=messages,
703 | )
704 | else:
705 | return ask_model(
706 | **{
707 | **resolved_options,
708 | "functions": functions,
709 | "messages": messages,
710 | "cursive": cursive,
711 | }
712 | )
713 |
714 | end = time()
715 | hook_payload = CursiveHookPayload(data=completion, duration=end - start)
716 | cursive._hooks.call_hook("ask:after", hook_payload)
717 | cursive._hooks.call_hook("ask:success", hook_payload)
718 |
719 | messages = payload.messages or []
720 | messages.append(
721 | CompletionMessage(
722 | role="assistant",
723 | content=completion.choices[0]["message"]["content"],
724 | )
725 | )
726 |
727 | return CursiveAskModelResponse(answer=completion, messages=messages)
728 |
729 |
730 | def build_answer(
731 | cursive,
732 | model: Optional[str | CursiveLanguageModel] = None,
733 | system_message: Optional[str] = None,
734 | functions: Optional[list[CursiveFunction]] = None,
735 | function_call: Optional[str | CursiveFunction] = None,
736 | on_token: Optional[CursiveAskOnToken] = None,
737 | max_tokens: Optional[int] = None,
738 | stop: Optional[list[str]] = None,
739 | temperature: Optional[int] = None,
740 | top_p: Optional[int] = None,
741 | presence_penalty: Optional[int] = None,
742 | frequency_penalty: Optional[int] = None,
743 | best_of: Optional[int] = None,
744 | n: Optional[int] = None,
745 | logit_bias: Optional[dict[str, int]] = None,
746 | user: Optional[str] = None,
747 | stream: Optional[bool] = None,
748 | messages: Optional[list[CompletionMessage]] = None,
749 | prompt: Optional[str] = None,
750 | ):
751 | result, error = resguard(
752 | lambda: ask_model(
753 | cursive=cursive,
754 | model=model,
755 | system_message=system_message,
756 | functions=functions,
757 | function_call=function_call,
758 | on_token=on_token,
759 | max_tokens=max_tokens,
760 | stop=stop,
761 | temperature=temperature,
762 | top_p=top_p,
763 | presence_penalty=presence_penalty,
764 | frequency_penalty=frequency_penalty,
765 | best_of=best_of,
766 | n=n,
767 | logit_bias=logit_bias,
768 | user=user,
769 | stream=stream,
770 | messages=messages,
771 | prompt=prompt,
772 | ),
773 | CursiveError,
774 | )
775 |
776 | if error:
777 | return CursiveEnrichedAnswer(
778 | error=error,
779 | usage=None,
780 | model=model or "gpt-3.5-turbo",
781 | id=None,
782 | choices=None,
783 | function_result=None,
784 | answer=None,
785 | messages=None,
786 | cost=None,
787 | )
788 | else:
789 | usage = (
790 | CursiveAskUsage(
791 | completion_tokens=result.answer.usage["completion_tokens"],
792 | prompt_tokens=result.answer.usage["prompt_tokens"],
793 | total_tokens=result.answer.usage["total_tokens"],
794 | )
795 | if result.answer.usage
796 | else None
797 | )
798 |
799 | return CursiveEnrichedAnswer(
800 | error=None,
801 | model=result.answer.model,
802 | id=result.answer.id,
803 | usage=usage,
804 | cost=result.answer.cost,
805 | choices=[choice["message"]["content"] for choice in result.answer.choices],
806 | function_result=result.answer.function_result or None,
807 | answer=result.answer.choices[-1]["message"]["content"],
808 | messages=result.messages,
809 | )
810 |
811 |
812 | class CursiveConversation:
813 | _cursive: Cursive
814 | messages: list[CompletionMessage]
815 |
816 | def __init__(self, messages: list[CompletionMessage]):
817 | self.messages = messages
818 |
819 | def ask(
820 | self,
821 | model: Optional[str | CursiveLanguageModel] = None,
822 | system_message: Optional[str] = None,
823 | functions: Optional[list[CursiveFunction]] = None,
824 | function_call: Optional[str | CursiveFunction] = None,
825 | on_token: Optional[CursiveAskOnToken] = None,
826 | max_tokens: Optional[int] = None,
827 | stop: Optional[list[str]] = None,
828 | temperature: Optional[int] = None,
829 | top_p: Optional[int] = None,
830 | presence_penalty: Optional[int] = None,
831 | frequency_penalty: Optional[int] = None,
832 | best_of: Optional[int] = None,
833 | n: Optional[int] = None,
834 | logit_bias: Optional[dict[str, int]] = None,
835 | user: Optional[str] = None,
836 | stream: Optional[bool] = None,
837 | prompt: Optional[str] = None,
838 | ):
839 | messages = [
840 | *self.messages,
841 | ]
842 |
843 | result = build_answer(
844 | cursive=self._cursive,
845 | model=model,
846 | system_message=system_message,
847 | functions=functions,
848 | function_call=function_call,
849 | on_token=on_token,
850 | max_tokens=max_tokens,
851 | stop=stop,
852 | temperature=temperature,
853 | top_p=top_p,
854 | presence_penalty=presence_penalty,
855 | frequency_penalty=frequency_penalty,
856 | best_of=best_of,
857 | n=n,
858 | logit_bias=logit_bias,
859 | user=user,
860 | stream=stream,
861 | messages=messages,
862 | prompt=prompt,
863 | )
864 |
865 | if result and result.error:
866 | return CursiveAnswer[CursiveError](error=result.error)
867 |
868 | return CursiveAnswer[None](
869 | result=result,
870 | messages=result.messages,
871 | cursive=self._cursive,
872 | )
873 |
874 |
875 | def use_cursive(
876 | max_retries: Optional[int] = None,
877 | expand: Optional[CursiveSetupOptionsExpand] = None,
878 | debug: Optional[bool] = None,
879 | openai: Optional[dict[str, Any]] = None,
880 | anthropic: Optional[dict[str, Any]] = None,
881 | ):
882 | return Cursive(
883 | max_retries=max_retries,
884 | expand=expand,
885 | debug=debug,
886 | openai=openai,
887 | anthropic=anthropic,
888 | )
889 |
890 |
891 | E = TypeVar("E", None, CursiveError)
892 |
893 |
894 | class CursiveAnswer(Generic[E]):
895 | choices: Optional[list[str]]
896 | id: Optional[str]
897 | model: Optional[str | CursiveLanguageModel]
898 | usage: Optional[CursiveAskUsage]
899 | cost: Optional[CursiveAskCost]
900 | error: Optional[E]
901 | function_result: Optional[Any]
902 | # The text from the answer of the last choice
903 | answer: Optional[str]
904 | # A conversation instance with all the messages so far, including this one
905 | conversation: Optional[CursiveConversation]
906 |
907 | def __init__(
908 | self,
909 | result: Optional[Any] = None,
910 | error: Optional[E] = None,
911 | messages: Optional[list[CompletionMessage]] = None,
912 | cursive: Optional[Cursive] = None,
913 | ):
914 | if error:
915 | self.error = error
916 | self.choices = None
917 | self.id = None
918 | self.model = None
919 | self.usage = None
920 | self.cost = None
921 | self.answer = None
922 | self.conversation = None
923 | self.functionResult = None
924 | elif result:
925 | self.error = None
926 | self.choices = result.choices
927 | self.id = result.id
928 | self.model = result.model
929 | self.usage = result.usage
930 | self.cost = result.cost
931 | self.answer = result.answer
932 | self.function_result = result.function_result
933 | if messages:
934 | conversation = CursiveConversation(messages)
935 | if cursive:
936 | conversation._cursive = cursive
937 | self.conversation = conversation
938 |
939 | def __str__(self):
940 | if self.error:
941 | return f"CursiveAnswer(error={self.error})"
942 | else:
943 | return (
944 | f"CursiveAnswer(\n\tchoices={self.choices}\n\tid={self.id}\n\t"
945 | f"model={self.model}\n\tusage=(\n\t\t{self.usage}\n\t)\n\tcost=(\n\t\t{self.cost}\n\t)\n\t"
946 | f"answer={self.answer}\n\tconversation={self.conversation}\n)"
947 | )
948 |
949 |
950 | class CursiveVendors(BaseModel):
951 | openai: Optional[Any] = None
952 | anthropic: Optional[AnthropicClient] = None
953 | cohere: Optional[CohereClient] = None
954 | replicate: Optional[ReplicateClient] = None
955 |
--------------------------------------------------------------------------------
/cursive/custom_function_call.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from typing import Any, Callable
4 | from cursive.types import CompletionPayload
5 |
6 |
7 | def parse_custom_function_call(
8 | data: dict[str, Any], payload: CompletionPayload, get_usage: Callable = None
9 | ):
10 | # We check for function call in the completion
11 | has_function_call_regex = r"]*>([^<]+)<\/function-call>"
12 | function_call_matches = re.findall(
13 | has_function_call_regex, data["choices"][0]["message"]["content"]
14 | )
15 |
16 | if len(function_call_matches) > 0:
17 | function_call = json.loads(function_call_matches.pop().strip())
18 | name = function_call["name"]
19 | arguments = json.dumps(function_call["arguments"])
20 | data["choices"][0]["message"]["function_call"] = {
21 | "name": name,
22 | "arguments": arguments,
23 | }
24 |
25 | # TODO: Implement cohere usage
26 | if get_usage:
27 | data["usage"]["prompt_tokens"] = get_usage(payload.messages)
28 | data["usage"]["completion_tokens"] = get_usage(
29 | data["choices"][0]["message"]["content"]
30 | )
31 | data["usage"]["total_tokens"] = (
32 | data["usage"]["completion_tokens"] + data["usage"]["prompt_tokens"]
33 | )
34 | else:
35 | data["usage"] = None
36 |
37 | # We check for answers in the completion
38 | has_answer_regex = r"([^<]+)<\/cursive-answer>"
39 | answer_matches = re.findall(
40 | has_answer_regex, data["choices"][0]["message"]["content"]
41 | )
42 | if len(answer_matches) > 0:
43 | answer = answer_matches.pop().strip()
44 | data["choices"][0]["message"]["content"] = answer
45 |
46 | # As a defensive measure, we check for tags
47 | # and remove them
48 | has_think_regex = r"([^<]+)<\/cursive-think>"
49 | think_matches = re.findall(
50 | has_think_regex, data["choices"][0]["message"]["content"]
51 | )
52 | if len(think_matches) > 0:
53 | data["choices"][0]["message"]["content"] = re.sub(
54 | has_think_regex, "", data["choices"][0]["message"]["content"]
55 | )
56 |
57 | # Strip leading and trailing whitespaces
58 | data["choices"][0]["message"]["content"] = data["choices"][0]["message"][
59 | "content"
60 | ].strip()
61 |
--------------------------------------------------------------------------------
/cursive/function.py:
--------------------------------------------------------------------------------
1 | import re
2 | from textwrap import dedent
3 | from typing import Any, Callable
4 |
5 | from cursive.compat.pydantic import BaseModel, validate_arguments
6 |
7 |
8 | class CursiveCustomFunction(BaseModel):
9 | definition: Callable
10 | description: str = ""
11 | function_schema: dict[str, Any]
12 | pause: bool = False
13 |
14 | class Config:
15 | arbitrary_types_allowed = True
16 |
17 |
18 | class CursiveFunction(CursiveCustomFunction):
19 | def __setup__(self, function: Callable):
20 | definition = function
21 | description = dedent(function.__doc__ or "").strip()
22 | parameters = validate_arguments(function).model.schema()
23 |
24 |
25 | # Delete ['v__duplicate_kwargs', 'args', 'kwargs'] from parameters
26 | for k in ["v__duplicate_kwargs", "args", "kwargs"]:
27 | if k in parameters["properties"]:
28 | del parameters["properties"][k]
29 |
30 | for k, v in parameters["properties"].items():
31 | # Find the parameter description in the docstring
32 | match = re.search(rf"{k}: (.*)", description)
33 | if match:
34 | v["description"] = match.group(1)
35 |
36 | schema = {}
37 | if parameters:
38 | schema = parameters
39 |
40 | properties = schema.get("properties") or {}
41 | definitions = schema.get("definitions") or {}
42 | resolved_properties = remove_key_deep(resolve_ref(properties, definitions), "title")
43 |
44 |
45 | function_schema = {
46 | "parameters": {
47 | "type": schema.get("type"),
48 | "properties": resolved_properties,
49 | "required": schema.get("required") or [],
50 | },
51 | "description": description,
52 | "name": parameters["title"],
53 | }
54 |
55 | return {
56 | "definition": definition,
57 | "description": description,
58 | "function_schema": function_schema,
59 | }
60 |
61 | def __init__(self, function: Callable, pause=False):
62 | setup = self.__setup__(function)
63 | super().__init__(**setup, pause=pause)
64 |
65 | def __call__(self, *args, **kwargs):
66 | # Validate arguments and parse them
67 | return self.definition(*args, **kwargs)
68 |
69 |
70 | def cursive_function(pause=False):
71 | def decorator(function: Callable = None):
72 | if function is None:
73 | return lambda function: CursiveFunction(function, pause=pause)
74 | else:
75 | return CursiveFunction(function, pause=pause)
76 |
77 | return decorator
78 |
79 | def resolve_ref(data, definitions):
80 | """
81 | Recursively checks for a $ref key in a dictionary and replaces it with an entry in the definitions
82 | dictionary, changing the key `$ref` to `type`.
83 |
84 | Args:
85 | data (dict): The data dictionary to check for $ref keys.
86 | definitions (dict): The definitions dictionary to replace $ref keys with.
87 |
88 | Returns:
89 | dict: The data dictionary with $ref keys replaced.
90 | """
91 | if isinstance(data, dict):
92 | if "$ref" in data:
93 | ref = data["$ref"].split('/')[-1]
94 | if ref in definitions:
95 | definition = definitions[ref]
96 | data = definition
97 | else:
98 | for key, value in data.items():
99 | data[key] = resolve_ref(value, definitions)
100 | elif isinstance(data, list):
101 | for index, value in enumerate(data):
102 | data[index] = resolve_ref(value, definitions)
103 | return data
104 |
105 | def remove_key_deep(data, key):
106 | """
107 | Recursively removes a key from a dictionary.
108 |
109 | Args:
110 | data (dict): The data dictionary to remove the key from.
111 | key (str): The key to remove from the dictionary.
112 |
113 | Returns:
114 | dict: The data dictionary with the key removed.
115 | """
116 | if isinstance(data, dict):
117 | data.pop(key, None)
118 | for k, v in data.items():
119 | data[k] = remove_key_deep(v, key)
120 | elif isinstance(data, list):
121 | for index, value in enumerate(data):
122 | data[index] = remove_key_deep(value, key)
123 | return data
--------------------------------------------------------------------------------
/cursive/hookable.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | from concurrent.futures import ThreadPoolExecutor
4 | from inspect import signature
5 | from typing import Any, Callable
6 | from warnings import warn
7 |
8 |
9 | def flatten_hooks_dictionary(
10 | hooks: dict[str, dict | Callable], parent_name: str | None = None
11 | ):
12 | flattened_hooks = {}
13 |
14 | for key, value in hooks.items():
15 | name = f"{parent_name}:{key}" if parent_name else key
16 |
17 | if isinstance(value, dict):
18 | flattened_hooks.update(flatten_hooks_dictionary(value, name))
19 | elif callable(value):
20 | flattened_hooks[name] = value
21 |
22 | return flattened_hooks
23 |
24 |
25 | def run_tasks_sequentially(tasks, handler):
26 | for task in tasks:
27 | handler(task)
28 |
29 |
30 | def merge_hooks(*hooks: dict[str, Any]):
31 | merged_hooks = dict()
32 |
33 | for hook in hooks:
34 | flattened_hook = flatten_hooks_dictionary(hook)
35 |
36 | for key, value in flattened_hook.items():
37 | if merged_hooks[key]:
38 | merged_hooks[key].append(value)
39 | else:
40 | merged_hooks[key] = [value]
41 |
42 | for key in merged_hooks.keys():
43 | if len(merged_hooks[key]) > 1:
44 | hooks_list = merged_hooks[key]
45 | merged_hooks[key] = lambda *arguments: run_tasks_sequentially(
46 | hooks_list, lambda hook: hook(*arguments)
47 | )
48 | else:
49 | merged_hooks[key] = merged_hooks[key][0]
50 |
51 | return merged_hooks
52 |
53 |
54 | def serial_caller(hooks: list[Callable], arguments: list[Any] = []):
55 | for hook in hooks:
56 | if len(signature(hook).parameters) > 0:
57 | hook(*arguments)
58 | else:
59 | hook()
60 |
61 |
62 | def concurrent_caller(hooks: list[Callable], arguments: list[Any] = []):
63 | with ThreadPoolExecutor() as executor:
64 | executor.map(lambda hook: hook(*arguments), hooks)
65 |
66 |
67 | def call_each_with(callbacks: list[Callable], argument: Any):
68 | for callback in callbacks:
69 | callback(argument)
70 |
71 |
72 | class Hookable:
73 | def __init__(self):
74 | self._hooks = {}
75 | self._before = []
76 | self._after = []
77 | self._deprecated_messages = set()
78 | self._deprecated_hooks = {}
79 |
80 | def hook(self, name: str, function: Callable | None, options={}):
81 | if not name or not callable(function):
82 | return lambda: None
83 |
84 | original_name = name
85 | deprecated_hook = {}
86 | while self._deprecated_hooks.get(name):
87 | deprecated_hook = self._deprecated_hooks[name]
88 | name = deprecated_hook["to"]
89 |
90 | message = None
91 | if deprecated_hook and not options["allow_deprecated"]:
92 | message = deprecated_hook["message"]
93 | if not message:
94 | message = f"{original_name} hook has been deprecated" + (
95 | f', please use {deprecated_hook["to"]}'
96 | if deprecated_hook["to"]
97 | else ""
98 | )
99 |
100 | if message not in self._deprecated_messages:
101 | warn(message)
102 | self._deprecated_messages.add(message)
103 |
104 | if function.__name__ == "":
105 | function.__name__ = "_" + re.sub(r"\W+", "_", name) + "_hook_cb"
106 |
107 | self._hooks[name] = name in self._hooks or []
108 | self._hooks[name].append(function)
109 |
110 | def remove():
111 | nonlocal function
112 | if function:
113 | self.remove_hook(name, function)
114 | function = None
115 |
116 | return remove
117 |
118 | def hook_once(self, name: str, function: Callable):
119 | hook = None
120 |
121 | def run_once(*arguments):
122 | nonlocal hook
123 | if callable(hook):
124 | hook()
125 |
126 | hook = None
127 | return function(*arguments)
128 |
129 | hook = self.hook(name, run_once)
130 | return hook
131 |
132 | def remove_hook(self, name: str, function: Callable):
133 | if self._hooks[name]:
134 | if len(self._hooks[name]) == 0:
135 | del self._hooks[name]
136 | else:
137 | try:
138 | index = self._hooks[name].index(function)
139 | self._hooks[name][index:index] = []
140 | # if index is not found, ignore
141 | except ValueError:
142 | pass
143 |
144 | def deprecate_hook(self, name: str, deprecated: Callable | str):
145 | self._deprecated_hooks[name] = (
146 | {"to": deprecated} if isinstance(deprecated, str) else deprecated
147 | )
148 | hooks = self._hooks[name] or []
149 | del self._hooks[name]
150 | for hook in hooks:
151 | self.hook(name, hook)
152 |
153 | def deprecate_hooks(self, deprecated_hooks: dict[str, Any]):
154 | self._deprecated_hooks.update(deprecated_hooks)
155 | for name in deprecated_hooks.keys():
156 | self.deprecate_hook(name, deprecated_hooks[name])
157 |
158 | def add_hooks(self, hooks: dict[str, Any]):
159 | hooks_to_be_added = flatten_hooks_dictionary(hooks)
160 | remove_fns = [self.hook(key, fn) for key, fn in hooks_to_be_added.items()]
161 |
162 | def function():
163 | for unreg in remove_fns:
164 | unreg()
165 | remove_fns[:] = []
166 |
167 | return function
168 |
169 | def remove_hooks(self, hooks: dict[str, Any]):
170 | hooks_to_be_removed = flatten_hooks_dictionary(hooks)
171 | for key, value in hooks_to_be_removed.items():
172 | self.remove_hook(key, value)
173 |
174 | def remove_all_hooks(self):
175 | for key in self._hooks.keys():
176 | del self._hooks[key]
177 |
178 | def call_hook(self, name: str, *arguments: Any):
179 | return self.call_hook_with(serial_caller, name, *arguments)
180 |
181 | def call_hook_concurrent(self, name: str, *arguments: Any):
182 | return self.call_hook_with(concurrent_caller, name, *arguments)
183 |
184 | def call_hook_with(self, caller: Callable, name: str, *arguments: Any):
185 | event = {"name": name, "args": arguments, "context": {}}
186 |
187 | call_each_with(self._before, event)
188 |
189 | result = caller(self._hooks[name] if name in self._hooks else [], arguments)
190 |
191 | call_each_with(self._after, event)
192 |
193 | return result
194 |
195 | def before_each(self, function: Callable):
196 | self._before.append(function)
197 |
198 | def remove_from_before_list():
199 | try:
200 | index = self._before.index(function)
201 | self._before[index:index] = []
202 | except ValueError:
203 | pass
204 |
205 | return remove_from_before_list
206 |
207 | def after_each(self, function: Callable):
208 | self._after.append(function)
209 |
210 | def remove_from_after_list():
211 | try:
212 | index = self._after.index(function)
213 | self._after[index:index] = []
214 | except ValueError:
215 | pass
216 |
217 | return remove_from_after_list
218 |
219 |
220 | def create_hooks():
221 | return Hookable()
222 |
223 |
224 | def starts_with_predicate(prefix: str):
225 | return lambda name: name.startswith(prefix)
226 |
227 |
228 | def create_debugger(hooks: Hookable, _options: dict[str, Any] = {}):
229 | options = {"filter": lambda: True, **_options}
230 |
231 | predicate = options["filter"]
232 | if isinstance(predicate, str):
233 | predicate = starts_with_predicate(predicate)
234 |
235 | tag = f'[{options["tag"]}] ' if options["tag"] else ""
236 | start_times = {}
237 |
238 | def log_prefix(event: dict[str, Any]):
239 | return tag + event["name"] + "".ljust(int(event["id"]), "\0")
240 |
241 | id_control = {}
242 |
243 | def unsubscribe_before_each(event: dict[str, Any] | None = None):
244 | if event is None or not predicate(event["name"]):
245 | return
246 |
247 | id_control[event["name"]] = id_control.get(event.get("name")) or 0
248 | event["id"] = id_control[event["name"]]
249 | id_control[event["name"]] += 1
250 | start_times[log_prefix(event)] = time.time()
251 |
252 | unsubscribe_before = hooks.before_each(unsubscribe_before_each)
253 |
254 | def unsubscribe_after_each(event: dict[str, Any] | None = None):
255 | if event is None or not predicate(event["name"]):
256 | return
257 |
258 | label = log_prefix(event)
259 | elapsed_time = time.time() - start_times[label]
260 | print(f"Elapsed time for {label}: {elapsed_time} seconds")
261 |
262 | id_control[event["name"]] -= 1
263 |
264 | unsubscribe_after = hooks.after_each(unsubscribe_after_each)
265 |
266 | def stop_debbuging_and_remove_listeners():
267 | unsubscribe_before()
268 | unsubscribe_after()
269 |
270 | return {"close": lambda: stop_debbuging_and_remove_listeners()}
271 |
--------------------------------------------------------------------------------
/cursive/model.py:
--------------------------------------------------------------------------------
1 | from cursive.compat.pydantic import BaseModel
2 | from cursive.function import CursiveFunction
3 |
4 |
5 | class CursiveModel(CursiveFunction):
6 | def __init__(self, model: type[BaseModel]):
7 | super().__init__(model, pause=True)
8 |
9 |
10 | def cursive_model():
11 | def decorator(model: type[BaseModel] = None):
12 | if model is None:
13 | return lambda function: CursiveModel(function)
14 | else:
15 | return CursiveModel(model)
16 |
17 | return decorator
18 |
--------------------------------------------------------------------------------
/cursive/pricing.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from .types import CursiveAskCost, CursiveAskUsage
4 | from .utils import destructure_items
5 |
6 | from .assets.price.anthropic import ANTHROPIC_PRICING
7 | from .assets.price.openai import OPENAI_PRICING
8 | from .assets.price.cohere import COHERE_PRICING
9 |
10 | VENDOR_PRICING = {
11 | "openai": OPENAI_PRICING,
12 | "anthropic": ANTHROPIC_PRICING,
13 | "cohere": COHERE_PRICING,
14 | }
15 |
16 |
17 | def resolve_pricing(
18 | vendor: Literal["openai", "anthropic"], usage: CursiveAskUsage, model: str
19 | ):
20 | if "/" in model:
21 | vendor, model = model.split("/")
22 |
23 | version: str
24 | prices: dict[str, dict[str, str]]
25 |
26 | version, prices = destructure_items(
27 | keys=["version"], dictionary=VENDOR_PRICING[vendor]
28 | )
29 |
30 | models_available = list(prices.keys())
31 | model_match = next((m for m in models_available if model.startswith(m)), None)
32 |
33 | if not model_match:
34 | raise Exception(f"Unknown model {model}")
35 |
36 | model_price = prices[model_match]
37 | completion = usage.completion_tokens * float(model_price["completion"]) / 1000
38 | prompt = usage.prompt_tokens * float(model_price["prompt"]) / 1000
39 |
40 | cost = CursiveAskCost(
41 | completion=completion, prompt=prompt, total=completion + prompt, version=version
42 | )
43 |
44 | return cost
45 |
--------------------------------------------------------------------------------
/cursive/stream.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Callable, Literal, Optional
3 | from cursive.types import CompletionPayload, CursiveAskOnToken
4 | from cursive.hookable import create_hooks
5 | from cursive.utils import random_id
6 |
7 |
8 | class StreamTransformer:
9 | def __init__(
10 | self,
11 | payload: CompletionPayload,
12 | response: Any,
13 | on_token: Optional[CursiveAskOnToken] = None,
14 | ):
15 | self._hooks = create_hooks()
16 | self.payload = payload
17 | self.response = response
18 | self.on_token = on_token
19 |
20 | def on(self, event: Literal["get_current_token"], function: Callable):
21 | self._hooks.hook(event, function)
22 |
23 | def call_output_hook(self, event: str, payload: Any):
24 | output = HookOutput(payload)
25 | self._hooks.call_hook(event, output)
26 | return output.value
27 |
28 | def process(self):
29 | data = {
30 | "choices": [{"message": {"content": ""}}],
31 | "usage": {
32 | "completion_tokens": 0,
33 | "prompt_tokens": 0,
34 | },
35 | "model": self.payload.model,
36 | "id": random_id(),
37 | }
38 | completion = ""
39 |
40 | for part in self.response:
41 | # The completion partial will come with a leading whitespace
42 | current_token = self.call_output_hook("get_current_token", part)
43 | if not data["choices"][0]["message"]["content"]:
44 | completion = completion.lstrip()
45 | completion += current_token
46 | # Check if theres any tag. The regex should allow for nested tags
47 | function_call_tag = re.findall(
48 | r"([\s\S]*?)(?=<\/function-call>|$)", completion
49 | )
50 | function_name = ""
51 | function_arguments = ""
52 | if len(function_call_tag) > 0:
53 | # Remove starting and ending tags, even if the ending tag is partial or missing
54 | function_call = re.sub(
55 | r"^\n|\n$",
56 | "",
57 | re.sub(
58 | r"<\/?f?u?n?c?t?i?o?n?-?c?a?l?l?>?", "", function_call_tag[0]
59 | ).strip(),
60 | ).strip()
61 | # Match the function name inside the JSON
62 | function_name_matches = re.findall(r'"name":\s*"(.+)"', function_call)
63 | function_name = (
64 | len(function_name_matches) > 0 and function_name_matches[0]
65 | )
66 | function_arguments_matches = re.findall(
67 | r'"arguments":\s*(\{.+)\}?', function_call, re.S
68 | )
69 | function_arguments = (
70 | len(function_arguments_matches) > 0
71 | and function_arguments_matches[0]
72 | )
73 | if function_arguments:
74 | # If theres unmatches } at the end, remove them
75 | unmatched_brackets = re.findall(r"(\{|\})", function_arguments)
76 | if len(unmatched_brackets) % 2:
77 | function_arguments = re.sub(
78 | r"\}$", "", function_arguments.strip()
79 | )
80 |
81 | function_arguments = function_arguments.strip()
82 |
83 | cursive_answer_tag = re.findall(
84 | r"([\s\S]*?)(?=<\/cursive-answer>|$)", completion
85 | )
86 | tagged_answer = ""
87 | if cursive_answer_tag:
88 | tagged_answer = re.sub(
89 | r"<\/?c?u?r?s?i?v?e?-?a?n?s?w?e?r?>?", "", cursive_answer_tag[0]
90 | ).lstrip()
91 |
92 | current_token = completion[len(data["choices"][0]["message"]["content"]) :]
93 | data["choices"][0]["message"]["content"] += current_token
94 |
95 | if self.on_token:
96 | chunk = None
97 |
98 | if self.payload.functions:
99 | if function_name:
100 | chunk = {
101 | "function_call": {},
102 | "content": None,
103 | }
104 | if function_arguments:
105 | # Remove all but the current token from the arguments
106 | chunk["function_call"]["arguments"] = function_arguments
107 | else:
108 | chunk["function_call"] = {
109 | "name": function_name,
110 | "arguments": "",
111 | }
112 | elif tagged_answer:
113 | # Token is at the end of the tagged answer
114 | regex = rf"(.*){current_token.strip()}$"
115 | match = re.findall(regex, tagged_answer)
116 | if len(match) > 0 and current_token:
117 | chunk = {
118 | "function_call": None,
119 | "content": current_token,
120 | }
121 | else:
122 | chunk = {
123 | "content": current_token,
124 | }
125 |
126 | if chunk:
127 | self.on_token(chunk)
128 |
129 | return data
130 |
131 |
132 | class HookOutput:
133 | value: Any
134 |
135 | def __init__(self, value: Any = None):
136 | self.value = value
137 |
--------------------------------------------------------------------------------
/cursive/tests/test_function_compatibility.py:
--------------------------------------------------------------------------------
1 | def test_pydantic_compatibility():
2 | from cursive.function import cursive_function
3 |
4 | # Define a function with Pydantic v1 and v2 compatible annotations
5 | @cursive_function()
6 | def test_function(name: str, age: int):
7 | """
8 | A test function.
9 | Args:
10 | name: The name of a person.
11 | age: The age of the person.
12 | """
13 | return f"{name} is {age} years old."
14 |
15 | # Test the function with some arguments
16 | assert test_function("John", 30) == "John is 30 years old."
17 |
18 | # Check the function schema
19 | expected_schema = {
20 | "parameters": {
21 | "type": "object",
22 | "properties": {
23 | "name": {"type": "string", "description": "The name of a person.", "title": "Name"},
24 | "age": {"type": "integer", "description": "The age of the person.", "title": "Age"},
25 | },
26 | "required": ["name", "age"],
27 | },
28 | "description": "A test function.\nArgs:\n name: The name of a person.\n age: The age of the person.",
29 | "name": "TestFunction",
30 | }
31 | assert test_function.function_schema == expected_schema
32 |
--------------------------------------------------------------------------------
/cursive/tests/test_function_schema.py:
--------------------------------------------------------------------------------
1 | from cursive.compat.pydantic import BaseModel
2 | from cursive.function import cursive_function
3 |
4 | def test_function_schema_allows_arbitrary_types():
5 |
6 | class Character(BaseModel):
7 | name: str
8 | age: int
9 |
10 | @cursive_function()
11 | def gen_arbitrary_type(character: Character):
12 | """
13 | A test function.
14 |
15 | character: A character.
16 | """
17 | return f"{character.name} is {character.age} years old."
18 |
19 | assert 'description' in gen_arbitrary_type.function_schema
--------------------------------------------------------------------------------
/cursive/tests/test_setup.py:
--------------------------------------------------------------------------------
1 | from cursive import Cursive
2 |
3 | def test_cursive_setup():
4 | cursive = Cursive()
5 | assert cursive is not None
--------------------------------------------------------------------------------
/cursive/types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Any, Callable, Literal, Optional
3 |
4 | from cursive.compat.pydantic import BaseModel as PydanticBaseModel, Field
5 | from cursive.function import CursiveFunction
6 | from cursive.utils import random_id
7 |
8 |
9 | class BaseModel(PydanticBaseModel):
10 | class Config(PydanticBaseModel.Config):
11 | arbitrary_types_allowed = True
12 | protected_namespaces = ()
13 |
14 |
15 | class CursiveLanguageModel(Enum):
16 | GPT4 = "gpt4"
17 | GPT4_32K = "gpt4-32k"
18 | GPT3_5_TURBO = "gpt-3.5-turbo"
19 | GPT3_5_TURBO_16K = "gpt-3.5-turbo-16k"
20 | CLAUDE_2 = "claude-2"
21 | CLAUDE_INSTANT_1_2 = "claude-instant-1.2"
22 | CLAUDE_INSTANT_1 = "claude-instant-1.2"
23 | COMMAND = "command"
24 | COMMAND_NIGHTLY = "command-nightly"
25 |
26 |
27 | class CursiveAskUsage(BaseModel):
28 | completion_tokens: int
29 | prompt_tokens: int
30 | total_tokens: int
31 |
32 |
33 | class CursiveAskCost(BaseModel):
34 | completion: float
35 | prompt: float
36 | total: float
37 | version: str
38 |
39 |
40 | Role = Literal[
41 | "system",
42 | "user",
43 | "assistant",
44 | "function",
45 | ]
46 |
47 |
48 | class ChatCompletionRequestMessageFunctionCall(BaseModel):
49 | name: Optional[str] = None
50 | arguments: Optional[str] = None
51 |
52 |
53 | class CompletionMessage(BaseModel):
54 | id: Optional[str] = None
55 | role: Role
56 | content: Optional[str] = None
57 | name: Optional[str] = None
58 | function_call: Optional[ChatCompletionRequestMessageFunctionCall] = None
59 |
60 | def __init__(self, **data):
61 | super().__init__(**data)
62 | if not self.id:
63 | self.id = random_id()
64 |
65 |
66 | class CursiveSetupOptionsExpand(BaseModel):
67 | enabled: Optional[bool] = None
68 | defaults_to: Optional[str] = None
69 | model_mapping: Optional[dict[str, str]] = None
70 |
71 |
72 | class CursiveErrorCode(Enum):
73 | function_call_error = ("function_call_error",)
74 | completion_error = ("completion_error",)
75 | invalid_request_error = ("invalid_request_error",)
76 | embedding_error = ("embedding_error",)
77 | unknown_error = ("unknown_error",)
78 |
79 |
80 | class CursiveError(Exception):
81 | name = "CursiveError"
82 |
83 | def __init__(
84 | self,
85 | message: str,
86 | details: Any,
87 | code: CursiveErrorCode,
88 | ):
89 | self.message = message
90 | self.details = details
91 | self.code = code
92 | super().__init__(self.message)
93 |
94 |
95 | class CursiveEnrichedAnswer(BaseModel):
96 | error: CursiveError | None = None
97 | usage: CursiveAskUsage | None = None
98 | model: str
99 | id: str | None = None
100 | choices: Optional[list[Any]] = None
101 | function_result: Any | None = None
102 | answer: str | None = None
103 | messages: list[CompletionMessage] | None = None
104 | cost: CursiveAskCost | None = None
105 |
106 |
107 | CursiveAskOnToken = Callable[[dict[str, Any]], None]
108 |
109 |
110 | class CursiveAskOptionsBase(BaseModel):
111 | model: Optional[str | CursiveLanguageModel] = None
112 | system_message: Optional[str] = None
113 | functions: Optional[list[CursiveFunction]] = None
114 | function_call: Optional[str | CursiveFunction] = None
115 | on_token: Optional[CursiveAskOnToken] = None
116 | max_tokens: Optional[int] = None
117 | stop: Optional[list[str]] = None
118 | temperature: Optional[float] = None
119 | top_p: Optional[float] = None
120 | presence_penalty: Optional[float] = None
121 | frequency_penalty: Optional[float] = None
122 | best_of: Optional[int] = None
123 | n: Optional[int] = None
124 | logit_bias: Optional[dict[str, float]] = None
125 | user: Optional[str] = None
126 | stream: Optional[bool] = None
127 |
128 |
129 | class CreateChatCompletionResponse(BaseModel):
130 | id: str
131 | model: str
132 | choices: list[Any]
133 | usage: Optional[Any] = None
134 |
135 |
136 | class CreateChatCompletionResponseExtended(CreateChatCompletionResponse):
137 | function_result: Optional[Any] = None
138 | cost: Optional[CursiveAskCost] = None
139 | error: Optional[CursiveError] = None
140 |
141 |
142 | class CursiveAskModelResponse(BaseModel):
143 | answer: CreateChatCompletionResponseExtended
144 | messages: list[CompletionMessage]
145 |
146 |
147 | class CursiveSetupOptions(BaseModel):
148 | max_retries: Optional[int] = None
149 | expand: Optional[CursiveSetupOptionsExpand] = None
150 | is_using_openrouter: Optional[bool] = None
151 |
152 |
153 | class CompletionRequestFunctionCall(BaseModel):
154 | name: str
155 | inputs: dict[str, Any] = Field(default_factory=dict)
156 |
157 |
158 | class CompletionRequestStop(BaseModel):
159 | messages_seen: Optional[list[CompletionMessage]] = None
160 | max_turns: Optional[int] = None
161 |
162 |
163 | class CompletionFunctions(BaseModel):
164 | name: str
165 | description: Optional[str] = None
166 | parameters: Optional[dict[str, Any]] = None
167 |
168 |
169 | class CompletionPayload(BaseModel):
170 | model: str
171 | messages: list[CompletionMessage]
172 | functions: Optional[list[CompletionFunctions]] = None
173 | function_call: Optional[CompletionRequestFunctionCall] = None
174 | temperature: Optional[float] = None
175 | top_p: Optional[float] = None
176 | n: Optional[int] = None
177 | stream: Optional[bool] = None
178 | stop: Optional[CompletionRequestStop] = None
179 | max_tokens: Optional[int] = None
180 | presence_penalty: Optional[float] = None
181 | frequency_penalty: Optional[float] = None
182 | logit_bias: Optional[dict[str, float]] = None
183 | user: Optional[str] = None
184 | other: Optional[dict[str, Any]] = None
185 |
186 |
187 | CursiveHook = Literal[
188 | "embedding:before",
189 | "embedding:after",
190 | "embedding:error",
191 | "embedding:success",
192 | "completion:before",
193 | "completion:after",
194 | "completion:error",
195 | "completion:success",
196 | "ask:before",
197 | "ask:after",
198 | "ask:success",
199 | "ask:error",
200 | ]
201 |
202 |
203 | class CursiveHookPayload:
204 | data: Optional[Any]
205 | error: Optional[CursiveError]
206 | duration: Optional[float]
207 |
208 | def __init__(
209 | self,
210 | data: Optional[Any] = None,
211 | error: Optional[CursiveError] = None,
212 | duration: Optional[float] = None,
213 | ):
214 | self.data = data
215 | self.error = error
216 | self.duration = duration
217 |
--------------------------------------------------------------------------------
/cursive/usage/anthropic.py:
--------------------------------------------------------------------------------
1 | from anthropic import Anthropic
2 |
3 | from cursive.build_input import build_completion_input
4 |
5 | from ..types import CompletionMessage
6 |
7 |
8 | def get_anthropic_usage(content: str | list[CompletionMessage]):
9 | client = Anthropic()
10 |
11 | if type(content) != str:
12 | content = build_completion_input(content)
13 |
14 | return client.count_tokens(content)
15 |
--------------------------------------------------------------------------------
/cursive/usage/cohere.py:
--------------------------------------------------------------------------------
1 | from tokenizers import Tokenizer
2 | from cursive.build_input import build_completion_input
3 | from cursive.types import CompletionMessage
4 |
5 |
6 | def get_cohere_usage(content: str | list[CompletionMessage]):
7 | tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly")
8 |
9 | if type(content) != str:
10 | content = build_completion_input(content)
11 |
12 | return len(tokenizer.encode(content).ids)
13 |
--------------------------------------------------------------------------------
/cursive/usage/openai.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import AbstractSet, Collection, Literal
3 |
4 | import tiktoken
5 |
6 | from cursive.types import CompletionMessage
7 |
8 |
9 | def encode(
10 | text: str,
11 | allowed_special: AbstractSet[str] | Literal["all"] = set(),
12 | disallowed_special: Collection[str] | Literal["all"] = "all",
13 | ) -> list[int]:
14 | enc = tiktoken.get_encoding("cl100k_base")
15 |
16 | return enc.encode(
17 | text, allowed_special=allowed_special, disallowed_special=disallowed_special
18 | )
19 |
20 |
21 | def get_openai_usage(content: str | list[CompletionMessage]):
22 | if type(content) == list:
23 | tokens = {
24 | "per_message": 3,
25 | "per_name": 1,
26 | }
27 |
28 | token_count = 3
29 | for message in content:
30 | token_count += tokens["per_message"]
31 | for attribute, value in message.dict().items():
32 | if attribute == "name":
33 | token_count += tokens["per_name"]
34 |
35 | if isinstance(value, dict):
36 | value = json.dumps(value, separators=(",", ":"))
37 |
38 | if value is None:
39 | continue
40 |
41 | token_count += len(encode(value))
42 |
43 | return token_count
44 | else:
45 | return len(encode(content)) # type: ignore
46 |
47 |
48 | def get_token_count_from_functions(functions: list[dict]):
49 | token_count = 3
50 | for fn in functions:
51 | function_tokens = len(encode(fn["name"]))
52 | function_tokens += len(encode(fn["description"])) if fn["description"] else 0
53 |
54 | if fn["parameters"] and fn["parameters"]["properties"]:
55 | properties = fn["parameters"]["properties"]
56 | for key in properties:
57 | function_tokens += len(encode(key))
58 | value = properties[key]
59 | for field in value:
60 | if field in ["type", "description"]:
61 | function_tokens += 2
62 | function_tokens += len(encode(value[field]))
63 | elif field == "enum":
64 | function_tokens -= 3
65 | for enum_value in value[field]:
66 | function_tokens += 3
67 | function_tokens += len(encode(enum_value))
68 |
69 | function_tokens += 11
70 |
71 | token_count += function_tokens
72 |
73 | token_count += 12
74 | return token_count
75 |
--------------------------------------------------------------------------------
/cursive/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 |
4 | from typing import Any, Callable, Optional, Tuple, Type, TypeVar, overload
5 |
6 |
7 | def destructure_items(keys: list[str], dictionary: dict):
8 | items = [dictionary[key] for key in keys]
9 |
10 | new_dictionary = {k: v for k, v in dictionary.items() if k not in keys}
11 |
12 | return *items, new_dictionary
13 |
14 |
15 | def without_nones(dictionary: dict):
16 | return {k: v for k, v in dictionary.items() if v is not None}
17 |
18 |
19 | def random_id():
20 | characters = string.ascii_lowercase + string.digits
21 | random_id = "".join(random.choice(characters) for _ in range(10))
22 | return random_id
23 |
24 |
25 | def delete_keys_from_dict(dictionary: dict, keys: list[str]):
26 | return {k: v for k, v in dictionary.items() if k not in set(keys)}
27 |
28 |
29 | T = TypeVar("T", bound=Exception)
30 |
31 |
32 | @overload
33 | def resguard(function: Callable) -> Tuple[Any, Exception | None]:
34 | ...
35 |
36 |
37 | @overload
38 | def resguard(function: Callable, exception_type: Type[T]) -> Tuple[Any, T | None]:
39 | ...
40 |
41 |
42 | def resguard(
43 | function: Callable, exception_type: Optional[Type[T]] = None
44 | ) -> Tuple[Any, T | Exception | None]:
45 | try:
46 | return function(), None
47 | except Exception as e:
48 | if exception_type:
49 | if isinstance(e, exception_type):
50 | return None, e
51 | else:
52 | raise e
53 | else:
54 | return None, e
55 |
--------------------------------------------------------------------------------
/cursive/vendor/anthropic.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | from anthropic import Anthropic
4 |
5 | from cursive.build_input import build_completion_input
6 | from cursive.stream import StreamTransformer
7 |
8 | from ..types import (
9 | CompletionPayload,
10 | CursiveAskOnToken,
11 | )
12 | from ..utils import without_nones
13 |
14 |
15 | class AnthropicClient:
16 | client: Anthropic
17 |
18 | def __init__(self, api_key: str):
19 | self.client = Anthropic(api_key=api_key)
20 |
21 | def create_completion(self, payload: CompletionPayload):
22 | prompt = build_completion_input(payload.messages)
23 | payload = without_nones(
24 | {
25 | "model": payload.model,
26 | "max_tokens_to_sample": payload.max_tokens or 100000,
27 | "prompt": prompt,
28 | "temperature": payload.temperature or 0.7,
29 | "top_p": payload.top_p,
30 | "stop_sequences": payload.stop,
31 | "stream": payload.stream or False,
32 | **(payload.other or {}),
33 | }
34 | )
35 | return self.client.completions.create(**payload)
36 |
37 |
38 | def process_anthropic_stream(
39 | payload: CompletionPayload,
40 | response: Any,
41 | on_token: Optional[CursiveAskOnToken] = None,
42 | ):
43 | stream_transformer = StreamTransformer(
44 | on_token=on_token,
45 | payload=payload,
46 | response=response,
47 | )
48 |
49 | def get_current_token(part):
50 | part.value = part.value.completion
51 |
52 | stream_transformer.on("get_current_token", get_current_token)
53 | return stream_transformer.process()
54 |
--------------------------------------------------------------------------------
/cursive/vendor/cohere.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 | from cohere import Client as Cohere
3 | from cursive.build_input import build_completion_input
4 |
5 | from cursive.types import CompletionPayload, CursiveAskOnToken
6 | from cursive.stream import StreamTransformer
7 | from cursive.utils import without_nones
8 |
9 |
10 | class CohereClient:
11 | client: Cohere
12 |
13 | def __init__(self, api_key: str):
14 | self.client = Cohere(api_key=api_key)
15 |
16 | def create_completion(self, payload: CompletionPayload):
17 | prompt = build_completion_input(payload.messages)
18 | payload = without_nones(
19 | {
20 | "model": payload.model,
21 | "max_tokens": payload.max_tokens or 3000,
22 | "prompt": prompt.rstrip(),
23 | "temperature": payload.temperature
24 | if payload.temperature is not None
25 | else 0.7,
26 | "stop_sequences": payload.stop,
27 | "stream": payload.stream or False,
28 | "frequency_penalty": payload.frequency_penalty,
29 | "p": payload.top_p,
30 | }
31 | )
32 | try:
33 | response = self.client.generate(
34 | **payload,
35 | )
36 |
37 | return response, None
38 | except Exception as e:
39 | return None, e
40 |
41 |
42 | def process_cohere_stream(
43 | payload: CompletionPayload,
44 | response: Any,
45 | on_token: Optional[CursiveAskOnToken] = None,
46 | ):
47 | stream_transformer = StreamTransformer(
48 | on_token=on_token,
49 | payload=payload,
50 | response=response,
51 | )
52 |
53 | def get_current_token(part):
54 | part.value = part.value.text
55 |
56 | stream_transformer.on("get_current_token", get_current_token)
57 |
58 | return stream_transformer.process()
59 |
--------------------------------------------------------------------------------
/cursive/vendor/index.py:
--------------------------------------------------------------------------------
1 | vendors_and_model_prefixes = {
2 | "openai": ["gpt-3.5", "gpt-4"],
3 | "anthropic": ["claude-instant", "claude-2"],
4 | "cohere": ["command"],
5 | "replicate": ["replicate"],
6 | }
7 |
8 |
9 | def resolve_vendor_from_model(model: str):
10 | for vendor, prefixes in vendors_and_model_prefixes.items():
11 | if any(model.startswith(m) for m in prefixes):
12 | return vendor
13 |
14 | return ""
15 |
--------------------------------------------------------------------------------
/cursive/vendor/openai.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | from ..types import CursiveAskOnToken
4 | from ..usage.openai import get_openai_usage, get_token_count_from_functions
5 |
6 |
7 | def process_openai_stream(
8 | payload: Any,
9 | cursive: Any,
10 | response: Any,
11 | on_token: Optional[CursiveAskOnToken] = None,
12 | ):
13 | data = {
14 | "choices": [],
15 | "usage": {
16 | "completion_tokens": 0,
17 | "prompt_tokens": get_openai_usage(payload.messages),
18 | },
19 | "model": payload.model,
20 | }
21 |
22 | if payload.functions:
23 | data["usage"]["prompt_tokens"] += get_token_count_from_functions(
24 | payload.functions
25 | )
26 |
27 | for slice in response:
28 | data = {
29 | **data,
30 | "id": slice["id"],
31 | }
32 |
33 | for i in range(len(slice["choices"])):
34 | delta = slice["choices"][i]["delta"]
35 |
36 | if len(data["choices"]) <= i:
37 | data["choices"].append(
38 | {
39 | "message": {
40 | "function_call": None,
41 | "role": "assistant",
42 | "content": "",
43 | },
44 | }
45 | )
46 |
47 | if (
48 | delta
49 | and delta.get("function_call")
50 | and delta["function_call"].get("name")
51 | ):
52 | data["choices"][i]["message"]["function_call"] = delta["function_call"]
53 |
54 | if (
55 | delta
56 | and delta.get("function_call")
57 | and delta["function_call"].get("arguments")
58 | ):
59 | data["choices"][i]["message"]["function_call"]["arguments"] += delta[
60 | "function_call"
61 | ]["arguments"]
62 |
63 | if delta and delta.get("content"):
64 | data["choices"][i]["message"]["content"] += delta["content"]
65 |
66 | if on_token:
67 | chunk = None
68 | if delta and delta.get("function_call"):
69 | chunk = {
70 | "function_call": {
71 | k: v for k, v in delta["function_call"].items()
72 | },
73 | "content": None,
74 | }
75 |
76 | if delta and delta.get("content"):
77 | chunk = {"content": delta["content"], "function_call": None}
78 |
79 | if chunk:
80 | on_token(chunk)
81 |
82 | return data
83 |
--------------------------------------------------------------------------------
/cursive/vendor/replicate.py:
--------------------------------------------------------------------------------
1 | import replicate
2 |
3 | from cursive.build_input import build_completion_input
4 | from cursive.types import CompletionPayload
5 | from cursive.utils import without_nones
6 |
7 |
8 | class ReplicateClient:
9 | client: replicate.Client
10 |
11 | def __init__(self, api_key: str):
12 | self.client = replicate.Client(api_key)
13 |
14 | def create_completion(self, payload: CompletionPayload): # noqa: F821
15 | prompt = build_completion_input(payload.messages)
16 | # Resolve model ID from `replicate/`
17 | version = payload.model[payload.model.find("/") + 1 :]
18 | resolved_payload = without_nones(
19 | {
20 | "max_new_tokens": payload.max_tokens or 2000,
21 | "max_length": payload.max_tokens or 2000,
22 | "prompt": prompt,
23 | "temperature": payload.temperature or 0.7,
24 | "top_p": payload.top_p,
25 | "stop": payload.stop,
26 | "model": version,
27 | "stream": bool(payload.stream),
28 | **(payload.other or {}),
29 | }
30 | )
31 | try:
32 | response = self.client.run(
33 | version,
34 | input=resolved_payload,
35 | )
36 |
37 | return response, None
38 | except Exception as e:
39 | print("e", e)
40 | return None, e
41 |
--------------------------------------------------------------------------------
/docs/logo-dark.svg:
--------------------------------------------------------------------------------
1 |
5 |
--------------------------------------------------------------------------------
/docs/logo-light.svg:
--------------------------------------------------------------------------------
1 |
5 |
--------------------------------------------------------------------------------
/examples/add-function.py:
--------------------------------------------------------------------------------
1 | from cursive import cursive_function, Cursive
2 |
3 | cursive = Cursive()
4 |
5 | @cursive_function()
6 | def add(a: float, b: float):
7 | """
8 | Adds two numbers.
9 |
10 | a: The first number.
11 | b: The second number.
12 | """
13 | return a + b
14 |
15 | res = cursive.ask(
16 | prompt='What is the sum of 232 and 243?',
17 | functions=[add],
18 | )
19 |
20 | print({
21 | 'model': res.model,
22 | 'message': res.answer,
23 | 'usage': res.usage.total_tokens,
24 | 'conversation': res.conversation.messages,
25 | })
--------------------------------------------------------------------------------
/examples/ai-cli.py:
--------------------------------------------------------------------------------
1 | from cursive import Cursive, cursive_function
2 | import argparse
3 | import subprocess
4 |
5 | parser = argparse.ArgumentParser(description='Solves your CLI problems')
6 | parser.add_argument('question', type=str, nargs=argparse.REMAINDER, help='Question to ask AI')
7 | args = parser.parse_args()
8 |
9 | cursive = Cursive()
10 |
11 | @cursive_function(pause=True)
12 | def execute_command(command: str) -> str:
13 | """
14 | Executes a CLI command from the user prompt
15 |
16 | command: The command to execute
17 | """
18 | return command
19 |
20 | res = cursive.ask(
21 | prompt=' '.join(args.question),
22 | system_message=
23 | "You are a CLI assistant, executes commands from the user prompt."
24 | "You have permission, so just use the function you're provided"
25 | "Always assume the user wants to run a command on the CLI"
26 | "Assume they're using a MacOS terminal.",
27 | functions=[execute_command],
28 | )
29 |
30 | conversation = res.conversation
31 |
32 | while True:
33 | if res.function_result:
34 | print(f'Executing command:\n\t$ {res.function_result}')
35 | print('Press enter to continue or N/n to cancel')
36 |
37 | prompt = input('> ')
38 | if prompt.lower() == 'n':
39 | print('Command cancelled')
40 | exit(0)
41 | elif prompt == '':
42 | subprocess.run(res.function_result, shell=True)
43 | exit(0)
44 | else:
45 | res = conversation.ask(
46 | prompt=prompt,
47 | functions=[execute_command],
48 | )
49 | conversation = res.conversation
50 | else:
51 | print(res.answer, end='\n\n')
52 | prompt = input('> ')
53 | res = conversation.ask(
54 | prompt=prompt,
55 | functions=[execute_command],
56 | )
57 | conversation = res.conversation
58 |
59 |
60 |
--------------------------------------------------------------------------------
/examples/compare-embeddings.py:
--------------------------------------------------------------------------------
1 | from cursive import Cursive
2 | import numpy as np
3 |
4 | cursive = Cursive()
5 |
6 | x1 = cursive.embed("""Pizza""")
7 |
8 | x2 = cursive.embed("""Cat""")
9 |
10 | def cosine_similarity(x, y):
11 | return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
12 |
13 | print(cosine_similarity(x1, x2))
--------------------------------------------------------------------------------
/examples/generate-list-of-objects.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from cursive.compat.pydantic import BaseModel
3 | from cursive.function import cursive_function
4 | from cursive import Cursive
5 |
6 | class Input(BaseModel):
7 | input: str
8 | idx: int
9 |
10 | @cursive_function(pause=True)
11 | def gen_character_list(inputs: List[Input]):
12 | """
13 | Given a prompt (which is directives for a LLM), generate possible inputs that could be fed to it.
14 | Generate 10 inputs.
15 |
16 | inputs: A list of inputs.
17 | """
18 | return inputs
19 |
20 | c = Cursive()
21 |
22 | res = c.ask(
23 | prompt="Generate a input for prompt that generates headlines for a SaaS company.",
24 | model="gpt-4",
25 | function_call=Input
26 | )
27 |
28 | print(res.function_result)
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "cursivepy"
3 | version = "0.7.2"
4 | description = ""
5 | authors = [
6 | "Rodrigo Godinho ",
7 | "Henrique Cunha ",
8 | "Cyrus Nouroozi "
9 | ]
10 | readme = "README.md"
11 | packages = [{ include = "cursive" }]
12 |
13 | [tool.poetry.dependencies]
14 | python = ">=3.10.0,<3.12"
15 | tiktoken = "^0.4.0"
16 | openai = "^0.27.8"
17 | anthropic = "^0.3.6"
18 | pydantic = ">=1,<3"
19 | cohere = "^4.19.3"
20 | tokenizers = "^0.13.3"
21 | replicate = "^0.11.0"
22 | sorcery = "^0.2.2"
23 | numpy = "^1.26.0"
24 |
25 | [tool.poetry.group.dev.dependencies]
26 | pytest = "^7.4.0"
27 | ipython = "^8.14.0"
28 | ipdb = "^0.13.13"
29 |
30 | [build-system]
31 | requires = ["poetry-core"]
32 | build-backend = "poetry.core.masonry.api"
33 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = pydantic1,pydantic2
3 |
4 | [testenv]
5 | deps =
6 | pytest
7 | pydantic1: pydantic==1.*
8 | pydantic2: pydantic==2.*
9 |
10 | commands = pytest
--------------------------------------------------------------------------------