├── .coveragerc
├── .gitignore
├── LICENSE
├── README.md
├── promptmodel
├── STARTER.py
├── __init__.py
├── apis
│ ├── __init__.py
│ └── base.py
├── chat_model.py
├── cli
│ ├── __init__.py
│ ├── commands
│ │ ├── __init__.py
│ │ ├── configure.py
│ │ ├── connect.py
│ │ ├── dev.py
│ │ ├── fix.py
│ │ ├── init.py
│ │ ├── login.py
│ │ └── project.py
│ ├── main.py
│ ├── signal_handler.py
│ └── utils.py
├── constants.py
├── database
│ ├── __init__.py
│ ├── config.py
│ ├── crud.py
│ ├── crud_chat.py
│ ├── models.py
│ ├── models_chat.py
│ └── orm.py
├── dev_app.py
├── function_model.py
├── llms
│ ├── __init__.py
│ ├── llm.py
│ ├── llm_dev.py
│ └── llm_proxy.py
├── promptmodel_init.py
├── types
│ ├── __init__.py
│ ├── enums.py
│ ├── request.py
│ └── response.py
├── unit_logger.py
├── utils
│ ├── __init__.py
│ ├── async_utils.py
│ ├── config_utils.py
│ ├── crypto.py
│ ├── logger.py
│ ├── output_utils.py
│ ├── random_utils.py
│ └── token_counting.py
└── websocket
│ ├── __init__.py
│ ├── reload_handler.py
│ └── websocket_client.py
├── publish.sh
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── api_client
├── __init__.py
└── api_client_test.py
├── chat_model
├── __init__.py
├── chat_model_test.py
├── conftest.py
└── registering_meta_test.py
├── cli
└── dev_test.py
├── constants.py
├── function_model
├── __init__.py
├── conftest.py
├── function_model_test.py
└── registering_meta_test.py
├── llm
├── __init__.py
├── error_case_test.py
├── function_call_test.py
├── llm_dev
│ ├── __init__.py
│ └── llm_dev_test.py
├── llm_proxy
│ ├── __init__.py
│ ├── conftest.py
│ ├── proxy_chat_test.py
│ └── proxy_test.py
├── parse_test.py
├── stream_function_call_test.py
├── stream_test.py
├── stream_tool_calls_test.py
└── tool_calls_test.py
├── utils
└── async_util_test.py
└── websocket_client
├── __init__.py
├── conftest.py
├── local_task_test.py
├── run_chatmodel_test.py
└── run_promptmodel_test.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | exclude_lines =
3 | pragma: no cover
4 | def __repr__
5 | if self\.debug:
6 | raise AssertionError
7 | raise NotImplementedError
8 | if 0:
9 | if __name__ == .__main__.:
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | venv
3 | .ipynb_checkpoints
4 | *.ipynb
5 | .env
6 |
7 | .fastllm
8 | .fastllm/*
9 | .promptmodel
10 | .promptmodel/*
11 |
12 | */__pycache__
13 | __pycache__
14 |
15 | .vscode
16 | .vscode/*
17 | */.DS_Store
18 | *.egg-info
19 |
20 | ipynb_test
21 | ipynb_test/*
22 | testapp
23 | testapp/*
24 |
25 | build
26 | build/*
27 | dist
28 | dist/*
29 |
30 | .coverage
31 | key.key
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/LICENSE
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
Promptmodel
6 |
7 |
Prompt & model versioning on the cloud, built for developers.
8 |
9 | We are currently on closed alpha.
10 | Request access 🚀
11 |
12 |
13 |
21 |
22 |
23 | ## Installation
24 |
25 | ```bash
26 | pip install promptmodel
27 | ```
28 |
29 | ## Documentation
30 |
31 | You can find our full documentation [here](https://www.promptmodel.run/docs/introduction).
32 |
--------------------------------------------------------------------------------
/promptmodel/STARTER.py:
--------------------------------------------------------------------------------
1 | """This single file is needed to build the DevApp development dashboard."""
2 |
3 | from promptmodel import DevApp
4 |
5 | # Example imports
6 | # from import
7 |
8 | app = DevApp()
9 |
10 | # Example usage
11 | # This is needed to integrate your codebase with the prompt engineering dashboard
12 | # app.include_client()
13 |
--------------------------------------------------------------------------------
/promptmodel/__init__.py:
--------------------------------------------------------------------------------
1 | from .dev_app import DevClient, DevApp
2 | from .function_model import FunctionModel, PromptModel
3 | from .chat_model import ChatModel
4 | from .promptmodel_init import init
5 | from .unit_logger import UnitLogger
6 |
7 | __version__ = "0.1.19"
8 |
9 |
--------------------------------------------------------------------------------
/promptmodel/apis/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
--------------------------------------------------------------------------------
/promptmodel/apis/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict
3 |
4 | import requests
5 |
6 | import httpx
7 | from rich import print
8 |
9 | from promptmodel.utils.config_utils import read_config
10 | from promptmodel.constants import ENDPOINT_URL
11 | from promptmodel.utils.crypto import decrypt_message
12 |
13 |
14 | class APIClient:
15 | """
16 | A class to represent an API request client.
17 |
18 | ...
19 |
20 | Methods
21 | -------
22 | get_headers():
23 | Generates headers for the API request.
24 | execute(method="GET", params=None, data=None, json=None, **kwargs):
25 | Executes the API request.
26 | """
27 |
28 | @classmethod
29 | def _get_headers(cls, use_cli_key: bool = True) -> Dict:
30 | """
31 | Reads, decrypts the api_key, and returns headers for API request.
32 |
33 | Returns
34 | -------
35 | dict
36 | a dictionary containing the Authorization header
37 | """
38 | config = read_config()
39 | if use_cli_key:
40 | if "connection" not in config:
41 | print(
42 | "User not logged in. Please run [violet]prompt login[/violet] first."
43 | )
44 | exit()
45 |
46 | encrypted_key = (
47 | config["connection"]["encrypted_api_key"]
48 | if "encrypted_api_key" in config["connection"]
49 | else None
50 | )
51 | if encrypted_key is None:
52 | raise Exception("No API key found. Please run 'prompt login' first.")
53 | decrypted_key = decrypt_message(encrypted_key)
54 | else:
55 | decrypted_key = os.environ.get("PROMPTMODEL_API_KEY")
56 | headers = {"Authorization": f"Bearer {decrypted_key}"}
57 | return headers
58 |
59 | @classmethod
60 | def execute(
61 | cls,
62 | path: str,
63 | method="GET",
64 | params: Dict = None,
65 | data: Dict = None,
66 | json: Dict = None,
67 | ignore_auth_error: bool = False,
68 | use_cli_key: bool = True,
69 | **kwargs,
70 | ) -> requests.Response:
71 | """
72 | Executes the API request with the decrypted API key in the headers.
73 |
74 | Parameters
75 | ----------
76 | method : str, optional
77 | The HTTP method of the request (default is "GET")
78 | params : dict, optional
79 | The URL parameters to be sent with the request
80 | data : dict, optional
81 | The request body to be sent with the request
82 | json : dict, optional
83 | The JSON-encoded request body to be sent with the request
84 | ignore_auth_error: bool, optional
85 | Whether to ignore authentication errors (default is False)
86 | **kwargs : dict
87 | Additional arguments to pass to the requests.request function
88 |
89 | Returns
90 | -------
91 | requests.Response
92 | The response object returned by the requests library
93 | """
94 | url = f"{ENDPOINT_URL}{path}"
95 | headers = cls._get_headers(use_cli_key)
96 | try:
97 | response = requests.request(
98 | method,
99 | url,
100 | headers=headers,
101 | params=params,
102 | data=data,
103 | json=json,
104 | **kwargs,
105 | )
106 | if not response:
107 | print(f"[red]Error: {response}[/red]")
108 | exit()
109 | if response.status_code == 200:
110 | return response
111 | elif response.status_code == 403:
112 | if not ignore_auth_error:
113 | print(
114 | "[red]Authentication failed. Please run [violet][bold]prompt login[/bold][/violet] first.[/red]"
115 | )
116 | exit()
117 | else:
118 | print(f"[red]Error: {response}[/red]")
119 | exit()
120 | except requests.exceptions.ConnectionError:
121 | print("[red]Could not connect to the Promptmodel API.[/red]")
122 | except requests.exceptions.Timeout:
123 | print("[red]The request timed out.[/red]")
124 |
125 |
126 | class AsyncAPIClient:
127 | """
128 | A class to represent an Async API request client.
129 | Used in Deployment stage.
130 |
131 | ...
132 |
133 | Methods
134 | -------
135 | get_headers():
136 | Generates headers for the API request.
137 | execute(method="GET", params=None, data=None, json=None, **kwargs):
138 | Executes the API request.
139 | """
140 |
141 | @classmethod
142 | async def _get_headers(cls, use_cli_key: bool = True) -> Dict:
143 | """
144 | Reads, decrypts the api_key, and returns headers for API request.
145 |
146 | Returns
147 | -------
148 | dict
149 | a dictionary containing the Authorization header
150 | """
151 | config = read_config()
152 | if use_cli_key:
153 | if "connection" not in config:
154 | print(
155 | "User not logged in. Please run [violet]prompt login[/violet] first."
156 | )
157 | exit()
158 |
159 | encrypted_key = config["connection"]["encrypted_api_key"]
160 | if encrypted_key is None:
161 | raise Exception("No API key found. Please run 'prompt login' first.")
162 | decrypted_key = decrypt_message(encrypted_key)
163 | else:
164 | decrypted_key = os.environ.get("PROMPTMODEL_API_KEY")
165 | if decrypted_key is None:
166 | raise Exception(
167 | "PROMPTMODEL_API_KEY was not found in the current environment."
168 | )
169 | headers = {"Authorization": f"Bearer {decrypted_key}"}
170 | return headers
171 |
172 | @classmethod
173 | async def execute(
174 | cls,
175 | path: str,
176 | method="GET",
177 | params: Dict = None,
178 | data: Dict = None,
179 | json: Dict = None,
180 | ignore_auth_error: bool = False,
181 | use_cli_key: bool = True,
182 | **kwargs,
183 | ) -> requests.Response:
184 | """
185 | Executes the API request with the decrypted API key in the headers.
186 |
187 | Parameters
188 | ----------
189 | method : str, optional
190 | The HTTP method of the request (default is "GET")
191 | params : dict, optional
192 | The URL parameters to be sent with the request
193 | data : dict, optional
194 | The request body to be sent with the request
195 | json : dict, optional
196 | The JSON-encoded request body to be sent with the request
197 | ignore_auth_error: bool, optional
198 | Whether to ignore authentication errors (default is False)
199 | **kwargs : dict
200 | Additional arguments to pass to the requests.request function
201 |
202 | Returns
203 | -------
204 | requests.Response
205 | The response object returned by the requests library
206 | """
207 | url = f"{ENDPOINT_URL}{path}"
208 | headers = await cls._get_headers(use_cli_key)
209 | try:
210 | async with httpx.AsyncClient(http2=True) as _client:
211 | response = await _client.request(
212 | method,
213 | url,
214 | headers=headers,
215 | params=params,
216 | data=data,
217 | json=json,
218 | **kwargs,
219 | )
220 | if not response:
221 | print(f"[red]Error: {response}[/red]")
222 | if response.status_code == 200:
223 | return response
224 | elif response.status_code == 403:
225 | if not ignore_auth_error:
226 | print("[red]Authentication failed.[/red]")
227 | else:
228 | print(f"[red]Error: {response}[/red]")
229 |
230 | return response
231 | except requests.exceptions.ConnectionError:
232 | print("[red]Could not connect to the Promptmodel API.[/red]")
233 | except requests.exceptions.Timeout:
234 | print("[red]The request timed out.[/red]")
235 | except Exception as exception:
236 | print(f"[red]Error: {exception}[/red]")
237 |
--------------------------------------------------------------------------------
/promptmodel/cli/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/configure.py:
--------------------------------------------------------------------------------
1 | from promptmodel.apis.base import APIClient
2 | import typer
3 | from InquirerPy import inquirer
4 |
5 | from promptmodel.utils.config_utils import upsert_config
6 |
7 |
8 | def configure():
9 | """Saves user's default organization and project."""
10 |
11 | orgs = APIClient.execute(method="GET", path="/organizations").json()
12 | choices = [
13 | {
14 | "key": org["name"],
15 | "name": org["name"],
16 | "value": org,
17 | }
18 | for org in orgs
19 | ]
20 | org = inquirer.select(
21 | message="Select default organization:", choices=choices
22 | ).execute()
23 |
24 | projects = APIClient.execute(
25 | method="GET",
26 | path="/projects",
27 | params={"organization_id": org["organization_id"]},
28 | ).json()
29 | choices = [
30 | {
31 | "key": project["name"],
32 | "name": project["name"],
33 | "value": project,
34 | }
35 | for project in projects
36 | ]
37 | project = inquirer.select(
38 | message="Select default project:", choices=choices
39 | ).execute()
40 |
41 | upsert_config({"org": org, "project": project}, section="connection")
42 |
43 |
44 | app = typer.Typer(invoke_without_command=True, callback=configure)
45 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/connect.py:
--------------------------------------------------------------------------------
1 | import time
2 | import asyncio
3 | import typer
4 | import importlib
5 | import signal
6 | from typing import Dict, Any, List
7 | from playhouse.shortcuts import model_to_dict
8 |
9 | import webbrowser
10 | from rich import print
11 | from InquirerPy import inquirer
12 | from watchdog.observers import Observer
13 |
14 | from promptmodel import DevApp
15 | from promptmodel.apis.base import APIClient
16 | from promptmodel.constants import ENDPOINT_URL, WEB_CLIENT_URL
17 | from promptmodel.cli.commands.init import init as promptmodel_init
18 | from promptmodel.cli.utils import get_org, get_project
19 | from promptmodel.cli.signal_handler import dev_terminate_signal_handler
20 | from promptmodel.utils.config_utils import read_config, upsert_config
21 | from promptmodel.websocket import DevWebsocketClient, CodeReloadHandler
22 | from promptmodel.database.orm import initialize_db
23 |
24 |
25 | def connect():
26 | """Connect websocket and opens up DevApp in the browser."""
27 | upsert_config({"initializing": True}, "connection")
28 | signal.signal(signal.SIGINT, dev_terminate_signal_handler)
29 | promptmodel_init(from_cli=False)
30 |
31 | config = read_config()
32 |
33 | if "project" not in config["connection"]:
34 | org = get_org(config)
35 | project = get_project(config=config, org=org)
36 |
37 | # connect
38 | res = APIClient.execute(
39 | method="POST",
40 | path="/project/cli_connect",
41 | params={"project_uuid": project["uuid"]},
42 | )
43 | if res.status_code != 200:
44 | print(f"Error: {res.json()['detail']}")
45 | return
46 |
47 | upsert_config(
48 | {
49 | "project": project,
50 | "org": org,
51 | },
52 | section="connection",
53 | )
54 |
55 | else:
56 | org = config["connection"]["org"]
57 | project = config["connection"]["project"]
58 |
59 | res = APIClient.execute(
60 | method="POST",
61 | path="/project/cli_connect",
62 | params={"project_uuid": project["uuid"]},
63 | )
64 | if res.status_code != 200:
65 | print(f"Error: {res.json()['detail']}")
66 | return
67 |
68 | _devapp_filename, devapp_instance_name = "promptmodel_dev:app".split(":")
69 |
70 | devapp_module = importlib.import_module(_devapp_filename)
71 | devapp_instance: DevApp = getattr(devapp_module, devapp_instance_name)
72 |
73 | dev_url = f"{WEB_CLIENT_URL}/org/{org['slug']}/projects/{project['uuid']}"
74 |
75 | # Open websocket connection to backend server
76 | dev_websocket_client = DevWebsocketClient(_devapp=devapp_instance)
77 |
78 | import threading
79 |
80 | try:
81 | main_loop = asyncio.get_running_loop()
82 | except RuntimeError:
83 | main_loop = asyncio.new_event_loop()
84 | asyncio.set_event_loop(main_loop)
85 |
86 | # save samples, FunctionSchema, FunctionModel, ChatModel to cloud server in dev_websocket_client.connect_to_gateway
87 | res = APIClient.execute(
88 | method="POST",
89 | path="/save_instances_in_code",
90 | params={"project_uuid": project["uuid"]},
91 | json={
92 | "function_models": devapp_instance._get_function_model_name_list(),
93 | "chat_models": devapp_instance._get_chat_model_name_list(),
94 | "function_schemas": devapp_instance._get_function_schema_list(),
95 | "samples": devapp_instance.samples,
96 | },
97 | )
98 | if res.status_code != 200:
99 | print(f"Error: {res.json()['detail']}")
100 | return
101 |
102 | print(
103 | f"\nOpening [violet]Promptmodel[/violet] prompt engineering environment with the following configuration:\n"
104 | )
105 | print(f"📌 Organization: [blue]{org['name']}[/blue]")
106 | print(f"📌 Project: [blue]{project['name']}[/blue]")
107 | print(
108 | f"\nIf browser doesn't open automatically, please visit [link={dev_url}]{dev_url}[/link]"
109 | )
110 | webbrowser.open(dev_url)
111 |
112 | upsert_config({"online": True, "initializing": False}, section="connection")
113 |
114 | reloader_thread = threading.Thread(
115 | target=start_code_reloader,
116 | args=(_devapp_filename, devapp_instance_name, dev_websocket_client, main_loop),
117 | )
118 | reloader_thread.daemon = True # Set the thread as a daemon
119 | reloader_thread.start()
120 |
121 | # Open Websocket
122 | asyncio.run(
123 | dev_websocket_client.connect_to_gateway(
124 | project_uuid=project["uuid"],
125 | connection_name=project["name"],
126 | cli_access_header=APIClient._get_headers(),
127 | )
128 | )
129 |
130 |
131 | app = typer.Typer(invoke_without_command=True, callback=connect)
132 |
133 |
134 | def start_code_reloader(
135 | _devapp_filename, devapp_instance_name, dev_websocket_client, main_loop
136 | ):
137 | event_handler = CodeReloadHandler(
138 | _devapp_filename, devapp_instance_name, dev_websocket_client, main_loop
139 | )
140 | observer = Observer()
141 | observer.schedule(event_handler, path=".", recursive=True)
142 | observer.start()
143 |
144 | try:
145 | while True:
146 | time.sleep(1)
147 | except KeyboardInterrupt:
148 | observer.stop()
149 | observer.join()
150 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/fix.py:
--------------------------------------------------------------------------------
1 | import typer
2 | import os
3 | from promptmodel.utils.config_utils import read_config, upsert_config
4 | from promptmodel.database.orm import initialize_db
5 |
6 |
7 | def fix():
8 | """Fix Bugs that can be occured in the promptmodel."""
9 | config = read_config()
10 | if "connection" in config:
11 | connection = config["connection"]
12 | if "initializing" in connection and connection["initializing"] == True:
13 | upsert_config({"initializing": False}, "connection")
14 | if "online" in connection and connection["online"] == True:
15 | upsert_config({"online": False}, "connection")
16 | if "reloading" in connection and connection["reloading"] == True:
17 | upsert_config({"reloading": False}, "connection")
18 | # delete .promptmodel/promptmodel.db
19 | # if .promptmodel/promptmodel.db exist, delete it
20 | if os.path.exists(".promptmodel/promptmodel.db"):
21 | os.remove(".promptmodel/promptmodel.db")
22 | # initialize_db()
23 |
24 | return
25 |
26 |
27 | app = typer.Typer(invoke_without_command=True, callback=fix)
28 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/init.py:
--------------------------------------------------------------------------------
1 | from importlib import resources
2 | import typer
3 | from rich import print
4 | from promptmodel.constants import (
5 | PROMPTMODEL_DEV_FILENAME,
6 | PROMPTMODEL_DEV_STARTER_FILENAME,
7 | )
8 |
9 |
10 | def init(from_cli: bool = True):
11 | """Initialize a new promptmodel project."""
12 | import os
13 |
14 | if not os.path.exists(PROMPTMODEL_DEV_FILENAME):
15 | # Read the content from the source file
16 | content = resources.read_text("promptmodel", PROMPTMODEL_DEV_STARTER_FILENAME)
17 |
18 | # Write the content to the target file
19 | with open(PROMPTMODEL_DEV_FILENAME, "w") as target_file:
20 | target_file.write(content)
21 | print(
22 | "[violet][bold]promptmodel_dev.py[/bold][/violet] was successfully created!"
23 | )
24 | print(
25 | "Add promptmodels in your code, then run [violet][bold]prompt dev[/bold][/violet] to start engineering prompts."
26 | )
27 | elif from_cli:
28 | print(
29 | "[yellow]promptmodel_dev.py[/yellow] was already initialized in this directory."
30 | )
31 | print(
32 | "Run [violet][bold]prompt dev[/bold][/violet] to start engineering prompts."
33 | )
34 |
35 |
36 | app = typer.Typer(invoke_without_command=True, callback=init)
37 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/login.py:
--------------------------------------------------------------------------------
1 | import time
2 | from promptmodel.apis.base import APIClient
3 | import typer
4 | import webbrowser
5 | from rich import print
6 |
7 | from promptmodel.constants import ENDPOINT_URL, GRANT_ACCESS_URL
8 | from promptmodel.utils.config_utils import upsert_config
9 | from promptmodel.utils.crypto import generate_api_key, encrypt_message
10 |
11 |
12 | def login():
13 | """Authenticate Client CLI."""
14 | # TODO: Check if already logged in
15 | api_key = generate_api_key()
16 | encrypted_key = encrypt_message(api_key)
17 | upsert_config({"encrypted_api_key": encrypted_key}, section="connection")
18 | url = f"{GRANT_ACCESS_URL}?token={api_key}"
19 | webbrowser.open(url)
20 | print("Please grant access to the CLI by visiting the URL in your browser.")
21 | print("Once you have granted access, you can close the browser tab.")
22 | print(f"\nURL: [link={url}]{url}[/link]\n")
23 | print("Waiting...\n")
24 | waiting_time = 0
25 | while waiting_time < 300:
26 | # Check access every 5 seconds
27 | try:
28 | res = APIClient.execute("/cli_access/check", ignore_auth_error=True)
29 | if res.json() == True:
30 | print("[green]Access granted![/green] 🎉")
31 | print(
32 | "Run [violet][bold]prompt init[/bold][/violet] to start developing prompts.\n"
33 | )
34 | return
35 | except Exception as err:
36 | print(f"[red]Error: {err}[/red]")
37 | time.sleep(5)
38 | waiting_time += 5
39 | print("Please try again later.")
40 |
41 |
42 | app = typer.Typer(invoke_without_command=True, callback=login)
43 |
--------------------------------------------------------------------------------
/promptmodel/cli/commands/project.py:
--------------------------------------------------------------------------------
1 | """CLI for project management."""
2 |
3 | from promptmodel.apis.base import APIClient
4 | import typer
5 | from rich import print
6 |
7 | from promptmodel.utils.config_utils import read_config, upsert_config
8 | from ..utils import get_org
9 |
10 | app = typer.Typer(no_args_is_help=True, short_help="Manage Client projects.")
11 |
12 |
13 | @app.command()
14 | def list():
15 | """List all projects."""
16 | config = read_config()
17 | org = get_org(config)
18 | projects = APIClient.execute(
19 | method="GET",
20 | path="/projects",
21 | params={"organization_id": org["organization_id"]},
22 | ).json()
23 | print("\nProjects:")
24 | for project in projects:
25 | print(f"📌 {project['name']} ({project['version']})")
26 | if project["description"]:
27 | print(f" {project['description']}\n")
28 |
--------------------------------------------------------------------------------
/promptmodel/cli/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import typer
4 | from promptmodel.cli.commands.login import app as login
5 | from promptmodel.cli.commands.init import app as init
6 |
7 | # from promptmodel.cli.commands.dev import app as dev
8 | from promptmodel.cli.commands.connect import app as connect
9 | from promptmodel.cli.commands.project import app as project
10 | from promptmodel.cli.commands.configure import app as configure
11 | from promptmodel.cli.commands.fix import app as fix
12 |
13 | # 현재 작업 디렉토리를 sys.path에 추가
14 | current_working_directory = os.getcwd()
15 | if current_working_directory not in sys.path:
16 | sys.path.append(current_working_directory)
17 |
18 | app = typer.Typer(no_args_is_help=True, pretty_exceptions_enable=False)
19 |
20 | app.add_typer(login, name="login")
21 | app.add_typer(init, name="init")
22 | # app.add_typer(dev, name="dev")
23 | app.add_typer(connect, name="connect")
24 | app.add_typer(project, name="project")
25 | app.add_typer(configure, name="configure")
26 | app.add_typer(fix, name="fix")
27 |
28 | if __name__ == "__main__":
29 | app()
30 |
--------------------------------------------------------------------------------
/promptmodel/cli/signal_handler.py:
--------------------------------------------------------------------------------
1 | import signal
2 | import sys
3 | from promptmodel.utils.config_utils import upsert_config, read_config
4 |
5 |
6 | def dev_terminate_signal_handler(sig, frame):
7 | config = read_config()
8 | print("\nTerminating...")
9 | if "connection" in config:
10 | upsert_config({"online": False}, section="connection")
11 | upsert_config({"initializing": False}, "connection")
12 | sys.exit(0)
13 |
--------------------------------------------------------------------------------
/promptmodel/cli/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 | from InquirerPy import inquirer
3 | from rich import print
4 | from promptmodel.apis.base import APIClient
5 |
6 |
7 | def get_org(config: Dict[str, Any]) -> Dict[str, Any]:
8 | """
9 | Gets the current organization from the configuration.
10 |
11 | :return: A dictionary containing the current organization.
12 | """
13 | if "connection" not in config:
14 | print("User not logged in. Please run [violet]prompt login[/violet] first.")
15 | exit()
16 | if "org" not in config["connection"]:
17 | orgs = APIClient.execute(method="GET", path="/organizations").json()
18 | choices = [
19 | {
20 | "key": org["name"],
21 | "name": org["name"],
22 | "value": org,
23 | }
24 | for org in orgs
25 | ]
26 | org = inquirer.select(message="Select organization:", choices=choices).execute()
27 | else:
28 | org = config["connection"]["org"]
29 | return org
30 |
31 |
32 | def get_project(config: Dict[str, Any], org: Dict[str, Any]) -> Dict[str, Any]:
33 | """
34 | Gets the current project from the configuration.
35 |
36 | :return: A dictionary containing the current project.
37 | """
38 | if "project" not in config["connection"]:
39 | projects = APIClient.execute(
40 | method="GET",
41 | path="/projects",
42 | params={"organization_id": org["organization_id"]},
43 | ).json()
44 | choices = [
45 | {
46 | "key": project["name"],
47 | "name": project["name"],
48 | "value": project,
49 | }
50 | for project in projects
51 | ]
52 | project = inquirer.select(message="Select project:", choices=choices).execute()
53 | else:
54 | project = config["connection"]["project"]
55 | return project
56 |
--------------------------------------------------------------------------------
/promptmodel/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 |
4 | load_dotenv()
5 | testmode: str = os.environ.get("TESTMODE", "false")
6 |
7 | if testmode == "true":
8 | ENDPOINT_URL = (
9 | os.environ.get(
10 | "TESTMODE_PROMPTMODEL_BACKEND_PUBLIC_URL", "http://localhost:8000"
11 | )
12 | + "/api/cli"
13 | )
14 | WEB_CLIENT_URL = os.environ.get(
15 | "TESTMODE_PROMPTMODEL_FRONTEND_PUBLIC_URL", "http://localhost:3000"
16 | )
17 | GRANT_ACCESS_URL = WEB_CLIENT_URL + "/cli/grant-access"
18 | else:
19 | ENDPOINT_URL = (
20 | os.environ.get(
21 | "PROMPTMODEL_BACKEND_PUBLIC_URL", "https://promptmodel.up.railway.app"
22 | )
23 | + "/api/cli"
24 | )
25 |
26 | WEB_CLIENT_URL = os.environ.get(
27 | "PROMPTMODEL_FRONTEND_PUBLIC_URL", "https://app.promptmodel.run"
28 | )
29 | GRANT_ACCESS_URL = WEB_CLIENT_URL + "/cli/grant-access"
30 |
31 | PROMPTMODEL_DEV_FILENAME = os.path.join(os.getcwd(), "promptmodel_dev.py")
32 | PROMPTMODEL_DEV_STARTER_FILENAME = "STARTER.py"
33 |
--------------------------------------------------------------------------------
/promptmodel/database/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
--------------------------------------------------------------------------------
/promptmodel/database/config.py:
--------------------------------------------------------------------------------
1 | from peewee import SqliteDatabase, Model
2 |
3 | db = SqliteDatabase("./.promptmodel/promptmodel.db")
4 |
5 |
6 | class BaseModel(Model):
7 | class Meta:
8 | database = db
9 |
--------------------------------------------------------------------------------
/promptmodel/database/crud.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple
2 | from promptmodel.database.models import (
3 | DeployedFunctionModel,
4 | DeployedFunctionModelVersion,
5 | DeployedPrompt,
6 | )
7 | from playhouse.shortcuts import model_to_dict
8 | from promptmodel.utils.random_utils import select_version_by_ratio
9 | from promptmodel.utils import logger
10 | from promptmodel.database.config import db
11 | from promptmodel.database.crud_chat import *
12 |
13 |
14 | # Insert
15 |
16 | # Select all
17 |
18 | # Select one
19 |
20 |
21 | def get_deployed_prompts(function_model_name: str) -> Tuple[List[DeployedPrompt], str]:
22 | try:
23 | with db.atomic():
24 | versions: List[DeployedFunctionModelVersion] = list(
25 | DeployedFunctionModelVersion.select()
26 | .join(DeployedFunctionModel)
27 | .where(
28 | DeployedFunctionModelVersion.function_model_uuid
29 | == DeployedFunctionModel.get(
30 | DeployedFunctionModel.name == function_model_name
31 | ).uuid
32 | )
33 | )
34 | prompts: List[DeployedPrompt] = list(
35 | DeployedPrompt.select()
36 | .where(
37 | DeployedPrompt.version_uuid.in_(
38 | [version.uuid for version in versions]
39 | )
40 | )
41 | .order_by(DeployedPrompt.step.asc())
42 | )
43 | # select version by ratio
44 | selected_version = select_version_by_ratio(
45 | [version.__data__ for version in versions]
46 | )
47 | selected_prompts = list(
48 | filter(
49 | lambda prompt: str(prompt.version_uuid.uuid)
50 | == str(selected_version["uuid"]),
51 | prompts,
52 | )
53 | )
54 |
55 | version_details = {
56 | "model": selected_version["model"],
57 | "version" : selected_version["version"],
58 | "uuid": selected_version["uuid"],
59 | "parsing_type": selected_version["parsing_type"],
60 | "output_keys": selected_version["output_keys"],
61 | }
62 |
63 | return selected_prompts, version_details
64 | except Exception as e:
65 | logger.error(e)
66 | return None, None
67 |
68 |
69 | # Update
70 |
71 |
72 | async def update_deployed_cache(project_status: dict):
73 | """Update Deployed Prompts Cache"""
74 | # TODO: 효율적으로 수정
75 | # 현재는 delete all & insert all
76 | function_models = project_status["function_models"]
77 | function_model_versions = project_status["function_model_versions"]
78 | for version in function_model_versions:
79 | if version["is_published"] is True:
80 | version["ratio"] = 1.0
81 | prompts = project_status["prompts"]
82 |
83 | with db.atomic():
84 | DeployedFunctionModel.delete().execute()
85 | DeployedFunctionModelVersion.delete().execute()
86 | DeployedPrompt.delete().execute()
87 | DeployedFunctionModel.insert_many(function_models).execute()
88 | DeployedFunctionModelVersion.insert_many(function_model_versions).execute()
89 | DeployedPrompt.insert_many(prompts).execute()
90 | return
91 |
--------------------------------------------------------------------------------
/promptmodel/database/crud_chat.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple, Any
2 | from uuid import UUID
3 | from playhouse.shortcuts import model_to_dict
4 | from peewee import fn, JOIN
5 |
6 | from promptmodel.utils import logger
7 | from promptmodel.database.config import db
8 |
9 |
10 | # def delete_fake_sessions():
11 | # """Delete ChatLogSession which len(ChatLog) == 1 where ChatLog.session_uuid == ChatLogSession.uuid"""
12 | # with db.atomic():
13 | # sessions_to_delete: List[ChatLogSession] = list(
14 | # ChatLogSession.select()
15 | # .join(ChatLog, JOIN.LEFT_OUTER)
16 | # .group_by(ChatLogSession)
17 | # .having(fn.COUNT(ChatLog.id) <= 1)
18 | # )
19 | # (
20 | # (
21 | # ChatLogSession.delete().where(
22 | # ChatLogSession.uuid.in_(
23 | # [session.uuid for session in sessions_to_delete]
24 | # )
25 | # )
26 | # ).execute()
27 | # )
28 | # return
29 |
30 |
31 | # def hide_chat_model_not_in_code(local_chat_model_list: List):
32 | # return (
33 | # ChatModel.update(used_in_code=False)
34 | # .where(ChatModel.name.not_in(local_chat_model_list))
35 | # .execute()
36 | # )
37 |
38 |
39 | # def update_chat_model_uuid(local_uuid, new_uuid):
40 | # """Update ChatModel.uuid"""
41 | # if str(local_uuid) == str(new_uuid):
42 | # return
43 | # else:
44 | # with db.atomic():
45 | # local_chat_model: ChatModel = ChatModel.get(ChatModel.uuid == local_uuid)
46 | # ChatModel.create(
47 | # uuid=new_uuid,
48 | # name=local_chat_model.name,
49 | # project_uuid=local_chat_model.project_uuid,
50 | # created_at=local_chat_model.created_at,
51 | # used_in_code=local_chat_model.used_in_code,
52 | # is_deployed=True,
53 | # )
54 | # ChatModelVersion.update(function_model_uuid=new_uuid).where(
55 | # ChatModelVersion.chat_model_uuid == local_uuid
56 | # ).execute()
57 | # ChatModel.delete().where(ChatModel.uuid == local_uuid).execute()
58 | # return
59 |
60 |
61 | # def update_candidate_chat_model_version(new_candidates: dict):
62 | # """Update candidate ChatModelVersion's candidate_version_id(version)"""
63 | # with db.atomic():
64 | # for uuid, version in new_candidates.items():
65 | # (
66 | # ChatModelVersion.update(version=version, is_deployed=True)
67 | # .where(ChatModelVersion.uuid == uuid)
68 | # .execute()
69 | # )
70 | # # Find ChatModel
71 | # chat_model_versions: List[ChatModelVersion] = list(
72 | # ChatModelVersion.select().where(
73 | # ChatModelVersion.uuid.in_(list(new_candidates.keys()))
74 | # )
75 | # )
76 | # chat_model_uuids = [
77 | # chat_model.chat_model_uuid.uuid for chat_model in chat_model_versions
78 | # ]
79 | # ChatModel.update(is_deployed=True).where(
80 | # ChatModel.uuid.in_(chat_model_uuids)
81 | # ).execute()
82 |
83 |
84 | # def rename_chat_model(chat_model_uuid: str, new_name: str):
85 | # """Update the name of the given ChatModel."""
86 | # return (
87 | # ChatModel.update(name=new_name)
88 | # .where(ChatModel.uuid == chat_model_uuid)
89 | # .execute()
90 | # )
91 |
92 |
93 | # def fetch_chat_log_with_uuid(session_uuid: str):
94 | # try:
95 | # try:
96 | # chat_logs: List[ChatLog] = list(
97 | # ChatLog.select()
98 | # .where(ChatLog.session_uuid == UUID(session_uuid))
99 | # .order_by(ChatLog.created_at.asc())
100 | # )
101 | # except:
102 | # return []
103 |
104 | # chat_log_to_return = [
105 | # {
106 | # "role": chat_log.role,
107 | # "content": chat_log.content,
108 | # "tool_calls": chat_log.tool_calls,
109 | # }
110 | # for chat_log in chat_logs
111 | # ]
112 | # return chat_log_to_return
113 | # except Exception as e:
114 | # logger.error(e)
115 | # return None
116 |
117 |
118 | # def get_latest_version_chat_model(
119 | # chat_model_name: str,
120 | # session_uuid: Optional[str] = None,
121 | # ) -> Tuple[List[Dict[str, Any]], str]:
122 | # try:
123 | # if session_uuid:
124 | # if type(session_uuid) == str:
125 | # session_uuid = UUID(session_uuid)
126 | # session: ChatLogSession = ChatLogSession.get(
127 | # ChatLogSession.uuid == session_uuid
128 | # )
129 |
130 | # version: ChatModelVersion = ChatModelVersion.get(
131 | # ChatModelVersion.uuid == session.version_uuid
132 | # )
133 |
134 | # instruction: List[Dict[str, Any]] = [version["system_prompt"]]
135 | # else:
136 | # with db.atomic():
137 | # try:
138 | # sessions_with_version: List[ChatLogSession] = list(
139 | # ChatLogSession.select()
140 | # .join(ChatModelVersion)
141 | # .where(
142 | # ChatModelVersion.chat_model_uuid
143 | # == ChatModel.get(ChatModel.name == chat_model_name).uuid
144 | # )
145 | # )
146 | # session_uuids = [x.uuid for x in sessions_with_version]
147 |
148 | # latest_chat_log: List[ChatLog] = list(
149 | # ChatLog.select()
150 | # .where(ChatLog.session_uuid.in_(session_uuids))
151 | # .order_by(ChatLog.created_at.desc())
152 | # )
153 |
154 | # latest_chat_log: ChatLog = latest_chat_log[0]
155 | # latest_session: ChatLogSession = ChatLogSession.get(
156 | # ChatLogSession.uuid == latest_chat_log.session_uuid
157 | # )
158 |
159 | # version: ChatModelVersion = (
160 | # ChatModelVersion.select()
161 | # .where(ChatModelVersion.uuid == latest_session.uuid)
162 | # .get()
163 | # )
164 | # except:
165 | # version: ChatModelVersion = list(
166 | # ChatModelVersion.select()
167 | # .join(ChatModel)
168 | # .where(ChatModel.name == chat_model_name)
169 | # .order_by(ChatModelVersion.created_at.desc())
170 | # .get()
171 | # )
172 |
173 | # instruction: List[Dict[str, Any]] = [version["system_prompt"]]
174 |
175 | # version_details = {
176 | # "model": version.model,
177 | # "uuid": version.uuid,
178 | # }
179 |
180 | # return instruction, version_details
181 |
182 | # except Exception as e:
183 | # logger.error(e)
184 | # return None, None
185 |
186 |
187 | # def find_ancestor_chat_model_version(
188 | # chat_model_version_uuid: str, versions: Optional[list] = None
189 | # ):
190 | # """Find ancestor ChatModel version"""
191 |
192 | # # get all versions
193 | # if versions is None:
194 | # response = list(ChatModelVersion.select())
195 | # versions = [model_to_dict(x, recurse=False) for x in response]
196 |
197 | # # find target version
198 | # target = list(
199 | # filter(lambda version: version["uuid"] == chat_model_version_uuid, versions)
200 | # )[0]
201 |
202 | # target = _find_ancestor(target, versions)
203 |
204 | # return target
205 |
206 |
207 | # def find_ancestor_chat_model_versions(target_chat_model_uuid: Optional[str] = None):
208 | # """find ancestor versions for each versions in input"""
209 | # # get all versions
210 | # if target_chat_model_uuid is not None:
211 | # response = list(
212 | # ChatModelVersion.select().where(
213 | # ChatModelVersion.chat_model_uuid == target_chat_model_uuid
214 | # )
215 | # )
216 | # else:
217 | # response = list(ChatModelVersion.select())
218 | # versions = [model_to_dict(x, recurse=False) for x in response]
219 |
220 | # targets = list(
221 | # filter(
222 | # lambda version: version["status"] == ModelVersionStatus.CANDIDATE.value
223 | # and version["version"] is None,
224 | # versions,
225 | # )
226 | # )
227 |
228 | # targets = [
229 | # find_ancestor_chat_model_version(target["uuid"], versions) for target in targets
230 | # ]
231 | # targets_with_real_ancestor = [target for target in targets]
232 |
233 | # return targets_with_real_ancestor
234 |
235 |
236 | # def _find_ancestor(target: dict, versions: List[Dict]):
237 | # ancestor = None
238 | # temp = target
239 | # if target["from_uuid"] is None:
240 | # ancestor = None
241 | # else:
242 | # while temp["from_uuid"] is not None:
243 | # new_temp = [
244 | # version for version in versions if version["uuid"] == temp["from_uuid"]
245 | # ][0]
246 | # if (
247 | # new_temp["version"] is not None
248 | # or new_temp["status"] == ModelVersionStatus.CANDIDATE.value
249 | # ):
250 | # ancestor = new_temp
251 | # break
252 | # else:
253 | # temp = new_temp
254 | # target["from_uuid"] = ancestor["uuid"] if ancestor is not None else None
255 |
256 | # return target
257 |
--------------------------------------------------------------------------------
/promptmodel/database/models.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 |
4 | from uuid import uuid4
5 | from peewee import (
6 | CharField,
7 | DateTimeField,
8 | IntegerField,
9 | ForeignKeyField,
10 | BooleanField,
11 | UUIDField,
12 | TextField,
13 | AutoField,
14 | FloatField,
15 | Check,
16 | )
17 |
18 | from promptmodel.database.config import BaseModel
19 | from promptmodel.database.models_chat import *
20 | from promptmodel.types.enums import ParsingType
21 |
22 |
23 | class JSONField(TextField):
24 | def db_value(self, value):
25 | return json.dumps(value)
26 |
27 | def python_value(self, value):
28 | return json.loads(value)
29 |
30 |
31 | class DeployedFunctionModel(BaseModel):
32 | uuid = UUIDField(unique=True, default=uuid4)
33 | name = CharField()
34 |
35 |
36 | class DeployedFunctionModelVersion(BaseModel):
37 | uuid = UUIDField(unique=True, default=uuid4)
38 | version = IntegerField(null=False)
39 | from_version = IntegerField(null=True)
40 | function_model_uuid = ForeignKeyField(
41 | DeployedFunctionModel,
42 | field=DeployedFunctionModel.uuid,
43 | backref="versions",
44 | on_delete="CASCADE",
45 | )
46 | model = CharField()
47 | is_published = BooleanField(default=False)
48 | is_ab_test = BooleanField(default=False)
49 | ratio = FloatField(null=True)
50 | parsing_type = CharField(
51 | null=True,
52 | default=None,
53 | constraints=[
54 | Check(
55 | f"parsing_type IN ('{ParsingType.COLON.value}', '{ParsingType.SQUARE_BRACKET.value}', '{ParsingType.DOUBLE_SQUARE_BRACKET.value}')"
56 | )
57 | ],
58 | )
59 | output_keys = JSONField(null=True, default=None)
60 | functions = JSONField(default=[])
61 |
62 |
63 | class DeployedPrompt(BaseModel):
64 | id = AutoField()
65 | version_uuid = ForeignKeyField(
66 | DeployedFunctionModelVersion,
67 | field=DeployedFunctionModelVersion.uuid,
68 | backref="prompts",
69 | on_delete="CASCADE",
70 | )
71 | role = CharField()
72 | step = IntegerField()
73 | content = TextField()
74 |
--------------------------------------------------------------------------------
/promptmodel/database/models_chat.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 |
4 | from uuid import uuid4
5 | from peewee import (
6 | CharField,
7 | DateTimeField,
8 | IntegerField,
9 | ForeignKeyField,
10 | BooleanField,
11 | UUIDField,
12 | TextField,
13 | AutoField,
14 | FloatField,
15 | Check,
16 | )
17 |
18 |
19 | class JSONField(TextField):
20 | def db_value(self, value):
21 | return json.dumps(value)
22 |
23 | def python_value(self, value):
24 | return json.loads(value)
25 |
26 |
27 | # class DeployedChatModel(BaseModel):
28 | # uuid = UUIDField(unique=True, default=uuid4)
29 | # name = CharField()
30 |
31 |
32 | # class DeployedChatModelVersion(BaseModel):
33 | # uuid = UUIDField(unique=True, default=uuid4)
34 | # from_uuid = UUIDField(null=True)
35 | # chat_model_uuid = ForeignKeyField(
36 | # DeployedChatModel,
37 | # field=DeployedChatModel.uuid,
38 | # backref="versions",
39 | # on_delete="CASCADE",
40 | # )
41 | # model = CharField()
42 | # is_published = BooleanField(default=False)
43 | # is_ab_test = BooleanField(default=False)
44 | # ratio = FloatField(null=True)
45 | # system_prompt = JSONField(null=True, default={})
46 | # functions = JSONField(default=[])
47 |
--------------------------------------------------------------------------------
/promptmodel/database/orm.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .models import (
4 | DeployedFunctionModel,
5 | DeployedFunctionModelVersion,
6 | DeployedPrompt,
7 | )
8 | from .config import db
9 |
10 |
11 | def initialize_db():
12 | if not os.path.exists("./.promptmodel"):
13 | os.mkdir("./.promptmodel")
14 | # Check if db connection exists
15 | if db.is_closed():
16 | db.connect()
17 | with db.atomic():
18 | if not DeployedFunctionModel.table_exists():
19 | db.create_tables(
20 | [
21 | # FunctionModel,
22 | # FunctionModelVersion,
23 | # Prompt,
24 | # RunLog,
25 | # SampleInputs,
26 | DeployedFunctionModel,
27 | DeployedFunctionModelVersion,
28 | DeployedPrompt,
29 | # ChatModel,
30 | # ChatModelVersion,
31 | # ChatLogSession,
32 | # ChatLog,
33 | ]
34 | )
35 | db.close()
36 |
--------------------------------------------------------------------------------
/promptmodel/dev_app.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import dis
3 | import nest_asyncio
4 | from dataclasses import dataclass
5 | from typing import Callable, Dict, Any, List, Optional, Union
6 |
7 | from promptmodel.types.response import FunctionSchema
8 |
9 |
10 | @dataclass
11 | class FunctionModelInterface:
12 | name: str
13 |
14 |
15 | @dataclass
16 | class ChatModelInterface:
17 | name: str
18 |
19 |
20 | class DevClient:
21 | """DevClient main class"""
22 |
23 | def __init__(self):
24 | self.function_models: List[FunctionModelInterface] = []
25 | self.chat_models: List[ChatModelInterface] = []
26 |
27 | def register(self, func):
28 | instructions = list(dis.get_instructions(func))
29 | for idx in range(
30 | len(instructions) - 1
31 | ): # We check up to len-1 because we access idx+1 inside loop
32 | instruction = instructions[idx]
33 | # print(instruction)
34 | if instruction.opname in ["LOAD_ATTR", "LOAD_METHOD", "LOAD_GLOBAL"] and (
35 | instruction.argval == "FunctionModel"
36 | or instruction.argval == "ChatModel"
37 | ):
38 | next_instruction = instructions[idx + 1]
39 |
40 | # Check if the next instruction is LOAD_CONST with string value
41 | if next_instruction.opname == "LOAD_CONST" and isinstance(
42 | next_instruction.argval, str
43 | ):
44 | if instruction.argval == "FunctionModel":
45 | self.function_models.append(
46 | FunctionModelInterface(name=next_instruction.argval)
47 | )
48 | elif instruction.argval == "ChatModel":
49 | self.chat_models.append(
50 | ChatModelInterface(name=next_instruction.argval)
51 | )
52 |
53 | def wrapper(*args, **kwargs):
54 | return func(*args, **kwargs)
55 |
56 | return wrapper
57 |
58 | def register_function_model(self, name):
59 | for function_model in self.function_models:
60 | if function_model.name == name:
61 | return
62 |
63 | self.function_models.append(FunctionModelInterface(name=name))
64 |
65 | def register_chat_model(self, name):
66 | for chat_model in self.chat_models:
67 | if chat_model.name == name:
68 | return
69 |
70 | self.chat_models.append(ChatModelInterface(name=name))
71 |
72 | def _get_function_model_name_list(self) -> List[str]:
73 | return [function_model.name for function_model in self.function_models]
74 |
75 |
76 | class DevApp:
77 | _nest_asyncio_applied = False
78 |
79 | def __init__(self):
80 | self.function_models: List[FunctionModelInterface] = []
81 | self.chat_models: List[ChatModelInterface] = []
82 | self.samples: List[Dict[str, Any]] = []
83 | self.functions: Dict[
84 | str, Dict[str, Union[FunctionSchema, Optional[Callable]]]
85 | ] = {}
86 |
87 | if not DevApp._nest_asyncio_applied:
88 | DevApp._nest_asyncio_applied = True
89 | nest_asyncio.apply()
90 |
91 | def include_client(self, client: DevClient):
92 | self.function_models.extend(client.function_models)
93 | self.chat_models.extend(client.chat_models)
94 |
95 | def register_function(
96 | self, schema: Union[Dict[str, Any], FunctionSchema], function: Callable
97 | ):
98 | function_name = schema["name"]
99 | if isinstance(schema, dict):
100 | try:
101 | schema = FunctionSchema(**schema)
102 | except:
103 | raise ValueError("schema is not a valid function call schema.")
104 |
105 | if function_name not in self.functions:
106 | self.functions[function_name] = {
107 | "schema": schema,
108 | "function": function,
109 | }
110 |
111 | def _call_register_function(self, name: str, arguments: Dict[str, str]):
112 | function_to_call: Optional[Callable] = self.functions[name]["function"]
113 | if not function_to_call:
114 | return
115 | try:
116 | function_response = function_to_call(**arguments)
117 | return function_response
118 | except Exception as e:
119 | raise e
120 |
121 | def _get_function_name_list(self) -> List[str]:
122 | return list(self.functions.keys())
123 |
124 | def _get_function_schema_list(self) -> List[Dict]:
125 | return [
126 | self.functions[function_name]["schema"].model_dump()
127 | for function_name in self._get_function_name_list()
128 | ]
129 |
130 | def _get_function_schemas(self, function_names: List[str] = []):
131 | try:
132 | function_schemas = [
133 | self.functions[function_name]["schema"].model_dump()
134 | for function_name in function_names
135 | ]
136 | return function_schemas
137 | except Exception as e:
138 | raise e
139 |
140 | def register_sample(self, name: str, content: Dict[str, Any]):
141 | self.samples.append({"name": name, "content": content})
142 |
143 | def _get_function_model_name_list(self) -> List[str]:
144 | return [function_model.name for function_model in self.function_models]
145 |
146 | def _get_chat_model_name_list(self) -> List[str]:
147 | return [chat_model.name for chat_model in self.chat_models]
148 |
--------------------------------------------------------------------------------
/promptmodel/llms/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
--------------------------------------------------------------------------------
/promptmodel/llms/llm_dev.py:
--------------------------------------------------------------------------------
1 | """LLM for Development TestRun"""
2 |
3 | import re
4 | import json
5 | from datetime import datetime
6 | from typing import Any, AsyncGenerator, List, Dict, Optional
7 | from pydantic import BaseModel
8 | from dotenv import load_dotenv
9 | from litellm import acompletion, get_max_tokens
10 |
11 | from promptmodel.types.enums import ParsingType, get_pattern_by_type
12 | from promptmodel.utils import logger
13 | from promptmodel.utils.output_utils import convert_str_to_type
14 | from promptmodel.utils.token_counting import (
15 | num_tokens_for_messages,
16 | num_tokens_for_messages_for_each,
17 | num_tokens_from_functions_input,
18 | )
19 | from promptmodel.types.response import (
20 | LLMStreamResponse,
21 | ModelResponse,
22 | Usage,
23 | Choices,
24 | Message,
25 | )
26 |
27 | load_dotenv()
28 |
29 |
30 | class OpenAIMessage(BaseModel):
31 | role: Optional[str] = None
32 | content: Optional[str] = ""
33 | function_call: Optional[Dict[str, Any]] = None
34 | name: Optional[str] = None
35 |
36 |
37 | class LLMDev:
38 | def __init__(self):
39 | self._model: str
40 |
41 | def __validate_openai_messages(
42 | self, messages: List[Dict[str, Any]]
43 | ) -> List[OpenAIMessage]:
44 | """Validate and convert list of dictionaries to list of OpenAIMessage."""
45 | res = []
46 | for message in messages:
47 | res.append(OpenAIMessage(**message))
48 | return res
49 |
50 | async def dev_run(
51 | self,
52 | messages: List[Dict[str, Any]],
53 | parsing_type: Optional[ParsingType] = None,
54 | functions: Optional[List[Any]] = None,
55 | model: Optional[str] = None,
56 | **kwargs,
57 | ) -> AsyncGenerator[Any, None]:
58 | """Parse & stream output from openai chat completion."""
59 | _model = model or self._model
60 | raw_output = ""
61 | if functions == []:
62 | functions = None
63 |
64 | start_time = datetime.now()
65 |
66 | response: AsyncGenerator[ModelResponse, None] = await acompletion(
67 | model=_model,
68 | messages=[
69 | message.model_dump(exclude_none=True)
70 | for message in self.__validate_openai_messages(messages)
71 | ],
72 | stream=True,
73 | functions=functions,
74 | response_format=(
75 | {"type": "json_object"}
76 | if parsing_type == ParsingType.JSON
77 | else {"type": "text"}
78 | ),
79 | **kwargs,
80 | )
81 | function_call = {"name": "", "arguments": ""}
82 | finish_reason_function_call = False
83 | async for chunk in response:
84 | if getattr(chunk.choices[0].delta, "content", None) is not None:
85 | stream_value = chunk.choices[0].delta.content
86 | raw_output += stream_value # append raw output
87 | yield LLMStreamResponse(raw_output=stream_value) # return raw output
88 |
89 | if getattr(chunk.choices[0].delta, "function_call", None) is not None:
90 | for key, value in (
91 | chunk.choices[0].delta.function_call.model_dump().items()
92 | ):
93 | if value is not None:
94 | function_call[key] += value
95 |
96 | if chunk.choices[0].finish_reason == "function_call":
97 | finish_reason_function_call = True
98 | yield LLMStreamResponse(function_call=function_call)
99 |
100 | if chunk.choices[0].finish_reason != None:
101 | end_time = datetime.now()
102 | response_ms = (end_time - start_time).total_seconds() * 1000
103 | yield LLMStreamResponse(
104 | api_response=self.make_model_response_dev(
105 | chunk,
106 | response_ms,
107 | messages,
108 | raw_output,
109 | functions=functions,
110 | function_call=(
111 | function_call
112 | if chunk.choices[0].finish_reason == "function_call"
113 | else None
114 | ),
115 | tools=None,
116 | tool_calls=None,
117 | )
118 | )
119 |
120 | # parsing
121 | if parsing_type and not finish_reason_function_call:
122 | if parsing_type == ParsingType.JSON:
123 | yield LLMStreamResponse(parsed_outputs=json.loads(raw_output))
124 | else:
125 | parsing_pattern: Dict[str, str] = get_pattern_by_type(parsing_type)
126 | whole_pattern = parsing_pattern["whole"]
127 | parsed_results = re.findall(whole_pattern, raw_output, flags=re.DOTALL)
128 | for parsed_result in parsed_results:
129 | key = parsed_result[0]
130 | type_str = parsed_result[1]
131 | value = convert_str_to_type(parsed_result[2], type_str)
132 | yield LLMStreamResponse(parsed_outputs={key: value})
133 |
134 | async def dev_chat(
135 | self,
136 | messages: List[Dict[str, Any]],
137 | functions: Optional[List[Any]] = None,
138 | tools: Optional[List[Any]] = None,
139 | model: Optional[str] = None,
140 | **kwargs,
141 | ) -> AsyncGenerator[LLMStreamResponse, None]:
142 | """Parse & stream output from openai chat completion."""
143 | _model = model or self._model
144 | raw_output = ""
145 | if functions == []:
146 | functions = None
147 |
148 | if model != "HCX-002":
149 | # Truncate the output if it is too long
150 | # truncate messages to make length <= model's max length
151 | token_per_functions = num_tokens_from_functions_input(
152 | functions=functions, model=model
153 | )
154 | model_max_tokens = get_max_tokens(model=model)
155 | token_per_messages = num_tokens_for_messages_for_each(messages, model)
156 | token_limit_exceeded = (
157 | sum(token_per_messages) + token_per_functions
158 | ) - model_max_tokens
159 | if token_limit_exceeded > 0:
160 | while token_limit_exceeded > 0:
161 | # erase the second oldest message (first one is system prompt, so it should not be erased)
162 | if len(messages) == 1:
163 | # if there is only one message, Error cannot be solved. Just call LLM and get error response
164 | break
165 | token_limit_exceeded -= token_per_messages[1]
166 | del messages[1]
167 | del token_per_messages[1]
168 |
169 | args = dict(
170 | model=_model,
171 | messages=[
172 | message.model_dump(exclude_none=True)
173 | for message in self.__validate_openai_messages(messages)
174 | ],
175 | functions=functions,
176 | tools=tools,
177 | )
178 |
179 | is_stream_unsupported = model in ["HCX-002"]
180 | if not is_stream_unsupported:
181 | args["stream"] = True
182 |
183 | start_time = datetime.now()
184 | response: AsyncGenerator[ModelResponse, None] = await acompletion(
185 | **args, **kwargs
186 | )
187 | if is_stream_unsupported:
188 | yield LLMStreamResponse(raw_output=response.choices[0].message.content)
189 | else:
190 | async for chunk in response:
191 | yield_api_response_with_fc = False
192 | logger.debug(chunk)
193 | if getattr(chunk.choices[0].delta, "function_call", None) is not None:
194 | yield LLMStreamResponse(
195 | api_response=chunk,
196 | function_call=chunk.choices[0].delta.function_call,
197 | )
198 | yield_api_response_with_fc = True
199 |
200 | if getattr(chunk.choices[0].delta, "tool_calls", None) is not None:
201 | yield LLMStreamResponse(
202 | api_response=chunk,
203 | tool_calls=chunk.choices[0].delta.tool_calls,
204 | )
205 | yield_api_response_with_fc = True
206 |
207 | if getattr(chunk.choices[0].delta, "content", None) is not None:
208 | raw_output += chunk.choices[0].delta.content
209 | yield LLMStreamResponse(
210 | api_response=chunk if not yield_api_response_with_fc else None,
211 | raw_output=chunk.choices[0].delta.content,
212 | )
213 |
214 | if getattr(chunk.choices[0].delta, "finish_reason", None) is not None:
215 | end_time = datetime.now()
216 | response_ms = (end_time - start_time).total_seconds() * 1000
217 | yield LLMStreamResponse(
218 | api_response=self.make_model_response_dev(
219 | chunk,
220 | response_ms,
221 | messages,
222 | raw_output,
223 | functions=None,
224 | function_call=(
225 | None
226 | if chunk.choices[0].finish_reason == "function_call"
227 | else None
228 | ),
229 | tools=None,
230 | tool_calls=None,
231 | )
232 | )
233 |
234 | def make_model_response_dev(
235 | self,
236 | chunk: ModelResponse,
237 | response_ms,
238 | messages: List[Dict[str, str]],
239 | raw_output: str,
240 | functions: Optional[List[Any]] = None,
241 | function_call: Optional[Dict[str, Any]] = None,
242 | tools: Optional[List[Any]] = None,
243 | tool_calls: Optional[List[Dict[str, Any]]] = None,
244 | ) -> ModelResponse:
245 | """Make ModelResponse object from openai response."""
246 | count_start_time = datetime.now()
247 | prompt_token: int = num_tokens_for_messages(
248 | messages=messages, model=chunk["model"]
249 | )
250 | completion_token: int = num_tokens_for_messages(
251 | model=chunk["model"],
252 | messages=[{"role": "assistant", "content": raw_output}],
253 | )
254 |
255 | if functions and len(functions) > 0:
256 | functions_token = num_tokens_from_functions_input(
257 | functions=functions, model=chunk["model"]
258 | )
259 | prompt_token += functions_token
260 |
261 | if tools and len(tools) > 0:
262 | tools_token = num_tokens_from_functions_input(
263 | functions=[tool["function"] for tool in tools], model=chunk["model"]
264 | )
265 | prompt_token += tools_token
266 | # if function_call:
267 | # function_call_token = num_tokens_from_function_call_output(
268 | # function_call_output=function_call, model=chunk["model"]
269 | # )
270 | # completion_token += function_call_token
271 |
272 | count_end_time = datetime.now()
273 | logger.debug(
274 | f"counting token time : {(count_end_time - count_start_time).total_seconds() * 1000} ms"
275 | )
276 |
277 | usage = Usage(
278 | **{
279 | "prompt_tokens": prompt_token,
280 | "completion_tokens": completion_token,
281 | "total_tokens": prompt_token + completion_token,
282 | }
283 | )
284 |
285 | last_message = Message(
286 | role=(
287 | chunk.choices[0].delta.role
288 | if getattr(chunk.choices[0].delta, "role", None)
289 | else "assistant"
290 | ),
291 | content=raw_output if raw_output != "" else None,
292 | function_call=function_call if function_call else None,
293 | tool_calls=tool_calls if tool_calls else None,
294 | )
295 | choices = [
296 | Choices(finish_reason=chunk.choices[0].finish_reason, message=last_message)
297 | ]
298 |
299 | res = ModelResponse(
300 | id=chunk["id"],
301 | created=chunk["created"],
302 | model=chunk["model"],
303 | stream=True,
304 | )
305 | res.choices = choices
306 | res.usage = usage
307 | res._response_ms = response_ms
308 |
309 | return res
310 |
--------------------------------------------------------------------------------
/promptmodel/promptmodel_init.py:
--------------------------------------------------------------------------------
1 | import os
2 | import nest_asyncio
3 | import threading
4 | import asyncio
5 | import atexit
6 | import time
7 |
8 | from typing import Optional, Dict, Any
9 | from datetime import datetime
10 |
11 | from promptmodel.utils.config_utils import upsert_config, read_config
12 | from promptmodel.utils import logger
13 | from promptmodel.database.orm import initialize_db
14 | from promptmodel.database.crud import update_deployed_cache
15 | from promptmodel.apis.base import AsyncAPIClient
16 | from promptmodel.types.enums import InstanceType
17 |
18 |
19 | def init(use_cache: Optional[bool] = True, mask_inputs: Optional[bool] = False):
20 | nest_asyncio.apply()
21 |
22 | config = read_config()
23 | if (
24 | config
25 | and "connection" in config
26 | and (
27 | (
28 | "online" in config["connection"]
29 | and config["connection"]["online"] == True
30 | )
31 | or (
32 | "initializing" in config["connection"]
33 | and config["connection"]["initializing"] == True
34 | )
35 | )
36 | ):
37 | cache_manager = None
38 | else:
39 | if use_cache:
40 | upsert_config({"use_cache": True}, section="project")
41 | cache_manager = CacheManager()
42 | else:
43 | upsert_config({"use_cache": False}, section="project")
44 | cache_manager = None
45 | initialize_db() # init db for local usage
46 |
47 | if mask_inputs is True:
48 | upsert_config({"mask_inputs": True}, section="project")
49 | else:
50 | upsert_config({"mask_inputs": False}, section="project")
51 |
52 |
53 | class CacheManager:
54 | _instance = None
55 | _lock = threading.Lock()
56 |
57 | def __new__(cls):
58 | with cls._lock:
59 | if cls._instance is None:
60 | instance = super(CacheManager, cls).__new__(cls)
61 | instance.last_update_time = 0 # to manage update frequency
62 | instance.update_interval = 60 * 60 * 6 # seconds, 6 hours
63 | instance.program_alive = True
64 | instance.background_tasks = []
65 | initialize_db()
66 | atexit.register(instance._terminate)
67 | asyncio.run(instance.update_cache()) # updae cache first synchronously
68 | instance.cache_thread = threading.Thread(
69 | target=instance._run_cache_loop
70 | )
71 | instance.cache_thread.daemon = True
72 | instance.cache_thread.start()
73 | cls._instance = instance
74 | return cls._instance
75 |
76 | def cache_update_background_task(self, config):
77 | asyncio.run(update_deployed_db(config))
78 |
79 | def _run_cache_loop(self):
80 | asyncio.run(self._update_cache_periodically())
81 |
82 | async def _update_cache_periodically(self):
83 | while True:
84 | await asyncio.sleep(self.update_interval) # Non-blocking sleep
85 | await self.update_cache()
86 |
87 | async def update_cache(self):
88 | # Current time
89 | current_time = time.time()
90 | config = read_config()
91 |
92 | if not config:
93 | upsert_config({"version": 0}, section="project")
94 | config = {"project": {"version": 0}}
95 | if "project" not in config:
96 | upsert_config({"version": 0}, section="project")
97 | config = {"project": {"version": 0}}
98 |
99 | if "version" not in config["project"]:
100 | upsert_config({"version": 0}, section="project")
101 | config = {"project": {"version": 0}}
102 |
103 | # Check if we need to update the cache
104 | if current_time - self.last_update_time > self.update_interval:
105 | # Update cache logic
106 | try:
107 | await update_deployed_db(config)
108 | except:
109 | # try once more
110 | await update_deployed_db(config)
111 | # Update the last update time
112 | self.last_update_time = current_time
113 |
114 | def _terminate(self):
115 | self.program_alive = False
116 |
117 | # async def cleanup_background_tasks(self):
118 | # for task in self.background_tasks:
119 | # if not task.done():
120 | # task.cancel()
121 | # try:
122 | # await task
123 | # except asyncio.CancelledError:
124 | # pass # 작업이 취소됨
125 |
126 |
127 | async def update_deployed_db(config):
128 | if "project" not in config or "version" not in config["project"]:
129 | cached_project_version = 0
130 | else:
131 | cached_project_version = int(config["project"]["version"])
132 | try:
133 | res = await AsyncAPIClient.execute(
134 | method="GET",
135 | path="/check_update",
136 | params={"cached_version": cached_project_version},
137 | use_cli_key=False,
138 | )
139 | res = res.json()
140 | if res["need_update"]:
141 | # update local DB with res['project_status']
142 | project_status = res["project_status"]
143 | await update_deployed_cache(project_status)
144 | upsert_config({"version": res["version"]}, section="project")
145 | else:
146 | upsert_config({"version": res["version"]}, section="project")
147 | except Exception as exception:
148 | logger.error(f"Deployment cache update error: {exception}")
149 |
--------------------------------------------------------------------------------
/promptmodel/types/__init__.py:
--------------------------------------------------------------------------------
1 | from .enums import InstanceType
2 | from .response import (
3 | LLMResponse,
4 | LLMStreamResponse,
5 | FunctionModelConfig,
6 | PromptModelConfig,
7 | ChatModelConfig,
8 | FunctionSchema,
9 | )
10 | from .request import ChatLogRequest, RunLogRequest
11 |
--------------------------------------------------------------------------------
/promptmodel/types/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class InstanceType(str, Enum):
5 | ChatLog = "ChatLog"
6 | RunLog = "RunLog"
7 | ChatLogSession = "ChatLogSession"
8 |
9 |
10 | class LocalTask(str, Enum):
11 | RUN_PROMPT_MODEL = "RUN_PROMPT_MODEL"
12 | RUN_CHAT_MODEL = "RUN_CHAT_MODEL"
13 |
14 | LIST_CODE_CHAT_MODELS = "LIST_CHAT_MODELS"
15 | LIST_CODE_PROMPT_MODELS = "LIST_PROMPT_MODELS"
16 | LIST_CODE_FUNCTIONS = "LIST_FUNCTIONS"
17 |
18 |
19 | class LocalTaskErrorType(str, Enum):
20 | NO_FUNCTION_NAMED_ERROR = "NO_FUNCTION_NAMED_ERROR" # no DB update is needed
21 | FUNCTION_CALL_FAILED_ERROR = "FUNCTION_CALL_FAILED_ERROR" # create FunctionModelVersion, create Prompt, create RunLog
22 | PARSING_FAILED_ERROR = "PARSING_FAILED_ERROR" # create FunctionModelVersion, create Prompt, create RunLog
23 |
24 | SERVICE_ERROR = "SERVICE_ERROR" # no DB update is needed
25 |
26 |
27 | class ServerTask(str, Enum):
28 | UPDATE_RESULT_RUN = "UPDATE_RESULT_RUN"
29 | UPDATE_RESULT_CHAT_RUN = "UPDATE_RESULT_CHAT_RUN"
30 |
31 | SYNC_CODE = "SYNC_CODE"
32 |
33 |
34 | class Role(str, Enum):
35 | SYSTEM = "system"
36 | USER = "user"
37 | ASSISTANT = "assistant"
38 | FUNCTION = "function"
39 |
40 |
41 | class ParsingType(str, Enum):
42 | COLON = "colon"
43 | SQUARE_BRACKET = "square_bracket"
44 | DOUBLE_SQUARE_BRACKET = "double_square_bracket"
45 | HTML = "html"
46 | JSON = "json"
47 |
48 |
49 | class ParsingPattern(dict, Enum):
50 | COLON = {
51 | "start": r"(\w+)\s+type=([\w,\s\[\]]+): \n",
52 | "start_fstring": "{key}: \n",
53 | "end_fstring": None,
54 | "whole": r"(.*?): (.*?)\n",
55 | "start_token": None,
56 | "end_token": None,
57 | }
58 | SQUARE_BRACKET = {
59 | "start": r"\[(\w+)\s+type=([\w,\s\[\]]+)\]",
60 | "start_fstring": "[{key} type={type}]",
61 | "end_fstring": "[/{key}]",
62 | "whole": r"\[(\w+)\s+type=([\w,\s\[\]]+)\](.*?)\[/\1\]",
63 | "start_token": r"[",
64 | "end_token": r"]",
65 | }
66 | DOUBLE_SQUARE_BRACKET = {
67 | "start": r"\[\[(\w+)\s+type=([\w,\s\[\]]+)\]\]",
68 | "start_fstring": "[[{key} type={type}]]",
69 | "end_fstring": "[[/{key}]]",
70 | "whole": r"\[\[(\w+)\s+type=([\w,\s\[\]]+)\]\](.*?)\[\[/\1\]\]",
71 | "start_token": r"[",
72 | "end_token": r"]",
73 | }
74 | HTML = {
75 | "start": r"<(\w+)\s+type=([\w,\s\[\]]+)>",
76 | "start_fstring": "<{key} type={type}>",
77 | "end_fstring": "{key}>",
78 | "whole": r"<(\w+)\s+type=([\w,\s\[\]]+)>(.*?)\1>", # also captures type
79 | "start_token": r"<",
80 | "end_token": r">",
81 | }
82 |
83 |
84 | def get_pattern_by_type(parsing_type_value):
85 | return ParsingPattern[ParsingType(parsing_type_value).name].value
86 |
--------------------------------------------------------------------------------
/promptmodel/types/request.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | Dict,
3 | Any,
4 | Optional,
5 | )
6 | from pydantic import BaseModel
7 | from litellm.utils import (
8 | ModelResponse,
9 | )
10 | from openai.types.chat.chat_completion import *
11 |
12 |
13 | class ChatLogRequest(BaseModel):
14 | uuid: Optional[str] = None
15 | message: Dict[str, Any]
16 | metadata: Optional[Dict] = None
17 | api_response: Optional[ModelResponse] = None
18 |
19 | def __post_init__(
20 | self,
21 | ):
22 | if self.api_response is not None and self.message is None:
23 | self.message = self.api_response.choices[0].message.model_dump()
24 |
25 |
26 | class RunLogRequest(BaseModel):
27 | uuid: Optional[str] = None
28 | inputs: Optional[Dict[str, Any]] = None
29 | metadata: Optional[Dict] = None
30 | api_response: Optional[ModelResponse] = None
31 |
--------------------------------------------------------------------------------
/promptmodel/types/response.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | List,
3 | Dict,
4 | Any,
5 | Optional,
6 | )
7 | from pydantic import BaseModel
8 | from litellm.utils import (
9 | ModelResponse,
10 | )
11 | from litellm.types.utils import (
12 | FunctionCall,
13 | ChatCompletionMessageToolCall,
14 | Usage,
15 | Choices,
16 | Message,
17 | Delta,
18 | StreamingChoices,
19 | )
20 | from openai._models import BaseModel as OpenAIObject
21 | from openai.types.chat.chat_completion import *
22 | from openai.types.chat.chat_completion_chunk import (
23 | ChoiceDeltaFunctionCall,
24 | ChoiceDeltaToolCall,
25 | ChoiceDeltaToolCallFunction,
26 | )
27 |
28 |
29 | class PMDetail(BaseModel):
30 | model: str
31 | name: str
32 | version_uuid: str
33 | version: int
34 | log_uuid: str
35 |
36 |
37 | class LLMResponse(OpenAIObject):
38 | api_response: Optional[ModelResponse] = None
39 | raw_output: Optional[str] = None
40 | parsed_outputs: Optional[Dict[str, Any]] = None
41 | error: Optional[bool] = None
42 | error_log: Optional[str] = None
43 | function_call: Optional[FunctionCall] = None
44 | tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
45 | pm_detail: Optional[PMDetail] = None
46 |
47 |
48 | class LLMStreamResponse(OpenAIObject):
49 | api_response: Optional[ModelResponse] = None
50 | raw_output: Optional[str] = None
51 | parsed_outputs: Optional[Dict[str, Any]] = None
52 | error: Optional[bool] = None
53 | error_log: Optional[str] = None
54 | function_call: Optional[ChoiceDeltaFunctionCall] = None
55 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
56 | pm_detail: Optional[PMDetail] = None
57 |
58 |
59 | class FunctionModelConfig(BaseModel):
60 | """Response Class for FunctionModel.get_config()
61 | prompts: List[Dict[str, Any]] = []
62 | each prompt can have role, content, name, function_call, and tool_calls
63 | version_detail: Dict[str, Any] = {}
64 | version_detail has "model", "uuid", "parsing_type" and "output_keys".
65 | model: str
66 | model name (e.g. "gpt-3.5-turbo")
67 | name: str
68 | name of the FunctionModel.
69 | version_uuid: str
70 | version uuid of the FunctionModel.
71 | version: int
72 | version id of the FunctionModel.
73 | parsing_type: Optional[str] = None
74 | parsing type of the FunctionModel.
75 | output_keys: Optional[List[str]] = None
76 | output keys of the FunctionModel.
77 | """
78 |
79 | prompts: List[Dict[str, Any]]
80 | model: str
81 | name: str
82 | version_uuid: str
83 | version: int
84 | parsing_type: Optional[str] = None
85 | output_keys: Optional[List[str]] = None
86 |
87 |
88 | class PromptModelConfig(FunctionModelConfig):
89 | """Deprecated. Use FunctionModelConfig instead."""
90 |
91 |
92 | class ChatModelConfig(BaseModel):
93 | system_prompt: str
94 | model: str
95 | name: str
96 | version_uuid: str
97 | version: int
98 | message_logs: Optional[List[Dict]] = []
99 |
100 |
101 | class FunctionSchema(BaseModel):
102 | """
103 | {
104 | "name": str,
105 | "description": Optional[str],
106 | "parameters": {
107 | "type": "object",
108 | "properties": {
109 | "argument_name": {
110 | "type": str,
111 | "description": Optional[str],
112 | "enum": Optional[List[str]]
113 | },
114 | },
115 | "required": Optional[List[str]],
116 | },
117 | }
118 | """
119 |
120 | class _Parameters(BaseModel):
121 | class _Properties(BaseModel):
122 | type: str
123 | description: Optional[str] = ""
124 | enum: Optional[List[str]] = []
125 |
126 | type: str = "object"
127 | properties: Dict[str, _Properties] = {}
128 | required: Optional[List[str]] = []
129 |
130 | name: str
131 | description: Optional[str] = None
132 | parameters: _Parameters
133 |
134 |
135 | class UnitConfig(BaseModel):
136 | """Response Class for UnitLogger.get_config().
137 | Created after calling UnitLogger.log_start()
138 | name: str
139 | name of the UnitLogger.
140 | version_uuid: str
141 | version uuid of the UnitLogger.
142 | version: int
143 | version id of the UnitLogger.
144 | log_uuid: str
145 | log_uuid for current trace.
146 | """
147 |
148 | name: str
149 | version_uuid: str
150 | log_uuid: str
151 | version: int
152 |
--------------------------------------------------------------------------------
/promptmodel/unit_logger.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | Dict,
3 | Optional,
4 | )
5 |
6 | import promptmodel.utils.logger as logger
7 | from promptmodel.utils.async_utils import run_async_in_sync
8 | from promptmodel.utils.config_utils import check_connection_status_decorator
9 | from promptmodel.types.response import (
10 | UnitConfig,
11 | )
12 | from promptmodel.apis.base import AsyncAPIClient
13 |
14 | class UnitLogger:
15 | def __init__(
16 | self,
17 | name: str,
18 | version: int
19 | ):
20 | self.name: str = name
21 | self.version: int = version
22 | self.config: Optional[UnitConfig] = None
23 |
24 | @check_connection_status_decorator
25 | def get_config(self) -> Optional[UnitConfig]:
26 | """Get the config of the component
27 | You can get the config directly from the UnitLogger.confg attribute.
28 | """
29 | return self.config
30 |
31 | @check_connection_status_decorator
32 | async def log_start(self, *args, **kwargs) -> Optional["UnitLogger"]:
33 | """Create Component Log on the cloud.
34 | It returns the UnitLogger itself, so you can use it like this:
35 | >>> component = UnitLogger("intent_classification_unit", 1).log_start()
36 | >>> res = FunctionModel("intent_classifier", unit_config=component.config).run(...)
37 | >>> component.log_score({"accuracy": 0.9})
38 | """
39 | res = await AsyncAPIClient.execute(
40 | method="POST",
41 | path="/unit/log",
42 | json={
43 | "name": self.name,
44 | "version": self.version,
45 | },
46 | use_cli_key=False
47 | )
48 | if res.status_code != 200:
49 | logger.error(f"Failed to log start for component {self.name} v{self.version}")
50 | return None
51 | else:
52 | self.config = UnitConfig(**res.json())
53 |
54 | return self
55 |
56 |
57 | @check_connection_status_decorator
58 | async def log_score(self, scores: Dict[str, float], *args, **kwargs):
59 | res = await AsyncAPIClient.execute(
60 | method="POST",
61 | path="/unit/score",
62 | json={
63 | "unit_log_uuid" : self.config.log_uuid,
64 | "scores": scores,
65 | },
66 | use_cli_key=False
67 | )
68 | if res.status_code != 200:
69 | logger.error(f"Failed to log score for component {self.name} v{self.version}")
70 |
71 | return
72 |
--------------------------------------------------------------------------------
/promptmodel/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
--------------------------------------------------------------------------------
/promptmodel/utils/async_utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Coroutine
3 |
4 |
5 | def run_async_in_sync(coro: Coroutine):
6 | try:
7 | loop = asyncio.get_running_loop()
8 | except RuntimeError: # No running loop
9 | loop = asyncio.new_event_loop()
10 | asyncio.set_event_loop(loop)
11 | result = loop.run_until_complete(coro)
12 | # loop.close()
13 | return result
14 |
15 | return loop.run_until_complete(coro)
16 |
17 |
18 | def run_async_in_sync_threadsafe(coro: Coroutine, main_loop: asyncio.AbstractEventLoop):
19 | future = asyncio.run_coroutine_threadsafe(coro, main_loop)
20 | res = future.result()
21 | return res
22 |
--------------------------------------------------------------------------------
/promptmodel/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | from typing import Any, Dict
4 | import yaml
5 | from functools import wraps
6 |
7 | CONFIG_FILE = "./.promptmodel/config.yaml"
8 |
9 |
10 | def read_config():
11 | """
12 | Reads the configuration from the given filename.
13 |
14 | :return: A dictionary containing the configuration.
15 | """
16 | if not os.path.exists(CONFIG_FILE):
17 | return {}
18 |
19 | with open(CONFIG_FILE, "r") as file:
20 | config = yaml.safe_load(file) or {}
21 | return config
22 |
23 |
24 | def merge_dict(d1: Dict[str, Any], d2: Dict[str, Any]):
25 | """
26 | Merge two dictionaries recursively.
27 |
28 | :param d1: The first dictionary.
29 | :param d2: The second dictionary.
30 | :return: The merged dictionary.
31 | """
32 | for key, value in d2.items():
33 | if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict):
34 | d1[key] = merge_dict(d1[key], value)
35 | else:
36 | d1[key] = value
37 | return d1
38 |
39 |
40 | def upsert_config(new_config: Dict[str, Any], section: str = None):
41 | """
42 | Upserts the given configuration file with the given configuration.
43 |
44 | :param new_config: A dictionary containing the new configuration.
45 | :param section: The section of the configuration to update.
46 | """
47 | config = read_config()
48 | if section:
49 | config_section = config.get(section, {})
50 | new_config = {section: merge_dict(config_section, new_config)}
51 | config = merge_dict(config, new_config)
52 | # If . directory does not exist, create it
53 | if not os.path.exists("./.promptmodel"):
54 | os.mkdir("./.promptmodel")
55 |
56 | with open(CONFIG_FILE, "w") as file:
57 | yaml.safe_dump(config, file, default_flow_style=False)
58 |
59 |
60 | def check_connection_status_decorator(method):
61 | if asyncio.iscoroutinefunction(method):
62 |
63 | @wraps(method)
64 | async def async_wrapper(self, *args, **kwargs):
65 | config = read_config()
66 | if "connection" in config and (
67 | (
68 | "initializing" in config["connection"]
69 | and config["connection"]["initializing"]
70 | )
71 | or (
72 | "reloading" in config["connection"]
73 | and config["connection"]["reloading"]
74 | )
75 | ):
76 | return
77 | else:
78 | if "config" not in kwargs:
79 | kwargs["config"] = config
80 | return await method(self, *args, **kwargs)
81 |
82 | # async_wrapper.__name__ = method.__name__
83 | # async_wrapper.__doc__ = method.__doc__
84 | return async_wrapper
85 | else:
86 |
87 | @wraps(method)
88 | def wrapper(self, *args, **kwargs):
89 | config = read_config()
90 | if "connection" in config and (
91 | (
92 | "initializing" in config["connection"]
93 | and config["connection"]["initializing"]
94 | )
95 | or (
96 | "reloading" in config["connection"]
97 | and config["connection"]["reloading"]
98 | )
99 | ):
100 | return
101 | else:
102 | return method(self, *args, **kwargs)
103 |
104 | # wrapper.__name__ = method.__name__
105 | # wrapper.__doc__ = method.__doc__
106 | return wrapper
107 |
--------------------------------------------------------------------------------
/promptmodel/utils/crypto.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict
3 | from cryptography.fernet import Fernet
4 | import secrets
5 |
6 |
7 | def generate_api_key():
8 | """Generate a random token with 32 bytes"""
9 | token = secrets.token_hex(32)
10 | return token
11 |
12 |
13 | def generate_crypto_key():
14 | """
15 | Generates a key and save it into a file
16 | """
17 | key = Fernet.generate_key()
18 | with open("key.key", "wb") as key_file:
19 | key_file.write(key)
20 |
21 |
22 | def load_crypto_key():
23 | """
24 | Loads the key named `key.key` from the directory.
25 | Generates a new key if it does not exist.
26 | """
27 | if not os.path.exists("key.key"):
28 | generate_crypto_key()
29 | return open("key.key", "rb").read()
30 |
31 |
32 | def encrypt_message(message: str) -> bytes:
33 | """
34 | Encrypts the message
35 | """
36 | key = load_crypto_key()
37 | f = Fernet(key)
38 | message = message.encode()
39 | encrypted_message = f.encrypt(message)
40 | return encrypted_message
41 |
42 |
43 | def decrypt_message(encrypted_message: bytes) -> str:
44 | """
45 | Decrypts the message
46 | """
47 | key = load_crypto_key()
48 | f = Fernet(key)
49 | decrypted_message = f.decrypt(encrypted_message)
50 | return decrypted_message.decode()
51 |
--------------------------------------------------------------------------------
/promptmodel/utils/logger.py:
--------------------------------------------------------------------------------
1 | """Logger module"""
2 |
3 | import os
4 | from typing import Any
5 | import termcolor
6 |
7 |
8 | def debug(msg: Any, *args):
9 | if os.environ.get("TESTMODE_LOGGING", "false") != "true":
10 | return
11 | print(termcolor.colored("[DEBUG] " + str(msg) + str(*args), "light_yellow"))
12 |
13 |
14 | def success(msg: Any, *args):
15 | if os.environ.get("TESTMODE_LOGGING", "false") != "true":
16 | return
17 | print(termcolor.colored("[SUCCESS] " + str(msg) + str(*args), "green"))
18 |
19 |
20 | def info(msg: Any, *args):
21 | if os.environ.get("TESTMODE_LOGGING", "false") != "true":
22 | return
23 | print(termcolor.colored("[INFO] " + str(msg) + str(*args), "blue"))
24 |
25 |
26 | def warning(msg: Any, *args):
27 | if os.environ.get("TESTMODE_LOGGING", "false") != "true":
28 | return
29 | print(termcolor.colored("[WARNING] " + str(msg) + str(*args), "yellow"))
30 |
31 |
32 | def error(msg: Any, *args):
33 | print(termcolor.colored("[Error] " + str(msg) + str(*args), "red"))
34 |
--------------------------------------------------------------------------------
/promptmodel/utils/output_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any, Dict
3 |
4 |
5 | def update_dict(
6 | target: Dict[str, str],
7 | source: Dict[str, str],
8 | ):
9 | for key, value in source.items():
10 | if value is not None:
11 | if key not in target:
12 | target[key] = value
13 | else:
14 | target[key] += value
15 | return target
16 |
17 |
18 | def convert_str_to_type(value: str, type_str: str) -> Any:
19 | if type_str == "str":
20 | return value.strip()
21 | elif type_str == "bool":
22 | return value.lower() == "true"
23 | elif type_str == "int":
24 | return int(value)
25 | elif type_str == "float":
26 | return float(value)
27 | elif type_str.startswith("List"):
28 | return json.loads(value)
29 | elif type_str.startswith("Dict"):
30 | return json.loads(value)
31 | return value # Default: Return as is
32 |
--------------------------------------------------------------------------------
/promptmodel/utils/random_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | def select_version_by_ratio(versions):
5 | epsilon = 1e-10
6 | ratios = [version["ratio"] for version in versions]
7 |
8 | if not abs(sum(ratios) - 1.0) <= epsilon:
9 | raise ValueError(f"Sum of ratios must be 1.0, now {sum(ratios)}")
10 |
11 | cumulative_ratios = []
12 | cumulative_sum = 0
13 | for ratio in ratios:
14 | cumulative_sum += ratio
15 | cumulative_ratios.append(cumulative_sum)
16 |
17 | random_value = random.random()
18 | for idx, cumulative_ratio in enumerate(cumulative_ratios):
19 | if random_value <= cumulative_ratio:
20 | return versions[idx]
21 |
--------------------------------------------------------------------------------
/promptmodel/utils/token_counting.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 | from litellm import token_counter
3 |
4 |
5 | def set_inputs_to_prompts(inputs: Dict[str, Any], prompts: List[Dict[str, str]]):
6 | messages = []
7 | for prompt in prompts:
8 | prompt["content"] = prompt["content"].replace("{", "{{").replace("}", "}}")
9 | prompt["content"] = prompt["content"].replace("{{{{", "{").replace("}}}}", "}")
10 | messages.append(
11 | {
12 | "content": prompt["content"].format(**inputs),
13 | "role": prompt["role"],
14 | }
15 | )
16 |
17 |
18 | def num_tokens_for_messages(
19 | messages: List[Dict[str, str]], model: str = "gpt-3.5-turbo-0613"
20 | ) -> int:
21 | # tokens_per_message = 0
22 | # tokens_per_name = 0
23 | # if model.startswith("gpt-3.5-turbo"):
24 | # tokens_per_message = 4
25 | # tokens_per_name = -1
26 |
27 | # if model.startswith("gpt-4"):
28 | # tokens_per_message = 3
29 | # tokens_per_name = 1
30 |
31 | # if model.endswith("-0613") or model == "gpt-3.5-turbo-16k":
32 | # tokens_per_message = 3
33 | # tokens_per_name = 1
34 | # sum = 0
35 | processed_messages = [
36 | {**message, "function_call": str(message["function_call"])}
37 | if "function_call" in message
38 | else message
39 | for message in messages
40 | ]
41 | sum = token_counter(model=model, messages=processed_messages)
42 | # for message in messages:
43 | # sum += tokens_per_message
44 | # if "name" in message:
45 | # sum += tokens_per_name
46 | return sum
47 |
48 |
49 | def num_tokens_for_messages_for_each(
50 | messages: List[Dict[str, str]], model: str = "gpt-3.5-turbo-0613"
51 | ) -> List[int]:
52 | processed_messages = [
53 | {**message, "function_call": str(message["function_call"])}
54 | if "function_call" in message
55 | else message
56 | for message in messages
57 | ]
58 | processed_messages = [
59 | {**message, "tool_calls": str(message["tool_calls"])}
60 | if "tool_calls" in message
61 | else message
62 | for message in processed_messages
63 | ]
64 | return [
65 | token_counter(model=model, messages=[message]) for message in processed_messages
66 | ]
67 |
68 |
69 | def num_tokens_from_functions_input(
70 | functions: Optional[List[Any]] = None, model="gpt-3.5-turbo-0613"
71 | ) -> int:
72 | """Return the number of tokens used by a list of functions."""
73 | if functions is None:
74 | return 0
75 | num_tokens = 0
76 | for function in functions:
77 | function_tokens = token_counter(model=model, text=function["name"])
78 | function_tokens += token_counter(model=model, text=function["description"])
79 |
80 | if "parameters" in function:
81 | parameters = function["parameters"]
82 | if "properties" in parameters:
83 | for properties_key in parameters["properties"]:
84 | function_tokens += token_counter(model=model, text=properties_key)
85 | v = parameters["properties"][properties_key]
86 | for field in v:
87 | if field == "type":
88 | function_tokens += 2
89 | function_tokens += token_counter(
90 | model=model, text=v["type"]
91 | )
92 | elif field == "description":
93 | function_tokens += 2
94 | function_tokens += token_counter(
95 | model=model, text=v["description"]
96 | )
97 | elif field == "enum":
98 | function_tokens -= 3
99 | for o in v["enum"]:
100 | function_tokens += 3
101 | function_tokens += token_counter(model=model, text=o)
102 | else:
103 | print(f"Warning: not supported field {field}")
104 | function_tokens += 11
105 |
106 | num_tokens += function_tokens
107 |
108 | num_tokens += 12
109 | return num_tokens
110 |
111 |
112 | def num_tokens_from_function_call_output(
113 | function_call_output: Dict[str, str] = {}, model="gpt-3.5-turbo-0613"
114 | ) -> int:
115 | num_tokens = 1
116 | num_tokens += token_counter(model=model, text=function_call_output["name"])
117 | if "arguments" in function_call_output:
118 | num_tokens += token_counter(model=model, text=function_call_output["arguments"])
119 | return num_tokens
120 |
--------------------------------------------------------------------------------
/promptmodel/websocket/__init__.py:
--------------------------------------------------------------------------------
1 | from .websocket_client import DevWebsocketClient
2 | from .reload_handler import CodeReloadHandler
3 |
--------------------------------------------------------------------------------
/promptmodel/websocket/reload_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import importlib
4 | import asyncio
5 | from typing import Any, Dict, List
6 | from threading import Timer
7 | from rich import print
8 | from watchdog.events import FileSystemEventHandler
9 | from playhouse.shortcuts import model_to_dict
10 |
11 | from promptmodel.apis.base import APIClient
12 | from promptmodel.utils.config_utils import read_config, upsert_config
13 | from promptmodel.utils.async_utils import run_async_in_sync_threadsafe
14 | from promptmodel.utils import logger
15 | from promptmodel import DevApp
16 | from promptmodel.database.models import (
17 | DeployedFunctionModel,
18 | DeployedFunctionModelVersion,
19 | DeployedPrompt,
20 | )
21 |
22 | from promptmodel.websocket.websocket_client import DevWebsocketClient
23 | from promptmodel.types.enums import ServerTask
24 |
25 |
26 | class CodeReloadHandler(FileSystemEventHandler):
27 | def __init__(
28 | self,
29 | _devapp_filename: str,
30 | _instance_name: str,
31 | dev_websocket_client: DevWebsocketClient,
32 | main_loop: asyncio.AbstractEventLoop,
33 | ):
34 | self._devapp_filename: str = _devapp_filename
35 | self.devapp_instance_name: str = _instance_name
36 | self.dev_websocket_client: DevWebsocketClient = (
37 | dev_websocket_client # save dev_websocket_client instance
38 | )
39 | self.timer = None
40 | self.main_loop = main_loop
41 |
42 | def on_modified(self, event):
43 | """Called when a file or directory is modified."""
44 | if event.src_path.endswith(".py"):
45 | upsert_config({"reloading": True}, "connection")
46 | if self.timer:
47 | self.timer.cancel()
48 | # reload modified file & main file
49 | self.timer = Timer(0.5, self.reload_code, args=(event.src_path,))
50 | self.timer.start()
51 |
52 | def reload_code(self, modified_file_path: str):
53 | print(
54 | f"[violet]promptmodel:dev:[/violet] Reloading {self._devapp_filename} module due to changes..."
55 | )
56 | relative_modified_path = os.path.relpath(modified_file_path, os.getcwd())
57 | # Reload the devapp module
58 | module_name = relative_modified_path.replace("./", "").replace("/", ".")[
59 | :-3
60 | ] # assuming the file is in the PYTHONPATH
61 |
62 | if module_name in sys.modules:
63 | module = sys.modules[module_name]
64 | importlib.reload(module)
65 |
66 | reloaded_module = importlib.reload(sys.modules[self._devapp_filename])
67 | print(
68 | f"[violet]promptmodel:dev:[/violet] {self._devapp_filename} module reloaded successfully."
69 | )
70 |
71 | new_devapp_instance: DevApp = getattr(
72 | reloaded_module, self.devapp_instance_name
73 | )
74 |
75 | config = read_config()
76 | org = config["connection"]["org"]
77 | project = config["connection"]["project"]
78 |
79 | # save samples, FunctionSchema, FunctionModel, ChatModel to cloud server by websocket ServerTask request
80 | new_function_model_name_list = (
81 | new_devapp_instance._get_function_model_name_list()
82 | )
83 | new_chat_model_name_list = new_devapp_instance._get_chat_model_name_list()
84 | new_samples = new_devapp_instance.samples
85 | new_function_schemas = new_devapp_instance._get_function_schema_list()
86 |
87 | res = run_async_in_sync_threadsafe(
88 | self.dev_websocket_client.request(
89 | ServerTask.SYNC_CODE,
90 | message={
91 | "new_function_model": new_function_model_name_list,
92 | "new_chat_model": new_chat_model_name_list,
93 | "new_samples": new_samples,
94 | "new_schemas": new_function_schemas,
95 | },
96 | ),
97 | main_loop=self.main_loop,
98 | )
99 |
100 | # update_samples(new_devapp_instance.samples)
101 | upsert_config({"reloading": False}, "connection")
102 | self.dev_websocket_client.update_devapp_instance(new_devapp_instance)
103 |
--------------------------------------------------------------------------------
/publish.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Exit on first error
4 | set -e
5 |
6 | # Clean up old distribution archives
7 | echo "Cleaning up old builds..."
8 | rm -rf dist/
9 | rm -rf build/
10 | rm -rf *.egg-info
11 |
12 | # Generate distribution archives
13 | echo "Building distribution..."
14 | python setup.py sdist bdist_wheel
15 |
16 | # Upload using twine
17 | echo "Uploading to PyPI..."
18 | twine upload dist/*
19 |
20 | echo "Publishing complete!"
21 |
22 | # If you want to automatically open a web browser to check your library on PyPI
23 | # Uncomment the following line and replace 'your_package_name' with the name of your library
24 | # xdg-open "https://pypi.org/project/your_package_name/"
25 |
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pydantic
2 | peewee
3 | typer[all]
4 | cryptography
5 | pyyaml
6 | InquirerPy
7 | litellm
8 | python-dotenv
9 | websockets==10.4
10 | termcolor
11 | watchdog
12 | readerwriterlock
13 | httpx[http2]
14 | pytest
15 | pytest-asyncio
16 | pytest-mock
17 | pytest-cov
18 | nest-asyncio
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt & model versioning on the cloud, built for developers.
3 | """
4 |
5 | from setuptools import setup, find_namespace_packages
6 |
7 | # Read README.md for the long description
8 | with open("README.md", "r") as fh:
9 | long_description = fh.read()
10 |
11 | setup(
12 | name="promptmodel",
13 | version="0.2.1",
14 | packages=find_namespace_packages(),
15 | entry_points={
16 | "console_scripts": [
17 | "prompt = promptmodel.cli.main:app",
18 | ],
19 | },
20 | description="Prompt & model versioning on the cloud, built for developers.",
21 | long_description=long_description,
22 | long_description_content_type="text/markdown",
23 | author="weavel",
24 | url="https://github.com/weavel-ai/promptmodel",
25 | install_requires=[
26 | "httpx[http2]",
27 | "pydantic>=2.4.2",
28 | "peewee",
29 | "typer[all]",
30 | "cryptography",
31 | "pyyaml",
32 | "InquirerPy",
33 | "litellm>=1.7.1",
34 | # "litellm@git+https://github.com/weavel-ai/litellm.git@llms_add_clova_support",
35 | "python-dotenv",
36 | "websockets",
37 | "termcolor",
38 | "watchdog",
39 | "readerwriterlock",
40 | "nest-asyncio",
41 | ],
42 | python_requires=">=3.8.10",
43 | keywords=[
44 | "weavel",
45 | "agent",
46 | "llm",
47 | "tools",
48 | "promptmodel",
49 | "llm agent",
50 | "prompt",
51 | "versioning",
52 | "eval",
53 | "evaluation",
54 | "collaborative",
55 | ],
56 | )
57 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/__init__.py
--------------------------------------------------------------------------------
/tests/api_client/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/api_client/__init__.py
--------------------------------------------------------------------------------
/tests/api_client/api_client_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import MagicMock
3 |
4 |
5 | from promptmodel.apis.base import APIClient, AsyncAPIClient
6 | from promptmodel.utils.crypto import decrypt_message, generate_api_key, encrypt_message
7 |
8 |
9 | @pytest.mark.asyncio
10 | async def test_get_headers(mocker):
11 | api_client = APIClient()
12 | async_api_client = AsyncAPIClient()
13 |
14 | mock_exit = mocker.patch("builtins.exit", side_effect=SystemExit)
15 | api_key = generate_api_key()
16 | mock_envirion_get = MagicMock(return_value=api_key)
17 |
18 | gt_cli_output = {
19 | "Authorization": f"Bearer {decrypt_message(encrypt_message(api_key))}"
20 | }
21 | gt_api_output = {"Authorization": f"Bearer {api_key}"}
22 | # No user
23 | mocker.patch("promptmodel.apis.base.read_config", return_value={})
24 | with pytest.raises(SystemExit):
25 | res = api_client._get_headers()
26 | mock_exit.assert_called_once()
27 | mock_exit.reset_mock()
28 |
29 | with pytest.raises(SystemExit):
30 | res = await async_api_client._get_headers()
31 | mock_exit.assert_called_once()
32 | mock_exit.reset_mock()
33 |
34 | mocker.patch("promptmodel.apis.base.read_config", return_value={"user": {}})
35 |
36 | with pytest.raises(Exception):
37 | res = api_client._get_headers()
38 | with pytest.raises(Exception):
39 | res = await async_api_client._get_headers()
40 |
41 | mocker.patch(
42 | "promptmodel.apis.base.read_config",
43 | return_value={"user": {"encrypted_api_key": encrypt_message(api_key)}},
44 | )
45 | res = api_client._get_headers()
46 | assert res == gt_cli_output, "API key is not decrypted properly"
47 | res = await async_api_client._get_headers()
48 | assert res == gt_api_output, "API key is not decrypted properly"
49 |
50 | mocker.patch("promptmodel.apis.base.os.environ.get", mock_envirion_get)
51 | res = api_client._get_headers(use_cli_key=False)
52 | assert res == gt_api_output, "API key is not retrieved properly"
53 | res = await async_api_client._get_headers(use_cli_key=False)
54 | assert res == gt_api_output, "API key is not retrieved properly"
55 |
56 |
57 | @pytest.mark.asyncio
58 | async def test_execute(mocker):
59 | api_client = APIClient()
60 | async_api_client = AsyncAPIClient()
61 | mock_exit = mocker.patch("builtins.exit", side_effect=SystemExit)
62 | mocker.patch("promptmodel.apis.base.APIClient._get_headers", return_value={})
63 | mocker.patch("promptmodel.apis.base.AsyncAPIClient._get_headers", return_value={})
64 |
65 | mock_request = mocker.patch(
66 | "promptmodel.apis.base.requests.request", return_value=None
67 | )
68 | mock_async_request = mocker.patch(
69 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=None
70 | )
71 | with pytest.raises(SystemExit):
72 | res = api_client.execute(path="test")
73 | mock_request.assert_called_once()
74 | mock_request.reset_mock()
75 | mock_exit.assert_called_once()
76 | mock_exit.reset_mock()
77 |
78 | res = await async_api_client.execute(path="test")
79 | mock_async_request.assert_called_once()
80 | mock_async_request.reset_mock()
81 |
82 | mock_response = MagicMock()
83 | mock_response.status_code = 200
84 | mock_request = mocker.patch(
85 | "promptmodel.apis.base.requests.request", return_value=mock_response
86 | )
87 | mock_async_request = mocker.patch(
88 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=mock_response
89 | )
90 |
91 | res = api_client.execute(path="test")
92 | mock_request.assert_called_once()
93 | assert res == mock_response, "Response is not returned properly"
94 | mock_request.reset_mock()
95 |
96 | res = await async_api_client.execute(path="test")
97 | mock_async_request.assert_called_once()
98 | assert res == mock_response, "Response is not returned properly"
99 | mock_async_request.reset_mock()
100 |
101 | mock_response.status_code = 403
102 | mock_request = mocker.patch(
103 | "promptmodel.apis.base.requests.request", return_value=mock_response
104 | )
105 | mock_async_request = mocker.patch(
106 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=mock_response
107 | )
108 | with pytest.raises(SystemExit):
109 | res = api_client.execute(path="test")
110 | mock_request.assert_called_once()
111 | mock_request.reset_mock()
112 | mock_exit.assert_called_once()
113 | mock_exit.reset_mock()
114 |
115 | res = await async_api_client.execute(path="test")
116 | mock_async_request.assert_called_once()
117 | mock_async_request.reset_mock()
118 |
119 | mock_response.status_code = 500
120 | mock_request = mocker.patch(
121 | "promptmodel.apis.base.requests.request", return_value=mock_response
122 | )
123 | mock_async_request = mocker.patch(
124 | "promptmodel.apis.base.httpx.AsyncClient.request", return_value=mock_response
125 | )
126 | with pytest.raises(SystemExit):
127 | res = api_client.execute(path="test")
128 | mock_request.assert_called_once()
129 | mock_request.reset_mock()
130 | mock_exit.assert_called_once()
131 | mock_exit.reset_mock()
132 |
133 | res = await async_api_client.execute(path="test")
134 | mock_async_request.assert_called_once()
135 | mock_async_request.reset_mock()
136 |
--------------------------------------------------------------------------------
/tests/chat_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/chat_model/__init__.py
--------------------------------------------------------------------------------
/tests/chat_model/chat_model_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, MagicMock
3 |
4 | from typing import Generator, AsyncGenerator, List
5 |
6 | from promptmodel.types.response import LLMResponse, LLMStreamResponse, ChatModelConfig
7 | from promptmodel import ChatModel, DevClient
8 | from promptmodel.dev_app import ChatModelInterface, ChatModelInterface
9 |
10 | client = DevClient()
11 |
12 |
13 | def test_find_client(
14 | mocker,
15 | mock_fetch_chat_model,
16 | mock_async_chat_log_to_cloud: AsyncMock,
17 | ):
18 | fetch_chat_model = mocker.patch(
19 | "promptmodel.chat_model.LLMProxy.fetch_chat_model",
20 | mock_fetch_chat_model,
21 | )
22 |
23 | mocker.patch(
24 | "promptmodel.chat_model.LLMProxy._async_make_session_cloud",
25 | mock_async_chat_log_to_cloud,
26 | )
27 | pm = ChatModel("test")
28 | assert client.chat_models == [ChatModelInterface(name="test")]
29 |
30 |
31 | def test_get_config(
32 | mocker,
33 | mock_fetch_chat_model,
34 | mock_async_chat_log_to_cloud: AsyncMock,
35 | ):
36 | fetch_chat_model = mocker.patch(
37 | "promptmodel.chat_model.LLMProxy.fetch_chat_model", mock_fetch_chat_model
38 | )
39 |
40 | mocker.patch(
41 | "promptmodel.chat_model.LLMProxy._async_make_session_cloud",
42 | mock_async_chat_log_to_cloud,
43 | )
44 | # mock registering_meta
45 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock())
46 | chat_model = ChatModel("test")
47 | assert len(client.chat_models) == 1
48 | config: ChatModelConfig = chat_model.get_config()
49 | assert config.system_prompt == "You are a helpful assistant."
50 |
51 |
52 | def test_add_messages(
53 | mocker,
54 | mock_async_chat_log_to_cloud: AsyncMock,
55 | ):
56 | pass
57 |
58 |
59 | def test_run(
60 | mocker,
61 | mock_fetch_chat_model: AsyncMock,
62 | mock_async_chat_log_to_cloud: AsyncMock,
63 | ):
64 | fetch_chat_model = mocker.patch(
65 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
66 | )
67 |
68 | async_chat_log_to_cloud = mocker.patch(
69 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
70 | mock_async_chat_log_to_cloud,
71 | )
72 |
73 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock())
74 | chat_model = ChatModel("test", session_uuid="testuuid")
75 |
76 | res: LLMResponse = chat_model.run()
77 | print(res.api_response.model_dump())
78 |
79 | fetch_chat_model.assert_called_once()
80 | async_chat_log_to_cloud.assert_called_once()
81 |
82 | assert res.raw_output is not None
83 | assert res.error is None or res.error is False
84 | assert res.api_response is not None
85 | assert res.parsed_outputs is None
86 |
87 | fetch_chat_model.reset_mock()
88 | async_chat_log_to_cloud.reset_mock()
89 | mocker.patch(
90 | "promptmodel.utils.config_utils.read_config",
91 | return_value={"connection": {"initializing": True}},
92 | )
93 | res: LLMResponse = chat_model.run()
94 | print(res)
95 | fetch_chat_model.assert_not_called()
96 | async_chat_log_to_cloud.assert_not_called()
97 |
98 |
99 | @pytest.mark.asyncio
100 | async def test_arun(
101 | mocker,
102 | mock_fetch_chat_model: AsyncMock,
103 | mock_async_chat_log_to_cloud: AsyncMock,
104 | ):
105 | fetch_chat_model = mocker.patch(
106 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
107 | )
108 |
109 | async_chat_log_to_cloud = mocker.patch(
110 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
111 | mock_async_chat_log_to_cloud,
112 | )
113 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock())
114 | chat_model = ChatModel("test", session_uuid="testuuid")
115 |
116 | res: LLMResponse = await chat_model.arun()
117 | print(res.api_response.model_dump())
118 |
119 | fetch_chat_model.assert_called_once()
120 | async_chat_log_to_cloud.assert_called_once()
121 |
122 | assert res.raw_output is not None
123 | assert res.error is None or res.error is False
124 | assert res.api_response is not None
125 | assert res.parsed_outputs is None
126 |
127 | fetch_chat_model.reset_mock()
128 | async_chat_log_to_cloud.reset_mock()
129 | mocker.patch(
130 | "promptmodel.utils.config_utils.read_config",
131 | return_value={"connection": {"initializing": True}},
132 | )
133 | res: LLMResponse = await chat_model.arun()
134 | print(res)
135 | fetch_chat_model.assert_not_called()
136 | async_chat_log_to_cloud.assert_not_called()
137 |
138 |
139 | def test_stream(
140 | mocker,
141 | mock_fetch_chat_model: AsyncMock,
142 | mock_async_chat_log_to_cloud: AsyncMock,
143 | ):
144 | fetch_chat_model = mocker.patch(
145 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
146 | )
147 |
148 | async_chat_log_to_cloud = mocker.patch(
149 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
150 | mock_async_chat_log_to_cloud,
151 | )
152 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock())
153 | chat_model = ChatModel("test", session_uuid="testuuid")
154 |
155 | res: Generator[LLMStreamResponse, None, None] = chat_model.run(stream=True)
156 | chunks: List[LLMStreamResponse] = []
157 | for chunk in res:
158 | chunks.append(chunk)
159 |
160 | fetch_chat_model.assert_called_once()
161 | async_chat_log_to_cloud.assert_called_once()
162 |
163 | assert chunks[-1].api_response is not None
164 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
165 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
166 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
167 |
168 | fetch_chat_model.reset_mock()
169 | async_chat_log_to_cloud.reset_mock()
170 | mocker.patch(
171 | "promptmodel.utils.config_utils.read_config",
172 | return_value={"connection": {"initializing": True}},
173 | )
174 | res: LLMResponse = chat_model.run(stream=True)
175 | print(res)
176 | fetch_chat_model.assert_not_called()
177 | async_chat_log_to_cloud.assert_not_called()
178 |
179 |
180 | @pytest.mark.asyncio
181 | async def test_astream(
182 | mocker,
183 | mock_fetch_chat_model: AsyncMock,
184 | mock_async_chat_log_to_cloud: AsyncMock,
185 | ):
186 | fetch_chat_model = mocker.patch(
187 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
188 | )
189 |
190 | async_chat_log_to_cloud = mocker.patch(
191 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
192 | mock_async_chat_log_to_cloud,
193 | )
194 |
195 | mocker.patch("promptmodel.chat_model.RegisteringMeta", MagicMock())
196 | chat_model = ChatModel("test", session_uuid="testuuid")
197 |
198 | res: AsyncGenerator[LLMStreamResponse, None] = await chat_model.arun(stream=True)
199 |
200 | chunks: List[LLMStreamResponse] = []
201 | async for chunk in res:
202 | chunks.append(chunk)
203 |
204 | fetch_chat_model.assert_called_once()
205 | async_chat_log_to_cloud.assert_called_once()
206 |
207 | assert chunks[-1].api_response is not None
208 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
209 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
210 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
211 |
212 | fetch_chat_model.reset_mock()
213 | async_chat_log_to_cloud.reset_mock()
214 | mocker.patch(
215 | "promptmodel.utils.config_utils.read_config",
216 | return_value={"connection": {"initializing": True}},
217 | )
218 |
219 | res: LLMResponse = await chat_model.arun(stream=True)
220 | print(res)
221 | fetch_chat_model.assert_not_called()
222 | async_chat_log_to_cloud.assert_not_called()
223 |
--------------------------------------------------------------------------------
/tests/chat_model/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from promptmodel.llms.llm_proxy import LLMProxy
5 | from promptmodel.types.response import ChatModelConfig
6 |
7 |
8 | async def echo_coroutine(*args, **kwargs):
9 | # print(args, kwargs)
10 | return args, kwargs
11 |
12 |
13 | @pytest.fixture
14 | def mock_fetch_chat_model():
15 | mock_fetch_chat_model = AsyncMock()
16 | mock_instruction = "You are a helpful assistant."
17 | mock_version_details = {
18 | "model": "gpt-4-1106-preview",
19 | "uuid": "testuuid",
20 | "version": 1,
21 | }
22 | mock_message_logs = [
23 | {
24 | "role": "system",
25 | "content": "You are a helpful assistant.",
26 | "session_uuid": "testuuid",
27 | },
28 | {"role": "user", "content": "Hello!", "session_uuid": "testuuid"},
29 | {
30 | "role": "assistant",
31 | "content": "Hello! How can I help you?",
32 | "session_uuid": "testuuid",
33 | },
34 | ]
35 |
36 | mock_fetch_chat_model.return_value = (
37 | mock_instruction,
38 | mock_version_details,
39 | mock_message_logs,
40 | )
41 |
42 | return mock_fetch_chat_model
43 |
44 |
45 | @pytest.fixture
46 | def mock_async_chat_log_to_cloud():
47 | mock_async_chat_log_to_cloud = AsyncMock()
48 |
49 | mock_response = MagicMock()
50 | mock_response.status_code = 200
51 | mock_async_chat_log_to_cloud.return_value = mock_response
52 |
53 | return mock_async_chat_log_to_cloud
54 |
55 |
56 | @pytest.fixture
57 | def mock_async_make_session_cloud():
58 | mock_async_make_session_cloud = AsyncMock()
59 |
60 | mock_response = MagicMock()
61 | mock_response.status_code = 200
62 | mock_async_make_session_cloud.return_value = mock_response
63 |
64 | return mock_async_make_session_cloud
65 |
--------------------------------------------------------------------------------
/tests/chat_model/registering_meta_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from promptmodel.chat_model import RegisteringMeta
5 |
6 |
7 | def test_registering_meta(mocker):
8 | # Fail to find DevClient instance
9 | client_instance = RegisteringMeta.find_client_instance()
10 | assert client_instance is None
11 |
--------------------------------------------------------------------------------
/tests/cli/dev_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from typing import Generator, AsyncGenerator, Dict, List, Any
5 | from litellm import ModelResponse
6 |
7 | from promptmodel.llms.llm import LLM
8 | from promptmodel.llms.llm_proxy import LLMProxy
9 | from promptmodel.types.response import LLMResponse, LLMStreamResponse
10 |
--------------------------------------------------------------------------------
/tests/constants.py:
--------------------------------------------------------------------------------
1 | # JSON Schema to pass to OpenAI
2 | function_shemas = [
3 | {
4 | "name": "get_current_weather",
5 | "description": "Get the current weather in a given location",
6 | "parameters": {
7 | "type": "object",
8 | "properties": {
9 | "location": {
10 | "type": "string",
11 | "description": "The city and state, e.g. San Francisco, CA",
12 | },
13 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
14 | },
15 | "required": ["location"],
16 | },
17 | }
18 | ]
19 |
--------------------------------------------------------------------------------
/tests/function_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/function_model/__init__.py
--------------------------------------------------------------------------------
/tests/function_model/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from promptmodel.llms.llm_proxy import LLMProxy
5 |
6 |
7 | async def echo_coroutine(*args, **kwargs):
8 | # print(args, kwargs)
9 | return args, kwargs
10 |
11 |
12 | @pytest.fixture
13 | def mock_fetch_prompts():
14 | mock_fetch_prompts = AsyncMock()
15 | mock_prompts = [
16 | {"role": "system", "content": "You are a helpful assistant.", "step": 1},
17 | {"role": "user", "content": "Hello!", "step": 2},
18 | ]
19 | mock_version_details = {
20 | "model": "gpt-3.5-turbo",
21 | "uuid": "testuuid",
22 | "version": 1,
23 | "parsing_type": None,
24 | "output_keys": None,
25 | }
26 | mock_fetch_prompts.return_value = (mock_prompts, mock_version_details)
27 |
28 | return mock_fetch_prompts
29 |
30 |
31 | @pytest.fixture
32 | def mock_async_log_to_cloud():
33 | mock_async_log_to_cloud = AsyncMock()
34 |
35 | mock_response = MagicMock()
36 | mock_response.status_code = 200
37 | mock_async_log_to_cloud.return_value = mock_response
38 |
39 | return mock_async_log_to_cloud
40 |
--------------------------------------------------------------------------------
/tests/function_model/function_model_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, MagicMock
3 |
4 | from typing import Generator, AsyncGenerator, List
5 |
6 | from promptmodel.types.response import LLMResponse, LLMStreamResponse
7 | from promptmodel import FunctionModel, DevClient
8 | from promptmodel.dev_app import FunctionModelInterface
9 |
10 | client = DevClient()
11 |
12 |
13 | def test_find_client(mocker):
14 | pm = FunctionModel("test")
15 | assert client.function_models == [FunctionModelInterface(name="test")]
16 |
17 |
18 | def test_get_config(mocker, mock_fetch_prompts):
19 | fetch_prompts = mocker.patch(
20 | "promptmodel.function_model.LLMProxy.fetch_prompts", mock_fetch_prompts
21 | )
22 | # mock registering_meta
23 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
24 | function_model = FunctionModel("test")
25 | assert len(client.function_models) == 1
26 | config = function_model.get_config()
27 | assert len(config.prompts) == 2
28 |
29 |
30 | def test_run(mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock):
31 | fetch_prompts = mocker.patch(
32 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
33 | )
34 | async_log_to_cloud = mocker.patch(
35 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
36 | mock_async_log_to_cloud,
37 | )
38 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
39 | function_model = FunctionModel("test")
40 | res: LLMResponse = function_model.run({})
41 | fetch_prompts.assert_called_once()
42 | async_log_to_cloud.assert_called_once()
43 | assert res.raw_output is not None
44 | assert res.error is None or res.error is False
45 | assert res.api_response is not None
46 | assert res.parsed_outputs is None
47 |
48 | fetch_prompts.reset_mock()
49 | async_log_to_cloud.reset_mock()
50 | mocker.patch(
51 | "promptmodel.utils.config_utils.read_config",
52 | return_value={"connection": {"initializing": True}},
53 | )
54 | res = function_model.run()
55 | print(res)
56 | fetch_prompts.assert_not_called()
57 | async_log_to_cloud.assert_not_called()
58 |
59 |
60 | @pytest.mark.asyncio
61 | async def test_arun(
62 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
63 | ):
64 | fetch_prompts = mocker.patch(
65 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
66 | )
67 | async_log_to_cloud = mocker.patch(
68 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
69 | mock_async_log_to_cloud,
70 | )
71 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
72 | function_model = FunctionModel("test")
73 |
74 | res: LLMResponse = await function_model.arun({})
75 | fetch_prompts.assert_called_once()
76 | async_log_to_cloud.assert_called_once()
77 | assert res.raw_output is not None
78 | assert res.error is None or res.error is False
79 | assert res.api_response is not None
80 | assert res.parsed_outputs is None
81 |
82 | fetch_prompts.reset_mock()
83 | async_log_to_cloud.reset_mock()
84 | mocker.patch(
85 | "promptmodel.utils.config_utils.read_config",
86 | return_value={"connection": {"initializing": True}},
87 | )
88 | res = await function_model.arun()
89 | print(res)
90 | fetch_prompts.assert_not_called()
91 | async_log_to_cloud.assert_not_called()
92 |
93 |
94 | def test_stream(
95 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
96 | ):
97 | fetch_prompts = mocker.patch(
98 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
99 | )
100 | async_log_to_cloud = mocker.patch(
101 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
102 | mock_async_log_to_cloud,
103 | )
104 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
105 | function_model = FunctionModel("test")
106 |
107 | res: Generator[LLMStreamResponse, None, None] = function_model.stream({})
108 | chunks: List[LLMStreamResponse] = []
109 | for chunk in res:
110 | chunks.append(chunk)
111 | fetch_prompts.assert_called_once()
112 | async_log_to_cloud.assert_called_once()
113 |
114 | assert chunks[-1].api_response is not None
115 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
116 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
117 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
118 |
119 | fetch_prompts.reset_mock()
120 | async_log_to_cloud.reset_mock()
121 | mocker.patch(
122 | "promptmodel.utils.config_utils.read_config",
123 | return_value={"connection": {"initializing": True}},
124 | )
125 | res = function_model.stream()
126 | print(res)
127 | fetch_prompts.assert_not_called()
128 | async_log_to_cloud.assert_not_called()
129 |
130 |
131 | @pytest.mark.asyncio
132 | async def test_astream(
133 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
134 | ):
135 | fetch_prompts = mocker.patch(
136 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
137 | )
138 | async_log_to_cloud = mocker.patch(
139 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
140 | mock_async_log_to_cloud,
141 | )
142 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
143 | function_model = FunctionModel("test")
144 |
145 | res: AsyncGenerator[LLMStreamResponse, None] = await function_model.astream({})
146 | chunks: List[LLMStreamResponse] = []
147 | async for chunk in res:
148 | chunks.append(chunk)
149 | fetch_prompts.assert_called_once()
150 | async_log_to_cloud.assert_called_once()
151 |
152 | assert chunks[-1].api_response is not None
153 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
154 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
155 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
156 |
157 | fetch_prompts.reset_mock()
158 | async_log_to_cloud.reset_mock()
159 | mocker.patch(
160 | "promptmodel.utils.config_utils.read_config",
161 | return_value={"connection": {"initializing": True}},
162 | )
163 | res = await function_model.astream({})
164 | print(res)
165 | fetch_prompts.assert_not_called()
166 | async_log_to_cloud.assert_not_called()
167 |
168 |
169 | def test_run_and_parse(
170 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
171 | ):
172 | fetch_prompts = mocker.patch(
173 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
174 | )
175 | async_log_to_cloud = mocker.patch(
176 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
177 | mock_async_log_to_cloud,
178 | )
179 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
180 | function_model = FunctionModel("test")
181 | res: LLMResponse = function_model.run_and_parse({})
182 | fetch_prompts.assert_called_once()
183 | async_log_to_cloud.assert_called_once()
184 | assert res.raw_output is not None
185 | assert res.error is None or res.error is False
186 | assert res.api_response is not None
187 | assert res.parsed_outputs == {}
188 |
189 | fetch_prompts.reset_mock()
190 | async_log_to_cloud.reset_mock()
191 | mocker.patch(
192 | "promptmodel.utils.config_utils.read_config",
193 | return_value={"connection": {"initializing": True}},
194 | )
195 | res = function_model.run_and_parse({})
196 | print(res)
197 | fetch_prompts.assert_not_called()
198 | async_log_to_cloud.assert_not_called()
199 |
200 |
201 | @pytest.mark.asyncio
202 | async def test_arun_and_parse(
203 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
204 | ):
205 | fetch_prompts = mocker.patch(
206 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
207 | )
208 | async_log_to_cloud = mocker.patch(
209 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
210 | mock_async_log_to_cloud,
211 | )
212 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
213 | function_model = FunctionModel("test")
214 |
215 | res: LLMResponse = await function_model.arun_and_parse({})
216 | fetch_prompts.assert_called_once()
217 | async_log_to_cloud.assert_called_once()
218 | assert res.raw_output is not None
219 | assert res.error is None or res.error is False
220 | assert res.api_response is not None
221 | assert res.parsed_outputs == {}
222 |
223 | fetch_prompts.reset_mock()
224 | async_log_to_cloud.reset_mock()
225 | mocker.patch(
226 | "promptmodel.utils.config_utils.read_config",
227 | return_value={"connection": {"initializing": True}},
228 | )
229 | res = await function_model.arun_and_parse({})
230 | print(res)
231 | fetch_prompts.assert_not_called()
232 | async_log_to_cloud.assert_not_called()
233 |
234 |
235 | def test_stream_and_parse(
236 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
237 | ):
238 | fetch_prompts = mocker.patch(
239 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
240 | )
241 | async_log_to_cloud = mocker.patch(
242 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
243 | mock_async_log_to_cloud,
244 | )
245 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
246 | function_model = FunctionModel("test")
247 |
248 | res: Generator[LLMStreamResponse, None, None] = function_model.stream_and_parse({})
249 | chunks: List[LLMStreamResponse] = []
250 | for chunk in res:
251 | chunks.append(chunk)
252 | fetch_prompts.assert_called_once()
253 | async_log_to_cloud.assert_called_once()
254 |
255 | assert chunks[-1].api_response is not None
256 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
257 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
258 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
259 |
260 | fetch_prompts.reset_mock()
261 | async_log_to_cloud.reset_mock()
262 | mocker.patch(
263 | "promptmodel.utils.config_utils.read_config",
264 | return_value={"connection": {"initializing": True}},
265 | )
266 | res = function_model.stream_and_parse({})
267 | print(res)
268 | fetch_prompts.assert_not_called()
269 | async_log_to_cloud.assert_not_called()
270 |
271 |
272 | @pytest.mark.asyncio
273 | async def test_astream_and_parse(
274 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
275 | ):
276 | fetch_prompts = mocker.patch(
277 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
278 | )
279 | async_log_to_cloud = mocker.patch(
280 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
281 | mock_async_log_to_cloud,
282 | )
283 | mocker.patch("promptmodel.function_model.RegisteringMeta", MagicMock())
284 | function_model = FunctionModel("test")
285 |
286 | res: AsyncGenerator[
287 | LLMStreamResponse, None
288 | ] = await function_model.astream_and_parse({})
289 | chunks: List[LLMStreamResponse] = []
290 | async for chunk in res:
291 | chunks.append(chunk)
292 | fetch_prompts.assert_called_once()
293 | async_log_to_cloud.assert_called_once()
294 |
295 | assert chunks[-1].api_response is not None
296 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
297 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
298 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
299 |
300 | fetch_prompts.reset_mock()
301 | async_log_to_cloud.reset_mock()
302 | mocker.patch(
303 | "promptmodel.utils.config_utils.read_config",
304 | return_value={"connection": {"initializing": True}},
305 | )
306 | res = await function_model.astream_and_parse({})
307 | print(res)
308 | fetch_prompts.assert_not_called()
309 | async_log_to_cloud.assert_not_called()
310 |
--------------------------------------------------------------------------------
/tests/function_model/registering_meta_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from promptmodel.function_model import RegisteringMeta
5 |
6 |
7 | def test_registering_meta(mocker):
8 | # Fail to find DevClient instance
9 | client_instance = RegisteringMeta.find_client_instance()
10 | assert client_instance is None
11 |
--------------------------------------------------------------------------------
/tests/llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/__init__.py
--------------------------------------------------------------------------------
/tests/llm/function_call_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | from ..constants import function_shemas
5 | from promptmodel.llms.llm import LLM
6 | from promptmodel.types.response import *
7 | from promptmodel.types.enums import ParsingType
8 |
9 | html_output_format = """\
10 | You must follow the provided output format. Keep the string between <> as it is.
11 | Output format:
12 |
13 | (value here)
14 |
15 | """
16 |
17 | double_bracket_output_format = """\
18 | You must follow the provided output format. Keep the string between [[ ]] as it is.
19 | Output format:
20 | [[response type=str]]
21 | (value here)
22 | [[/response]]
23 | """
24 |
25 | colon_output_format = """\
26 | You must follow the provided output format. Keep the string before : as it is
27 | Output format:
28 | response type=str: (value here)
29 |
30 | """
31 |
32 |
33 | def test_run_with_functions(mocker):
34 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
35 | llm = LLM()
36 | res: LLMResponse = llm.run(
37 | messages=messages,
38 | functions=function_shemas,
39 | model="gpt-3.5-turbo-0613",
40 | )
41 |
42 | assert res.error is None, "error is not None"
43 | assert res.api_response is not None, "api_response is None"
44 | assert (
45 | res.api_response.choices[0].finish_reason == "function_call"
46 | ), "finish_reason is not function_call"
47 | print(res.api_response.model_dump())
48 | print(res.__dict__)
49 | assert res.function_call is not None, "function_call is None"
50 | assert isinstance(res.function_call, FunctionCall)
51 |
52 | messages = [{"role": "user", "content": "Hello, How are you?"}]
53 |
54 | res: LLMResponse = llm.run(
55 | messages=messages,
56 | functions=function_shemas,
57 | model="gpt-3.5-turbo-0613",
58 | )
59 |
60 | assert res.error is None, "error is not None"
61 | assert res.api_response is not None, "api_response is None"
62 | assert (
63 | res.api_response.choices[0].finish_reason == "stop"
64 | ), "finish_reason is not stop"
65 |
66 | assert res.function_call is None, "function_call is not None"
67 | assert res.raw_output is not None, "raw_output is None"
68 |
69 |
70 | @pytest.mark.asyncio
71 | async def test_arun_with_functions(mocker):
72 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
73 | llm = LLM()
74 | res: LLMResponse = await llm.arun(
75 | messages=messages,
76 | functions=function_shemas,
77 | model="gpt-3.5-turbo-0613",
78 | )
79 |
80 | assert res.error is None, "error is not None"
81 | assert res.api_response is not None, "api_response is None"
82 | assert (
83 | res.api_response.choices[0].finish_reason == "function_call"
84 | ), "finish_reason is not function_call"
85 |
86 | assert res.function_call is not None, "function_call is None"
87 | assert isinstance(res.function_call, FunctionCall)
88 |
89 | messages = [{"role": "user", "content": "Hello, How are you?"}]
90 |
91 | res: LLMResponse = await llm.arun(
92 | messages=messages,
93 | functions=function_shemas,
94 | model="gpt-3.5-turbo-0613",
95 | )
96 |
97 | assert res.error is None, "error is not None"
98 | assert res.api_response is not None, "api_response is None"
99 | assert (
100 | res.api_response.choices[0].finish_reason == "stop"
101 | ), "finish_reason is not stop"
102 |
103 | assert res.function_call is None, "function_call is not None"
104 | assert res.raw_output is not None, "raw_output is None"
105 |
106 |
107 | def test_run_and_parse_with_functions(mocker):
108 | # With parsing_type = None
109 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
110 | llm = LLM()
111 | res: LLMResponse = llm.run_and_parse(
112 | messages=messages,
113 | functions=function_shemas,
114 | model="gpt-3.5-turbo-0613",
115 | parsing_type=None,
116 | )
117 | assert res.error is False, "error is not False"
118 | assert res.api_response is not None, "api_response is None"
119 | assert (
120 | res.api_response.choices[0].finish_reason == "function_call"
121 | ), "finish_reason is not function_call"
122 |
123 | assert res.function_call is not None, "function_call is None"
124 | assert isinstance(res.function_call, FunctionCall)
125 |
126 | messages = [{"role": "user", "content": "Hello, How are you?"}]
127 |
128 | res: LLMResponse = llm.run_and_parse(
129 | messages=messages,
130 | functions=function_shemas,
131 | model="gpt-3.5-turbo-0613",
132 | parsing_type=None,
133 | )
134 |
135 | assert res.error is False, "error is not False"
136 | assert res.api_response is not None, "api_response is None"
137 | assert (
138 | res.api_response.choices[0].finish_reason == "stop"
139 | ), "finish_reason is not stop"
140 |
141 | assert res.function_call is None, "function_call is not None"
142 | assert res.raw_output is not None, "raw_output is None"
143 |
144 | # with parsing_type = "HTML"
145 |
146 | messages = [
147 | {
148 | "role": "user",
149 | "content": "What is the weather like in Boston? \n" + html_output_format,
150 | }
151 | ]
152 | llm = LLM()
153 | res: LLMResponse = llm.run_and_parse(
154 | messages=messages,
155 | functions=function_shemas,
156 | model="gpt-3.5-turbo-0613",
157 | parsing_type=ParsingType.HTML.value,
158 | output_keys=["response"],
159 | )
160 |
161 | # 1. Output 지키고 function call -> (Pass)
162 | # 2. Output 지키고 stop -> OK
163 | # 3. Output 무시하고 function call -> OK (function call이 나타나면 파싱을 하지 않도록 수정)
164 |
165 | # In this case, error is True because the output is not in the correct format
166 | assert res.error is False, "error is not False"
167 | assert res.api_response is not None, "api_response is None"
168 | assert (
169 | res.api_response.choices[0].finish_reason == "function_call"
170 | ), "finish_reason is not function_call"
171 |
172 | assert res.function_call is not None, "function_call is None"
173 | assert isinstance(res.function_call, FunctionCall)
174 |
175 | assert res.parsed_outputs is None, "parsed_outputs is not empty"
176 |
177 | messages = [
178 | {
179 | "role": "user",
180 | "content": "Hello, How are you?\n" + html_output_format,
181 | }
182 | ]
183 |
184 | res: LLMResponse = llm.run_and_parse(
185 | messages=messages,
186 | functions=function_shemas,
187 | model="gpt-4-1106-preview",
188 | parsing_type=ParsingType.HTML.value,
189 | output_keys=["response"],
190 | )
191 |
192 | print(res.__dict__)
193 |
194 | if not "str" in res.raw_output:
195 | # if "str" in res.raw_output, it means that LLM make mistakes
196 | assert res.error is False, "error is not False"
197 | assert res.api_response is not None, "api_response is None"
198 | assert (
199 | res.api_response.choices[0].finish_reason == "stop"
200 | ), "finish_reason is not stop"
201 |
202 | assert res.function_call is None, "function_call is not None"
203 | assert res.raw_output is not None, "raw_output is None"
204 | assert res.parsed_outputs != {}, "parsed_outputs is empty dict"
205 |
206 |
207 | @pytest.mark.asyncio
208 | async def test_arun_and_parse_with_functions(mocker):
209 | # With parsing_type = None
210 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
211 | llm = LLM()
212 | res: LLMResponse = await llm.arun_and_parse(
213 | messages=messages,
214 | functions=function_shemas,
215 | model="gpt-3.5-turbo-0613",
216 | parsing_type=None,
217 | )
218 |
219 | assert res.error is False, "error is not False"
220 | assert res.api_response is not None, "api_response is None"
221 | assert (
222 | res.api_response.choices[0].finish_reason == "function_call"
223 | ), "finish_reason is not function_call"
224 |
225 | assert res.function_call is not None, "function_call is None"
226 | assert isinstance(res.function_call, FunctionCall)
227 |
228 | messages = [{"role": "user", "content": "Hello, How are you?"}]
229 |
230 | res: LLMResponse = await llm.arun_and_parse(
231 | messages=messages,
232 | functions=function_shemas,
233 | model="gpt-3.5-turbo-0613",
234 | parsing_type=None,
235 | )
236 |
237 | assert res.error is False, "error is not False"
238 | assert res.api_response is not None, "api_response is None"
239 | assert (
240 | res.api_response.choices[0].finish_reason == "stop"
241 | ), "finish_reason is not stop"
242 |
243 | assert res.function_call is None, "function_call is not None"
244 | assert res.raw_output is not None, "raw_output is None"
245 |
246 | # with parsing_type = "HTML"
247 |
248 | messages = [
249 | {
250 | "role": "user",
251 | "content": "What is the weather like in Boston? \n" + html_output_format,
252 | }
253 | ]
254 | llm = LLM()
255 | res: LLMResponse = await llm.arun_and_parse(
256 | messages=messages,
257 | functions=function_shemas,
258 | model="gpt-3.5-turbo-0613",
259 | parsing_type=ParsingType.HTML.value,
260 | output_keys=["response"],
261 | )
262 |
263 | # In this case, error is False becuase if function_call, parsing is not performed
264 | assert res.error is False, "error is not False"
265 | assert res.api_response is not None, "api_response is None"
266 | assert (
267 | res.api_response.choices[0].finish_reason == "function_call"
268 | ), "finish_reason is not function_call"
269 |
270 | assert res.function_call is not None, "function_call is None"
271 | assert isinstance(res.function_call, FunctionCall)
272 |
273 | assert res.parsed_outputs is None, "parsed_outputs is not empty"
274 |
275 | messages = [
276 | {
277 | "role": "user",
278 | "content": "Hello, How are you?\n" + html_output_format,
279 | }
280 | ]
281 |
282 | res: LLMResponse = await llm.arun_and_parse(
283 | messages=messages,
284 | functions=function_shemas,
285 | model="gpt-4-1106-preview",
286 | parsing_type=ParsingType.HTML.value,
287 | output_keys=["response"],
288 | )
289 | # if not "str" in res.raw_output:
290 | # # if "str" in res.raw_output, it means that LLM make mistakes
291 | assert res.error is False, "error is not False"
292 | assert res.parsed_outputs != {}, "parsed_outputs is empty"
293 |
294 | assert res.api_response is not None, "api_response is None"
295 | assert (
296 | res.api_response.choices[0].finish_reason == "stop"
297 | ), "finish_reason is not stop"
298 |
299 | assert res.function_call is None, "function_call is not None"
300 | assert res.raw_output is not None, "raw_output is None"
301 |
--------------------------------------------------------------------------------
/tests/llm/llm_dev/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/llm_dev/__init__.py
--------------------------------------------------------------------------------
/tests/llm/llm_dev/llm_dev_test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/llm_dev/llm_dev_test.py
--------------------------------------------------------------------------------
/tests/llm/llm_proxy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/llm/llm_proxy/__init__.py
--------------------------------------------------------------------------------
/tests/llm/llm_proxy/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from promptmodel.llms.llm_proxy import LLMProxy
5 |
6 |
7 | async def echo_coroutine(*args, **kwargs):
8 | # print(args, kwargs)
9 | return args, kwargs
10 |
11 |
12 | @pytest.fixture
13 | def mock_fetch_prompts():
14 | mock_fetch_prompts = AsyncMock()
15 | mock_prompts = [
16 | {"role": "system", "content": "You are a helpful assistant.", "step": 1},
17 | {"role": "user", "content": "Hello!", "step": 2},
18 | ]
19 | mock_version_details = {
20 | "model": "gpt-3.5-turbo",
21 | "uuid": "testuuid",
22 | "version": 1,
23 | "parsing_type": None,
24 | "output_keys": None,
25 | }
26 | mock_fetch_prompts.return_value = (mock_prompts, mock_version_details)
27 |
28 | return mock_fetch_prompts
29 |
30 |
31 | @pytest.fixture
32 | def mock_fetch_chat_model():
33 | mock_fetch_chat_model = AsyncMock()
34 | mock_instruction = "You are a helpful assistant."
35 | mock_version_details = {
36 | "model": "gpt-4-1106-preview",
37 | "uuid": "testuuid",
38 | "version": 1,
39 | }
40 | mock_message_logs = [
41 | {
42 | "role": "system",
43 | "content": "You are a helpful assistant.",
44 | "session_uuid": "testuuid",
45 | },
46 | {"role": "user", "content": "Hello!", "session_uuid": "testuuid"},
47 | {
48 | "role": "assistant",
49 | "content": "Hello! How can I help you?",
50 | "session_uuid": "testuuid",
51 | },
52 | ]
53 |
54 | mock_fetch_chat_model.return_value = (
55 | mock_instruction,
56 | mock_version_details,
57 | mock_message_logs,
58 | )
59 |
60 | return mock_fetch_chat_model
61 |
62 |
63 | @pytest.fixture
64 | def mock_async_log_to_cloud():
65 | mock_async_log_to_cloud = AsyncMock()
66 |
67 | mock_response = MagicMock()
68 | mock_response.status_code = 200
69 | mock_async_log_to_cloud.return_value = mock_response
70 |
71 | return mock_async_log_to_cloud
72 |
73 |
74 | @pytest.fixture
75 | def mock_async_chat_log_to_cloud():
76 | mock_async_chat_log_to_cloud = AsyncMock()
77 |
78 | mock_response = MagicMock()
79 | mock_response.status_code = 200
80 | mock_async_chat_log_to_cloud.return_value = mock_response
81 |
82 | return mock_async_chat_log_to_cloud
83 |
84 |
85 | @pytest.fixture
86 | def mock_async_make_session_cloud():
87 | mock_async_make_session_cloud = AsyncMock()
88 |
89 | mock_response = MagicMock()
90 | mock_response.status_code = 200
91 | mock_async_make_session_cloud.return_value = mock_response
92 |
93 | return mock_async_make_session_cloud
94 |
--------------------------------------------------------------------------------
/tests/llm/llm_proxy/proxy_chat_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock
3 |
4 | from typing import Generator, AsyncGenerator, List
5 |
6 | from promptmodel.llms.llm_proxy import LLMProxy
7 | from promptmodel.types.response import LLMResponse, LLMStreamResponse
8 |
9 | proxy = LLMProxy(name="test")
10 |
11 |
12 | def test_chat_run(
13 | mocker,
14 | mock_fetch_chat_model: AsyncMock,
15 | mock_async_chat_log_to_cloud: AsyncMock,
16 | ):
17 | fetch_chat_model = mocker.patch(
18 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
19 | )
20 |
21 | async_chat_log_to_cloud = mocker.patch(
22 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
23 | mock_async_chat_log_to_cloud,
24 | )
25 |
26 | res: LLMResponse = proxy.chat_run(session_uuid="testuuid")
27 |
28 | fetch_chat_model.assert_called_once()
29 | async_chat_log_to_cloud.assert_called_once()
30 |
31 | assert res.raw_output is not None
32 | assert res.error is None or res.error is False
33 | assert res.api_response is not None
34 | assert res.parsed_outputs is None
35 | if isinstance(res.api_response.usage, dict):
36 | assert res.api_response.usage["prompt_tokens"] > 15
37 | print(res.api_response.usage["prompt_tokens"])
38 | else:
39 | assert res.api_response.usage.prompt_tokens > 15
40 | print(res.api_response.usage.prompt_tokens)
41 |
42 |
43 | @pytest.mark.asyncio
44 | async def test_chat_arun(
45 | mocker,
46 | mock_fetch_chat_model: AsyncMock,
47 | mock_async_chat_log_to_cloud: AsyncMock,
48 | ):
49 | fetch_chat_model = mocker.patch(
50 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
51 | )
52 |
53 | async_chat_log_to_cloud = mocker.patch(
54 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
55 | mock_async_chat_log_to_cloud,
56 | )
57 |
58 | res: LLMResponse = await proxy.chat_arun(session_uuid="testuuid")
59 |
60 | fetch_chat_model.assert_called_once()
61 | async_chat_log_to_cloud.assert_called_once()
62 |
63 | assert res.raw_output is not None
64 | assert res.error is None or res.error is False
65 | assert res.api_response is not None
66 | assert res.parsed_outputs is None
67 |
68 |
69 | def test_chat_stream(
70 | mocker,
71 | mock_fetch_chat_model: AsyncMock,
72 | mock_async_chat_log_to_cloud: AsyncMock,
73 | ):
74 | fetch_chat_model = mocker.patch(
75 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
76 | )
77 |
78 | async_chat_log_to_cloud = mocker.patch(
79 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
80 | mock_async_chat_log_to_cloud,
81 | )
82 |
83 | res: Generator[LLMStreamResponse, None, None] = proxy.chat_stream(
84 | session_uuid="testuuid"
85 | )
86 | chunks: List[LLMStreamResponse] = []
87 | for chunk in res:
88 | chunks.append(chunk)
89 |
90 | fetch_chat_model.assert_called_once()
91 | async_chat_log_to_cloud.assert_called_once()
92 |
93 | assert chunks[-1].api_response is not None
94 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
95 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
96 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
97 |
98 |
99 | @pytest.mark.asyncio
100 | async def test_chat_astream(
101 | mocker,
102 | mock_fetch_chat_model: AsyncMock,
103 | mock_async_chat_log_to_cloud: AsyncMock,
104 | ):
105 | fetch_chat_model = mocker.patch(
106 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
107 | )
108 |
109 | async_chat_log_to_cloud = mocker.patch(
110 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
111 | mock_async_chat_log_to_cloud,
112 | )
113 |
114 | res: AsyncGenerator[LLMStreamResponse, None] = proxy.chat_astream(
115 | session_uuid="testuuid"
116 | )
117 | chunks: List[LLMStreamResponse] = []
118 | async for chunk in res:
119 | chunks.append(chunk)
120 |
121 | fetch_chat_model.assert_called_once()
122 | async_chat_log_to_cloud.assert_called_once()
123 |
124 | assert chunks[-1].api_response is not None
125 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
126 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
127 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
128 |
129 |
130 | def test_chat_run_extra_long_input(
131 | mocker,
132 | mock_fetch_chat_model: AsyncMock,
133 | mock_async_chat_log_to_cloud: AsyncMock,
134 | ):
135 | mocker.patch("promptmodel.llms.llm_proxy.get_max_tokens", return_value=10)
136 |
137 | fetch_chat_model = mocker.patch(
138 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_chat_model", mock_fetch_chat_model
139 | )
140 |
141 | async_chat_log_to_cloud = mocker.patch(
142 | "promptmodel.llms.llm_proxy.LLMProxy._async_chat_log_to_cloud",
143 | mock_async_chat_log_to_cloud,
144 | )
145 |
146 | res: LLMResponse = proxy.chat_run(session_uuid="testuuid")
147 |
148 | fetch_chat_model.assert_called_once()
149 | async_chat_log_to_cloud.assert_called_once()
150 |
151 | assert res.raw_output is not None
152 | assert res.error is None or res.error is False
153 | assert res.api_response is not None
154 | assert res.parsed_outputs is None
155 | if isinstance(res.api_response.usage, dict):
156 | assert res.api_response.usage["prompt_tokens"] < 15
157 | print(res.api_response.usage["prompt_tokens"])
158 | else:
159 | assert res.api_response.usage.prompt_tokens < 15
160 | print(res.api_response.usage.prompt_tokens)
161 |
--------------------------------------------------------------------------------
/tests/llm/llm_proxy/proxy_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, MagicMock
3 |
4 | from typing import Generator, AsyncGenerator, List
5 |
6 | from promptmodel.llms.llm_proxy import LLMProxy
7 | from promptmodel.types.response import LLMResponse, LLMStreamResponse
8 |
9 | proxy = LLMProxy(name="test")
10 |
11 |
12 | def test_run(mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock):
13 | fetch_prompts = mocker.patch(
14 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
15 | )
16 | async_log_to_cloud = mocker.patch(
17 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
18 | mock_async_log_to_cloud,
19 | )
20 |
21 | res: LLMResponse = proxy.run({})
22 | fetch_prompts.assert_called_once()
23 | async_log_to_cloud.assert_called_once()
24 | assert res.raw_output is not None
25 | assert res.error is None or res.error is False
26 | assert res.api_response is not None
27 | assert res.parsed_outputs is None
28 |
29 |
30 | @pytest.mark.asyncio
31 | async def test_arun(
32 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
33 | ):
34 | fetch_prompts = mocker.patch(
35 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
36 | )
37 | async_log_to_cloud = mocker.patch(
38 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
39 | mock_async_log_to_cloud,
40 | )
41 |
42 | res: LLMResponse = await proxy.arun({})
43 | fetch_prompts.assert_called_once()
44 | async_log_to_cloud.assert_called_once()
45 | assert res.raw_output is not None
46 | assert res.error is None or res.error is False
47 | assert res.api_response is not None
48 | assert res.parsed_outputs is None
49 |
50 |
51 | def test_stream(
52 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
53 | ):
54 | fetch_prompts = mocker.patch(
55 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
56 | )
57 | async_log_to_cloud = mocker.patch(
58 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
59 | mock_async_log_to_cloud,
60 | )
61 |
62 | res: Generator[LLMStreamResponse, None, None] = proxy.stream({})
63 | chunks: List[LLMStreamResponse] = []
64 | for chunk in res:
65 | chunks.append(chunk)
66 | fetch_prompts.assert_called_once()
67 | async_log_to_cloud.assert_called_once()
68 |
69 | assert chunks[-1].api_response is not None
70 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
71 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
72 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
73 |
74 |
75 | @pytest.mark.asyncio
76 | async def test_astream(
77 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
78 | ):
79 | fetch_prompts = mocker.patch(
80 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
81 | )
82 | async_log_to_cloud = mocker.patch(
83 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
84 | mock_async_log_to_cloud,
85 | )
86 |
87 | res: AsyncGenerator[LLMStreamResponse, None] = proxy.astream({})
88 | chunks: List[LLMStreamResponse] = []
89 | async for chunk in res:
90 | chunks.append(chunk)
91 | fetch_prompts.assert_called_once()
92 | async_log_to_cloud.assert_called_once()
93 |
94 | assert chunks[-1].api_response is not None
95 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
96 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
97 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
98 |
99 |
100 | def test_run_and_parse(
101 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
102 | ):
103 | fetch_prompts = mocker.patch(
104 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
105 | )
106 | async_log_to_cloud = mocker.patch(
107 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
108 | mock_async_log_to_cloud,
109 | )
110 |
111 | res: LLMResponse = proxy.run({})
112 | fetch_prompts.assert_called_once()
113 | async_log_to_cloud.assert_called_once()
114 | assert res.raw_output is not None
115 | assert res.error is None or res.error is False
116 | assert res.api_response is not None
117 | assert res.parsed_outputs is None
118 |
119 | fetch_prompts.reset_mock()
120 | async_log_to_cloud.reset_mock()
121 |
122 | # mock run
123 | mock_run = MagicMock()
124 | mock_run.return_value = LLMResponse(
125 | api_response=res.api_response, parsed_outputs={"key": "value"}
126 | )
127 | mocker.patch("promptmodel.llms.llm.LLM.run", mock_run)
128 | mock_res = proxy.run({})
129 | fetch_prompts.assert_called_once()
130 | async_log_to_cloud.assert_called_once()
131 | assert mock_res.raw_output is None
132 | assert mock_res.error is None or res.error is False
133 | assert mock_res.api_response is not None
134 | assert mock_res.parsed_outputs is not None
135 |
136 |
137 | @pytest.mark.asyncio
138 | async def test_arun_and_parse(
139 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
140 | ):
141 | fetch_prompts = mocker.patch(
142 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
143 | )
144 | async_log_to_cloud = mocker.patch(
145 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
146 | mock_async_log_to_cloud,
147 | )
148 |
149 | res: LLMResponse = await proxy.arun({})
150 | fetch_prompts.assert_called_once()
151 | async_log_to_cloud.assert_called_once()
152 | assert res.raw_output is not None
153 | assert res.error is None or res.error is False
154 | assert res.api_response is not None
155 | assert res.parsed_outputs is None
156 |
157 | fetch_prompts.reset_mock()
158 | async_log_to_cloud.reset_mock()
159 |
160 | # mock run
161 | mock_run = AsyncMock()
162 | mock_run.return_value = LLMResponse(
163 | api_response=res.api_response, parsed_outputs={"key": "value"}
164 | )
165 | mocker.patch("promptmodel.llms.llm.LLM.arun", mock_run)
166 | mock_res: LLMResponse = await proxy.arun({})
167 | fetch_prompts.assert_called_once()
168 | async_log_to_cloud.assert_called_once()
169 | assert mock_res.raw_output is None
170 | assert mock_res.error is None or res.error is False
171 | assert mock_res.api_response is not None
172 | assert mock_res.parsed_outputs is not None
173 |
174 |
175 | def test_stream_and_parse(
176 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
177 | ):
178 | fetch_prompts = mocker.patch(
179 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
180 | )
181 | async_log_to_cloud = mocker.patch(
182 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
183 | mock_async_log_to_cloud,
184 | )
185 |
186 | res: Generator[LLMStreamResponse, None, None] = proxy.stream({})
187 | chunks: List[LLMStreamResponse] = []
188 | for chunk in res:
189 | chunks.append(chunk)
190 | fetch_prompts.assert_called_once()
191 | async_log_to_cloud.assert_called_once()
192 |
193 | assert chunks[-1].api_response is not None
194 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
195 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
196 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
197 |
198 | fetch_prompts.reset_mock()
199 | async_log_to_cloud.reset_mock()
200 |
201 | def mock_stream_generator(*args, **kwargs):
202 | yield LLMStreamResponse(parsed_outputs={"key": "value"})
203 | yield LLMStreamResponse(api_response=chunks[-1].api_response)
204 |
205 | mock_run = MagicMock(side_effect=mock_stream_generator)
206 | mocker.patch("promptmodel.llms.llm.LLM.stream", mock_run)
207 |
208 | mock_res: Generator[LLMStreamResponse, None, None] = proxy.stream({})
209 | mock_chunks: List[LLMStreamResponse] = []
210 | for chunk in mock_res:
211 | mock_chunks.append(chunk)
212 | fetch_prompts.assert_called_once()
213 | async_log_to_cloud.assert_called_once()
214 |
215 | assert mock_chunks[-1].api_response is not None
216 | assert len([chunk for chunk in mock_chunks if chunk.error is not None]) == 0
217 | assert len([chunk for chunk in mock_chunks if chunk.parsed_outputs is not None]) > 0
218 | assert len([chunk for chunk in mock_chunks if chunk.raw_output is not None]) == 0
219 |
220 |
221 | @pytest.mark.asyncio
222 | async def test_astream_and_parse(
223 | mocker, mock_fetch_prompts: AsyncMock, mock_async_log_to_cloud: AsyncMock
224 | ):
225 | fetch_prompts = mocker.patch(
226 | "promptmodel.llms.llm_proxy.LLMProxy.fetch_prompts", mock_fetch_prompts
227 | )
228 | async_log_to_cloud = mocker.patch(
229 | "promptmodel.llms.llm_proxy.LLMProxy._async_log_to_cloud",
230 | mock_async_log_to_cloud,
231 | )
232 |
233 | res: AsyncGenerator[LLMStreamResponse, None] = proxy.astream({})
234 | chunks: List[LLMStreamResponse] = []
235 | async for chunk in res:
236 | chunks.append(chunk)
237 | fetch_prompts.assert_called_once()
238 | async_log_to_cloud.assert_called_once()
239 |
240 | assert chunks[-1].api_response is not None
241 | assert len([chunk for chunk in chunks if chunk.error is not None]) == 0
242 | assert len([chunk for chunk in chunks if chunk.parsed_outputs is not None]) == 0
243 | assert len([chunk for chunk in chunks if chunk.raw_output is not None]) > 0
244 |
245 | fetch_prompts.reset_mock()
246 | async_log_to_cloud.reset_mock()
247 |
248 | async def mock_stream_generator(*args, **kwargs):
249 | yield LLMStreamResponse(parsed_outputs={"key": "value"})
250 | yield LLMStreamResponse(api_response=chunks[-1].api_response)
251 |
252 | mock_run = MagicMock(side_effect=mock_stream_generator)
253 | mocker.patch("promptmodel.llms.llm.LLM.astream", mock_run)
254 |
255 | mock_res: AsyncGenerator[LLMStreamResponse, None] = proxy.astream({})
256 | mock_chunks: List[LLMStreamResponse] = []
257 | async for chunk in mock_res:
258 | mock_chunks.append(chunk)
259 | fetch_prompts.assert_called_once()
260 | async_log_to_cloud.assert_called_once()
261 |
262 | assert mock_chunks[-1].api_response is not None
263 | assert len([chunk for chunk in mock_chunks if chunk.error is not None]) == 0
264 | assert len([chunk for chunk in mock_chunks if chunk.parsed_outputs is not None]) > 0
265 | assert len([chunk for chunk in mock_chunks if chunk.raw_output is not None]) == 0
266 |
--------------------------------------------------------------------------------
/tests/llm/stream_function_call_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, MagicMock
3 |
4 | import nest_asyncio
5 | from typing import Generator, AsyncGenerator, List, Optional
6 | from litellm import ModelResponse
7 |
8 | from ..constants import function_shemas
9 | from promptmodel.llms.llm import LLM
10 | from promptmodel.llms.llm_proxy import LLMProxy
11 | from promptmodel.types.response import *
12 | from promptmodel.types.enums import ParsingType
13 | from promptmodel.utils.async_utils import run_async_in_sync
14 |
15 | html_output_format = """\
16 | You must follow the provided output format. Keep the string between <> as it is.
17 | Output format:
18 |
19 | (value here)
20 |
21 | """
22 |
23 | double_bracket_output_format = """\
24 | You must follow the provided output format. Keep the string between [[ ]] as it is.
25 | Output format:
26 | [[response type=str]]
27 | (value here)
28 | [[/response]]
29 | """
30 |
31 | colon_output_format = """\
32 | You must follow the provided output format. Keep the string before : as it is
33 | Output format:
34 | response type=str: (value here)
35 |
36 | """
37 |
38 |
39 | def test_stream_with_functions(mocker):
40 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
41 | llm = LLM()
42 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream(
43 | messages=messages,
44 | functions=function_shemas,
45 | model="gpt-3.5-turbo-0613",
46 | )
47 |
48 | error_count = 0
49 | api_responses: List[ModelResponse] = []
50 | final_response: Optional[LLMStreamResponse] = None
51 | chunks: List[LLMStreamResponse] = []
52 | for res in stream_res:
53 | chunks.append(res)
54 | if res.error:
55 | error_count += 1
56 | if (
57 | res.api_response
58 | and getattr(res.api_response.choices[0], "delta", None) is None
59 | ):
60 | api_responses.append(res.api_response)
61 | final_response = res
62 |
63 | assert error_count == 0, "error_count is not 0"
64 | assert len(api_responses) == 1, "api_count is not 1"
65 | assert (
66 | getattr(final_response.api_response.choices[0].message, "function_call", None)
67 | is not None
68 | ), "function_call is None"
69 | assert isinstance(
70 | final_response.api_response.choices[0].message.function_call, FunctionCall
71 | )
72 |
73 | assert len([c for c in chunks if c.function_call is not None]) > 0
74 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall) is True
75 |
76 | assert api_responses[0].choices[0].message.content is None, "content is not None"
77 | assert api_responses[0]._response_ms is not None, "response_ms is None"
78 | assert api_responses[0].usage is not None, "usage is None"
79 |
80 | # test logging
81 | llm_proxy = LLMProxy("test")
82 |
83 | mock_execute = AsyncMock()
84 | mock_response = MagicMock()
85 | mock_response.status_code = 200
86 | mock_execute.return_value = mock_response
87 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute)
88 |
89 | nest_asyncio.apply()
90 | run_async_in_sync(
91 | llm_proxy._async_log_to_cloud(
92 | log_uuid="test",
93 | version_uuid="test",
94 | inputs={},
95 | api_response=api_responses[0],
96 | parsed_outputs={},
97 | metadata={},
98 | )
99 | )
100 |
101 | mock_execute.assert_called_once()
102 | _, kwargs = mock_execute.call_args
103 | api_response_dict = api_responses[0].model_dump()
104 | api_response_dict.update({"response_ms": api_responses[0]._response_ms})
105 |
106 | assert (
107 | kwargs["json"]["api_response"] == api_response_dict
108 | ), "api_response is not equal"
109 |
110 |
111 | @pytest.mark.asyncio
112 | async def test_astream_with_functions(mocker):
113 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
114 | llm = LLM()
115 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream(
116 | messages=messages,
117 | functions=function_shemas,
118 | model="gpt-3.5-turbo-0613",
119 | )
120 |
121 | error_count = 0
122 | api_responses: List[ModelResponse] = []
123 | final_response: Optional[LLMStreamResponse] = None
124 | chunks: List[LLMStreamResponse] = []
125 | async for res in stream_res:
126 | chunks.append(res)
127 | if res.error:
128 | error_count += 1
129 | if (
130 | res.api_response
131 | and getattr(res.api_response.choices[0], "delta", None) is None
132 | ):
133 | api_responses.append(res.api_response)
134 | final_response = res
135 |
136 | assert error_count == 0, "error_count is not 0"
137 | assert len(api_responses) == 1, "api_count is not 1"
138 | assert (
139 | getattr(final_response.api_response.choices[0].message, "function_call", None)
140 | is not None
141 | ), "function_call is None"
142 | assert isinstance(
143 | final_response.api_response.choices[0].message.function_call, FunctionCall
144 | )
145 | assert len([c for c in chunks if c.function_call is not None]) > 0
146 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall)
147 |
148 | assert api_responses[0].choices[0].message.content is None, "content is not None"
149 | assert api_responses[0]._response_ms is not None, "response_ms is None"
150 | assert api_responses[0].usage is not None, "usage is None"
151 | # assert api_responses[0].usage.prompt_tokens == 74, "prompt_tokens is not 74"
152 |
153 | # test logging
154 | llm_proxy = LLMProxy("test")
155 |
156 | mock_execute = AsyncMock()
157 | mock_response = MagicMock()
158 | mock_response.status_code = 200
159 | mock_execute.return_value = mock_response
160 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute)
161 | nest_asyncio.apply()
162 | await llm_proxy._async_log_to_cloud(
163 | log_uuid="test",
164 | version_uuid="test",
165 | inputs={},
166 | api_response=api_responses[0],
167 | parsed_outputs={},
168 | metadata={},
169 | )
170 |
171 | mock_execute.assert_called_once()
172 | _, kwargs = mock_execute.call_args
173 | api_response_dict = api_responses[0].model_dump()
174 | api_response_dict.update({"response_ms": api_responses[0]._response_ms})
175 |
176 | assert (
177 | kwargs["json"]["api_response"] == api_response_dict
178 | ), "api_response is not equal"
179 |
180 |
181 | def test_stream_and_parse_with_functions(mocker):
182 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
183 | llm = LLM()
184 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream_and_parse(
185 | messages=messages,
186 | functions=function_shemas,
187 | model="gpt-3.5-turbo-0613",
188 | parsing_type=None,
189 | )
190 |
191 | error_count = 0
192 | api_responses: List[ModelResponse] = []
193 | final_response: Optional[LLMStreamResponse] = None
194 | chunks: List[LLMStreamResponse] = []
195 | for res in stream_res:
196 | chunks.append(res)
197 | if res.error:
198 | error_count += 1
199 | print("ERROR")
200 | print(res.error)
201 | print(res.error_log)
202 | if (
203 | res.api_response
204 | and getattr(res.api_response.choices[0], "delta", None) is None
205 | ):
206 | api_responses.append(res.api_response)
207 | final_response = res
208 |
209 | assert error_count == 0, "error_count is not 0"
210 | assert len(api_responses) == 1, "api_count is not 1"
211 | assert (
212 | getattr(final_response.api_response.choices[0].message, "function_call", None)
213 | is not None
214 | ), "function_call is None"
215 | assert isinstance(
216 | final_response.api_response.choices[0].message.function_call, FunctionCall
217 | )
218 | assert len([c for c in chunks if c.function_call is not None]) > 0
219 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall)
220 |
221 | assert api_responses[0].choices[0].message.content is None, "content is not None"
222 |
223 | # Not call function, parsing case
224 | messages = [
225 | {
226 | "role": "user",
227 | "content": "Hello, How are you?\n" + html_output_format,
228 | }
229 | ]
230 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream_and_parse(
231 | messages=messages,
232 | functions=function_shemas,
233 | model="gpt-4-1106-preview",
234 | parsing_type=ParsingType.HTML.value,
235 | output_keys=["response"],
236 | )
237 |
238 | error_count = 0
239 | error_log = ""
240 | api_responses: List[ModelResponse] = []
241 | final_response: Optional[LLMStreamResponse] = None
242 | chunks: List[LLMStreamResponse] = []
243 | for res in stream_res:
244 | chunks.append(res)
245 | if res.error:
246 | error_count += 1
247 | error_log = res.error_log
248 | print("ERROR")
249 | if (
250 | res.api_response
251 | and getattr(res.api_response.choices[0], "delta", None) is None
252 | ):
253 | api_responses.append(res.api_response)
254 | final_response = res
255 |
256 | if not "str" in api_responses[0].choices[0].message.content:
257 | # if "str" in content, just LLM make mistake in generation.
258 | assert (
259 | error_count == 0
260 | ), f"error_count is not 0, {error_log}, {api_responses[0].model_dump()}"
261 | assert final_response.parsed_outputs is not None, "parsed_outputs is None"
262 | assert len(api_responses) == 1, "api_count is not 1"
263 | assert (
264 | getattr(final_response.api_response.choices[0].message, "function_call", None)
265 | is None
266 | ), "function_call is not None"
267 | assert len([c for c in chunks if c.function_call is not None]) == 0
268 |
269 | assert api_responses[0].choices[0].message.content is not None, "content is None"
270 |
271 |
272 | @pytest.mark.asyncio
273 | async def test_astream_and_parse_with_functions(mocker):
274 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
275 | llm = LLM()
276 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream_and_parse(
277 | messages=messages,
278 | functions=function_shemas,
279 | model="gpt-3.5-turbo-0613",
280 | parsing_type=None,
281 | )
282 |
283 | error_count = 0
284 | api_responses: List[ModelResponse] = []
285 | final_response: Optional[LLMStreamResponse] = None
286 | chunks: List[LLMStreamResponse] = []
287 | async for res in stream_res:
288 | chunks.append(res)
289 | if res.error:
290 | error_count += 1
291 |
292 | if (
293 | res.api_response
294 | and getattr(res.api_response.choices[0], "delta", None) is None
295 | ):
296 | api_responses.append(res.api_response)
297 | final_response = res
298 |
299 | assert error_count == 0, "error_count is not 0"
300 | assert len(api_responses) == 1, "api_count is not 1"
301 |
302 | assert api_responses[0].choices[0].message.content is None, "content is not None"
303 | assert (
304 | getattr(final_response.api_response.choices[0].message, "function_call", None)
305 | is not None
306 | ), "function_call is None"
307 | assert isinstance(
308 | final_response.api_response.choices[0].message.function_call, FunctionCall
309 | )
310 | assert len([c for c in chunks if c.function_call is not None]) > 0
311 | assert isinstance(chunks[1].function_call, ChoiceDeltaFunctionCall)
312 |
313 | # Not call function, parsing case
314 | messages = [
315 | {
316 | "role": "user",
317 | "content": "Hello, How are you?\n" + html_output_format,
318 | }
319 | ]
320 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream_and_parse(
321 | messages=messages,
322 | functions=function_shemas,
323 | model="gpt-3.5-turbo-0613",
324 | parsing_type=ParsingType.HTML.value,
325 | output_keys=["response"],
326 | )
327 |
328 | error_count = 0
329 | error_log = ""
330 | api_responses: List[ModelResponse] = []
331 | final_response: Optional[LLMStreamResponse] = None
332 | chunks: List[LLMStreamResponse] = []
333 | async for res in stream_res:
334 | chunks.append(res)
335 | if res.error:
336 | error_count += 1
337 | error_log = res.error_log
338 |
339 | if (
340 | res.api_response
341 | and getattr(res.api_response.choices[0], "delta", None) is None
342 | ):
343 | api_responses.append(res.api_response)
344 | final_response = res
345 |
346 | if not "str" in api_responses[0].choices[0].message.content:
347 | # if "str" in content, just LLM make mistake in generation.
348 | assert (
349 | error_count == 0
350 | ), f"error_count is not 0, {error_log}, {api_responses[0].model_dump()}"
351 | assert final_response.parsed_outputs is not None, "parsed_outputs is None"
352 | assert len(api_responses) == 1, "api_count is not 1"
353 | assert (
354 | getattr(final_response.api_response.choices[0].message, "function_call", None)
355 | is None
356 | ), "function_call is not None"
357 |
358 | assert len([c for c in chunks if c.function_call is not None]) == 0
359 |
360 | assert api_responses[0].choices[0].message.content is not None, "content is None"
361 |
--------------------------------------------------------------------------------
/tests/llm/stream_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, MagicMock
3 |
4 | import nest_asyncio
5 | from typing import Generator, AsyncGenerator, List
6 | from litellm import ModelResponse
7 |
8 | from promptmodel.llms.llm import LLM
9 | from promptmodel.llms.llm_proxy import LLMProxy
10 | from promptmodel.types.response import LLMStreamResponse
11 | from promptmodel.utils.async_utils import run_async_in_sync
12 |
13 |
14 | def test_stream(mocker):
15 | llm = LLM()
16 | test_messages = [
17 | {"role": "system", "content": "You are a helpful assistant."},
18 | {"role": "user", "content": "Introduce yourself in 50 words."},
19 | ]
20 |
21 | stream_res: Generator[LLMStreamResponse, None, None] = llm.stream(
22 | messages=test_messages,
23 | model="gpt-3.5-turbo",
24 | )
25 | error_count = 0
26 | api_responses: List[ModelResponse] = []
27 | for res in stream_res:
28 | if res.error:
29 | error_count += 1
30 | print("ERROR")
31 | print(res.error)
32 | print(res.error_log)
33 | if (
34 | res.api_response
35 | and getattr(res.api_response.choices[0], "delta", None) is None
36 | ):
37 | api_responses.append(res.api_response)
38 | print(api_responses)
39 | assert error_count == 0, "error_count is not 0"
40 | assert len(api_responses) == 1, "api_count is not 1"
41 |
42 | assert api_responses[0].choices[0].message.content is not None, "content is None"
43 | assert api_responses[0]._response_ms is not None, "response_ms is None"
44 |
45 | # test logging
46 | llm_proxy = LLMProxy("test")
47 |
48 | mock_execute = AsyncMock()
49 | mock_response = MagicMock()
50 | mock_response.status_code = 200
51 | mock_execute.return_value = mock_response
52 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute)
53 |
54 | nest_asyncio.apply()
55 | run_async_in_sync(
56 | llm_proxy._async_log_to_cloud(
57 | log_uuid="test",
58 | version_uuid="test",
59 | inputs={},
60 | api_response=api_responses[0],
61 | parsed_outputs={},
62 | metadata={},
63 | )
64 | )
65 |
66 | mock_execute.assert_called_once()
67 | _, kwargs = mock_execute.call_args
68 | api_response_dict = api_responses[0].model_dump()
69 | api_response_dict.update({"response_ms": api_responses[0]._response_ms})
70 | assert (
71 | kwargs["json"]["api_response"] == api_response_dict
72 | ), "api_response is not equal"
73 |
74 |
75 | @pytest.mark.asyncio
76 | async def test_astream(mocker):
77 | llm = LLM()
78 | test_messages = [
79 | {"role": "system", "content": "You are a helpful assistant."},
80 | {"role": "user", "content": "Introduce yourself in 50 words."},
81 | ]
82 |
83 | stream_res: AsyncGenerator[LLMStreamResponse, None] = llm.astream(
84 | messages=test_messages,
85 | model="gpt-3.5-turbo",
86 | )
87 | error_count = 0
88 | api_responses: List[ModelResponse] = []
89 | async for res in stream_res:
90 | if res.error:
91 | error_count += 1
92 | print("ERROR")
93 | print(res.error)
94 | print(res.error_log)
95 | if (
96 | res.api_response
97 | and getattr(res.api_response.choices[0], "delta", None) is None
98 | ):
99 | api_responses.append(res.api_response)
100 |
101 | assert error_count == 0, "error_count is not 0"
102 | assert len(api_responses) == 1, "api_count is not 1"
103 |
104 | assert api_responses[0].choices[0].message.content is not None, "content is None"
105 | assert api_responses[0]._response_ms is not None, "response_ms is None"
106 |
107 | # test logging
108 | llm_proxy = LLMProxy("test")
109 |
110 | mock_execute = AsyncMock()
111 | mock_response = MagicMock()
112 | mock_response.status_code = 200
113 | mock_execute.return_value = mock_response
114 | mocker.patch("promptmodel.llms.llm_proxy.AsyncAPIClient.execute", new=mock_execute)
115 |
116 | await llm_proxy._async_log_to_cloud(
117 | log_uuid="test",
118 | version_uuid="test",
119 | inputs={},
120 | api_response=api_responses[0],
121 | parsed_outputs={},
122 | metadata={},
123 | )
124 |
125 | mock_execute.assert_called_once()
126 | _, kwargs = mock_execute.call_args
127 |
128 | api_response_dict = api_responses[0].model_dump()
129 | api_response_dict.update({"response_ms": api_responses[0]._response_ms})
130 |
131 | assert (
132 | kwargs["json"]["api_response"] == api_response_dict
133 | ), "api_response is not equal"
134 |
--------------------------------------------------------------------------------
/tests/llm/tool_calls_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from typing import Generator, AsyncGenerator, Dict, List, Any, Optional
5 | from litellm import ModelResponse
6 |
7 | from ..constants import function_shemas
8 | from promptmodel.llms.llm import LLM
9 | from promptmodel.llms.llm_proxy import LLMProxy
10 | from promptmodel.types.response import *
11 | from promptmodel.types.enums import ParsingType
12 |
13 | html_output_format = """\
14 | You must follow the provided output format. Keep the string between <> as it is.
15 | Output format:
16 |
17 | (value here)
18 |
19 | """
20 |
21 | double_bracket_output_format = """\
22 | You must follow the provided output format. Keep the string between [[ ]] as it is.
23 | Output format:
24 | [[response type=str]]
25 | (value here)
26 | [[/response]]
27 | """
28 |
29 | colon_output_format = """\
30 | You must follow the provided output format. Keep the string before : as it is
31 | Output format:
32 | response type=str: (value here)
33 |
34 | """
35 |
36 | tool_schemas = [{"type": "function", "function": schema} for schema in function_shemas]
37 |
38 |
39 | def test_run_with_tools(mocker):
40 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
41 | llm = LLM()
42 | res: LLMResponse = llm.run(
43 | messages=messages,
44 | tools=tool_schemas,
45 | model="gpt-4-1106-preview",
46 | )
47 |
48 | assert res.error is None, "error is not None"
49 | assert res.api_response is not None, "api_response is None"
50 | assert (
51 | res.api_response.choices[0].finish_reason == "tool_calls"
52 | ), "finish_reason is not tool_calls"
53 | print(res.api_response.model_dump())
54 | print(res.__dict__)
55 | assert res.tool_calls is not None, "tool_calls is None"
56 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall)
57 |
58 | messages = [{"role": "user", "content": "Hello, How are you?"}]
59 |
60 | res: LLMResponse = llm.run(
61 | messages=messages,
62 | tools=tool_schemas,
63 | model="gpt-4-1106-preview",
64 | )
65 |
66 | assert res.error is None, "error is not None"
67 | assert res.api_response is not None, "api_response is None"
68 | assert (
69 | res.api_response.choices[0].finish_reason == "stop"
70 | ), "finish_reason is not stop"
71 |
72 | assert res.tool_calls is None, "tool_calls is not None"
73 | assert res.raw_output is not None, "raw_output is None"
74 |
75 |
76 | @pytest.mark.asyncio
77 | async def test_arun_with_tools(mocker):
78 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
79 | llm = LLM()
80 | res: LLMResponse = await llm.arun(
81 | messages=messages,
82 | tools=tool_schemas,
83 | model="gpt-4-1106-preview",
84 | )
85 |
86 | assert res.error is None, "error is not None"
87 | assert res.api_response is not None, "api_response is None"
88 | assert (
89 | res.api_response.choices[0].finish_reason == "tool_calls"
90 | ), "finish_reason is not tool_calls"
91 |
92 | assert res.tool_calls is not None, "tool_calls is None"
93 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall)
94 |
95 | messages = [{"role": "user", "content": "Hello, How are you?"}]
96 |
97 | res: LLMResponse = await llm.arun(
98 | messages=messages,
99 | tools=tool_schemas,
100 | model="gpt-4-1106-preview",
101 | )
102 |
103 | assert res.error is None, "error is not None"
104 | assert res.api_response is not None, "api_response is None"
105 | assert (
106 | res.api_response.choices[0].finish_reason == "stop"
107 | ), "finish_reason is not stop"
108 |
109 | assert res.tool_calls is None, "tool_calls is not None"
110 | assert res.raw_output is not None, "raw_output is None"
111 |
112 |
113 | def test_run_and_parse_with_tools(mocker):
114 | # With parsing_type = None
115 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
116 | llm = LLM()
117 | res: LLMResponse = llm.run_and_parse(
118 | messages=messages,
119 | tools=tool_schemas,
120 | model="gpt-4-1106-preview",
121 | parsing_type=None,
122 | )
123 |
124 | assert res.error is False, "error is not False"
125 | assert res.api_response is not None, "api_response is None"
126 | assert (
127 | res.api_response.choices[0].finish_reason == "tool_calls"
128 | ), "finish_reason is not tool_calls"
129 |
130 | assert res.tool_calls is not None, "tool_calls is None"
131 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall)
132 |
133 | messages = [{"role": "user", "content": "Hello, How are you?"}]
134 |
135 | res: LLMResponse = llm.run_and_parse(
136 | messages=messages,
137 | tools=tool_schemas,
138 | model="gpt-4-1106-preview",
139 | parsing_type=None,
140 | )
141 |
142 | assert res.error is False, "error is not False"
143 | assert res.api_response is not None, "api_response is None"
144 | assert (
145 | res.api_response.choices[0].finish_reason == "stop"
146 | ), "finish_reason is not stop"
147 |
148 | assert res.tool_calls is None, "tool_calls is not None"
149 | assert res.raw_output is not None, "raw_output is None"
150 |
151 | # with parsing_type = "HTML"
152 |
153 | messages = [
154 | {
155 | "role": "user",
156 | "content": "What is the weather like in Boston? \n" + html_output_format,
157 | }
158 | ]
159 | llm = LLM()
160 | res: LLMResponse = llm.run_and_parse(
161 | messages=messages,
162 | tools=tool_schemas,
163 | model="gpt-4-1106-preview",
164 | parsing_type=ParsingType.HTML.value,
165 | output_keys=["response"],
166 | )
167 |
168 | # 1. Output 지키고 function call -> (Pass)
169 | # 2. Output 지키고 stop -> OK
170 | # 3. Output 무시하고 function call -> OK (function call이 나타나면 파싱을 하지 않도록 수정)
171 |
172 | # In this case, error is True because the output is not in the correct format
173 | assert res.error is False, "error is not False"
174 | assert res.api_response is not None, "api_response is None"
175 | assert (
176 | res.api_response.choices[0].finish_reason == "tool_calls"
177 | ), "finish_reason is not tool_calls"
178 |
179 | assert res.tool_calls is not None, "tool_calls is None"
180 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall)
181 |
182 | assert res.parsed_outputs is None, "parsed_outputs is not empty"
183 |
184 | messages = [
185 | {
186 | "role": "user",
187 | "content": "Hello, How are you?\n" + html_output_format,
188 | }
189 | ]
190 |
191 | res: LLMResponse = llm.run_and_parse(
192 | messages=messages,
193 | tools=tool_schemas,
194 | model="gpt-4-1106-preview",
195 | parsing_type=ParsingType.HTML.value,
196 | output_keys=["response"],
197 | )
198 |
199 | print(res.__dict__)
200 |
201 | if not "str" in res.raw_output:
202 | # if "str" in res.raw_output, it means that LLM make mistakes
203 | assert res.error is False, "error is not False"
204 | assert res.parsed_outputs is not None, "parsed_outputs is None"
205 |
206 | assert res.api_response is not None, "api_response is None"
207 | assert (
208 | res.api_response.choices[0].finish_reason == "stop"
209 | ), "finish_reason is not stop"
210 |
211 | assert res.tool_calls is None, "tool_calls is not None"
212 | assert res.raw_output is not None, "raw_output is None"
213 |
214 |
215 | @pytest.mark.asyncio
216 | async def test_arun_and_parse_with_tools(mocker):
217 | # With parsing_type = None
218 | messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
219 | llm = LLM()
220 | res: LLMResponse = await llm.arun_and_parse(
221 | messages=messages,
222 | tools=tool_schemas,
223 | model="gpt-4-1106-preview",
224 | parsing_type=None,
225 | )
226 |
227 | assert res.error is False, "error is not False"
228 | assert res.api_response is not None, "api_response is None"
229 | assert (
230 | res.api_response.choices[0].finish_reason == "tool_calls"
231 | ), "finish_reason is not tool_calls"
232 | print(res)
233 | assert res.tool_calls is not None, "tool_calls is None"
234 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall)
235 |
236 | messages = [{"role": "user", "content": "Hello, How are you?"}]
237 |
238 | res: LLMResponse = await llm.arun_and_parse(
239 | messages=messages,
240 | tools=tool_schemas,
241 | model="gpt-4-1106-preview",
242 | parsing_type=None,
243 | )
244 |
245 | assert res.error is False, "error is not False"
246 | assert res.api_response is not None, "api_response is None"
247 | assert (
248 | res.api_response.choices[0].finish_reason == "stop"
249 | ), "finish_reason is not stop"
250 |
251 | assert res.tool_calls is None, "tool_calls is not None"
252 | assert res.raw_output is not None, "raw_output is None"
253 |
254 | # with parsing_type = "HTML"
255 |
256 | messages = [
257 | {
258 | "role": "user",
259 | "content": "What is the weather like in Boston? \n" + html_output_format,
260 | }
261 | ]
262 | llm = LLM()
263 | res: LLMResponse = await llm.arun_and_parse(
264 | messages=messages,
265 | tools=tool_schemas,
266 | model="gpt-4-1106-preview",
267 | parsing_type=ParsingType.HTML.value,
268 | output_keys=["response"],
269 | )
270 |
271 | # In this case, error is False becuase if tool_calls, parsing is not performed
272 | assert res.error is False, "error is not False"
273 | assert res.api_response is not None, "api_response is None"
274 | assert (
275 | res.api_response.choices[0].finish_reason == "tool_calls"
276 | ), "finish_reason is not tool_calls"
277 |
278 | assert res.tool_calls is not None, "tool_calls is None"
279 | assert isinstance(res.tool_calls[0], ChatCompletionMessageToolCall)
280 |
281 | assert res.parsed_outputs is None, "parsed_outputs is not empty"
282 |
283 | messages = [
284 | {
285 | "role": "user",
286 | "content": "Hello, How are you?\n" + html_output_format,
287 | }
288 | ]
289 |
290 | res: LLMResponse = await llm.arun_and_parse(
291 | messages=messages,
292 | tools=tool_schemas,
293 | model="gpt-4-1106-preview",
294 | parsing_type=ParsingType.HTML.value,
295 | output_keys=["response"],
296 | )
297 | if not "str" in res.raw_output:
298 | assert res.error is False, "error is not False"
299 | assert res.parsed_outputs is not None, "parsed_outputs is None"
300 |
301 | assert res.api_response is not None, "api_response is None"
302 | assert (
303 | res.api_response.choices[0].finish_reason == "stop"
304 | ), "finish_reason is not stop"
305 |
306 | assert res.tool_calls is None, "tool_calls is not None"
307 | assert res.raw_output is not None, "raw_output is None"
308 |
--------------------------------------------------------------------------------
/tests/utils/async_util_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import asyncio
3 | import nest_asyncio
4 | from typing import Coroutine
5 | from unittest.mock import AsyncMock
6 |
7 | from promptmodel.utils.async_utils import run_async_in_sync
8 |
9 | nest_asyncio.apply()
10 |
11 |
12 | def test_sync_context(mocker):
13 | coro = AsyncMock(return_value="test")
14 | res = run_async_in_sync(coro())
15 | coro.assert_called_once()
16 |
17 | assert res == "test", "res is not test"
18 |
19 |
20 | def test_sync_async_context(mocker):
21 | coro = AsyncMock(return_value="test")
22 |
23 | async def async_context(coro: Coroutine):
24 | res = await coro
25 | return res
26 |
27 | res = asyncio.run(async_context(coro()))
28 | coro.assert_called_once()
29 |
30 | assert res == "test", "res is not test"
31 |
32 |
33 | @pytest.mark.asyncio
34 | async def test_async_context(mocker):
35 | coro = AsyncMock(return_value="test")
36 | res = await coro()
37 | coro.assert_called_once()
38 |
39 | assert res == "test", "res is not test"
40 |
41 |
42 | @pytest.mark.asyncio
43 | async def test_async_sync_context(mocker):
44 | coro = AsyncMock(return_value="test")
45 | print("ready")
46 |
47 | def sync_context(coro: Coroutine):
48 | return run_async_in_sync(coro)
49 |
50 | res = sync_context(coro())
51 | coro.assert_called_once()
52 |
53 | assert res == "test", "res is not test"
54 |
--------------------------------------------------------------------------------
/tests/websocket_client/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weavel-ai/promptmodel-python/a3b8d1095bc3abfd463b92722b9efd945fe75ea3/tests/websocket_client/__init__.py
--------------------------------------------------------------------------------
/tests/websocket_client/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | from promptmodel.websocket.websocket_client import DevWebsocketClient
5 | from promptmodel.dev_app import DevApp
6 |
7 |
8 | async def echo_coroutine(*args, **kwargs):
9 | # print(args, kwargs)
10 | return args, kwargs
11 |
12 |
13 | @pytest.fixture
14 | def websocket_client():
15 | websocket_client = DevWebsocketClient(_devapp=DevApp())
16 | return websocket_client
17 |
18 |
19 | @pytest.fixture
20 | def mock_websocket():
21 | # 모의 WebSocketClientProtocol 객체 생성
22 | mock_websocket = AsyncMock()
23 |
24 | async def aenter(self):
25 | return self
26 |
27 | async def aexit(self, exc_type, exc_value, traceback):
28 | pass
29 |
30 | mock_websocket.__aenter__ = aenter
31 | mock_websocket.__aexit__ = aexit
32 | mock_websocket.recv = AsyncMock(return_value='{"key" : "value"}')
33 | mock_websocket.send = AsyncMock()
34 |
35 | return mock_websocket
36 |
37 |
38 | @pytest.fixture
39 | def mock_json_dumps():
40 | mock_json_dumps = MagicMock()
41 | # it should return exactly the same as the input
42 | mock_json_dumps.side_effect = lambda data, *args, **kwargs: data
43 | return mock_json_dumps
44 |
--------------------------------------------------------------------------------
/tests/websocket_client/local_task_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 | import asyncio
4 |
5 | from websockets.exceptions import ConnectionClosedOK
6 | from promptmodel.websocket.websocket_client import DevWebsocketClient
7 | from promptmodel.types.enums import LocalTask
8 |
9 |
10 | @pytest.mark.asyncio
11 | async def test_connect_to_gateway(
12 | mocker, websocket_client: DevWebsocketClient, mock_websocket: AsyncMock
13 | ):
14 | project_uuid = "test_uuid"
15 | connection_name = "test_project"
16 | cli_access_header = {"Authorization": "Bearer testtoken"}
17 |
18 | with patch.object(
19 | websocket_client, "_DevWebsocketClient__handle_message", new_callable=AsyncMock
20 | ) as mock_function:
21 | mock_function.side_effect = ConnectionClosedOK(None, None)
22 | with patch(
23 | "promptmodel.websocket.websocket_client.connect",
24 | new_callable=MagicMock,
25 | return_value=mock_websocket,
26 | ) as mock_connect:
27 | # 5초 후에 자동으로 테스트 종료
28 | await websocket_client.connect_to_gateway(
29 | project_uuid, connection_name, cli_access_header, retries=1
30 | ),
31 | mock_connect.assert_called_once()
32 | mock_websocket.recv.assert_called_once()
33 | websocket_client._DevWebsocketClient__handle_message.assert_called_once()
34 |
35 |
36 | @pytest.mark.asyncio
37 | async def test_local_tasks(
38 | mocker, websocket_client: DevWebsocketClient, mock_websocket: AsyncMock
39 | ):
40 | websocket_client._devapp.function_models = {}
41 | websocket_client._devapp.samples = {}
42 | websocket_client._devapp.functions = {"test_function": "test_function"}
43 |
44 | await websocket_client._DevWebsocketClient__handle_message(
45 | message={"type": LocalTask.LIST_CODE_FUNCTIONS}, ws=mock_websocket
46 | )
47 |
--------------------------------------------------------------------------------
/tests/websocket_client/run_chatmodel_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | import asyncio
5 | from typing import Optional
6 | from uuid import uuid4
7 | from dataclasses import dataclass
8 | from websockets.exceptions import ConnectionClosedOK
9 | from promptmodel.websocket.websocket_client import DevWebsocketClient
10 | from promptmodel.types.enums import LocalTask, LocalTaskErrorType
11 | from promptmodel.types.response import FunctionSchema
12 |
13 |
14 | @dataclass
15 | class ChatModelInterface:
16 | name: str
17 | default_model: str = "gpt-3.5-turbo"
18 |
19 |
20 | def get_current_weather(location: str, unit: Optional[str] = "celsius"):
21 | return "13 degrees celsius"
22 |
23 |
24 | get_current_weather_desc = {
25 | "name": "get_current_weather",
26 | "description": "Get the current weather in a given location",
27 | "parameters": {
28 | "type": "object",
29 | "properties": {
30 | "location": {
31 | "type": "string",
32 | "description": "The city and state, e.g. San Francisco, CA",
33 | },
34 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
35 | },
36 | "required": ["location"],
37 | },
38 | }
39 |
40 |
41 | @pytest.mark.asyncio
42 | async def test_run_model_function_call(
43 | mocker,
44 | websocket_client: DevWebsocketClient,
45 | mock_websocket: AsyncMock,
46 | mock_json_dumps: MagicMock,
47 | ):
48 | websocket_client._devapp.chat_models = [ChatModelInterface("test_module")]
49 | websocket_client._devapp.samples = {
50 | "sample_1": {"user_message": "What is the weather like in Boston?"}
51 | }
52 | websocket_client._devapp.functions = {
53 | "get_current_weather": {
54 | "schema": FunctionSchema(**get_current_weather_desc),
55 | "function": get_current_weather,
56 | }
57 | }
58 |
59 | function_schemas_in_db = [FunctionSchema(**get_current_weather_desc).model_dump()]
60 |
61 | function_schemas_in_db[0]["mock_response"] = "13"
62 |
63 | mocker.patch("promptmodel.websocket.websocket_client.json.dumps", mock_json_dumps)
64 |
65 | # success case with function_call
66 | await websocket_client._DevWebsocketClient__handle_message(
67 | message={
68 | "type": LocalTask.RUN_CHAT_MODEL,
69 | "chat_model_name": "test_module",
70 | "system_prompt": {
71 | "role": "system",
72 | "content": "You are a helpful assistant.",
73 | },
74 | "old_messages": [],
75 | "new_messages": [
76 | {
77 | "role": "user",
78 | "content": "What is the weather in Boston?",
79 | }
80 | ],
81 | "model": "gpt-3.5-turbo",
82 | "functions": ["get_current_weather"],
83 | "function_schemas": function_schemas_in_db,
84 | },
85 | ws=mock_websocket,
86 | )
87 |
88 | call_args_list = mock_websocket.send.call_args_list
89 | data = [arg.args[0] for arg in call_args_list]
90 |
91 | assert len([d for d in data if d["status"] == "failed"]) == 0
92 | assert len([d for d in data if d["status"] == "completed"]) == 1
93 |
94 | assert len([d for d in data if "function_response" in d]) == 1
95 | assert [d for d in data if "function_response" in d][0]["function_response"][
96 | "name"
97 | ] == "get_current_weather"
98 |
99 | assert len([d for d in data if "function_call" in d]) > 1
100 | assert len([d for d in data if "raw_output" in d]) > 0
101 |
102 | mock_websocket.send.reset_mock()
103 | function_schemas_in_db[0]["mock_response"] = "13"
104 |
105 | print(
106 | "======================================================================================="
107 | )
108 |
109 | # success case with no function call
110 | await websocket_client._DevWebsocketClient__handle_message(
111 | message={
112 | "type": LocalTask.RUN_CHAT_MODEL,
113 | "chat_model_name": "test_module",
114 | "system_prompt": {
115 | "role": "system",
116 | "content": "You are a helpful assistant.",
117 | },
118 | "old_messages": [],
119 | "new_messages": [
120 | {
121 | "role": "user",
122 | "content": "Hello?",
123 | }
124 | ],
125 | "model": "gpt-3.5-turbo",
126 | },
127 | ws=mock_websocket,
128 | )
129 |
130 | call_args_list = mock_websocket.send.call_args_list
131 | data = [arg.args[0] for arg in call_args_list]
132 |
133 | assert len([d for d in data if d["status"] == "failed"]) == 0
134 | assert len([d for d in data if d["status"] == "completed"]) == 1
135 |
136 | assert len([d for d in data if "function_response" in d]) == 0
137 |
138 | assert len([d for d in data if "function_call" in d]) == 0
139 | assert len([d for d in data if "raw_output" in d]) > 0
140 |
141 | mock_websocket.send.reset_mock()
142 | function_schemas_in_db[0]["mock_response"] = "13"
143 |
144 | print(
145 | "======================================================================================="
146 | )
147 |
148 | # FUNCTION_CALL_FAILED_ERROR case
149 | def error_raise_function(*args, **kwargs):
150 | raise Exception("error")
151 |
152 | websocket_client._devapp.functions = {
153 | "get_current_weather": {
154 | "schema": FunctionSchema(**get_current_weather_desc),
155 | "function": error_raise_function,
156 | }
157 | }
158 |
159 | await websocket_client._DevWebsocketClient__handle_message(
160 | message={
161 | "type": LocalTask.RUN_CHAT_MODEL,
162 | "chat_model_name": "test_module",
163 | "system_prompt": {
164 | "role": "system",
165 | "content": "You are a helpful assistant.",
166 | },
167 | "old_messages": [],
168 | "new_messages": [
169 | {
170 | "role": "user",
171 | "content": "What is the weather in Boston?",
172 | }
173 | ],
174 | "model": "gpt-3.5-turbo",
175 | "functions": ["get_current_weather"],
176 | "function_schemas": function_schemas_in_db,
177 | },
178 | ws=mock_websocket,
179 | )
180 | call_args_list = mock_websocket.send.call_args_list
181 | # print(call_args_list)
182 | data = [arg.args[0] for arg in call_args_list]
183 |
184 | assert len([d for d in data if d["status"] == "failed"]) == 1
185 | assert [d for d in data if d["status"] == "failed"][0][
186 | "error_type"
187 | ] == LocalTaskErrorType.FUNCTION_CALL_FAILED_ERROR.value
188 | assert len([d for d in data if d["status"] == "completed"]) == 0
189 | assert len([d for d in data if "function_response" in d]) == 0
190 | assert len([d for d in data if "function_call" in d]) > 1
191 | assert len([d for d in data if "raw_output" in d]) == 0
192 | mock_websocket.send.reset_mock()
193 | function_schemas_in_db[0]["mock_response"] = "13"
194 |
195 | # function not in code case, should use mock_response
196 | websocket_client._devapp.functions = {}
197 | await websocket_client._DevWebsocketClient__handle_message(
198 | message={
199 | "type": LocalTask.RUN_CHAT_MODEL,
200 | "chat_model_name": "test_module",
201 | "system_prompt": {
202 | "role": "system",
203 | "content": "You are a helpful assistant.",
204 | },
205 | "old_messages": [],
206 | "new_messages": [
207 | {
208 | "role": "user",
209 | "content": "What is the weather in Boston?",
210 | }
211 | ],
212 | "model": "gpt-3.5-turbo",
213 | "functions": ["get_weather"],
214 | "function_schemas": function_schemas_in_db,
215 | },
216 | ws=mock_websocket,
217 | )
218 | call_args_list = mock_websocket.send.call_args_list
219 | print(call_args_list)
220 | data = [arg.args[0] for arg in call_args_list]
221 |
222 | assert len([d for d in data if d["status"] == "failed"]) == 0
223 | assert len([d for d in data if d["status"] == "completed"]) == 1
224 |
225 | assert len([d for d in data if "function_response" in d]) == 1
226 | assert (
227 | "FAKE RESPONSE"
228 | in [d for d in data if "function_response" in d][0]["function_response"][
229 | "response"
230 | ]
231 | )
232 | assert len([d for d in data if "function_call" in d]) > 1
233 | assert len([d for d in data if "raw_output" in d]) > 0
234 |
--------------------------------------------------------------------------------
/tests/websocket_client/run_promptmodel_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, patch, MagicMock
3 |
4 | import asyncio
5 | from typing import Optional
6 | from uuid import uuid4
7 | from dataclasses import dataclass
8 | from websockets.exceptions import ConnectionClosedOK
9 | from promptmodel.websocket.websocket_client import DevWebsocketClient
10 | from promptmodel.types.enums import LocalTask, ParsingType, LocalTaskErrorType
11 | from promptmodel.types.response import FunctionSchema
12 |
13 |
14 | @dataclass
15 | class FunctionModelInterface:
16 | name: str
17 | default_model: str = "gpt-3.5-turbo"
18 |
19 |
20 | def get_current_weather(location: str, unit: Optional[str] = "celsius"):
21 | return "13 degrees celsius"
22 |
23 |
24 | get_current_weather_desc = {
25 | "name": "get_current_weather",
26 | "description": "Get the current weather in a given location",
27 | "parameters": {
28 | "type": "object",
29 | "properties": {
30 | "location": {
31 | "type": "string",
32 | "description": "The city and state, e.g. San Francisco, CA",
33 | },
34 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
35 | },
36 | "required": ["location"],
37 | },
38 | }
39 |
40 |
41 | @pytest.mark.asyncio
42 | async def test_run_model_function_call(
43 | mocker,
44 | websocket_client: DevWebsocketClient,
45 | mock_websocket: AsyncMock,
46 | mock_json_dumps: MagicMock,
47 | ):
48 | websocket_client._devapp.function_models = [FunctionModelInterface("test_module")]
49 |
50 | websocket_client._devapp.functions = {
51 | "get_current_weather": {
52 | "schema": FunctionSchema(**get_current_weather_desc),
53 | "function": get_current_weather,
54 | }
55 | }
56 |
57 | function_schemas_in_db = [FunctionSchema(**get_current_weather_desc).model_dump()]
58 |
59 | function_schemas_in_db[0]["mock_response"] = "13"
60 |
61 | mocker.patch("promptmodel.websocket.websocket_client.json.dumps", mock_json_dumps)
62 |
63 | # success case
64 | await websocket_client._DevWebsocketClient__handle_message(
65 | message={
66 | "type": LocalTask.RUN_PROMPT_MODEL,
67 | "messages_for_run": [
68 | {
69 | "role": "system",
70 | "content": "You are a helpful assistant.",
71 | "step": 1,
72 | },
73 | {
74 | "role": "user",
75 | "content": "What is the weather like in Boston?",
76 | "step": 2,
77 | },
78 | ],
79 | "model": "gpt-3.5-turbo",
80 | "parsing_type": None,
81 | "output_keys": None,
82 | "functions": ["get_current_weather"],
83 | "function_schemas": function_schemas_in_db,
84 | },
85 | ws=mock_websocket,
86 | )
87 | call_args_list = mock_websocket.send.call_args_list
88 | # print(call_args_list)
89 | data = [arg.args[0] for arg in call_args_list]
90 |
91 | assert len([d for d in data if d["status"] == "failed"]) == 0
92 | assert len([d for d in data if d["status"] == "completed"]) == 1
93 | assert len([d for d in data if "function_response" in d]) == 1
94 | assert [d for d in data if "function_response" in d][0]["function_response"][
95 | "name"
96 | ] == "get_current_weather"
97 | assert len([d for d in data if "function_call" in d]) == 1
98 | assert len([d for d in data if "raw_output" in d]) == 0
99 | mock_websocket.send.reset_mock()
100 | function_schemas_in_db[0]["mock_response"] = "13"
101 |
102 | print(
103 | "======================================================================================="
104 | )
105 |
106 | # success case with no function call
107 | await websocket_client._DevWebsocketClient__handle_message(
108 | message={
109 | "type": LocalTask.RUN_PROMPT_MODEL,
110 | "messages_for_run": [
111 | {
112 | "role": "system",
113 | "content": "You are a helpful assistant.",
114 | "step": 1,
115 | },
116 | {
117 | "role": "user",
118 | "content": "What is the weather like in Boston?",
119 | "step": 2,
120 | },
121 | ],
122 | "model": "gpt-3.5-turbo",
123 | "parsing_type": None,
124 | "output_keys": None,
125 | },
126 | ws=mock_websocket,
127 | )
128 | call_args_list = mock_websocket.send.call_args_list
129 | # print(call_args_list)
130 | data = [arg.args[0] for arg in call_args_list]
131 |
132 | assert len([d for d in data if "function_call" in d]) == 0
133 | assert len([d for d in data if "raw_output" in d]) > 0
134 | assert len([d for d in data if d["status"] == "failed"]) == 0
135 | assert len([d for d in data if d["status"] == "completed"]) == 1
136 | mock_websocket.send.reset_mock()
137 |
138 | print(
139 | "======================================================================================="
140 | )
141 |
142 | # system_prompt_with_format = """
143 | # You are a helpful assistant.
144 |
145 | # This is your output format. Keep the string between "[" and "]" as same as given below. You should response content first before call functions. You MUST FOLLOW this output format.
146 | # Output Format:
147 | # [response type=str]
148 | # (value here)
149 | # [/response]
150 | # """
151 | # await websocket_client._DevWebsocketClient__handle_message(
152 | # message={
153 | # "type": LocalTask.RUN_PROMPT_MODEL,
154 | # "messages_for_run": [
155 | # {"role": "system", "content": system_prompt_with_format, "step": 1},
156 | # {
157 | # "role": "user",
158 | # "content": "What is the weather like in Boston?",
159 | # "step": 2,
160 | # },
161 | # ],
162 | # "model": "gpt-4-1106-preview",
163 | # "uuid": None,
164 | # "from_uuid": None,
165 | # "parsing_type": ParsingType.SQUARE_BRACKET.value,
166 | # "output_keys": ["response"],
167 | # "functions": ["get_current_weather"],
168 | # "function_schemas": function_schemas_in_db,
169 | # },
170 | # ws=mock_websocket,
171 | # )
172 | # call_args_list = mock_websocket.send.call_args_list
173 | # print(call_args_list)
174 | # data = [arg.args[0] for arg in call_args_list]
175 |
176 | # assert len([d for d in data if d["status"] == "failed"]) == 0
177 | # assert len([d for d in data if d["status"] == "completed"]) == 1
178 | # assert len([d for d in data if "function_response" in d]) == 1
179 | # assert [d for d in data if "function_response" in d][0]["function_response"][
180 | # "name"
181 | # ] == "get_current_weather"
182 | # assert len([d for d in data if "function_call" in d]) == 1
183 | # assert len([d for d in data if "raw_output" in d]) > 0
184 | # assert len([d for d in data if "parsed_outputs" in d]) > 0
185 | # assert (
186 | # len(
187 | # list(
188 | # set([d["parsed_outputs"].keys() for d in data if "parsed_outputs" in d])
189 | # )
190 | # )
191 | # == 1
192 | # )
193 | # mock_websocket.send.reset_mock()
194 | # function_schemas_in_db[0]["mock_response"] = "13"
195 |
196 | print(
197 | "======================================================================================="
198 | )
199 |
200 | # FUNCTION_CALL_FAILED_ERROR case
201 | def error_raise_function(*args, **kwargs):
202 | raise Exception("error")
203 |
204 | websocket_client._devapp.functions = {
205 | "get_current_weather": {
206 | "schema": FunctionSchema(**get_current_weather_desc),
207 | "function": error_raise_function,
208 | }
209 | }
210 |
211 | # success case
212 | await websocket_client._DevWebsocketClient__handle_message(
213 | message={
214 | "type": LocalTask.RUN_PROMPT_MODEL,
215 | "messages_for_run": [
216 | {
217 | "role": "system",
218 | "content": "You are a helpful assistant.",
219 | "step": 1,
220 | },
221 | {
222 | "role": "user",
223 | "content": "What is the weather like in Boston?",
224 | "step": 2,
225 | },
226 | ],
227 | "model": "gpt-3.5-turbo",
228 | "parsing_type": None,
229 | "output_keys": None,
230 | "functions": ["get_current_weather"],
231 | "function_schemas": function_schemas_in_db,
232 | },
233 | ws=mock_websocket,
234 | )
235 | call_args_list = mock_websocket.send.call_args_list
236 | # print(call_args_list)
237 | data = [arg.args[0] for arg in call_args_list]
238 |
239 | assert len([d for d in data if d["status"] == "failed"]) == 1
240 | assert [d for d in data if d["status"] == "failed"][0][
241 | "error_type"
242 | ] == LocalTaskErrorType.FUNCTION_CALL_FAILED_ERROR.value
243 | assert len([d for d in data if d["status"] == "completed"]) == 0
244 | assert len([d for d in data if "function_response" in d]) == 0
245 | assert len([d for d in data if "function_call" in d]) == 1
246 | assert len([d for d in data if "raw_output" in d]) == 0
247 | mock_websocket.send.reset_mock()
248 | function_schemas_in_db[0]["mock_response"] = "13"
249 |
250 | # PARSING_FAILED_ERROR case
251 |
252 | system_prompt_with_format = """
253 | You are a helpful assistant.
254 |
255 | This is your output format. Keep the string between "[" and "]" as same as given below. You MUST FOLLOW this output format.
256 | Output Format:
257 | [response type=str]
258 | (value here)
259 | [/response]
260 | """
261 | await websocket_client._DevWebsocketClient__handle_message(
262 | message={
263 | "type": LocalTask.RUN_PROMPT_MODEL,
264 | "messages_for_run": [
265 | {"role": "system", "content": system_prompt_with_format, "step": 1},
266 | {
267 | "role": "user",
268 | "content": "Hello!",
269 | "step": 2,
270 | },
271 | ],
272 | "model": "gpt-4-1106-preview",
273 | "parsing_type": ParsingType.SQUARE_BRACKET.value,
274 | "output_keys": ["respond"],
275 | "functions": ["get_current_weather"],
276 | "function_schemas": function_schemas_in_db,
277 | },
278 | ws=mock_websocket,
279 | )
280 | call_args_list = mock_websocket.send.call_args_list
281 | # print(call_args_list)
282 | data = [arg.args[0] for arg in call_args_list]
283 |
284 | assert len([d for d in data if d["status"] == "failed"]) == 1
285 | assert [d for d in data if d["status"] == "failed"][0][
286 | "error_type"
287 | ] == LocalTaskErrorType.PARSING_FAILED_ERROR.value
288 | assert len([d for d in data if d["status"] == "completed"]) == 0
289 | assert len([d for d in data if "function_response" in d]) == 0
290 | assert len([d for d in data if "function_call" in d]) == 0
291 | assert len([d for d in data if "raw_output" in d]) > 0
292 | assert len([d for d in data if "parsed_outputs" in d]) > 0
293 |
294 | mock_websocket.send.reset_mock()
295 | function_schemas_in_db[0]["mock_response"] = "13"
296 |
297 | # function not in code case, should use mock_response
298 | websocket_client._devapp.functions = {}
299 | await websocket_client._DevWebsocketClient__handle_message(
300 | message={
301 | "type": LocalTask.RUN_PROMPT_MODEL,
302 | "messages_for_run": [
303 | {
304 | "role": "system",
305 | "content": "You are a helpful assistant.",
306 | "step": 1,
307 | },
308 | {
309 | "role": "user",
310 | "content": "What is the weather like in Boston?",
311 | "step": 2,
312 | },
313 | ],
314 | "model": "gpt-3.5-turbo",
315 | "parsing_type": None,
316 | "output_keys": None,
317 | "functions": ["get_weather"],
318 | "function_schemas": function_schemas_in_db,
319 | },
320 | ws=mock_websocket,
321 | )
322 | call_args_list = mock_websocket.send.call_args_list
323 | print(call_args_list)
324 | data = [arg.args[0] for arg in call_args_list]
325 |
326 | assert len([d for d in data if d["status"] == "failed"]) == 0
327 | assert len([d for d in data if d["status"] == "completed"]) == 1
328 |
329 | assert len([d for d in data if "function_response" in d]) == 1
330 | assert (
331 | "FAKE RESPONSE"
332 | in [d for d in data if "function_response" in d][0]["function_response"][
333 | "response"
334 | ]
335 | )
336 | assert len([d for d in data if "function_call" in d]) == 1
337 | assert len([d for d in data if "raw_output" in d]) == 0
338 |
--------------------------------------------------------------------------------