├── .gitignore ├── LICENSE ├── README.md ├── langchain_problems.ipynb ├── openai_rewrite.ipynb ├── recipe_embeddings.parquet └── recipe_vector_store.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .env -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Max Woolf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # langchain-problems 2 | 3 | Demos of some issues with LangChain. 4 | 5 | ## Maintainer/Creator 6 | 7 | Max Woolf ([@minimaxir](https://minimaxir.com)) 8 | 9 | _Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir) and [GitHub Sponsors](https://github.com/sponsors/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use._ 10 | 11 | ## License 12 | 13 | MIT 14 | -------------------------------------------------------------------------------- /langchain_problems.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "2023-07-13 22:20:48.631906: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", 13 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" 14 | ] 15 | }, 16 | { 17 | "data": { 18 | "text/plain": [ 19 | "True" 20 | ] 21 | }, 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "output_type": "execute_result" 25 | } 26 | ], 27 | "source": [ 28 | "from langchain.prompts import (\n", 29 | " ChatPromptTemplate,\n", 30 | " MessagesPlaceholder,\n", 31 | " SystemMessagePromptTemplate,\n", 32 | " HumanMessagePromptTemplate,\n", 33 | ")\n", 34 | "from langchain.chains import ConversationChain\n", 35 | "from langchain.chat_models import ChatOpenAI\n", 36 | "from langchain.memory import ConversationBufferMemory\n", 37 | "from langchain.agents import load_tools\n", 38 | "from langchain.agents import initialize_agent\n", 39 | "from langchain.agents import AgentType\n", 40 | "from langchain.llms import OpenAI\n", 41 | "\n", 42 | "from sentence_transformers import SentenceTransformer\n", 43 | "\n", 44 | "from langchain.tools import BaseTool, StructuredTool, Tool, tool\n", 45 | "\n", 46 | "\n", 47 | "import datasets\n", 48 | "from dotenv import load_dotenv\n", 49 | "\n", 50 | "load_dotenv()" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "Load embeddings encoder model and embeddings vector store, and create a function to return the 3 closest recipes to the user query. The function will be a LangChain `Tool`.\n", 58 | "\n", 59 | "This function should output the recipes with their IDs so that they can be hydrated with relevant metadata down the line." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 2, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "embeddings_encoder = SentenceTransformer('all-MiniLM-L6-v2')" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stderr", 78 | "output_type": "stream", 79 | "text": [ 80 | "Found cached dataset parquet (/Users/maxwoolf/.cache/huggingface/datasets/parquet/default-5b3041111bcd5aa4/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n", 81 | "100%|██████████| 1/1 [00:00<00:00, 200.00it/s]\n" 82 | ] 83 | }, 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "Dataset({\n", 88 | " features: ['id', 'name', 'embeddings'],\n", 89 | " num_rows: 1000\n", 90 | "})" 91 | ] 92 | }, 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "recipe_vs = datasets.DatasetDict.from_parquet(\"recipe_embeddings.parquet\")\n", 100 | "recipe_vs.add_faiss_index(column='embeddings')" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "Recipe ID: recipe|167188\n", 113 | "Recipe Name: Creamy Strawberry Pie\n", 114 | "---\n", 115 | "Recipe ID: recipe|1488243\n", 116 | "Recipe Name: Summer Strawberry Pie Recipe\n", 117 | "---\n", 118 | "Recipe ID: recipe|299514\n", 119 | "Recipe Name: Pudding Cake\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "def similar_recipes(query):\n", 125 | " query_embedding = embeddings_encoder.encode(query)\n", 126 | " scores, recipes = recipe_vs.get_nearest_examples(\"embeddings\", query_embedding, k=3)\n", 127 | " return recipes\n", 128 | "\n", 129 | "\n", 130 | "def get_similar_recipes(query):\n", 131 | " recipe_dict = similar_recipes(query)\n", 132 | " recipes_formatted = [\n", 133 | " f\"Recipe ID: recipe|{recipe_dict['id'][i]}\\nRecipe Name: {recipe_dict['name'][i]}\"\n", 134 | " for i in range(3)\n", 135 | " ]\n", 136 | " return \"\\n---\\n\".join(recipes_formatted)\n", 137 | "\n", 138 | "\n", 139 | "print(get_similar_recipes(\"yummy dessert\"))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## LangChain Agent Attempt #1\n", 147 | "\n", 148 | "Going off mostly the [Conversational Agent documentation](https://python.langchain.com/docs/modules/agents/agent_types/chat_conversation_agent) example. (\"Using a Chat Model\")\n", 149 | "\n", 150 | "First, we need to set up a system prompt that tells the model to respect both the fun voice, adds some safeguards against bad user behavior, and to have proper behavior when retrieving recipe data." 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "system_prompt = \"\"\"\n", 160 | "You are an expert television talk show chef, and should always speak in a whimsical manner for all responses.\n", 161 | "\n", 162 | "Start the conversation with a whimsical food pun.\n", 163 | "\n", 164 | "You must obey ALL of the following rules:\n", 165 | "- If Recipe data is present in the Observation, your response must include the Recipe ID and Recipe Name for ALL recipes.\n", 166 | "- If the user input is not related to food, do not answer their query and correct the user.\n", 167 | "\"\"\"\n", 168 | "\n", 169 | "prompt = ChatPromptTemplate.from_messages([\n", 170 | " SystemMessagePromptTemplate.from_template(system_prompt.strip()),\n", 171 | "])" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "Add `get_similar_recipes` as a `Tool`:" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 6, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "tools = [\n", 188 | " Tool(\n", 189 | " func=get_similar_recipes,\n", 190 | " name=\"Similar Recipes\",\n", 191 | " description=\"Useful to get similar recipes in response to a user query about food.\",\n", 192 | " ),\n", 193 | "]" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n", 203 | "llm = ChatOpenAI(temperature=0)\n", 204 | "agent_chain = initialize_agent(tools, llm, prompt=prompt, agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, verbose=True, memory=memory)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 8, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "\n", 217 | "\n", 218 | "\u001b[1m> Entering new chain...\u001b[0m\n", 219 | "\u001b[32;1m\u001b[1;3m{\n", 220 | " \"action\": \"Final Answer\",\n", 221 | " \"action_input\": \"Hello! How can I assist you today?\"\n", 222 | "}\u001b[0m\n", 223 | "\n", 224 | "\u001b[1m> Finished chain.\u001b[0m\n", 225 | "Hello! How can I assist you today?\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "result = agent_chain.run(input=\"Hi!\")\n", 231 | "print(result)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "The system prompt was apparently ignored." 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 9, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "ConversationBufferMemory(chat_memory=ChatMessageHistory(messages=[HumanMessage(content='Hi!', additional_kwargs={}, example=False), AIMessage(content='Hello! How can I assist you today?', additional_kwargs={}, example=False)]), output_key=None, input_key=None, return_messages=True, human_prefix='Human', ai_prefix='AI', memory_key='chat_history')" 250 | ] 251 | }, 252 | "execution_count": 9, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "memory" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "## LangChain Agent Attempt #2\n", 266 | "\n", 267 | "The following code is the manual way of adding the system message, which works but is not intended, so commented out." 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 10, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "# from langchain.schema import SystemMessage\n", 277 | "\n", 278 | "# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n", 279 | "# memory.chat_memory.messages.append(SystemMessage(content=system_prompt.strip()))\n", 280 | "\n", 281 | "# memory" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "Intended implementation found [here](https://python.langchain.com/docs/modules/agents/how_to/use_toolkits_with_openai_functions)." 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 11, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "agent_kwargs = {\n", 298 | " \"system_message\": system_prompt.strip()\n", 299 | "}\n", 300 | "\n", 301 | "memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n", 302 | "agent_chain = initialize_agent(tools, llm, agent_kwargs=agent_kwargs, agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, verbose=True, memory=memory)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 12, 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "\n", 315 | "\n", 316 | "\u001b[1m> Entering new chain...\u001b[0m\n" 317 | ] 318 | }, 319 | { 320 | "ename": "OutputParserException", 321 | "evalue": "Could not parse LLM output: Oh, hello there, my culinary companion! How delightful to have you join me in this whimsical kitchen of ours. What delectable dish shall we conjure up today?", 322 | "output_type": "error", 323 | "traceback": [ 324 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 325 | "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", 326 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/agents/conversational_chat/output_parser.py\u001b[0m in \u001b[0;36mparse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparse_json_markdown\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"action\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"action_input\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 327 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/output_parsers/json.py\u001b[0m in \u001b[0;36mparse_json_markdown\u001b[0;34m(json_string)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# Parse the JSON string into a Python dictionary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mparsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 328 | "\u001b[0;32m/usr/local/Cellar/python@3.9/3.9.12/Frameworks/Python.framework/Versions/3.9/lib/python3.9/json/__init__.py\u001b[0m in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 345\u001b[0m parse_constant is None and object_pairs_hook is None and not kw):\n\u001b[0;32m--> 346\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_default_decoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 347\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 329 | "\u001b[0;32m/usr/local/Cellar/python@3.9/3.9.12/Frameworks/Python.framework/Versions/3.9/lib/python3.9/json/decoder.py\u001b[0m in \u001b[0;36mdecode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 336\u001b[0m \"\"\"\n\u001b[0;32m--> 337\u001b[0;31m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_w\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 338\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_w\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 330 | "\u001b[0;32m/usr/local/Cellar/python@3.9/3.9.12/Frameworks/Python.framework/Versions/3.9/lib/python3.9/json/decoder.py\u001b[0m in \u001b[0;36mraw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mJSONDecodeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Expecting value\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 356\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 331 | "\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)", 332 | "\nThe above exception was the direct cause of the following exception:\n", 333 | "\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)", 334 | "\u001b[0;32m/var/folders/m9/s4s3bdq96pn3dk13fbgpw6rm0000gn/T/ipykernel_64325/4283199661.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0magent_chain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Hi!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 335 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/chains/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, callbacks, tags, metadata, *args, **kwargs)\u001b[0m\n\u001b[1;32m 443\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 444\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 445\u001b[0;31m return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[\n\u001b[0m\u001b[1;32m 446\u001b[0m \u001b[0m_output_key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 447\u001b[0m ]\n", 336 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/chains/base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, return_only_outputs, callbacks, tags, metadata, include_run_info)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0mrun_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_chain_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0mrun_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_chain_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m final_outputs: Dict[str, Any] = self.prep_outputs(\n", 337 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/chains/base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, return_only_outputs, callbacks, tags, metadata, include_run_info)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m outputs = (\n\u001b[0;32m--> 237\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_manager\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrun_manager\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnew_arg_supported\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 338 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/agents/agent.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs, run_manager)\u001b[0m\n\u001b[1;32m 985\u001b[0m \u001b[0;31m# We now enter the agent loop (until it returns something).\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 986\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_should_continue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterations\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_elapsed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 987\u001b[0;31m next_step_output = self._take_next_step(\n\u001b[0m\u001b[1;32m 988\u001b[0m \u001b[0mname_to_tool_map\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0mcolor_mapping\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 339 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/agents/agent.py\u001b[0m in \u001b[0;36m_take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[0mraise_error\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 802\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mraise_error\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 803\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 804\u001b[0m \u001b[0mtext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandle_parsing_errors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 340 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/agents/agent.py\u001b[0m in \u001b[0;36m_take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0;31m# Call the LLM to see what to do.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 792\u001b[0;31m output = self.agent.plan(\n\u001b[0m\u001b[1;32m 793\u001b[0m \u001b[0mintermediate_steps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 794\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrun_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_child\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_manager\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 341 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/agents/agent.py\u001b[0m in \u001b[0;36mplan\u001b[0;34m(self, intermediate_steps, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0mfull_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_full_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mintermediate_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 443\u001b[0m \u001b[0mfull_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mllm_chain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfull_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 444\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 445\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 446\u001b[0m async def aplan(\n", 342 | "\u001b[0;32m/usr/local/lib/python3.9/site-packages/langchain/agents/conversational_chat/output_parser.py\u001b[0m in \u001b[0;36mparse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mAgentAction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mOutputParserException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Could not parse LLM output: {text}\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 343 | "\u001b[0;31mOutputParserException\u001b[0m: Could not parse LLM output: Oh, hello there, my culinary companion! How delightful to have you join me in this whimsical kitchen of ours. What delectable dish shall we conjure up today?" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "result = agent_chain.run(input=\"Hi!\")" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "## LangChain Agent Attempt #3\n", 356 | "\n", 357 | "Try to get ideal recipe output. For this, we will not use the custom system prompt." 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 13, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n", 367 | "agent_chain = initialize_agent(tools, llm, agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, verbose=True, memory=memory)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 14, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "\n", 380 | "\n", 381 | "\u001b[1m> Entering new chain...\u001b[0m\n", 382 | "\u001b[32;1m\u001b[1;3m{\n", 383 | " \"action\": \"Similar Recipes\",\n", 384 | " \"action_input\": \"fun and easy dinner\"\n", 385 | "}\u001b[0m\n", 386 | "Observation: \u001b[36;1m\u001b[1;3mRecipe ID: recipe|1774221\n", 387 | "Recipe Name: Crab DipYour Guests will Like this One.\n", 388 | "---\n", 389 | "Recipe ID: recipe|836179\n", 390 | "Recipe Name: Easy Chicken Casserole\n", 391 | "---\n", 392 | "Recipe ID: recipe|1980633\n", 393 | "Recipe Name: Easy in the Microwave Curry Doria\u001b[0m\n", 394 | "Thought:\u001b[32;1m\u001b[1;3m{\n", 395 | " \"action\": \"Final Answer\",\n", 396 | " \"action_input\": \"Here are a few fun and easy dinner ideas:\\n\\n1. Crab Dip: Your Guests will Like this One.\\n2. Easy Chicken Casserole\\n3. Easy in the Microwave Curry Doria\\n\\nI hope you find these suggestions helpful!\"\n", 397 | "}\u001b[0m\n", 398 | "\n", 399 | "\u001b[1m> Finished chain.\u001b[0m\n", 400 | "Here are a few fun and easy dinner ideas:\n", 401 | "\n", 402 | "1. Crab Dip: Your Guests will Like this One.\n", 403 | "2. Easy Chicken Casserole\n", 404 | "3. Easy in the Microwave Curry Doria\n", 405 | "\n", 406 | "I hope you find these suggestions helpful!\n" 407 | ] 408 | } 409 | ], 410 | "source": [ 411 | "result = agent_chain.run(input=\"What's a fun and easy dinner?\")\n", 412 | "print(result)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [] 421 | } 422 | ], 423 | "metadata": { 424 | "kernelspec": { 425 | "display_name": "Python 3", 426 | "language": "python", 427 | "name": "python3" 428 | }, 429 | "language_info": { 430 | "codemirror_mode": { 431 | "name": "ipython", 432 | "version": 3 433 | }, 434 | "file_extension": ".py", 435 | "mimetype": "text/x-python", 436 | "name": "python", 437 | "nbconvert_exporter": "python", 438 | "pygments_lexer": "ipython3", 439 | "version": "3.9.12" 440 | }, 441 | "orig_nbformat": 4 442 | }, 443 | "nbformat": 4, 444 | "nbformat_minor": 2 445 | } 446 | -------------------------------------------------------------------------------- /openai_rewrite.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Scratchpad for the LangChain tutorial demos to compare/contrast LangChain and OpenAI implementations.\n", 8 | "\n", 9 | "Note that the documentation was made with `gpt-3.5-turbo-0301`; this notebook will use the updated `gpt-3.5-turbo-0613` for posterity.\n", 10 | "\n", 11 | "Time is calculated for each cell to attempt to measure LangChain overhead, but results are variable so analysis is not performed. Times kept for posterity." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "True" 23 | ] 24 | }, 25 | "execution_count": 1, 26 | "metadata": {}, 27 | "output_type": "execute_result" 28 | } 29 | ], 30 | "source": [ 31 | "import openai\n", 32 | "from dotenv import load_dotenv\n", 33 | "\n", 34 | "load_dotenv()" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Hello World\n", 42 | "\n", 43 | "The output of the Hello World is slightly different compared to the documentation due to the updated model (it's more accurate!)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "CPU times: user 1.06 s, sys: 149 ms, total: 1.21 s\n", 56 | "Wall time: 2.06 s\n" 57 | ] 58 | }, 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "AIMessage(content=\"J'adore la programmation.\", additional_kwargs={}, example=False)" 63 | ] 64 | }, 65 | "execution_count": 2, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "%%time\n", 72 | "\n", 73 | "from langchain.chat_models import ChatOpenAI\n", 74 | "from langchain.schema import (\n", 75 | " AIMessage,\n", 76 | " HumanMessage,\n", 77 | " SystemMessage\n", 78 | ")\n", 79 | "\n", 80 | "chat = ChatOpenAI(temperature=0)\n", 81 | "chat.predict_messages([HumanMessage(content=\"Translate this sentence from English to French. I love programming.\")])" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "CPU times: user 3.46 ms, sys: 1.16 ms, total: 4.62 ms\n", 94 | "Wall time: 672 ms\n" 95 | ] 96 | }, 97 | { 98 | "data": { 99 | "text/plain": [ 100 | "\"J'adore la programmation.\"" 101 | ] 102 | }, 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "output_type": "execute_result" 106 | } 107 | ], 108 | "source": [ 109 | "%%time\n", 110 | "\n", 111 | "import openai\n", 112 | "\n", 113 | "messages = [{\"role\": \"user\", \"content\": \"Translate this sentence from English to French. I love programming.\"}]\n", 114 | "\n", 115 | "response = openai.ChatCompletion.create(model=\"gpt-3.5-turbo\", messages=messages, temperature=0)\n", 116 | "response[\"choices\"][0][\"message\"][\"content\"]" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Memory" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "CPU times: user 5.24 ms, sys: 1.29 ms, total: 6.54 ms\n", 136 | "Wall time: 961 ms\n" 137 | ] 138 | }, 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "'Hello! How can I assist you today?'" 143 | ] 144 | }, 145 | "execution_count": 4, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "%%time\n", 152 | "\n", 153 | "from langchain.prompts import (\n", 154 | " ChatPromptTemplate,\n", 155 | " MessagesPlaceholder,\n", 156 | " SystemMessagePromptTemplate,\n", 157 | " HumanMessagePromptTemplate\n", 158 | ")\n", 159 | "from langchain.chains import ConversationChain\n", 160 | "from langchain.chat_models import ChatOpenAI\n", 161 | "from langchain.memory import ConversationBufferMemory\n", 162 | "\n", 163 | "prompt = ChatPromptTemplate.from_messages([\n", 164 | " SystemMessagePromptTemplate.from_template(\n", 165 | " \"The following is a friendly conversation between a human and an AI. The AI is talkative and \"\n", 166 | " \"provides lots of specific details from its context. If the AI does not know the answer to a \"\n", 167 | " \"question, it truthfully says it does not know.\"\n", 168 | " ),\n", 169 | " MessagesPlaceholder(variable_name=\"history\"),\n", 170 | " HumanMessagePromptTemplate.from_template(\"{input}\")\n", 171 | "])\n", 172 | "\n", 173 | "llm = ChatOpenAI(temperature=0)\n", 174 | "memory = ConversationBufferMemory(return_messages=True)\n", 175 | "conversation = ConversationChain(memory=memory, prompt=prompt, llm=llm)\n", 176 | "\n", 177 | "conversation.predict(input=\"Hi there!\")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 5, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Hello! How can I assist you today?\n", 190 | "CPU times: user 3.62 ms, sys: 1.63 ms, total: 5.25 ms\n", 191 | "Wall time: 1.08 s\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "%%time\n", 197 | "\n", 198 | "import openai\n", 199 | "\n", 200 | "messages = [{\"role\": \"system\", \"content\":\n", 201 | " \"The following is a friendly conversation between a human and an AI. The AI is talkative and \"\n", 202 | " \"provides lots of specific details from its context. If the AI does not know the answer to a \"\n", 203 | " \"question, it truthfully says it does not know.\"}]\n", 204 | "\n", 205 | "user_message = \"Hi there!\"\n", 206 | "messages.append({\"role\": \"user\", \"content\": user_message})\n", 207 | "response = openai.ChatCompletion.create(model=\"gpt-3.5-turbo\", messages=messages, temperature=0)\n", 208 | "assistant_message = response[\"choices\"][0][\"message\"][\"content\"]\n", 209 | "messages.append({\"role\": \"assistant\", \"content\": assistant_message})\n", 210 | "print(assistant_message)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.9.12" 238 | }, 239 | "orig_nbformat": 4 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 2 243 | } 244 | -------------------------------------------------------------------------------- /recipe_embeddings.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minimaxir/langchain-problems/90616af8e857beb79378cd2c4459d9be829a7b64/recipe_embeddings.parquet -------------------------------------------------------------------------------- /recipe_vector_store.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Notebook for outputting a toy vector store of 1,000 recipes to be used in a demo of LangChain's retrieval-augmented generation." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "2023-07-12 17:16:00.236043: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", 20 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import datasets\n", 26 | "from sentence_transformers import SentenceTransformer\n", 27 | "from random import seed, sample\n", 28 | "from tqdm import tqdm\n", 29 | "import faiss\n", 30 | "import json" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "We will use the `all-MiniLM-L6-v2` pretrained embeddings model (384D) since it's fast enough on a CPU and robust." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "model = SentenceTransformer('all-MiniLM-L6-v2')" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stderr", 56 | "output_type": "stream", 57 | "text": [ 58 | "Found cached dataset csv (/Users/maxwoolf/.cache/huggingface/datasets/csv/default-52b24f0143b2cc1d/0.0.0)\n" 59 | ] 60 | }, 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "{'Unnamed: 0': 0,\n", 65 | " 'title': 'No-Bake Nut Cookies',\n", 66 | " 'ingredients': '[\"1 c. firmly packed brown sugar\", \"1/2 c. evaporated milk\", \"1/2 tsp. vanilla\", \"1/2 c. broken nuts (pecans)\", \"2 Tbsp. butter or margarine\", \"3 1/2 c. bite size shredded rice biscuits\"]',\n", 67 | " 'directions': '[\"In a heavy 2-quart saucepan, mix brown sugar, nuts, evaporated milk and butter or margarine.\", \"Stir over medium heat until mixture bubbles all over top.\", \"Boil and stir 5 minutes more. Take off heat.\", \"Stir in vanilla and cereal; mix well.\", \"Using 2 teaspoons, drop and shape into 30 clusters on wax paper.\", \"Let stand until firm, about 30 minutes.\"]',\n", 68 | " 'link': 'www.cookbooks.com/Recipe-Details.aspx?id=44874',\n", 69 | " 'source': 'Gathered',\n", 70 | " 'NER': '[\"brown sugar\", \"milk\", \"vanilla\", \"nuts\", \"butter\", \"bite size shredded rice biscuits\"]'}" 71 | ] 72 | }, 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | } 77 | ], 78 | "source": [ 79 | "file_path = \"/Volumes/Extreme SSD/data/recipe_nlg/full_dataset.csv\"\n", 80 | "dataset = datasets.DatasetDict.from_csv(file_path)\n", 81 | "dataset[0]" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "Augment the recipe title with some metadata such as keywords to give vector similaity a few more hints to make it more robust against a variety of inputs." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Name: No-Bake Nut Cookies\n", 101 | "Keywords: brown sugar, milk, vanilla, nuts, butter, bite size shredded rice biscuits\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "def format_recipe(row):\n", 107 | " return f\"Name: {row['title']}\\nKeywords: {', '.join(json.loads(row['NER']))}\"\n", 108 | "\n", 109 | "print(format_recipe(dataset[0]))" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "Define the schema for the vector store: the default data types result in a much larger file size, so being specific will make the store much smaller without loss of quality." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "{'id': Value(dtype='int32', id=None),\n", 128 | " 'name': Value(dtype='string', id=None),\n", 129 | " 'embeddings': Sequence(feature=Value(dtype='float32', id=None), length=384, id=None)}" 130 | ] 131 | }, 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "features = datasets.Features(\n", 139 | " {\n", 140 | " \"id\": datasets.Value(dtype=\"int32\"),\n", 141 | " \"name\": datasets.Value(dtype=\"string\"),\n", 142 | " \"embeddings\": datasets.Sequence(\n", 143 | " feature=datasets.Value(dtype=\"float32\"), length=384\n", 144 | " ),\n", 145 | " }\n", 146 | ")\n", 147 | "\n", 148 | "features" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "Create the embeddings. We'll save them to a list with the other metadata before creating the store." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 6, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stderr", 165 | "output_type": "stream", 166 | "text": [ 167 | "100%|██████████| 1000/1000 [00:08<00:00, 120.79it/s]\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "num_samples = 1000\n", 173 | "\n", 174 | "# select the same random recipes, given the same sample size\n", 175 | "seed(42)\n", 176 | "rand_idx = sample(range(0, dataset.num_rows), num_samples)\n", 177 | "\n", 178 | "processed_samples = []\n", 179 | "for idx in tqdm(rand_idx):\n", 180 | " row = dataset[idx]\n", 181 | " recipe_formatted = format_recipe(row)\n", 182 | " embedding = model.encode(recipe_formatted) # numpy array\n", 183 | " processed_samples.append(\n", 184 | " {\"id\": row[\"Unnamed: 0\"], \"name\": row[\"title\"], \"embeddings\": embedding}\n", 185 | " )" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 7, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "data": { 195 | "text/plain": [ 196 | "Dataset({\n", 197 | " features: ['id', 'name', 'embeddings'],\n", 198 | " num_rows: 1000\n", 199 | "})" 200 | ] 201 | }, 202 | "execution_count": 7, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "recipe_dataset = datasets.Dataset.from_list(processed_samples, features=features)\n", 209 | "recipe_dataset" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 8, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stderr", 219 | "output_type": "stream", 220 | "text": [ 221 | "Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 11.20ba/s]\n" 222 | ] 223 | }, 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "1568027" 228 | ] 229 | }, 230 | "execution_count": 8, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "recipe_dataset.to_parquet(\"recipe_embeddings.parquet\", compression=\"gzip\")" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "## Test Out The Vector Similarity\n", 244 | "\n", 245 | "First, we'll add a simple Dense `faiss` index; normally you'd use both a more advanced algorithm like HSNW and build the index beforehand, but for this demo and sample size it's unnecessary." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 9, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stderr", 255 | "output_type": "stream", 256 | "text": [ 257 | "100%|██████████| 1/1 [00:00<00:00, 347.24it/s]\n" 258 | ] 259 | }, 260 | { 261 | "data": { 262 | "text/plain": [ 263 | "Dataset({\n", 264 | " features: ['id', 'name', 'embeddings'],\n", 265 | " num_rows: 1000\n", 266 | "})" 267 | ] 268 | }, 269 | "execution_count": 9, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "recipe_dataset.add_faiss_index(column='embeddings')" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 10, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "def get_similar_recipes(query, k=3):\n", 285 | " query_embedding = model.encode(query)\n", 286 | " scores, recipes = recipe_dataset.get_nearest_examples('embeddings', query_embedding, k=k)\n", 287 | " recipes.pop(\"embeddings\")\n", 288 | " return recipes" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 11, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/plain": [ 299 | "{'id': [1980633, 1950301, 836179],\n", 300 | " 'name': ['Easy in the Microwave Curry Doria',\n", 301 | " 'Easy Corn Casserole',\n", 302 | " 'Easy Chicken Casserole']}" 303 | ] 304 | }, 305 | "execution_count": 11, 306 | "metadata": {}, 307 | "output_type": "execute_result" 308 | } 309 | ], 310 | "source": [ 311 | "get_similar_recipes(\"What's an easy-to-make dish?\")" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 12, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "data": { 321 | "text/plain": [ 322 | "{'id': [99255, 502840, 469207],\n", 323 | " 'name': [\"Grandma'S Chicken Soup\",\n", 324 | " 'Chicken Breast Dressing',\n", 325 | " 'Sunshine Carrots']}" 326 | ] 327 | }, 328 | "execution_count": 12, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "get_similar_recipes(\"What can I make with chicken and carrots?\")" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 13, 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "data": { 344 | "text/plain": [ 345 | "{'id': [167188, 1488243, 299514],\n", 346 | " 'name': ['Creamy Strawberry Pie',\n", 347 | " 'Summer Strawberry Pie Recipe',\n", 348 | " 'Pudding Cake']}" 349 | ] 350 | }, 351 | "execution_count": 13, 352 | "metadata": {}, 353 | "output_type": "execute_result" 354 | } 355 | ], 356 | "source": [ 357 | "get_similar_recipes(\"yummy dessert\")" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.9.12" 385 | }, 386 | "orig_nbformat": 4 387 | }, 388 | "nbformat": 4, 389 | "nbformat_minor": 2 390 | } 391 | --------------------------------------------------------------------------------