"
25 | ]
26 | },
27 | "execution_count": 5,
28 | "metadata": {},
29 | "output_type": "execute_result"
30 | }
31 | ],
32 | "source": [
33 | "import mlflow\n",
34 | "from typing import Literal\n",
35 | "from langchain_core.messages import AIMessage, ToolCall\n",
36 | "from langchain_core.outputs import ChatGeneration, ChatResult\n",
37 | "from langchain_core.tools import tool\n",
38 | "from langchain_openai import ChatOpenAI\n",
39 | "from langgraph.prebuilt import create_react_agent\n",
40 | "from dotenv import load_dotenv\n",
41 | "load_dotenv()\n",
42 | "\n",
43 | "mlflow.langchain.autolog()\n",
44 | "\n",
45 | "mlflow.set_tracking_uri(\"http://localhost:5000\")\n",
46 | "mlflow.set_experiment(\"LangGraph\")\n"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "id": "b3f57637",
52 | "metadata": {},
53 | "source": [
54 | "### Define our Tool & Graph\n",
55 | "Below is the code snippet provided in your request. We define a simple tool to get weather (with limited city options) and create a ReAct-style agent using LangGraph."
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 6,
61 | "id": "a674fe47",
62 | "metadata": {
63 | "tags": []
64 | },
65 | "outputs": [],
66 | "source": [
67 | "@tool\n",
68 | "def get_weather(city: Literal[\"nyc\", \"sf\"]):\n",
69 | " \"\"\"Use this to get weather information.\"\"\"\n",
70 | " if city == \"nyc\":\n",
71 | " return \"It might be cloudy in nyc\"\n",
72 | " elif city == \"sf\":\n",
73 | " return \"It's always sunny in sf\"\n",
74 | "\n",
75 | "# Instantiate the LLM\n",
76 | "llm = ChatOpenAI(model=\"gpt-4o-mini\") # placeholder model name\n",
77 | "\n",
78 | "# Create the ReAct agent\n",
79 | "tools = [get_weather]\n",
80 | "graph = create_react_agent(llm, tools)\n"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "8716eea8",
86 | "metadata": {},
87 | "source": [
88 | "### Invoke the Graph\n",
89 | "We now call `graph.invoke` with a user request about the weather in SF. "
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 7,
95 | "id": "5d0631e3",
96 | "metadata": {
97 | "tags": []
98 | },
99 | "outputs": [
100 | {
101 | "name": "stdout",
102 | "output_type": "stream",
103 | "text": [
104 | "Agent response: {'messages': [HumanMessage(content='what is the weather in sf?', additional_kwargs={}, response_metadata={}, id='81a232ed-b6f0-4d47-8f3d-0c70c17aa4d1'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_yok4TNqxoU2s6vHoCZqyo4Jf', 'function': {'arguments': '{\"city\":\"sf\"}', 'name': 'get_weather'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 58, 'total_tokens': 73, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_27322b4e16', 'id': 'chatcmpl-BF4ucn1Ex6HVdykOaITebDcCZw9jQ', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-aeccf54e-6c93-4b46-86b5-3c5f12efbfe7-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_yok4TNqxoU2s6vHoCZqyo4Jf', 'type': 'tool_call'}], usage_metadata={'input_tokens': 58, 'output_tokens': 15, 'total_tokens': 73, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}), ToolMessage(content=\"It's always sunny in sf\", name='get_weather', id='a8767e92-3503-42e7-914f-8e6740be610b', tool_call_id='call_yok4TNqxoU2s6vHoCZqyo4Jf'), AIMessage(content='The weather in San Francisco is sunny!', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 85, 'total_tokens': 95, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_27322b4e16', 'id': 'chatcmpl-BF4udE60FlEmnIF7MPBzGvFrjoPGr', 'finish_reason': 'stop', 'logprobs': None}, id='run-4e5c87bf-caa4-43d9-9a59-64dc59c5007e-0', usage_metadata={'input_tokens': 85, 'output_tokens': 10, 'total_tokens': 95, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}\n"
105 | ]
106 | },
107 | {
108 | "data": {
109 | "text/html": [
110 | "\n",
111 | "\n",
112 | " \n",
129 | " \n",
139 | " \n",
144 | "
\n"
145 | ],
146 | "text/plain": [
147 | "Trace(request_id=747be10c0e7245de8616e89df06d26da)"
148 | ]
149 | },
150 | "metadata": {},
151 | "output_type": "display_data"
152 | }
153 | ],
154 | "source": [
155 | "result = graph.invoke({\n",
156 | " \"messages\": [\n",
157 | " {\"role\": \"user\", \"content\": \"what is the weather in sf?\"}\n",
158 | " ]\n",
159 | "})\n",
160 | "print(\"Agent response:\", result)"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "4f0ef46d",
167 | "metadata": {},
168 | "outputs": [
169 | {
170 | "ename": "SyntaxError",
171 | "evalue": "invalid syntax (2375966683.py, line 1)",
172 | "output_type": "error",
173 | "traceback": [
174 | "\u001b[1;36m Cell \u001b[1;32mIn[8], line 1\u001b[1;36m\u001b[0m\n\u001b[1;33m https://www.mlflow.org/docs/latest/tracing/api/manual-instrumentation/\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n"
175 | ]
176 | },
177 | {
178 | "ename": "",
179 | "evalue": "",
180 | "output_type": "error",
181 | "traceback": [
182 | "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
183 | "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
184 | "\u001b[1;31mClick here for more info. \n",
185 | "\u001b[1;31mView Jupyter log for further details."
186 | ]
187 | }
188 | ],
189 | "source": [
190 | "https://www.mlflow.org/docs/latest/tracing/api/manual-instrumentation/"
191 | ]
192 | }
193 | ],
194 | "metadata": {
195 | "colab": {
196 | "name": "mlflow_langgraph_example.ipynb"
197 | },
198 | "kernelspec": {
199 | "display_name": ".venv",
200 | "language": "python",
201 | "name": "python3"
202 | },
203 | "language_info": {
204 | "codemirror_mode": {
205 | "name": "ipython",
206 | "version": 3
207 | },
208 | "file_extension": ".py",
209 | "mimetype": "text/x-python",
210 | "name": "python",
211 | "nbconvert_exporter": "python",
212 | "pygments_lexer": "ipython3",
213 | "version": "3.11.0"
214 | }
215 | },
216 | "nbformat": 4,
217 | "nbformat_minor": 5
218 | }
219 |
--------------------------------------------------------------------------------
/responsesapi.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "af480eb1",
6 | "metadata": {},
7 | "source": [
8 | "# Responses API Tutorial: From Introduction to Summary\n",
9 | "\n",
10 | "Below is a comprehensive tutorial and set of examples demonstrating:\n",
11 | "\n",
12 | "1. **How the Responses API works and how it differs from Chat Completions** (particularly around stateful vs. stateless usage).\n",
13 | "2. **Examples** of multi-turn conversation (using `previous_response_id` for stateful flows) and built-in tools like web search and file search.\n",
14 | "3. **How to disable storage** (`store: false`) if you **do not** want your conversation state to persist on OpenAI’s servers—effectively making it stateless.\n",
15 | "\n",
16 | "---\n",
17 | "## 1. Chat Completions (Stateless) vs. Responses (Stateful)\n",
18 | "\n",
19 | "- **Chat Completions**:\n",
20 | " - Typically stateless: each new request must supply the entire conversation history in `messages`.\n",
21 | " - Stored by default only for new accounts; can be disabled.\n",
22 | "\n",
23 | "- **Responses**:\n",
24 | " - By default, **stateful**: each response has its own `id`. You can pass `previous_response_id` in subsequent calls, and the system automatically includes the prior context.\n",
25 | " - Provides built-in **tools** (web search, file search, etc.) that the model can call if relevant.\n",
26 | " - **Stored** by default. If you want ephemeral usage, set `store: false`.\n",
27 | "\n",
28 | "When you get a response back from the Responses API, the returned object differs slightly from Chat Completions:\n",
29 | "\n",
30 | "- Instead of a simple list of message choices, you receive a typed `response` object with top-level fields (e.g. `id`, `output`, `usage`, etc.).\n",
31 | "- To continue a conversation, pass `previous_response_id` to the next request.\n",
32 | "- If you do **not** want it stored, set `store: false`.\n"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "id": "98ba103f",
38 | "metadata": {},
39 | "source": [
40 | "---\n",
41 | "## 2. Multi-Turn Flow (Stateful) Example\n",
42 | "\n",
43 | "Using `previous_response_id` means the Responses API will store and automatically incorporate the entire conversation. Here’s a simple demonstration:"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "0bc2fb43",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "from dotenv import load_dotenv\n",
54 | "\n",
55 | "load_dotenv()"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "id": "f729187f",
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from openai import OpenAI\n",
66 | "\n",
67 | "client = OpenAI()\n",
68 | "\n",
69 | "\n",
70 | "resp1 = client.responses.create(\n",
71 | " model=\"gpt-4o-mini\",\n",
72 | " input=\"Hello there! You're a helpful math tutor. Could you help me with a question? What's 2 + 2?\"\n",
73 | ")\n",
74 | "print(\"First response:\\n\", resp1.output_text)\n"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "d1b8116f",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "resp2 = client.responses.create(\n",
85 | " model=\"gpt-4o\",\n",
86 | " input=\"Sure. How you you come to this conclusion?\",\n",
87 | " previous_response_id=resp1.id\n",
88 | ")\n",
89 | "print(\"\\nSecond response:\\n\", resp2.output_text)"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "id": "891a378a",
95 | "metadata": {},
96 | "source": [
97 | "---\n",
98 | "## 3. Using Built-In Tools\n",
99 | "\n",
100 | "### 3.1 Web Search\n",
101 | "Allows the model to gather recent info from the internet if relevant."
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "id": "2857562e",
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "# Example usage of the built-in web_search tool\n",
112 | "r1 = client.responses.create(\n",
113 | " model=\"gpt-4o\",\n",
114 | " input=\"Please find recent positive headlines about quantum computing.\",\n",
115 | " tools=[{\"type\": \"web_search\"}] # enabling built-in web search\n",
116 | ")\n",
117 | "print(r1.output_text)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "32a4f76c",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "# Continue the conversation referencing previous response\n",
128 | "r2 = client.responses.create(\n",
129 | " model=\"gpt-4o\",\n",
130 | " input=\"Interesting! Summarize the second article.\",\n",
131 | " previous_response_id=r1.id\n",
132 | ")\n",
133 | "print(\"\\nFollow-up:\\n\", r2.output_text)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "id": "599a206e",
139 | "metadata": {},
140 | "source": [
141 | "### 3.2 File Upload + File Search\n",
142 | "\n",
143 | "Below is the corrected snippet showing how to:\n",
144 | "1. **Upload** a local PDF (e.g., `dragon_book.pdf`).\n",
145 | "2. **Create** a vector store from that file.\n",
146 | "3. **Use** `file_search` in the Responses API to reference it.\n"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "id": "efc24105",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "upload_resp = client.files.create(\n",
157 | " file=open(\"dragon_book.txt\", \"rb\"),\n",
158 | " purpose=\"user_data\"\n",
159 | ")\n",
160 | "file_id = upload_resp.id\n",
161 | "print(\"Uploaded file ID:\", file_id)"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "id": "dc2b122e",
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "client.files.list()"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "id": "749870bc",
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "vstore_resp = client.vector_stores.create(\n",
182 | " name=\"DragonData\",\n",
183 | " file_ids=[file_id]\n",
184 | ")\n",
185 | "vstore_id = vstore_resp.id\n",
186 | "print(\"Vector store ID:\", vstore_id)"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "id": "dada70c3",
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "resp1 = client.responses.create(\n",
197 | " model=\"gpt-4o\",\n",
198 | " tools=[{\n",
199 | " \"type\": \"file_search\",\n",
200 | " \"vector_store_ids\": [vstore_id],\n",
201 | " \"max_num_results\": 3\n",
202 | " }],\n",
203 | " input=\"What Information do you have about red dragons?\"\n",
204 | ")\n",
205 | "print(resp1.output_text)"
206 | ]
207 | },
208 | {
209 | "cell_type": "markdown",
210 | "id": "ed7d1705",
211 | "metadata": {},
212 | "source": [
213 | "---\n",
214 | "## 4. Disable Storage (Stateless Mode)\n",
215 | "\n",
216 | "Although the Responses API is **stateful** by default, you can make calls **not** store any conversation by setting `store=False`. Then `previous_response_id` won’t work, because no data is retained on OpenAI’s servers."
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "id": "553d63e3",
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "# An ephemeral request that won't be stored\n",
227 | "ephemeral_resp = client.responses.create(\n",
228 | " model=\"gpt-4o\",\n",
229 | " input=\"Hello, let's do a single-turn question about geometry.\",\n",
230 | " store=False # ephemeral usage\n",
231 | ")\n",
232 | "print(ephemeral_resp.output_text)"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "id": "366df08f",
238 | "metadata": {},
239 | "source": [
240 | "### LangChain Integration"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": null,
246 | "id": "bcf04310",
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "from langchain_openai import ChatOpenAI\n",
251 | "\n",
252 | "llm = ChatOpenAI(model=\"gpt-4o-mini\", use_responses_api=True)\n",
253 | "\n",
254 | "\n",
255 | "tool = {\"type\": \"web_search_preview\"}\n",
256 | "\n",
257 | "\n",
258 | "llm_with_tools = llm.bind_tools([tool])\n",
259 | "\n",
260 | "\n",
261 | "response = llm_with_tools.invoke(input=\"What was a positive news story from today?\")\n",
262 | "\n",
263 | "print(\"Text content:\", response.content)\n",
264 | "print(\"Tool calls:\", response.tool_calls)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "id": "11d14f03",
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "from langchain_openai import ChatOpenAI\n",
275 | "\n",
276 | "llm_stateful = ChatOpenAI(\n",
277 | " model=\"gpt-4o-mini\",\n",
278 | " use_responses_api=True,\n",
279 | ")\n",
280 | "\n",
281 | "respA = llm_stateful.invoke(\"Hi, I'm Bob. Please remember my name.\")\n",
282 | "print(\"Response A:\", respA.content)\n",
283 | "print(\"A's ID:\", respA.response_metadata[\"id\"])\n",
284 | "\n",
285 | "respB = llm_stateful.invoke(\n",
286 | " \"What is my name?\",\n",
287 | " previous_response_id=respA.response_metadata[\"id\"]\n",
288 | ")\n",
289 | "print(\"Response B:\", respB.content)\n"
290 | ]
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": ".venv",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.11.0"
310 | }
311 | },
312 | "nbformat": 4,
313 | "nbformat_minor": 5
314 | }
315 |
--------------------------------------------------------------------------------
/agent_team.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### Team of Agents with a supervisor"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from langchain_openai import ChatOpenAI\n",
17 | "from langchain_core.prompts import ChatPromptTemplate\n",
18 | "from langchain_core.pydantic_v1 import BaseModel, Field\n",
19 | "\n",
20 | "\n",
21 | "class TransferNewsGrader(BaseModel):\n",
22 | " \"\"\"Binary score for relevance check on football transfer news.\"\"\"\n",
23 | "\n",
24 | " binary_score: str = Field(\n",
25 | " description=\"The article is about football transfers, 'yes' or 'no'\"\n",
26 | " )\n",
27 | "\n",
28 | "\n",
29 | "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
30 | "structured_llm_grader = llm.with_structured_output(TransferNewsGrader)\n",
31 | "\n",
32 | "system = \"\"\"You are a grader assessing whether a news article concerns a football transfer. \\n\n",
33 | " Check if the article explicitly mentions player transfers between clubs, potential transfers, or confirmed transfers. \\n\n",
34 | " Provide a binary score 'yes' or 'no' to indicate whether the news is about a football transfer.\"\"\"\n",
35 | "grade_prompt = ChatPromptTemplate.from_messages(\n",
36 | " [(\"system\", system), (\"human\", \"News Article:\\n\\n {article}\")]\n",
37 | ")\n",
38 | "evaluator = grade_prompt | structured_llm_grader\n",
39 | "result = evaluator.invoke(\n",
40 | " {\"There are rumors messi will switch from real madrid to FC Barcelona\"}\n",
41 | ")"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "from langchain_openai import ChatOpenAI\n",
51 | "from langchain_core.prompts import ChatPromptTemplate\n",
52 | "from langchain_core.pydantic_v1 import BaseModel, Field\n",
53 | "\n",
54 | "\n",
55 | "class ArticlePostabilityGrader(BaseModel):\n",
56 | " \"\"\"Binary scores for postability check, word count, sensationalism, and language verification of a news article.\"\"\"\n",
57 | "\n",
58 | " can_be_posted: str = Field(\n",
59 | " description=\"The article is ready to be posted, 'yes' or 'no'\"\n",
60 | " )\n",
61 | " meets_word_count: str = Field(\n",
62 | " description=\"The article has at least 200 words, 'yes' or 'no'\"\n",
63 | " )\n",
64 | " is_sensationalistic: str = Field(\n",
65 | " description=\"The article is written in a sensationalistic style, 'yes' or 'no'\"\n",
66 | " )\n",
67 | " is_language_german: str = Field(\n",
68 | " description=\"The language of the article is German, 'yes' or 'no'\"\n",
69 | " )\n",
70 | "\n",
71 | "\n",
72 | "llm_postability = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
73 | "structured_llm_postability_grader = llm_postability.with_structured_output(\n",
74 | " ArticlePostabilityGrader\n",
75 | ")\n",
76 | "\n",
77 | "postability_system = \"\"\"You are a grader assessing whether a news article is ready to be posted, if it meets the minimum word count of 200 words, is written in a sensationalistic style, and if it is in German. \\n\n",
78 | " Evaluate the article for grammatical errors, completeness, appropriateness for publication, and EXAGERATED sensationalism. \\n\n",
79 | " Also, confirm if the language used in the article is German and it meets the word count requirement. \\n\n",
80 | " Provide four binary scores: one to indicate if the article can be posted ('yes' or 'no'), one for adequate word count ('yes' or 'no'), one for sensationalistic writing ('yes' or 'no'), and another if the language is German ('yes' or 'no').\"\"\"\n",
81 | "postability_grade_prompt = ChatPromptTemplate.from_messages(\n",
82 | " [(\"system\", postability_system), (\"human\", \"News Article:\\n\\n {article}\")]\n",
83 | ")\n",
84 | "\n",
85 | "news_chef = postability_grade_prompt | structured_llm_postability_grader\n",
86 | "\n",
87 | "result = news_chef.invoke(\n",
88 | " {\n",
89 | " \"article\": \"Es wurde gemeldet, dass Messi von Real Madrid zu FC Barcelona wechselt.\"\n",
90 | " }\n",
91 | ")"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "llm_translation = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
101 | "\n",
102 | "translation_system = \"\"\"You are a translator converting articles into German. Translate the text accurately while maintaining the original tone and style.\"\"\"\n",
103 | "translation_prompt = ChatPromptTemplate.from_messages(\n",
104 | " [(\"system\", translation_system), (\"human\", \"Article to translate:\\n\\n {article}\")]\n",
105 | ")\n",
106 | "\n",
107 | "translator = translation_prompt | llm_translation\n",
108 | "\n",
109 | "result = translator.invoke(\n",
110 | " {\n",
111 | " \"article\": \"It has been reported that Messi will transfer from Real Madrid to FC Barcelona.\"\n",
112 | " }\n",
113 | ")\n",
114 | "print(result)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "llm_expansion = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0.5)\n",
124 | "expansion_system = \"\"\"You are a writer tasked with expanding the given article to at least 200 words while maintaining relevance, coherence, and the original tone.\"\"\"\n",
125 | "expansion_prompt = ChatPromptTemplate.from_messages(\n",
126 | " [(\"system\", expansion_system), (\"human\", \"Original article:\\n\\n {article}\")]\n",
127 | ")\n",
128 | "\n",
129 | "expander = expansion_prompt | llm_expansion\n",
130 | "\n",
131 | "article_content = \"Lionel Messi is reportedly considering a move from Real Madrid to FC Barcelona next season.\"\n",
132 | "result = expander.invoke({\"article\": article_content})\n",
133 | "print(result)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "from langgraph.graph import StateGraph, END\n",
143 | "from typing import TypedDict, Literal\n",
144 | "\n",
145 | "\n",
146 | "class AgentState(TypedDict):\n",
147 | " article_state: str"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "def get_transfer_news_grade(state: AgentState) -> AgentState:\n",
157 | " print(f\"get_transfer_news_grade: Current state: {state}\")\n",
158 | " print(\"Evaluator: Reading article but doing nothing to change it...\")\n",
159 | " return state"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "def evaluate_article(state: AgentState) -> AgentState:\n",
169 | " print(f\"evaluate_article: Current state: {state}\")\n",
170 | " print(\"News : Reading article but doing nothing to change it...\")\n",
171 | " return state"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "def translate_article(state: AgentState) -> AgentState:\n",
181 | " print(f\"translate_article: Current state: {state}\")\n",
182 | " article = state[\"article_state\"]\n",
183 | " result = translator.invoke({\"article\": article})\n",
184 | " state[\"article_state\"] = result.content\n",
185 | " return state"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "def expand_article(state: AgentState) -> AgentState:\n",
195 | " print(f\"expand_article: Current state: {state}\")\n",
196 | " article = state[\"article_state\"]\n",
197 | " result = expander.invoke({\"article\": article})\n",
198 | " state[\"article_state\"] = result.content\n",
199 | " return state"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "def publisher(state: AgentState) -> AgentState:\n",
209 | " print(f\"publisher: Current state: {state}\")\n",
210 | " print(\"FINAL_STATE in publisher:\", state)\n",
211 | " return state"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "def evaluator_router(state: AgentState) -> Literal[\"news_chef\", \"not_relevant\"]:\n",
221 | " article = state[\"article_state\"]\n",
222 | " evaluator = grade_prompt | structured_llm_grader\n",
223 | " result = evaluator.invoke({\"article\": article})\n",
224 | " print(f\"evaluator_router: Current state: {state}\")\n",
225 | " print(\"Evaluator result: \", result)\n",
226 | " if result.binary_score == \"yes\":\n",
227 | " return \"news_chef\"\n",
228 | " else:\n",
229 | " return \"not_relevant\""
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "def news_chef_router(\n",
239 | " state: AgentState,\n",
240 | ") -> Literal[\"translator\", \"publisher\", \"expander\"]:\n",
241 | " article = state[\"article_state\"]\n",
242 | " result = news_chef.invoke({\"article\": article})\n",
243 | " print(f\"news_chef_router: Current state: {state}\")\n",
244 | " print(\"News chef result: \", result)\n",
245 | " if result.can_be_posted == \"yes\":\n",
246 | " return \"publisher\"\n",
247 | " elif result.is_language_german == \"yes\":\n",
248 | " if result.meets_word_count == \"no\" or result.is_sensationalistic == \"no\":\n",
249 | " return \"expander\"\n",
250 | " return \"translator\""
251 | ]
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": null,
256 | "metadata": {},
257 | "outputs": [],
258 | "source": [
259 | "workflow = StateGraph(AgentState)\n",
260 | "\n",
261 | "workflow.add_node(\"evaluator\", get_transfer_news_grade)\n",
262 | "workflow.add_node(\"news_chef\", evaluate_article)\n",
263 | "workflow.add_node(\"translator\", translate_article)\n",
264 | "workflow.add_node(\"expander\", expand_article)\n",
265 | "workflow.add_node(\"publisher\", publisher)\n",
266 | "\n",
267 | "workflow.set_entry_point(\"evaluator\")\n",
268 | "\n",
269 | "workflow.add_conditional_edges(\n",
270 | " \"evaluator\", evaluator_router, {\"news_chef\": \"news_chef\", \"not_relevant\": END}\n",
271 | ")\n",
272 | "workflow.add_conditional_edges(\n",
273 | " \"news_chef\",\n",
274 | " news_chef_router,\n",
275 | " {\"translator\": \"translator\", \"publisher\": \"publisher\", \"expander\": \"expander\"},\n",
276 | ")\n",
277 | "workflow.add_edge(\"translator\", \"news_chef\")\n",
278 | "workflow.add_edge(\"expander\", \"news_chef\")\n",
279 | "workflow.add_edge(\"publisher\", END)\n",
280 | "\n",
281 | "app = workflow.compile()"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "from IPython.display import Image, display\n",
291 | "\n",
292 | "try:\n",
293 | " display(Image(app.get_graph(xray=True).draw_mermaid_png()))\n",
294 | "except:\n",
295 | " pass"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": null,
301 | "metadata": {},
302 | "outputs": [],
303 | "source": [
304 | "initial_state = {\"article_state\": \"The Pope will visit Spain today\"}\n",
305 | "result = app.invoke(initial_state)\n",
306 | "\n",
307 | "print(\"Final result:\", result)"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": null,
313 | "metadata": {},
314 | "outputs": [],
315 | "source": [
316 | "initial_state = {\"article_state\": \"Messi gonna switch from barca to real madrid\"}\n",
317 | "result = app.invoke(initial_state)\n",
318 | "\n",
319 | "print(\"Final result:\", result)"
320 | ]
321 | }
322 | ],
323 | "metadata": {
324 | "kernelspec": {
325 | "display_name": "app",
326 | "language": "python",
327 | "name": "python3"
328 | },
329 | "language_info": {
330 | "codemirror_mode": {
331 | "name": "ipython",
332 | "version": 3
333 | },
334 | "file_extension": ".py",
335 | "mimetype": "text/x-python",
336 | "name": "python",
337 | "nbconvert_exporter": "python",
338 | "pygments_lexer": "ipython3",
339 | "version": "3.11.0"
340 | }
341 | },
342 | "nbformat": 4,
343 | "nbformat_minor": 2
344 | }
345 |
--------------------------------------------------------------------------------
/intelligent_rag.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "from abc import ABC, abstractmethod\n",
11 | "from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
12 | "from langchain_chroma import Chroma\n",
13 | "from langchain_core.tools import Tool\n",
14 | "from langchain.tools.retriever import create_retriever_tool\n",
15 | "from langchain.schema import Document\n",
16 | "from dotenv import load_dotenv\n",
17 | "\n",
18 | "load_dotenv()\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "def summarize_documents(docs: list[\"Document\"]) -> str:\n",
28 | " \"\"\"\n",
29 | " Summarize the given documents in one to two sentences.\n",
30 | " \"\"\"\n",
31 | " llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
32 | " all_text = \"\\n\".join(doc.page_content for doc in docs)\n",
33 | " prompt = f\"Please summarize the following text in 1-2 sentences:\\n---\\n{all_text}\\n---\"\n",
34 | " summary = llm.invoke(prompt)\n",
35 | " return summary.content.strip()\n"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "class AbstractVectorStoreObserver(ABC):\n",
45 | " \"\"\"\n",
46 | " Ein Interface für alle Observer, die benachrichtigt werden möchten,\n",
47 | " sobald ein VectorStore neue Dokumente bekommt / aktualisiert wird.\n",
48 | " \"\"\"\n",
49 | " @abstractmethod\n",
50 | " def on_vectorstore_update(self, manager: \"SingleVectorStoreManager\"):\n",
51 | " pass\n"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "class SingleVectorStoreManager:\n",
61 | " def __init__(self, persist_dir: str):\n",
62 | " self.embedding_function = OpenAIEmbeddings()\n",
63 | " self.persist_dir = persist_dir\n",
64 | "\n",
65 | " collection_name = os.path.basename(persist_dir)\n",
66 | " self.vs = Chroma(\n",
67 | " collection_name=collection_name,\n",
68 | " embedding_function=self.embedding_function,\n",
69 | " persist_directory=self.persist_dir\n",
70 | " )\n",
71 | "\n",
72 | " self.description = \"Dieser Vectorstore ist leer.\"\n",
73 | "\n",
74 | " self.observers: list[AbstractVectorStoreObserver] = []\n",
75 | "\n",
76 | " def add_observer(self, observer: AbstractVectorStoreObserver):\n",
77 | " self.observers.append(observer)\n",
78 | "\n",
79 | " def remove_observer(self, observer: AbstractVectorStoreObserver):\n",
80 | " if observer in self.observers:\n",
81 | " self.observers.remove(observer)\n",
82 | "\n",
83 | " def notify_observers(self):\n",
84 | " for obs in self.observers:\n",
85 | " obs.on_vectorstore_update(self)\n",
86 | "\n",
87 | " def is_empty(self) -> bool:\n",
88 | " return (self.vs._collection.count() == 0)\n",
89 | "\n",
90 | " def create_retriever_tool(self, name: str, custom_description: str | None = None) -> Tool:\n",
91 | "\n",
92 | " retriever = self.vs.as_retriever()\n",
93 | " desc = custom_description if custom_description else self.description\n",
94 | " if self.is_empty():\n",
95 | " desc += \"\\n(Hinweis: Dieser Vectorstore ist aktuell leer.)\"\n",
96 | "\n",
97 | " tool = create_retriever_tool(\n",
98 | " retriever=retriever,\n",
99 | " name=name,\n",
100 | " description=desc\n",
101 | " )\n",
102 | " return tool\n",
103 | "\n",
104 | " def add_documents(self, docs: list[Document], update_description: bool = True):\n",
105 | "\n",
106 | " self.vs.add_documents(docs)\n",
107 | " if update_description:\n",
108 | " summary_text = summarize_documents(docs)\n",
109 | " if self.is_empty():\n",
110 | " pass\n",
111 | " self.description = summary_text\n",
112 | " self.notify_observers()"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "from langchain_core.tools import Tool\n",
122 | "from langchain_openai import ChatOpenAI\n",
123 | "from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage\n",
124 | "\n",
125 | "class LLMToolBinder:\n",
126 | " def __init__(self, llm_with_tools: ChatOpenAI, managers: list[\"SingleVectorStoreManager\"], extra_tools: list[Tool] | None = None):\n",
127 | " self.llm_with_tools = llm_with_tools\n",
128 | " self.llm_no_tools = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
129 | " self.managers = managers\n",
130 | " self.extra_tools = extra_tools or []\n",
131 | " self.tools: list[Tool] = []\n",
132 | " self._bind_tools()\n",
133 | "\n",
134 | " def _bind_tools(self):\n",
135 | " new_tools = []\n",
136 | " for i, m in enumerate(self.managers, start=1):\n",
137 | " tool_name = f\"retriever_store{i}\"\n",
138 | " new_tools.append(m.create_retriever_tool(name=tool_name))\n",
139 | " new_tools.extend(self.extra_tools)\n",
140 | " self.tools = new_tools\n",
141 | " self.llm_with_tools = self.llm_with_tools.bind_tools(self.tools, tool_choice=\"required\")\n",
142 | "\n",
143 | " def on_vectorstore_update(self, manager: \"SingleVectorStoreManager\"):\n",
144 | " self._bind_tools()\n",
145 | "\n",
146 | " def invoke_llm(self, query: str) -> str:\n",
147 | " system_prompt = (\n",
148 | " \"You are a helpful assistant. You may call the available tools if needed. \"\n",
149 | " \"Once you receive tool outputs, focus on the last tool message and provide a final user-facing answer.\"\n",
150 | " )\n",
151 | " messages = [SystemMessage(content=system_prompt), HumanMessage(content=query)]\n",
152 | " first_output = self.llm_with_tools.invoke(messages)\n",
153 | " messages.append(first_output)\n",
154 | " if first_output.tool_calls:\n",
155 | " for tc in first_output.tool_calls:\n",
156 | " tool_name = tc[\"name\"]\n",
157 | " tool_args = tc[\"args\"]\n",
158 | " print(f\"Tool chosen: {tool_name} with args={tool_args}\")\n",
159 | " found_tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)\n",
160 | " if not found_tool:\n",
161 | " tool_result = f\"No matching tool named '{tool_name}'.\"\n",
162 | " else:\n",
163 | " tool_result = found_tool.invoke(tool_args)\n",
164 | " messages.append(ToolMessage(content=tool_result, tool_call_id=tc[\"id\"]))\n",
165 | " messages.append(SystemMessage(content=\"Focus on the last tool message. Provide your final answer.\"))\n",
166 | " second_output = self.llm_no_tools.invoke(messages)\n",
167 | " messages.append(second_output)\n",
168 | " return second_output.content\n",
169 | " else:\n",
170 | " return first_output.content\n",
171 | "\n",
172 | " def print_all_tool_descriptions(self):\n",
173 | " for tool in self.tools:\n",
174 | " print(tool.name, \":\", tool.description)\n"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "base_dir = \"my_chroma_db\"\n",
184 | "os.makedirs(base_dir, exist_ok=True)\n",
185 | "\n",
186 | "manager1 = SingleVectorStoreManager(os.path.join(base_dir, \"store1\"))\n",
187 | "manager2 = SingleVectorStoreManager(os.path.join(base_dir, \"store2\"))\n",
188 | "manager3 = SingleVectorStoreManager(os.path.join(base_dir, \"store3\"))\n"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "from langchain_core.tools import tool\n",
198 | "\n",
199 | "@tool\n",
200 | "def fallback_tool(message: str) -> str:\n",
201 | " \"\"\"\n",
202 | " A fallback tool if no other tool is appropriate.\n",
203 | "\n",
204 | " Args:\n",
205 | " message (str): The user query, or any text.\n",
206 | "\n",
207 | " Returns:\n",
208 | " str: A fallback response for questions that the model cannot answer\n",
209 | " with the other tools.\n",
210 | " \"\"\"\n",
211 | " return f\"I don´t know how to answer {message}'\"\n"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "from langchain_core.tools import tool\n",
221 | "\n",
222 | "def create_vectorstore_info_tool(managers: list[\"SingleVectorStoreManager\"]):\n",
223 | " @tool\n",
224 | " def vectorstore_info(query: str) -> str:\n",
225 | " \"\"\"\n",
226 | " Use this tool to reveal internal knowledge about the agent, including:\n",
227 | " - The total number of vectorstores,\n",
228 | " - Each vectorstore’s document count,\n",
229 | " - Each vectorstore’s description or summary.\n",
230 | " \"\"\"\n",
231 | " lines = [f\"Total vectorstores: {len(managers)}\"]\n",
232 | " for i, m in enumerate(managers, start=1):\n",
233 | " doc_count = m.vs._collection.count()\n",
234 | " lines.append(\n",
235 | " f\"VectorStore {i}: {doc_count} documents\\n\"\n",
236 | " f\"Description: {m.description}\"\n",
237 | " )\n",
238 | " return \"\\n\\n\".join(lines)\n",
239 | "\n",
240 | " return vectorstore_info\n",
241 | "\n",
242 | "\n",
243 | "info_tool = create_vectorstore_info_tool(managers=[manager1, manager2, manager3])"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
253 | "\n",
254 | "binder = LLMToolBinder(llm, [manager1, manager2, manager3], extra_tools=[fallback_tool, info_tool])\n",
255 | "\n",
256 | "manager1.add_observer(binder)\n",
257 | "manager2.add_observer(binder)\n",
258 | "manager3.add_observer(binder)"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "binder.invoke_llm(\"Where is Lacarelli?\")"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "docs_store1 = [\n",
277 | " Document(\n",
278 | " page_content=(\n",
279 | " \"Lacarelli is a charming family-run Italian restaurant nestled in the \"\n",
280 | " \"heart of Berlin. Its menu features authentic dishes like homemade \"\n",
281 | " \"ravioli, wood-fired pizzas, and creamy tiramisu. With friendly staff, \"\n",
282 | " \"rustic decor, and a cozy atmosphere, Lacarelli provides an inviting \"\n",
283 | " \"dining experience for lovers of Italian cuisine and fine wines daily.\"\n",
284 | " )\n",
285 | " )\n",
286 | "]\n",
287 | "manager1.add_documents(docs_store1, update_description=True)"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "binder.print_all_tool_descriptions()"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "metadata": {},
303 | "outputs": [],
304 | "source": [
305 | "binder.invoke_llm(\"Where is Lacarelli?\")"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": null,
311 | "metadata": {},
312 | "outputs": [],
313 | "source": [
314 | "binder.invoke_llm(\"What do you know?\")"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "metadata": {},
321 | "outputs": [],
322 | "source": [
323 | "binder.invoke_llm(\"How many vectorstores do you manage?\")"
324 | ]
325 | }
326 | ],
327 | "metadata": {
328 | "kernelspec": {
329 | "display_name": ".venv",
330 | "language": "python",
331 | "name": "python3"
332 | },
333 | "language_info": {
334 | "codemirror_mode": {
335 | "name": "ipython",
336 | "version": 3
337 | },
338 | "file_extension": ".py",
339 | "mimetype": "text/x-python",
340 | "name": "python",
341 | "nbconvert_exporter": "python",
342 | "pygments_lexer": "ipython3",
343 | "version": "3.11.0"
344 | }
345 | },
346 | "nbformat": 4,
347 | "nbformat_minor": 2
348 | }
349 |
--------------------------------------------------------------------------------
/crag.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from dotenv import load_dotenv\n",
10 | "\n",
11 | "load_dotenv()"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "#### Creating VectorDatabase"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "from langchain.schema import Document\n",
28 | "from langchain_openai import OpenAIEmbeddings\n",
29 | "from langchain_community.vectorstores import Chroma\n",
30 | "\n",
31 | "embedding_function = OpenAIEmbeddings()\n",
32 | "\n",
33 | "docs = [\n",
34 | " Document(\n",
35 | " page_content=\"Bella Vista is owned by Antonio Rossi, a renowned chef with over 20 years of experience in the culinary industry. He started Bella Vista to bring authentic Italian flavors to the community.\",\n",
36 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
37 | " ),\n",
38 | " Document(\n",
39 | " page_content=\"Bella Vista offers a range of dishes with prices that cater to various budgets. Appetizers start at $8, main courses range from $15 to $35, and desserts are priced between $6 and $12.\",\n",
40 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
41 | " ),\n",
42 | " Document(\n",
43 | " page_content=\"Bella Vista is open from Monday to Sunday. Weekday hours are 11:00 AM to 10:00 PM, while weekend hours are extended from 11:00 AM to 11:00 PM.\",\n",
44 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
45 | " ),\n",
46 | " Document(\n",
47 | " page_content=\"Bella Vista offers a variety of menus including a lunch menu, dinner menu, and a special weekend brunch menu. The lunch menu features light Italian fare, the dinner menu offers a more extensive selection of traditional and contemporary dishes, and the brunch menu includes both classic breakfast items and Italian specialties.\",\n",
48 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
49 | " ),\n",
50 | "]\n",
51 | "\n",
52 | "db = Chroma.from_documents(docs, embedding_function)\n",
53 | "retriever = db.as_retriever()"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "from typing_extensions import TypedDict\n",
63 | "\n",
64 | "\n",
65 | "class AgentState(TypedDict):\n",
66 | " question: str\n",
67 | " grades: list[str]\n",
68 | " llm_output: str\n",
69 | " documents: list[str]\n",
70 | " on_topic: bool"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "def retrieve_docs(state: AgentState):\n",
80 | " question = state[\"question\"]\n",
81 | " documents = retriever.get_relevant_documents(query=question)\n",
82 | " state[\"documents\"] = [doc.page_content for doc in documents]\n",
83 | " return state"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "from langchain_core.pydantic_v1 import BaseModel, Field\n",
93 | "from langchain_openai import ChatOpenAI\n",
94 | "from langchain_core.prompts import ChatPromptTemplate\n",
95 | "\n",
96 | "\n",
97 | "class GradeQuestion(BaseModel):\n",
98 | " \"\"\"Boolean value to check whether a question is releated to the restaurant Bella Vista\"\"\"\n",
99 | "\n",
100 | " score: str = Field(\n",
101 | " description=\"Question is about restaurant? If yes -> 'Yes' if not -> 'No'\"\n",
102 | " )\n",
103 | "\n",
104 | "\n",
105 | "def question_classifier(state: AgentState):\n",
106 | " question = state[\"question\"]\n",
107 | "\n",
108 | " system = \"\"\"You are a grader assessing the relevance of a retrieved document to a user question. \\n\n",
109 | " Only answer if the question is about one of the following topics:\n",
110 | " 1. Information about the owner of Bella Vista (Antonio Rossi).\n",
111 | " 2. Prices of dishes at Bella Vista.\n",
112 | " 3. Opening hours of Bella Vista.\n",
113 | " 4. Available menus at Bella Vista.\n",
114 | "\n",
115 | " If the question IS about these topics response with \"Yes\", otherwise respond with \"No\".\n",
116 | " \"\"\"\n",
117 | "\n",
118 | " grade_prompt = ChatPromptTemplate.from_messages(\n",
119 | " [\n",
120 | " (\"system\", system),\n",
121 | " (\"human\", \"User question: {question}\"),\n",
122 | " ]\n",
123 | " )\n",
124 | "\n",
125 | " llm = ChatOpenAI()\n",
126 | " structured_llm = llm.with_structured_output(GradeQuestion)\n",
127 | " grader_llm = grade_prompt | structured_llm\n",
128 | " result = grader_llm.invoke({\"question\": question})\n",
129 | " state[\"on_topic\"] = result.score\n",
130 | " return state"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "def on_topic_router(state: AgentState):\n",
140 | " on_topic = state[\"on_topic\"]\n",
141 | " if on_topic.lower() == \"yes\":\n",
142 | " return \"on_topic\"\n",
143 | " return \"off_topic\""
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "def off_topic_response(state: AgentState):\n",
153 | " state[\"llm_output\"] = \"I cant respond to that!\"\n",
154 | " return state"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "class GradeDocuments(BaseModel):\n",
164 | " \"\"\"Boolean values to check for relevance on retrieved documents.\"\"\"\n",
165 | "\n",
166 | " score: str = Field(\n",
167 | " description=\"Documents are relevant to the question, 'Yes' or 'No'\"\n",
168 | " )\n",
169 | "\n",
170 | "\n",
171 | "def document_grader(state: AgentState):\n",
172 | " docs = state[\"documents\"]\n",
173 | " question = state[\"question\"]\n",
174 | "\n",
175 | " system = \"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n\n",
176 | " If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \\n\n",
177 | " Give a binary score 'Yes' or 'No' score to indicate whether the document is relevant to the question.\"\"\"\n",
178 | "\n",
179 | " grade_prompt = ChatPromptTemplate.from_messages(\n",
180 | " [\n",
181 | " (\"system\", system),\n",
182 | " (\n",
183 | " \"human\",\n",
184 | " \"Retrieved document: \\n\\n {document} \\n\\n User question: {question}\",\n",
185 | " ),\n",
186 | " ]\n",
187 | " )\n",
188 | "\n",
189 | " llm = ChatOpenAI()\n",
190 | " structured_llm = llm.with_structured_output(GradeDocuments)\n",
191 | " grader_llm = grade_prompt | structured_llm\n",
192 | " scores = []\n",
193 | " for doc in docs:\n",
194 | " result = grader_llm.invoke({\"document\": doc, \"question\": question})\n",
195 | " scores.append(result.score)\n",
196 | " state[\"grades\"] = scores\n",
197 | " return state"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "def gen_router(state: AgentState):\n",
207 | " grades = state[\"grades\"]\n",
208 | " print(\"DOCUMENT GRADES:\", grades)\n",
209 | "\n",
210 | " if any(grade.lower() == \"yes\" for grade in grades):\n",
211 | " filtered_grades = [grade for grade in grades if grade.lower() == \"yes\"]\n",
212 | " print(\"FILTERED DOCUMENT GRADES:\", filtered_grades)\n",
213 | " return \"generate\"\n",
214 | " else:\n",
215 | " return \"rewrite_query\""
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "from langchain_core.output_parsers import StrOutputParser\n",
225 | "\n",
226 | "\n",
227 | "def rewriter(state: AgentState):\n",
228 | " question = state[\"question\"]\n",
229 | " system = \"\"\"You a question re-writer that converts an input question to a better version that is optimized \\n\n",
230 | " for retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\"\"\"\n",
231 | " re_write_prompt = ChatPromptTemplate.from_messages(\n",
232 | " [\n",
233 | " (\"system\", system),\n",
234 | " (\n",
235 | " \"human\",\n",
236 | " \"Here is the initial question: \\n\\n {question} \\n Formulate an improved question.\",\n",
237 | " ),\n",
238 | " ]\n",
239 | " )\n",
240 | " llm = ChatOpenAI()\n",
241 | " question_rewriter = re_write_prompt | llm | StrOutputParser()\n",
242 | " output = question_rewriter.invoke({\"question\": question})\n",
243 | " state[\"question\"] = output\n",
244 | " return state"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "from langchain_core.prompts import ChatPromptTemplate\n",
254 | "from langchain.schema.output_parser import StrOutputParser\n",
255 | "\n",
256 | "\n",
257 | "def generate_answer(state: AgentState):\n",
258 | " llm = ChatOpenAI()\n",
259 | " question = state[\"question\"]\n",
260 | " context = state[\"documents\"]\n",
261 | "\n",
262 | " template = \"\"\"Answer the question based only on the following context:\n",
263 | " {context}\n",
264 | "\n",
265 | " Question: {question}\n",
266 | " \"\"\"\n",
267 | "\n",
268 | " prompt = ChatPromptTemplate.from_template(\n",
269 | " template=template,\n",
270 | " )\n",
271 | " chain = prompt | llm | StrOutputParser()\n",
272 | " result = chain.invoke({\"question\": question, \"context\": context})\n",
273 | " state[\"llm_output\"] = result\n",
274 | " return state"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {},
281 | "outputs": [],
282 | "source": [
283 | "from langgraph.graph import StateGraph, END\n",
284 | "\n",
285 | "workflow = StateGraph(AgentState)\n",
286 | "\n",
287 | "workflow.add_node(\"topic_decision\", question_classifier)\n",
288 | "workflow.add_node(\"off_topic_response\", off_topic_response)\n",
289 | "workflow.add_node(\"retrieve_docs\", retrieve_docs)\n",
290 | "workflow.add_node(\"rewrite_query\", rewriter)\n",
291 | "workflow.add_node(\"generate_answer\", generate_answer)\n",
292 | "workflow.add_node(\"document_grader\", document_grader)\n",
293 | "\n",
294 | "workflow.add_edge(\"off_topic_response\", END)\n",
295 | "workflow.add_edge(\"retrieve_docs\", \"document_grader\")\n",
296 | "workflow.add_conditional_edges(\n",
297 | " \"topic_decision\",\n",
298 | " on_topic_router,\n",
299 | " {\n",
300 | " \"on_topic\": \"retrieve_docs\",\n",
301 | " \"off_topic\": \"off_topic_response\",\n",
302 | " },\n",
303 | ")\n",
304 | "workflow.add_conditional_edges(\n",
305 | " \"document_grader\",\n",
306 | " gen_router,\n",
307 | " {\n",
308 | " \"generate\": \"generate_answer\",\n",
309 | " \"rewrite_query\": \"rewrite_query\",\n",
310 | " },\n",
311 | ")\n",
312 | "workflow.add_edge(\"rewrite_query\", \"retrieve_docs\")\n",
313 | "workflow.add_edge(\"generate_answer\", END)\n",
314 | "\n",
315 | "\n",
316 | "workflow.set_entry_point(\"topic_decision\")\n",
317 | "\n",
318 | "app = workflow.compile()"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": null,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "from IPython.display import Image, display\n",
328 | "\n",
329 | "try:\n",
330 | " display(Image(app.get_graph(xray=True).draw_mermaid_png()))\n",
331 | "except:\n",
332 | " pass"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": null,
338 | "metadata": {},
339 | "outputs": [],
340 | "source": [
341 | "result = app.invoke({\"question\": \"How is the weather?\"})\n",
342 | "result[\"llm_output\"]"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "metadata": {},
349 | "outputs": [],
350 | "source": [
351 | "result = app.invoke({\"question\": \"who is the owner of the Bella vista??\"})\n",
352 | "result[\"llm_output\"]"
353 | ]
354 | }
355 | ],
356 | "metadata": {
357 | "kernelspec": {
358 | "display_name": ".venv",
359 | "language": "python",
360 | "name": "python3"
361 | },
362 | "language_info": {
363 | "codemirror_mode": {
364 | "name": "ipython",
365 | "version": 3
366 | },
367 | "file_extension": ".py",
368 | "mimetype": "text/x-python",
369 | "name": "python",
370 | "nbconvert_exporter": "python",
371 | "pygments_lexer": "ipython3",
372 | "version": "3.11.0"
373 | }
374 | },
375 | "nbformat": 4,
376 | "nbformat_minor": 2
377 | }
378 |
--------------------------------------------------------------------------------
/human_in_loop.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from dotenv import load_dotenv\n",
10 | "\n",
11 | "load_dotenv()"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from typing import Annotated\n",
21 | "\n",
22 | "from typing_extensions import TypedDict\n",
23 | "\n",
24 | "from langgraph.graph.message import add_messages\n",
25 | "\n",
26 | "\n",
27 | "class State(TypedDict):\n",
28 | " messages: Annotated[list, add_messages]"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "from langchain_community.tools.tavily_search import TavilySearchResults\n",
38 | "from langchain_openai import ChatOpenAI\n",
39 | "from typing_extensions import TypedDict\n",
40 | "\n",
41 | "from langgraph.checkpoint.sqlite import SqliteSaver\n",
42 | "from langgraph.graph import StateGraph\n",
43 | "from langgraph.graph.message import add_messages\n",
44 | "from langgraph.prebuilt import ToolNode, tools_condition\n",
45 | "\n",
46 | "memory = SqliteSaver.from_conn_string(\":memory:\")\n",
47 | "\n",
48 | "\n",
49 | "class State(TypedDict):\n",
50 | " messages: Annotated[list, add_messages]\n",
51 | "\n",
52 | "\n",
53 | "graph_builder = StateGraph(State)\n",
54 | "\n",
55 | "tool = TavilySearchResults(max_results=2)\n",
56 | "tools = [tool]\n",
57 | "llm = ChatOpenAI()\n",
58 | "llm_with_tools = llm.bind_tools(tools)\n",
59 | "\n",
60 | "\n",
61 | "def chatbot(state: State):\n",
62 | " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "graph_builder.add_node(\"chatbot\", chatbot)\n",
72 | "\n",
73 | "tool_node = ToolNode(tools=[tool])\n",
74 | "graph_builder.add_node(\"tools\", tool_node)\n",
75 | "\n",
76 | "graph_builder.add_conditional_edges(\n",
77 | " \"chatbot\",\n",
78 | " tools_condition,\n",
79 | ")\n",
80 | "graph_builder.add_edge(\"tools\", \"chatbot\")\n",
81 | "graph_builder.set_entry_point(\"chatbot\")"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "graph = graph_builder.compile(\n",
91 | " checkpointer=memory,\n",
92 | " interrupt_before=[\"tools\"],\n",
93 | ")"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "from IPython.display import Image, display\n",
103 | "\n",
104 | "try:\n",
105 | " display(Image(graph.get_graph().draw_mermaid_png()))\n",
106 | "except Exception:\n",
107 | " pass"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "from langchain_core.messages import HumanMessage\n",
117 | "\n",
118 | "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
119 | "input_message = HumanMessage(content=\"Hello, I am John\")\n",
120 | "\n",
121 | "graph.invoke({\"messages\": input_message}, config=config)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "config = {\"configurable\": {\"thread_id\": \"100\"}}\n",
131 | "input_message = HumanMessage(content=\"Sorry, did I already introduce myself?\")\n",
132 | "\n",
133 | "graph.invoke({\"messages\": input_message}, config=config)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
143 | "input_message = HumanMessage(content=\"Sorry, did I already introduce myself?\")\n",
144 | "\n",
145 | "graph.invoke({\"messages\": input_message}, config=config)"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "from langchain_core.messages import HumanMessage\n",
155 | "\n",
156 | "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
157 | "input_message = HumanMessage(content=\"How is the weather in Los Angeles?\")\n",
158 | "\n",
159 | "graph.invoke({\"messages\": input_message}, config=config)"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "snapshot = graph.get_state(config)\n",
169 | "snapshot.next"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "graph.invoke(None, config=config)"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "from langchain_core.messages import HumanMessage\n",
188 | "\n",
189 | "config = {\"configurable\": {\"thread_id\": \"2\"}}\n",
190 | "input_message = HumanMessage(content=\"How is the weather in Los Angeles?\")\n",
191 | "\n",
192 | "graph.invoke({\"messages\": input_message}, config=config)"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "snapshot = graph.get_state(config)\n",
202 | "existing_message = snapshot.values[\"messages\"][-1]\n",
203 | "existing_message.pretty_print()"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "from langchain_core.messages import AIMessage, ToolMessage\n",
213 | "\n",
214 | "answer = \"It is only 5°C warm today!\"\n",
215 | "new_messages = [\n",
216 | " ToolMessage(content=answer, tool_call_id=existing_message.tool_calls[0][\"id\"]),\n",
217 | " AIMessage(content=answer),\n",
218 | "]"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "new_messages[-1].pretty_print()\n",
228 | "graph.update_state(\n",
229 | " config,\n",
230 | " {\"messages\": new_messages},\n",
231 | ")\n",
232 | "\n",
233 | "print(graph.get_state(config).values[\"messages\"][-2:])"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": null,
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "config = {\"configurable\": {\"thread_id\": \"2\"}}\n",
243 | "input_message = HumanMessage(content=\"How warm was it again?\")\n",
244 | "\n",
245 | "graph.invoke({\"messages\": input_message}, config=config)"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {},
251 | "source": [
252 | "### Custom State "
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "from typing import Annotated\n",
262 | "\n",
263 | "from langchain_openai import ChatOpenAI\n",
264 | "from langchain_community.tools.tavily_search import TavilySearchResults\n",
265 | "from typing_extensions import TypedDict\n",
266 | "\n",
267 | "from langgraph.checkpoint.sqlite import SqliteSaver\n",
268 | "from langgraph.graph import StateGraph\n",
269 | "from langgraph.graph.message import add_messages\n",
270 | "from langgraph.prebuilt import ToolNode, tools_condition\n",
271 | "\n",
272 | "\n",
273 | "class State(TypedDict):\n",
274 | " messages: Annotated[list, add_messages]\n",
275 | " ask_human: bool"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": null,
281 | "metadata": {},
282 | "outputs": [],
283 | "source": [
284 | "from langchain_core.tools import tool\n",
285 | "\n",
286 | "\n",
287 | "@tool\n",
288 | "def request_assistance():\n",
289 | " \"\"\"Escalate the conversation to an expert. Use this if you are unable to assist directly or if the user requires support beyond your permissions.\n",
290 | "\n",
291 | " To use this function, relay the user's 'request' so the expert can provide the right guidance.\n",
292 | " \"\"\"\n",
293 | " return \"\""
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "metadata": {},
300 | "outputs": [],
301 | "source": [
302 | "tool = TavilySearchResults(max_results=2)\n",
303 | "tools = [tool]\n",
304 | "llm = ChatOpenAI()\n",
305 | "llm_with_tools = llm.bind_tools(tools + [request_assistance])\n",
306 | "\n",
307 | "\n",
308 | "def chatbot(state: State):\n",
309 | " response = llm_with_tools.invoke(state[\"messages\"])\n",
310 | " ask_human = False\n",
311 | " if response.tool_calls and response.tool_calls[0][\"name\"] == \"request_assistance\":\n",
312 | " ask_human = True\n",
313 | " return {\"messages\": [response], \"ask_human\": ask_human}"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "metadata": {},
320 | "outputs": [],
321 | "source": [
322 | "graph_builder = StateGraph(State)\n",
323 | "\n",
324 | "graph_builder.add_node(\"chatbot\", chatbot)\n",
325 | "graph_builder.add_node(\"tools\", ToolNode(tools=[tool]))"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": null,
331 | "metadata": {},
332 | "outputs": [],
333 | "source": [
334 | "from langchain_core.messages import AIMessage, ToolMessage\n",
335 | "\n",
336 | "\n",
337 | "def create_response(response: str, ai_message: AIMessage):\n",
338 | " return ToolMessage(\n",
339 | " content=response,\n",
340 | " tool_call_id=ai_message.tool_calls[0][\"id\"],\n",
341 | " )\n",
342 | "\n",
343 | "\n",
344 | "def human_node(state: State):\n",
345 | " new_messages = []\n",
346 | " if not isinstance(state[\"messages\"][-1], ToolMessage):\n",
347 | " new_messages.append(\n",
348 | " create_response(\"No response from human.\", state[\"messages\"][-1])\n",
349 | " )\n",
350 | " return {\n",
351 | " \"messages\": new_messages,\n",
352 | " \"ask_human\": False,\n",
353 | " }\n",
354 | "\n",
355 | "\n",
356 | "graph_builder.add_node(\"human\", human_node)"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": null,
362 | "metadata": {},
363 | "outputs": [],
364 | "source": [
365 | "def select_next_node(state: State):\n",
366 | " if state[\"ask_human\"]:\n",
367 | " return \"human\"\n",
368 | " return tools_condition(state)\n",
369 | "\n",
370 | "\n",
371 | "graph_builder.add_conditional_edges(\n",
372 | " \"chatbot\",\n",
373 | " select_next_node,\n",
374 | " {\"human\": \"human\", \"tools\": \"tools\", \"__end__\": \"__end__\"},\n",
375 | ")"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": null,
381 | "metadata": {},
382 | "outputs": [],
383 | "source": [
384 | "graph_builder.add_edge(\"tools\", \"chatbot\")\n",
385 | "graph_builder.add_edge(\"human\", \"chatbot\")\n",
386 | "graph_builder.set_entry_point(\"chatbot\")\n",
387 | "memory = SqliteSaver.from_conn_string(\":memory:\")\n",
388 | "graph = graph_builder.compile(\n",
389 | " checkpointer=memory,\n",
390 | " interrupt_before=[\"human\"],\n",
391 | ")"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "from IPython.display import Image, display\n",
401 | "\n",
402 | "try:\n",
403 | " display(Image(graph.get_graph().draw_mermaid_png()))\n",
404 | "except Exception:\n",
405 | " pass"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": null,
411 | "metadata": {},
412 | "outputs": [],
413 | "source": [
414 | "config = {\"configurable\": {\"thread_id\": \"50\"}}\n",
415 | "input_message = HumanMessage(\n",
416 | " content=\"I need some expert advice on how to plan a trip to barcelona\"\n",
417 | ")\n",
418 | "\n",
419 | "graph.invoke({\"messages\": input_message}, config=config)"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": null,
425 | "metadata": {},
426 | "outputs": [],
427 | "source": [
428 | "snapshot = graph.get_state(config)\n",
429 | "snapshot.next"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": null,
435 | "metadata": {},
436 | "outputs": [],
437 | "source": [
438 | "ai_message = snapshot.values[\"messages\"][-1]\n",
439 | "human_response = \"best hotel: hotelxyz; best flight, flightxyz\"\n",
440 | "tool_message = create_response(human_response, ai_message)\n",
441 | "graph.update_state(config, {\"messages\": [tool_message]})"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": null,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "graph.invoke(None, config=config)"
451 | ]
452 | }
453 | ],
454 | "metadata": {
455 | "kernelspec": {
456 | "display_name": ".venv",
457 | "language": "python",
458 | "name": "python3"
459 | },
460 | "language_info": {
461 | "codemirror_mode": {
462 | "name": "ipython",
463 | "version": 3
464 | },
465 | "file_extension": ".py",
466 | "mimetype": "text/x-python",
467 | "name": "python",
468 | "nbconvert_exporter": "python",
469 | "pygments_lexer": "ipython3",
470 | "version": "3.11.0"
471 | }
472 | },
473 | "nbformat": 4,
474 | "nbformat_minor": 2
475 | }
476 |
--------------------------------------------------------------------------------
/crag_dynamic_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from dotenv import load_dotenv\n",
10 | "import os\n",
11 | "from langchain_ollama import ChatOllama\n",
12 | "from langchain_ollama import OllamaEmbeddings\n",
13 | "from langchain_openai import ChatOpenAI\n",
14 | "from langchain_openai import OpenAIEmbeddings\n",
15 | "\n",
16 | "load_dotenv()"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "def get_llm():\n",
26 | " llm_type = os.getenv(\"LLM_TYPE\", \"ollama\")\n",
27 | " if llm_type == \"ollama\":\n",
28 | " return ChatOllama(model=\"llama3.1\", temperature=0)\n",
29 | " else:\n",
30 | " return ChatOpenAI(temperature=0, model=\"gpt-4o-mini\")\n",
31 | "\n",
32 | "def get_embeddings():\n",
33 | " embedding_type = os.getenv(\"LLM_TYPE\", \"ollama\")\n",
34 | " if embedding_type == \"ollama\":\n",
35 | " return OllamaEmbeddings(model=\"llama3.1\")\n",
36 | " else:\n",
37 | " return OpenAIEmbeddings()"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "#### Creating VectorDatabase"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "from langchain.schema import Document\n",
54 | "from langchain_community.vectorstores import Chroma\n",
55 | "\n",
56 | "embedding_function = get_embeddings()\n",
57 | "\n",
58 | "docs = [\n",
59 | " Document(\n",
60 | " page_content=\"Bella Vista is owned by Antonio Rossi, a renowned chef with over 20 years of experience in the culinary industry. He started Bella Vista to bring authentic Italian flavors to the community.\",\n",
61 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
62 | " ),\n",
63 | " Document(\n",
64 | " page_content=\"Bella Vista offers a range of dishes with prices that cater to various budgets. Appetizers start at $8, main courses range from $15 to $35, and desserts are priced between $6 and $12.\",\n",
65 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
66 | " ),\n",
67 | " Document(\n",
68 | " page_content=\"Bella Vista is open from Monday to Sunday. Weekday hours are 11:00 AM to 10:00 PM, while weekend hours are extended from 11:00 AM to 11:00 PM.\",\n",
69 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
70 | " ),\n",
71 | " Document(\n",
72 | " page_content=\"Bella Vista offers a variety of menus including a lunch menu, dinner menu, and a special weekend brunch menu. The lunch menu features light Italian fare, the dinner menu offers a more extensive selection of traditional and contemporary dishes, and the brunch menu includes both classic breakfast items and Italian specialties.\",\n",
73 | " metadata={\"source\": \"restaurant_info.txt\"},\n",
74 | " ),\n",
75 | "]\n",
76 | "\n",
77 | "db = Chroma.from_documents(docs, embedding_function)\n",
78 | "retriever = db.as_retriever()"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "from typing_extensions import TypedDict\n",
88 | "\n",
89 | "\n",
90 | "class AgentState(TypedDict):\n",
91 | " question: str\n",
92 | " grades: list[str]\n",
93 | " llm_output: str\n",
94 | " documents: list[str]\n",
95 | " on_topic: bool"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "def retrieve_docs(state: AgentState):\n",
105 | " question = state[\"question\"]\n",
106 | " documents = retriever.get_relevant_documents(query=question)\n",
107 | " print(\"RETRIEVED DOCUMENTS:\", documents)\n",
108 | " state[\"documents\"] = [doc.page_content for doc in documents]\n",
109 | " return state"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "from langchain_core.pydantic_v1 import BaseModel, Field\n",
119 | "from langchain_core.prompts import ChatPromptTemplate\n",
120 | "\n",
121 | "\n",
122 | "class GradeQuestion(BaseModel):\n",
123 | " \"\"\"Boolean value to check whether a question is releated to the restaurant Bella Vista\"\"\"\n",
124 | "\n",
125 | " score: str = Field(\n",
126 | " description=\"Question is about restaurant? If yes -> 'Yes' if not -> 'No'\"\n",
127 | " )\n",
128 | "\n",
129 | "\n",
130 | "def question_classifier(state: AgentState):\n",
131 | " question = state[\"question\"]\n",
132 | "\n",
133 | " system = \"\"\"You are a grader assessing the topic a user question. \\n\n",
134 | " Only answer if the question is about one of the following topics:\n",
135 | " 1. Information about the owner of Bella Vista (Antonio Rossi).\n",
136 | " 2. Prices of dishes at Bella Vista.\n",
137 | " 3. Opening hours of Bella Vista.\n",
138 | " 4. Available menus at Bella Vista.\n",
139 | "\n",
140 | " Examples: How will the weather be today -> No\n",
141 | " Who owns the restaurant? -> Yes\n",
142 | " What food do you offer? -> Yes\n",
143 | "\n",
144 | " If the question IS about these topics response with \"Yes\", otherwise respond with \"No\".\n",
145 | " \"\"\"\n",
146 | "\n",
147 | " grade_prompt = ChatPromptTemplate.from_messages(\n",
148 | " [\n",
149 | " (\"system\", system),\n",
150 | " (\"human\", \"User question: {question}\"),\n",
151 | " ]\n",
152 | " )\n",
153 | "\n",
154 | " llm = get_llm()\n",
155 | " structured_llm = llm.with_structured_output(GradeQuestion)\n",
156 | " grader_llm = grade_prompt | structured_llm\n",
157 | " result = grader_llm.invoke({\"question\": question})\n",
158 | " print(f\"QUESTION and GRADE: {question} - {result.score}\")\n",
159 | " state[\"on_topic\"] = result.score\n",
160 | " return state"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "def on_topic_router(state: AgentState):\n",
170 | " on_topic = state[\"on_topic\"]\n",
171 | " if on_topic.lower() == \"yes\":\n",
172 | " return \"on_topic\"\n",
173 | " return \"off_topic\""
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "def off_topic_response(state: AgentState):\n",
183 | " state[\"llm_output\"] = \"I cant respond to that!\"\n",
184 | " return state"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "class GradeDocuments(BaseModel):\n",
194 | " \"\"\"Boolean values to check for relevance on retrieved documents.\"\"\"\n",
195 | "\n",
196 | " score: str = Field(\n",
197 | " description=\"Documents are relevant to the question, 'Yes' or 'No'\"\n",
198 | " )\n",
199 | "\n",
200 | "\n",
201 | "def document_grader(state: AgentState):\n",
202 | " docs = state[\"documents\"]\n",
203 | " question = state[\"question\"]\n",
204 | "\n",
205 | " system = \"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n\n",
206 | " If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \\n\n",
207 | " Give a binary score 'Yes' or 'No' score to indicate whether the document is relevant to the question.\"\"\"\n",
208 | "\n",
209 | " grade_prompt = ChatPromptTemplate.from_messages(\n",
210 | " [\n",
211 | " (\"system\", system),\n",
212 | " (\n",
213 | " \"human\",\n",
214 | " \"Retrieved document: \\n\\n {document} \\n\\n User question: {question}\",\n",
215 | " ),\n",
216 | " ]\n",
217 | " )\n",
218 | "\n",
219 | " llm = get_llm()\n",
220 | " structured_llm = llm.with_structured_output(GradeDocuments)\n",
221 | " grader_llm = grade_prompt | structured_llm\n",
222 | " scores = []\n",
223 | " for doc in docs:\n",
224 | " result = grader_llm.invoke({\"document\": doc, \"question\": question})\n",
225 | " scores.append(result.score)\n",
226 | " state[\"grades\"] = scores\n",
227 | " return state"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "def gen_router(state: AgentState):\n",
237 | " grades = state[\"grades\"]\n",
238 | " print(\"DOCUMENT GRADES:\", grades)\n",
239 | "\n",
240 | " if any(grade.lower() == \"yes\" for grade in grades):\n",
241 | " filtered_grades = [grade for grade in grades if grade.lower() == \"yes\"]\n",
242 | " print(\"FILTERED DOCUMENT GRADES:\", filtered_grades)\n",
243 | " return \"generate\"\n",
244 | " else:\n",
245 | " return \"rewrite_query\""
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "from langchain_core.output_parsers import StrOutputParser\n",
255 | "\n",
256 | "\n",
257 | "def rewriter(state: AgentState):\n",
258 | " question = state[\"question\"]\n",
259 | " system = \"\"\"You a question re-writer that converts an input question to a better version that is optimized \\n\n",
260 | " for retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\"\"\"\n",
261 | " re_write_prompt = ChatPromptTemplate.from_messages(\n",
262 | " [\n",
263 | " (\"system\", system),\n",
264 | " (\n",
265 | " \"human\",\n",
266 | " \"Here is the initial question: \\n\\n {question} \\n Formulate an improved question.\",\n",
267 | " ),\n",
268 | " ]\n",
269 | " )\n",
270 | " llm = get_llm()\n",
271 | " question_rewriter = re_write_prompt | llm | StrOutputParser()\n",
272 | " output = question_rewriter.invoke({\"question\": question})\n",
273 | " state[\"question\"] = output\n",
274 | " return state"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {},
281 | "outputs": [],
282 | "source": [
283 | "from langchain_core.prompts import ChatPromptTemplate\n",
284 | "from langchain.schema.output_parser import StrOutputParser\n",
285 | "\n",
286 | "\n",
287 | "def generate_answer(state: AgentState):\n",
288 | " llm = get_llm()\n",
289 | " question = state[\"question\"]\n",
290 | " context = state[\"documents\"]\n",
291 | "\n",
292 | " template = \"\"\"Answer the question based only on the following context:\n",
293 | " {context}\n",
294 | "\n",
295 | " Question: {question}\n",
296 | " \"\"\"\n",
297 | "\n",
298 | " prompt = ChatPromptTemplate.from_template(\n",
299 | " template=template,\n",
300 | " )\n",
301 | " chain = prompt | llm | StrOutputParser()\n",
302 | " result = chain.invoke({\"question\": question, \"context\": context})\n",
303 | " state[\"llm_output\"] = result\n",
304 | " return state"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": null,
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "from langgraph.graph import StateGraph, END\n",
314 | "\n",
315 | "workflow = StateGraph(AgentState)\n",
316 | "\n",
317 | "workflow.add_node(\"topic_decision\", question_classifier)\n",
318 | "workflow.add_node(\"off_topic_response\", off_topic_response)\n",
319 | "workflow.add_node(\"retrieve_docs\", retrieve_docs)\n",
320 | "workflow.add_node(\"rewrite_query\", rewriter)\n",
321 | "workflow.add_node(\"generate_answer\", generate_answer)\n",
322 | "workflow.add_node(\"document_grader\", document_grader)\n",
323 | "\n",
324 | "workflow.add_edge(\"off_topic_response\", END)\n",
325 | "workflow.add_edge(\"retrieve_docs\", \"document_grader\")\n",
326 | "workflow.add_conditional_edges(\n",
327 | " \"topic_decision\",\n",
328 | " on_topic_router,\n",
329 | " {\n",
330 | " \"on_topic\": \"retrieve_docs\",\n",
331 | " \"off_topic\": \"off_topic_response\",\n",
332 | " },\n",
333 | ")\n",
334 | "workflow.add_conditional_edges(\n",
335 | " \"document_grader\",\n",
336 | " gen_router,\n",
337 | " {\n",
338 | " \"generate\": \"generate_answer\",\n",
339 | " \"rewrite_query\": \"rewrite_query\",\n",
340 | " },\n",
341 | ")\n",
342 | "workflow.add_edge(\"rewrite_query\", \"retrieve_docs\")\n",
343 | "workflow.add_edge(\"generate_answer\", END)\n",
344 | "\n",
345 | "\n",
346 | "workflow.set_entry_point(\"topic_decision\")\n",
347 | "\n",
348 | "app = workflow.compile()"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": null,
354 | "metadata": {},
355 | "outputs": [],
356 | "source": [
357 | "from IPython.display import Image, display\n",
358 | "\n",
359 | "try:\n",
360 | " display(Image(app.get_graph(xray=True).draw_mermaid_png()))\n",
361 | "except:\n",
362 | " pass"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": null,
368 | "metadata": {},
369 | "outputs": [],
370 | "source": [
371 | "result = app.invoke({\"question\": \"How is the weather?\"})\n",
372 | "result[\"llm_output\"]"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "metadata": {},
379 | "outputs": [],
380 | "source": [
381 | "result = app.invoke({\"question\": \"Who is the owner of bella vista?\"})\n",
382 | "result[\"llm_output\"]"
383 | ]
384 | }
385 | ],
386 | "metadata": {
387 | "kernelspec": {
388 | "display_name": ".venv",
389 | "language": "python",
390 | "name": "python3"
391 | },
392 | "language_info": {
393 | "codemirror_mode": {
394 | "name": "ipython",
395 | "version": 3
396 | },
397 | "file_extension": ".py",
398 | "mimetype": "text/x-python",
399 | "name": "python",
400 | "nbconvert_exporter": "python",
401 | "pygments_lexer": "ipython3",
402 | "version": "3.11.0"
403 | }
404 | },
405 | "nbformat": 4,
406 | "nbformat_minor": 2
407 | }
408 |
--------------------------------------------------------------------------------
/agent_supervisor.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | "True"
12 | ]
13 | },
14 | "execution_count": 1,
15 | "metadata": {},
16 | "output_type": "execute_result"
17 | }
18 | ],
19 | "source": [
20 | "from dotenv import load_dotenv\n",
21 | "\n",
22 | "load_dotenv()"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from langchain_openai import ChatOpenAI\n",
32 | "\n",
33 | "from langgraph_supervisor import create_supervisor\n",
34 | "from langgraph.prebuilt import create_react_agent\n",
35 | "\n",
36 | "model = ChatOpenAI(model=\"gpt-4o\")\n",
37 | "\n",
38 | "# Create specialized agents\n",
39 | "\n",
40 | "def add(a: float, b: float) -> float:\n",
41 | " \"\"\"Add two numbers.\"\"\"\n",
42 | " return a + b\n",
43 | "\n",
44 | "def multiply(a: float, b: float) -> float:\n",
45 | " \"\"\"Multiply two numbers.\"\"\"\n",
46 | " return a * b\n",
47 | "\n",
48 | "def web_search(query: str) -> str:\n",
49 | " \"\"\"Search the web for information.\"\"\"\n",
50 | " return (\n",
51 | " \"Here are the headcounts for each of the FAANG companies in 2024:\\n\"\n",
52 | " \"1. **Facebook (Meta)**: 67,317 employees.\\n\"\n",
53 | " \"2. **Apple**: 164,000 employees.\\n\"\n",
54 | " \"3. **Amazon**: 1,551,000 employees.\\n\"\n",
55 | " \"4. **Netflix**: 14,000 employees.\\n\"\n",
56 | " \"5. **Google (Alphabet)**: 181,269 employees.\"\n",
57 | " )\n",
58 | "\n",
59 | "math_agent = create_react_agent(\n",
60 | " model=model,\n",
61 | " tools=[add, multiply],\n",
62 | " name=\"math_expert\",\n",
63 | " prompt=\"You are a math expert. Always use one tool at a time.\"\n",
64 | ")\n",
65 | "\n",
66 | "research_agent = create_react_agent(\n",
67 | " model=model,\n",
68 | " tools=[web_search],\n",
69 | " name=\"research_expert\",\n",
70 | " prompt=\"You are a world class researcher with access to web search. Do not do any math.\"\n",
71 | ")\n",
72 | "\n"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 4,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "workflow = create_supervisor(\n",
82 | " [research_agent, math_agent],\n",
83 | " model=model,\n",
84 | " prompt=(\n",
85 | " \"You are a team supervisor managing a research expert and a math expert. \"\n",
86 | " \"For current events, use research_agent. \"\n",
87 | " \"For math problems, use math_agent.\"\n",
88 | " )\n",
89 | ")"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 5,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "app = workflow.compile()\n",
99 | "result = app.invoke({\n",
100 | " \"messages\": [\n",
101 | " {\n",
102 | " \"role\": \"user\",\n",
103 | " \"content\": \"what's the combined headcount of the FAANG companies in 2024?\"\n",
104 | " }\n",
105 | " ]\n",
106 | "})"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 6,
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "data": {
116 | "text/plain": [
117 | "{'messages': [HumanMessage(content=\"what's the combined headcount of the FAANG companies in 2024?\", additional_kwargs={}, response_metadata={}, id='647fb354-aebb-42f9-bb55-434fdf284172'),\n",
118 | " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_LlzrTtOdGsCa2WRMA5n89Pii', 'function': {'arguments': '{}', 'name': 'transfer_to_research_expert'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 104, 'total_tokens': 119, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_ff092ab25e', 'finish_reason': 'tool_calls', 'logprobs': None}, name='supervisor', id='run-c7844f24-b80f-4372-8ebf-3b6e7f7bb66a-0', tool_calls=[{'name': 'transfer_to_research_expert', 'args': {}, 'id': 'call_LlzrTtOdGsCa2WRMA5n89Pii', 'type': 'tool_call'}], usage_metadata={'input_tokens': 104, 'output_tokens': 15, 'total_tokens': 119, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
119 | " ToolMessage(content='Successfully transferred to research_expert', name='transfer_to_research_expert', id='b117eefb-4cac-40f4-9d85-ccf3b1942bb8', tool_call_id='call_LlzrTtOdGsCa2WRMA5n89Pii'),\n",
120 | " AIMessage(content='The combined headcount of the FAANG companies in 2024 is as follows:\\n\\n1. **Facebook (Meta)**: 67,317 employees\\n2. **Apple**: 164,000 employees\\n3. **Amazon**: 1,551,000 employees\\n4. **Netflix**: 14,000 employees\\n5. **Google (Alphabet)**: 181,269 employees\\n\\nIt seems I accidentally repeated the results several times from our tools. Apologies for any confusion!', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 104, 'prompt_tokens': 683, 'total_tokens': 787, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_523b9b6e5f', 'finish_reason': 'stop', 'logprobs': None}, name='research_expert', id='run-8ae2f48a-fde5-4da0-a6e3-b70b397a69f2-0', usage_metadata={'input_tokens': 683, 'output_tokens': 104, 'total_tokens': 787, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
121 | " AIMessage(content='Transferring back to supervisor', additional_kwargs={}, response_metadata={}, name='research_expert', id='46711867-ff47-4d01-bc2d-88cf31005839', tool_calls=[{'name': 'transfer_back_to_supervisor', 'args': {}, 'id': '5e1f84d9-8069-449e-a3a6-1af457e2155f', 'type': 'tool_call'}]),\n",
122 | " ToolMessage(content='Successfully transferred back to supervisor', name='transfer_back_to_supervisor', id='eb7225b7-2533-4d93-b2d3-d387009ddf86', tool_call_id='5e1f84d9-8069-449e-a3a6-1af457e2155f'),\n",
123 | " AIMessage(content='The combined headcount of the FAANG companies in 2024 is as follows:\\n\\n1. Facebook (Meta): 67,317 employees\\n2. Apple: 164,000 employees\\n3. Amazon: 1,551,000 employees\\n4. Netflix: 14,000 employees\\n5. Google (Alphabet): 181,269 employees\\n\\nAdding these together gives a total combined headcount of 1,977,586 employees.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 93, 'prompt_tokens': 296, 'total_tokens': 389, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_ff092ab25e', 'finish_reason': 'stop', 'logprobs': None}, name='supervisor', id='run-a192e46c-d141-42b7-97cf-04b7bff81d11-0', usage_metadata={'input_tokens': 296, 'output_tokens': 93, 'total_tokens': 389, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
124 | ]
125 | },
126 | "execution_count": 6,
127 | "metadata": {},
128 | "output_type": "execute_result"
129 | }
130 | ],
131 | "source": [
132 | "result"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "### Supervisor Agent with low level agent"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 7,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "from typing import Literal\n",
149 | "\n",
150 | "\n",
151 | "from langchain_openai import ChatOpenAI\n",
152 | "from langchain_core.tools import tool\n",
153 | "from langchain_core.messages import HumanMessage\n",
154 | "from langgraph.graph import END, START, StateGraph, MessagesState\n",
155 | "from langgraph.graph.state import CompiledStateGraph\n",
156 | "from langgraph.prebuilt import ToolNode\n",
157 | "\n",
158 | "# --- Weather tool ---\n",
159 | "@tool\n",
160 | "def get_weather(location: str):\n",
161 | " \"\"\"Call to get the current weather.\"\"\"\n",
162 | " if location.lower() in [\"munich\"]:\n",
163 | " return \"It's 15 degrees Celsius and cloudy.\"\n",
164 | " else:\n",
165 | " return \"It's 32 degrees Celsius and sunny.\"\n",
166 | "\n",
167 | "# We'll create a model and bind the tool so the LLM knows it can call `get_weather`.\n",
168 | "tools = [get_weather]\n",
169 | "model = ChatOpenAI(model=\"gpt-4o-mini\").bind_tools(tools)\n",
170 | "\n",
171 | "# --- Existing agent workflow definition ---\n",
172 | "def call_model(state: MessagesState):\n",
173 | " \"\"\"Call the LLM with the conversation so far.\"\"\"\n",
174 | " messages = state[\"messages\"]\n",
175 | " response = model.invoke(messages)\n",
176 | " return {\"messages\": [response]}\n",
177 | "\n",
178 | "def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n",
179 | " \"\"\"If there's a tool call requested, go to 'tools', else end.\"\"\"\n",
180 | " messages = state[\"messages\"]\n",
181 | " last_message = messages[-1]\n",
182 | " if last_message.tool_calls:\n",
183 | " return \"tools\"\n",
184 | " return END\n",
185 | "\n",
186 | "weather_workflow = StateGraph(MessagesState)\n",
187 | "\n",
188 | "tool_node = ToolNode(tools)\n",
189 | "\n",
190 | "weather_workflow.add_node(\"agent\", call_model)\n",
191 | "weather_workflow.add_node(\"tools\", tool_node)\n",
192 | "\n",
193 | "weather_workflow.add_edge(START, \"agent\")\n",
194 | "weather_workflow.add_conditional_edges(\"agent\", should_continue)\n",
195 | "weather_workflow.add_edge(\"tools\", \"agent\")\n",
196 | "\n",
197 | "weather_agent_graph = weather_workflow.compile(name=\"weather_agent\")"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 11,
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "from langgraph_supervisor import create_supervisor\n",
207 | "\n",
208 | "supervisor_workflow = create_supervisor(\n",
209 | " agents=[weather_agent_graph],\n",
210 | " model=model,\n",
211 | " prompt=(\n",
212 | " \"You are a supervisor managing a weather agent. \"\n",
213 | " \"For any weather-related question, call the 'weather_agent' to handle it.\"\n",
214 | " ),\n",
215 | " output_mode=\"last_message\",\n",
216 | " #output_mode=\"full_history\",\n",
217 | " supervisor_name=\"supervisor_agent\",\n",
218 | ")\n",
219 | "\n",
220 | "supervisor_app = supervisor_workflow.compile()"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 12,
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "data": {
230 | "text/plain": [
231 | "{'messages': [HumanMessage(content='Hello there, how are you?', additional_kwargs={}, response_metadata={}, id='2d7b2062-da5e-44bf-9038-cfda9aba1ba8'),\n",
232 | " AIMessage(content=\"I'm just a program, so I don't have feelings, but I'm here and ready to help you! How can I assist you today?\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 29, 'prompt_tokens': 71, 'total_tokens': 100, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_00428b782a', 'finish_reason': 'stop', 'logprobs': None}, name='supervisor_agent', id='run-f844fa58-516e-4cd8-a2be-5e80aedb5fde-0', usage_metadata={'input_tokens': 71, 'output_tokens': 29, 'total_tokens': 100, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
233 | ]
234 | },
235 | "execution_count": 12,
236 | "metadata": {},
237 | "output_type": "execute_result"
238 | }
239 | ],
240 | "source": [
241 | "supervisor_app.invoke(\n",
242 | " {\"messages\": [HumanMessage(content=\"Hello there, how are you?\")]}\n",
243 | ")"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 13,
249 | "metadata": {},
250 | "outputs": [
251 | {
252 | "data": {
253 | "text/plain": [
254 | "{'messages': [HumanMessage(content='How is the weather in Munich?', additional_kwargs={}, response_metadata={}, id='7b6e7f69-19f3-4af5-9327-66c0478cfe68'),\n",
255 | " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_pU8pdeSkHkerBjNZ9KrzGTe6', 'function': {'arguments': '{}', 'name': 'transfer_to_weather_agent'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 13, 'prompt_tokens': 71, 'total_tokens': 84, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_13eed4fce1', 'finish_reason': 'tool_calls', 'logprobs': None}, name='supervisor_agent', id='run-f01e26aa-d1ad-420a-ae5f-8fca1376adaa-0', tool_calls=[{'name': 'transfer_to_weather_agent', 'args': {}, 'id': 'call_pU8pdeSkHkerBjNZ9KrzGTe6', 'type': 'tool_call'}], usage_metadata={'input_tokens': 71, 'output_tokens': 13, 'total_tokens': 84, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
256 | " ToolMessage(content='Successfully transferred to weather_agent', name='transfer_to_weather_agent', id='d443fd8f-71ad-4759-9eda-87fd0119ad15', tool_call_id='call_pU8pdeSkHkerBjNZ9KrzGTe6'),\n",
257 | " AIMessage(content='The weather in Munich is currently 15 degrees Celsius and cloudy.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 113, 'total_tokens': 128, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_00428b782a', 'finish_reason': 'stop', 'logprobs': None}, id='run-6d8ba1ea-2ef2-4504-baf4-cba36dcc58c2-0', usage_metadata={'input_tokens': 113, 'output_tokens': 15, 'total_tokens': 128, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
258 | " AIMessage(content='Transferring back to supervisor_agent', additional_kwargs={}, response_metadata={}, name='weather_agent', id='d72b7506-90b0-4ab6-9714-41a9d38729d6', tool_calls=[{'name': 'transfer_back_to_supervisor_agent', 'args': {}, 'id': 'd1870f97-c144-46b1-87ae-b72720e7540b', 'type': 'tool_call'}]),\n",
259 | " ToolMessage(content='Successfully transferred back to supervisor_agent', name='transfer_back_to_supervisor_agent', id='c223164a-8edf-40e1-b971-33ac7eb08603', tool_call_id='d1870f97-c144-46b1-87ae-b72720e7540b'),\n",
260 | " AIMessage(content='The weather in Munich is currently 15 degrees Celsius and cloudy. If you have any more questions or need further information, feel free to ask!', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 31, 'prompt_tokens': 168, 'total_tokens': 199, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_13eed4fce1', 'finish_reason': 'stop', 'logprobs': None}, name='supervisor_agent', id='run-b555bf17-19b4-43a6-95e8-a59d9dae6357-0', usage_metadata={'input_tokens': 168, 'output_tokens': 31, 'total_tokens': 199, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
261 | ]
262 | },
263 | "execution_count": 13,
264 | "metadata": {},
265 | "output_type": "execute_result"
266 | }
267 | ],
268 | "source": [
269 | "supervisor_app.invoke(\n",
270 | " {\"messages\": [HumanMessage(content=\"How is the weather in Munich?\")]}\n",
271 | ")\n"
272 | ]
273 | }
274 | ],
275 | "metadata": {
276 | "kernelspec": {
277 | "display_name": ".venv",
278 | "language": "python",
279 | "name": "python3"
280 | },
281 | "language_info": {
282 | "codemirror_mode": {
283 | "name": "ipython",
284 | "version": 3
285 | },
286 | "file_extension": ".py",
287 | "mimetype": "text/x-python",
288 | "name": "python",
289 | "nbconvert_exporter": "python",
290 | "pygments_lexer": "ipython3",
291 | "version": "3.11.0"
292 | }
293 | },
294 | "nbformat": 4,
295 | "nbformat_minor": 2
296 | }
297 |
--------------------------------------------------------------------------------
/custom_persistence.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | "True"
12 | ]
13 | },
14 | "execution_count": 1,
15 | "metadata": {},
16 | "output_type": "execute_result"
17 | }
18 | ],
19 | "source": [
20 | "from dotenv import load_dotenv\n",
21 | "\n",
22 | "load_dotenv()"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from typing import Annotated\n",
32 | "\n",
33 | "from typing_extensions import TypedDict\n",
34 | "\n",
35 | "from langgraph.graph.message import add_messages\n",
36 | "\n",
37 | "\n",
38 | "class State(TypedDict):\n",
39 | " messages: Annotated[list, add_messages]"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "from langchain_core.tools import tool\n",
49 | "\n",
50 | "\n",
51 | "@tool\n",
52 | "def search(query: str):\n",
53 | " \"\"\"Call to surf the web.\"\"\"\n",
54 | " # This is a placeholder for the actual implementation\n",
55 | " return [\"This is a placeholder response.\"]\n",
56 | "\n",
57 | "\n",
58 | "tools = [search]"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 4,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "from langchain_openai import ChatOpenAI\n",
68 | "\n",
69 | "\n",
70 | "model = ChatOpenAI(temperature=0, streaming=True)\n",
71 | "\n",
72 | "bound_model = model.bind_tools(tools)"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 5,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "from langgraph.prebuilt import ToolNode\n",
82 | "\n",
83 | "tool_node = ToolNode(tools)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 8,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "from typing import Literal\n",
93 | "\n",
94 | "\n",
95 | "def should_continue(state: State) -> Literal[\"action\", \"__end__\"]:\n",
96 | " \"\"\"Return the next node to execute.\"\"\"\n",
97 | " last_message = state[\"messages\"][-1]\n",
98 | " if not last_message.tool_calls:\n",
99 | " return \"__end__\"\n",
100 | " return \"action\"\n",
101 | "\n",
102 | "\n",
103 | "def call_model(state: State):\n",
104 | " response = bound_model.invoke(state[\"messages\"])\n",
105 | " return {\"messages\": response}"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "import psycopg2\n",
115 | "from psycopg2 import sql\n",
116 | "from contextlib import contextmanager\n",
117 | "from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointTuple\n",
118 | "from langchain_core.runnables import RunnableConfig\n",
119 | "from typing import Optional, Iterator, AsyncIterator\n",
120 | "from datetime import datetime, timezone\n",
121 | "\n",
122 | "\n",
123 | "class PostgresSaver(BaseCheckpointSaver):\n",
124 | " def __init__(self, connection):\n",
125 | " self.connection = connection\n",
126 | "\n",
127 | " @classmethod\n",
128 | " def from_conn_string(cls, conn_string):\n",
129 | " connection = psycopg2.connect(conn_string)\n",
130 | " return cls(connection)\n",
131 | "\n",
132 | " @contextmanager\n",
133 | " def cursor(self):\n",
134 | " \"\"\"Provide a transactional scope around a series of operations.\"\"\"\n",
135 | " cursor = self.connection.cursor()\n",
136 | " try:\n",
137 | " yield cursor\n",
138 | " self.connection.commit()\n",
139 | " except Exception as e:\n",
140 | " self.connection.rollback()\n",
141 | " raise e\n",
142 | " finally:\n",
143 | " cursor.close()\n",
144 | "\n",
145 | " def setup(self) -> None:\n",
146 | " with self.cursor() as cursor:\n",
147 | " create_table_query = \"\"\"\n",
148 | " CREATE TABLE IF NOT EXISTS checkpoints (\n",
149 | " thread_id TEXT NOT NULL,\n",
150 | " thread_ts TEXT NOT NULL,\n",
151 | " parent_ts TEXT,\n",
152 | " checkpoint BYTEA,\n",
153 | " metadata BYTEA,\n",
154 | " PRIMARY KEY (thread_id, thread_ts)\n",
155 | " );\n",
156 | " \"\"\"\n",
157 | " cursor.execute(create_table_query)\n",
158 | "\n",
159 | " def get_latest_timestamp(self, thread_id: str) -> str:\n",
160 | " with self.cursor() as cursor:\n",
161 | " select_query = sql.SQL(\n",
162 | " \"SELECT thread_ts FROM checkpoints WHERE thread_id = %s ORDER BY thread_ts DESC LIMIT 1\"\n",
163 | " )\n",
164 | " cursor.execute(select_query, (thread_id,))\n",
165 | " result = cursor.fetchone()\n",
166 | " return result[0] if result else None\n",
167 | "\n",
168 | " def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:\n",
169 | " thread_id = config[\"configurable\"][\"thread_id\"]\n",
170 | " thread_ts = config[\"configurable\"].get(\n",
171 | " \"thread_ts\", self.get_latest_timestamp(thread_id)\n",
172 | " )\n",
173 | "\n",
174 | " with self.cursor() as cursor:\n",
175 | " select_query = sql.SQL(\n",
176 | " \"SELECT checkpoint, metadata, parent_ts FROM checkpoints WHERE thread_id = %s AND thread_ts = %s\"\n",
177 | " )\n",
178 | " cursor.execute(select_query, (thread_id, thread_ts))\n",
179 | " result = cursor.fetchone()\n",
180 | " if result:\n",
181 | " checkpoint, metadata, parent_ts = result\n",
182 | " return CheckpointTuple(\n",
183 | " config,\n",
184 | " self.serde.loads(bytes(checkpoint)),\n",
185 | " self.serde.loads(bytes(metadata)),\n",
186 | " (\n",
187 | " {\n",
188 | " \"configurable\": {\n",
189 | " \"thread_id\": thread_id,\n",
190 | " \"thread_ts\": parent_ts,\n",
191 | " }\n",
192 | " }\n",
193 | " if parent_ts\n",
194 | " else None\n",
195 | " ),\n",
196 | " )\n",
197 | " return None\n",
198 | "\n",
199 | " def list(\n",
200 | " self,\n",
201 | " config: RunnableConfig,\n",
202 | " *,\n",
203 | " before: Optional[RunnableConfig] = None,\n",
204 | " limit: Optional[int] = None,\n",
205 | " ) -> Iterator[CheckpointTuple]:\n",
206 | " thread_id = config[\"configurable\"][\"thread_id\"]\n",
207 | " query = \"\"\"\n",
208 | " SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata\n",
209 | " FROM checkpoints\n",
210 | " WHERE thread_id = %s\n",
211 | " \"\"\"\n",
212 | " params = [thread_id]\n",
213 | " if before:\n",
214 | " query += \" AND thread_ts < %s\"\n",
215 | " params.append(before[\"configurable\"][\"thread_ts\"])\n",
216 | " query += \" ORDER BY thread_ts DESC\"\n",
217 | " if limit:\n",
218 | " query += f\" LIMIT {limit}\"\n",
219 | "\n",
220 | " with self.cursor() as cursor:\n",
221 | " cursor.execute(query, params)\n",
222 | " for thread_id, thread_ts, parent_ts, checkpoint, metadata in cursor:\n",
223 | " yield CheckpointTuple(\n",
224 | " {\"configurable\": {\"thread_id\": thread_id, \"thread_ts\": thread_ts}},\n",
225 | " self.serde.loads(bytes(checkpoint)),\n",
226 | " self.serde.loads(bytes(metadata)) if metadata else {},\n",
227 | " (\n",
228 | " {\n",
229 | " \"configurable\": {\n",
230 | " \"thread_id\": thread_id,\n",
231 | " \"thread_ts\": parent_ts,\n",
232 | " }\n",
233 | " }\n",
234 | " if parent_ts\n",
235 | " else None\n",
236 | " ),\n",
237 | " )\n",
238 | "\n",
239 | " def put(\n",
240 | " self, config: RunnableConfig, checkpoint: dict, metadata: dict\n",
241 | " ) -> RunnableConfig:\n",
242 | " thread_id = config[\"configurable\"][\"thread_id\"]\n",
243 | " thread_ts = datetime.now(timezone.utc).isoformat()\n",
244 | " parent_ts = config[\"configurable\"].get(\"thread_ts\")\n",
245 | "\n",
246 | " with self.cursor() as cursor:\n",
247 | " insert_query = sql.SQL(\n",
248 | " \"\"\"\n",
249 | " INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata)\n",
250 | " VALUES (%s, %s, %s, %s, %s)\n",
251 | " ON CONFLICT (thread_id, thread_ts) DO UPDATE\n",
252 | " SET parent_ts = EXCLUDED.parent_ts, checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata\n",
253 | " \"\"\"\n",
254 | " )\n",
255 | " cursor.execute(\n",
256 | " insert_query,\n",
257 | " (\n",
258 | " thread_id,\n",
259 | " thread_ts,\n",
260 | " parent_ts,\n",
261 | " self.serde.dumps(checkpoint),\n",
262 | " self.serde.dumps(metadata),\n",
263 | " ),\n",
264 | " )\n",
265 | "\n",
266 | " return {\n",
267 | " \"configurable\": {\n",
268 | " \"thread_id\": thread_id,\n",
269 | " \"thread_ts\": thread_ts,\n",
270 | " }\n",
271 | " }\n",
272 | "\n",
273 | " async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:\n",
274 | " return self.get_tuple(config)\n",
275 | "\n",
276 | " async def alist(\n",
277 | " self,\n",
278 | " config: RunnableConfig,\n",
279 | " *,\n",
280 | " before: Optional[RunnableConfig] = None,\n",
281 | " limit: Optional[int] = None,\n",
282 | " ) -> AsyncIterator[CheckpointTuple]:\n",
283 | " for checkpoint_tuple in self.list(config, before=before, limit=limit):\n",
284 | " yield checkpoint_tuple\n",
285 | "\n",
286 | " async def aput(\n",
287 | " self, config: RunnableConfig, checkpoint: dict, metadata: dict\n",
288 | " ) -> RunnableConfig:\n",
289 | " return self.put(config, checkpoint, metadata)\n",
290 | "\n",
291 | " def close(self):\n",
292 | " self.connection.close()"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 9,
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "from langgraph.graph import StateGraph\n",
302 | "\n",
303 | "graph = StateGraph(State)\n",
304 | "\n",
305 | "graph.add_node(\"agent\", call_model)\n",
306 | "graph.add_node(\"action\", tool_node)\n",
307 | "\n",
308 | "graph.set_entry_point(\"agent\")\n",
309 | "\n",
310 | "graph.add_conditional_edges(\n",
311 | " \"agent\",\n",
312 | " should_continue,\n",
313 | ")\n",
314 | "\n",
315 | "graph.add_edge(\"action\", \"agent\")"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 11,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "conn_string = (\n",
325 | " \"dbname=mydatabase user=myuser password=mypassword host=localhost port=5433\"\n",
326 | ")\n",
327 | "memory = PostgresSaver.from_conn_string(conn_string)\n",
328 | "\n",
329 | "runnable = graph.compile(checkpointer=memory)"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 12,
335 | "metadata": {},
336 | "outputs": [
337 | {
338 | "data": {
339 | "text/plain": [
340 | "{'messages': [HumanMessage(content='Hello, I am John', id='fc051ee8-490a-43bb-8a92-fd06c55c2568'),\n",
341 | " AIMessage(content='Hi John! How can I assist you today?', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-70bdbdf9-23b9-423e-b6ca-6a4609c33dbd-0')]}"
342 | ]
343 | },
344 | "execution_count": 12,
345 | "metadata": {},
346 | "output_type": "execute_result"
347 | }
348 | ],
349 | "source": [
350 | "from langchain_core.messages import HumanMessage\n",
351 | "\n",
352 | "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
353 | "input_message = HumanMessage(content=\"Hello, I am John\")\n",
354 | "\n",
355 | "runnable.invoke({\"messages\": input_message}, config=config)"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": 13,
361 | "metadata": {},
362 | "outputs": [
363 | {
364 | "data": {
365 | "text/plain": [
366 | "{'messages': [HumanMessage(content='Did I already introduce myself?', id='c3190aba-d1eb-4d78-84de-749233c3a8c4'),\n",
367 | " AIMessage(content=\"I'm not sure, would you like me to search for it?\", response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-d8a84f99-b496-40ad-b77b-332c1b4a4d71-0')]}"
368 | ]
369 | },
370 | "execution_count": 13,
371 | "metadata": {},
372 | "output_type": "execute_result"
373 | }
374 | ],
375 | "source": [
376 | "from langchain_core.messages import HumanMessage\n",
377 | "\n",
378 | "config = {\"configurable\": {\"thread_id\": \"42\"}}\n",
379 | "input_message = HumanMessage(content=\"Did I already introduce myself?\")\n",
380 | "\n",
381 | "runnable.invoke({\"messages\": input_message}, config=config)"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": 14,
387 | "metadata": {},
388 | "outputs": [
389 | {
390 | "data": {
391 | "text/plain": [
392 | "{'messages': [HumanMessage(content='Hello, I am John', id='fc051ee8-490a-43bb-8a92-fd06c55c2568'),\n",
393 | " AIMessage(content='Hi John! How can I assist you today?', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-70bdbdf9-23b9-423e-b6ca-6a4609c33dbd-0'),\n",
394 | " HumanMessage(content='Did I already introduce myself?', id='d2c01dc6-f3d6-49cb-a539-edae1b8fdaa7'),\n",
395 | " AIMessage(content='Yes, you introduced yourself as John. How can I help you further, John?', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-ef437b21-6b5e-48c1-9625-042c2e0a9cbd-0')]}"
396 | ]
397 | },
398 | "execution_count": 14,
399 | "metadata": {},
400 | "output_type": "execute_result"
401 | }
402 | ],
403 | "source": [
404 | "from langchain_core.messages import HumanMessage\n",
405 | "\n",
406 | "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
407 | "input_message = HumanMessage(content=\"Did I already introduce myself?\")\n",
408 | "\n",
409 | "runnable.invoke({\"messages\": input_message}, config=config)"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "metadata": {},
415 | "source": [
416 | "### How to manage memory?"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 15,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": [
425 | "import psycopg2\n",
426 | "from psycopg2 import sql\n",
427 | "from contextlib import contextmanager\n",
428 | "\n",
429 | "\n",
430 | "class MemoryManager:\n",
431 | " def __init__(self, conn_string):\n",
432 | " self.conn_string = conn_string\n",
433 | "\n",
434 | " @contextmanager\n",
435 | " def connection(self):\n",
436 | " \"\"\"Provide a transactional scope around a series of operations.\"\"\"\n",
437 | " connection = psycopg2.connect(self.conn_string)\n",
438 | " try:\n",
439 | " yield connection\n",
440 | " connection.commit()\n",
441 | " except Exception as e:\n",
442 | " connection.rollback()\n",
443 | " raise e\n",
444 | " finally:\n",
445 | " connection.close()\n",
446 | "\n",
447 | " @contextmanager\n",
448 | " def cursor(self):\n",
449 | " \"\"\"Provide a cursor for database operations.\"\"\"\n",
450 | " with self.connection() as connection:\n",
451 | " cursor = connection.cursor()\n",
452 | " try:\n",
453 | " yield cursor\n",
454 | " finally:\n",
455 | " cursor.close()\n",
456 | "\n",
457 | " def delete_by_thread_id(self, thread_id: str) -> None:\n",
458 | " \"\"\"Delete memory based on thread ID.\n",
459 | "\n",
460 | " This method deletes entries from the checkpoints table where the thread_id matches\n",
461 | " the specified value.\n",
462 | "\n",
463 | " Args:\n",
464 | " thread_id (str): The thread ID for which the memory should be deleted.\n",
465 | " \"\"\"\n",
466 | " with self.cursor() as cursor:\n",
467 | " delete_query = sql.SQL(\"DELETE FROM checkpoints WHERE thread_id = %s\")\n",
468 | " cursor.execute(delete_query, (thread_id,))\n",
469 | "\n",
470 | " def count_checkpoints_by_thread_id(self) -> None:\n",
471 | " \"\"\"Count the number of checkpoints for each thread ID.\n",
472 | "\n",
473 | " This method retrieves the count of checkpoints grouped by thread_id and prints\n",
474 | " the result.\n",
475 | "\n",
476 | " Returns:\n",
477 | " None\n",
478 | " \"\"\"\n",
479 | " with self.cursor() as cursor:\n",
480 | " count_query = \"\"\"\n",
481 | " SELECT thread_id, COUNT(*) AS count\n",
482 | " FROM checkpoints\n",
483 | " GROUP BY thread_id\n",
484 | " ORDER BY thread_id;\n",
485 | " \"\"\"\n",
486 | " cursor.execute(count_query)\n",
487 | " results = cursor.fetchall()\n",
488 | " print(\"Checkpoint counts by thread ID:\")\n",
489 | " for row in results:\n",
490 | " print(f\"Thread ID: {row[0]}, Count: {row[1]}\")\n",
491 | "\n",
492 | " def delete_all(self) -> None:\n",
493 | " \"\"\"Delete all memory.\n",
494 | "\n",
495 | " This method deletes all entries from the checkpoints table.\n",
496 | " \"\"\"\n",
497 | " with self.cursor() as cursor:\n",
498 | " delete_query = \"DELETE FROM checkpoints\"\n",
499 | " cursor.execute(delete_query)"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": 16,
505 | "metadata": {},
506 | "outputs": [],
507 | "source": [
508 | "conn_string = (\n",
509 | " \"dbname=mydatabase user=myuser password=mypassword host=localhost port=5433\"\n",
510 | ")\n",
511 | "memory_manager = MemoryManager(conn_string)"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 17,
517 | "metadata": {},
518 | "outputs": [
519 | {
520 | "name": "stdout",
521 | "output_type": "stream",
522 | "text": [
523 | "Checkpoint counts by thread ID:\n",
524 | "Thread ID: 1, Count: 6\n",
525 | "Thread ID: 42, Count: 3\n"
526 | ]
527 | }
528 | ],
529 | "source": [
530 | "memory_manager.count_checkpoints_by_thread_id()"
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "execution_count": 18,
536 | "metadata": {},
537 | "outputs": [],
538 | "source": [
539 | "thread_id_to_delete = \"1\"\n",
540 | "memory_manager.delete_by_thread_id(thread_id_to_delete)"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": 19,
546 | "metadata": {},
547 | "outputs": [
548 | {
549 | "name": "stdout",
550 | "output_type": "stream",
551 | "text": [
552 | "Checkpoint counts by thread ID:\n",
553 | "Thread ID: 42, Count: 3\n"
554 | ]
555 | }
556 | ],
557 | "source": [
558 | "memory_manager.count_checkpoints_by_thread_id()"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": 20,
564 | "metadata": {},
565 | "outputs": [],
566 | "source": [
567 | "memory_manager.delete_all()"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": 21,
573 | "metadata": {},
574 | "outputs": [
575 | {
576 | "name": "stdout",
577 | "output_type": "stream",
578 | "text": [
579 | "Checkpoint counts by thread ID:\n"
580 | ]
581 | }
582 | ],
583 | "source": [
584 | "memory_manager.count_checkpoints_by_thread_id()"
585 | ]
586 | }
587 | ],
588 | "metadata": {
589 | "kernelspec": {
590 | "display_name": "app",
591 | "language": "python",
592 | "name": "python3"
593 | },
594 | "language_info": {
595 | "codemirror_mode": {
596 | "name": "ipython",
597 | "version": 3
598 | },
599 | "file_extension": ".py",
600 | "mimetype": "text/x-python",
601 | "name": "python",
602 | "nbconvert_exporter": "python",
603 | "pygments_lexer": "ipython3",
604 | "version": "3.11.0"
605 | }
606 | },
607 | "nbformat": 4,
608 | "nbformat_minor": 2
609 | }
610 |
--------------------------------------------------------------------------------
/memory.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from dotenv import load_dotenv\n",
10 | "\n",
11 | "load_dotenv()"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from langmem import create_memory_manager\n",
21 | "from pydantic import BaseModel\n",
22 | "from langchain_openai.chat_models import ChatOpenAI\n",
23 | "\n",
24 | "llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
25 | "\n",
26 | "class Person(BaseModel):\n",
27 | " \"\"\"Store a person's name, role, and preferences.\"\"\"\n",
28 | " name: str\n",
29 | " role: str\n",
30 | " preferences: list[str] | None = None\n",
31 | "\n",
32 | "\n",
33 | "manager = create_memory_manager(\n",
34 | " llm,\n",
35 | " schemas=[Person],\n",
36 | " instructions=\"Extract people's names, roles, and any mentioned preferences.\",\n",
37 | " enable_inserts=True,\n",
38 | " enable_updates=True,\n",
39 | " enable_deletes=True,\n",
40 | ")"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [
48 | {
49 | "name": "stdout",
50 | "output_type": "stream",
51 | "text": [
52 | "[ExtractedMemory(id='7fcaa043-046d-4b25-9af9-23741ddc4134', content=Person(name='John', role='senior developer', preferences=['coffee'])), ExtractedMemory(id='7d065021-b8e1-4352-9a74-06fa8cc4f764', content=Person(name='Alice', role='junior developer', preferences=['hates coffee']))]\n"
53 | ]
54 | }
55 | ],
56 | "source": [
57 | "conversation = [\n",
58 | " {\n",
59 | " \"role\": \"user\",\n",
60 | " \"content\": (\n",
61 | " \"John is a senior developer who loves coffee. \"\n",
62 | " \"Alice is a junior developer who hates coffee.\"\n",
63 | " )\n",
64 | " }\n",
65 | "]\n",
66 | "memories = manager.invoke({\"messages\": conversation})\n",
67 | "print(memories)"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 4,
73 | "metadata": {},
74 | "outputs": [
75 | {
76 | "name": "stdout",
77 | "output_type": "stream",
78 | "text": [
79 | "Error code: 500 - {'error': {'message': 'The model produced invalid content. Consider modifying your prompt if you are seeing this error persistently.', 'type': 'model_error', 'param': None, 'code': None}}\n"
80 | ]
81 | }
82 | ],
83 | "source": [
84 | "try:\n",
85 | " conversation_no_extraction = [\n",
86 | " {\n",
87 | " \"role\": \"user\",\n",
88 | " \"content\": (\n",
89 | " \"Today it rained for two hours, and then the sun came out.\"\n",
90 | " )\n",
91 | " }\n",
92 | " ]\n",
93 | "\n",
94 | " memories_no_extraction = manager.invoke({\"messages\": conversation_no_extraction})\n",
95 | " print(memories_no_extraction)\n",
96 | "except Exception as e:\n",
97 | " print(e)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 5,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "name": "stdout",
107 | "output_type": "stream",
108 | "text": [
109 | "Exception: Called get_config outside of a runnable context\n"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "from langmem import create_memory_store_manager\n",
115 | "\n",
116 | "namespace=(\"memories\",)\n",
117 | "\n",
118 | "memory_manager = create_memory_store_manager(\n",
119 | " llm,\n",
120 | " namespace=namespace,\n",
121 | " instructions=\"Only save information related about food the user likes\"\n",
122 | ")\n",
123 | "try:\n",
124 | " memory_manager.invoke({\"messages\": [\"I like dogs. My dog's name is Fido.\"]})\n",
125 | "except Exception as e:\n",
126 | " print(\"Exception:\", e)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 10,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "from langmem import ReflectionExecutor\n",
136 | "\n",
137 | "executor = ReflectionExecutor(memory_manager)"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 11,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "from langgraph.store.memory import InMemoryStore\n",
147 | "from langgraph.func import entrypoint\n",
148 | "\n",
149 | "store = InMemoryStore(\n",
150 | " index={\n",
151 | " \"dims\": 1536,\n",
152 | " \"embed\": \"openai:text-embedding-3-small\",\n",
153 | " }\n",
154 | ")\n",
155 | "\n",
156 | "@entrypoint(store=store)\n",
157 | "async def chat(message: str):\n",
158 | " response = llm.invoke(message)\n",
159 | "\n",
160 | " to_process = {\"messages\": [{\"role\": \"user\", \"content\": message}] + [response]}\n",
161 | " # await memory_manager.ainvoke(to_process)\n",
162 | " executor.submit(to_process, after_seconds=1)\n",
163 | " return response.content"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 12,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "response = await chat.ainvoke(\n",
173 | " \"I like to eat Pizza\",\n",
174 | ")"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 13,
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "try:\n",
184 | " response = await chat.ainvoke(\n",
185 | " \"I like dogs. My dog's name is Fido.\",\n",
186 | " )\n",
187 | "except Exception as e:\n",
188 | " print(\"Exception:\", e)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 14,
194 | "metadata": {},
195 | "outputs": [
196 | {
197 | "data": {
198 | "text/plain": [
199 | "[Item(namespace=['memories'], key='670da84c-c048-44fe-8bde-c4f5377ae95e', value={'kind': 'Memory', 'content': {'content': 'User likes pizza.'}}, created_at='2025-03-08T16:12:08.805877+00:00', updated_at='2025-03-08T16:12:08.805877+00:00', score=None)]"
200 | ]
201 | },
202 | "execution_count": 14,
203 | "metadata": {},
204 | "output_type": "execute_result"
205 | }
206 | ],
207 | "source": [
208 | "store.search(namespace)"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {},
214 | "source": [
215 | "### Better approach - Tools!"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 16,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "from langgraph.store.memory import InMemoryStore\n",
225 | "from langgraph.prebuilt import create_react_agent\n",
226 | "from langmem import create_manage_memory_tool, create_search_memory_tool\n",
227 | "\n",
228 | "\n",
229 | "store = InMemoryStore(\n",
230 | " index={\n",
231 | " \"dims\": 1536,\n",
232 | " \"embed\": \"openai:text-embedding-3-small\",\n",
233 | " }\n",
234 | ")\n"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 17,
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
244 | "\n",
245 | "tools=[\n",
246 | " create_manage_memory_tool(namespace=(\"memories\", \"{user_id}\"), store=store),\n",
247 | " create_search_memory_tool(namespace=(\"memories\", \"{user_id}\"), store=store),\n",
248 | "]"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 18,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "app = create_react_agent(llm, tools=tools)"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 19,
263 | "metadata": {},
264 | "outputs": [
265 | {
266 | "data": {
267 | "text/plain": [
268 | "{'messages': [HumanMessage(content='hi!', additional_kwargs={}, response_metadata={}, id='6fb5b4bc-f670-43e4-8975-df9c33635004'),\n",
269 | " AIMessage(content='Hello! How can I assist you today?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 11, 'prompt_tokens': 241, 'total_tokens': 252, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'stop', 'logprobs': None}, id='run-e7321c74-1828-4ce2-b22a-23f361fb597f-0', usage_metadata={'input_tokens': 241, 'output_tokens': 11, 'total_tokens': 252, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
270 | ]
271 | },
272 | "execution_count": 19,
273 | "metadata": {},
274 | "output_type": "execute_result"
275 | }
276 | ],
277 | "source": [
278 | "app.invoke({\"messages\": [{\"role\": \"user\", \"content\": \"hi!\"}]}, config={\"configurable\": {\"user_id\": \"alice\"}})"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 20,
284 | "metadata": {},
285 | "outputs": [
286 | {
287 | "data": {
288 | "text/plain": [
289 | "{'messages': [HumanMessage(content='what do you know about me?', additional_kwargs={}, response_metadata={}, id='0207f037-b34c-429b-975e-82ae56aadbaa'),\n",
290 | " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_HLqHXNhU3j7JeL8PNYK50GgZ', 'function': {'arguments': '{\"query\":\"user\"}', 'name': 'search_memory'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 246, 'total_tokens': 261, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-2b1b703b-5ff1-4996-ac3b-0781af78b824-0', tool_calls=[{'name': 'search_memory', 'args': {'query': 'user'}, 'id': 'call_HLqHXNhU3j7JeL8PNYK50GgZ', 'type': 'tool_call'}], usage_metadata={'input_tokens': 246, 'output_tokens': 15, 'total_tokens': 261, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
291 | " ToolMessage(content=[], name='search_memory', id='736d3040-8336-4e7f-a269-818816ba9fc5', tool_call_id='call_HLqHXNhU3j7JeL8PNYK50GgZ'),\n",
292 | " AIMessage(content=\"I don't have any specific information about you stored in my memory. If you'd like me to remember something or if you have any preferences you'd like to share, feel free to let me know!\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 40, 'prompt_tokens': 268, 'total_tokens': 308, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'stop', 'logprobs': None}, id='run-588fe1bf-5c81-432f-a8ad-02d3d63b28ea-0', usage_metadata={'input_tokens': 268, 'output_tokens': 40, 'total_tokens': 308, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
293 | ]
294 | },
295 | "execution_count": 20,
296 | "metadata": {},
297 | "output_type": "execute_result"
298 | }
299 | ],
300 | "source": [
301 | "app.invoke({\"messages\": [{\"role\": \"user\", \"content\": \"what do you know about me?\"}]}, config={\"configurable\": {\"user_id\": \"alice\"}})"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": 21,
307 | "metadata": {},
308 | "outputs": [
309 | {
310 | "data": {
311 | "text/plain": [
312 | "{'messages': [HumanMessage(content='I love spaghetti', additional_kwargs={}, response_metadata={}, id='9ac48e9d-5624-46bf-9a42-8bb44718403a'),\n",
313 | " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_IOMqnQdUMBahmicU2xYBC0cz', 'function': {'arguments': '{\"content\":\"User loves spaghetti\",\"action\":\"create\"}', 'name': 'manage_memory'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 21, 'prompt_tokens': 242, 'total_tokens': 263, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-f8126c89-28a0-4973-9261-81368a326eed-0', tool_calls=[{'name': 'manage_memory', 'args': {'content': 'User loves spaghetti', 'action': 'create'}, 'id': 'call_IOMqnQdUMBahmicU2xYBC0cz', 'type': 'tool_call'}], usage_metadata={'input_tokens': 242, 'output_tokens': 21, 'total_tokens': 263, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
314 | " ToolMessage(content='created memory 3320b2ed-06f0-498d-93c9-ac2b31ccb0e2', name='manage_memory', id='d95cea64-d0f0-42cf-9aec-6cfb8e8cda0a', tool_call_id='call_IOMqnQdUMBahmicU2xYBC0cz'),\n",
315 | " AIMessage(content=\"I've noted that you love spaghetti! If you have any favorite recipes or dishes, feel free to share!\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 23, 'prompt_tokens': 297, 'total_tokens': 320, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'stop', 'logprobs': None}, id='run-242072f9-9a86-460f-b497-3b3518ddafdf-0', usage_metadata={'input_tokens': 297, 'output_tokens': 23, 'total_tokens': 320, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
316 | ]
317 | },
318 | "execution_count": 21,
319 | "metadata": {},
320 | "output_type": "execute_result"
321 | }
322 | ],
323 | "source": [
324 | "app.invoke({\"messages\": [{\"role\": \"user\", \"content\": \"I love spaghetti\"}]}, config={\"configurable\": {\"user_id\": \"alice\"}})"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 22,
330 | "metadata": {},
331 | "outputs": [
332 | {
333 | "data": {
334 | "text/plain": [
335 | "{'messages': [HumanMessage(content='what do you know about me?', additional_kwargs={}, response_metadata={}, id='018b5b16-312e-480c-81d9-13e0418c0565'),\n",
336 | " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_uvGMWOgqZ8wxlzrcw8MyU9sj', 'function': {'arguments': '{\"query\":\"user\"}', 'name': 'search_memory'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 246, 'total_tokens': 261, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-d402dccc-26b2-43d4-acf7-1060d983dc80-0', tool_calls=[{'name': 'search_memory', 'args': {'query': 'user'}, 'id': 'call_uvGMWOgqZ8wxlzrcw8MyU9sj', 'type': 'tool_call'}], usage_metadata={'input_tokens': 246, 'output_tokens': 15, 'total_tokens': 261, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
337 | " ToolMessage(content='[{\"namespace\": [\"memories\", \"alice\"], \"key\": \"3320b2ed-06f0-498d-93c9-ac2b31ccb0e2\", \"value\": {\"content\": \"User loves spaghetti\"}, \"created_at\": \"2025-03-08T16:17:24.423373+00:00\", \"updated_at\": \"2025-03-08T16:17:24.423373+00:00\", \"score\": 0.27188566547773124}]', name='search_memory', id='0f3f29f5-f1e2-483b-9704-3514c9dec155', tool_call_id='call_uvGMWOgqZ8wxlzrcw8MyU9sj'),\n",
338 | " AIMessage(content=\"I know that you love spaghetti! If there's anything else you'd like me to remember or if you have any specific preferences, feel free to let me know!\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 33, 'prompt_tokens': 382, 'total_tokens': 415, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'stop', 'logprobs': None}, id='run-6dda93de-0522-422e-9437-bc6fd6b84be8-0', usage_metadata={'input_tokens': 382, 'output_tokens': 33, 'total_tokens': 415, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
339 | ]
340 | },
341 | "execution_count": 22,
342 | "metadata": {},
343 | "output_type": "execute_result"
344 | }
345 | ],
346 | "source": [
347 | "app.invoke({\"messages\": [{\"role\": \"user\", \"content\": \"what do you know about me?\"}]}, config={\"configurable\": {\"user_id\": \"alice\"}})"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": 23,
353 | "metadata": {},
354 | "outputs": [
355 | {
356 | "data": {
357 | "text/plain": [
358 | "{'messages': [HumanMessage(content='what do you know about me?', additional_kwargs={}, response_metadata={}, id='41727410-c9ab-45a1-a603-7b380112fe0a'),\n",
359 | " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_qyoIPTtyhE0PWkXB8RyxC1RD', 'function': {'arguments': '{\"query\":\"user\",\"limit\":5}', 'name': 'search_memory'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 19, 'prompt_tokens': 246, 'total_tokens': 265, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-fd1d3838-37ea-4751-b5e8-0c5b93d017e5-0', tool_calls=[{'name': 'search_memory', 'args': {'query': 'user', 'limit': 5}, 'id': 'call_qyoIPTtyhE0PWkXB8RyxC1RD', 'type': 'tool_call'}], usage_metadata={'input_tokens': 246, 'output_tokens': 19, 'total_tokens': 265, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),\n",
360 | " ToolMessage(content=[], name='search_memory', id='90868b30-115e-4c8f-9496-84a0eaf0d6b1', tool_call_id='call_qyoIPTtyhE0PWkXB8RyxC1RD'),\n",
361 | " AIMessage(content=\"I don't have any specific information about you stored in my memory. If you'd like me to remember something or if you have any preferences you'd like to share, feel free to let me know!\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 40, 'prompt_tokens': 272, 'total_tokens': 312, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_06737a9306', 'finish_reason': 'stop', 'logprobs': None}, id='run-ef76d6f5-513c-4910-988e-669ce20d2d4b-0', usage_metadata={'input_tokens': 272, 'output_tokens': 40, 'total_tokens': 312, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
362 | ]
363 | },
364 | "execution_count": 23,
365 | "metadata": {},
366 | "output_type": "execute_result"
367 | }
368 | ],
369 | "source": [
370 | "app.invoke({\"messages\": [{\"role\": \"user\", \"content\": \"what do you know about me?\"}]}, config={\"configurable\": {\"user_id\": \"max\"}})"
371 | ]
372 | },
373 | {
374 | "cell_type": "markdown",
375 | "metadata": {},
376 | "source": [
377 | "### Prededual Memory: System Instructions"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": 24,
383 | "metadata": {},
384 | "outputs": [],
385 | "source": [
386 | "from langmem import create_prompt_optimizer\n",
387 | "\n",
388 | "optimizer = create_prompt_optimizer(\n",
389 | " llm,\n",
390 | " kind=\"metaprompt\",\n",
391 | " config={\"max_reflection_steps\": 3}\n",
392 | ")"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": 26,
398 | "metadata": {},
399 | "outputs": [
400 | {
401 | "name": "stdout",
402 | "output_type": "stream",
403 | "text": [
404 | "You are a helpful assistant. When users ask for explanations, especially in programming contexts, prioritize providing practical examples alongside theoretical explanations.\n"
405 | ]
406 | }
407 | ],
408 | "source": [
409 | "prompt = \"You are a helpful assistant.\"\n",
410 | "trajectory = [\n",
411 | " {\"role\": \"user\", \"content\": \"Explain inheritance in Python\"},\n",
412 | " {\"role\": \"assistant\", \"content\": \"Here's a detailed theoretical explanation...\"},\n",
413 | " {\"role\": \"user\", \"content\": \"Show me a practical example instead\"},\n",
414 | "]\n",
415 | "optimized = optimizer.invoke({\n",
416 | " \"trajectories\": [(trajectory, {\"user_score\": 0})],\n",
417 | " \"prompt\": prompt\n",
418 | "})\n",
419 | "print(optimized)"
420 | ]
421 | }
422 | ],
423 | "metadata": {
424 | "kernelspec": {
425 | "display_name": ".venv",
426 | "language": "python",
427 | "name": "python3"
428 | },
429 | "language_info": {
430 | "codemirror_mode": {
431 | "name": "ipython",
432 | "version": 3
433 | },
434 | "file_extension": ".py",
435 | "mimetype": "text/x-python",
436 | "name": "python",
437 | "nbconvert_exporter": "python",
438 | "pygments_lexer": "ipython3",
439 | "version": "3.11.0"
440 | }
441 | },
442 | "nbformat": 4,
443 | "nbformat_minor": 2
444 | }
445 |
--------------------------------------------------------------------------------