├── .coveragerc ├── .gitignore ├── LICENSE ├── README.md ├── promptmodel ├── STARTER.py ├── __init__.py ├── apis │ ├── __init__.py │ └── base.py ├── chat_model.py ├── cli │ ├── __init__.py │ ├── commands │ │ ├── __init__.py │ │ ├── configure.py │ │ ├── connect.py │ │ ├── dev.py │ │ ├── fix.py │ │ ├── init.py │ │ ├── login.py │ │ └── project.py │ ├── main.py │ ├── signal_handler.py │ └── utils.py ├── constants.py ├── database │ ├── __init__.py │ ├── config.py │ ├── crud.py │ ├── crud_chat.py │ ├── models.py │ ├── models_chat.py │ └── orm.py ├── dev_app.py ├── function_model.py ├── llms │ ├── __init__.py │ ├── llm.py │ ├── llm_dev.py │ └── llm_proxy.py ├── promptmodel_init.py ├── types │ ├── __init__.py │ ├── enums.py │ ├── request.py │ └── response.py ├── unit_logger.py ├── utils │ ├── __init__.py │ ├── async_utils.py │ ├── config_utils.py │ ├── crypto.py │ ├── logger.py │ ├── output_utils.py │ ├── random_utils.py │ └── token_counting.py └── websocket │ ├── __init__.py │ ├── reload_handler.py │ └── websocket_client.py ├── publish.sh ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── api_client ├── __init__.py └── api_client_test.py ├── chat_model ├── __init__.py ├── chat_model_test.py ├── conftest.py └── registering_meta_test.py ├── cli └── dev_test.py ├── constants.py ├── function_model ├── __init__.py ├── conftest.py ├── function_model_test.py └── registering_meta_test.py ├── llm ├── __init__.py ├── error_case_test.py ├── function_call_test.py ├── llm_dev │ ├── __init__.py │ └── llm_dev_test.py ├── llm_proxy │ ├── __init__.py │ ├── conftest.py │ ├── proxy_chat_test.py │ └── proxy_test.py ├── parse_test.py ├── stream_function_call_test.py ├── stream_test.py ├── stream_tool_calls_test.py └── tool_calls_test.py ├── utils └── async_util_test.py └── websocket_client ├── __init__.py ├── conftest.py ├── local_task_test.py ├── run_chatmodel_test.py └── run_promptmodel_test.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: no cover 4 | def __repr__ 5 | if self\.debug: 6 | raise AssertionError 7 | raise NotImplementedError 8 | if 0: 9 | if __name__ == .__main__.: -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | venv 3 | .ipynb_checkpoints 4 | *.ipynb 5 | .env 6 | 7 | .fastllm 8 | .fastllm/* 9 | .promptmodel 10 | .promptmodel/* 11 | 12 | */__pycache__ 13 | __pycache__ 14 | 15 | .vscode 16 | .vscode/* 17 | */.DS_Store 18 | *.egg-info 19 | 20 | ipynb_test 21 | ipynb_test/* 22 | testapp 23 | testapp/* 24 | 25 | build 26 | build/* 27 | dist 28 | dist/* 29 | 30 | .coverage 31 | key.key -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |

Promptmodel

6 |

7 |

Prompt & model versioning on the cloud, built for developers.

8 |

9 | We are currently on closed alpha. 10 | Request access 🚀 11 |

12 |

13 |
14 | 15 | PyPI Version 16 | 17 | 18 | 19 | 20 |
21 |
22 | 23 | ## Installation 24 | 25 | ```bash 26 | pip install promptmodel 27 | ``` 28 | 29 | ## Documentation 30 | 31 | You can find our full documentation [here](https://www.promptmodel.run/docs/introduction). 32 | -------------------------------------------------------------------------------- /promptmodel/STARTER.py: -------------------------------------------------------------------------------- 1 | """This single file is needed to build the DevApp development dashboard.""" 2 | 3 | from promptmodel import DevApp 4 | 5 | # Example imports 6 | # from import 7 | 8 | app = DevApp() 9 | 10 | # Example usage 11 | # This is needed to integrate your codebase with the prompt engineering dashboard 12 | # app.include_client() 13 | -------------------------------------------------------------------------------- /promptmodel/__init__.py: -------------------------------------------------------------------------------- 1 | from .dev_app import DevClient, DevApp 2 | from .function_model import FunctionModel, PromptModel 3 | from .chat_model import ChatModel 4 | from .promptmodel_init import init 5 | from .unit_logger import UnitLogger 6 | 7 | __version__ = "0.1.19" 8 | 9 | -------------------------------------------------------------------------------- /promptmodel/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /promptmodel/apis/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import requests 5 | 6 | import httpx 7 | from rich import print 8 | 9 | from promptmodel.utils.config_utils import read_config 10 | from promptmodel.constants import ENDPOINT_URL 11 | from promptmodel.utils.crypto import decrypt_message 12 | 13 | 14 | class APIClient: 15 | """ 16 | A class to represent an API request client. 17 | 18 | ... 19 | 20 | Methods 21 | ------- 22 | get_headers(): 23 | Generates headers for the API request. 24 | execute(method="GET", params=None, data=None, json=None, **kwargs): 25 | Executes the API request. 26 | """ 27 | 28 | @classmethod 29 | def _get_headers(cls, use_cli_key: bool = True) -> Dict: 30 | """ 31 | Reads, decrypts the api_key, and returns headers for API request. 32 | 33 | Returns 34 | ------- 35 | dict 36 | a dictionary containing the Authorization header 37 | """ 38 | config = read_config() 39 | if use_cli_key: 40 | if "connection" not in config: 41 | print( 42 | "User not logged in. Please run [violet]prompt login[/violet] first." 43 | ) 44 | exit() 45 | 46 | encrypted_key = ( 47 | config["connection"]["encrypted_api_key"] 48 | if "encrypted_api_key" in config["connection"] 49 | else None 50 | ) 51 | if encrypted_key is None: 52 | raise Exception("No API key found. Please run 'prompt login' first.") 53 | decrypted_key = decrypt_message(encrypted_key) 54 | else: 55 | decrypted_key = os.environ.get("PROMPTMODEL_API_KEY") 56 | headers = {"Authorization": f"Bearer {decrypted_key}"} 57 | return headers 58 | 59 | @classmethod 60 | def execute( 61 | cls, 62 | path: str, 63 | method="GET", 64 | params: Dict = None, 65 | data: Dict = None, 66 | json: Dict = None, 67 | ignore_auth_error: bool = False, 68 | use_cli_key: bool = True, 69 | **kwargs, 70 | ) -> requests.Response: 71 | """ 72 | Executes the API request with the decrypted API key in the headers. 73 | 74 | Parameters 75 | ---------- 76 | method : str, optional 77 | The HTTP method of the request (default is "GET") 78 | params : dict, optional 79 | The URL parameters to be sent with the request 80 | data : dict, optional 81 | The request body to be sent with the request 82 | json : dict, optional 83 | The JSON-encoded request body to be sent with the request 84 | ignore_auth_error: bool, optional 85 | Whether to ignore authentication errors (default is False) 86 | **kwargs : dict 87 | Additional arguments to pass to the requests.request function 88 | 89 | Returns 90 | ------- 91 | requests.Response 92 | The response object returned by the requests library 93 | """ 94 | url = f"{ENDPOINT_URL}{path}" 95 | headers = cls._get_headers(use_cli_key) 96 | try: 97 | response = requests.request( 98 | method, 99 | url, 100 | headers=headers, 101 | params=params, 102 | data=data, 103 | json=json, 104 | **kwargs, 105 | ) 106 | if not response: 107 | print(f"[red]Error: {response}[/red]") 108 | exit() 109 | if response.status_code == 200: 110 | return response 111 | elif response.status_code == 403: 112 | if not ignore_auth_error: 113 | print( 114 | "[red]Authentication failed. Please run [violet][bold]prompt login[/bold][/violet] first.[/red]" 115 | ) 116 | exit() 117 | else: 118 | print(f"[red]Error: {response}[/red]") 119 | exit() 120 | except requests.exceptions.ConnectionError: 121 | print("[red]Could not connect to the Promptmodel API.[/red]") 122 | except requests.exceptions.Timeout: 123 | print("[red]The request timed out.[/red]") 124 | 125 | 126 | class AsyncAPIClient: 127 | """ 128 | A class to represent an Async API request client. 129 | Used in Deployment stage. 130 | 131 | ... 132 | 133 | Methods 134 | ------- 135 | get_headers(): 136 | Generates headers for the API request. 137 | execute(method="GET", params=None, data=None, json=None, **kwargs): 138 | Executes the API request. 139 | """ 140 | 141 | @classmethod 142 | async def _get_headers(cls, use_cli_key: bool = True) -> Dict: 143 | """ 144 | Reads, decrypts the api_key, and returns headers for API request. 145 | 146 | Returns 147 | ------- 148 | dict 149 | a dictionary containing the Authorization header 150 | """ 151 | config = read_config() 152 | if use_cli_key: 153 | if "connection" not in config: 154 | print( 155 | "User not logged in. Please run [violet]prompt login[/violet] first." 156 | ) 157 | exit() 158 | 159 | encrypted_key = config["connection"]["encrypted_api_key"] 160 | if encrypted_key is None: 161 | raise Exception("No API key found. Please run 'prompt login' first.") 162 | decrypted_key = decrypt_message(encrypted_key) 163 | else: 164 | decrypted_key = os.environ.get("PROMPTMODEL_API_KEY") 165 | if decrypted_key is None: 166 | raise Exception( 167 | "PROMPTMODEL_API_KEY was not found in the current environment." 168 | ) 169 | headers = {"Authorization": f"Bearer {decrypted_key}"} 170 | return headers 171 | 172 | @classmethod 173 | async def execute( 174 | cls, 175 | path: str, 176 | method="GET", 177 | params: Dict = None, 178 | data: Dict = None, 179 | json: Dict = None, 180 | ignore_auth_error: bool = False, 181 | use_cli_key: bool = True, 182 | **kwargs, 183 | ) -> requests.Response: 184 | """ 185 | Executes the API request with the decrypted API key in the headers. 186 | 187 | Parameters 188 | ---------- 189 | method : str, optional 190 | The HTTP method of the request (default is "GET") 191 | params : dict, optional 192 | The URL parameters to be sent with the request 193 | data : dict, optional 194 | The request body to be sent with the request 195 | json : dict, optional 196 | The JSON-encoded request body to be sent with the request 197 | ignore_auth_error: bool, optional 198 | Whether to ignore authentication errors (default is False) 199 | **kwargs : dict 200 | Additional arguments to pass to the requests.request function 201 | 202 | Returns 203 | ------- 204 | requests.Response 205 | The response object returned by the requests library 206 | """ 207 | url = f"{ENDPOINT_URL}{path}" 208 | headers = await cls._get_headers(use_cli_key) 209 | try: 210 | async with httpx.AsyncClient(http2=True) as _client: 211 | response = await _client.request( 212 | method, 213 | url, 214 | headers=headers, 215 | params=params, 216 | data=data, 217 | json=json, 218 | **kwargs, 219 | ) 220 | if not response: 221 | print(f"[red]Error: {response}[/red]") 222 | if response.status_code == 200: 223 | return response 224 | elif response.status_code == 403: 225 | if not ignore_auth_error: 226 | print("[red]Authentication failed.[/red]") 227 | else: 228 | print(f"[red]Error: {response}[/red]") 229 | 230 | return response 231 | except requests.exceptions.ConnectionError: 232 | print("[red]Could not connect to the Promptmodel API.[/red]") 233 | except requests.exceptions.Timeout: 234 | print("[red]The request timed out.[/red]") 235 | except Exception as exception: 236 | print(f"[red]Error: {exception}[/red]") 237 | -------------------------------------------------------------------------------- /promptmodel/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/configure.py: -------------------------------------------------------------------------------- 1 | from promptmodel.apis.base import APIClient 2 | import typer 3 | from InquirerPy import inquirer 4 | 5 | from promptmodel.utils.config_utils import upsert_config 6 | 7 | 8 | def configure(): 9 | """Saves user's default organization and project.""" 10 | 11 | orgs = APIClient.execute(method="GET", path="/organizations").json() 12 | choices = [ 13 | { 14 | "key": org["name"], 15 | "name": org["name"], 16 | "value": org, 17 | } 18 | for org in orgs 19 | ] 20 | org = inquirer.select( 21 | message="Select default organization:", choices=choices 22 | ).execute() 23 | 24 | projects = APIClient.execute( 25 | method="GET", 26 | path="/projects", 27 | params={"organization_id": org["organization_id"]}, 28 | ).json() 29 | choices = [ 30 | { 31 | "key": project["name"], 32 | "name": project["name"], 33 | "value": project, 34 | } 35 | for project in projects 36 | ] 37 | project = inquirer.select( 38 | message="Select default project:", choices=choices 39 | ).execute() 40 | 41 | upsert_config({"org": org, "project": project}, section="connection") 42 | 43 | 44 | app = typer.Typer(invoke_without_command=True, callback=configure) 45 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/connect.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio 3 | import typer 4 | import importlib 5 | import signal 6 | from typing import Dict, Any, List 7 | from playhouse.shortcuts import model_to_dict 8 | 9 | import webbrowser 10 | from rich import print 11 | from InquirerPy import inquirer 12 | from watchdog.observers import Observer 13 | 14 | from promptmodel import DevApp 15 | from promptmodel.apis.base import APIClient 16 | from promptmodel.constants import ENDPOINT_URL, WEB_CLIENT_URL 17 | from promptmodel.cli.commands.init import init as promptmodel_init 18 | from promptmodel.cli.utils import get_org, get_project 19 | from promptmodel.cli.signal_handler import dev_terminate_signal_handler 20 | from promptmodel.utils.config_utils import read_config, upsert_config 21 | from promptmodel.websocket import DevWebsocketClient, CodeReloadHandler 22 | from promptmodel.database.orm import initialize_db 23 | 24 | 25 | def connect(): 26 | """Connect websocket and opens up DevApp in the browser.""" 27 | upsert_config({"initializing": True}, "connection") 28 | signal.signal(signal.SIGINT, dev_terminate_signal_handler) 29 | promptmodel_init(from_cli=False) 30 | 31 | config = read_config() 32 | 33 | if "project" not in config["connection"]: 34 | org = get_org(config) 35 | project = get_project(config=config, org=org) 36 | 37 | # connect 38 | res = APIClient.execute( 39 | method="POST", 40 | path="/project/cli_connect", 41 | params={"project_uuid": project["uuid"]}, 42 | ) 43 | if res.status_code != 200: 44 | print(f"Error: {res.json()['detail']}") 45 | return 46 | 47 | upsert_config( 48 | { 49 | "project": project, 50 | "org": org, 51 | }, 52 | section="connection", 53 | ) 54 | 55 | else: 56 | org = config["connection"]["org"] 57 | project = config["connection"]["project"] 58 | 59 | res = APIClient.execute( 60 | method="POST", 61 | path="/project/cli_connect", 62 | params={"project_uuid": project["uuid"]}, 63 | ) 64 | if res.status_code != 200: 65 | print(f"Error: {res.json()['detail']}") 66 | return 67 | 68 | _devapp_filename, devapp_instance_name = "promptmodel_dev:app".split(":") 69 | 70 | devapp_module = importlib.import_module(_devapp_filename) 71 | devapp_instance: DevApp = getattr(devapp_module, devapp_instance_name) 72 | 73 | dev_url = f"{WEB_CLIENT_URL}/org/{org['slug']}/projects/{project['uuid']}" 74 | 75 | # Open websocket connection to backend server 76 | dev_websocket_client = DevWebsocketClient(_devapp=devapp_instance) 77 | 78 | import threading 79 | 80 | try: 81 | main_loop = asyncio.get_running_loop() 82 | except RuntimeError: 83 | main_loop = asyncio.new_event_loop() 84 | asyncio.set_event_loop(main_loop) 85 | 86 | # save samples, FunctionSchema, FunctionModel, ChatModel to cloud server in dev_websocket_client.connect_to_gateway 87 | res = APIClient.execute( 88 | method="POST", 89 | path="/save_instances_in_code", 90 | params={"project_uuid": project["uuid"]}, 91 | json={ 92 | "function_models": devapp_instance._get_function_model_name_list(), 93 | "chat_models": devapp_instance._get_chat_model_name_list(), 94 | "function_schemas": devapp_instance._get_function_schema_list(), 95 | "samples": devapp_instance.samples, 96 | }, 97 | ) 98 | if res.status_code != 200: 99 | print(f"Error: {res.json()['detail']}") 100 | return 101 | 102 | print( 103 | f"\nOpening [violet]Promptmodel[/violet] prompt engineering environment with the following configuration:\n" 104 | ) 105 | print(f"📌 Organization: [blue]{org['name']}[/blue]") 106 | print(f"📌 Project: [blue]{project['name']}[/blue]") 107 | print( 108 | f"\nIf browser doesn't open automatically, please visit [link={dev_url}]{dev_url}[/link]" 109 | ) 110 | webbrowser.open(dev_url) 111 | 112 | upsert_config({"online": True, "initializing": False}, section="connection") 113 | 114 | reloader_thread = threading.Thread( 115 | target=start_code_reloader, 116 | args=(_devapp_filename, devapp_instance_name, dev_websocket_client, main_loop), 117 | ) 118 | reloader_thread.daemon = True # Set the thread as a daemon 119 | reloader_thread.start() 120 | 121 | # Open Websocket 122 | asyncio.run( 123 | dev_websocket_client.connect_to_gateway( 124 | project_uuid=project["uuid"], 125 | connection_name=project["name"], 126 | cli_access_header=APIClient._get_headers(), 127 | ) 128 | ) 129 | 130 | 131 | app = typer.Typer(invoke_without_command=True, callback=connect) 132 | 133 | 134 | def start_code_reloader( 135 | _devapp_filename, devapp_instance_name, dev_websocket_client, main_loop 136 | ): 137 | event_handler = CodeReloadHandler( 138 | _devapp_filename, devapp_instance_name, dev_websocket_client, main_loop 139 | ) 140 | observer = Observer() 141 | observer.schedule(event_handler, path=".", recursive=True) 142 | observer.start() 143 | 144 | try: 145 | while True: 146 | time.sleep(1) 147 | except KeyboardInterrupt: 148 | observer.stop() 149 | observer.join() 150 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/fix.py: -------------------------------------------------------------------------------- 1 | import typer 2 | import os 3 | from promptmodel.utils.config_utils import read_config, upsert_config 4 | from promptmodel.database.orm import initialize_db 5 | 6 | 7 | def fix(): 8 | """Fix Bugs that can be occured in the promptmodel.""" 9 | config = read_config() 10 | if "connection" in config: 11 | connection = config["connection"] 12 | if "initializing" in connection and connection["initializing"] == True: 13 | upsert_config({"initializing": False}, "connection") 14 | if "online" in connection and connection["online"] == True: 15 | upsert_config({"online": False}, "connection") 16 | if "reloading" in connection and connection["reloading"] == True: 17 | upsert_config({"reloading": False}, "connection") 18 | # delete .promptmodel/promptmodel.db 19 | # if .promptmodel/promptmodel.db exist, delete it 20 | if os.path.exists(".promptmodel/promptmodel.db"): 21 | os.remove(".promptmodel/promptmodel.db") 22 | # initialize_db() 23 | 24 | return 25 | 26 | 27 | app = typer.Typer(invoke_without_command=True, callback=fix) 28 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/init.py: -------------------------------------------------------------------------------- 1 | from importlib import resources 2 | import typer 3 | from rich import print 4 | from promptmodel.constants import ( 5 | PROMPTMODEL_DEV_FILENAME, 6 | PROMPTMODEL_DEV_STARTER_FILENAME, 7 | ) 8 | 9 | 10 | def init(from_cli: bool = True): 11 | """Initialize a new promptmodel project.""" 12 | import os 13 | 14 | if not os.path.exists(PROMPTMODEL_DEV_FILENAME): 15 | # Read the content from the source file 16 | content = resources.read_text("promptmodel", PROMPTMODEL_DEV_STARTER_FILENAME) 17 | 18 | # Write the content to the target file 19 | with open(PROMPTMODEL_DEV_FILENAME, "w") as target_file: 20 | target_file.write(content) 21 | print( 22 | "[violet][bold]promptmodel_dev.py[/bold][/violet] was successfully created!" 23 | ) 24 | print( 25 | "Add promptmodels in your code, then run [violet][bold]prompt dev[/bold][/violet] to start engineering prompts." 26 | ) 27 | elif from_cli: 28 | print( 29 | "[yellow]promptmodel_dev.py[/yellow] was already initialized in this directory." 30 | ) 31 | print( 32 | "Run [violet][bold]prompt dev[/bold][/violet] to start engineering prompts." 33 | ) 34 | 35 | 36 | app = typer.Typer(invoke_without_command=True, callback=init) 37 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/login.py: -------------------------------------------------------------------------------- 1 | import time 2 | from promptmodel.apis.base import APIClient 3 | import typer 4 | import webbrowser 5 | from rich import print 6 | 7 | from promptmodel.constants import ENDPOINT_URL, GRANT_ACCESS_URL 8 | from promptmodel.utils.config_utils import upsert_config 9 | from promptmodel.utils.crypto import generate_api_key, encrypt_message 10 | 11 | 12 | def login(): 13 | """Authenticate Client CLI.""" 14 | # TODO: Check if already logged in 15 | api_key = generate_api_key() 16 | encrypted_key = encrypt_message(api_key) 17 | upsert_config({"encrypted_api_key": encrypted_key}, section="connection") 18 | url = f"{GRANT_ACCESS_URL}?token={api_key}" 19 | webbrowser.open(url) 20 | print("Please grant access to the CLI by visiting the URL in your browser.") 21 | print("Once you have granted access, you can close the browser tab.") 22 | print(f"\nURL: [link={url}]{url}[/link]\n") 23 | print("Waiting...\n") 24 | waiting_time = 0 25 | while waiting_time < 300: 26 | # Check access every 5 seconds 27 | try: 28 | res = APIClient.execute("/cli_access/check", ignore_auth_error=True) 29 | if res.json() == True: 30 | print("[green]Access granted![/green] 🎉") 31 | print( 32 | "Run [violet][bold]prompt init[/bold][/violet] to start developing prompts.\n" 33 | ) 34 | return 35 | except Exception as err: 36 | print(f"[red]Error: {err}[/red]") 37 | time.sleep(5) 38 | waiting_time += 5 39 | print("Please try again later.") 40 | 41 | 42 | app = typer.Typer(invoke_without_command=True, callback=login) 43 | -------------------------------------------------------------------------------- /promptmodel/cli/commands/project.py: -------------------------------------------------------------------------------- 1 | """CLI for project management.""" 2 | 3 | from promptmodel.apis.base import APIClient 4 | import typer 5 | from rich import print 6 | 7 | from promptmodel.utils.config_utils import read_config, upsert_config 8 | from ..utils import get_org 9 | 10 | app = typer.Typer(no_args_is_help=True, short_help="Manage Client projects.") 11 | 12 | 13 | @app.command() 14 | def list(): 15 | """List all projects.""" 16 | config = read_config() 17 | org = get_org(config) 18 | projects = APIClient.execute( 19 | method="GET", 20 | path="/projects", 21 | params={"organization_id": org["organization_id"]}, 22 | ).json() 23 | print("\nProjects:") 24 | for project in projects: 25 | print(f"📌 {project['name']} ({project['version']})") 26 | if project["description"]: 27 | print(f" {project['description']}\n") 28 | -------------------------------------------------------------------------------- /promptmodel/cli/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import typer 4 | from promptmodel.cli.commands.login import app as login 5 | from promptmodel.cli.commands.init import app as init 6 | 7 | # from promptmodel.cli.commands.dev import app as dev 8 | from promptmodel.cli.commands.connect import app as connect 9 | from promptmodel.cli.commands.project import app as project 10 | from promptmodel.cli.commands.configure import app as configure 11 | from promptmodel.cli.commands.fix import app as fix 12 | 13 | # 현재 작업 디렉토리를 sys.path에 추가 14 | current_working_directory = os.getcwd() 15 | if current_working_directory not in sys.path: 16 | sys.path.append(current_working_directory) 17 | 18 | app = typer.Typer(no_args_is_help=True, pretty_exceptions_enable=False) 19 | 20 | app.add_typer(login, name="login") 21 | app.add_typer(init, name="init") 22 | # app.add_typer(dev, name="dev") 23 | app.add_typer(connect, name="connect") 24 | app.add_typer(project, name="project") 25 | app.add_typer(configure, name="configure") 26 | app.add_typer(fix, name="fix") 27 | 28 | if __name__ == "__main__": 29 | app() 30 | -------------------------------------------------------------------------------- /promptmodel/cli/signal_handler.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import sys 3 | from promptmodel.utils.config_utils import upsert_config, read_config 4 | 5 | 6 | def dev_terminate_signal_handler(sig, frame): 7 | config = read_config() 8 | print("\nTerminating...") 9 | if "connection" in config: 10 | upsert_config({"online": False}, section="connection") 11 | upsert_config({"initializing": False}, "connection") 12 | sys.exit(0) 13 | -------------------------------------------------------------------------------- /promptmodel/cli/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from InquirerPy import inquirer 3 | from rich import print 4 | from promptmodel.apis.base import APIClient 5 | 6 | 7 | def get_org(config: Dict[str, Any]) -> Dict[str, Any]: 8 | """ 9 | Gets the current organization from the configuration. 10 | 11 | :return: A dictionary containing the current organization. 12 | """ 13 | if "connection" not in config: 14 | print("User not logged in. Please run [violet]prompt login[/violet] first.") 15 | exit() 16 | if "org" not in config["connection"]: 17 | orgs = APIClient.execute(method="GET", path="/organizations").json() 18 | choices = [ 19 | { 20 | "key": org["name"], 21 | "name": org["name"], 22 | "value": org, 23 | } 24 | for org in orgs 25 | ] 26 | org = inquirer.select(message="Select organization:", choices=choices).execute() 27 | else: 28 | org = config["connection"]["org"] 29 | return org 30 | 31 | 32 | def get_project(config: Dict[str, Any], org: Dict[str, Any]) -> Dict[str, Any]: 33 | """ 34 | Gets the current project from the configuration. 35 | 36 | :return: A dictionary containing the current project. 37 | """ 38 | if "project" not in config["connection"]: 39 | projects = APIClient.execute( 40 | method="GET", 41 | path="/projects", 42 | params={"organization_id": org["organization_id"]}, 43 | ).json() 44 | choices = [ 45 | { 46 | "key": project["name"], 47 | "name": project["name"], 48 | "value": project, 49 | } 50 | for project in projects 51 | ] 52 | project = inquirer.select(message="Select project:", choices=choices).execute() 53 | else: 54 | project = config["connection"]["project"] 55 | return project 56 | -------------------------------------------------------------------------------- /promptmodel/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | 4 | load_dotenv() 5 | testmode: str = os.environ.get("TESTMODE", "false") 6 | 7 | if testmode == "true": 8 | ENDPOINT_URL = ( 9 | os.environ.get( 10 | "TESTMODE_PROMPTMODEL_BACKEND_PUBLIC_URL", "http://localhost:8000" 11 | ) 12 | + "/api/cli" 13 | ) 14 | WEB_CLIENT_URL = os.environ.get( 15 | "TESTMODE_PROMPTMODEL_FRONTEND_PUBLIC_URL", "http://localhost:3000" 16 | ) 17 | GRANT_ACCESS_URL = WEB_CLIENT_URL + "/cli/grant-access" 18 | else: 19 | ENDPOINT_URL = ( 20 | os.environ.get( 21 | "PROMPTMODEL_BACKEND_PUBLIC_URL", "https://promptmodel.up.railway.app" 22 | ) 23 | + "/api/cli" 24 | ) 25 | 26 | WEB_CLIENT_URL = os.environ.get( 27 | "PROMPTMODEL_FRONTEND_PUBLIC_URL", "https://app.promptmodel.run" 28 | ) 29 | GRANT_ACCESS_URL = WEB_CLIENT_URL + "/cli/grant-access" 30 | 31 | PROMPTMODEL_DEV_FILENAME = os.path.join(os.getcwd(), "promptmodel_dev.py") 32 | PROMPTMODEL_DEV_STARTER_FILENAME = "STARTER.py" 33 | -------------------------------------------------------------------------------- /promptmodel/database/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /promptmodel/database/config.py: -------------------------------------------------------------------------------- 1 | from peewee import SqliteDatabase, Model 2 | 3 | db = SqliteDatabase("./.promptmodel/promptmodel.db") 4 | 5 | 6 | class BaseModel(Model): 7 | class Meta: 8 | database = db 9 | -------------------------------------------------------------------------------- /promptmodel/database/crud.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | from promptmodel.database.models import ( 3 | DeployedFunctionModel, 4 | DeployedFunctionModelVersion, 5 | DeployedPrompt, 6 | ) 7 | from playhouse.shortcuts import model_to_dict 8 | from promptmodel.utils.random_utils import select_version_by_ratio 9 | from promptmodel.utils import logger 10 | from promptmodel.database.config import db 11 | from promptmodel.database.crud_chat import * 12 | 13 | 14 | # Insert 15 | 16 | # Select all 17 | 18 | # Select one 19 | 20 | 21 | def get_deployed_prompts(function_model_name: str) -> Tuple[List[DeployedPrompt], str]: 22 | try: 23 | with db.atomic(): 24 | versions: List[DeployedFunctionModelVersion] = list( 25 | DeployedFunctionModelVersion.select() 26 | .join(DeployedFunctionModel) 27 | .where( 28 | DeployedFunctionModelVersion.function_model_uuid 29 | == DeployedFunctionModel.get( 30 | DeployedFunctionModel.name == function_model_name 31 | ).uuid 32 | ) 33 | ) 34 | prompts: List[DeployedPrompt] = list( 35 | DeployedPrompt.select() 36 | .where( 37 | DeployedPrompt.version_uuid.in_( 38 | [version.uuid for version in versions] 39 | ) 40 | ) 41 | .order_by(DeployedPrompt.step.asc()) 42 | ) 43 | # select version by ratio 44 | selected_version = select_version_by_ratio( 45 | [version.__data__ for version in versions] 46 | ) 47 | selected_prompts = list( 48 | filter( 49 | lambda prompt: str(prompt.version_uuid.uuid) 50 | == str(selected_version["uuid"]), 51 | prompts, 52 | ) 53 | ) 54 | 55 | version_details = { 56 | "model": selected_version["model"], 57 | "version" : selected_version["version"], 58 | "uuid": selected_version["uuid"], 59 | "parsing_type": selected_version["parsing_type"], 60 | "output_keys": selected_version["output_keys"], 61 | } 62 | 63 | return selected_prompts, version_details 64 | except Exception as e: 65 | logger.error(e) 66 | return None, None 67 | 68 | 69 | # Update 70 | 71 | 72 | async def update_deployed_cache(project_status: dict): 73 | """Update Deployed Prompts Cache""" 74 | # TODO: 효율적으로 수정 75 | # 현재는 delete all & insert all 76 | function_models = project_status["function_models"] 77 | function_model_versions = project_status["function_model_versions"] 78 | for version in function_model_versions: 79 | if version["is_published"] is True: 80 | version["ratio"] = 1.0 81 | prompts = project_status["prompts"] 82 | 83 | with db.atomic(): 84 | DeployedFunctionModel.delete().execute() 85 | DeployedFunctionModelVersion.delete().execute() 86 | DeployedPrompt.delete().execute() 87 | DeployedFunctionModel.insert_many(function_models).execute() 88 | DeployedFunctionModelVersion.insert_many(function_model_versions).execute() 89 | DeployedPrompt.insert_many(prompts).execute() 90 | return 91 | -------------------------------------------------------------------------------- /promptmodel/database/crud_chat.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Any 2 | from uuid import UUID 3 | from playhouse.shortcuts import model_to_dict 4 | from peewee import fn, JOIN 5 | 6 | from promptmodel.utils import logger 7 | from promptmodel.database.config import db 8 | 9 | 10 | # def delete_fake_sessions(): 11 | # """Delete ChatLogSession which len(ChatLog) == 1 where ChatLog.session_uuid == ChatLogSession.uuid""" 12 | # with db.atomic(): 13 | # sessions_to_delete: List[ChatLogSession] = list( 14 | # ChatLogSession.select() 15 | # .join(ChatLog, JOIN.LEFT_OUTER) 16 | # .group_by(ChatLogSession) 17 | # .having(fn.COUNT(ChatLog.id) <= 1) 18 | # ) 19 | # ( 20 | # ( 21 | # ChatLogSession.delete().where( 22 | # ChatLogSession.uuid.in_( 23 | # [session.uuid for session in sessions_to_delete] 24 | # ) 25 | # ) 26 | # ).execute() 27 | # ) 28 | # return 29 | 30 | 31 | # def hide_chat_model_not_in_code(local_chat_model_list: List): 32 | # return ( 33 | # ChatModel.update(used_in_code=False) 34 | # .where(ChatModel.name.not_in(local_chat_model_list)) 35 | # .execute() 36 | # ) 37 | 38 | 39 | # def update_chat_model_uuid(local_uuid, new_uuid): 40 | # """Update ChatModel.uuid""" 41 | # if str(local_uuid) == str(new_uuid): 42 | # return 43 | # else: 44 | # with db.atomic(): 45 | # local_chat_model: ChatModel = ChatModel.get(ChatModel.uuid == local_uuid) 46 | # ChatModel.create( 47 | # uuid=new_uuid, 48 | # name=local_chat_model.name, 49 | # project_uuid=local_chat_model.project_uuid, 50 | # created_at=local_chat_model.created_at, 51 | # used_in_code=local_chat_model.used_in_code, 52 | # is_deployed=True, 53 | # ) 54 | # ChatModelVersion.update(function_model_uuid=new_uuid).where( 55 | # ChatModelVersion.chat_model_uuid == local_uuid 56 | # ).execute() 57 | # ChatModel.delete().where(ChatModel.uuid == local_uuid).execute() 58 | # return 59 | 60 | 61 | # def update_candidate_chat_model_version(new_candidates: dict): 62 | # """Update candidate ChatModelVersion's candidate_version_id(version)""" 63 | # with db.atomic(): 64 | # for uuid, version in new_candidates.items(): 65 | # ( 66 | # ChatModelVersion.update(version=version, is_deployed=True) 67 | # .where(ChatModelVersion.uuid == uuid) 68 | # .execute() 69 | # ) 70 | # # Find ChatModel 71 | # chat_model_versions: List[ChatModelVersion] = list( 72 | # ChatModelVersion.select().where( 73 | # ChatModelVersion.uuid.in_(list(new_candidates.keys())) 74 | # ) 75 | # ) 76 | # chat_model_uuids = [ 77 | # chat_model.chat_model_uuid.uuid for chat_model in chat_model_versions 78 | # ] 79 | # ChatModel.update(is_deployed=True).where( 80 | # ChatModel.uuid.in_(chat_model_uuids) 81 | # ).execute() 82 | 83 | 84 | # def rename_chat_model(chat_model_uuid: str, new_name: str): 85 | # """Update the name of the given ChatModel.""" 86 | # return ( 87 | # ChatModel.update(name=new_name) 88 | # .where(ChatModel.uuid == chat_model_uuid) 89 | # .execute() 90 | # ) 91 | 92 | 93 | # def fetch_chat_log_with_uuid(session_uuid: str): 94 | # try: 95 | # try: 96 | # chat_logs: List[ChatLog] = list( 97 | # ChatLog.select() 98 | # .where(ChatLog.session_uuid == UUID(session_uuid)) 99 | # .order_by(ChatLog.created_at.asc()) 100 | # ) 101 | # except: 102 | # return [] 103 | 104 | # chat_log_to_return = [ 105 | # { 106 | # "role": chat_log.role, 107 | # "content": chat_log.content, 108 | # "tool_calls": chat_log.tool_calls, 109 | # } 110 | # for chat_log in chat_logs 111 | # ] 112 | # return chat_log_to_return 113 | # except Exception as e: 114 | # logger.error(e) 115 | # return None 116 | 117 | 118 | # def get_latest_version_chat_model( 119 | # chat_model_name: str, 120 | # session_uuid: Optional[str] = None, 121 | # ) -> Tuple[List[Dict[str, Any]], str]: 122 | # try: 123 | # if session_uuid: 124 | # if type(session_uuid) == str: 125 | # session_uuid = UUID(session_uuid) 126 | # session: ChatLogSession = ChatLogSession.get( 127 | # ChatLogSession.uuid == session_uuid 128 | # ) 129 | 130 | # version: ChatModelVersion = ChatModelVersion.get( 131 | # ChatModelVersion.uuid == session.version_uuid 132 | # ) 133 | 134 | # instruction: List[Dict[str, Any]] = [version["system_prompt"]] 135 | # else: 136 | # with db.atomic(): 137 | # try: 138 | # sessions_with_version: List[ChatLogSession] = list( 139 | # ChatLogSession.select() 140 | # .join(ChatModelVersion) 141 | # .where( 142 | # ChatModelVersion.chat_model_uuid 143 | # == ChatModel.get(ChatModel.name == chat_model_name).uuid 144 | # ) 145 | # ) 146 | # session_uuids = [x.uuid for x in sessions_with_version] 147 | 148 | # latest_chat_log: List[ChatLog] = list( 149 | # ChatLog.select() 150 | # .where(ChatLog.session_uuid.in_(session_uuids)) 151 | # .order_by(ChatLog.created_at.desc()) 152 | # ) 153 | 154 | # latest_chat_log: ChatLog = latest_chat_log[0] 155 | # latest_session: ChatLogSession = ChatLogSession.get( 156 | # ChatLogSession.uuid == latest_chat_log.session_uuid 157 | # ) 158 | 159 | # version: ChatModelVersion = ( 160 | # ChatModelVersion.select() 161 | # .where(ChatModelVersion.uuid == latest_session.uuid) 162 | # .get() 163 | # ) 164 | # except: 165 | # version: ChatModelVersion = list( 166 | # ChatModelVersion.select() 167 | # .join(ChatModel) 168 | # .where(ChatModel.name == chat_model_name) 169 | # .order_by(ChatModelVersion.created_at.desc()) 170 | # .get() 171 | # ) 172 | 173 | # instruction: List[Dict[str, Any]] = [version["system_prompt"]] 174 | 175 | # version_details = { 176 | # "model": version.model, 177 | # "uuid": version.uuid, 178 | # } 179 | 180 | # return instruction, version_details 181 | 182 | # except Exception as e: 183 | # logger.error(e) 184 | # return None, None 185 | 186 | 187 | # def find_ancestor_chat_model_version( 188 | # chat_model_version_uuid: str, versions: Optional[list] = None 189 | # ): 190 | # """Find ancestor ChatModel version""" 191 | 192 | # # get all versions 193 | # if versions is None: 194 | # response = list(ChatModelVersion.select()) 195 | # versions = [model_to_dict(x, recurse=False) for x in response] 196 | 197 | # # find target version 198 | # target = list( 199 | # filter(lambda version: version["uuid"] == chat_model_version_uuid, versions) 200 | # )[0] 201 | 202 | # target = _find_ancestor(target, versions) 203 | 204 | # return target 205 | 206 | 207 | # def find_ancestor_chat_model_versions(target_chat_model_uuid: Optional[str] = None): 208 | # """find ancestor versions for each versions in input""" 209 | # # get all versions 210 | # if target_chat_model_uuid is not None: 211 | # response = list( 212 | # ChatModelVersion.select().where( 213 | # ChatModelVersion.chat_model_uuid == target_chat_model_uuid 214 | # ) 215 | # ) 216 | # else: 217 | # response = list(ChatModelVersion.select()) 218 | # versions = [model_to_dict(x, recurse=False) for x in response] 219 | 220 | # targets = list( 221 | # filter( 222 | # lambda version: version["status"] == ModelVersionStatus.CANDIDATE.value 223 | # and version["version"] is None, 224 | # versions, 225 | # ) 226 | # ) 227 | 228 | # targets = [ 229 | # find_ancestor_chat_model_version(target["uuid"], versions) for target in targets 230 | # ] 231 | # targets_with_real_ancestor = [target for target in targets] 232 | 233 | # return targets_with_real_ancestor 234 | 235 | 236 | # def _find_ancestor(target: dict, versions: List[Dict]): 237 | # ancestor = None 238 | # temp = target 239 | # if target["from_uuid"] is None: 240 | # ancestor = None 241 | # else: 242 | # while temp["from_uuid"] is not None: 243 | # new_temp = [ 244 | # version for version in versions if version["uuid"] == temp["from_uuid"] 245 | # ][0] 246 | # if ( 247 | # new_temp["version"] is not None 248 | # or new_temp["status"] == ModelVersionStatus.CANDIDATE.value 249 | # ): 250 | # ancestor = new_temp 251 | # break 252 | # else: 253 | # temp = new_temp 254 | # target["from_uuid"] = ancestor["uuid"] if ancestor is not None else None 255 | 256 | # return target 257 | -------------------------------------------------------------------------------- /promptmodel/database/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | 4 | from uuid import uuid4 5 | from peewee import ( 6 | CharField, 7 | DateTimeField, 8 | IntegerField, 9 | ForeignKeyField, 10 | BooleanField, 11 | UUIDField, 12 | TextField, 13 | AutoField, 14 | FloatField, 15 | Check, 16 | ) 17 | 18 | from promptmodel.database.config import BaseModel 19 | from promptmodel.database.models_chat import * 20 | from promptmodel.types.enums import ParsingType 21 | 22 | 23 | class JSONField(TextField): 24 | def db_value(self, value): 25 | return json.dumps(value) 26 | 27 | def python_value(self, value): 28 | return json.loads(value) 29 | 30 | 31 | class DeployedFunctionModel(BaseModel): 32 | uuid = UUIDField(unique=True, default=uuid4) 33 | name = CharField() 34 | 35 | 36 | class DeployedFunctionModelVersion(BaseModel): 37 | uuid = UUIDField(unique=True, default=uuid4) 38 | version = IntegerField(null=False) 39 | from_version = IntegerField(null=True) 40 | function_model_uuid = ForeignKeyField( 41 | DeployedFunctionModel, 42 | field=DeployedFunctionModel.uuid, 43 | backref="versions", 44 | on_delete="CASCADE", 45 | ) 46 | model = CharField() 47 | is_published = BooleanField(default=False) 48 | is_ab_test = BooleanField(default=False) 49 | ratio = FloatField(null=True) 50 | parsing_type = CharField( 51 | null=True, 52 | default=None, 53 | constraints=[ 54 | Check( 55 | f"parsing_type IN ('{ParsingType.COLON.value}', '{ParsingType.SQUARE_BRACKET.value}', '{ParsingType.DOUBLE_SQUARE_BRACKET.value}')" 56 | ) 57 | ], 58 | ) 59 | output_keys = JSONField(null=True, default=None) 60 | functions = JSONField(default=[]) 61 | 62 | 63 | class DeployedPrompt(BaseModel): 64 | id = AutoField() 65 | version_uuid = ForeignKeyField( 66 | DeployedFunctionModelVersion, 67 | field=DeployedFunctionModelVersion.uuid, 68 | backref="prompts", 69 | on_delete="CASCADE", 70 | ) 71 | role = CharField() 72 | step = IntegerField() 73 | content = TextField() 74 | -------------------------------------------------------------------------------- /promptmodel/database/models_chat.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | 4 | from uuid import uuid4 5 | from peewee import ( 6 | CharField, 7 | DateTimeField, 8 | IntegerField, 9 | ForeignKeyField, 10 | BooleanField, 11 | UUIDField, 12 | TextField, 13 | AutoField, 14 | FloatField, 15 | Check, 16 | ) 17 | 18 | 19 | class JSONField(TextField): 20 | def db_value(self, value): 21 | return json.dumps(value) 22 | 23 | def python_value(self, value): 24 | return json.loads(value) 25 | 26 | 27 | # class DeployedChatModel(BaseModel): 28 | # uuid = UUIDField(unique=True, default=uuid4) 29 | # name = CharField() 30 | 31 | 32 | # class DeployedChatModelVersion(BaseModel): 33 | # uuid = UUIDField(unique=True, default=uuid4) 34 | # from_uuid = UUIDField(null=True) 35 | # chat_model_uuid = ForeignKeyField( 36 | # DeployedChatModel, 37 | # field=DeployedChatModel.uuid, 38 | # backref="versions", 39 | # on_delete="CASCADE", 40 | # ) 41 | # model = CharField() 42 | # is_published = BooleanField(default=False) 43 | # is_ab_test = BooleanField(default=False) 44 | # ratio = FloatField(null=True) 45 | # system_prompt = JSONField(null=True, default={}) 46 | # functions = JSONField(default=[]) 47 | -------------------------------------------------------------------------------- /promptmodel/database/orm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .models import ( 4 | DeployedFunctionModel, 5 | DeployedFunctionModelVersion, 6 | DeployedPrompt, 7 | ) 8 | from .config import db 9 | 10 | 11 | def initialize_db(): 12 | if not os.path.exists("./.promptmodel"): 13 | os.mkdir("./.promptmodel") 14 | # Check if db connection exists 15 | if db.is_closed(): 16 | db.connect() 17 | with db.atomic(): 18 | if not DeployedFunctionModel.table_exists(): 19 | db.create_tables( 20 | [ 21 | # FunctionModel, 22 | # FunctionModelVersion, 23 | # Prompt, 24 | # RunLog, 25 | # SampleInputs, 26 | DeployedFunctionModel, 27 | DeployedFunctionModelVersion, 28 | DeployedPrompt, 29 | # ChatModel, 30 | # ChatModelVersion, 31 | # ChatLogSession, 32 | # ChatLog, 33 | ] 34 | ) 35 | db.close() 36 | -------------------------------------------------------------------------------- /promptmodel/dev_app.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import dis 3 | import nest_asyncio 4 | from dataclasses import dataclass 5 | from typing import Callable, Dict, Any, List, Optional, Union 6 | 7 | from promptmodel.types.response import FunctionSchema 8 | 9 | 10 | @dataclass 11 | class FunctionModelInterface: 12 | name: str 13 | 14 | 15 | @dataclass 16 | class ChatModelInterface: 17 | name: str 18 | 19 | 20 | class DevClient: 21 | """DevClient main class""" 22 | 23 | def __init__(self): 24 | self.function_models: List[FunctionModelInterface] = [] 25 | self.chat_models: List[ChatModelInterface] = [] 26 | 27 | def register(self, func): 28 | instructions = list(dis.get_instructions(func)) 29 | for idx in range( 30 | len(instructions) - 1 31 | ): # We check up to len-1 because we access idx+1 inside loop 32 | instruction = instructions[idx] 33 | # print(instruction) 34 | if instruction.opname in ["LOAD_ATTR", "LOAD_METHOD", "LOAD_GLOBAL"] and ( 35 | instruction.argval == "FunctionModel" 36 | or instruction.argval == "ChatModel" 37 | ): 38 | next_instruction = instructions[idx + 1] 39 | 40 | # Check if the next instruction is LOAD_CONST with string value 41 | if next_instruction.opname == "LOAD_CONST" and isinstance( 42 | next_instruction.argval, str 43 | ): 44 | if instruction.argval == "FunctionModel": 45 | self.function_models.append( 46 | FunctionModelInterface(name=next_instruction.argval) 47 | ) 48 | elif instruction.argval == "ChatModel": 49 | self.chat_models.append( 50 | ChatModelInterface(name=next_instruction.argval) 51 | ) 52 | 53 | def wrapper(*args, **kwargs): 54 | return func(*args, **kwargs) 55 | 56 | return wrapper 57 | 58 | def register_function_model(self, name): 59 | for function_model in self.function_models: 60 | if function_model.name == name: 61 | return 62 | 63 | self.function_models.append(FunctionModelInterface(name=name)) 64 | 65 | def register_chat_model(self, name): 66 | for chat_model in self.chat_models: 67 | if chat_model.name == name: 68 | return 69 | 70 | self.chat_models.append(ChatModelInterface(name=name)) 71 | 72 | def _get_function_model_name_list(self) -> List[str]: 73 | return [function_model.name for function_model in self.function_models] 74 | 75 | 76 | class DevApp: 77 | _nest_asyncio_applied = False 78 | 79 | def __init__(self): 80 | self.function_models: List[FunctionModelInterface] = [] 81 | self.chat_models: List[ChatModelInterface] = [] 82 | self.samples: List[Dict[str, Any]] = [] 83 | self.functions: Dict[ 84 | str, Dict[str, Union[FunctionSchema, Optional[Callable]]] 85 | ] = {} 86 | 87 | if not DevApp._nest_asyncio_applied: 88 | DevApp._nest_asyncio_applied = True 89 | nest_asyncio.apply() 90 | 91 | def include_client(self, client: DevClient): 92 | self.function_models.extend(client.function_models) 93 | self.chat_models.extend(client.chat_models) 94 | 95 | def register_function( 96 | self, schema: Union[Dict[str, Any], FunctionSchema], function: Callable 97 | ): 98 | function_name = schema["name"] 99 | if isinstance(schema, dict): 100 | try: 101 | schema = FunctionSchema(**schema) 102 | except: 103 | raise ValueError("schema is not a valid function call schema.") 104 | 105 | if function_name not in self.functions: 106 | self.functions[function_name] = { 107 | "schema": schema, 108 | "function": function, 109 | } 110 | 111 | def _call_register_function(self, name: str, arguments: Dict[str, str]): 112 | function_to_call: Optional[Callable] = self.functions[name]["function"] 113 | if not function_to_call: 114 | return 115 | try: 116 | function_response = function_to_call(**arguments) 117 | return function_response 118 | except Exception as e: 119 | raise e 120 | 121 | def _get_function_name_list(self) -> List[str]: 122 | return list(self.functions.keys()) 123 | 124 | def _get_function_schema_list(self) -> List[Dict]: 125 | return [ 126 | self.functions[function_name]["schema"].model_dump() 127 | for function_name in self._get_function_name_list() 128 | ] 129 | 130 | def _get_function_schemas(self, function_names: List[str] = []): 131 | try: 132 | function_schemas = [ 133 | self.functions[function_name]["schema"].model_dump() 134 | for function_name in function_names 135 | ] 136 | return function_schemas 137 | except Exception as e: 138 | raise e 139 | 140 | def register_sample(self, name: str, content: Dict[str, Any]): 141 | self.samples.append({"name": name, "content": content}) 142 | 143 | def _get_function_model_name_list(self) -> List[str]: 144 | return [function_model.name for function_model in self.function_models] 145 | 146 | def _get_chat_model_name_list(self) -> List[str]: 147 | return [chat_model.name for chat_model in self.chat_models] 148 | -------------------------------------------------------------------------------- /promptmodel/llms/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /promptmodel/llms/llm_dev.py: -------------------------------------------------------------------------------- 1 | """LLM for Development TestRun""" 2 | 3 | import re 4 | import json 5 | from datetime import datetime 6 | from typing import Any, AsyncGenerator, List, Dict, Optional 7 | from pydantic import BaseModel 8 | from dotenv import load_dotenv 9 | from litellm import acompletion, get_max_tokens 10 | 11 | from promptmodel.types.enums import ParsingType, get_pattern_by_type 12 | from promptmodel.utils import logger 13 | from promptmodel.utils.output_utils import convert_str_to_type 14 | from promptmodel.utils.token_counting import ( 15 | num_tokens_for_messages, 16 | num_tokens_for_messages_for_each, 17 | num_tokens_from_functions_input, 18 | ) 19 | from promptmodel.types.response import ( 20 | LLMStreamResponse, 21 | ModelResponse, 22 | Usage, 23 | Choices, 24 | Message, 25 | ) 26 | 27 | load_dotenv() 28 | 29 | 30 | class OpenAIMessage(BaseModel): 31 | role: Optional[str] = None 32 | content: Optional[str] = "" 33 | function_call: Optional[Dict[str, Any]] = None 34 | name: Optional[str] = None 35 | 36 | 37 | class LLMDev: 38 | def __init__(self): 39 | self._model: str 40 | 41 | def __validate_openai_messages( 42 | self, messages: List[Dict[str, Any]] 43 | ) -> List[OpenAIMessage]: 44 | """Validate and convert list of dictionaries to list of OpenAIMessage.""" 45 | res = [] 46 | for message in messages: 47 | res.append(OpenAIMessage(**message)) 48 | return res 49 | 50 | async def dev_run( 51 | self, 52 | messages: List[Dict[str, Any]], 53 | parsing_type: Optional[ParsingType] = None, 54 | functions: Optional[List[Any]] = None, 55 | model: Optional[str] = None, 56 | **kwargs, 57 | ) -> AsyncGenerator[Any, None]: 58 | """Parse & stream output from openai chat completion.""" 59 | _model = model or self._model 60 | raw_output = "" 61 | if functions == []: 62 | functions = None 63 | 64 | start_time = datetime.now() 65 | 66 | response: AsyncGenerator[ModelResponse, None] = await acompletion( 67 | model=_model, 68 | messages=[ 69 | message.model_dump(exclude_none=True) 70 | for message in self.__validate_openai_messages(messages) 71 | ], 72 | stream=True, 73 | functions=functions, 74 | response_format=( 75 | {"type": "json_object"} 76 | if parsing_type == ParsingType.JSON 77 | else {"type": "text"} 78 | ), 79 | **kwargs, 80 | ) 81 | function_call = {"name": "", "arguments": ""} 82 | finish_reason_function_call = False 83 | async for chunk in response: 84 | if getattr(chunk.choices[0].delta, "content", None) is not None: 85 | stream_value = chunk.choices[0].delta.content 86 | raw_output += stream_value # append raw output 87 | yield LLMStreamResponse(raw_output=stream_value) # return raw output 88 | 89 | if getattr(chunk.choices[0].delta, "function_call", None) is not None: 90 | for key, value in ( 91 | chunk.choices[0].delta.function_call.model_dump().items() 92 | ): 93 | if value is not None: 94 | function_call[key] += value 95 | 96 | if chunk.choices[0].finish_reason == "function_call": 97 | finish_reason_function_call = True 98 | yield LLMStreamResponse(function_call=function_call) 99 | 100 | if chunk.choices[0].finish_reason != None: 101 | end_time = datetime.now() 102 | response_ms = (end_time - start_time).total_seconds() * 1000 103 | yield LLMStreamResponse( 104 | api_response=self.make_model_response_dev( 105 | chunk, 106 | response_ms, 107 | messages, 108 | raw_output, 109 | functions=functions, 110 | function_call=( 111 | function_call 112 | if chunk.choices[0].finish_reason == "function_call" 113 | else None 114 | ), 115 | tools=None, 116 | tool_calls=None, 117 | ) 118 | ) 119 | 120 | # parsing 121 | if parsing_type and not finish_reason_function_call: 122 | if parsing_type == ParsingType.JSON: 123 | yield LLMStreamResponse(parsed_outputs=json.loads(raw_output)) 124 | else: 125 | parsing_pattern: Dict[str, str] = get_pattern_by_type(parsing_type) 126 | whole_pattern = parsing_pattern["whole"] 127 | parsed_results = re.findall(whole_pattern, raw_output, flags=re.DOTALL) 128 | for parsed_result in parsed_results: 129 | key = parsed_result[0] 130 | type_str = parsed_result[1] 131 | value = convert_str_to_type(parsed_result[2], type_str) 132 | yield LLMStreamResponse(parsed_outputs={key: value}) 133 | 134 | async def dev_chat( 135 | self, 136 | messages: List[Dict[str, Any]], 137 | functions: Optional[List[Any]] = None, 138 | tools: Optional[List[Any]] = None, 139 | model: Optional[str] = None, 140 | **kwargs, 141 | ) -> AsyncGenerator[LLMStreamResponse, None]: 142 | """Parse & stream output from openai chat completion.""" 143 | _model = model or self._model 144 | raw_output = "" 145 | if functions == []: 146 | functions = None 147 | 148 | if model != "HCX-002": 149 | # Truncate the output if it is too long 150 | # truncate messages to make length <= model's max length 151 | token_per_functions = num_tokens_from_functions_input( 152 | functions=functions, model=model 153 | ) 154 | model_max_tokens = get_max_tokens(model=model) 155 | token_per_messages = num_tokens_for_messages_for_each(messages, model) 156 | token_limit_exceeded = ( 157 | sum(token_per_messages) + token_per_functions 158 | ) - model_max_tokens 159 | if token_limit_exceeded > 0: 160 | while token_limit_exceeded > 0: 161 | # erase the second oldest message (first one is system prompt, so it should not be erased) 162 | if len(messages) == 1: 163 | # if there is only one message, Error cannot be solved. Just call LLM and get error response 164 | break 165 | token_limit_exceeded -= token_per_messages[1] 166 | del messages[1] 167 | del token_per_messages[1] 168 | 169 | args = dict( 170 | model=_model, 171 | messages=[ 172 | message.model_dump(exclude_none=True) 173 | for message in self.__validate_openai_messages(messages) 174 | ], 175 | functions=functions, 176 | tools=tools, 177 | ) 178 | 179 | is_stream_unsupported = model in ["HCX-002"] 180 | if not is_stream_unsupported: 181 | args["stream"] = True 182 | 183 | start_time = datetime.now() 184 | response: AsyncGenerator[ModelResponse, None] = await acompletion( 185 | **args, **kwargs 186 | ) 187 | if is_stream_unsupported: 188 | yield LLMStreamResponse(raw_output=response.choices[0].message.content) 189 | else: 190 | async for chunk in response: 191 | yield_api_response_with_fc = False 192 | logger.debug(chunk) 193 | if getattr(chunk.choices[0].delta, "function_call", None) is not None: 194 | yield LLMStreamResponse( 195 | api_response=chunk, 196 | function_call=chunk.choices[0].delta.function_call, 197 | ) 198 | yield_api_response_with_fc = True 199 | 200 | if getattr(chunk.choices[0].delta, "tool_calls", None) is not None: 201 | yield LLMStreamResponse( 202 | api_response=chunk, 203 | tool_calls=chunk.choices[0].delta.tool_calls, 204 | ) 205 | yield_api_response_with_fc = True 206 | 207 | if getattr(chunk.choices[0].delta, "content", None) is not None: 208 | raw_output += chunk.choices[0].delta.content 209 | yield LLMStreamResponse( 210 | api_response=chunk if not yield_api_response_with_fc else None, 211 | raw_output=chunk.choices[0].delta.content, 212 | ) 213 | 214 | if getattr(chunk.choices[0].delta, "finish_reason", None) is not None: 215 | end_time = datetime.now() 216 | response_ms = (end_time - start_time).total_seconds() * 1000 217 | yield LLMStreamResponse( 218 | api_response=self.make_model_response_dev( 219 | chunk, 220 | response_ms, 221 | messages, 222 | raw_output, 223 | functions=None, 224 | function_call=( 225 | None 226 | if chunk.choices[0].finish_reason == "function_call" 227 | else None 228 | ), 229 | tools=None, 230 | tool_calls=None, 231 | ) 232 | ) 233 | 234 | def make_model_response_dev( 235 | self, 236 | chunk: ModelResponse, 237 | response_ms, 238 | messages: List[Dict[str, str]], 239 | raw_output: str, 240 | functions: Optional[List[Any]] = None, 241 | function_call: Optional[Dict[str, Any]] = None, 242 | tools: Optional[List[Any]] = None, 243 | tool_calls: Optional[List[Dict[str, Any]]] = None, 244 | ) -> ModelResponse: 245 | """Make ModelResponse object from openai response.""" 246 | count_start_time = datetime.now() 247 | prompt_token: int = num_tokens_for_messages( 248 | messages=messages, model=chunk["model"] 249 | ) 250 | completion_token: int = num_tokens_for_messages( 251 | model=chunk["model"], 252 | messages=[{"role": "assistant", "content": raw_output}], 253 | ) 254 | 255 | if functions and len(functions) > 0: 256 | functions_token = num_tokens_from_functions_input( 257 | functions=functions, model=chunk["model"] 258 | ) 259 | prompt_token += functions_token 260 | 261 | if tools and len(tools) > 0: 262 | tools_token = num_tokens_from_functions_input( 263 | functions=[tool["function"] for tool in tools], model=chunk["model"] 264 | ) 265 | prompt_token += tools_token 266 | # if function_call: 267 | # function_call_token = num_tokens_from_function_call_output( 268 | # function_call_output=function_call, model=chunk["model"] 269 | # ) 270 | # completion_token += function_call_token 271 | 272 | count_end_time = datetime.now() 273 | logger.debug( 274 | f"counting token time : {(count_end_time - count_start_time).total_seconds() * 1000} ms" 275 | ) 276 | 277 | usage = Usage( 278 | **{ 279 | "prompt_tokens": prompt_token, 280 | "completion_tokens": completion_token, 281 | "total_tokens": prompt_token + completion_token, 282 | } 283 | ) 284 | 285 | last_message = Message( 286 | role=( 287 | chunk.choices[0].delta.role 288 | if getattr(chunk.choices[0].delta, "role", None) 289 | else "assistant" 290 | ), 291 | content=raw_output if raw_output != "" else None, 292 | function_call=function_call if function_call else None, 293 | tool_calls=tool_calls if tool_calls else None, 294 | ) 295 | choices = [ 296 | Choices(finish_reason=chunk.choices[0].finish_reason, message=last_message) 297 | ] 298 | 299 | res = ModelResponse( 300 | id=chunk["id"], 301 | created=chunk["created"], 302 | model=chunk["model"], 303 | stream=True, 304 | ) 305 | res.choices = choices 306 | res.usage = usage 307 | res._response_ms = response_ms 308 | 309 | return res 310 | -------------------------------------------------------------------------------- /promptmodel/promptmodel_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nest_asyncio 3 | import threading 4 | import asyncio 5 | import atexit 6 | import time 7 | 8 | from typing import Optional, Dict, Any 9 | from datetime import datetime 10 | 11 | from promptmodel.utils.config_utils import upsert_config, read_config 12 | from promptmodel.utils import logger 13 | from promptmodel.database.orm import initialize_db 14 | from promptmodel.database.crud import update_deployed_cache 15 | from promptmodel.apis.base import AsyncAPIClient 16 | from promptmodel.types.enums import InstanceType 17 | 18 | 19 | def init(use_cache: Optional[bool] = True, mask_inputs: Optional[bool] = False): 20 | nest_asyncio.apply() 21 | 22 | config = read_config() 23 | if ( 24 | config 25 | and "connection" in config 26 | and ( 27 | ( 28 | "online" in config["connection"] 29 | and config["connection"]["online"] == True 30 | ) 31 | or ( 32 | "initializing" in config["connection"] 33 | and config["connection"]["initializing"] == True 34 | ) 35 | ) 36 | ): 37 | cache_manager = None 38 | else: 39 | if use_cache: 40 | upsert_config({"use_cache": True}, section="project") 41 | cache_manager = CacheManager() 42 | else: 43 | upsert_config({"use_cache": False}, section="project") 44 | cache_manager = None 45 | initialize_db() # init db for local usage 46 | 47 | if mask_inputs is True: 48 | upsert_config({"mask_inputs": True}, section="project") 49 | else: 50 | upsert_config({"mask_inputs": False}, section="project") 51 | 52 | 53 | class CacheManager: 54 | _instance = None 55 | _lock = threading.Lock() 56 | 57 | def __new__(cls): 58 | with cls._lock: 59 | if cls._instance is None: 60 | instance = super(CacheManager, cls).__new__(cls) 61 | instance.last_update_time = 0 # to manage update frequency 62 | instance.update_interval = 60 * 60 * 6 # seconds, 6 hours 63 | instance.program_alive = True 64 | instance.background_tasks = [] 65 | initialize_db() 66 | atexit.register(instance._terminate) 67 | asyncio.run(instance.update_cache()) # updae cache first synchronously 68 | instance.cache_thread = threading.Thread( 69 | target=instance._run_cache_loop 70 | ) 71 | instance.cache_thread.daemon = True 72 | instance.cache_thread.start() 73 | cls._instance = instance 74 | return cls._instance 75 | 76 | def cache_update_background_task(self, config): 77 | asyncio.run(update_deployed_db(config)) 78 | 79 | def _run_cache_loop(self): 80 | asyncio.run(self._update_cache_periodically()) 81 | 82 | async def _update_cache_periodically(self): 83 | while True: 84 | await asyncio.sleep(self.update_interval) # Non-blocking sleep 85 | await self.update_cache() 86 | 87 | async def update_cache(self): 88 | # Current time 89 | current_time = time.time() 90 | config = read_config() 91 | 92 | if not config: 93 | upsert_config({"version": 0}, section="project") 94 | config = {"project": {"version": 0}} 95 | if "project" not in config: 96 | upsert_config({"version": 0}, section="project") 97 | config = {"project": {"version": 0}} 98 | 99 | if "version" not in config["project"]: 100 | upsert_config({"version": 0}, section="project") 101 | config = {"project": {"version": 0}} 102 | 103 | # Check if we need to update the cache 104 | if current_time - self.last_update_time > self.update_interval: 105 | # Update cache logic 106 | try: 107 | await update_deployed_db(config) 108 | except: 109 | # try once more 110 | await update_deployed_db(config) 111 | # Update the last update time 112 | self.last_update_time = current_time 113 | 114 | def _terminate(self): 115 | self.program_alive = False 116 | 117 | # async def cleanup_background_tasks(self): 118 | # for task in self.background_tasks: 119 | # if not task.done(): 120 | # task.cancel() 121 | # try: 122 | # await task 123 | # except asyncio.CancelledError: 124 | # pass # 작업이 취소됨 125 | 126 | 127 | async def update_deployed_db(config): 128 | if "project" not in config or "version" not in config["project"]: 129 | cached_project_version = 0 130 | else: 131 | cached_project_version = int(config["project"]["version"]) 132 | try: 133 | res = await AsyncAPIClient.execute( 134 | method="GET", 135 | path="/check_update", 136 | params={"cached_version": cached_project_version}, 137 | use_cli_key=False, 138 | ) 139 | res = res.json() 140 | if res["need_update"]: 141 | # update local DB with res['project_status'] 142 | project_status = res["project_status"] 143 | await update_deployed_cache(project_status) 144 | upsert_config({"version": res["version"]}, section="project") 145 | else: 146 | upsert_config({"version": res["version"]}, section="project") 147 | except Exception as exception: 148 | logger.error(f"Deployment cache update error: {exception}") 149 | -------------------------------------------------------------------------------- /promptmodel/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .enums import InstanceType 2 | from .response import ( 3 | LLMResponse, 4 | LLMStreamResponse, 5 | FunctionModelConfig, 6 | PromptModelConfig, 7 | ChatModelConfig, 8 | FunctionSchema, 9 | ) 10 | from .request import ChatLogRequest, RunLogRequest 11 | -------------------------------------------------------------------------------- /promptmodel/types/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class InstanceType(str, Enum): 5 | ChatLog = "ChatLog" 6 | RunLog = "RunLog" 7 | ChatLogSession = "ChatLogSession" 8 | 9 | 10 | class LocalTask(str, Enum): 11 | RUN_PROMPT_MODEL = "RUN_PROMPT_MODEL" 12 | RUN_CHAT_MODEL = "RUN_CHAT_MODEL" 13 | 14 | LIST_CODE_CHAT_MODELS = "LIST_CHAT_MODELS" 15 | LIST_CODE_PROMPT_MODELS = "LIST_PROMPT_MODELS" 16 | LIST_CODE_FUNCTIONS = "LIST_FUNCTIONS" 17 | 18 | 19 | class LocalTaskErrorType(str, Enum): 20 | NO_FUNCTION_NAMED_ERROR = "NO_FUNCTION_NAMED_ERROR" # no DB update is needed 21 | FUNCTION_CALL_FAILED_ERROR = "FUNCTION_CALL_FAILED_ERROR" # create FunctionModelVersion, create Prompt, create RunLog 22 | PARSING_FAILED_ERROR = "PARSING_FAILED_ERROR" # create FunctionModelVersion, create Prompt, create RunLog 23 | 24 | SERVICE_ERROR = "SERVICE_ERROR" # no DB update is needed 25 | 26 | 27 | class ServerTask(str, Enum): 28 | UPDATE_RESULT_RUN = "UPDATE_RESULT_RUN" 29 | UPDATE_RESULT_CHAT_RUN = "UPDATE_RESULT_CHAT_RUN" 30 | 31 | SYNC_CODE = "SYNC_CODE" 32 | 33 | 34 | class Role(str, Enum): 35 | SYSTEM = "system" 36 | USER = "user" 37 | ASSISTANT = "assistant" 38 | FUNCTION = "function" 39 | 40 | 41 | class ParsingType(str, Enum): 42 | COLON = "colon" 43 | SQUARE_BRACKET = "square_bracket" 44 | DOUBLE_SQUARE_BRACKET = "double_square_bracket" 45 | HTML = "html" 46 | JSON = "json" 47 | 48 | 49 | class ParsingPattern(dict, Enum): 50 | COLON = { 51 | "start": r"(\w+)\s+type=([\w,\s\[\]]+): \n", 52 | "start_fstring": "{key}: \n", 53 | "end_fstring": None, 54 | "whole": r"(.*?): (.*?)\n", 55 | "start_token": None, 56 | "end_token": None, 57 | } 58 | SQUARE_BRACKET = { 59 | "start": r"\[(\w+)\s+type=([\w,\s\[\]]+)\]", 60 | "start_fstring": "[{key} type={type}]", 61 | "end_fstring": "[/{key}]", 62 | "whole": r"\[(\w+)\s+type=([\w,\s\[\]]+)\](.*?)\[/\1\]", 63 | "start_token": r"[", 64 | "end_token": r"]", 65 | } 66 | DOUBLE_SQUARE_BRACKET = { 67 | "start": r"\[\[(\w+)\s+type=([\w,\s\[\]]+)\]\]", 68 | "start_fstring": "[[{key} type={type}]]", 69 | "end_fstring": "[[/{key}]]", 70 | "whole": r"\[\[(\w+)\s+type=([\w,\s\[\]]+)\]\](.*?)\[\[/\1\]\]", 71 | "start_token": r"[", 72 | "end_token": r"]", 73 | } 74 | HTML = { 75 | "start": r"<(\w+)\s+type=([\w,\s\[\]]+)>", 76 | "start_fstring": "<{key} type={type}>", 77 | "end_fstring": "", 78 | "whole": r"<(\w+)\s+type=([\w,\s\[\]]+)>(.*?)", # also captures type 79 | "start_token": r"<", 80 | "end_token": r">", 81 | } 82 | 83 | 84 | def get_pattern_by_type(parsing_type_value): 85 | return ParsingPattern[ParsingType(parsing_type_value).name].value 86 | -------------------------------------------------------------------------------- /promptmodel/types/request.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Dict, 3 | Any, 4 | Optional, 5 | ) 6 | from pydantic import BaseModel 7 | from litellm.utils import ( 8 | ModelResponse, 9 | ) 10 | from openai.types.chat.chat_completion import * 11 | 12 | 13 | class ChatLogRequest(BaseModel): 14 | uuid: Optional[str] = None 15 | message: Dict[str, Any] 16 | metadata: Optional[Dict] = None 17 | api_response: Optional[ModelResponse] = None 18 | 19 | def __post_init__( 20 | self, 21 | ): 22 | if self.api_response is not None and self.message is None: 23 | self.message = self.api_response.choices[0].message.model_dump() 24 | 25 | 26 | class RunLogRequest(BaseModel): 27 | uuid: Optional[str] = None 28 | inputs: Optional[Dict[str, Any]] = None 29 | metadata: Optional[Dict] = None 30 | api_response: Optional[ModelResponse] = None 31 | -------------------------------------------------------------------------------- /promptmodel/types/response.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | List, 3 | Dict, 4 | Any, 5 | Optional, 6 | ) 7 | from pydantic import BaseModel 8 | from litellm.utils import ( 9 | ModelResponse, 10 | ) 11 | from litellm.types.utils import ( 12 | FunctionCall, 13 | ChatCompletionMessageToolCall, 14 | Usage, 15 | Choices, 16 | Message, 17 | Delta, 18 | StreamingChoices, 19 | ) 20 | from openai._models import BaseModel as OpenAIObject 21 | from openai.types.chat.chat_completion import * 22 | from openai.types.chat.chat_completion_chunk import ( 23 | ChoiceDeltaFunctionCall, 24 | ChoiceDeltaToolCall, 25 | ChoiceDeltaToolCallFunction, 26 | ) 27 | 28 | 29 | class PMDetail(BaseModel): 30 | model: str 31 | name: str 32 | version_uuid: str 33 | version: int 34 | log_uuid: str 35 | 36 | 37 | class LLMResponse(OpenAIObject): 38 | api_response: Optional[ModelResponse] = None 39 | raw_output: Optional[str] = None 40 | parsed_outputs: Optional[Dict[str, Any]] = None 41 | error: Optional[bool] = None 42 | error_log: Optional[str] = None 43 | function_call: Optional[FunctionCall] = None 44 | tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None 45 | pm_detail: Optional[PMDetail] = None 46 | 47 | 48 | class LLMStreamResponse(OpenAIObject): 49 | api_response: Optional[ModelResponse] = None 50 | raw_output: Optional[str] = None 51 | parsed_outputs: Optional[Dict[str, Any]] = None 52 | error: Optional[bool] = None 53 | error_log: Optional[str] = None 54 | function_call: Optional[ChoiceDeltaFunctionCall] = None 55 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = None 56 | pm_detail: Optional[PMDetail] = None 57 | 58 | 59 | class FunctionModelConfig(BaseModel): 60 | """Response Class for FunctionModel.get_config() 61 | prompts: List[Dict[str, Any]] = [] 62 | each prompt can have role, content, name, function_call, and tool_calls 63 | version_detail: Dict[str, Any] = {} 64 | version_detail has "model", "uuid", "parsing_type" and "output_keys". 65 | model: str 66 | model name (e.g. "gpt-3.5-turbo") 67 | name: str 68 | name of the FunctionModel. 69 | version_uuid: str 70 | version uuid of the FunctionModel. 71 | version: int 72 | version id of the FunctionModel. 73 | parsing_type: Optional[str] = None 74 | parsing type of the FunctionModel. 75 | output_keys: Optional[List[str]] = None 76 | output keys of the FunctionModel. 77 | """ 78 | 79 | prompts: List[Dict[str, Any]] 80 | model: str 81 | name: str 82 | version_uuid: str 83 | version: int 84 | parsing_type: Optional[str] = None 85 | output_keys: Optional[List[str]] = None 86 | 87 | 88 | class PromptModelConfig(FunctionModelConfig): 89 | """Deprecated. Use FunctionModelConfig instead.""" 90 | 91 | 92 | class ChatModelConfig(BaseModel): 93 | system_prompt: str 94 | model: str 95 | name: str 96 | version_uuid: str 97 | version: int 98 | message_logs: Optional[List[Dict]] = [] 99 | 100 | 101 | class FunctionSchema(BaseModel): 102 | """ 103 | { 104 | "name": str, 105 | "description": Optional[str], 106 | "parameters": { 107 | "type": "object", 108 | "properties": { 109 | "argument_name": { 110 | "type": str, 111 | "description": Optional[str], 112 | "enum": Optional[List[str]] 113 | }, 114 | }, 115 | "required": Optional[List[str]], 116 | }, 117 | } 118 | """ 119 | 120 | class _Parameters(BaseModel): 121 | class _Properties(BaseModel): 122 | type: str 123 | description: Optional[str] = "" 124 | enum: Optional[List[str]] = [] 125 | 126 | type: str = "object" 127 | properties: Dict[str, _Properties] = {} 128 | required: Optional[List[str]] = [] 129 | 130 | name: str 131 | description: Optional[str] = None 132 | parameters: _Parameters 133 | 134 | 135 | class UnitConfig(BaseModel): 136 | """Response Class for UnitLogger.get_config(). 137 | Created after calling UnitLogger.log_start() 138 | name: str 139 | name of the UnitLogger. 140 | version_uuid: str 141 | version uuid of the UnitLogger. 142 | version: int 143 | version id of the UnitLogger. 144 | log_uuid: str 145 | log_uuid for current trace. 146 | """ 147 | 148 | name: str 149 | version_uuid: str 150 | log_uuid: str 151 | version: int 152 | -------------------------------------------------------------------------------- /promptmodel/unit_logger.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Dict, 3 | Optional, 4 | ) 5 | 6 | import promptmodel.utils.logger as logger 7 | from promptmodel.utils.async_utils import run_async_in_sync 8 | from promptmodel.utils.config_utils import check_connection_status_decorator 9 | from promptmodel.types.response import ( 10 | UnitConfig, 11 | ) 12 | from promptmodel.apis.base import AsyncAPIClient 13 | 14 | class UnitLogger: 15 | def __init__( 16 | self, 17 | name: str, 18 | version: int 19 | ): 20 | self.name: str = name 21 | self.version: int = version 22 | self.config: Optional[UnitConfig] = None 23 | 24 | @check_connection_status_decorator 25 | def get_config(self) -> Optional[UnitConfig]: 26 | """Get the config of the component 27 | You can get the config directly from the UnitLogger.confg attribute. 28 | """ 29 | return self.config 30 | 31 | @check_connection_status_decorator 32 | async def log_start(self, *args, **kwargs) -> Optional["UnitLogger"]: 33 | """Create Component Log on the cloud. 34 | It returns the UnitLogger itself, so you can use it like this: 35 | >>> component = UnitLogger("intent_classification_unit", 1).log_start() 36 | >>> res = FunctionModel("intent_classifier", unit_config=component.config).run(...) 37 | >>> component.log_score({"accuracy": 0.9}) 38 | """ 39 | res = await AsyncAPIClient.execute( 40 | method="POST", 41 | path="/unit/log", 42 | json={ 43 | "name": self.name, 44 | "version": self.version, 45 | }, 46 | use_cli_key=False 47 | ) 48 | if res.status_code != 200: 49 | logger.error(f"Failed to log start for component {self.name} v{self.version}") 50 | return None 51 | else: 52 | self.config = UnitConfig(**res.json()) 53 | 54 | return self 55 | 56 | 57 | @check_connection_status_decorator 58 | async def log_score(self, scores: Dict[str, float], *args, **kwargs): 59 | res = await AsyncAPIClient.execute( 60 | method="POST", 61 | path="/unit/score", 62 | json={ 63 | "unit_log_uuid" : self.config.log_uuid, 64 | "scores": scores, 65 | }, 66 | use_cli_key=False 67 | ) 68 | if res.status_code != 200: 69 | logger.error(f"Failed to log score for component {self.name} v{self.version}") 70 | 71 | return 72 | -------------------------------------------------------------------------------- /promptmodel/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /promptmodel/utils/async_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Coroutine 3 | 4 | 5 | def run_async_in_sync(coro: Coroutine): 6 | try: 7 | loop = asyncio.get_running_loop() 8 | except RuntimeError: # No running loop 9 | loop = asyncio.new_event_loop() 10 | asyncio.set_event_loop(loop) 11 | result = loop.run_until_complete(coro) 12 | # loop.close() 13 | return result 14 | 15 | return loop.run_until_complete(coro) 16 | 17 | 18 | def run_async_in_sync_threadsafe(coro: Coroutine, main_loop: asyncio.AbstractEventLoop): 19 | future = asyncio.run_coroutine_threadsafe(coro, main_loop) 20 | res = future.result() 21 | return res 22 | -------------------------------------------------------------------------------- /promptmodel/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from typing import Any, Dict 4 | import yaml 5 | from functools import wraps 6 | 7 | CONFIG_FILE = "./.promptmodel/config.yaml" 8 | 9 | 10 | def read_config(): 11 | """ 12 | Reads the configuration from the given filename. 13 | 14 | :return: A dictionary containing the configuration. 15 | """ 16 | if not os.path.exists(CONFIG_FILE): 17 | return {} 18 | 19 | with open(CONFIG_FILE, "r") as file: 20 | config = yaml.safe_load(file) or {} 21 | return config 22 | 23 | 24 | def merge_dict(d1: Dict[str, Any], d2: Dict[str, Any]): 25 | """ 26 | Merge two dictionaries recursively. 27 | 28 | :param d1: The first dictionary. 29 | :param d2: The second dictionary. 30 | :return: The merged dictionary. 31 | """ 32 | for key, value in d2.items(): 33 | if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict): 34 | d1[key] = merge_dict(d1[key], value) 35 | else: 36 | d1[key] = value 37 | return d1 38 | 39 | 40 | def upsert_config(new_config: Dict[str, Any], section: str = None): 41 | """ 42 | Upserts the given configuration file with the given configuration. 43 | 44 | :param new_config: A dictionary containing the new configuration. 45 | :param section: The section of the configuration to update. 46 | """ 47 | config = read_config() 48 | if section: 49 | config_section = config.get(section, {}) 50 | new_config = {section: merge_dict(config_section, new_config)} 51 | config = merge_dict(config, new_config) 52 | # If . directory does not exist, create it 53 | if not os.path.exists("./.promptmodel"): 54 | os.mkdir("./.promptmodel") 55 | 56 | with open(CONFIG_FILE, "w") as file: 57 | yaml.safe_dump(config, file, default_flow_style=False) 58 | 59 | 60 | def check_connection_status_decorator(method): 61 | if asyncio.iscoroutinefunction(method): 62 | 63 | @wraps(method) 64 | async def async_wrapper(self, *args, **kwargs): 65 | config = read_config() 66 | if "connection" in config and ( 67 | ( 68 | "initializing" in config["connection"] 69 | and config["connection"]["initializing"] 70 | ) 71 | or ( 72 | "reloading" in config["connection"] 73 | and config["connection"]["reloading"] 74 | ) 75 | ): 76 | return 77 | else: 78 | if "config" not in kwargs: 79 | kwargs["config"] = config 80 | return await method(self, *args, **kwargs) 81 | 82 | # async_wrapper.__name__ = method.__name__ 83 | # async_wrapper.__doc__ = method.__doc__ 84 | return async_wrapper 85 | else: 86 | 87 | @wraps(method) 88 | def wrapper(self, *args, **kwargs): 89 | config = read_config() 90 | if "connection" in config and ( 91 | ( 92 | "initializing" in config["connection"] 93 | and config["connection"]["initializing"] 94 | ) 95 | or ( 96 | "reloading" in config["connection"] 97 | and config["connection"]["reloading"] 98 | ) 99 | ): 100 | return 101 | else: 102 | return method(self, *args, **kwargs) 103 | 104 | # wrapper.__name__ = method.__name__ 105 | # wrapper.__doc__ = method.__doc__ 106 | return wrapper 107 | -------------------------------------------------------------------------------- /promptmodel/utils/crypto.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | from cryptography.fernet import Fernet 4 | import secrets 5 | 6 | 7 | def generate_api_key(): 8 | """Generate a random token with 32 bytes""" 9 | token = secrets.token_hex(32) 10 | return token 11 | 12 | 13 | def generate_crypto_key(): 14 | """ 15 | Generates a key and save it into a file 16 | """ 17 | key = Fernet.generate_key() 18 | with open("key.key", "wb") as key_file: 19 | key_file.write(key) 20 | 21 | 22 | def load_crypto_key(): 23 | """ 24 | Loads the key named `key.key` from the directory. 25 | Generates a new key if it does not exist. 26 | """ 27 | if not os.path.exists("key.key"): 28 | generate_crypto_key() 29 | return open("key.key", "rb").read() 30 | 31 | 32 | def encrypt_message(message: str) -> bytes: 33 | """ 34 | Encrypts the message 35 | """ 36 | key = load_crypto_key() 37 | f = Fernet(key) 38 | message = message.encode() 39 | encrypted_message = f.encrypt(message) 40 | return encrypted_message 41 | 42 | 43 | def decrypt_message(encrypted_message: bytes) -> str: 44 | """ 45 | Decrypts the message 46 | """ 47 | key = load_crypto_key() 48 | f = Fernet(key) 49 | decrypted_message = f.decrypt(encrypted_message) 50 | return decrypted_message.decode() 51 | -------------------------------------------------------------------------------- /promptmodel/utils/logger.py: -------------------------------------------------------------------------------- 1 | """Logger module""" 2 | 3 | import os 4 | from typing import Any 5 | import termcolor 6 | 7 | 8 | def debug(msg: Any, *args): 9 | if os.environ.get("TESTMODE_LOGGING", "false") != "true": 10 | return 11 | print(termcolor.colored("[DEBUG] " + str(msg) + str(*args), "light_yellow")) 12 | 13 | 14 | def success(msg: Any, *args): 15 | if os.environ.get("TESTMODE_LOGGING", "false") != "true": 16 | return 17 | print(termcolor.colored("[SUCCESS] " + str(msg) + str(*args), "green")) 18 | 19 | 20 | def info(msg: Any, *args): 21 | if os.environ.get("TESTMODE_LOGGING", "false") != "true": 22 | return 23 | print(termcolor.colored("[INFO] " + str(msg) + str(*args), "blue")) 24 | 25 | 26 | def warning(msg: Any, *args): 27 | if os.environ.get("TESTMODE_LOGGING", "false") != "true": 28 | return 29 | print(termcolor.colored("[WARNING] " + str(msg) + str(*args), "yellow")) 30 | 31 | 32 | def error(msg: Any, *args): 33 | print(termcolor.colored("[Error] " + str(msg) + str(*args), "red")) 34 | -------------------------------------------------------------------------------- /promptmodel/utils/output_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict 3 | 4 | 5 | def update_dict( 6 | target: Dict[str, str], 7 | source: Dict[str, str], 8 | ): 9 | for key, value in source.items(): 10 | if value is not None: 11 | if key not in target: 12 | target[key] = value 13 | else: 14 | target[key] += value 15 | return target 16 | 17 | 18 | def convert_str_to_type(value: str, type_str: str) -> Any: 19 | if type_str == "str": 20 | return value.strip() 21 | elif type_str == "bool": 22 | return value.lower() == "true" 23 | elif type_str == "int": 24 | return int(value) 25 | elif type_str == "float": 26 | return float(value) 27 | elif type_str.startswith("List"): 28 | return json.loads(value) 29 | elif type_str.startswith("Dict"): 30 | return json.loads(value) 31 | return value # Default: Return as is 32 | -------------------------------------------------------------------------------- /promptmodel/utils/random_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def select_version_by_ratio(versions): 5 | epsilon = 1e-10 6 | ratios = [version["ratio"] for version in versions] 7 | 8 | if not abs(sum(ratios) - 1.0) <= epsilon: 9 | raise ValueError(f"Sum of ratios must be 1.0, now {sum(ratios)}") 10 | 11 | cumulative_ratios = [] 12 | cumulative_sum = 0 13 | for ratio in ratios: 14 | cumulative_sum += ratio 15 | cumulative_ratios.append(cumulative_sum) 16 | 17 | random_value = random.random() 18 | for idx, cumulative_ratio in enumerate(cumulative_ratios): 19 | if random_value <= cumulative_ratio: 20 | return versions[idx] 21 | -------------------------------------------------------------------------------- /promptmodel/utils/token_counting.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | from litellm import token_counter 3 | 4 | 5 | def set_inputs_to_prompts(inputs: Dict[str, Any], prompts: List[Dict[str, str]]): 6 | messages = [] 7 | for prompt in prompts: 8 | prompt["content"] = prompt["content"].replace("{", "{{").replace("}", "}}") 9 | prompt["content"] = prompt["content"].replace("{{{{", "{").replace("}}}}", "}") 10 | messages.append( 11 | { 12 | "content": prompt["content"].format(**inputs), 13 | "role": prompt["role"], 14 | } 15 | ) 16 | 17 | 18 | def num_tokens_for_messages( 19 | messages: List[Dict[str, str]], model: str = "gpt-3.5-turbo-0613" 20 | ) -> int: 21 | # tokens_per_message = 0 22 | # tokens_per_name = 0 23 | # if model.startswith("gpt-3.5-turbo"): 24 | # tokens_per_message = 4 25 | # tokens_per_name = -1 26 | 27 | # if model.startswith("gpt-4"): 28 | # tokens_per_message = 3 29 | # tokens_per_name = 1 30 | 31 | # if model.endswith("-0613") or model == "gpt-3.5-turbo-16k": 32 | # tokens_per_message = 3 33 | # tokens_per_name = 1 34 | # sum = 0 35 | processed_messages = [ 36 | {**message, "function_call": str(message["function_call"])} 37 | if "function_call" in message 38 | else message 39 | for message in messages 40 | ] 41 | sum = token_counter(model=model, messages=processed_messages) 42 | # for message in messages: 43 | # sum += tokens_per_message 44 | # if "name" in message: 45 | # sum += tokens_per_name 46 | return sum 47 | 48 | 49 | def num_tokens_for_messages_for_each( 50 | messages: List[Dict[str, str]], model: str = "gpt-3.5-turbo-0613" 51 | ) -> List[int]: 52 | processed_messages = [ 53 | {**message, "function_call": str(message["function_call"])} 54 | if "function_call" in message 55 | else message 56 | for message in messages 57 | ] 58 | processed_messages = [ 59 | {**message, "tool_calls": str(message["tool_calls"])} 60 | if "tool_calls" in message 61 | else message 62 | for message in processed_messages 63 | ] 64 | return [ 65 | token_counter(model=model, messages=[message]) for message in processed_messages 66 | ] 67 | 68 | 69 | def num_tokens_from_functions_input( 70 | functions: Optional[List[Any]] = None, model="gpt-3.5-turbo-0613" 71 | ) -> int: 72 | """Return the number of tokens used by a list of functions.""" 73 | if functions is None: 74 | return 0 75 | num_tokens = 0 76 | for function in functions: 77 | function_tokens = token_counter(model=model, text=function["name"]) 78 | function_tokens += token_counter(model=model, text=function["description"]) 79 | 80 | if "parameters" in function: 81 | parameters = function["parameters"] 82 | if "properties" in parameters: 83 | for properties_key in parameters["properties"]: 84 | function_tokens += token_counter(model=model, text=properties_key) 85 | v = parameters["properties"][properties_key] 86 | for field in v: 87 | if field == "type": 88 | function_tokens += 2 89 | function_tokens += token_counter( 90 | model=model, text=v["type"] 91 | ) 92 | elif field == "description": 93 | function_tokens += 2 94 | function_tokens += token_counter( 95 | model=model, text=v["description"] 96 | ) 97 | elif field == "enum": 98 | function_tokens -= 3 99 | for o in v["enum"]: 100 | function_tokens += 3 101 | function_tokens += token_counter(model=model, text=o) 102 | else: 103 | print(f"Warning: not supported field {field}") 104 | function_tokens += 11 105 | 106 | num_tokens += function_tokens 107 | 108 | num_tokens += 12 109 | return num_tokens 110 | 111 | 112 | def num_tokens_from_function_call_output( 113 | function_call_output: Dict[str, str] = {}, model="gpt-3.5-turbo-0613" 114 | ) -> int: 115 | num_tokens = 1 116 | num_tokens += token_counter(model=model, text=function_call_output["name"]) 117 | if "arguments" in function_call_output: 118 | num_tokens += token_counter(model=model, text=function_call_output["arguments"]) 119 | return num_tokens 120 | -------------------------------------------------------------------------------- /promptmodel/websocket/__init__.py: -------------------------------------------------------------------------------- 1 | from .websocket_client import DevWebsocketClient 2 | from .reload_handler import CodeReloadHandler 3 | -------------------------------------------------------------------------------- /promptmodel/websocket/reload_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | import asyncio 5 | from typing import Any, Dict, List 6 | from threading import Timer 7 | from rich import print 8 | from watchdog.events import FileSystemEventHandler 9 | from playhouse.shortcuts import model_to_dict 10 | 11 | from promptmodel.apis.base import APIClient 12 | from promptmodel.utils.config_utils import read_config, upsert_config 13 | from promptmodel.utils.async_utils import run_async_in_sync_threadsafe 14 | from promptmodel.utils import logger 15 | from promptmodel import DevApp 16 | from promptmodel.database.models import ( 17 | DeployedFunctionModel, 18 | DeployedFunctionModelVersion, 19 | DeployedPrompt, 20 | ) 21 | 22 | from promptmodel.websocket.websocket_client import DevWebsocketClient 23 | from promptmodel.types.enums import ServerTask 24 | 25 | 26 | class CodeReloadHandler(FileSystemEventHandler): 27 | def __init__( 28 | self, 29 | _devapp_filename: str, 30 | _instance_name: str, 31 | dev_websocket_client: DevWebsocketClient, 32 | main_loop: asyncio.AbstractEventLoop, 33 | ): 34 | self._devapp_filename: str = _devapp_filename 35 | self.devapp_instance_name: str = _instance_name 36 | self.dev_websocket_client: DevWebsocketClient = ( 37 | dev_websocket_client # save dev_websocket_client instance 38 | ) 39 | self.timer = None 40 | self.main_loop = main_loop 41 | 42 | def on_modified(self, event): 43 | """Called when a file or directory is modified.""" 44 | if event.src_path.endswith(".py"): 45 | upsert_config({"reloading": True}, "connection") 46 | if self.timer: 47 | self.timer.cancel() 48 | # reload modified file & main file 49 | self.timer = Timer(0.5, self.reload_code, args=(event.src_path,)) 50 | self.timer.start() 51 | 52 | def reload_code(self, modified_file_path: str): 53 | print( 54 | f"[violet]promptmodel:dev:[/violet] Reloading {self._devapp_filename} module due to changes..." 55 | ) 56 | relative_modified_path = os.path.relpath(modified_file_path, os.getcwd()) 57 | # Reload the devapp module 58 | module_name = relative_modified_path.replace("./", "").replace("/", ".")[ 59 | :-3 60 | ] # assuming the file is in the PYTHONPATH 61 | 62 | if module_name in sys.modules: 63 | module = sys.modules[module_name] 64 | importlib.reload(module) 65 | 66 | reloaded_module = importlib.reload(sys.modules[self._devapp_filename]) 67 | print( 68 | f"[violet]promptmodel:dev:[/violet] {self._devapp_filename} module reloaded successfully." 69 | ) 70 | 71 | new_devapp_instance: DevApp = getattr( 72 | reloaded_module, self.devapp_instance_name 73 | ) 74 | 75 | config = read_config() 76 | org = config["connection"]["org"] 77 | project = config["connection"]["project"] 78 | 79 | # save samples, FunctionSchema, FunctionModel, ChatModel to cloud server by websocket ServerTask request 80 | new_function_model_name_list = ( 81 | new_devapp_instance._get_function_model_name_list() 82 | ) 83 | new_chat_model_name_list = new_devapp_instance._get_chat_model_name_list() 84 | new_samples = new_devapp_instance.samples 85 | new_function_schemas = new_devapp_instance._get_function_schema_list() 86 | 87 | res = run_async_in_sync_threadsafe( 88 | self.dev_websocket_client.request( 89 | ServerTask.SYNC_CODE, 90 | message={ 91 | "new_function_model": new_function_model_name_list, 92 | "new_chat_model": new_chat_model_name_list, 93 | "new_samples": new_samples, 94 | "new_schemas": new_function_schemas, 95 | }, 96 | ), 97 | main_loop=self.main_loop, 98 | ) 99 | 100 | # update_samples(new_devapp_instance.samples) 101 | upsert_config({"reloading": False}, "connection") 102 | self.dev_websocket_client.update_devapp_instance(new_devapp_instance) 103 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on first error 4 | set -e 5 | 6 | # Clean up old distribution archives 7 | echo "Cleaning up old builds..." 8 | rm -rf dist/ 9 | rm -rf build/ 10 | rm -rf *.egg-info 11 | 12 | # Generate distribution archives 13 | echo "Building distribution..." 14 | python setup.py sdist bdist_wheel 15 | 16 | # Upload using twine 17 | echo "Uploading to PyPI..." 18 | twine upload dist/* 19 | 20 | echo "Publishing complete!" 21 | 22 | # If you want to automatically open a web browser to check your library on PyPI 23 | # Uncomment the following line and replace 'your_package_name' with the name of your library 24 | # xdg-open "https://pypi.org/project/your_package_name/" 25 | 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pydantic 2 | peewee 3 | typer[all] 4 | cryptography 5 | pyyaml 6 | InquirerPy 7 | litellm 8 | python-dotenv 9 | websockets==10.4 10 | termcolor 11 | watchdog 12 | readerwriterlock 13 | httpx[http2] 14 | pytest 15 | pytest-asyncio 16 | pytest-mock 17 | pytest-cov 18 | nest-asyncio -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt & model versioning on the cloud, built for developers. 3 | """ 4 | 5 | from setuptools import setup, find_namespace_packages 6 | 7 | # Read README.md for the long description 8 | with open("README.md", "r") as fh: 9 | long_description = fh.read() 10 | 11 | setup( 12 | name="promptmodel", 13 | version="0.2.1", 14 | packages=find_namespace_packages(), 15 | entry_points={ 16 | "console_scripts": [ 17 | "prompt = promptmodel.cli.main:app", 18 | ], 19 | }, 20 | description="Prompt & model versioning on the cloud, built for developers.", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | author="weavel", 24 | url="https://github.com/weavel-ai/promptmodel", 25 | install_requires=[ 26 | "httpx[http2]", 27 | "pydantic>=2.4.2", 28 | "peewee", 29 | "typer[all]", 30 | "cryptography", 31 | "pyyaml", 32 | "InquirerPy", 33 | "litellm>=1.7.1", 34 | # "litellm@git+https://github.com/weavel-ai/litellm.git@llms_add_clova_support", 35 | "python-dotenv", 36 | "websockets", 37 | "termcolor", 38 | "watchdog", 39 | "readerwriterlock", 40 | "nest-asyncio", 41 | ], 42 | python_requires=">=3.8.10", 43 | keywords=[ 44 | "weavel", 45 | "agent", 46 | "llm", 47 | "tools", 48 | "promptmodel", 49 | "llm agent", 50 | "prompt", 51 | "versioning", 52 | "eval", 53 | "evaluation", 54 | "collaborative", 55 | ], 56 | ) 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/__init__.py -------------------------------------------------------------------------------- /tests/api_client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/api_client/__init__.py -------------------------------------------------------------------------------- /tests/api_client/api_client_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import MagicMock 3 | 4 | 5 | from promptmodel.apis.base import APIClient, AsyncAPIClient 6 | from promptmodel.utils.crypto import decrypt_message, generate_api_key, encrypt_message 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_get_headers(mocker): 11 | api_client = APIClient() 12 | async_api_client = AsyncAPIClient() 13 | 14 | mock_exit = mocker.patch("builtins.exit", side_effect=SystemExit) 15 | api_key = generate_api_key() 16 | mock_envirion_get = MagicMock(return_value=api_key) 17 | 18 | gt_cli_output = { 19 | "Authorization": f"Bearer {decrypt_message(encrypt_message(api_key))}" 20 | } 21 | gt_api_output = {"Authorization": f"Bearer {api_key}"} 22 | # No user 23 | mocker.patch("promptmodel.apis.base.read_config", return_value={}) 24 | with pytest.raises(SystemExit): 25 | res = api_client._get_headers() 26 | mock_exit.assert_called_once() 27 | mock_exit.reset_mock() 28 | 29 | with pytest.raises(SystemExit): 30 | res = await async_api_client._get_headers() 31 | mock_exit.assert_called_once() 32 | mock_exit.reset_mock() 33 | 34 | mocker.patch("promptmodel.apis.base.read_config", return_value={"user": {}}) 35 | 36 | with pytest.raises(Exception): 37 | res = api_client._get_headers() 38 | with pytest.raises(Exception): 39 | res = await async_api_client._get_headers() 40 | 41 | mocker.patch( 42 | "promptmodel.apis.base.read_config", 43 | return_value={"user": {"encrypted_api_key": encrypt_message(api_key)}}, 44 | ) 45 | res = api_client._get_headers() 46 | assert res == gt_cli_output, "API key is not decrypted properly" 47 | res = await async_api_client._get_headers() 48 | assert res == gt_api_output, "API key is not decrypted properly" 49 | 50 | mocker.patch("promptmodel.apis.base.os.environ.get", mock_envirion_get) 51 | res = api_client._get_headers(use_cli_key=False) 52 | assert res == gt_api_output, "API key is not retrieved properly" 53 | res = await async_api_client._get_headers(use_cli_key=False) 54 | assert res == gt_api_output, "API key is not retrieved properly" 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def test_execute(mocker): 59 | api_client = APIClient() 60 | async_api_client = AsyncAPIClient() 61 | mock_exit = mocker.patch("builtins.exit", side_effect=SystemExit) 62 | mocker.patch("promptmodel.apis.base.APIClient._get_headers", return_value={}) 63 | mocker.patch("promptmodel.apis.base.AsyncAPIClient._get_headers", return_value={}) 64 | 65 | mock_request = mocker.patch( 66 | "promptmodel.apis.base.requests.request", return_value=None 67 | ) 68 | mock_async_request = mocker.patch( 69 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=None 70 | ) 71 | with pytest.raises(SystemExit): 72 | res = api_client.execute(path="test") 73 | mock_request.assert_called_once() 74 | mock_request.reset_mock() 75 | mock_exit.assert_called_once() 76 | mock_exit.reset_mock() 77 | 78 | res = await async_api_client.execute(path="test") 79 | mock_async_request.assert_called_once() 80 | mock_async_request.reset_mock() 81 | 82 | mock_response = MagicMock() 83 | mock_response.status_code = 200 84 | mock_request = mocker.patch( 85 | "promptmodel.apis.base.requests.request", return_value=mock_response 86 | ) 87 | mock_async_request = mocker.patch( 88 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=mock_response 89 | ) 90 | 91 | res = api_client.execute(path="test") 92 | mock_request.assert_called_once() 93 | assert res == mock_response, "Response is not returned properly" 94 | mock_request.reset_mock() 95 | 96 | res = await async_api_client.execute(path="test") 97 | mock_async_request.assert_called_once() 98 | assert res == mock_response, "Response is not returned properly" 99 | mock_async_request.reset_mock() 100 | 101 | mock_response.status_code = 403 102 | mock_request = mocker.patch( 103 | "promptmodel.apis.base.requests.request", return_value=mock_response 104 | ) 105 | mock_async_request = mocker.patch( 106 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=mock_response 107 | ) 108 | with pytest.raises(SystemExit): 109 | res = api_client.execute(path="test") 110 | mock_request.assert_called_once() 111 | mock_request.reset_mock() 112 | mock_exit.assert_called_once() 113 | mock_exit.reset_mock() 114 | 115 | res = await async_api_client.execute(path="test") 116 | mock_async_request.assert_called_once() 117 | mock_async_request.reset_mock() 118 | 119 | mock_response.status_code = 500 120 | mock_request = mocker.patch( 121 | "promptmodel.apis.base.requests.request", return_value=mock_response 122 | ) 123 | mock_async_request = mocker.patch( 124 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=mock_response 125 | ) 126 | with pytest.raises(SystemExit): 127 | res = api_client.execute(path="test") 128 | mock_request.assert_called_once() 129 | mock_request.reset_mock() 130 | mock_exit.assert_called_once() 131 | mock_exit.reset_mock() 132 | 133 | res = await async_api_client.execute(path="test") 134 | mock_async_request.assert_called_once() 135 | mock_async_request.reset_mock() 136 | -------------------------------------------------------------------------------- /tests/chat_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/chat_model/__init__.py -------------------------------------------------------------------------------- /tests/chat_model/chat_model_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, MagicMock 3 | 4 | from typing import Generator, AsyncGenerator, List 5 | 6 | from promptmodel.types.response import LLMResponse, LLMStreamResponse, ChatModelConfig 7 | from promptmodel import ChatModel, DevClient 8 | from promptmodel.dev_app import ChatModelInterface, ChatModelInterface 9 | 10 | client = DevClient() 11 | 12 | 13 | def test_find_client( 14 | mocker, 15 | mock_fetch_chat_model, 16 | mock_async_chat_log_to_cloud: AsyncMock, 17 | ): 18 | fetch_chat_model = mocker.patch( 19 | "promptmodel.chat_model.LLMProxy.fetch_chat_model", 20 | mock_fetch_chat_model, 21 | ) 22 | 23 | mocker.patch( 24 | "promptmodel.chat_model.LLMProxy._async_make_session_cloud", 25 | mock_async_chat_log_to_cloud, 26 | ) 27 | pm = ChatModel("test") 28 | assert client.chat_models == [ChatModelInterface(name="test")] 29 | 30 | 31 | def test_get_config( 32 | mocker, 33 | mock_fetch_chat_model, 34 | mock_async_chat_log_to_cloud: AsyncMock, 35 | ): 36 | fetch_chat_model = mocker.patch( 37 | "promptmodel.chat_model.LLMProxy.fetch_chat_model", mock_fetch_chat_model 38 | ) 39 | 40 | mocker.patch( 41 | "promptmodel.chat_model.LLMProxy._async_make_session_cloud", 42 | mock_async_chat_log_to_cloud, 43 | ) 44 | # mock registering_meta 45 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock()) 46 | chat_model = ChatModel("test") 47 | assert len(client.chat_models) == 1 48 | config: ChatModelConfig = chat_model.get_config() 49 | assert config.system_prompt == "You are a helpful assistant." 50 | 51 | 52 | def test_add_messages( 53 | mocker, 54 | mock_async_chat_log_to_cloud: AsyncMock, 55 | ): 56 | pass 57 | 58 | 59 | def test_run( 60 | mocker, 61 | mock_fetch_chat_model: AsyncMock, 62 | mock_async_chat_log_to_cloud: AsyncMock, 63 | ): 64 | fetch_chat_model = mocker.patch( 65 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 66 | ) 67 | 68 | async_chat_log_to_cloud = mocker.patch( 69 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 70 | mock_async_chat_log_to_cloud, 71 | ) 72 | 73 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock()) 74 | chat_model = ChatModel("test", session_uuid="testuuid") 75 | 76 | res: LLMResponse = chat_model.run() 77 | print(res.api_response.model_dump()) 78 | 79 | fetch_chat_model.assert_called_once() 80 | async_chat_log_to_cloud.assert_called_once() 81 | 82 | assert res.raw_output is not None 83 | assert res.error is None or res.error is False 84 | assert res.api_response is not None 85 | assert res.parsed_outputs is None 86 | 87 | fetch_chat_model.reset_mock() 88 | async_chat_log_to_cloud.reset_mock() 89 | mocker.patch( 90 | "promptmodel.utils.config_utils.read_config", 91 | return_value={"connection": {"initializing": True}}, 92 | ) 93 | res: LLMResponse = chat_model.run() 94 | print(res) 95 | fetch_chat_model.assert_not_called() 96 | async_chat_log_to_cloud.assert_not_called() 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_arun( 101 | mocker, 102 | mock_fetch_chat_model: AsyncMock, 103 | mock_async_chat_log_to_cloud: AsyncMock, 104 | ): 105 | fetch_chat_model = mocker.patch( 106 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 107 | ) 108 | 109 | async_chat_log_to_cloud = mocker.patch( 110 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 111 | mock_async_chat_log_to_cloud, 112 | ) 113 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock()) 114 | chat_model = ChatModel("test", session_uuid="testuuid") 115 | 116 | res: LLMResponse = await chat_model.arun() 117 | print(res.api_response.model_dump()) 118 | 119 | fetch_chat_model.assert_called_once() 120 | async_chat_log_to_cloud.assert_called_once() 121 | 122 | assert res.raw_output is not None 123 | assert res.error is None or res.error is False 124 | assert res.api_response is not None 125 | assert res.parsed_outputs is None 126 | 127 | fetch_chat_model.reset_mock() 128 | async_chat_log_to_cloud.reset_mock() 129 | mocker.patch( 130 | "promptmodel.utils.config_utils.read_config", 131 | return_value={"connection": {"initializing": True}}, 132 | ) 133 | res: LLMResponse = await chat_model.arun() 134 | print(res) 135 | fetch_chat_model.assert_not_called() 136 | async_chat_log_to_cloud.assert_not_called() 137 | 138 | 139 | def test_stream( 140 | mocker, 141 | mock_fetch_chat_model: AsyncMock, 142 | mock_async_chat_log_to_cloud: AsyncMock, 143 | ): 144 | fetch_chat_model = mocker.patch( 145 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 146 | ) 147 | 148 | async_chat_log_to_cloud = mocker.patch( 149 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 150 | mock_async_chat_log_to_cloud, 151 | ) 152 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock()) 153 | chat_model = ChatModel("test", session_uuid="testuuid") 154 | 155 | res: Generator[LLMStreamResponse, None, None] = chat_model.run(stream=True) 156 | chunks: List[LLMStreamResponse] = [] 157 | for chunk in res: 158 | chunks.append(chunk) 159 | 160 | fetch_chat_model.assert_called_once() 161 | async_chat_log_to_cloud.assert_called_once() 162 | 163 | assert chunks[-1].api_response is not None 164 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 165 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 166 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 167 | 168 | fetch_chat_model.reset_mock() 169 | async_chat_log_to_cloud.reset_mock() 170 | mocker.patch( 171 | "promptmodel.utils.config_utils.read_config", 172 | return_value={"connection": {"initializing": True}}, 173 | ) 174 | res: LLMResponse = chat_model.run(stream=True) 175 | print(res) 176 | fetch_chat_model.assert_not_called() 177 | async_chat_log_to_cloud.assert_not_called() 178 | 179 | 180 | @pytest.mark.asyncio 181 | async def test_astream( 182 | mocker, 183 | mock_fetch_chat_model: AsyncMock, 184 | mock_async_chat_log_to_cloud: AsyncMock, 185 | ): 186 | fetch_chat_model = mocker.patch( 187 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 188 | ) 189 | 190 | async_chat_log_to_cloud = mocker.patch( 191 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 192 | mock_async_chat_log_to_cloud, 193 | ) 194 | 195 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock()) 196 | chat_model = ChatModel("test", session_uuid="testuuid") 197 | 198 | res: AsyncGenerator[LLMStreamResponse, None] = await chat_model.arun(stream=True) 199 | 200 | chunks: List[LLMStreamResponse] = [] 201 | async for chunk in res: 202 | chunks.append(chunk) 203 | 204 | fetch_chat_model.assert_called_once() 205 | async_chat_log_to_cloud.assert_called_once() 206 | 207 | assert chunks[-1].api_response is not None 208 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 209 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 210 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 211 | 212 | fetch_chat_model.reset_mock() 213 | async_chat_log_to_cloud.reset_mock() 214 | mocker.patch( 215 | "promptmodel.utils.config_utils.read_config", 216 | return_value={"connection": {"initializing": True}}, 217 | ) 218 | 219 | res: LLMResponse = await chat_model.arun(stream=True) 220 | print(res) 221 | fetch_chat_model.assert_not_called() 222 | async_chat_log_to_cloud.assert_not_called() 223 | -------------------------------------------------------------------------------- /tests/chat_model/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from promptmodel.llms.llm_proxy import LLMProxy 5 | from promptmodel.types.response import ChatModelConfig 6 | 7 | 8 | async def echo_coroutine(*args, **kwargs): 9 | # print(args, kwargs) 10 | return args, kwargs 11 | 12 | 13 | @pytest.fixture 14 | def mock_fetch_chat_model(): 15 | mock_fetch_chat_model = AsyncMock() 16 | mock_instruction = "You are a helpful assistant." 17 | mock_version_details = { 18 | "model": "gpt-4-1106-preview", 19 | "uuid": "testuuid", 20 | "version": 1, 21 | } 22 | mock_message_logs = [ 23 | { 24 | "role": "system", 25 | "content": "You are a helpful assistant.", 26 | "session_uuid": "testuuid", 27 | }, 28 | {"role": "user", "content": "Hello!", "session_uuid": "testuuid"}, 29 | { 30 | "role": "assistant", 31 | "content": "Hello! How can I help you?", 32 | "session_uuid": "testuuid", 33 | }, 34 | ] 35 | 36 | mock_fetch_chat_model.return_value = ( 37 | mock_instruction, 38 | mock_version_details, 39 | mock_message_logs, 40 | ) 41 | 42 | return mock_fetch_chat_model 43 | 44 | 45 | @pytest.fixture 46 | def mock_async_chat_log_to_cloud(): 47 | mock_async_chat_log_to_cloud = AsyncMock() 48 | 49 | mock_response = MagicMock() 50 | mock_response.status_code = 200 51 | mock_async_chat_log_to_cloud.return_value = mock_response 52 | 53 | return mock_async_chat_log_to_cloud 54 | 55 | 56 | @pytest.fixture 57 | def mock_async_make_session_cloud(): 58 | mock_async_make_session_cloud = AsyncMock() 59 | 60 | mock_response = MagicMock() 61 | mock_response.status_code = 200 62 | mock_async_make_session_cloud.return_value = mock_response 63 | 64 | return mock_async_make_session_cloud 65 | -------------------------------------------------------------------------------- /tests/chat_model/registering_meta_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from promptmodel.chat_model import RegisteringMeta 5 | 6 | 7 | def test_registering_meta(mocker): 8 | # Fail to find DevClient instance 9 | client_instance = RegisteringMeta.find_client_instance() 10 | assert client_instance is None 11 | -------------------------------------------------------------------------------- /tests/cli/dev_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from typing import Generator, AsyncGenerator, Dict, List, Any 5 | from litellm import ModelResponse 6 | 7 | from promptmodel.llms.llm import LLM 8 | from promptmodel.llms.llm_proxy import LLMProxy 9 | from promptmodel.types.response import LLMResponse, LLMStreamResponse 10 | -------------------------------------------------------------------------------- /tests/constants.py: -------------------------------------------------------------------------------- 1 | # JSON Schema to pass to OpenAI 2 | function_shemas = [ 3 | { 4 | "name": "get_current_weather", 5 | "description": "Get the current weather in a given location", 6 | "parameters": { 7 | "type": "object", 8 | "properties": { 9 | "location": { 10 | "type": "string", 11 | "description": "The city and state, e.g. San Francisco, CA", 12 | }, 13 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, 14 | }, 15 | "required": ["location"], 16 | }, 17 | } 18 | ] 19 | -------------------------------------------------------------------------------- /tests/function_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/function_model/__init__.py -------------------------------------------------------------------------------- /tests/function_model/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from promptmodel.llms.llm_proxy import LLMProxy 5 | 6 | 7 | async def echo_coroutine(*args, **kwargs): 8 | # print(args, kwargs) 9 | return args, kwargs 10 | 11 | 12 | @pytest.fixture 13 | def mock_fetch_prompts(): 14 | mock_fetch_prompts = AsyncMock() 15 | mock_prompts = [ 16 | {"role": "system", "content": "You are a helpful assistant.", "step": 1}, 17 | {"role": "user", "content": "Hello!", "step": 2}, 18 | ] 19 | mock_version_details = { 20 | "model": "gpt-3.5-turbo", 21 | "uuid": "testuuid", 22 | "version": 1, 23 | "parsing_type": None, 24 | "output_keys": None, 25 | } 26 | mock_fetch_prompts.return_value = (mock_prompts, mock_version_details) 27 | 28 | return mock_fetch_prompts 29 | 30 | 31 | @pytest.fixture 32 | def mock_async_log_to_cloud(): 33 | mock_async_log_to_cloud = AsyncMock() 34 | 35 | mock_response = MagicMock() 36 | mock_response.status_code = 200 37 | mock_async_log_to_cloud.return_value = mock_response 38 | 39 | return mock_async_log_to_cloud 40 | -------------------------------------------------------------------------------- /tests/function_model/function_model_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, MagicMock 3 | 4 | from typing import Generator, AsyncGenerator, List 5 | 6 | from promptmodel.types.response import LLMResponse, LLMStreamResponse 7 | from promptmodel import FunctionModel, DevClient 8 | from promptmodel.dev_app import FunctionModelInterface 9 | 10 | client = DevClient() 11 | 12 | 13 | def test_find_client(mocker): 14 | pm = FunctionModel("test") 15 | assert client.function_models == [FunctionModelInterface(name="test")] 16 | 17 | 18 | def test_get_config(mocker, mock_fetch_prompts): 19 | fetch_prompts = mocker.patch( 20 | "promptmodel.function_model.LLMProxy.fetch_prompts", mock_fetch_prompts 21 | ) 22 | # mock registering_meta 23 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 24 | function_model = FunctionModel("test") 25 | assert len(client.function_models) == 1 26 | config = function_model.get_config() 27 | assert len(config.prompts) == 2 28 | 29 | 30 | def test_run(mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock): 31 | fetch_prompts = mocker.patch( 32 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 33 | ) 34 | async_log_to_cloud = mocker.patch( 35 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 36 | mock_async_log_to_cloud, 37 | ) 38 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 39 | function_model = FunctionModel("test") 40 | res: LLMResponse = function_model.run({}) 41 | fetch_prompts.assert_called_once() 42 | async_log_to_cloud.assert_called_once() 43 | assert res.raw_output is not None 44 | assert res.error is None or res.error is False 45 | assert res.api_response is not None 46 | assert res.parsed_outputs is None 47 | 48 | fetch_prompts.reset_mock() 49 | async_log_to_cloud.reset_mock() 50 | mocker.patch( 51 | "promptmodel.utils.config_utils.read_config", 52 | return_value={"connection": {"initializing": True}}, 53 | ) 54 | res = function_model.run() 55 | print(res) 56 | fetch_prompts.assert_not_called() 57 | async_log_to_cloud.assert_not_called() 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_arun( 62 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 63 | ): 64 | fetch_prompts = mocker.patch( 65 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 66 | ) 67 | async_log_to_cloud = mocker.patch( 68 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 69 | mock_async_log_to_cloud, 70 | ) 71 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 72 | function_model = FunctionModel("test") 73 | 74 | res: LLMResponse = await function_model.arun({}) 75 | fetch_prompts.assert_called_once() 76 | async_log_to_cloud.assert_called_once() 77 | assert res.raw_output is not None 78 | assert res.error is None or res.error is False 79 | assert res.api_response is not None 80 | assert res.parsed_outputs is None 81 | 82 | fetch_prompts.reset_mock() 83 | async_log_to_cloud.reset_mock() 84 | mocker.patch( 85 | "promptmodel.utils.config_utils.read_config", 86 | return_value={"connection": {"initializing": True}}, 87 | ) 88 | res = await function_model.arun() 89 | print(res) 90 | fetch_prompts.assert_not_called() 91 | async_log_to_cloud.assert_not_called() 92 | 93 | 94 | def test_stream( 95 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 96 | ): 97 | fetch_prompts = mocker.patch( 98 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 99 | ) 100 | async_log_to_cloud = mocker.patch( 101 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 102 | mock_async_log_to_cloud, 103 | ) 104 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 105 | function_model = FunctionModel("test") 106 | 107 | res: Generator[LLMStreamResponse, None, None] = function_model.stream({}) 108 | chunks: List[LLMStreamResponse] = [] 109 | for chunk in res: 110 | chunks.append(chunk) 111 | fetch_prompts.assert_called_once() 112 | async_log_to_cloud.assert_called_once() 113 | 114 | assert chunks[-1].api_response is not None 115 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 116 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 117 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 118 | 119 | fetch_prompts.reset_mock() 120 | async_log_to_cloud.reset_mock() 121 | mocker.patch( 122 | "promptmodel.utils.config_utils.read_config", 123 | return_value={"connection": {"initializing": True}}, 124 | ) 125 | res = function_model.stream() 126 | print(res) 127 | fetch_prompts.assert_not_called() 128 | async_log_to_cloud.assert_not_called() 129 | 130 | 131 | @pytest.mark.asyncio 132 | async def test_astream( 133 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 134 | ): 135 | fetch_prompts = mocker.patch( 136 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 137 | ) 138 | async_log_to_cloud = mocker.patch( 139 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 140 | mock_async_log_to_cloud, 141 | ) 142 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 143 | function_model = FunctionModel("test") 144 | 145 | res: AsyncGenerator[LLMStreamResponse, None] = await function_model.astream({}) 146 | chunks: List[LLMStreamResponse] = [] 147 | async for chunk in res: 148 | chunks.append(chunk) 149 | fetch_prompts.assert_called_once() 150 | async_log_to_cloud.assert_called_once() 151 | 152 | assert chunks[-1].api_response is not None 153 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 154 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 155 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 156 | 157 | fetch_prompts.reset_mock() 158 | async_log_to_cloud.reset_mock() 159 | mocker.patch( 160 | "promptmodel.utils.config_utils.read_config", 161 | return_value={"connection": {"initializing": True}}, 162 | ) 163 | res = await function_model.astream({}) 164 | print(res) 165 | fetch_prompts.assert_not_called() 166 | async_log_to_cloud.assert_not_called() 167 | 168 | 169 | def test_run_and_parse( 170 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 171 | ): 172 | fetch_prompts = mocker.patch( 173 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 174 | ) 175 | async_log_to_cloud = mocker.patch( 176 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 177 | mock_async_log_to_cloud, 178 | ) 179 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 180 | function_model = FunctionModel("test") 181 | res: LLMResponse = function_model.run_and_parse({}) 182 | fetch_prompts.assert_called_once() 183 | async_log_to_cloud.assert_called_once() 184 | assert res.raw_output is not None 185 | assert res.error is None or res.error is False 186 | assert res.api_response is not None 187 | assert res.parsed_outputs == {} 188 | 189 | fetch_prompts.reset_mock() 190 | async_log_to_cloud.reset_mock() 191 | mocker.patch( 192 | "promptmodel.utils.config_utils.read_config", 193 | return_value={"connection": {"initializing": True}}, 194 | ) 195 | res = function_model.run_and_parse({}) 196 | print(res) 197 | fetch_prompts.assert_not_called() 198 | async_log_to_cloud.assert_not_called() 199 | 200 | 201 | @pytest.mark.asyncio 202 | async def test_arun_and_parse( 203 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 204 | ): 205 | fetch_prompts = mocker.patch( 206 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 207 | ) 208 | async_log_to_cloud = mocker.patch( 209 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 210 | mock_async_log_to_cloud, 211 | ) 212 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 213 | function_model = FunctionModel("test") 214 | 215 | res: LLMResponse = await function_model.arun_and_parse({}) 216 | fetch_prompts.assert_called_once() 217 | async_log_to_cloud.assert_called_once() 218 | assert res.raw_output is not None 219 | assert res.error is None or res.error is False 220 | assert res.api_response is not None 221 | assert res.parsed_outputs == {} 222 | 223 | fetch_prompts.reset_mock() 224 | async_log_to_cloud.reset_mock() 225 | mocker.patch( 226 | "promptmodel.utils.config_utils.read_config", 227 | return_value={"connection": {"initializing": True}}, 228 | ) 229 | res = await function_model.arun_and_parse({}) 230 | print(res) 231 | fetch_prompts.assert_not_called() 232 | async_log_to_cloud.assert_not_called() 233 | 234 | 235 | def test_stream_and_parse( 236 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 237 | ): 238 | fetch_prompts = mocker.patch( 239 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 240 | ) 241 | async_log_to_cloud = mocker.patch( 242 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 243 | mock_async_log_to_cloud, 244 | ) 245 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 246 | function_model = FunctionModel("test") 247 | 248 | res: Generator[LLMStreamResponse, None, None] = function_model.stream_and_parse({}) 249 | chunks: List[LLMStreamResponse] = [] 250 | for chunk in res: 251 | chunks.append(chunk) 252 | fetch_prompts.assert_called_once() 253 | async_log_to_cloud.assert_called_once() 254 | 255 | assert chunks[-1].api_response is not None 256 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 257 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 258 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 259 | 260 | fetch_prompts.reset_mock() 261 | async_log_to_cloud.reset_mock() 262 | mocker.patch( 263 | "promptmodel.utils.config_utils.read_config", 264 | return_value={"connection": {"initializing": True}}, 265 | ) 266 | res = function_model.stream_and_parse({}) 267 | print(res) 268 | fetch_prompts.assert_not_called() 269 | async_log_to_cloud.assert_not_called() 270 | 271 | 272 | @pytest.mark.asyncio 273 | async def test_astream_and_parse( 274 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 275 | ): 276 | fetch_prompts = mocker.patch( 277 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 278 | ) 279 | async_log_to_cloud = mocker.patch( 280 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 281 | mock_async_log_to_cloud, 282 | ) 283 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock()) 284 | function_model = FunctionModel("test") 285 | 286 | res: AsyncGenerator[ 287 | LLMStreamResponse, None 288 | ] = await function_model.astream_and_parse({}) 289 | chunks: List[LLMStreamResponse] = [] 290 | async for chunk in res: 291 | chunks.append(chunk) 292 | fetch_prompts.assert_called_once() 293 | async_log_to_cloud.assert_called_once() 294 | 295 | assert chunks[-1].api_response is not None 296 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 297 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 298 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 299 | 300 | fetch_prompts.reset_mock() 301 | async_log_to_cloud.reset_mock() 302 | mocker.patch( 303 | "promptmodel.utils.config_utils.read_config", 304 | return_value={"connection": {"initializing": True}}, 305 | ) 306 | res = await function_model.astream_and_parse({}) 307 | print(res) 308 | fetch_prompts.assert_not_called() 309 | async_log_to_cloud.assert_not_called() 310 | -------------------------------------------------------------------------------- /tests/function_model/registering_meta_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from promptmodel.function_model import RegisteringMeta 5 | 6 | 7 | def test_registering_meta(mocker): 8 | # Fail to find DevClient instance 9 | client_instance = RegisteringMeta.find_client_instance() 10 | assert client_instance is None 11 | -------------------------------------------------------------------------------- /tests/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/__init__.py -------------------------------------------------------------------------------- /tests/llm/function_call_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | from ..constants import function_shemas 5 | from promptmodel.llms.llm import LLM 6 | from promptmodel.types.response import * 7 | from promptmodel.types.enums import ParsingType 8 | 9 | html_output_format = """\ 10 | You must follow the provided output format. Keep the string between <> as it is. 11 | Output format: 12 | 13 | (value here) 14 | 15 | """ 16 | 17 | double_bracket_output_format = """\ 18 | You must follow the provided output format. Keep the string between [[ ]] as it is. 19 | Output format: 20 | [[response type=str]] 21 | (value here) 22 | [[/response]] 23 | """ 24 | 25 | colon_output_format = """\ 26 | You must follow the provided output format. Keep the string before : as it is 27 | Output format: 28 | response type=str: (value here) 29 | 30 | """ 31 | 32 | 33 | def test_run_with_functions(mocker): 34 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 35 | llm = LLM() 36 | res: LLMResponse = llm.run( 37 | messages=messages, 38 | functions=function_shemas, 39 | model="gpt-3.5-turbo-0613", 40 | ) 41 | 42 | assert res.error is None, "error is not None" 43 | assert res.api_response is not None, "api_response is None" 44 | assert ( 45 | res.api_response.choices[0].finish_reason == "function_call" 46 | ), "finish_reason is not function_call" 47 | print(res.api_response.model_dump()) 48 | print(res.__dict__) 49 | assert res.function_call is not None, "function_call is None" 50 | assert isinstance(res.function_call, FunctionCall) 51 | 52 | messages = [{"role": "user", "content": "Hello, How are you?"}] 53 | 54 | res: LLMResponse = llm.run( 55 | messages=messages, 56 | functions=function_shemas, 57 | model="gpt-3.5-turbo-0613", 58 | ) 59 | 60 | assert res.error is None, "error is not None" 61 | assert res.api_response is not None, "api_response is None" 62 | assert ( 63 | res.api_response.choices[0].finish_reason == "stop" 64 | ), "finish_reason is not stop" 65 | 66 | assert res.function_call is None, "function_call is not None" 67 | assert res.raw_output is not None, "raw_output is None" 68 | 69 | 70 | @pytest.mark.asyncio 71 | async def test_arun_with_functions(mocker): 72 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 73 | llm = LLM() 74 | res: LLMResponse = await llm.arun( 75 | messages=messages, 76 | functions=function_shemas, 77 | model="gpt-3.5-turbo-0613", 78 | ) 79 | 80 | assert res.error is None, "error is not None" 81 | assert res.api_response is not None, "api_response is None" 82 | assert ( 83 | res.api_response.choices[0].finish_reason == "function_call" 84 | ), "finish_reason is not function_call" 85 | 86 | assert res.function_call is not None, "function_call is None" 87 | assert isinstance(res.function_call, FunctionCall) 88 | 89 | messages = [{"role": "user", "content": "Hello, How are you?"}] 90 | 91 | res: LLMResponse = await llm.arun( 92 | messages=messages, 93 | functions=function_shemas, 94 | model="gpt-3.5-turbo-0613", 95 | ) 96 | 97 | assert res.error is None, "error is not None" 98 | assert res.api_response is not None, "api_response is None" 99 | assert ( 100 | res.api_response.choices[0].finish_reason == "stop" 101 | ), "finish_reason is not stop" 102 | 103 | assert res.function_call is None, "function_call is not None" 104 | assert res.raw_output is not None, "raw_output is None" 105 | 106 | 107 | def test_run_and_parse_with_functions(mocker): 108 | # With parsing_type = None 109 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 110 | llm = LLM() 111 | res: LLMResponse = llm.run_and_parse( 112 | messages=messages, 113 | functions=function_shemas, 114 | model="gpt-3.5-turbo-0613", 115 | parsing_type=None, 116 | ) 117 | assert res.error is False, "error is not False" 118 | assert res.api_response is not None, "api_response is None" 119 | assert ( 120 | res.api_response.choices[0].finish_reason == "function_call" 121 | ), "finish_reason is not function_call" 122 | 123 | assert res.function_call is not None, "function_call is None" 124 | assert isinstance(res.function_call, FunctionCall) 125 | 126 | messages = [{"role": "user", "content": "Hello, How are you?"}] 127 | 128 | res: LLMResponse = llm.run_and_parse( 129 | messages=messages, 130 | functions=function_shemas, 131 | model="gpt-3.5-turbo-0613", 132 | parsing_type=None, 133 | ) 134 | 135 | assert res.error is False, "error is not False" 136 | assert res.api_response is not None, "api_response is None" 137 | assert ( 138 | res.api_response.choices[0].finish_reason == "stop" 139 | ), "finish_reason is not stop" 140 | 141 | assert res.function_call is None, "function_call is not None" 142 | assert res.raw_output is not None, "raw_output is None" 143 | 144 | # with parsing_type = "HTML" 145 | 146 | messages = [ 147 | { 148 | "role": "user", 149 | "content": "What is the weather like in Boston? \n" + html_output_format, 150 | } 151 | ] 152 | llm = LLM() 153 | res: LLMResponse = llm.run_and_parse( 154 | messages=messages, 155 | functions=function_shemas, 156 | model="gpt-3.5-turbo-0613", 157 | parsing_type=ParsingType.HTML.value, 158 | output_keys=["response"], 159 | ) 160 | 161 | # 1. Output 지키고 function call -> (Pass) 162 | # 2. Output 지키고 stop -> OK 163 | # 3. Output 무시하고 function call -> OK (function call이 나타나면 파싱을 하지 않도록 수정) 164 | 165 | # In this case, error is True because the output is not in the correct format 166 | assert res.error is False, "error is not False" 167 | assert res.api_response is not None, "api_response is None" 168 | assert ( 169 | res.api_response.choices[0].finish_reason == "function_call" 170 | ), "finish_reason is not function_call" 171 | 172 | assert res.function_call is not None, "function_call is None" 173 | assert isinstance(res.function_call, FunctionCall) 174 | 175 | assert res.parsed_outputs is None, "parsed_outputs is not empty" 176 | 177 | messages = [ 178 | { 179 | "role": "user", 180 | "content": "Hello, How are you?\n" + html_output_format, 181 | } 182 | ] 183 | 184 | res: LLMResponse = llm.run_and_parse( 185 | messages=messages, 186 | functions=function_shemas, 187 | model="gpt-4-1106-preview", 188 | parsing_type=ParsingType.HTML.value, 189 | output_keys=["response"], 190 | ) 191 | 192 | print(res.__dict__) 193 | 194 | if not "str" in res.raw_output: 195 | # if "str" in res.raw_output, it means that LLM make mistakes 196 | assert res.error is False, "error is not False" 197 | assert res.api_response is not None, "api_response is None" 198 | assert ( 199 | res.api_response.choices[0].finish_reason == "stop" 200 | ), "finish_reason is not stop" 201 | 202 | assert res.function_call is None, "function_call is not None" 203 | assert res.raw_output is not None, "raw_output is None" 204 | assert res.parsed_outputs != {}, "parsed_outputs is empty dict" 205 | 206 | 207 | @pytest.mark.asyncio 208 | async def test_arun_and_parse_with_functions(mocker): 209 | # With parsing_type = None 210 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 211 | llm = LLM() 212 | res: LLMResponse = await llm.arun_and_parse( 213 | messages=messages, 214 | functions=function_shemas, 215 | model="gpt-3.5-turbo-0613", 216 | parsing_type=None, 217 | ) 218 | 219 | assert res.error is False, "error is not False" 220 | assert res.api_response is not None, "api_response is None" 221 | assert ( 222 | res.api_response.choices[0].finish_reason == "function_call" 223 | ), "finish_reason is not function_call" 224 | 225 | assert res.function_call is not None, "function_call is None" 226 | assert isinstance(res.function_call, FunctionCall) 227 | 228 | messages = [{"role": "user", "content": "Hello, How are you?"}] 229 | 230 | res: LLMResponse = await llm.arun_and_parse( 231 | messages=messages, 232 | functions=function_shemas, 233 | model="gpt-3.5-turbo-0613", 234 | parsing_type=None, 235 | ) 236 | 237 | assert res.error is False, "error is not False" 238 | assert res.api_response is not None, "api_response is None" 239 | assert ( 240 | res.api_response.choices[0].finish_reason == "stop" 241 | ), "finish_reason is not stop" 242 | 243 | assert res.function_call is None, "function_call is not None" 244 | assert res.raw_output is not None, "raw_output is None" 245 | 246 | # with parsing_type = "HTML" 247 | 248 | messages = [ 249 | { 250 | "role": "user", 251 | "content": "What is the weather like in Boston? \n" + html_output_format, 252 | } 253 | ] 254 | llm = LLM() 255 | res: LLMResponse = await llm.arun_and_parse( 256 | messages=messages, 257 | functions=function_shemas, 258 | model="gpt-3.5-turbo-0613", 259 | parsing_type=ParsingType.HTML.value, 260 | output_keys=["response"], 261 | ) 262 | 263 | # In this case, error is False becuase if function_call, parsing is not performed 264 | assert res.error is False, "error is not False" 265 | assert res.api_response is not None, "api_response is None" 266 | assert ( 267 | res.api_response.choices[0].finish_reason == "function_call" 268 | ), "finish_reason is not function_call" 269 | 270 | assert res.function_call is not None, "function_call is None" 271 | assert isinstance(res.function_call, FunctionCall) 272 | 273 | assert res.parsed_outputs is None, "parsed_outputs is not empty" 274 | 275 | messages = [ 276 | { 277 | "role": "user", 278 | "content": "Hello, How are you?\n" + html_output_format, 279 | } 280 | ] 281 | 282 | res: LLMResponse = await llm.arun_and_parse( 283 | messages=messages, 284 | functions=function_shemas, 285 | model="gpt-4-1106-preview", 286 | parsing_type=ParsingType.HTML.value, 287 | output_keys=["response"], 288 | ) 289 | # if not "str" in res.raw_output: 290 | # # if "str" in res.raw_output, it means that LLM make mistakes 291 | assert res.error is False, "error is not False" 292 | assert res.parsed_outputs != {}, "parsed_outputs is empty" 293 | 294 | assert res.api_response is not None, "api_response is None" 295 | assert ( 296 | res.api_response.choices[0].finish_reason == "stop" 297 | ), "finish_reason is not stop" 298 | 299 | assert res.function_call is None, "function_call is not None" 300 | assert res.raw_output is not None, "raw_output is None" 301 | -------------------------------------------------------------------------------- /tests/llm/llm_dev/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/llm_dev/__init__.py -------------------------------------------------------------------------------- /tests/llm/llm_dev/llm_dev_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/llm_dev/llm_dev_test.py -------------------------------------------------------------------------------- /tests/llm/llm_proxy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/llm_proxy/__init__.py -------------------------------------------------------------------------------- /tests/llm/llm_proxy/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from promptmodel.llms.llm_proxy import LLMProxy 5 | 6 | 7 | async def echo_coroutine(*args, **kwargs): 8 | # print(args, kwargs) 9 | return args, kwargs 10 | 11 | 12 | @pytest.fixture 13 | def mock_fetch_prompts(): 14 | mock_fetch_prompts = AsyncMock() 15 | mock_prompts = [ 16 | {"role": "system", "content": "You are a helpful assistant.", "step": 1}, 17 | {"role": "user", "content": "Hello!", "step": 2}, 18 | ] 19 | mock_version_details = { 20 | "model": "gpt-3.5-turbo", 21 | "uuid": "testuuid", 22 | "version": 1, 23 | "parsing_type": None, 24 | "output_keys": None, 25 | } 26 | mock_fetch_prompts.return_value = (mock_prompts, mock_version_details) 27 | 28 | return mock_fetch_prompts 29 | 30 | 31 | @pytest.fixture 32 | def mock_fetch_chat_model(): 33 | mock_fetch_chat_model = AsyncMock() 34 | mock_instruction = "You are a helpful assistant." 35 | mock_version_details = { 36 | "model": "gpt-4-1106-preview", 37 | "uuid": "testuuid", 38 | "version": 1, 39 | } 40 | mock_message_logs = [ 41 | { 42 | "role": "system", 43 | "content": "You are a helpful assistant.", 44 | "session_uuid": "testuuid", 45 | }, 46 | {"role": "user", "content": "Hello!", "session_uuid": "testuuid"}, 47 | { 48 | "role": "assistant", 49 | "content": "Hello! How can I help you?", 50 | "session_uuid": "testuuid", 51 | }, 52 | ] 53 | 54 | mock_fetch_chat_model.return_value = ( 55 | mock_instruction, 56 | mock_version_details, 57 | mock_message_logs, 58 | ) 59 | 60 | return mock_fetch_chat_model 61 | 62 | 63 | @pytest.fixture 64 | def mock_async_log_to_cloud(): 65 | mock_async_log_to_cloud = AsyncMock() 66 | 67 | mock_response = MagicMock() 68 | mock_response.status_code = 200 69 | mock_async_log_to_cloud.return_value = mock_response 70 | 71 | return mock_async_log_to_cloud 72 | 73 | 74 | @pytest.fixture 75 | def mock_async_chat_log_to_cloud(): 76 | mock_async_chat_log_to_cloud = AsyncMock() 77 | 78 | mock_response = MagicMock() 79 | mock_response.status_code = 200 80 | mock_async_chat_log_to_cloud.return_value = mock_response 81 | 82 | return mock_async_chat_log_to_cloud 83 | 84 | 85 | @pytest.fixture 86 | def mock_async_make_session_cloud(): 87 | mock_async_make_session_cloud = AsyncMock() 88 | 89 | mock_response = MagicMock() 90 | mock_response.status_code = 200 91 | mock_async_make_session_cloud.return_value = mock_response 92 | 93 | return mock_async_make_session_cloud 94 | -------------------------------------------------------------------------------- /tests/llm/llm_proxy/proxy_chat_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock 3 | 4 | from typing import Generator, AsyncGenerator, List 5 | 6 | from promptmodel.llms.llm_proxy import LLMProxy 7 | from promptmodel.types.response import LLMResponse, LLMStreamResponse 8 | 9 | proxy = LLMProxy(name="test") 10 | 11 | 12 | def test_chat_run( 13 | mocker, 14 | mock_fetch_chat_model: AsyncMock, 15 | mock_async_chat_log_to_cloud: AsyncMock, 16 | ): 17 | fetch_chat_model = mocker.patch( 18 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 19 | ) 20 | 21 | async_chat_log_to_cloud = mocker.patch( 22 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 23 | mock_async_chat_log_to_cloud, 24 | ) 25 | 26 | res: LLMResponse = proxy.chat_run(session_uuid="testuuid") 27 | 28 | fetch_chat_model.assert_called_once() 29 | async_chat_log_to_cloud.assert_called_once() 30 | 31 | assert res.raw_output is not None 32 | assert res.error is None or res.error is False 33 | assert res.api_response is not None 34 | assert res.parsed_outputs is None 35 | if isinstance(res.api_response.usage, dict): 36 | assert res.api_response.usage["prompt_tokens"] > 15 37 | print(res.api_response.usage["prompt_tokens"]) 38 | else: 39 | assert res.api_response.usage.prompt_tokens > 15 40 | print(res.api_response.usage.prompt_tokens) 41 | 42 | 43 | @pytest.mark.asyncio 44 | async def test_chat_arun( 45 | mocker, 46 | mock_fetch_chat_model: AsyncMock, 47 | mock_async_chat_log_to_cloud: AsyncMock, 48 | ): 49 | fetch_chat_model = mocker.patch( 50 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 51 | ) 52 | 53 | async_chat_log_to_cloud = mocker.patch( 54 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 55 | mock_async_chat_log_to_cloud, 56 | ) 57 | 58 | res: LLMResponse = await proxy.chat_arun(session_uuid="testuuid") 59 | 60 | fetch_chat_model.assert_called_once() 61 | async_chat_log_to_cloud.assert_called_once() 62 | 63 | assert res.raw_output is not None 64 | assert res.error is None or res.error is False 65 | assert res.api_response is not None 66 | assert res.parsed_outputs is None 67 | 68 | 69 | def test_chat_stream( 70 | mocker, 71 | mock_fetch_chat_model: AsyncMock, 72 | mock_async_chat_log_to_cloud: AsyncMock, 73 | ): 74 | fetch_chat_model = mocker.patch( 75 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 76 | ) 77 | 78 | async_chat_log_to_cloud = mocker.patch( 79 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 80 | mock_async_chat_log_to_cloud, 81 | ) 82 | 83 | res: Generator[LLMStreamResponse, None, None] = proxy.chat_stream( 84 | session_uuid="testuuid" 85 | ) 86 | chunks: List[LLMStreamResponse] = [] 87 | for chunk in res: 88 | chunks.append(chunk) 89 | 90 | fetch_chat_model.assert_called_once() 91 | async_chat_log_to_cloud.assert_called_once() 92 | 93 | assert chunks[-1].api_response is not None 94 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 95 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 96 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_chat_astream( 101 | mocker, 102 | mock_fetch_chat_model: AsyncMock, 103 | mock_async_chat_log_to_cloud: AsyncMock, 104 | ): 105 | fetch_chat_model = mocker.patch( 106 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 107 | ) 108 | 109 | async_chat_log_to_cloud = mocker.patch( 110 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 111 | mock_async_chat_log_to_cloud, 112 | ) 113 | 114 | res: AsyncGenerator[LLMStreamResponse, None] = proxy.chat_astream( 115 | session_uuid="testuuid" 116 | ) 117 | chunks: List[LLMStreamResponse] = [] 118 | async for chunk in res: 119 | chunks.append(chunk) 120 | 121 | fetch_chat_model.assert_called_once() 122 | async_chat_log_to_cloud.assert_called_once() 123 | 124 | assert chunks[-1].api_response is not None 125 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 126 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 127 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 128 | 129 | 130 | def test_chat_run_extra_long_input( 131 | mocker, 132 | mock_fetch_chat_model: AsyncMock, 133 | mock_async_chat_log_to_cloud: AsyncMock, 134 | ): 135 | mocker.patch("promptmodel.llms.llm_proxy.get_max_tokens", return_value=10) 136 | 137 | fetch_chat_model = mocker.patch( 138 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model 139 | ) 140 | 141 | async_chat_log_to_cloud = mocker.patch( 142 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud", 143 | mock_async_chat_log_to_cloud, 144 | ) 145 | 146 | res: LLMResponse = proxy.chat_run(session_uuid="testuuid") 147 | 148 | fetch_chat_model.assert_called_once() 149 | async_chat_log_to_cloud.assert_called_once() 150 | 151 | assert res.raw_output is not None 152 | assert res.error is None or res.error is False 153 | assert res.api_response is not None 154 | assert res.parsed_outputs is None 155 | if isinstance(res.api_response.usage, dict): 156 | assert res.api_response.usage["prompt_tokens"] < 15 157 | print(res.api_response.usage["prompt_tokens"]) 158 | else: 159 | assert res.api_response.usage.prompt_tokens < 15 160 | print(res.api_response.usage.prompt_tokens) 161 | -------------------------------------------------------------------------------- /tests/llm/llm_proxy/proxy_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, MagicMock 3 | 4 | from typing import Generator, AsyncGenerator, List 5 | 6 | from promptmodel.llms.llm_proxy import LLMProxy 7 | from promptmodel.types.response import LLMResponse, LLMStreamResponse 8 | 9 | proxy = LLMProxy(name="test") 10 | 11 | 12 | def test_run(mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock): 13 | fetch_prompts = mocker.patch( 14 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 15 | ) 16 | async_log_to_cloud = mocker.patch( 17 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 18 | mock_async_log_to_cloud, 19 | ) 20 | 21 | res: LLMResponse = proxy.run({}) 22 | fetch_prompts.assert_called_once() 23 | async_log_to_cloud.assert_called_once() 24 | assert res.raw_output is not None 25 | assert res.error is None or res.error is False 26 | assert res.api_response is not None 27 | assert res.parsed_outputs is None 28 | 29 | 30 | @pytest.mark.asyncio 31 | async def test_arun( 32 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 33 | ): 34 | fetch_prompts = mocker.patch( 35 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 36 | ) 37 | async_log_to_cloud = mocker.patch( 38 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 39 | mock_async_log_to_cloud, 40 | ) 41 | 42 | res: LLMResponse = await proxy.arun({}) 43 | fetch_prompts.assert_called_once() 44 | async_log_to_cloud.assert_called_once() 45 | assert res.raw_output is not None 46 | assert res.error is None or res.error is False 47 | assert res.api_response is not None 48 | assert res.parsed_outputs is None 49 | 50 | 51 | def test_stream( 52 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 53 | ): 54 | fetch_prompts = mocker.patch( 55 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 56 | ) 57 | async_log_to_cloud = mocker.patch( 58 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 59 | mock_async_log_to_cloud, 60 | ) 61 | 62 | res: Generator[LLMStreamResponse, None, None] = proxy.stream({}) 63 | chunks: List[LLMStreamResponse] = [] 64 | for chunk in res: 65 | chunks.append(chunk) 66 | fetch_prompts.assert_called_once() 67 | async_log_to_cloud.assert_called_once() 68 | 69 | assert chunks[-1].api_response is not None 70 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 71 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 72 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_astream( 77 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 78 | ): 79 | fetch_prompts = mocker.patch( 80 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 81 | ) 82 | async_log_to_cloud = mocker.patch( 83 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 84 | mock_async_log_to_cloud, 85 | ) 86 | 87 | res: AsyncGenerator[LLMStreamResponse, None] = proxy.astream({}) 88 | chunks: List[LLMStreamResponse] = [] 89 | async for chunk in res: 90 | chunks.append(chunk) 91 | fetch_prompts.assert_called_once() 92 | async_log_to_cloud.assert_called_once() 93 | 94 | assert chunks[-1].api_response is not None 95 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 96 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 97 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 98 | 99 | 100 | def test_run_and_parse( 101 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 102 | ): 103 | fetch_prompts = mocker.patch( 104 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 105 | ) 106 | async_log_to_cloud = mocker.patch( 107 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 108 | mock_async_log_to_cloud, 109 | ) 110 | 111 | res: LLMResponse = proxy.run({}) 112 | fetch_prompts.assert_called_once() 113 | async_log_to_cloud.assert_called_once() 114 | assert res.raw_output is not None 115 | assert res.error is None or res.error is False 116 | assert res.api_response is not None 117 | assert res.parsed_outputs is None 118 | 119 | fetch_prompts.reset_mock() 120 | async_log_to_cloud.reset_mock() 121 | 122 | # mock run 123 | mock_run = MagicMock() 124 | mock_run.return_value = LLMResponse( 125 | api_response=res.api_response, parsed_outputs={"key": "value"} 126 | ) 127 | mocker.patch("promptmodel.llms.llm.LLM.run", mock_run) 128 | mock_res = proxy.run({}) 129 | fetch_prompts.assert_called_once() 130 | async_log_to_cloud.assert_called_once() 131 | assert mock_res.raw_output is None 132 | assert mock_res.error is None or res.error is False 133 | assert mock_res.api_response is not None 134 | assert mock_res.parsed_outputs is not None 135 | 136 | 137 | @pytest.mark.asyncio 138 | async def test_arun_and_parse( 139 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 140 | ): 141 | fetch_prompts = mocker.patch( 142 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 143 | ) 144 | async_log_to_cloud = mocker.patch( 145 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 146 | mock_async_log_to_cloud, 147 | ) 148 | 149 | res: LLMResponse = await proxy.arun({}) 150 | fetch_prompts.assert_called_once() 151 | async_log_to_cloud.assert_called_once() 152 | assert res.raw_output is not None 153 | assert res.error is None or res.error is False 154 | assert res.api_response is not None 155 | assert res.parsed_outputs is None 156 | 157 | fetch_prompts.reset_mock() 158 | async_log_to_cloud.reset_mock() 159 | 160 | # mock run 161 | mock_run = AsyncMock() 162 | mock_run.return_value = LLMResponse( 163 | api_response=res.api_response, parsed_outputs={"key": "value"} 164 | ) 165 | mocker.patch("promptmodel.llms.llm.LLM.arun", mock_run) 166 | mock_res: LLMResponse = await proxy.arun({}) 167 | fetch_prompts.assert_called_once() 168 | async_log_to_cloud.assert_called_once() 169 | assert mock_res.raw_output is None 170 | assert mock_res.error is None or res.error is False 171 | assert mock_res.api_response is not None 172 | assert mock_res.parsed_outputs is not None 173 | 174 | 175 | def test_stream_and_parse( 176 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 177 | ): 178 | fetch_prompts = mocker.patch( 179 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 180 | ) 181 | async_log_to_cloud = mocker.patch( 182 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 183 | mock_async_log_to_cloud, 184 | ) 185 | 186 | res: Generator[LLMStreamResponse, None, None] = proxy.stream({}) 187 | chunks: List[LLMStreamResponse] = [] 188 | for chunk in res: 189 | chunks.append(chunk) 190 | fetch_prompts.assert_called_once() 191 | async_log_to_cloud.assert_called_once() 192 | 193 | assert chunks[-1].api_response is not None 194 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 195 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 196 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 197 | 198 | fetch_prompts.reset_mock() 199 | async_log_to_cloud.reset_mock() 200 | 201 | def mock_stream_generator(*args, **kwargs): 202 | yield LLMStreamResponse(parsed_outputs={"key": "value"}) 203 | yield LLMStreamResponse(api_response=chunks[-1].api_response) 204 | 205 | mock_run = MagicMock(side_effect=mock_stream_generator) 206 | mocker.patch("promptmodel.llms.llm.LLM.stream", mock_run) 207 | 208 | mock_res: Generator[LLMStreamResponse, None, None] = proxy.stream({}) 209 | mock_chunks: List[LLMStreamResponse] = [] 210 | for chunk in mock_res: 211 | mock_chunks.append(chunk) 212 | fetch_prompts.assert_called_once() 213 | async_log_to_cloud.assert_called_once() 214 | 215 | assert mock_chunks[-1].api_response is not None 216 | assert len([chunk for chunk in mock_chunks if chunk.error is not None]) == 0 217 | assert len([chunk for chunk in mock_chunks if chunk.parsed_outputs is not None]) > 0 218 | assert len([chunk for chunk in mock_chunks if chunk.raw_output is not None]) == 0 219 | 220 | 221 | @pytest.mark.asyncio 222 | async def test_astream_and_parse( 223 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock 224 | ): 225 | fetch_prompts = mocker.patch( 226 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts 227 | ) 228 | async_log_to_cloud = mocker.patch( 229 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud", 230 | mock_async_log_to_cloud, 231 | ) 232 | 233 | res: AsyncGenerator[LLMStreamResponse, None] = proxy.astream({}) 234 | chunks: List[LLMStreamResponse] = [] 235 | async for chunk in res: 236 | chunks.append(chunk) 237 | fetch_prompts.assert_called_once() 238 | async_log_to_cloud.assert_called_once() 239 | 240 | assert chunks[-1].api_response is not None 241 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0 242 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0 243 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0 244 | 245 | fetch_prompts.reset_mock() 246 | async_log_to_cloud.reset_mock() 247 | 248 | async def mock_stream_generator(*args, **kwargs): 249 | yield LLMStreamResponse(parsed_outputs={"key": "value"}) 250 | yield LLMStreamResponse(api_response=chunks[-1].api_response) 251 | 252 | mock_run = MagicMock(side_effect=mock_stream_generator) 253 | mocker.patch("promptmodel.llms.llm.LLM.astream", mock_run) 254 | 255 | mock_res: AsyncGenerator[LLMStreamResponse, None] = proxy.astream({}) 256 | mock_chunks: List[LLMStreamResponse] = [] 257 | async for chunk in mock_res: 258 | mock_chunks.append(chunk) 259 | fetch_prompts.assert_called_once() 260 | async_log_to_cloud.assert_called_once() 261 | 262 | assert mock_chunks[-1].api_response is not None 263 | assert len([chunk for chunk in mock_chunks if chunk.error is not None]) == 0 264 | assert len([chunk for chunk in mock_chunks if chunk.parsed_outputs is not None]) > 0 265 | assert len([chunk for chunk in mock_chunks if chunk.raw_output is not None]) == 0 266 | -------------------------------------------------------------------------------- /tests/llm/stream_function_call_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, MagicMock 3 | 4 | import nest_asyncio 5 | from typing import Generator, AsyncGenerator, List, Optional 6 | from litellm import ModelResponse 7 | 8 | from ..constants import function_shemas 9 | from promptmodel.llms.llm import LLM 10 | from promptmodel.llms.llm_proxy import LLMProxy 11 | from promptmodel.types.response import * 12 | from promptmodel.types.enums import ParsingType 13 | from promptmodel.utils.async_utils import run_async_in_sync 14 | 15 | html_output_format = """\ 16 | You must follow the provided output format. Keep the string between <> as it is. 17 | Output format: 18 | 19 | (value here) 20 | 21 | """ 22 | 23 | double_bracket_output_format = """\ 24 | You must follow the provided output format. Keep the string between [[ ]] as it is. 25 | Output format: 26 | [[response type=str]] 27 | (value here) 28 | [[/response]] 29 | """ 30 | 31 | colon_output_format = """\ 32 | You must follow the provided output format. Keep the string before : as it is 33 | Output format: 34 | response type=str: (value here) 35 | 36 | """ 37 | 38 | 39 | def test_stream_with_functions(mocker): 40 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 41 | llm = LLM() 42 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream( 43 | messages=messages, 44 | functions=function_shemas, 45 | model="gpt-3.5-turbo-0613", 46 | ) 47 | 48 | error_count = 0 49 | api_responses: List[ModelResponse] = [] 50 | final_response: Optional[LLMStreamResponse] = None 51 | chunks: List[LLMStreamResponse] = [] 52 | for res in stream_res: 53 | chunks.append(res) 54 | if res.error: 55 | error_count += 1 56 | if ( 57 | res.api_response 58 | and getattr(res.api_response.choices[0], "delta", None) is None 59 | ): 60 | api_responses.append(res.api_response) 61 | final_response = res 62 | 63 | assert error_count == 0, "error_count is not 0" 64 | assert len(api_responses) == 1, "api_count is not 1" 65 | assert ( 66 | getattr(final_response.api_response.choices[0].message, "function_call", None) 67 | is not None 68 | ), "function_call is None" 69 | assert isinstance( 70 | final_response.api_response.choices[0].message.function_call, FunctionCall 71 | ) 72 | 73 | assert len([c for c in chunks if c.function_call is not None]) > 0 74 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall) is True 75 | 76 | assert api_responses[0].choices[0].message.content is None, "content is not None" 77 | assert api_responses[0]._response_ms is not None, "response_ms is None" 78 | assert api_responses[0].usage is not None, "usage is None" 79 | 80 | # test logging 81 | llm_proxy = LLMProxy("test") 82 | 83 | mock_execute = AsyncMock() 84 | mock_response = MagicMock() 85 | mock_response.status_code = 200 86 | mock_execute.return_value = mock_response 87 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute) 88 | 89 | nest_asyncio.apply() 90 | run_async_in_sync( 91 | llm_proxy._async_log_to_cloud( 92 | log_uuid="test", 93 | version_uuid="test", 94 | inputs={}, 95 | api_response=api_responses[0], 96 | parsed_outputs={}, 97 | metadata={}, 98 | ) 99 | ) 100 | 101 | mock_execute.assert_called_once() 102 | _, kwargs = mock_execute.call_args 103 | api_response_dict = api_responses[0].model_dump() 104 | api_response_dict.update({"response_ms": api_responses[0]._response_ms}) 105 | 106 | assert ( 107 | kwargs["json"]["api_response"] == api_response_dict 108 | ), "api_response is not equal" 109 | 110 | 111 | @pytest.mark.asyncio 112 | async def test_astream_with_functions(mocker): 113 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 114 | llm = LLM() 115 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream( 116 | messages=messages, 117 | functions=function_shemas, 118 | model="gpt-3.5-turbo-0613", 119 | ) 120 | 121 | error_count = 0 122 | api_responses: List[ModelResponse] = [] 123 | final_response: Optional[LLMStreamResponse] = None 124 | chunks: List[LLMStreamResponse] = [] 125 | async for res in stream_res: 126 | chunks.append(res) 127 | if res.error: 128 | error_count += 1 129 | if ( 130 | res.api_response 131 | and getattr(res.api_response.choices[0], "delta", None) is None 132 | ): 133 | api_responses.append(res.api_response) 134 | final_response = res 135 | 136 | assert error_count == 0, "error_count is not 0" 137 | assert len(api_responses) == 1, "api_count is not 1" 138 | assert ( 139 | getattr(final_response.api_response.choices[0].message, "function_call", None) 140 | is not None 141 | ), "function_call is None" 142 | assert isinstance( 143 | final_response.api_response.choices[0].message.function_call, FunctionCall 144 | ) 145 | assert len([c for c in chunks if c.function_call is not None]) > 0 146 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall) 147 | 148 | assert api_responses[0].choices[0].message.content is None, "content is not None" 149 | assert api_responses[0]._response_ms is not None, "response_ms is None" 150 | assert api_responses[0].usage is not None, "usage is None" 151 | # assert api_responses[0].usage.prompt_tokens == 74, "prompt_tokens is not 74" 152 | 153 | # test logging 154 | llm_proxy = LLMProxy("test") 155 | 156 | mock_execute = AsyncMock() 157 | mock_response = MagicMock() 158 | mock_response.status_code = 200 159 | mock_execute.return_value = mock_response 160 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute) 161 | nest_asyncio.apply() 162 | await llm_proxy._async_log_to_cloud( 163 | log_uuid="test", 164 | version_uuid="test", 165 | inputs={}, 166 | api_response=api_responses[0], 167 | parsed_outputs={}, 168 | metadata={}, 169 | ) 170 | 171 | mock_execute.assert_called_once() 172 | _, kwargs = mock_execute.call_args 173 | api_response_dict = api_responses[0].model_dump() 174 | api_response_dict.update({"response_ms": api_responses[0]._response_ms}) 175 | 176 | assert ( 177 | kwargs["json"]["api_response"] == api_response_dict 178 | ), "api_response is not equal" 179 | 180 | 181 | def test_stream_and_parse_with_functions(mocker): 182 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 183 | llm = LLM() 184 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream_and_parse( 185 | messages=messages, 186 | functions=function_shemas, 187 | model="gpt-3.5-turbo-0613", 188 | parsing_type=None, 189 | ) 190 | 191 | error_count = 0 192 | api_responses: List[ModelResponse] = [] 193 | final_response: Optional[LLMStreamResponse] = None 194 | chunks: List[LLMStreamResponse] = [] 195 | for res in stream_res: 196 | chunks.append(res) 197 | if res.error: 198 | error_count += 1 199 | print("ERROR") 200 | print(res.error) 201 | print(res.error_log) 202 | if ( 203 | res.api_response 204 | and getattr(res.api_response.choices[0], "delta", None) is None 205 | ): 206 | api_responses.append(res.api_response) 207 | final_response = res 208 | 209 | assert error_count == 0, "error_count is not 0" 210 | assert len(api_responses) == 1, "api_count is not 1" 211 | assert ( 212 | getattr(final_response.api_response.choices[0].message, "function_call", None) 213 | is not None 214 | ), "function_call is None" 215 | assert isinstance( 216 | final_response.api_response.choices[0].message.function_call, FunctionCall 217 | ) 218 | assert len([c for c in chunks if c.function_call is not None]) > 0 219 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall) 220 | 221 | assert api_responses[0].choices[0].message.content is None, "content is not None" 222 | 223 | # Not call function, parsing case 224 | messages = [ 225 | { 226 | "role": "user", 227 | "content": "Hello, How are you?\n" + html_output_format, 228 | } 229 | ] 230 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream_and_parse( 231 | messages=messages, 232 | functions=function_shemas, 233 | model="gpt-4-1106-preview", 234 | parsing_type=ParsingType.HTML.value, 235 | output_keys=["response"], 236 | ) 237 | 238 | error_count = 0 239 | error_log = "" 240 | api_responses: List[ModelResponse] = [] 241 | final_response: Optional[LLMStreamResponse] = None 242 | chunks: List[LLMStreamResponse] = [] 243 | for res in stream_res: 244 | chunks.append(res) 245 | if res.error: 246 | error_count += 1 247 | error_log = res.error_log 248 | print("ERROR") 249 | if ( 250 | res.api_response 251 | and getattr(res.api_response.choices[0], "delta", None) is None 252 | ): 253 | api_responses.append(res.api_response) 254 | final_response = res 255 | 256 | if not "str" in api_responses[0].choices[0].message.content: 257 | # if "str" in content, just LLM make mistake in generation. 258 | assert ( 259 | error_count == 0 260 | ), f"error_count is not 0, {error_log}, {api_responses[0].model_dump()}" 261 | assert final_response.parsed_outputs is not None, "parsed_outputs is None" 262 | assert len(api_responses) == 1, "api_count is not 1" 263 | assert ( 264 | getattr(final_response.api_response.choices[0].message, "function_call", None) 265 | is None 266 | ), "function_call is not None" 267 | assert len([c for c in chunks if c.function_call is not None]) == 0 268 | 269 | assert api_responses[0].choices[0].message.content is not None, "content is None" 270 | 271 | 272 | @pytest.mark.asyncio 273 | async def test_astream_and_parse_with_functions(mocker): 274 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 275 | llm = LLM() 276 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream_and_parse( 277 | messages=messages, 278 | functions=function_shemas, 279 | model="gpt-3.5-turbo-0613", 280 | parsing_type=None, 281 | ) 282 | 283 | error_count = 0 284 | api_responses: List[ModelResponse] = [] 285 | final_response: Optional[LLMStreamResponse] = None 286 | chunks: List[LLMStreamResponse] = [] 287 | async for res in stream_res: 288 | chunks.append(res) 289 | if res.error: 290 | error_count += 1 291 | 292 | if ( 293 | res.api_response 294 | and getattr(res.api_response.choices[0], "delta", None) is None 295 | ): 296 | api_responses.append(res.api_response) 297 | final_response = res 298 | 299 | assert error_count == 0, "error_count is not 0" 300 | assert len(api_responses) == 1, "api_count is not 1" 301 | 302 | assert api_responses[0].choices[0].message.content is None, "content is not None" 303 | assert ( 304 | getattr(final_response.api_response.choices[0].message, "function_call", None) 305 | is not None 306 | ), "function_call is None" 307 | assert isinstance( 308 | final_response.api_response.choices[0].message.function_call, FunctionCall 309 | ) 310 | assert len([c for c in chunks if c.function_call is not None]) > 0 311 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall) 312 | 313 | # Not call function, parsing case 314 | messages = [ 315 | { 316 | "role": "user", 317 | "content": "Hello, How are you?\n" + html_output_format, 318 | } 319 | ] 320 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream_and_parse( 321 | messages=messages, 322 | functions=function_shemas, 323 | model="gpt-3.5-turbo-0613", 324 | parsing_type=ParsingType.HTML.value, 325 | output_keys=["response"], 326 | ) 327 | 328 | error_count = 0 329 | error_log = "" 330 | api_responses: List[ModelResponse] = [] 331 | final_response: Optional[LLMStreamResponse] = None 332 | chunks: List[LLMStreamResponse] = [] 333 | async for res in stream_res: 334 | chunks.append(res) 335 | if res.error: 336 | error_count += 1 337 | error_log = res.error_log 338 | 339 | if ( 340 | res.api_response 341 | and getattr(res.api_response.choices[0], "delta", None) is None 342 | ): 343 | api_responses.append(res.api_response) 344 | final_response = res 345 | 346 | if not "str" in api_responses[0].choices[0].message.content: 347 | # if "str" in content, just LLM make mistake in generation. 348 | assert ( 349 | error_count == 0 350 | ), f"error_count is not 0, {error_log}, {api_responses[0].model_dump()}" 351 | assert final_response.parsed_outputs is not None, "parsed_outputs is None" 352 | assert len(api_responses) == 1, "api_count is not 1" 353 | assert ( 354 | getattr(final_response.api_response.choices[0].message, "function_call", None) 355 | is None 356 | ), "function_call is not None" 357 | 358 | assert len([c for c in chunks if c.function_call is not None]) == 0 359 | 360 | assert api_responses[0].choices[0].message.content is not None, "content is None" 361 | -------------------------------------------------------------------------------- /tests/llm/stream_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, MagicMock 3 | 4 | import nest_asyncio 5 | from typing import Generator, AsyncGenerator, List 6 | from litellm import ModelResponse 7 | 8 | from promptmodel.llms.llm import LLM 9 | from promptmodel.llms.llm_proxy import LLMProxy 10 | from promptmodel.types.response import LLMStreamResponse 11 | from promptmodel.utils.async_utils import run_async_in_sync 12 | 13 | 14 | def test_stream(mocker): 15 | llm = LLM() 16 | test_messages = [ 17 | {"role": "system", "content": "You are a helpful assistant."}, 18 | {"role": "user", "content": "Introduce yourself in 50 words."}, 19 | ] 20 | 21 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream( 22 | messages=test_messages, 23 | model="gpt-3.5-turbo", 24 | ) 25 | error_count = 0 26 | api_responses: List[ModelResponse] = [] 27 | for res in stream_res: 28 | if res.error: 29 | error_count += 1 30 | print("ERROR") 31 | print(res.error) 32 | print(res.error_log) 33 | if ( 34 | res.api_response 35 | and getattr(res.api_response.choices[0], "delta", None) is None 36 | ): 37 | api_responses.append(res.api_response) 38 | print(api_responses) 39 | assert error_count == 0, "error_count is not 0" 40 | assert len(api_responses) == 1, "api_count is not 1" 41 | 42 | assert api_responses[0].choices[0].message.content is not None, "content is None" 43 | assert api_responses[0]._response_ms is not None, "response_ms is None" 44 | 45 | # test logging 46 | llm_proxy = LLMProxy("test") 47 | 48 | mock_execute = AsyncMock() 49 | mock_response = MagicMock() 50 | mock_response.status_code = 200 51 | mock_execute.return_value = mock_response 52 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute) 53 | 54 | nest_asyncio.apply() 55 | run_async_in_sync( 56 | llm_proxy._async_log_to_cloud( 57 | log_uuid="test", 58 | version_uuid="test", 59 | inputs={}, 60 | api_response=api_responses[0], 61 | parsed_outputs={}, 62 | metadata={}, 63 | ) 64 | ) 65 | 66 | mock_execute.assert_called_once() 67 | _, kwargs = mock_execute.call_args 68 | api_response_dict = api_responses[0].model_dump() 69 | api_response_dict.update({"response_ms": api_responses[0]._response_ms}) 70 | assert ( 71 | kwargs["json"]["api_response"] == api_response_dict 72 | ), "api_response is not equal" 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_astream(mocker): 77 | llm = LLM() 78 | test_messages = [ 79 | {"role": "system", "content": "You are a helpful assistant."}, 80 | {"role": "user", "content": "Introduce yourself in 50 words."}, 81 | ] 82 | 83 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream( 84 | messages=test_messages, 85 | model="gpt-3.5-turbo", 86 | ) 87 | error_count = 0 88 | api_responses: List[ModelResponse] = [] 89 | async for res in stream_res: 90 | if res.error: 91 | error_count += 1 92 | print("ERROR") 93 | print(res.error) 94 | print(res.error_log) 95 | if ( 96 | res.api_response 97 | and getattr(res.api_response.choices[0], "delta", None) is None 98 | ): 99 | api_responses.append(res.api_response) 100 | 101 | assert error_count == 0, "error_count is not 0" 102 | assert len(api_responses) == 1, "api_count is not 1" 103 | 104 | assert api_responses[0].choices[0].message.content is not None, "content is None" 105 | assert api_responses[0]._response_ms is not None, "response_ms is None" 106 | 107 | # test logging 108 | llm_proxy = LLMProxy("test") 109 | 110 | mock_execute = AsyncMock() 111 | mock_response = MagicMock() 112 | mock_response.status_code = 200 113 | mock_execute.return_value = mock_response 114 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute) 115 | 116 | await llm_proxy._async_log_to_cloud( 117 | log_uuid="test", 118 | version_uuid="test", 119 | inputs={}, 120 | api_response=api_responses[0], 121 | parsed_outputs={}, 122 | metadata={}, 123 | ) 124 | 125 | mock_execute.assert_called_once() 126 | _, kwargs = mock_execute.call_args 127 | 128 | api_response_dict = api_responses[0].model_dump() 129 | api_response_dict.update({"response_ms": api_responses[0]._response_ms}) 130 | 131 | assert ( 132 | kwargs["json"]["api_response"] == api_response_dict 133 | ), "api_response is not equal" 134 | -------------------------------------------------------------------------------- /tests/llm/tool_calls_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from typing import Generator, AsyncGenerator, Dict, List, Any, Optional 5 | from litellm import ModelResponse 6 | 7 | from ..constants import function_shemas 8 | from promptmodel.llms.llm import LLM 9 | from promptmodel.llms.llm_proxy import LLMProxy 10 | from promptmodel.types.response import * 11 | from promptmodel.types.enums import ParsingType 12 | 13 | html_output_format = """\ 14 | You must follow the provided output format. Keep the string between <> as it is. 15 | Output format: 16 | 17 | (value here) 18 | 19 | """ 20 | 21 | double_bracket_output_format = """\ 22 | You must follow the provided output format. Keep the string between [[ ]] as it is. 23 | Output format: 24 | [[response type=str]] 25 | (value here) 26 | [[/response]] 27 | """ 28 | 29 | colon_output_format = """\ 30 | You must follow the provided output format. Keep the string before : as it is 31 | Output format: 32 | response type=str: (value here) 33 | 34 | """ 35 | 36 | tool_schemas = [{"type": "function", "function": schema} for schema in function_shemas] 37 | 38 | 39 | def test_run_with_tools(mocker): 40 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 41 | llm = LLM() 42 | res: LLMResponse = llm.run( 43 | messages=messages, 44 | tools=tool_schemas, 45 | model="gpt-4-1106-preview", 46 | ) 47 | 48 | assert res.error is None, "error is not None" 49 | assert res.api_response is not None, "api_response is None" 50 | assert ( 51 | res.api_response.choices[0].finish_reason == "tool_calls" 52 | ), "finish_reason is not tool_calls" 53 | print(res.api_response.model_dump()) 54 | print(res.__dict__) 55 | assert res.tool_calls is not None, "tool_calls is None" 56 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall) 57 | 58 | messages = [{"role": "user", "content": "Hello, How are you?"}] 59 | 60 | res: LLMResponse = llm.run( 61 | messages=messages, 62 | tools=tool_schemas, 63 | model="gpt-4-1106-preview", 64 | ) 65 | 66 | assert res.error is None, "error is not None" 67 | assert res.api_response is not None, "api_response is None" 68 | assert ( 69 | res.api_response.choices[0].finish_reason == "stop" 70 | ), "finish_reason is not stop" 71 | 72 | assert res.tool_calls is None, "tool_calls is not None" 73 | assert res.raw_output is not None, "raw_output is None" 74 | 75 | 76 | @pytest.mark.asyncio 77 | async def test_arun_with_tools(mocker): 78 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 79 | llm = LLM() 80 | res: LLMResponse = await llm.arun( 81 | messages=messages, 82 | tools=tool_schemas, 83 | model="gpt-4-1106-preview", 84 | ) 85 | 86 | assert res.error is None, "error is not None" 87 | assert res.api_response is not None, "api_response is None" 88 | assert ( 89 | res.api_response.choices[0].finish_reason == "tool_calls" 90 | ), "finish_reason is not tool_calls" 91 | 92 | assert res.tool_calls is not None, "tool_calls is None" 93 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall) 94 | 95 | messages = [{"role": "user", "content": "Hello, How are you?"}] 96 | 97 | res: LLMResponse = await llm.arun( 98 | messages=messages, 99 | tools=tool_schemas, 100 | model="gpt-4-1106-preview", 101 | ) 102 | 103 | assert res.error is None, "error is not None" 104 | assert res.api_response is not None, "api_response is None" 105 | assert ( 106 | res.api_response.choices[0].finish_reason == "stop" 107 | ), "finish_reason is not stop" 108 | 109 | assert res.tool_calls is None, "tool_calls is not None" 110 | assert res.raw_output is not None, "raw_output is None" 111 | 112 | 113 | def test_run_and_parse_with_tools(mocker): 114 | # With parsing_type = None 115 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 116 | llm = LLM() 117 | res: LLMResponse = llm.run_and_parse( 118 | messages=messages, 119 | tools=tool_schemas, 120 | model="gpt-4-1106-preview", 121 | parsing_type=None, 122 | ) 123 | 124 | assert res.error is False, "error is not False" 125 | assert res.api_response is not None, "api_response is None" 126 | assert ( 127 | res.api_response.choices[0].finish_reason == "tool_calls" 128 | ), "finish_reason is not tool_calls" 129 | 130 | assert res.tool_calls is not None, "tool_calls is None" 131 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall) 132 | 133 | messages = [{"role": "user", "content": "Hello, How are you?"}] 134 | 135 | res: LLMResponse = llm.run_and_parse( 136 | messages=messages, 137 | tools=tool_schemas, 138 | model="gpt-4-1106-preview", 139 | parsing_type=None, 140 | ) 141 | 142 | assert res.error is False, "error is not False" 143 | assert res.api_response is not None, "api_response is None" 144 | assert ( 145 | res.api_response.choices[0].finish_reason == "stop" 146 | ), "finish_reason is not stop" 147 | 148 | assert res.tool_calls is None, "tool_calls is not None" 149 | assert res.raw_output is not None, "raw_output is None" 150 | 151 | # with parsing_type = "HTML" 152 | 153 | messages = [ 154 | { 155 | "role": "user", 156 | "content": "What is the weather like in Boston? \n" + html_output_format, 157 | } 158 | ] 159 | llm = LLM() 160 | res: LLMResponse = llm.run_and_parse( 161 | messages=messages, 162 | tools=tool_schemas, 163 | model="gpt-4-1106-preview", 164 | parsing_type=ParsingType.HTML.value, 165 | output_keys=["response"], 166 | ) 167 | 168 | # 1. Output 지키고 function call -> (Pass) 169 | # 2. Output 지키고 stop -> OK 170 | # 3. Output 무시하고 function call -> OK (function call이 나타나면 파싱을 하지 않도록 수정) 171 | 172 | # In this case, error is True because the output is not in the correct format 173 | assert res.error is False, "error is not False" 174 | assert res.api_response is not None, "api_response is None" 175 | assert ( 176 | res.api_response.choices[0].finish_reason == "tool_calls" 177 | ), "finish_reason is not tool_calls" 178 | 179 | assert res.tool_calls is not None, "tool_calls is None" 180 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall) 181 | 182 | assert res.parsed_outputs is None, "parsed_outputs is not empty" 183 | 184 | messages = [ 185 | { 186 | "role": "user", 187 | "content": "Hello, How are you?\n" + html_output_format, 188 | } 189 | ] 190 | 191 | res: LLMResponse = llm.run_and_parse( 192 | messages=messages, 193 | tools=tool_schemas, 194 | model="gpt-4-1106-preview", 195 | parsing_type=ParsingType.HTML.value, 196 | output_keys=["response"], 197 | ) 198 | 199 | print(res.__dict__) 200 | 201 | if not "str" in res.raw_output: 202 | # if "str" in res.raw_output, it means that LLM make mistakes 203 | assert res.error is False, "error is not False" 204 | assert res.parsed_outputs is not None, "parsed_outputs is None" 205 | 206 | assert res.api_response is not None, "api_response is None" 207 | assert ( 208 | res.api_response.choices[0].finish_reason == "stop" 209 | ), "finish_reason is not stop" 210 | 211 | assert res.tool_calls is None, "tool_calls is not None" 212 | assert res.raw_output is not None, "raw_output is None" 213 | 214 | 215 | @pytest.mark.asyncio 216 | async def test_arun_and_parse_with_tools(mocker): 217 | # With parsing_type = None 218 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}] 219 | llm = LLM() 220 | res: LLMResponse = await llm.arun_and_parse( 221 | messages=messages, 222 | tools=tool_schemas, 223 | model="gpt-4-1106-preview", 224 | parsing_type=None, 225 | ) 226 | 227 | assert res.error is False, "error is not False" 228 | assert res.api_response is not None, "api_response is None" 229 | assert ( 230 | res.api_response.choices[0].finish_reason == "tool_calls" 231 | ), "finish_reason is not tool_calls" 232 | print(res) 233 | assert res.tool_calls is not None, "tool_calls is None" 234 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall) 235 | 236 | messages = [{"role": "user", "content": "Hello, How are you?"}] 237 | 238 | res: LLMResponse = await llm.arun_and_parse( 239 | messages=messages, 240 | tools=tool_schemas, 241 | model="gpt-4-1106-preview", 242 | parsing_type=None, 243 | ) 244 | 245 | assert res.error is False, "error is not False" 246 | assert res.api_response is not None, "api_response is None" 247 | assert ( 248 | res.api_response.choices[0].finish_reason == "stop" 249 | ), "finish_reason is not stop" 250 | 251 | assert res.tool_calls is None, "tool_calls is not None" 252 | assert res.raw_output is not None, "raw_output is None" 253 | 254 | # with parsing_type = "HTML" 255 | 256 | messages = [ 257 | { 258 | "role": "user", 259 | "content": "What is the weather like in Boston? \n" + html_output_format, 260 | } 261 | ] 262 | llm = LLM() 263 | res: LLMResponse = await llm.arun_and_parse( 264 | messages=messages, 265 | tools=tool_schemas, 266 | model="gpt-4-1106-preview", 267 | parsing_type=ParsingType.HTML.value, 268 | output_keys=["response"], 269 | ) 270 | 271 | # In this case, error is False becuase if tool_calls, parsing is not performed 272 | assert res.error is False, "error is not False" 273 | assert res.api_response is not None, "api_response is None" 274 | assert ( 275 | res.api_response.choices[0].finish_reason == "tool_calls" 276 | ), "finish_reason is not tool_calls" 277 | 278 | assert res.tool_calls is not None, "tool_calls is None" 279 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall) 280 | 281 | assert res.parsed_outputs is None, "parsed_outputs is not empty" 282 | 283 | messages = [ 284 | { 285 | "role": "user", 286 | "content": "Hello, How are you?\n" + html_output_format, 287 | } 288 | ] 289 | 290 | res: LLMResponse = await llm.arun_and_parse( 291 | messages=messages, 292 | tools=tool_schemas, 293 | model="gpt-4-1106-preview", 294 | parsing_type=ParsingType.HTML.value, 295 | output_keys=["response"], 296 | ) 297 | if not "str" in res.raw_output: 298 | assert res.error is False, "error is not False" 299 | assert res.parsed_outputs is not None, "parsed_outputs is None" 300 | 301 | assert res.api_response is not None, "api_response is None" 302 | assert ( 303 | res.api_response.choices[0].finish_reason == "stop" 304 | ), "finish_reason is not stop" 305 | 306 | assert res.tool_calls is None, "tool_calls is not None" 307 | assert res.raw_output is not None, "raw_output is None" 308 | -------------------------------------------------------------------------------- /tests/utils/async_util_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | import nest_asyncio 4 | from typing import Coroutine 5 | from unittest.mock import AsyncMock 6 | 7 | from promptmodel.utils.async_utils import run_async_in_sync 8 | 9 | nest_asyncio.apply() 10 | 11 | 12 | def test_sync_context(mocker): 13 | coro = AsyncMock(return_value="test") 14 | res = run_async_in_sync(coro()) 15 | coro.assert_called_once() 16 | 17 | assert res == "test", "res is not test" 18 | 19 | 20 | def test_sync_async_context(mocker): 21 | coro = AsyncMock(return_value="test") 22 | 23 | async def async_context(coro: Coroutine): 24 | res = await coro 25 | return res 26 | 27 | res = asyncio.run(async_context(coro())) 28 | coro.assert_called_once() 29 | 30 | assert res == "test", "res is not test" 31 | 32 | 33 | @pytest.mark.asyncio 34 | async def test_async_context(mocker): 35 | coro = AsyncMock(return_value="test") 36 | res = await coro() 37 | coro.assert_called_once() 38 | 39 | assert res == "test", "res is not test" 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_async_sync_context(mocker): 44 | coro = AsyncMock(return_value="test") 45 | print("ready") 46 | 47 | def sync_context(coro: Coroutine): 48 | return run_async_in_sync(coro) 49 | 50 | res = sync_context(coro()) 51 | coro.assert_called_once() 52 | 53 | assert res == "test", "res is not test" 54 | -------------------------------------------------------------------------------- /tests/websocket_client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/websocket_client/__init__.py -------------------------------------------------------------------------------- /tests/websocket_client/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | from promptmodel.websocket.websocket_client import DevWebsocketClient 5 | from promptmodel.dev_app import DevApp 6 | 7 | 8 | async def echo_coroutine(*args, **kwargs): 9 | # print(args, kwargs) 10 | return args, kwargs 11 | 12 | 13 | @pytest.fixture 14 | def websocket_client(): 15 | websocket_client = DevWebsocketClient(_devapp=DevApp()) 16 | return websocket_client 17 | 18 | 19 | @pytest.fixture 20 | def mock_websocket(): 21 | # 모의 WebSocketClientProtocol 객체 생성 22 | mock_websocket = AsyncMock() 23 | 24 | async def aenter(self): 25 | return self 26 | 27 | async def aexit(self, exc_type, exc_value, traceback): 28 | pass 29 | 30 | mock_websocket.__aenter__ = aenter 31 | mock_websocket.__aexit__ = aexit 32 | mock_websocket.recv = AsyncMock(return_value='{"key" : "value"}') 33 | mock_websocket.send = AsyncMock() 34 | 35 | return mock_websocket 36 | 37 | 38 | @pytest.fixture 39 | def mock_json_dumps(): 40 | mock_json_dumps = MagicMock() 41 | # it should return exactly the same as the input 42 | mock_json_dumps.side_effect = lambda data, *args, **kwargs: data 43 | return mock_json_dumps 44 | -------------------------------------------------------------------------------- /tests/websocket_client/local_task_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | import asyncio 4 | 5 | from websockets.exceptions import ConnectionClosedOK 6 | from promptmodel.websocket.websocket_client import DevWebsocketClient 7 | from promptmodel.types.enums import LocalTask 8 | 9 | 10 | @pytest.mark.asyncio 11 | async def test_connect_to_gateway( 12 | mocker, websocket_client: DevWebsocketClient, mock_websocket: AsyncMock 13 | ): 14 | project_uuid = "test_uuid" 15 | connection_name = "test_project" 16 | cli_access_header = {"Authorization": "Bearer testtoken"} 17 | 18 | with patch.object( 19 | websocket_client, "_DevWebsocketClient__handle_message", new_callable=AsyncMock 20 | ) as mock_function: 21 | mock_function.side_effect = ConnectionClosedOK(None, None) 22 | with patch( 23 | "promptmodel.websocket.websocket_client.connect", 24 | new_callable=MagicMock, 25 | return_value=mock_websocket, 26 | ) as mock_connect: 27 | # 5초 후에 자동으로 테스트 종료 28 | await websocket_client.connect_to_gateway( 29 | project_uuid, connection_name, cli_access_header, retries=1 30 | ), 31 | mock_connect.assert_called_once() 32 | mock_websocket.recv.assert_called_once() 33 | websocket_client._DevWebsocketClient__handle_message.assert_called_once() 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_local_tasks( 38 | mocker, websocket_client: DevWebsocketClient, mock_websocket: AsyncMock 39 | ): 40 | websocket_client._devapp.function_models = {} 41 | websocket_client._devapp.samples = {} 42 | websocket_client._devapp.functions = {"test_function": "test_function"} 43 | 44 | await websocket_client._DevWebsocketClient__handle_message( 45 | message={"type": LocalTask.LIST_CODE_FUNCTIONS}, ws=mock_websocket 46 | ) 47 | -------------------------------------------------------------------------------- /tests/websocket_client/run_chatmodel_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | import asyncio 5 | from typing import Optional 6 | from uuid import uuid4 7 | from dataclasses import dataclass 8 | from websockets.exceptions import ConnectionClosedOK 9 | from promptmodel.websocket.websocket_client import DevWebsocketClient 10 | from promptmodel.types.enums import LocalTask, LocalTaskErrorType 11 | from promptmodel.types.response import FunctionSchema 12 | 13 | 14 | @dataclass 15 | class ChatModelInterface: 16 | name: str 17 | default_model: str = "gpt-3.5-turbo" 18 | 19 | 20 | def get_current_weather(location: str, unit: Optional[str] = "celsius"): 21 | return "13 degrees celsius" 22 | 23 | 24 | get_current_weather_desc = { 25 | "name": "get_current_weather", 26 | "description": "Get the current weather in a given location", 27 | "parameters": { 28 | "type": "object", 29 | "properties": { 30 | "location": { 31 | "type": "string", 32 | "description": "The city and state, e.g. San Francisco, CA", 33 | }, 34 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, 35 | }, 36 | "required": ["location"], 37 | }, 38 | } 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_run_model_function_call( 43 | mocker, 44 | websocket_client: DevWebsocketClient, 45 | mock_websocket: AsyncMock, 46 | mock_json_dumps: MagicMock, 47 | ): 48 | websocket_client._devapp.chat_models = [ChatModelInterface("test_module")] 49 | websocket_client._devapp.samples = { 50 | "sample_1": {"user_message": "What is the weather like in Boston?"} 51 | } 52 | websocket_client._devapp.functions = { 53 | "get_current_weather": { 54 | "schema": FunctionSchema(**get_current_weather_desc), 55 | "function": get_current_weather, 56 | } 57 | } 58 | 59 | function_schemas_in_db = [FunctionSchema(**get_current_weather_desc).model_dump()] 60 | 61 | function_schemas_in_db[0]["mock_response"] = "13" 62 | 63 | mocker.patch("promptmodel.websocket.websocket_client.json.dumps", mock_json_dumps) 64 | 65 | # success case with function_call 66 | await websocket_client._DevWebsocketClient__handle_message( 67 | message={ 68 | "type": LocalTask.RUN_CHAT_MODEL, 69 | "chat_model_name": "test_module", 70 | "system_prompt": { 71 | "role": "system", 72 | "content": "You are a helpful assistant.", 73 | }, 74 | "old_messages": [], 75 | "new_messages": [ 76 | { 77 | "role": "user", 78 | "content": "What is the weather in Boston?", 79 | } 80 | ], 81 | "model": "gpt-3.5-turbo", 82 | "functions": ["get_current_weather"], 83 | "function_schemas": function_schemas_in_db, 84 | }, 85 | ws=mock_websocket, 86 | ) 87 | 88 | call_args_list = mock_websocket.send.call_args_list 89 | data = [arg.args[0] for arg in call_args_list] 90 | 91 | assert len([d for d in data if d["status"] == "failed"]) == 0 92 | assert len([d for d in data if d["status"] == "completed"]) == 1 93 | 94 | assert len([d for d in data if "function_response" in d]) == 1 95 | assert [d for d in data if "function_response" in d][0]["function_response"][ 96 | "name" 97 | ] == "get_current_weather" 98 | 99 | assert len([d for d in data if "function_call" in d]) > 1 100 | assert len([d for d in data if "raw_output" in d]) > 0 101 | 102 | mock_websocket.send.reset_mock() 103 | function_schemas_in_db[0]["mock_response"] = "13" 104 | 105 | print( 106 | "=======================================================================================" 107 | ) 108 | 109 | # success case with no function call 110 | await websocket_client._DevWebsocketClient__handle_message( 111 | message={ 112 | "type": LocalTask.RUN_CHAT_MODEL, 113 | "chat_model_name": "test_module", 114 | "system_prompt": { 115 | "role": "system", 116 | "content": "You are a helpful assistant.", 117 | }, 118 | "old_messages": [], 119 | "new_messages": [ 120 | { 121 | "role": "user", 122 | "content": "Hello?", 123 | } 124 | ], 125 | "model": "gpt-3.5-turbo", 126 | }, 127 | ws=mock_websocket, 128 | ) 129 | 130 | call_args_list = mock_websocket.send.call_args_list 131 | data = [arg.args[0] for arg in call_args_list] 132 | 133 | assert len([d for d in data if d["status"] == "failed"]) == 0 134 | assert len([d for d in data if d["status"] == "completed"]) == 1 135 | 136 | assert len([d for d in data if "function_response" in d]) == 0 137 | 138 | assert len([d for d in data if "function_call" in d]) == 0 139 | assert len([d for d in data if "raw_output" in d]) > 0 140 | 141 | mock_websocket.send.reset_mock() 142 | function_schemas_in_db[0]["mock_response"] = "13" 143 | 144 | print( 145 | "=======================================================================================" 146 | ) 147 | 148 | # FUNCTION_CALL_FAILED_ERROR case 149 | def error_raise_function(*args, **kwargs): 150 | raise Exception("error") 151 | 152 | websocket_client._devapp.functions = { 153 | "get_current_weather": { 154 | "schema": FunctionSchema(**get_current_weather_desc), 155 | "function": error_raise_function, 156 | } 157 | } 158 | 159 | await websocket_client._DevWebsocketClient__handle_message( 160 | message={ 161 | "type": LocalTask.RUN_CHAT_MODEL, 162 | "chat_model_name": "test_module", 163 | "system_prompt": { 164 | "role": "system", 165 | "content": "You are a helpful assistant.", 166 | }, 167 | "old_messages": [], 168 | "new_messages": [ 169 | { 170 | "role": "user", 171 | "content": "What is the weather in Boston?", 172 | } 173 | ], 174 | "model": "gpt-3.5-turbo", 175 | "functions": ["get_current_weather"], 176 | "function_schemas": function_schemas_in_db, 177 | }, 178 | ws=mock_websocket, 179 | ) 180 | call_args_list = mock_websocket.send.call_args_list 181 | # print(call_args_list) 182 | data = [arg.args[0] for arg in call_args_list] 183 | 184 | assert len([d for d in data if d["status"] == "failed"]) == 1 185 | assert [d for d in data if d["status"] == "failed"][0][ 186 | "error_type" 187 | ] == LocalTaskErrorType.FUNCTION_CALL_FAILED_ERROR.value 188 | assert len([d for d in data if d["status"] == "completed"]) == 0 189 | assert len([d for d in data if "function_response" in d]) == 0 190 | assert len([d for d in data if "function_call" in d]) > 1 191 | assert len([d for d in data if "raw_output" in d]) == 0 192 | mock_websocket.send.reset_mock() 193 | function_schemas_in_db[0]["mock_response"] = "13" 194 | 195 | # function not in code case, should use mock_response 196 | websocket_client._devapp.functions = {} 197 | await websocket_client._DevWebsocketClient__handle_message( 198 | message={ 199 | "type": LocalTask.RUN_CHAT_MODEL, 200 | "chat_model_name": "test_module", 201 | "system_prompt": { 202 | "role": "system", 203 | "content": "You are a helpful assistant.", 204 | }, 205 | "old_messages": [], 206 | "new_messages": [ 207 | { 208 | "role": "user", 209 | "content": "What is the weather in Boston?", 210 | } 211 | ], 212 | "model": "gpt-3.5-turbo", 213 | "functions": ["get_weather"], 214 | "function_schemas": function_schemas_in_db, 215 | }, 216 | ws=mock_websocket, 217 | ) 218 | call_args_list = mock_websocket.send.call_args_list 219 | print(call_args_list) 220 | data = [arg.args[0] for arg in call_args_list] 221 | 222 | assert len([d for d in data if d["status"] == "failed"]) == 0 223 | assert len([d for d in data if d["status"] == "completed"]) == 1 224 | 225 | assert len([d for d in data if "function_response" in d]) == 1 226 | assert ( 227 | "FAKE RESPONSE" 228 | in [d for d in data if "function_response" in d][0]["function_response"][ 229 | "response" 230 | ] 231 | ) 232 | assert len([d for d in data if "function_call" in d]) > 1 233 | assert len([d for d in data if "raw_output" in d]) > 0 234 | -------------------------------------------------------------------------------- /tests/websocket_client/run_promptmodel_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, patch, MagicMock 3 | 4 | import asyncio 5 | from typing import Optional 6 | from uuid import uuid4 7 | from dataclasses import dataclass 8 | from websockets.exceptions import ConnectionClosedOK 9 | from promptmodel.websocket.websocket_client import DevWebsocketClient 10 | from promptmodel.types.enums import LocalTask, ParsingType, LocalTaskErrorType 11 | from promptmodel.types.response import FunctionSchema 12 | 13 | 14 | @dataclass 15 | class FunctionModelInterface: 16 | name: str 17 | default_model: str = "gpt-3.5-turbo" 18 | 19 | 20 | def get_current_weather(location: str, unit: Optional[str] = "celsius"): 21 | return "13 degrees celsius" 22 | 23 | 24 | get_current_weather_desc = { 25 | "name": "get_current_weather", 26 | "description": "Get the current weather in a given location", 27 | "parameters": { 28 | "type": "object", 29 | "properties": { 30 | "location": { 31 | "type": "string", 32 | "description": "The city and state, e.g. San Francisco, CA", 33 | }, 34 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, 35 | }, 36 | "required": ["location"], 37 | }, 38 | } 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_run_model_function_call( 43 | mocker, 44 | websocket_client: DevWebsocketClient, 45 | mock_websocket: AsyncMock, 46 | mock_json_dumps: MagicMock, 47 | ): 48 | websocket_client._devapp.function_models = [FunctionModelInterface("test_module")] 49 | 50 | websocket_client._devapp.functions = { 51 | "get_current_weather": { 52 | "schema": FunctionSchema(**get_current_weather_desc), 53 | "function": get_current_weather, 54 | } 55 | } 56 | 57 | function_schemas_in_db = [FunctionSchema(**get_current_weather_desc).model_dump()] 58 | 59 | function_schemas_in_db[0]["mock_response"] = "13" 60 | 61 | mocker.patch("promptmodel.websocket.websocket_client.json.dumps", mock_json_dumps) 62 | 63 | # success case 64 | await websocket_client._DevWebsocketClient__handle_message( 65 | message={ 66 | "type": LocalTask.RUN_PROMPT_MODEL, 67 | "messages_for_run": [ 68 | { 69 | "role": "system", 70 | "content": "You are a helpful assistant.", 71 | "step": 1, 72 | }, 73 | { 74 | "role": "user", 75 | "content": "What is the weather like in Boston?", 76 | "step": 2, 77 | }, 78 | ], 79 | "model": "gpt-3.5-turbo", 80 | "parsing_type": None, 81 | "output_keys": None, 82 | "functions": ["get_current_weather"], 83 | "function_schemas": function_schemas_in_db, 84 | }, 85 | ws=mock_websocket, 86 | ) 87 | call_args_list = mock_websocket.send.call_args_list 88 | # print(call_args_list) 89 | data = [arg.args[0] for arg in call_args_list] 90 | 91 | assert len([d for d in data if d["status"] == "failed"]) == 0 92 | assert len([d for d in data if d["status"] == "completed"]) == 1 93 | assert len([d for d in data if "function_response" in d]) == 1 94 | assert [d for d in data if "function_response" in d][0]["function_response"][ 95 | "name" 96 | ] == "get_current_weather" 97 | assert len([d for d in data if "function_call" in d]) == 1 98 | assert len([d for d in data if "raw_output" in d]) == 0 99 | mock_websocket.send.reset_mock() 100 | function_schemas_in_db[0]["mock_response"] = "13" 101 | 102 | print( 103 | "=======================================================================================" 104 | ) 105 | 106 | # success case with no function call 107 | await websocket_client._DevWebsocketClient__handle_message( 108 | message={ 109 | "type": LocalTask.RUN_PROMPT_MODEL, 110 | "messages_for_run": [ 111 | { 112 | "role": "system", 113 | "content": "You are a helpful assistant.", 114 | "step": 1, 115 | }, 116 | { 117 | "role": "user", 118 | "content": "What is the weather like in Boston?", 119 | "step": 2, 120 | }, 121 | ], 122 | "model": "gpt-3.5-turbo", 123 | "parsing_type": None, 124 | "output_keys": None, 125 | }, 126 | ws=mock_websocket, 127 | ) 128 | call_args_list = mock_websocket.send.call_args_list 129 | # print(call_args_list) 130 | data = [arg.args[0] for arg in call_args_list] 131 | 132 | assert len([d for d in data if "function_call" in d]) == 0 133 | assert len([d for d in data if "raw_output" in d]) > 0 134 | assert len([d for d in data if d["status"] == "failed"]) == 0 135 | assert len([d for d in data if d["status"] == "completed"]) == 1 136 | mock_websocket.send.reset_mock() 137 | 138 | print( 139 | "=======================================================================================" 140 | ) 141 | 142 | # system_prompt_with_format = """ 143 | # You are a helpful assistant. 144 | 145 | # This is your output format. Keep the string between "[" and "]" as same as given below. You should response content first before call functions. You MUST FOLLOW this output format. 146 | # Output Format: 147 | # [response type=str] 148 | # (value here) 149 | # [/response] 150 | # """ 151 | # await websocket_client._DevWebsocketClient__handle_message( 152 | # message={ 153 | # "type": LocalTask.RUN_PROMPT_MODEL, 154 | # "messages_for_run": [ 155 | # {"role": "system", "content": system_prompt_with_format, "step": 1}, 156 | # { 157 | # "role": "user", 158 | # "content": "What is the weather like in Boston?", 159 | # "step": 2, 160 | # }, 161 | # ], 162 | # "model": "gpt-4-1106-preview", 163 | # "uuid": None, 164 | # "from_uuid": None, 165 | # "parsing_type": ParsingType.SQUARE_BRACKET.value, 166 | # "output_keys": ["response"], 167 | # "functions": ["get_current_weather"], 168 | # "function_schemas": function_schemas_in_db, 169 | # }, 170 | # ws=mock_websocket, 171 | # ) 172 | # call_args_list = mock_websocket.send.call_args_list 173 | # print(call_args_list) 174 | # data = [arg.args[0] for arg in call_args_list] 175 | 176 | # assert len([d for d in data if d["status"] == "failed"]) == 0 177 | # assert len([d for d in data if d["status"] == "completed"]) == 1 178 | # assert len([d for d in data if "function_response" in d]) == 1 179 | # assert [d for d in data if "function_response" in d][0]["function_response"][ 180 | # "name" 181 | # ] == "get_current_weather" 182 | # assert len([d for d in data if "function_call" in d]) == 1 183 | # assert len([d for d in data if "raw_output" in d]) > 0 184 | # assert len([d for d in data if "parsed_outputs" in d]) > 0 185 | # assert ( 186 | # len( 187 | # list( 188 | # set([d["parsed_outputs"].keys() for d in data if "parsed_outputs" in d]) 189 | # ) 190 | # ) 191 | # == 1 192 | # ) 193 | # mock_websocket.send.reset_mock() 194 | # function_schemas_in_db[0]["mock_response"] = "13" 195 | 196 | print( 197 | "=======================================================================================" 198 | ) 199 | 200 | # FUNCTION_CALL_FAILED_ERROR case 201 | def error_raise_function(*args, **kwargs): 202 | raise Exception("error") 203 | 204 | websocket_client._devapp.functions = { 205 | "get_current_weather": { 206 | "schema": FunctionSchema(**get_current_weather_desc), 207 | "function": error_raise_function, 208 | } 209 | } 210 | 211 | # success case 212 | await websocket_client._DevWebsocketClient__handle_message( 213 | message={ 214 | "type": LocalTask.RUN_PROMPT_MODEL, 215 | "messages_for_run": [ 216 | { 217 | "role": "system", 218 | "content": "You are a helpful assistant.", 219 | "step": 1, 220 | }, 221 | { 222 | "role": "user", 223 | "content": "What is the weather like in Boston?", 224 | "step": 2, 225 | }, 226 | ], 227 | "model": "gpt-3.5-turbo", 228 | "parsing_type": None, 229 | "output_keys": None, 230 | "functions": ["get_current_weather"], 231 | "function_schemas": function_schemas_in_db, 232 | }, 233 | ws=mock_websocket, 234 | ) 235 | call_args_list = mock_websocket.send.call_args_list 236 | # print(call_args_list) 237 | data = [arg.args[0] for arg in call_args_list] 238 | 239 | assert len([d for d in data if d["status"] == "failed"]) == 1 240 | assert [d for d in data if d["status"] == "failed"][0][ 241 | "error_type" 242 | ] == LocalTaskErrorType.FUNCTION_CALL_FAILED_ERROR.value 243 | assert len([d for d in data if d["status"] == "completed"]) == 0 244 | assert len([d for d in data if "function_response" in d]) == 0 245 | assert len([d for d in data if "function_call" in d]) == 1 246 | assert len([d for d in data if "raw_output" in d]) == 0 247 | mock_websocket.send.reset_mock() 248 | function_schemas_in_db[0]["mock_response"] = "13" 249 | 250 | # PARSING_FAILED_ERROR case 251 | 252 | system_prompt_with_format = """ 253 | You are a helpful assistant. 254 | 255 | This is your output format. Keep the string between "[" and "]" as same as given below. You MUST FOLLOW this output format. 256 | Output Format: 257 | [response type=str] 258 | (value here) 259 | [/response] 260 | """ 261 | await websocket_client._DevWebsocketClient__handle_message( 262 | message={ 263 | "type": LocalTask.RUN_PROMPT_MODEL, 264 | "messages_for_run": [ 265 | {"role": "system", "content": system_prompt_with_format, "step": 1}, 266 | { 267 | "role": "user", 268 | "content": "Hello!", 269 | "step": 2, 270 | }, 271 | ], 272 | "model": "gpt-4-1106-preview", 273 | "parsing_type": ParsingType.SQUARE_BRACKET.value, 274 | "output_keys": ["respond"], 275 | "functions": ["get_current_weather"], 276 | "function_schemas": function_schemas_in_db, 277 | }, 278 | ws=mock_websocket, 279 | ) 280 | call_args_list = mock_websocket.send.call_args_list 281 | # print(call_args_list) 282 | data = [arg.args[0] for arg in call_args_list] 283 | 284 | assert len([d for d in data if d["status"] == "failed"]) == 1 285 | assert [d for d in data if d["status"] == "failed"][0][ 286 | "error_type" 287 | ] == LocalTaskErrorType.PARSING_FAILED_ERROR.value 288 | assert len([d for d in data if d["status"] == "completed"]) == 0 289 | assert len([d for d in data if "function_response" in d]) == 0 290 | assert len([d for d in data if "function_call" in d]) == 0 291 | assert len([d for d in data if "raw_output" in d]) > 0 292 | assert len([d for d in data if "parsed_outputs" in d]) > 0 293 | 294 | mock_websocket.send.reset_mock() 295 | function_schemas_in_db[0]["mock_response"] = "13" 296 | 297 | # function not in code case, should use mock_response 298 | websocket_client._devapp.functions = {} 299 | await websocket_client._DevWebsocketClient__handle_message( 300 | message={ 301 | "type": LocalTask.RUN_PROMPT_MODEL, 302 | "messages_for_run": [ 303 | { 304 | "role": "system", 305 | "content": "You are a helpful assistant.", 306 | "step": 1, 307 | }, 308 | { 309 | "role": "user", 310 | "content": "What is the weather like in Boston?", 311 | "step": 2, 312 | }, 313 | ], 314 | "model": "gpt-3.5-turbo", 315 | "parsing_type": None, 316 | "output_keys": None, 317 | "functions": ["get_weather"], 318 | "function_schemas": function_schemas_in_db, 319 | }, 320 | ws=mock_websocket, 321 | ) 322 | call_args_list = mock_websocket.send.call_args_list 323 | print(call_args_list) 324 | data = [arg.args[0] for arg in call_args_list] 325 | 326 | assert len([d for d in data if d["status"] == "failed"]) == 0 327 | assert len([d for d in data if d["status"] == "completed"]) == 1 328 | 329 | assert len([d for d in data if "function_response" in d]) == 1 330 | assert ( 331 | "FAKE RESPONSE" 332 | in [d for d in data if "function_response" in d][0]["function_response"][ 333 | "response" 334 | ] 335 | ) 336 | assert len([d for d in data if "function_call" in d]) == 1 337 | assert len([d for d in data if "raw_output" in d]) == 0 338 | --------------------------------------------------------------------------------