├── .gitattributes └── sql.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /sql.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "os.environ[\"OPENAI_API_KEY\"] = \" \"" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "File downloaded and saved as Chinook.db\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import requests\n", 29 | "\n", 30 | "url = \"https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db\"\n", 31 | "\n", 32 | "response = requests.get(url)\n", 33 | "\n", 34 | "if response.status_code == 200:\n", 35 | " # Open a local file in binary write mode\n", 36 | " with open(\"Chinook.db\", \"wb\") as file:\n", 37 | " # Write the content of the response (the file) to the local file\n", 38 | " file.write(response.content)\n", 39 | " print(\"File downloaded and saved as Chinook.db\")\n", 40 | "else:\n", 41 | " print(f\"Failed to download the file. Status code: {response.status_code}\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "%%capture --no-stderr --no-display\n", 51 | "!pip install langgraph langchain_community langchain_openai" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "sqlite\n", 64 | "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" 65 | ] 66 | }, 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" 71 | ] 72 | }, 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | } 77 | ], 78 | "source": [ 79 | "from langchain_community.utilities import SQLDatabase\n", 80 | "\n", 81 | "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", 82 | "print(db.dialect)\n", 83 | "print(db.get_usable_table_names())\n", 84 | "db.run(\"SELECT * FROM Artist LIMIT 10;\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "from typing import Any\n", 94 | "\n", 95 | "from langchain_core.messages import ToolMessage\n", 96 | "from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks\n", 97 | "from langgraph.prebuilt import ToolNode\n", 98 | "\n", 99 | "\n", 100 | "def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:\n", 101 | " \"\"\"\n", 102 | " Create a ToolNode with a fallback to handle errors and surface them to the agent.\n", 103 | " \"\"\"\n", 104 | " return ToolNode(tools).with_fallbacks(\n", 105 | " [RunnableLambda(handle_tool_error)], exception_key=\"error\"\n", 106 | " )\n", 107 | "\n", 108 | "\n", 109 | "def handle_tool_error(state) -> dict:\n", 110 | " error = state.get(\"error\")\n", 111 | " tool_calls = state[\"messages\"][-1].tool_calls\n", 112 | " return {\n", 113 | " \"messages\": [\n", 114 | " ToolMessage(\n", 115 | " content=f\"Error: {repr(error)}\\n please fix your mistakes.\",\n", 116 | " tool_call_id=tc[\"id\"],\n", 117 | " )\n", 118 | " for tc in tool_calls\n", 119 | " ]\n", 120 | " }" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 6, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n", 133 | "\n", 134 | "CREATE TABLE \"Artist\" (\n", 135 | "\t\"ArtistId\" INTEGER NOT NULL, \n", 136 | "\t\"Name\" NVARCHAR(120), \n", 137 | "\tPRIMARY KEY (\"ArtistId\")\n", 138 | ")\n", 139 | "\n", 140 | "/*\n", 141 | "3 rows from Artist table:\n", 142 | "ArtistId\tName\n", 143 | "1\tAC/DC\n", 144 | "2\tAccept\n", 145 | "3\tAerosmith\n", 146 | "*/\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n", 152 | "from langchain_openai import ChatOpenAI\n", 153 | "\n", 154 | "toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model=\"gpt-4o\"))\n", 155 | "tools = toolkit.get_tools()\n", 156 | "\n", 157 | "list_tables_tool = next(tool for tool in tools if tool.name == \"sql_db_list_tables\")\n", 158 | "get_schema_tool = next(tool for tool in tools if tool.name == \"sql_db_schema\")\n", 159 | "\n", 160 | "print(list_tables_tool.invoke(\"\"))\n", 161 | "\n", 162 | "print(get_schema_tool.invoke(\"Artist\"))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 9, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "from langchain_core.tools import tool\n", 180 | "\n", 181 | "\n", 182 | "@tool\n", 183 | "def db_query_tool(query: str) -> str:\n", 184 | " \"\"\"\n", 185 | " Execute a SQL query against the database and get back the result.\n", 186 | " If the query is not correct, an error message will be returned.\n", 187 | " If an error is returned, rewrite the query, check the query, and try again.\n", 188 | " \"\"\"\n", 189 | " result = db.run_no_throw(query)\n", 190 | " if not result:\n", 191 | " return \"Error: Query failed. Please rewrite your query and try again.\"\n", 192 | " return result\n", 193 | "\n", 194 | "\n", 195 | "print(db_query_tool.invoke(\"SELECT * FROM Artist LIMIT 10;\"))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 54, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "text/plain": [ 206 | "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_uw4YDXhwEcgtPJpGWXEdnvqb', 'function': {'arguments': '{\\n \"query\": \"SELECT * FROM Artist LIMIT 10;\"\\n}', 'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 20, 'prompt_tokens': 222, 'total_tokens': 242}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_dd932ca5d1', 'finish_reason': 'stop', 'logprobs': None}, id='run-a0f2dfe8-99da-4593-87c6-0a8ad5d877b7-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': 'SELECT * FROM Artist LIMIT 10;'}, 'id': 'call_uw4YDXhwEcgtPJpGWXEdnvqb'}], usage_metadata={'input_tokens': 222, 'output_tokens': 20, 'total_tokens': 242})" 207 | ] 208 | }, 209 | "execution_count": 54, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | } 213 | ], 214 | "source": [ 215 | "from langchain_core.prompts import ChatPromptTemplate\n", 216 | "\n", 217 | "query_check_system = \"\"\"You are a SQL expert with a strong attention to detail.\n", 218 | "Double check the SQLite query for common mistakes, including:\n", 219 | "- Using NOT IN with NULL values\n", 220 | "- Using UNION when UNION ALL should have been used\n", 221 | "- Using BETWEEN for exclusive ranges\n", 222 | "- Data type mismatch in predicates\n", 223 | "- Properly quoting identifiers\n", 224 | "- Using the correct number of arguments for functions\n", 225 | "- Casting to the correct data type\n", 226 | "- Using the proper columns for joins\n", 227 | "\n", 228 | "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n", 229 | "\n", 230 | "You will call the appropriate tool to execute the query after running this check.\"\"\"\n", 231 | "\n", 232 | "query_check_prompt = ChatPromptTemplate.from_messages(\n", 233 | " [(\"system\", query_check_system), (\"placeholder\", \"{messages}\")]\n", 234 | ")\n", 235 | "query_check = query_check_prompt | ChatOpenAI(model=\"gpt-4o\", temperature=0).bind_tools(\n", 236 | " [db_query_tool], tool_choice=\"required\"\n", 237 | ")\n", 238 | "\n", 239 | "query_check.invoke({\"messages\": [(\"user\", \"SEECT * FROM Artist LIMIT 10;\")]})" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 51, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "from typing import Annotated, Literal\n", 249 | "\n", 250 | "from langchain_core.messages import AIMessage\n", 251 | "from langchain_core.pydantic_v1 import BaseModel, Field\n", 252 | "from langchain_openai import ChatOpenAI\n", 253 | "from typing_extensions import TypedDict\n", 254 | "\n", 255 | "from langgraph.graph import END, StateGraph, START\n", 256 | "from langgraph.graph.message import AnyMessage, add_messages\n", 257 | "\n", 258 | "\n", 259 | "# Define the state for the agent\n", 260 | "class State(TypedDict):\n", 261 | " messages: Annotated[list[AnyMessage], add_messages]\n", 262 | "\n", 263 | "\n", 264 | "# Define a new graph\n", 265 | "workflow = StateGraph(State)\n", 266 | "\n", 267 | "# Get relevant tables\n", 268 | "def list_table_node(state: State):\n", 269 | " messages = [\n", 270 | " state[\"messages\"],\n", 271 | " AIMessage(\n", 272 | " content=\"\",\n", 273 | " tool_calls=[\n", 274 | " {\n", 275 | " \"name\": \"sql_db_list_tables\",\n", 276 | " \"args\": {},\n", 277 | " \"id\": \"tool_abcd123\",\n", 278 | " }\n", 279 | " ],\n", 280 | " ),\n", 281 | " ToolMessage(\n", 282 | " content=list_tables_tool.run(\"get\"), \n", 283 | " name='sql_db_list_tables', \n", 284 | " tool_call_id='tool_abcd123'\n", 285 | " ) \n", 286 | " ]\n", 287 | "\n", 288 | " model_get_schema = ChatOpenAI(model=\"gpt-4o\", temperature=0).bind_tools(\n", 289 | " [get_schema_tool]\n", 290 | " )\n", 291 | " model_message = model_get_schema.invoke(messages)\n", 292 | " messages.append(model_message)\n", 293 | "\n", 294 | " return {\"messages\": messages}\n", 295 | "\n", 296 | "workflow.add_node(\"list_table_tools\", list_table_node)\n", 297 | "\n", 298 | "# Get relevant schema\n", 299 | "workflow.add_node(\"get_schema_tool\", create_tool_node_with_fallback([get_schema_tool]))\n", 300 | "\n", 301 | "# Describe a tool to represent the end state\n", 302 | "class SubmitFinalAnswer(BaseModel):\n", 303 | " \"\"\"Submit the final answer to the user based on the query results.\"\"\"\n", 304 | "\n", 305 | " final_answer: str = Field(..., description=\"The final answer to the user\")\n", 306 | "\n", 307 | "\n", 308 | "# Add a node for a model to generate a query based on the question and schema\n", 309 | "query_gen_system = \"\"\"You are a SQL expert with a strong attention to detail.\n", 310 | "\n", 311 | "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n", 312 | "\n", 313 | "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n", 314 | "\n", 315 | "When generating the query:\n", 316 | "\n", 317 | "Output the SQL query that answers the input question without a tool call.\n", 318 | "\n", 319 | "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n", 320 | "You can order the results by a relevant column to return the most interesting examples in the database.\n", 321 | "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", 322 | "\n", 323 | "If you get an error while executing a query, rewrite the query and try again.\n", 324 | "\n", 325 | "If you get an empty result set, you should try to rewrite the query to get a non-empty result set. \n", 326 | "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n", 327 | "\n", 328 | "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n", 329 | "\n", 330 | "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\"\"\"\n", 331 | "\n", 332 | "query_gen_prompt = ChatPromptTemplate.from_messages(\n", 333 | " [(\"system\", query_gen_system), (\"placeholder\", \"{messages}\")]\n", 334 | ")\n", 335 | "query_gen = query_gen_prompt | ChatOpenAI(model=\"gpt-4o\", temperature=0).bind_tools(\n", 336 | " [SubmitFinalAnswer]\n", 337 | ")\n", 338 | "\n", 339 | "\n", 340 | "def query_gen_node(state: State):\n", 341 | " message = query_gen.invoke(state)\n", 342 | "\n", 343 | " # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.\n", 344 | " tool_messages = []\n", 345 | " if message.tool_calls:\n", 346 | " for tc in message.tool_calls:\n", 347 | " if tc[\"name\"] != \"SubmitFinalAnswer\":\n", 348 | " tool_messages.append(\n", 349 | " ToolMessage(\n", 350 | " content=f\"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.\",\n", 351 | " tool_call_id=tc[\"id\"],\n", 352 | " )\n", 353 | " )\n", 354 | " else:\n", 355 | " tool_messages = []\n", 356 | " return {\"messages\": [message] + tool_messages}\n", 357 | "\n", 358 | "\n", 359 | "workflow.add_node(\"query_gen\", query_gen_node)\n", 360 | "\n", 361 | "\n", 362 | "# Add a node for the model to check the query before executing it\n", 363 | "def model_check_query(state: State) -> dict[str, list[AIMessage]]:\n", 364 | " \"\"\"\n", 365 | " Use this tool to double-check if your query is correct before executing it.\n", 366 | " \"\"\"\n", 367 | " return {\n", 368 | " \"messages\": [\n", 369 | " query_check.invoke({\"messages\": [state[\"messages\"][-1]]})]}\n", 370 | "\n", 371 | "workflow.add_node(\"correct_query\", model_check_query)\n", 372 | "\n", 373 | "\n", 374 | "# Add node for executing the query\n", 375 | "workflow.add_node(\"execute_query\", create_tool_node_with_fallback([db_query_tool]))\n", 376 | "\n", 377 | "\n", 378 | "\n", 379 | "# Define a conditional edge to decide whether to continue or end the workflow\n", 380 | "def should_continue(state: State) -> Literal[END, \"correct_query\", \"query_gen\"]:\n", 381 | " messages = state[\"messages\"]\n", 382 | " last_message = messages[-1]\n", 383 | " # If there is a tool call, then we finish\n", 384 | " if getattr(last_message, \"tool_calls\", None):\n", 385 | " return END\n", 386 | " if last_message.content.startswith(\"Error:\"):\n", 387 | " return \"query_gen\"\n", 388 | " else:\n", 389 | " return \"correct_query\"\n", 390 | "\n", 391 | "\n", 392 | "# Specify the edges between the nodes\n", 393 | "workflow.add_edge(START, \"list_table_tools\")\n", 394 | "workflow.add_edge(\"list_table_tools\", \"get_schema_tool\")\n", 395 | "workflow.add_edge(\"get_schema_tool\", \"query_gen\")\n", 396 | "workflow.add_conditional_edges(\n", 397 | " \"query_gen\",\n", 398 | " should_continue,\n", 399 | ")\n", 400 | "workflow.add_edge(\"correct_query\", \"execute_query\")\n", 401 | "workflow.add_edge(\"execute_query\", \"query_gen\")\n", 402 | "\n", 403 | "# Compile the workflow into a runnable\n", 404 | "app = workflow.compile()" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 56, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "data": { 414 | "image/jpeg": "", 415 | "text/plain": [ 416 | "" 417 | ] 418 | }, 419 | "metadata": {}, 420 | "output_type": "display_data" 421 | } 422 | ], 423 | "source": [ 424 | "from IPython.display import Image, display\n", 425 | "from langchain_core.runnables.graph import MermaidDrawMethod\n", 426 | "\n", 427 | "display(\n", 428 | " Image(\n", 429 | " app.get_graph().draw_mermaid_png(\n", 430 | " draw_method=MermaidDrawMethod.API,\n", 431 | " )\n", 432 | " )\n", 433 | ")" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 55, 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "'The sales agent who made the most in sales in 2009 is Steve Johnson with total sales of 164.34.'" 445 | ] 446 | }, 447 | "execution_count": 55, 448 | "metadata": {}, 449 | "output_type": "execute_result" 450 | } 451 | ], 452 | "source": [ 453 | "import json\n", 454 | "\n", 455 | "messages = app.invoke(\n", 456 | " {\"messages\": [(\"user\", \"Which sales agent made the most in sales in 2009?\")]}\n", 457 | ")\n", 458 | "json_str = messages[\"messages\"][-1].additional_kwargs[\"tool_calls\"][0][\"function\"][\n", 459 | " \"arguments\"\n", 460 | "]\n", 461 | "json.loads(json_str)[\"final_answer\"]" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 52, 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "name": "stdout", 471 | "output_type": "stream", 472 | "text": [ 473 | "{'list_table_tools': {'messages': [HumanMessage(content='Which sales agent made the most in sales in 2009?', id='5baa4d66-d372-4552-8453-a830b990b7f7'), AIMessage(content='', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123'}]), ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', tool_call_id='tool_abcd123'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_7X2B7tHwserfvHZ8rXwM8nFu', 'function': {'arguments': '{\"table_names\":\"Employee, Invoice\"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 177, 'total_tokens': 195}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d33f7b429e', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-ea46a77b-694f-40a5-b602-c4190df23554-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'Employee, Invoice'}, 'id': 'call_7X2B7tHwserfvHZ8rXwM8nFu'}], usage_metadata={'input_tokens': 177, 'output_tokens': 18, 'total_tokens': 195})]}}\n", 474 | "{'get_schema_tool': {'messages': [ToolMessage(content='\\nCREATE TABLE \"Employee\" (\\n\\t\"EmployeeId\" INTEGER NOT NULL, \\n\\t\"LastName\" NVARCHAR(20) NOT NULL, \\n\\t\"FirstName\" NVARCHAR(20) NOT NULL, \\n\\t\"Title\" NVARCHAR(30), \\n\\t\"ReportsTo\" INTEGER, \\n\\t\"BirthDate\" DATETIME, \\n\\t\"HireDate\" DATETIME, \\n\\t\"Address\" NVARCHAR(70), \\n\\t\"City\" NVARCHAR(40), \\n\\t\"State\" NVARCHAR(40), \\n\\t\"Country\" NVARCHAR(40), \\n\\t\"PostalCode\" NVARCHAR(10), \\n\\t\"Phone\" NVARCHAR(24), \\n\\t\"Fax\" NVARCHAR(24), \\n\\t\"Email\" NVARCHAR(60), \\n\\tPRIMARY KEY (\"EmployeeId\"), \\n\\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\\n)\\n\\n/*\\n3 rows from Employee table:\\nEmployeeId\\tLastName\\tFirstName\\tTitle\\tReportsTo\\tBirthDate\\tHireDate\\tAddress\\tCity\\tState\\tCountry\\tPostalCode\\tPhone\\tFax\\tEmail\\n1\\tAdams\\tAndrew\\tGeneral Manager\\tNone\\t1962-02-18 00:00:00\\t2002-08-14 00:00:00\\t11120 Jasper Ave NW\\tEdmonton\\tAB\\tCanada\\tT5K 2N1\\t+1 (780) 428-9482\\t+1 (780) 428-3457\\tandrew@chinookcorp.com\\n2\\tEdwards\\tNancy\\tSales Manager\\t1\\t1958-12-08 00:00:00\\t2002-05-01 00:00:00\\t825 8 Ave SW\\tCalgary\\tAB\\tCanada\\tT2P 2T3\\t+1 (403) 262-3443\\t+1 (403) 262-3322\\tnancy@chinookcorp.com\\n3\\tPeacock\\tJane\\tSales Support Agent\\t2\\t1973-08-29 00:00:00\\t2002-04-01 00:00:00\\t1111 6 Ave SW\\tCalgary\\tAB\\tCanada\\tT2P 5M5\\t+1 (403) 262-3443\\t+1 (403) 262-6712\\tjane@chinookcorp.com\\n*/\\n\\n\\nCREATE TABLE \"Invoice\" (\\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"InvoiceDate\" DATETIME NOT NULL, \\n\\t\"BillingAddress\" NVARCHAR(70), \\n\\t\"BillingCity\" NVARCHAR(40), \\n\\t\"BillingState\" NVARCHAR(40), \\n\\t\"BillingCountry\" NVARCHAR(40), \\n\\t\"BillingPostalCode\" NVARCHAR(10), \\n\\t\"Total\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceId\"), \\n\\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\\n)\\n\\n/*\\n3 rows from Invoice table:\\nInvoiceId\\tCustomerId\\tInvoiceDate\\tBillingAddress\\tBillingCity\\tBillingState\\tBillingCountry\\tBillingPostalCode\\tTotal\\n1\\t2\\t2009-01-01 00:00:00\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t1.98\\n2\\t4\\t2009-01-02 00:00:00\\tUllevålsveien 14\\tOslo\\tNone\\tNorway\\t0171\\t3.96\\n3\\t8\\t2009-01-03 00:00:00\\tGrétrystraat 63\\tBrussels\\tNone\\tBelgium\\t1000\\t5.94\\n*/', name='sql_db_schema', tool_call_id='call_7X2B7tHwserfvHZ8rXwM8nFu')]}}\n", 475 | "{'query_gen': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_xVTfaqdKieBCtQ6gRiPwm5O6', 'function': {'arguments': '{\"table_names\":\"Customer\"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 1179, 'total_tokens': 1195}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_dd932ca5d1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a43b968e-9fdb-4ae9-856d-dc4420230df5-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'Customer'}, 'id': 'call_xVTfaqdKieBCtQ6gRiPwm5O6'}], usage_metadata={'input_tokens': 1179, 'output_tokens': 16, 'total_tokens': 1195}), ToolMessage(content='Error: The wrong tool was called: sql_db_schema. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.', id='9e055c3b-c764-4323-8bf5-cdb215bd5124', tool_call_id='call_xVTfaqdKieBCtQ6gRiPwm5O6')]}}\n", 476 | "{'query_gen': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_NSM1mo7O2NxeU2PlPgKtZbWu', 'function': {'arguments': '{\"table_names\":[\"Customer\"]}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 1245, 'total_tokens': 1262}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d33f7b429e', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-70b0487b-f867-4d3b-a541-1ee42d0d1843-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': ['Customer']}, 'id': 'call_NSM1mo7O2NxeU2PlPgKtZbWu'}], usage_metadata={'input_tokens': 1245, 'output_tokens': 17, 'total_tokens': 1262}), ToolMessage(content='Error: The wrong tool was called: sql_db_schema. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.', id='f55d66fb-19cd-42dc-9728-863cd5daff38', tool_call_id='call_NSM1mo7O2NxeU2PlPgKtZbWu')]}}\n", 477 | "{'query_gen': {'messages': [AIMessage(content=\"To determine which sales agent made the most in sales in 2009, we need to join the `Invoice`, `Customer`, and `Employee` tables. Here's the query to find the top sales agent:\\n\\n```sql\\nSELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales\\nFROM Invoice i\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\nJOIN Employee e ON c.SupportRepId = e.EmployeeId\\nWHERE strftime('%Y', i.InvoiceDate) = '2009'\\nGROUP BY e.EmployeeId\\nORDER BY TotalSales DESC\\nLIMIT 1;\\n```\", response_metadata={'token_usage': {'completion_tokens': 124, 'prompt_tokens': 1312, 'total_tokens': 1436}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d33f7b429e', 'finish_reason': 'stop', 'logprobs': None}, id='run-53ed8165-f812-4892-bc2c-9d20cc8a9d24-0', usage_metadata={'input_tokens': 1312, 'output_tokens': 124, 'total_tokens': 1436})]}}\n", 478 | "{'correct_query': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_pDgr0GkUv0pqwXPngy58lL8B', 'function': {'arguments': '{\"query\":\"SELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales\\\\nFROM Invoice i\\\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\\\nJOIN Employee e ON c.SupportRepId = e.EmployeeId\\\\nWHERE strftime(\\'%Y\\', i.InvoiceDate) = \\'2009\\'\\\\nGROUP BY e.EmployeeId\\\\nORDER BY TotalSales DESC\\\\nLIMIT 1;\"}', 'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 90, 'prompt_tokens': 336, 'total_tokens': 426}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_dd932ca5d1', 'finish_reason': 'stop', 'logprobs': None}, id='run-79845919-aea8-4139-a50b-fc9696cb56aa-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': \"SELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales\\nFROM Invoice i\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\nJOIN Employee e ON c.SupportRepId = e.EmployeeId\\nWHERE strftime('%Y', i.InvoiceDate) = '2009'\\nGROUP BY e.EmployeeId\\nORDER BY TotalSales DESC\\nLIMIT 1;\"}, 'id': 'call_pDgr0GkUv0pqwXPngy58lL8B'}], usage_metadata={'input_tokens': 336, 'output_tokens': 90, 'total_tokens': 426})]}}\n", 479 | "{'execute_query': {'messages': [ToolMessage(content=\"[('Steve', 'Johnson', 164.34)]\", name='db_query_tool', tool_call_id='call_pDgr0GkUv0pqwXPngy58lL8B')]}}\n", 480 | "{'query_gen': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_A3CKekODdXBGOYtZwCeqovTT', 'function': {'arguments': '{\"final_answer\":\"The sales agent who made the most in sales in 2009 is Steve Johnson with total sales of 164.34.\"}', 'name': 'SubmitFinalAnswer'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 41, 'prompt_tokens': 1552, 'total_tokens': 1593}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d33f7b429e', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-476625ab-5983-4982-a484-406ab907bd30-0', tool_calls=[{'name': 'SubmitFinalAnswer', 'args': {'final_answer': 'The sales agent who made the most in sales in 2009 is Steve Johnson with total sales of 164.34.'}, 'id': 'call_A3CKekODdXBGOYtZwCeqovTT'}], usage_metadata={'input_tokens': 1552, 'output_tokens': 41, 'total_tokens': 1593})]}}\n" 481 | ] 482 | } 483 | ], 484 | "source": [ 485 | "for event in app.stream(\n", 486 | " {\"messages\": [(\"user\", \"Which sales agent made the most in sales in 2009?\")]}\n", 487 | "):\n", 488 | " print(event)" 489 | ] 490 | } 491 | ], 492 | "metadata": { 493 | "kernelspec": { 494 | "display_name": "Python 3", 495 | "language": "python", 496 | "name": "python3" 497 | }, 498 | "language_info": { 499 | "codemirror_mode": { 500 | "name": "ipython", 501 | "version": 3 502 | }, 503 | "file_extension": ".py", 504 | "mimetype": "text/x-python", 505 | "name": "python", 506 | "nbconvert_exporter": "python", 507 | "pygments_lexer": "ipython3", 508 | "version": "3.10.10" 509 | } 510 | }, 511 | "nbformat": 4, 512 | "nbformat_minor": 2 513 | } 514 | --------------------------------------------------------------------------------