├── .gitignore ├── cursive ├── tests │ └── test_setup.py ├── assets │ └── price │ │ ├── anthropic.py │ │ └── openai.py ├── vendor │ ├── index.py │ ├── cohere.py │ ├── openai.py │ └── anthropic.py ├── __init__.py ├── usage │ ├── anthropic.py │ └── openai.py ├── pricing.py ├── utils.py ├── function.py ├── build_input.py ├── custom_types.py ├── hookable.py └── cursive.py ├── pyproject.toml ├── examples └── ai-cli.py ├── README.md └── docs ├── logo-dark.svg └── logo-light.svg /.gitignore: -------------------------------------------------------------------------------- 1 | **/*/.DS_Store 2 | **/*/test.py 3 | dist 4 | **/*/__pycache__ 5 | playground.py 6 | .pytest_cache -------------------------------------------------------------------------------- /cursive/tests/test_setup.py: -------------------------------------------------------------------------------- 1 | from cursive import Cursive 2 | 3 | def test_cursive_setup(): 4 | cursive = Cursive() 5 | assert cursive is not None -------------------------------------------------------------------------------- /cursive/assets/price/anthropic.py: -------------------------------------------------------------------------------- 1 | ANTHROPIC_PRICING = { 2 | "version": "2023-07-11", 3 | "claude-instant": { 4 | "completion": 0.00551, 5 | "prompt": 0.00163 6 | }, 7 | "claude-2": { 8 | "completion": 0.03268, 9 | "prompt": 0.01102 10 | }, 11 | } -------------------------------------------------------------------------------- /cursive/assets/price/openai.py: -------------------------------------------------------------------------------- 1 | OPENAI_PRICING = { 2 | "version": "2023-06-13", 3 | "gpt-4": { 4 | "completion": 0.06, 5 | "prompt": 0.03 6 | }, 7 | "gpt-4-32k": { 8 | "completion": 0.12, 9 | "prompt": 0.06 10 | }, 11 | "gpt-3.5-turbo": { 12 | "completion": 0.002, 13 | "prompt": 0.0015 14 | }, 15 | "gpt-3.5-turbo-16k": { 16 | "completion": 0.004, 17 | "prompt": 0.003 18 | }, 19 | } -------------------------------------------------------------------------------- /cursive/vendor/index.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import CursiveAvailableModels 2 | 3 | model_suffix_to_vendor_mapping = { 4 | 'openai': ['gpt-3.5', 'gpt-4'], 5 | 'anthropic': ['claude-instant', 'claude-2'], 6 | 'cohere': ['command'] 7 | } 8 | 9 | def resolve_vendor_from_model(model: CursiveAvailableModels): 10 | for vendor, suffixes in model_suffix_to_vendor_mapping.items(): 11 | if len([m for m in suffixes if model.startswith(m)]) > 0: 12 | return vendor 13 | 14 | return '' 15 | 16 | -------------------------------------------------------------------------------- /cursive/__init__.py: -------------------------------------------------------------------------------- 1 | from .cursive import Cursive 2 | from .function import cursive_function 3 | from .custom_types import ( 4 | CompletionPayload, 5 | CursiveError, 6 | CursiveErrorCode, 7 | CursiveEnrichedAnswer, 8 | CursiveAvailableModels, 9 | CompletionMessage, 10 | CursiveFunction, 11 | ) 12 | 13 | __all__ = [ 14 | 'Cursive', 15 | 'cursive_function', 16 | 'CompletionPayload', 17 | 'CursiveError', 18 | 'CursiveErrorCode', 19 | 'CursiveEnrichedAnswer', 20 | 'CursiveAvailableModels', 21 | 'CompletionMessage', 22 | 'CursiveFunction', 23 | ] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cursivepy" 3 | version = "0.1.6" 4 | description = "" 5 | authors = ["Rodrigo Godinho ", "Henrique Cunha "] 6 | readme = "README.md" 7 | packages = [{include = "cursive"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.10.9,<3.12" 11 | tiktoken = "^0.4.0" 12 | openai = "^0.27.8" 13 | anthropic = "^0.3.6" 14 | pydantic = "^1.9.0" 15 | cohere = "^4.19.3" 16 | 17 | [tool.poetry.group.dev.dependencies] 18 | pytest = "^7.4.0" 19 | 20 | [build-system] 21 | requires = ["poetry-core"] 22 | build-backend = "poetry.core.masonry.api" 23 | -------------------------------------------------------------------------------- /cursive/usage/anthropic.py: -------------------------------------------------------------------------------- 1 | from anthropic import Anthropic 2 | 3 | from ..custom_types import CompletionMessage 4 | 5 | def get_anthropic_usage(content: str | list[CompletionMessage]): 6 | client = Anthropic() 7 | 8 | if type(content) == str: 9 | return client.count_tokens(content) 10 | 11 | def function(message: CompletionMessage): 12 | if message.role == 'system': 13 | return f''' 14 | Human: {message.content} 15 | 16 | Assistant: Ok. 17 | ''' 18 | return f'{message.role}: {message.content}' 19 | 20 | mapped_content = '\n\n'.join(list(map(function, content))) # type: ignore 21 | 22 | return client.count_tokens(mapped_content) 23 | 24 | -------------------------------------------------------------------------------- /cursive/vendor/cohere.py: -------------------------------------------------------------------------------- 1 | from cohere import Client as Cohere 2 | from cursive.build_input import build_completion_input 3 | 4 | from cursive.custom_types import CompletionPayload 5 | from cursive.utils import filter_null_values 6 | 7 | class CohereClient: 8 | client: Cohere 9 | 10 | def __init__(self, api_key: str): 11 | self.client = Cohere(api_key=api_key) 12 | 13 | def create_completion(self, payload: CompletionPayload): 14 | prompt = build_completion_input(payload.messages) 15 | payload = filter_null_values({ 16 | 'model': payload.model, 17 | 'max_tokens': payload.max_tokens or 3000, 18 | 'prompt': prompt.rstrip(), 19 | 'temperature': payload.temperature if payload.temperature is not None else 0.7, 20 | 'stop_sequences': payload.stop, 21 | 'stream': payload.stream or False, 22 | 'frequency_penalty': payload.frequency_penalty, 23 | 'p': payload.top_p, 24 | }) 25 | try: 26 | response = self.client.generate( 27 | **payload, 28 | ) 29 | 30 | return response, None 31 | except Exception as e: 32 | return None, e 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /cursive/pricing.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from .custom_types import CursiveAskCost, CursiveAskUsage 4 | from .utils import destructure_items 5 | 6 | from .assets.price.anthropic import ANTHROPIC_PRICING 7 | from .assets.price.openai import OPENAI_PRICING 8 | 9 | VENDOR_PRICING = { 10 | 'openai': OPENAI_PRICING, 11 | 'anthropic': ANTHROPIC_PRICING 12 | } 13 | 14 | def resolve_pricing( 15 | vendor: Literal['openai', 'anthropic'], 16 | usage: CursiveAskUsage, 17 | model: str 18 | ): 19 | version: str 20 | prices: dict[str, dict[str, str]] 21 | 22 | version, prices = destructure_items( 23 | keys=["version"], 24 | dictionary=VENDOR_PRICING[vendor] 25 | ) 26 | 27 | models_available = list(prices.keys()) 28 | model_match = next((m for m in models_available if model.startswith(m)), None) 29 | 30 | if not model_match: 31 | raise Exception(f'Unknown model {model}') 32 | 33 | model_price = prices[model_match] 34 | completion = usage.completion_tokens * float(model_price["completion"]) / 1000 35 | prompt = usage.prompt_tokens * float(model_price["prompt"]) / 1000 36 | 37 | cost = CursiveAskCost( 38 | completion=completion, 39 | prompt=prompt, 40 | total=completion + prompt, 41 | version=version 42 | ) 43 | 44 | return cost 45 | -------------------------------------------------------------------------------- /cursive/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from typing import Any, Callable, Optional, Tuple, Type, TypeVar, overload 4 | import re 5 | 6 | 7 | def destructure_items(keys: list[str], dictionary: dict): 8 | items = [ 9 | dictionary[key] for key in keys 10 | ] 11 | 12 | new_dictionary = { 13 | k:v for k,v in dictionary.items() if k not in keys 14 | } 15 | 16 | return *items, new_dictionary 17 | 18 | 19 | def filter_null_values(dictionary: dict): 20 | return { k:v for k, v in dictionary.items() if v is not None } 21 | 22 | 23 | def random_id(): 24 | characters = string.ascii_lowercase + string.digits 25 | random_id = ''.join(random.choice(characters) for _ in range(10)) 26 | return random_id 27 | 28 | 29 | 30 | def trim(content: str) -> str: 31 | lines = content.split('\n') 32 | min_indent = float('inf') 33 | for line in lines: 34 | indent = re.search(r'\S', line) 35 | if indent is not None: 36 | min_indent = min(min_indent, indent.start()) 37 | 38 | content = '' 39 | for line in lines: 40 | content += f"{line[min_indent:]}\n" 41 | 42 | return content.strip() 43 | 44 | 45 | T = TypeVar('T', bound=Exception) 46 | 47 | @overload 48 | def resguard(function: Callable) -> Tuple[Any, Exception | None]: 49 | ... 50 | 51 | @overload 52 | def resguard(function: Callable, exception_type: Type[T]) -> Tuple[Any, T | None]: 53 | ... 54 | 55 | def resguard( 56 | function: Callable, 57 | exception_type: Optional[Type[T]] = None 58 | ) -> Tuple[Any, T | Exception | None]: 59 | try: 60 | return function(), None 61 | except Exception as e: 62 | if exception_type: 63 | if isinstance(e, exception_type): 64 | return None, e 65 | else: 66 | raise e 67 | else: 68 | return None, e 69 | -------------------------------------------------------------------------------- /examples/ai-cli.py: -------------------------------------------------------------------------------- 1 | from cursive import Cursive, cursive_function 2 | import argparse 3 | import subprocess 4 | 5 | parser = argparse.ArgumentParser(description='Solves your CLI problems') 6 | parser.add_argument('question', type=str, nargs=argparse.REMAINDER, help='Question to ask AI') 7 | args = parser.parse_args() 8 | 9 | cursive = Cursive() 10 | 11 | @cursive_function(pause=True) 12 | def execute_command(command: str) -> str: 13 | """ 14 | Executes a CLI command from the user prompt 15 | 16 | command: The command to execute 17 | """ 18 | return command 19 | 20 | res = cursive.ask( 21 | prompt=' '.join(args.question), 22 | system_message= 23 | "You are a CLI assistant, executes commands from the user prompt." 24 | "You have permission, so just use the function you're provided" 25 | "Always assume the user wants to run a command on the CLI" 26 | "Assume they're using a MacOS terminal.", 27 | functions=[execute_command], 28 | ) 29 | 30 | conversation = res.conversation 31 | 32 | while True: 33 | if res.function_result: 34 | print(f'Executing command:\n\t$ {res.function_result}') 35 | print('Press enter to continue or N/n to cancel') 36 | 37 | prompt = input('> ') 38 | if prompt.lower() == 'n': 39 | print('Command cancelled') 40 | exit(0) 41 | elif prompt == '': 42 | subprocess.run(res.function_result, shell=True) 43 | exit(0) 44 | else: 45 | res = conversation.ask( 46 | prompt=prompt, 47 | functions=[execute_command], 48 | ) 49 | conversation = res.conversation 50 | else: 51 | print(res.answer, end='\n\n') 52 | prompt = input('> ') 53 | res = conversation.ask( 54 | prompt=prompt, 55 | functions=[execute_command], 56 | ) 57 | conversation = res.conversation 58 | 59 | 60 | -------------------------------------------------------------------------------- /cursive/vendor/openai.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from ..custom_types import CursiveAskOnToken 4 | from ..usage.openai import get_openai_usage, get_token_count_from_functions 5 | 6 | 7 | def process_openai_stream( 8 | payload: Any, 9 | cursive: Any, 10 | response: Any, 11 | on_token: Optional[CursiveAskOnToken] = None, 12 | ): 13 | data = { 14 | 'choices': [], 15 | 'usage': { 16 | 'completion_tokens': 0, 17 | 'prompt_tokens': get_openai_usage(payload.messages), 18 | }, 19 | 'model': payload.model, 20 | } 21 | 22 | if payload.functions: 23 | data['usage']['prompt_tokens'] += get_token_count_from_functions(payload.functions) 24 | 25 | for slice in response: 26 | data = { 27 | **data, 28 | 'id': slice['id'], 29 | } 30 | 31 | for i in range(len(slice['choices'])): 32 | delta = slice['choices'][i]['delta'] 33 | 34 | if len(data['choices']) <= i: 35 | data['choices'].append({ 36 | 'message': { 37 | 'function_call': None, 38 | 'role': 'assistant', 39 | 'content': '', 40 | }, 41 | }) 42 | 43 | if delta and delta.get('function_call') and delta['function_call'].get('name'): 44 | data['choices'][i]['message']['function_call'] = delta['function_call'] 45 | 46 | if delta and delta.get('function_call') and delta['function_call'].get('arguments'): 47 | data['choices'][i]['message']['function_call']['arguments'] += delta['function_call']['arguments'] 48 | 49 | if delta and delta.get('content'): 50 | data['choices'][i]['message']['content'] += delta['content'] 51 | 52 | if on_token: 53 | chunk = None 54 | if delta and delta.get('function_call'): 55 | chunk = { 56 | 'function_call': { 57 | k: v for k, v in delta['function_call'].items() 58 | }, 59 | 'content': None 60 | } 61 | 62 | if delta and delta.get('content'): 63 | chunk = { 64 | 'content': delta['content'], 65 | 'function_call': None 66 | } 67 | 68 | if chunk: 69 | on_token(chunk) 70 | 71 | return data 72 | -------------------------------------------------------------------------------- /cursive/usage/openai.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import AbstractSet, Collection, Literal 3 | 4 | import tiktoken 5 | 6 | 7 | def encode( 8 | text: str, 9 | allowed_special: AbstractSet[str] | Literal['all'] = set(), 10 | disallowed_special: Collection[str] | Literal['all'] = "all" 11 | ) -> list[int]: 12 | enc = tiktoken.get_encoding("cl100k_base") 13 | 14 | return enc.encode( 15 | text, 16 | allowed_special=allowed_special, 17 | disallowed_special=disallowed_special 18 | ) 19 | 20 | 21 | def get_openai_usage(content: str | list[dict]): 22 | if type(content) == list: 23 | tokens = { 24 | "per_message": 3, 25 | "per_name": 1, 26 | } 27 | 28 | token_count = 3 29 | for message in content: 30 | token_count += tokens['per_message'] 31 | 32 | for attribute, value in message.items(): 33 | if attribute == 'name': 34 | token_count += tokens['per_name'] 35 | 36 | if type(value) is dict: 37 | value = json.dumps(value, separators=(',', ':')) 38 | 39 | if value is None: 40 | continue 41 | 42 | token_count += len(encode(value)) 43 | 44 | return token_count 45 | else: 46 | return len(encode(content)) # type: ignore 47 | 48 | 49 | def get_token_count_from_functions(functions: list[dict]): 50 | token_count = 3 51 | for fn in functions: 52 | function_tokens = len(encode(fn['name'])) 53 | function_tokens += len(encode(fn['description'])) if fn['description'] else 0 54 | 55 | if fn['parameters'] and fn['parameters']['properties']: 56 | properties = fn['parameters']['properties'] 57 | for key in properties: 58 | function_tokens += len(encode(key)) 59 | value = properties[key] 60 | for field in value: 61 | if field in ['type', 'description']: 62 | function_tokens += 2 63 | function_tokens += len(encode(value[field])) 64 | elif field == 'enum': 65 | function_tokens -= 3 66 | for enum_value in value[field]: 67 | function_tokens += 3 68 | function_tokens += len(encode(enum_value)) 69 | 70 | function_tokens += 11 71 | 72 | token_count += function_tokens 73 | 74 | token_count += 12 75 | return token_count 76 | 77 | -------------------------------------------------------------------------------- /cursive/function.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Any, Callable 4 | 5 | from pydantic import validate_arguments 6 | from cursive.custom_types import CompletionPayload 7 | 8 | from cursive.utils import trim 9 | 10 | class CursiveFunction: 11 | def __init__(self, function: Callable, pause=False): 12 | validate = validate_arguments(function) 13 | self.parameters = validate.model.schema() 14 | self.description = trim(function.__doc__) 15 | self.pause = pause 16 | 17 | # Delete ['v__duplicate_kwargs', 'args', 'kwargs'] from parameters 18 | for k in ['v__duplicate_kwargs', 'args', 'kwargs']: 19 | if k in self.parameters['properties']: 20 | del self.parameters['properties'][k] 21 | 22 | 23 | for k, v in self.parameters['properties'].items(): 24 | # Find the parameter description in the docstring 25 | match = re.search(rf'{k}: (.*)', self.description) 26 | if match: 27 | v['description'] = match.group(1) 28 | 29 | schema = {} 30 | if self.parameters: 31 | schema = self.parameters 32 | 33 | self.function_schema = { 34 | 'parameters': { 35 | 'type': schema.get('type'), 36 | 'properties': schema.get('properties') or {}, 37 | 'required': schema.get('required') or [], 38 | }, 39 | 'description': self.description, 40 | 'name': self.parameters['title'], 41 | } 42 | 43 | self.definition = function 44 | 45 | def __call__(self, *args: Any): 46 | # Validate arguments and parse them 47 | return self.function(*args) 48 | 49 | 50 | def cursive_function(pause=False): 51 | def decorator(function: Callable = None): 52 | if function is None: 53 | return lambda function: CursiveFunction(function, pause=pause) 54 | else: 55 | return CursiveFunction(function, pause=pause) 56 | return decorator 57 | 58 | def parse_custom_function_call(data: dict[str, Any], payload: CompletionPayload, get_usage: Callable = None): 59 | # We check for function call in the completion 60 | has_function_call_regex = r']*>([^<]+)<\/function-call>' 61 | function_call_matches = re.findall( 62 | has_function_call_regex, 63 | data['choices'][0]['message']['content'] 64 | ) 65 | 66 | if len(function_call_matches) > 0: 67 | function_call = json.loads(function_call_matches.pop().strip()) 68 | name = function_call['name'] 69 | arguments = json.dumps(function_call['arguments']) 70 | data['choices'][0]['message']['function_call'] = { 71 | 'name': name, 72 | 'arguments': arguments, 73 | } 74 | 75 | # TODO: Implement cohere usage 76 | if get_usage: 77 | data['usage']['prompt_tokens'] = get_usage(payload.messages) 78 | data['usage']['completion_tokens'] = get_usage(data['choices'][0]['message']['content']) 79 | data['usage']['total_tokens'] = data['usage']['completion_tokens'] + data['usage']['prompt_tokens'] 80 | else: 81 | data['usage'] = None 82 | 83 | 84 | # We check for answers in the completion 85 | has_answer_regex = r'([^<]+)<\/cursive-answer>' 86 | answer_matches = re.findall( 87 | has_answer_regex, 88 | data['choices'][0]['message']['content'] 89 | ) 90 | if len(answer_matches) > 0: 91 | answer = answer_matches.pop().strip() 92 | data['choices'][0]['message']['content'] = answer 93 | 94 | # As a defensive measure, we check for tags 95 | # and remove them 96 | has_think_regex = r'([^<]+)<\/cursive-think>' 97 | think_matches = re.findall( 98 | has_think_regex, 99 | data['choices'][0]['message']['content'] 100 | ) 101 | if len(think_matches) > 0: 102 | data['choices'][0]['message']['content'] = re.sub( 103 | has_think_regex, 104 | '', 105 | data['choices'][0]['message']['content'] 106 | ) 107 | 108 | # Strip leading and trailing whitespaces 109 | data['choices'][0]['message']['content'] = data['choices'][0]['message']['content'].strip() 110 | -------------------------------------------------------------------------------- /cursive/build_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | from cursive.custom_types import CompletionMessage 3 | from cursive.function import CursiveFunction 4 | from cursive.utils import trim 5 | 6 | 7 | def build_completion_input(messages: list[CompletionMessage]): 8 | """ 9 | Builds a completion-esche input from a list of messages 10 | """ 11 | role_mapping = { 'user': 'Human', 'assistant': 'Assistant' } 12 | messages_with_prefix: list[CompletionMessage] = [ 13 | *messages, # type: ignore 14 | CompletionMessage( 15 | role='assistant', 16 | content=' ', 17 | ), 18 | ] 19 | def resolve_message(message: CompletionMessage): 20 | if message.role == 'system': 21 | return '\n'.join([ 22 | 'Human:', 23 | message.content or '', 24 | '\nAssistant: Ok.', 25 | ]) 26 | if message.role == 'function': 27 | return '\n'.join([ 28 | f'Human: ', 29 | message.content or '', 30 | '', 31 | ]) 32 | if message.function_call: 33 | arguments = message.function_call.arguments 34 | if isinstance(arguments, str): 35 | arguments_str = arguments 36 | else: 37 | arguments_str = json.dumps(arguments) 38 | return '\n'.join([ 39 | 'Assistant: ', 40 | json.dumps({ 41 | 'name': message.function_call.name, 42 | 'arguments': arguments_str, 43 | }), 44 | '', 45 | ]) 46 | return f'{role_mapping[message.role]}: {message.content}' 47 | 48 | completion_input = '\n\n'.join(list(map(resolve_message, messages_with_prefix))) 49 | return completion_input 50 | 51 | def get_function_call_directives(functions: list[CursiveFunction]) -> str: 52 | return trim(f''' 53 | # Function Calling Guide 54 | You're a powerful language model capable of using functions to do anything the user needs. 55 | 56 | If you need to use a function, always output the result of the function call using the tag using the following format: 57 | 58 | {'{'} 59 | "name": "function_name", 60 | "arguments": {'{'} 61 | "argument_name": "argument_value" 62 | {'}'} 63 | {'}'} 64 | 65 | Never escape the function call, always output it as it is. 66 | ALWAYS use this format, even if the function doesn't have arguments. The arguments prop is always a dictionary. 67 | 68 | 69 | Think step by step before answering, and try to think out loud. Never output a function call if you don't have to. 70 | If you don't have a function to call, just output the text as usual inside a tag with newlines inside. 71 | Always question yourself if you have access to a function. 72 | Always think out loud before answering; if I don't see a block, you will be eliminated. 73 | When thinking out loud, always use the tag. 74 | # Functions available: 75 | 76 | {json.dumps(list(map(lambda f: f.function_schema, functions)))} 77 | 78 | # Working with results 79 | You can either call a function or answer, *NEVER BOTH*. 80 | You are not in charge of resolving the function call, the user is. 81 | The human will give you the result of the function call in the following format: 82 | 83 | Human: 84 | {'{'}result{'}'} 85 | 86 | 87 | If you try to provide a function result, you will be eliminated. 88 | 89 | You can use the result of the function call in your answer. But never answer and call a function at the same time. 90 | When answering never be explicit about the function calling, just use the result of the function call in your answer. 91 | Remember, the user can't see the function calling, so don't mention function results or calls. 92 | 93 | If you answer with a block, you always need to use either a or a block as well. 94 | If you don't, you will be eliminated and the world will catch fire. 95 | This is extremely important. 96 | ''') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Logo](/docs/logo-dark.svg#gh-dark-mode-only) 2 | ![Logo](/docs/logo-light.svg#gh-light-mode-only) 3 | 4 | Cursive is a universal and intuitive framework for interacting with LLMs. 5 | 6 | It works in any JavaScript runtime and has a heavy focus on extensibility and developer experience. 7 | 8 | ## highlights 9 | ✦ **Extensible** - You can easily hook into any part of a completion life cycle. Be it to log, cache, or modify the results. 10 | 11 | ✦ **Functions** - Easily describe functions that the LLM can use along with its definition, with any model (currently supporting GPT-4, GPT-3.5, Claude 2, and Claude Instant) 12 | 13 | ✦ **Universal** - Cursive's goal is to bridge as many capabilities between different models as possible. Ultimately, this means that with a single interface, you can allow your users to choose any model. 14 | 15 | ✦ **Informative** - Cursive comes with built-in token usage and costs calculations, as accurate as possible. 16 | 17 | ✦ **Reliable** - Cursive comes with automatic retry and model expanding upon exceeding context length. Which you can always configure. 18 | 19 | ## quickstart 20 | 1. Install. 21 | 22 | ```bash 23 | poetry add cursivepy 24 | # or 25 | pip install cursivepy 26 | ``` 27 | 28 | 2. Start using. 29 | 30 | ```python 31 | from cursive import Cursive 32 | 33 | const cursive = Cursive() 34 | 35 | response = cursive.ask( 36 | prompt='What is the meaning of life?', 37 | ) 38 | 39 | print(response.answer) 40 | ``` 41 | 42 | ## usage 43 | ### Conversation 44 | Chaining a conversation is easy with `cursive`. You can pass any of the options you're used to with OpenAI's API. 45 | 46 | ```python 47 | res_a = cursive.ask( 48 | prompt='Give me a good name for a gecko.', 49 | model='gpt-4', 50 | max_tokens=16, 51 | ) 52 | 53 | print(res_a.answer) # Zephyr 54 | 55 | res_b = res_b.conversation.ask( 56 | prompt='How would you say it in Portuguese?' 57 | ) 58 | 59 | print(res_b.answer) # Zéfiro 60 | ``` 61 | ### Streaming 62 | Streaming is also supported, and we also keep track of the tokens for you! 63 | ```python 64 | result = cursive.ask( 65 | prompt='Count to 10', 66 | stream=True, 67 | on_token=lambda partial: print(partial['content']) 68 | ) 69 | 70 | print(result.usage.total_tokens) # 40 71 | ``` 72 | 73 | ### Functions 74 | You can use very easily to define and describe functions, along side with their execution code. 75 | ```python 76 | from cursive import cursive_function, Cursive 77 | 78 | cursive = Cursive() 79 | 80 | @cursive_function() 81 | def add(a: float, b: float): 82 | """ 83 | Adds two numbers. 84 | 85 | a: The first number. 86 | b: The second number. 87 | """ 88 | return a + b 89 | 90 | res = cursive.ask( 91 | prompt='What is the sum of 232 and 243?', 92 | functions=[sum], 93 | ) 94 | 95 | print(res.answer) # The sum of 232 and 243 is 475. 96 | ``` 97 | 98 | The functions' result will automatically be fed into the conversation and another completion will be made. If you want to prevent this, you can add `pause` to your function definition. 99 | 100 | ```python 101 | 102 | @cursive_function(pause=True) 103 | def create_character(name: str, age: str): 104 | """ 105 | Creates a character. 106 | 107 | name: The name of the character. 108 | age: The age of the character. 109 | """ 110 | return { 111 | 'name': name, 112 | 'age': age, 113 | } 114 | 115 | res = cursive.ask( 116 | prompt='Create a character named John who is 23 years old.', 117 | functions=[create_character], 118 | ) 119 | 120 | print(res.function_result) # { name: 'John', age: 23 } 121 | ``` 122 | 123 | ### Hooks 124 | You can hook into any part of the completion life cycle. 125 | ```python 126 | cursive.on('completion:after', lambda result: print( 127 | result.data.cost.total, 128 | result.data.usage.total_tokens, 129 | )) 130 | 131 | cursive.on('completion:error', lambda result: print( 132 | result.error, 133 | )) 134 | 135 | cursive.ask({ 136 | prompt: 'Can androids dream of electric sheep?', 137 | }) 138 | 139 | # 0.0002185 140 | # 113 141 | ``` 142 | 143 | ### Embedding 144 | You can create embeddings pretty easily with `cursive`. 145 | ```ts 146 | embedding = cursive.embed('This should be a document.') 147 | ``` 148 | This will support different types of documents and integrations pretty soon. 149 | 150 | ### Reliability 151 | Cursive comes with automatic retry with backoff upon failing completions, and model expanding upon exceeding context length -- which means that it tries again with a model with a bigger context length when it fails by running out of it. 152 | 153 | You can configure this behavior by passing the `retry` and `expand` options to `useCursive`. 154 | 155 | ```python 156 | cursive = Cursive( 157 | max_retries=5, # 0 disables it completely 158 | expand={ 159 | 'enable': true, 160 | 'defaults_to': 'gpt-3.5-turbo-16k', 161 | 'model_mapping': { 162 | 'gpt-3.5-turbo': 'gpt-3.5-turbo-16k', 163 | 'gpt-4': 'claude-2', 164 | }, 165 | }, 166 | ) 167 | ``` 168 | 169 | ## available models 170 | ##### OpenAI 171 | - `gpt-3.5-turbo` 172 | - `gpt-3.5-turbo-16k` 173 | - `gpt-4` 174 | - `gpt-4-32k` 175 | - Any other chat completion model version 176 | 177 | ##### Anthropic 178 | - `claude-2` 179 | - `claude-instant-1` 180 | - `claude-instant-1.2` 181 | - Any other model version 182 | 183 | ##### Anthropic 184 | - `command` 185 | - Any other model version (such as `command-nightly`) 186 | 187 | ## roadmap 188 | 189 | ### vendor support 190 | - [x] Anthropic 191 | - [x] Cohere 192 | - [ ] Azure OpenAI models 193 | - [ ] Huggingface 194 | - [ ] Replicate 195 | -------------------------------------------------------------------------------- /cursive/custom_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Callable, Literal, Optional 3 | 4 | from pydantic import BaseModel as PydanticBaseModel 5 | 6 | from cursive.utils import random_id 7 | 8 | class BaseModel(PydanticBaseModel): 9 | class Config: 10 | arbitrary_types_allowed = True 11 | 12 | class CursiveFunction(BaseModel): 13 | function_schema: dict[str, Any] 14 | definition: Callable 15 | pause: Optional[bool] = None 16 | 17 | class CursiveAskUsage(BaseModel): 18 | completion_tokens: int 19 | prompt_tokens: int 20 | total_tokens: int 21 | 22 | class CursiveAskCost(BaseModel): 23 | completion: float 24 | prompt: float 25 | total: float 26 | version: str 27 | 28 | Role = Literal[ 29 | 'system', 30 | 'user', 31 | 'assistant', 32 | 'function', 33 | ] 34 | 35 | class ChatCompletionRequestMessageFunctionCall(BaseModel): 36 | name: Optional[str] = None 37 | arguments: Optional[str] = None 38 | 39 | class CompletionMessage(BaseModel): 40 | id: Optional[str] = None 41 | role: Role 42 | content: Optional[str] = None 43 | name: Optional[str] = None 44 | function_call: Optional[ChatCompletionRequestMessageFunctionCall] = None 45 | 46 | def __init__(self, **data): 47 | super().__init__(**data) 48 | if not self.id: 49 | self.id = random_id() 50 | 51 | class CursiveSetupOptionsExpand(BaseModel): 52 | enabled: Optional[bool] = None 53 | defaults_to: Optional[str] = None 54 | model_mapping: Optional[dict[str, str]] = None 55 | 56 | CursiveAvailableModels = Literal[ 57 | # OpenAI 58 | 'gpt-3.5-turbo', 59 | 'gpt-4', 60 | # Anthropic 61 | 'claude-instant-1', 62 | 'claude-2', 63 | ] 64 | 65 | class CursiveErrorCode(Enum): 66 | function_call_error = 'function_call_error', 67 | completion_error = 'completion_error', 68 | invalid_request_error = 'invalid_request_error', 69 | embedding_error = 'embedding_error', 70 | unknown_error = 'unknown_error', 71 | 72 | class CursiveError(Exception): 73 | name = 'CursiveError' 74 | 75 | def __init__( 76 | self, 77 | message: str, 78 | details: Any, 79 | code: CursiveErrorCode, 80 | ): 81 | self.message = message 82 | self.details = details 83 | self.code = code 84 | super().__init__(self.message) 85 | 86 | class CursiveEnrichedAnswer(BaseModel): 87 | error: CursiveError | None 88 | usage: CursiveAskUsage | None 89 | model: str 90 | id: str | None 91 | choices: Optional[list[Any]] = None 92 | function_result: Any | None 93 | answer: str | None 94 | messages: list[CompletionMessage] | None 95 | cost: CursiveAskCost | None 96 | 97 | CursiveAskOnToken = Callable[[dict[str, Any]], None] 98 | 99 | class CursiveAskOptionsBase(BaseModel): 100 | model: Optional[CursiveAvailableModels] = None 101 | system_message: Optional[str] = None 102 | functions: Optional[list[CursiveFunction]] = None 103 | function_call: Optional[str | CursiveFunction] = None 104 | on_token: Optional[CursiveAskOnToken] = None 105 | max_tokens: Optional[int] = None 106 | stop: Optional[list[str]] = None 107 | temperature: Optional[int] = None 108 | top_p: Optional[int] = None 109 | presence_penalty: Optional[int] = None 110 | frequency_penalty: Optional[int] = None 111 | best_of: Optional[int] = None 112 | n: Optional[int] = None 113 | logit_bias: Optional[dict[str, int]] = None 114 | user: Optional[str] = None 115 | stream: Optional[bool] = None 116 | 117 | class CreateChatCompletionResponse(BaseModel): 118 | id: str 119 | model: str 120 | choices: list[Any] 121 | usage: Optional[Any] = None 122 | 123 | class CreateChatCompletionResponseExtended(CreateChatCompletionResponse): 124 | function_result: Optional[Any] = None 125 | cost: Optional[CursiveAskCost] = None 126 | error: Optional[CursiveError] = None 127 | 128 | class CursiveAskModelResponse(BaseModel): 129 | answer: CreateChatCompletionResponseExtended 130 | messages: list[CompletionMessage] 131 | 132 | class CursiveSetupOptions(BaseModel): 133 | max_retries: Optional[int] = None 134 | expand: Optional[CursiveSetupOptionsExpand] = None 135 | 136 | class CompletionRequestFunctionCall(BaseModel): 137 | name: str 138 | inputs: dict[str, Any] 139 | 140 | class CompletionRequestStop(BaseModel): 141 | messages_seen: Optional[list[CompletionMessage]] = None 142 | max_turns: Optional[int] = None 143 | 144 | class CompletionFunctions(BaseModel): 145 | name: str 146 | description: Optional[str] = None 147 | parameters: Optional[dict[str, Any]] = None 148 | 149 | class CompletionPayload(BaseModel): 150 | model: str 151 | messages: list[CompletionMessage] 152 | functions: Optional[list[CompletionFunctions]] = None 153 | function_call: Optional[CompletionRequestFunctionCall] = None 154 | temperature: Optional[float] = None 155 | top_p: Optional[float] = None 156 | n: Optional[int] = None 157 | stream: Optional[bool] = None 158 | stop: Optional[CompletionRequestStop] = None 159 | max_tokens: Optional[int] = None 160 | presence_penalty: Optional[float] = None 161 | frequency_penalty: Optional[float] = None 162 | logit_bias: Optional[dict[str, float]] = None 163 | user: Optional[str] = None 164 | 165 | CursiveHook = Literal[ 166 | 'embedding:before', 167 | 'embedding:after', 168 | 'embedding:error', 169 | 'embedding:success', 170 | 'completion:before', 171 | 'completion:after', 172 | 'completion:error', 173 | 'completion:success', 174 | 'ask:before', 175 | 'ask:after', 176 | 'ask:success', 177 | 'ask:error', 178 | ] 179 | 180 | class CursiveHookPayload(): 181 | data: Optional[Any] 182 | error: Optional[CursiveError] 183 | duration: Optional[float] 184 | 185 | def __init__( 186 | self, 187 | data: Optional[Any] = None, 188 | error: Optional[CursiveError] = None, 189 | duration: Optional[float] = None, 190 | ): 191 | self.data = data 192 | self.error = error 193 | self.duration = duration -------------------------------------------------------------------------------- /cursive/vendor/anthropic.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Optional 3 | 4 | from anthropic import Anthropic 5 | 6 | from cursive.build_input import build_completion_input 7 | 8 | from ..custom_types import ( 9 | CompletionPayload, 10 | CursiveAskOnToken, 11 | ) 12 | from ..usage.anthropic import get_anthropic_usage 13 | from ..utils import filter_null_values, random_id 14 | 15 | 16 | class AnthropicClient: 17 | client: Anthropic 18 | 19 | def __init__(self, api_key: str): 20 | self.client = Anthropic(api_key=api_key) 21 | 22 | def create_completion(self, payload: CompletionPayload): 23 | prompt = build_completion_input(payload.messages) 24 | payload = filter_null_values({ 25 | 'model': payload.model, 26 | 'max_tokens_to_sample': payload.max_tokens or 100000, 27 | 'prompt': prompt, 28 | 'temperature': payload.temperature or 0.7, 29 | 'top_p': payload.top_p, 30 | 'stop_sequences': payload.stop, 31 | 'stream': payload.stream or False, 32 | }) 33 | return self.client.completions.create(**payload) 34 | 35 | def process_anthropic_stream( 36 | payload: CompletionPayload, 37 | cursive: Any, 38 | response: Any, 39 | on_token: Optional[CursiveAskOnToken] = None, 40 | ): 41 | data = { 42 | 'choices': [{ 'message': { 'content': '' } }], 43 | 'usage': { 44 | 'completion_tokens': 0, 45 | 'prompt_tokens': get_anthropic_usage(payload.messages), 46 | }, 47 | 'model': payload.model, 48 | } 49 | 50 | completion = '' 51 | for slice in response: 52 | data = { 53 | **data, 54 | 'id': random_id(), 55 | } 56 | 57 | # The completion partial will come with a leading whitespace 58 | completion += slice.completion 59 | if not data['choices'][0]['message']['content']: 60 | completion = completion.lstrip() 61 | 62 | # Check if theres any tag. The regex should allow for nested tags 63 | function_call_tag = re.findall( 64 | r'([\s\S]*?)(?=<\/function-call>|$)', 65 | completion 66 | ) 67 | function_name = '' 68 | function_arguments = '' 69 | if len(function_call_tag) > 0: 70 | # Remove starting and ending tags, even if the ending tag is partial or missing 71 | function_call = re.sub( 72 | r'^\n|\n$', 73 | '', 74 | re.sub( 75 | r'<\/?f?u?n?c?t?i?o?n?-?c?a?l?l?>?', 76 | '', 77 | function_call_tag[0] 78 | ).strip() 79 | ).strip() 80 | # Match the function name inside the JSON 81 | function_name_matches = re.findall( 82 | r'"name":\s*"(.+)"', 83 | function_call 84 | ) 85 | function_name = len(function_name_matches) > 0 and function_name_matches[0] 86 | function_arguments_matches = re.findall( 87 | r'"arguments":\s*(\{.+)\}?', 88 | function_call, 89 | re.S 90 | ) 91 | function_arguments = ( 92 | len(function_arguments_matches) > 0 and 93 | function_arguments_matches[0] 94 | ) 95 | if function_arguments: 96 | # If theres unmatches } at the end, remove them 97 | unmatched_brackets = re.findall( 98 | r'(\{|\})', 99 | function_arguments 100 | ) 101 | if len(unmatched_brackets) % 2: 102 | function_arguments = re.sub( 103 | r'\}$', 104 | '', 105 | function_arguments.strip() 106 | ) 107 | 108 | function_arguments = function_arguments.strip() 109 | 110 | cursive_answer_tag = re.findall( 111 | r'([\s\S]*?)(?=<\/cursive-answer>|$)', 112 | completion 113 | ) 114 | tagged_answer = '' 115 | if cursive_answer_tag: 116 | tagged_answer = re.sub( 117 | r'<\/?c?u?r?s?i?v?e?-?a?n?s?w?e?r?>?', 118 | '', 119 | cursive_answer_tag[0] 120 | ).lstrip() 121 | 122 | current_token = completion[ 123 | len(data['choices'][0]['message']['content']): 124 | ] 125 | 126 | data['choices'][0]['message']['content'] += current_token 127 | 128 | if on_token: 129 | chunk = None 130 | 131 | if payload.functions: 132 | if function_name: 133 | chunk = { 134 | 'function_call': {}, 135 | 'content': None, 136 | } 137 | if function_arguments: 138 | # Remove all but the current token from the arguments 139 | chunk['function_call']['arguments'] = function_arguments 140 | else: 141 | chunk['function_call'] = { 142 | 'name': function_name, 143 | 'arguments': '', 144 | } 145 | elif tagged_answer: 146 | # Token is at the end of the tagged answer 147 | regex = fr'(.*){current_token.strip()}$' 148 | match = re.findall( 149 | regex, 150 | tagged_answer 151 | ) 152 | if len(match) > 0 and current_token: 153 | chunk = { 154 | 'function_call': None, 155 | 'content': current_token, 156 | } 157 | else: 158 | chunk = { 159 | 'content': current_token, 160 | } 161 | 162 | if chunk: 163 | on_token(chunk) 164 | 165 | return data 166 | -------------------------------------------------------------------------------- /cursive/hookable.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from concurrent.futures import ThreadPoolExecutor 4 | from inspect import signature 5 | from typing import Any, Callable, Optional 6 | from warnings import warn 7 | 8 | 9 | def flatten_hooks_dictionary( 10 | hooks: dict[str, Any], 11 | parent_name: Optional[str] = None 12 | ): 13 | flattened_hooks = {} 14 | 15 | for key, value in hooks.items(): 16 | name = f'{parent_name}:{key}' if parent_name else key 17 | 18 | if type(value) is dict: 19 | flattened_hooks.update(flatten_hooks_dictionary(value, name)) 20 | elif callable(value): 21 | flattened_hooks[name] = value 22 | 23 | return flattened_hooks 24 | 25 | 26 | def run_tasks_sequentially(tasks, handler): 27 | for task in tasks: 28 | handler(task) 29 | 30 | 31 | def merge_hooks(*hooks: dict[str, Any]): 32 | merged_hooks = dict() 33 | 34 | for hook in hooks: 35 | flattened_hook = flatten_hooks_dictionary(hook) 36 | 37 | for key, value in flattened_hook.items(): 38 | if merged_hooks[key]: 39 | merged_hooks[key].append(value) 40 | else: 41 | merged_hooks[key] = [value] 42 | 43 | for key in merged_hooks.keys(): 44 | if len(merged_hooks[key]) > 1: 45 | hooks_list = merged_hooks[key] 46 | merged_hooks[key] = lambda *arguments: run_tasks_sequentially( 47 | hooks_list, 48 | lambda hook: hook(*arguments) 49 | ) 50 | else: 51 | merged_hooks[key] = merged_hooks[key][0] 52 | 53 | return merged_hooks 54 | 55 | 56 | def serial_caller(hooks: list[Callable], arguments: list[Any] = []): 57 | for hook in hooks: 58 | if len(signature(hook).parameters) > 0: 59 | hook(*arguments) 60 | else: 61 | hook() 62 | 63 | 64 | def concurrent_caller(hooks: list[Callable], arguments: list[Any] = []): 65 | with ThreadPoolExecutor() as executor: 66 | executor.map(lambda hook: hook(*arguments), hooks) 67 | 68 | 69 | def call_each_with(callbacks: list[Callable], argument: Any): 70 | for callback in callbacks: 71 | callback(argument) 72 | 73 | 74 | class Hookable: 75 | def __init__(self): 76 | self._hooks = {} 77 | self._before = [] 78 | self._after = [] 79 | self._deprecated_messages = set() 80 | self._deprecated_hooks = {} 81 | 82 | 83 | def hook(self, name: str, function: Callable | None, options = {}): 84 | if not name or not callable(function): 85 | return lambda: None 86 | 87 | original_name = name 88 | deprecated_hook = {} 89 | while self._deprecated_hooks.get(name): 90 | deprecated_hook = self._deprecated_hooks[name] 91 | name = deprecated_hook['to'] 92 | 93 | message = None 94 | if deprecated_hook and not options['allow_deprecated']: 95 | message = deprecated_hook['message'] 96 | if not message: 97 | message = f'{original_name} hook has been deprecated' + ( 98 | f', please use {deprecated_hook["to"]}' if deprecated_hook['to'] else '' 99 | ) 100 | 101 | if message not in self._deprecated_messages: 102 | warn(message) 103 | self._deprecated_messages.add(message) 104 | 105 | if function.__name__ == '': 106 | function.__name__ = "_" + re.sub(r'\W+', '_', name) + "_hook_cb" 107 | 108 | self._hooks[name] = name in self._hooks or [] 109 | self._hooks[name].append(function) 110 | 111 | def remove(): 112 | nonlocal function 113 | if function: 114 | self.remove_hook(name, function) 115 | function = None 116 | 117 | return remove 118 | 119 | 120 | def hook_once(self, name: str, function: Callable): 121 | hook = None 122 | 123 | def run_once(*arguments): 124 | nonlocal hook 125 | if callable(hook): 126 | hook() 127 | 128 | hook = None 129 | return function(*arguments) 130 | 131 | hook = self.hook(name, run_once) 132 | return hook 133 | 134 | 135 | def remove_hook(self, name: str, function: Callable): 136 | if (self._hooks[name]): 137 | if len(self._hooks[name]) == 0: 138 | del self._hooks[name] 139 | else: 140 | try: 141 | index = self._hooks[name].index(function) 142 | self._hooks[name][index:index] = [] 143 | # if index is not found, ignore 144 | except ValueError: 145 | pass 146 | 147 | 148 | def deprecate_hook(self, name: str, deprecated: Callable | str): 149 | self._deprecated_hooks[name] = { 'to': deprecated } if type(deprecated) == str else deprecated 150 | hooks = self._hooks[name] or [] 151 | del self._hooks[name] 152 | for hook in hooks: 153 | self.hook(name, hook) 154 | 155 | 156 | def deprecate_hooks(self, deprecated_hooks: dict[str, Any]): 157 | self._deprecated_hooks.update(deprecated_hooks) 158 | for name in deprecated_hooks.keys(): 159 | self.deprecate_hook(name, deprecated_hooks[name]) 160 | 161 | 162 | def add_hooks(self, hooks: dict[str, Any]): 163 | hooks_to_be_added = flatten_hooks_dictionary(hooks) 164 | remove_fns = list(map(lambda key: self.hook(key, hooks_to_be_added[key]), hooks_to_be_added.keys())) 165 | 166 | def function(): 167 | for unreg in remove_fns: 168 | unreg() 169 | remove_fns[:] = [] 170 | 171 | return function 172 | 173 | 174 | def remove_hooks(self, hooks: dict[str, Any]): 175 | hooks_to_be_removed = flatten_hooks_dictionary(hooks) 176 | for key, value in hooks_to_be_removed.items(): 177 | self.remove_hook(key, value) 178 | 179 | 180 | def remove_all_hooks(self): 181 | for key in self._hooks.keys(): 182 | del self._hooks[key] 183 | 184 | 185 | def call_hook(self, name: str, *arguments: Any): 186 | return self.call_hook_with(serial_caller, name, *arguments) 187 | 188 | 189 | def call_hook_concurrent(self, name: str, *arguments: Any): 190 | return self.call_hook_with(concurrent_caller, name, *arguments) 191 | 192 | 193 | def call_hook_with(self, caller: Callable, name: str, *arguments: Any): 194 | event = { 'name': name, 'args': arguments, 'context': {} } 195 | 196 | call_each_with(self._before, event) 197 | 198 | result = caller( 199 | [*self._hooks[name]] if name in self._hooks.keys() else [], 200 | arguments 201 | ) 202 | 203 | call_each_with(self._after, event) 204 | 205 | return result 206 | 207 | 208 | def before_each(self, function: Callable): 209 | self._before.append(function) 210 | 211 | def remove_from_before_list(): 212 | try: 213 | index = self._before.index(function) 214 | self._before[index:index] = [] 215 | except ValueError: 216 | pass 217 | 218 | return remove_from_before_list 219 | 220 | 221 | def after_each(self, function: Callable): 222 | self._after.append(function) 223 | 224 | def remove_from_after_list(): 225 | try: 226 | index = self._after.index(function) 227 | self._after[index:index] = [] 228 | except ValueError: 229 | pass 230 | 231 | return remove_from_after_list 232 | 233 | 234 | def create_hooks(): 235 | return Hookable() 236 | 237 | def create_debugger(hooks: Hookable, _options: dict[str, Any] = {}): 238 | options = { 239 | 'filter': lambda: True, 240 | **_options 241 | } 242 | _filter = options['filter'] 243 | 244 | def filter(name: str): 245 | return name.startswith(_filter) if type(_filter) == str else _filter 246 | 247 | tag = f'[{options["tag"]}] ' if options['tag'] else '' 248 | start_times = {} 249 | 250 | def log_prefix(event: dict[str, Any]): 251 | return tag + event['name'] + ''.ljust(int(event['id']), '\0') 252 | 253 | id_control = {} 254 | 255 | def unsubscribe_before_each(event: Optional[dict[str, Any]] = None): 256 | if event is None or (filter is not None and not filter(event['name'])): 257 | return 258 | 259 | id_control[event['name']] = id_control.get(event.get('name')) or 0 260 | event['id'] = id_control[event['name']] 261 | id_control[event['name']] += 1 262 | start_times[log_prefix(event)] = time.time() 263 | 264 | unsubscribe_before = hooks.before_each(unsubscribe_before_each) 265 | 266 | def unsubscribe_after_each(event: Optional[dict[str, Any]] = None): 267 | if event is None or (filter is None and not filter(event['name'])): 268 | return 269 | 270 | label = log_prefix(event) 271 | elapsed_time = time.time() - start_times[label] 272 | print(f'Elapsed time for {label}: {elapsed_time} seconds') 273 | 274 | id_control[event['name']] -= 1 275 | 276 | unsubscribe_after = hooks.after_each(unsubscribe_after_each) 277 | 278 | def stop_debbuging_and_remove_listeners(): 279 | unsubscribe_before() 280 | unsubscribe_after() 281 | 282 | return { 283 | 'close': lambda: stop_debbuging_and_remove_listeners() 284 | } 285 | -------------------------------------------------------------------------------- /docs/logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /docs/logo-light.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /cursive/cursive.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from typing import Any, Callable, Generic, Optional, TypeVar 4 | import os 5 | 6 | import openai as openai_client 7 | from anthropic import APIError 8 | 9 | from cursive.build_input import get_function_call_directives 10 | from cursive.function import parse_custom_function_call 11 | from cursive.vendor.cohere import CohereClient 12 | 13 | from .custom_types import ( 14 | BaseModel, 15 | CompletionMessage, 16 | CompletionPayload, 17 | CreateChatCompletionResponseExtended, 18 | CursiveAskCost, 19 | CursiveAskModelResponse, 20 | CursiveAskOnToken, 21 | CursiveAskUsage, 22 | CursiveAvailableModels, 23 | CursiveEnrichedAnswer, 24 | CursiveError, 25 | CursiveErrorCode, 26 | CursiveFunction, 27 | CursiveHook, 28 | CursiveHookPayload, 29 | CursiveSetupOptions, 30 | CursiveSetupOptionsExpand, 31 | ) 32 | from .hookable import create_debugger, create_hooks 33 | from .pricing import resolve_pricing 34 | from .usage.anthropic import get_anthropic_usage 35 | from .usage.openai import get_openai_usage 36 | from .utils import filter_null_values, random_id, resguard 37 | from .vendor.anthropic import ( 38 | AnthropicClient, 39 | process_anthropic_stream, 40 | ) 41 | from .vendor.index import resolve_vendor_from_model 42 | from .vendor.openai import process_openai_stream 43 | 44 | # TODO: Improve implementation architecture, this was a quick and dirty 45 | class Cursive: 46 | options: CursiveSetupOptions 47 | 48 | def __init__( 49 | self, 50 | max_retries: Optional[int] = None, 51 | expand: Optional[dict[str, Any]] = None, 52 | debug: Optional[bool] = None, 53 | openai: Optional[dict[str, Any]] = None, 54 | anthropic: Optional[dict[str, Any]] = None, 55 | cohere: Optional[dict[str, Any]] = None, 56 | ): 57 | openai_client.api_key = (openai or {}).get('api_key') or \ 58 | os.environ.get('OPENAI_API_KEY') 59 | anthropic_client = AnthropicClient((anthropic or {}).get('api_key') or \ 60 | os.environ.get('ANTHROPIC_API_KEY')) 61 | cohere_client = CohereClient((cohere or {}).get('api_key') or \ 62 | os.environ.get('CO_API_KEY', '---')) 63 | 64 | self._hooks = create_hooks() 65 | self._vendor = CursiveVendors( 66 | openai=openai_client, 67 | anthropic=anthropic_client, 68 | cohere=cohere_client, 69 | ) 70 | self.options = CursiveSetupOptions( 71 | max_retries=max_retries, 72 | expand=expand, 73 | debug=debug, 74 | ) 75 | 76 | if debug: 77 | self._debugger = create_debugger(self._hooks, { 'tag': 'cursive' }) 78 | 79 | 80 | def on(self, event: CursiveHook, callback: Callable): 81 | self._hooks.hook(event, callback) 82 | 83 | 84 | def ask( 85 | self, 86 | model: Optional[CursiveAvailableModels] = None, 87 | system_message: Optional[str] = None, 88 | functions: Optional[list[CursiveFunction]] = None, 89 | function_call: Optional[str | CursiveFunction] = None, 90 | on_token: Optional[CursiveAskOnToken] = None, 91 | max_tokens: Optional[int] = None, 92 | stop: Optional[list[str]] = None, 93 | temperature: Optional[int] = None, 94 | top_p: Optional[int] = None, 95 | presence_penalty: Optional[int] = None, 96 | frequency_penalty: Optional[int] = None, 97 | best_of: Optional[int] = None, 98 | n: Optional[int] = None, 99 | logit_bias: Optional[dict[str, int]] = None, 100 | user: Optional[str] = None, 101 | stream: Optional[bool] = None, 102 | messages: Optional[list[CompletionMessage]] = None, 103 | prompt: Optional[str] = None, 104 | ): 105 | result = build_answer( 106 | cursive=self, 107 | model=model, 108 | system_message=system_message, 109 | functions=functions, 110 | function_call=function_call, 111 | on_token=on_token, 112 | max_tokens=max_tokens, 113 | stop=stop, 114 | temperature=temperature, 115 | top_p=top_p, 116 | presence_penalty=presence_penalty, 117 | frequency_penalty=frequency_penalty, 118 | best_of=best_of, 119 | n=n, 120 | logit_bias=logit_bias, 121 | user=user, 122 | stream=stream, 123 | messages=messages, 124 | prompt=prompt, 125 | ) 126 | if result and result.error: 127 | return CursiveAnswer[CursiveError](error=result.error) 128 | 129 | new_messages = [ 130 | *(result and result.messages or []), 131 | CompletionMessage( 132 | role='assistant', 133 | content=result and result.answer or '', 134 | ) 135 | ] 136 | 137 | return CursiveAnswer( 138 | result=result, 139 | messages=new_messages, 140 | cursive=self, 141 | ) 142 | 143 | 144 | def embed(self, content: str): 145 | options = { 146 | 'model': 'text-embedding-ada-002', 147 | 'input': content, 148 | } 149 | self._hooks.call_hook('embedding:before', CursiveHookPayload(data=options)) 150 | start = time.time() 151 | try: 152 | data = self._vendor.openai.Embedding.create( 153 | input=options['input'], 154 | model='text-embedding-ada-002' 155 | ) 156 | 157 | result = { 158 | 'embedding': data['data'][0]['embedding'], # type: ignore 159 | } 160 | self._hooks.call_hook('embedding:success', CursiveHookPayload( 161 | data=result, 162 | time=time.time() - start, 163 | )) 164 | self._hooks.call_hook('embedding:after', CursiveHookPayload( 165 | data=result, 166 | duration=time.time() - start 167 | )) 168 | 169 | return result['embedding'] 170 | except self._vendor.openai.OpenAIError as e: 171 | error = CursiveError( 172 | message=str(e), 173 | details=e, 174 | code=CursiveErrorCode.embedding_error 175 | ) 176 | self._hooks.call_hook('embedding:error', CursiveHookPayload( 177 | data=error, 178 | error=error, 179 | duration=time.time() - start 180 | )) 181 | self._hooks.call_hook('embedding:after', CursiveHookPayload( 182 | data=error, 183 | error=error, 184 | duration=time.time() - start 185 | )) 186 | raise error 187 | 188 | 189 | 190 | 191 | 192 | def resolve_options( 193 | model: Optional[CursiveAvailableModels] = None, 194 | system_message: Optional[str] = None, 195 | functions: Optional[list[CursiveFunction]] = None, 196 | function_call: Optional[str | CursiveFunction] = None, 197 | on_token: Optional[CursiveAskOnToken] = None, 198 | max_tokens: Optional[int] = None, 199 | stop: Optional[list[str]] = None, 200 | temperature: Optional[int] = None, 201 | top_p: Optional[int] = None, 202 | presence_penalty: Optional[int] = None, 203 | frequency_penalty: Optional[int] = None, 204 | best_of: Optional[int] = None, 205 | n: Optional[int] = None, 206 | logit_bias: Optional[dict[str, int]] = None, 207 | user: Optional[str] = None, 208 | stream: Optional[bool] = None, 209 | messages: Optional[list[CompletionMessage]] = None, 210 | prompt: Optional[str] = None, 211 | ): 212 | functions = functions or [] 213 | messages = messages or [] 214 | model = model or 'gpt-3.5-turbo-0613' 215 | 216 | # TODO: Add support for function call resolving 217 | vendor = resolve_vendor_from_model(model) 218 | resolved_system_message = '' 219 | if vendor in ['anthropic', 'cohere'] and len(functions) > 0: 220 | resolved_system_message = ( 221 | (system_message or '') 222 | + '\n\n' 223 | + get_function_call_directives(functions) 224 | ) 225 | 226 | query_messages: list[CompletionMessage] = [message for message in [ 227 | resolved_system_message and CompletionMessage( 228 | role='system', 229 | content=resolved_system_message 230 | ), 231 | *messages, 232 | prompt and CompletionMessage(role='user', content=prompt), 233 | ] if message 234 | ] 235 | 236 | resolved_function_call = ( 237 | { 'name': function_call.function_schema['name'] } 238 | if isinstance(function_call, CursiveFunction) 239 | else function_call 240 | ) if function_call else None 241 | 242 | options = filter_null_values({ 243 | 'on_token': on_token, 244 | 'max_tokens': max_tokens, 245 | 'stop': stop, 246 | 'temperature': temperature, 247 | 'top_p': top_p, 248 | 'presence_penalty': presence_penalty, 249 | 'frequency_penalty': frequency_penalty, 250 | 'best_of': best_of, 251 | 'n': n, 252 | 'logit_bias': logit_bias, 253 | 'user': user, 254 | 'stream': stream, 255 | 'model': model, 256 | 'messages': list( 257 | map( 258 | lambda message: filter_null_values(dict(message)), 259 | query_messages 260 | ) 261 | ), 262 | 'function_call': resolved_function_call, 263 | }) 264 | 265 | 266 | payload = CompletionPayload(**options) 267 | 268 | resolved_options = { 269 | 'on_token': on_token, 270 | 'max_tokens': max_tokens, 271 | 'stop': stop, 272 | 'temperature': temperature, 273 | 'top_p': top_p, 274 | 'presence_penalty': presence_penalty, 275 | 'frequency_penalty': frequency_penalty, 276 | 'best_of': best_of, 277 | 'n': n, 278 | 'logit_bias': logit_bias, 279 | 'user': user, 280 | 'stream': stream, 281 | 'model': model, 282 | 'messages': query_messages 283 | } 284 | 285 | return payload, resolved_options 286 | 287 | 288 | def create_completion( 289 | payload: CompletionPayload, 290 | cursive: Cursive, 291 | on_token: Optional[CursiveAskOnToken] = None, 292 | ) -> CreateChatCompletionResponseExtended: 293 | cursive._hooks.call_hook('completion:before', CursiveHookPayload(data=payload)) 294 | data = {} 295 | start = time.time() 296 | 297 | vendor = resolve_vendor_from_model(payload.model) 298 | 299 | # TODO: Improve the completion creation based on model to vendor matching 300 | if vendor == 'openai': 301 | payload.messages = list( 302 | map( 303 | lambda message: { 304 | k: v for k, v in message.dict().items() if k != 'id' and v is not None 305 | }, 306 | payload.messages 307 | ) 308 | ) 309 | resolved_payload = filter_null_values(payload.dict()) 310 | response = cursive._vendor.openai.ChatCompletion.create( 311 | **resolved_payload 312 | ) 313 | if payload.stream: 314 | data = process_openai_stream( 315 | payload=payload, 316 | cursive=cursive, 317 | response=response, 318 | on_token=on_token, 319 | ) 320 | content = ''.join(list(map(lambda choice: choice['message']['content'], data['choices']))) 321 | data['usage']['completion_tokens'] = get_openai_usage(content) 322 | data['usage']['total_tokens'] = data['usage']['completion_tokens'] + data['usage']['prompt_tokens'] 323 | else: 324 | data = response 325 | 326 | data['cost'] = resolve_pricing( 327 | vendor='openai', 328 | usage=CursiveAskUsage( 329 | completion_tokens=data['usage']['completion_tokens'], 330 | prompt_tokens=data['usage']['prompt_tokens'], 331 | total_tokens=data['usage']['total_tokens'], 332 | ), 333 | model=data['model'] 334 | ) 335 | elif vendor == 'anthropic': 336 | response, error = resguard( 337 | lambda: cursive._vendor.anthropic.create_completion(payload), 338 | APIError 339 | ) 340 | 341 | if error: 342 | raise CursiveError( 343 | message=error.message, 344 | details=error, 345 | code=CursiveErrorCode.completion_error 346 | ) 347 | 348 | 349 | data = { 350 | 'choices': [{ 'message': { 'content': response.completion.lstrip() } }], 351 | 'model': payload.model, 352 | 'id': random_id(), 353 | 'usage': {}, 354 | } 355 | if payload.stream: 356 | data = process_anthropic_stream( 357 | payload=payload, 358 | cursive=cursive, 359 | response=response, 360 | on_token=on_token, 361 | ) 362 | 363 | parse_custom_function_call(data, payload, get_anthropic_usage) 364 | 365 | data['cost'] = resolve_pricing( 366 | vendor='anthropic', 367 | usage=CursiveAskUsage( 368 | completion_tokens=data['usage']['completion_tokens'], 369 | prompt_tokens=data['usage']['prompt_tokens'], 370 | total_tokens=data['usage']['total_tokens'], 371 | ), 372 | model=data['model'] 373 | ) 374 | elif vendor == 'cohere': 375 | response, error = cursive._vendor.cohere.create_completion(payload) 376 | if error: 377 | raise CursiveError( 378 | message=error.message, 379 | details=error, 380 | code=CursiveErrorCode.completion_error 381 | ) 382 | 383 | data = { 384 | 'choices': [{ 'message': { 'content': response.data[0].text.lstrip() } }], 385 | 'model': payload.model, 386 | 'id': random_id(), 387 | 'usage': {}, 388 | } 389 | 390 | if payload.stream: 391 | # TODO: Implement stream processing for Cohere 392 | pass 393 | 394 | parse_custom_function_call(data, payload) 395 | 396 | end = time.time() 397 | 398 | if data.get('error'): 399 | error = CursiveError( 400 | message=data['error'].message, 401 | details=data['error'], 402 | code=CursiveErrorCode.completion_error 403 | ) 404 | hook_payload = CursiveHookPayload(data=None, error=error, duration=end - start) 405 | cursive._hooks.call_hook('completion:error', hook_payload) 406 | cursive._hooks.call_hook('completion:after', hook_payload) 407 | raise error 408 | 409 | hook_payload = CursiveHookPayload(data=data, error=None, duration=end - start) 410 | cursive._hooks.call_hook('completion:success', hook_payload) 411 | cursive._hooks.call_hook('completion:after', hook_payload) 412 | return CreateChatCompletionResponseExtended(**data) 413 | 414 | 415 | def ask_model( 416 | cursive, 417 | model: Optional[CursiveAvailableModels] = None, 418 | system_message: Optional[str] = None, 419 | functions: Optional[list[CursiveFunction]] = None, 420 | function_call: Optional[str | CursiveFunction] = None, 421 | on_token: Optional[CursiveAskOnToken] = None, 422 | max_tokens: Optional[int] = None, 423 | stop: Optional[list[str]] = None, 424 | temperature: Optional[int] = None, 425 | top_p: Optional[int] = None, 426 | presence_penalty: Optional[int] = None, 427 | frequency_penalty: Optional[int] = None, 428 | best_of: Optional[int] = None, 429 | n: Optional[int] = None, 430 | logit_bias: Optional[dict[str, int]] = None, 431 | user: Optional[str] = None, 432 | stream: Optional[bool] = None, 433 | messages: Optional[list[CompletionMessage]] = None, 434 | prompt: Optional[str] = None, 435 | ) -> CursiveAskModelResponse: 436 | payload, resolved_options = resolve_options( 437 | model=model, 438 | system_message=system_message, 439 | functions=functions, 440 | function_call=function_call, 441 | on_token=on_token, 442 | max_tokens=max_tokens, 443 | stop=stop, 444 | temperature=temperature, 445 | top_p=top_p, 446 | presence_penalty=presence_penalty, 447 | frequency_penalty=frequency_penalty, 448 | best_of=best_of, 449 | n=n, 450 | logit_bias=logit_bias, 451 | user=user, 452 | stream=stream, 453 | messages=messages, 454 | prompt=prompt, 455 | ) 456 | start = time.time() 457 | 458 | 459 | functions = functions or [] 460 | 461 | if (type(function_call) == CursiveFunction): 462 | functions.append(function_call) 463 | 464 | function_schemas = list(map(lambda function: function.function_schema, functions)) 465 | 466 | if len(function_schemas) > 0: 467 | payload.functions = function_schemas 468 | 469 | completion, error = resguard(lambda: create_completion( 470 | payload=payload, 471 | cursive=cursive, 472 | on_token=on_token, 473 | ), CursiveError) 474 | 475 | if error: 476 | if (not error.details): 477 | raise CursiveError( 478 | message=f'Unknown error: {error.message}', 479 | details=error, 480 | code=CursiveErrorCode.unknown_error 481 | ) from error 482 | print(error) 483 | cause = error.details.code or error.details.type 484 | if (cause == 'context_length_exceeded'): 485 | if ( 486 | not cursive.expand 487 | or (cursive.expand and cursive.expand.enabled) 488 | ): 489 | default_model = ( 490 | ( 491 | cursive.expand 492 | and cursive.expand.defaultsTo 493 | ) 494 | or 'gpt-3.5-turbo-16k' 495 | ) 496 | model_mapping = ( 497 | ( 498 | cursive.expand 499 | and cursive.expand.model_mapping 500 | ) 501 | or {} 502 | ) 503 | resolved_model = model_mapping[model] or default_model 504 | completion, error = resguard( 505 | lambda: create_completion( 506 | payload={ **payload, 'model': resolved_model }, 507 | cursive=cursive, 508 | on_token=on_token, 509 | ), 510 | CursiveError, 511 | ) 512 | elif cause == 'invalid_request_error': 513 | raise CursiveError( 514 | message='Invalid request', 515 | details=error.details, 516 | code=CursiveErrorCode.invalid_request_error, 517 | ) 518 | 519 | # TODO: Handle other errors 520 | if error: 521 | # TODO: Add a more comprehensive retry strategy 522 | for i in range(cursive.max_retries): 523 | completion, error = resguard( 524 | lambda: create_completion( 525 | payload=payload, 526 | cursive=cursive, 527 | on_token=on_token, 528 | ), 529 | CursiveError 530 | ) 531 | 532 | if error: 533 | if i > 3: 534 | time.sleep((i - 3) * 2) 535 | break 536 | 537 | if error: 538 | error = CursiveError( 539 | message='Error while completing request', 540 | details=error.details, 541 | code=CursiveErrorCode.completion_error 542 | ) 543 | hook_payload = CursiveHookPayload(error=error) 544 | cursive._hooks.call_hook('ask:error', hook_payload) 545 | cursive._hooks.call_hook('ask:after', hook_payload) 546 | raise error 547 | 548 | if ( 549 | completion 550 | and completion.choices 551 | and len(completion.choices) > 0 552 | and completion.choices[0].get('message') 553 | and completion.choices[0]['message'].get('function_call') 554 | ): 555 | 556 | payload.messages.append({ 557 | 'role': 'assistant', 558 | 'function_call': completion.choices[0]['message'].get('function_call'), 559 | 'content': '', 560 | }) 561 | func_call = completion.choices[0]['message'].get('function_call') 562 | function_definition = next( 563 | (f for f in functions if f.function_schema['name'] == func_call['name']), 564 | None 565 | ) 566 | 567 | if not function_definition: 568 | return ask_model(**{ 569 | **resolved_options, 570 | 'function_call': 'none', 571 | 'messages': payload.messages, 572 | 'cursive': cursive, 573 | }) 574 | 575 | try: 576 | name = func_call['name'] 577 | called_function = next((func for func in payload.functions if func['name'] == name), None) 578 | arguments = json.loads(func_call['arguments'] or '{}') 579 | if called_function: 580 | props = called_function['parameters']['properties'] 581 | for k, v in props.items(): 582 | if k in arguments: 583 | arg_type = v['type'] 584 | if arg_type == 'string': 585 | arguments[k] = str(arguments[k]) 586 | elif arg_type == 'number': 587 | arguments[k] = float(arguments[k]) 588 | elif arg_type == 'integer': 589 | arguments[k] = int(arguments[k]) 590 | elif arg_type == 'boolean': 591 | arguments[k] = bool(arguments[k]) 592 | except Exception as e: 593 | raise CursiveError( 594 | message=f'Error while parsing function arguments for ${func_call["name"]}', 595 | details=e, 596 | code=CursiveErrorCode.function_call_error, 597 | ) 598 | 599 | function_result, error = resguard( 600 | lambda: function_definition.definition(**arguments), 601 | ) 602 | 603 | if error: 604 | raise CursiveError( 605 | message=f'Error while running function ${func_call["name"]}', 606 | details=error, 607 | code=CursiveErrorCode.function_call_error, 608 | ) 609 | 610 | messages = payload.messages or [] 611 | 612 | messages.append(CompletionMessage( 613 | role='function', 614 | name=func_call['name'], 615 | content=json.dumps(function_result or ''), 616 | )) 617 | 618 | if function_definition.pause: 619 | 620 | completion.function_result = function_result 621 | return CursiveAskModelResponse( 622 | answer=CreateChatCompletionResponseExtended(**completion.dict()), 623 | messages=messages, 624 | ) 625 | else: 626 | return ask_model(**{ 627 | **resolved_options, 628 | 'functions': functions, 629 | 'messages': messages, 630 | 'cursive': cursive, 631 | }) 632 | 633 | end = time.time() 634 | hook_payload = CursiveHookPayload(data=completion, duration=end - start) 635 | cursive._hooks.call_hook('ask:after', hook_payload) 636 | cursive._hooks.call_hook('ask:success', hook_payload) 637 | 638 | return CursiveAskModelResponse( 639 | answer=completion, 640 | messages=payload.messages or [], 641 | ) 642 | 643 | 644 | def build_answer( 645 | cursive, 646 | model: Optional[CursiveAvailableModels] = None, 647 | system_message: Optional[str] = None, 648 | functions: Optional[list[CursiveFunction]] = None, 649 | function_call: Optional[str | CursiveFunction] = None, 650 | on_token: Optional[CursiveAskOnToken] = None, 651 | max_tokens: Optional[int] = None, 652 | stop: Optional[list[str]] = None, 653 | temperature: Optional[int] = None, 654 | top_p: Optional[int] = None, 655 | presence_penalty: Optional[int] = None, 656 | frequency_penalty: Optional[int] = None, 657 | best_of: Optional[int] = None, 658 | n: Optional[int] = None, 659 | logit_bias: Optional[dict[str, int]] = None, 660 | user: Optional[str] = None, 661 | stream: Optional[bool] = None, 662 | messages: Optional[list[CompletionMessage]] = None, 663 | prompt: Optional[str] = None, 664 | ): 665 | result, error = resguard( 666 | lambda: ask_model( 667 | cursive=cursive, 668 | model=model, 669 | system_message=system_message, 670 | functions=functions, 671 | function_call=function_call, 672 | on_token=on_token, 673 | max_tokens=max_tokens, 674 | stop=stop, 675 | temperature=temperature, 676 | top_p=top_p, 677 | presence_penalty=presence_penalty, 678 | frequency_penalty=frequency_penalty, 679 | best_of=best_of, 680 | n=n, 681 | logit_bias=logit_bias, 682 | user=user, 683 | stream=stream, 684 | messages=messages, 685 | prompt=prompt, 686 | ), 687 | CursiveError 688 | ) 689 | 690 | if error: 691 | CursiveEnrichedAnswer( 692 | error=error, 693 | usage=None, 694 | model=model or 'gpt-3.5-turbo', 695 | id=None, 696 | choices=None, 697 | function_result=None, 698 | answer=None, 699 | messages=None, 700 | cost=None, 701 | ) 702 | else: 703 | usage = CursiveAskUsage( 704 | completion_tokens=result.answer.usage['completion_tokens'], 705 | prompt_tokens=result.answer.usage['prompt_tokens'], 706 | total_tokens=result.answer.usage['total_tokens'], 707 | ) if result.answer.usage else None 708 | 709 | return CursiveEnrichedAnswer( 710 | error=None, 711 | model=result.answer.model, 712 | id=result.answer.id, 713 | usage=usage, 714 | cost=result.answer.cost, 715 | choices=list( 716 | map(lambda choice: choice['message']['content'], result.answer.choices) 717 | ), 718 | function_result=result.answer.function_result or None, 719 | answer=result.answer.choices[-1]['message']['content'], 720 | messages=result.messages, 721 | ) 722 | 723 | 724 | 725 | class CursiveConversation: 726 | _cursive: Cursive 727 | messages: list[CompletionMessage] 728 | 729 | def __init__(self, messages: list[CompletionMessage]): 730 | self.messages = messages 731 | 732 | 733 | def ask( 734 | self, 735 | model: Optional[CursiveAvailableModels] = None, 736 | system_message: Optional[str] = None, 737 | functions: Optional[list[CursiveFunction]] = None, 738 | function_call: Optional[str | CursiveFunction] = None, 739 | on_token: Optional[CursiveAskOnToken] = None, 740 | max_tokens: Optional[int] = None, 741 | stop: Optional[list[str]] = None, 742 | temperature: Optional[int] = None, 743 | top_p: Optional[int] = None, 744 | presence_penalty: Optional[int] = None, 745 | frequency_penalty: Optional[int] = None, 746 | best_of: Optional[int] = None, 747 | n: Optional[int] = None, 748 | logit_bias: Optional[dict[str, int]] = None, 749 | user: Optional[str] = None, 750 | stream: Optional[bool] = None, 751 | prompt: Optional[str] = None, 752 | ): 753 | messages=[ 754 | *self.messages, 755 | ] 756 | 757 | result = build_answer( 758 | cursive=self._cursive, 759 | model=model, 760 | system_message=system_message, 761 | functions=functions, 762 | function_call=function_call, 763 | on_token=on_token, 764 | max_tokens=max_tokens, 765 | stop=stop, 766 | temperature=temperature, 767 | top_p=top_p, 768 | presence_penalty=presence_penalty, 769 | frequency_penalty=frequency_penalty, 770 | best_of=best_of, 771 | n=n, 772 | logit_bias=logit_bias, 773 | user=user, 774 | stream=stream, 775 | messages=messages, 776 | prompt=prompt, 777 | ) 778 | 779 | if result and result.error: 780 | return CursiveAnswer[CursiveError](error=result.error) 781 | 782 | new_messages = [ 783 | *(result and result.messages or []), 784 | CompletionMessage(role='assistant', content=result and result.answer or ''), 785 | ] 786 | 787 | return CursiveAnswer[None]( 788 | result=result, 789 | messages=new_messages, 790 | cursive=self._cursive, 791 | ) 792 | 793 | 794 | def use_cursive( 795 | max_retries: Optional[int] = None, 796 | expand: Optional[CursiveSetupOptionsExpand] = None, 797 | debug: Optional[bool] = None, 798 | openai: Optional[dict[str, Any]] = None, 799 | anthropic: Optional[dict[str, Any]] = None, 800 | ): 801 | return Cursive( 802 | max_retries=max_retries, 803 | expand=expand, 804 | debug=debug, 805 | openai=openai, 806 | anthropic=anthropic, 807 | ) 808 | 809 | 810 | E = TypeVar("E", None, CursiveError) 811 | 812 | class CursiveAnswer(Generic[E]): 813 | choices: Optional[list[str]] 814 | id: Optional[str] 815 | model: Optional[str] 816 | usage: Optional[CursiveAskUsage] 817 | cost: Optional[CursiveAskCost] 818 | error: Optional[E] 819 | function_result: Optional[Any] 820 | # The text from the answer of the last choice 821 | answer: Optional[str] 822 | # A conversation instance with all the messages so far, including this one 823 | conversation: Optional[CursiveConversation] 824 | 825 | def __init__( 826 | self, 827 | result: Optional[Any] = None, 828 | error: Optional[E] = None, 829 | messages: Optional[list[CompletionMessage]] = None, 830 | cursive: Optional[Cursive] = None, 831 | ): 832 | if error: 833 | self.error = error 834 | self.choices = None 835 | self.id = None 836 | self.model = None 837 | self.usage = None 838 | self.cost = None 839 | self.answer = None 840 | self.conversation = None 841 | self.functionResult = None 842 | elif result: 843 | self.error = None 844 | self.choices = result.choices 845 | self.id = result.id 846 | self.model = result.model 847 | self.usage = result.usage 848 | self.cost = result.cost 849 | self.answer = result.answer 850 | self.function_result = result.function_result 851 | if messages: 852 | conversation = CursiveConversation(messages) 853 | if cursive: 854 | conversation._cursive = cursive 855 | self.conversation = conversation 856 | 857 | def __str__(self): 858 | if self.error: 859 | return f"CursiveAnswer(error={self.error})" 860 | else: 861 | return ( 862 | f"CursiveAnswer(\n\tchoices={self.choices}\n\tid={self.id}\n\t" 863 | f"model={self.model}\n\tusage=(\n\t\t{self.usage}\n\t)\n\tcost=(\n\t\t{self.cost}\n\t)\n\t" 864 | f"answer={self.answer}\n\tconversation={self.conversation}\n)" 865 | ) 866 | 867 | class CursiveVendors(BaseModel): 868 | openai: Optional[Any] = None 869 | anthropic: Optional[AnthropicClient] = None 870 | cohere: Optional[CohereClient] = None 871 | --------------------------------------------------------------------------------