├── notebooks ├── .gitkeep ├── chapter2 │ ├── sample_audio.mp3 │ ├── sample_image1.png │ ├── sample_image2.png │ ├── 02_langchain_introduction.ipynb │ ├── 01_openai_introduction.ipynb │ └── 03_gradio_introduction.ipynb ├── chapter3 │ ├── 04_memory.ipynb │ ├── 03_agent.ipynb │ ├── 01_knowledge.ipynb │ ├── 05_persona.ipynb │ └── 02_tools.ipynb └── chapter4 │ ├── 02_multi_agent_system_construction.ipynb │ └── 03_multi_agent_system_application.ipynb ├── tests └── __init__.py ├── llm_agent └── __init__.py ├── pyproject.toml ├── .gitignore └── README.md /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/chapter2/sample_audio.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elith-co-jp/book-llm-agent/HEAD/notebooks/chapter2/sample_audio.mp3 -------------------------------------------------------------------------------- /notebooks/chapter2/sample_image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elith-co-jp/book-llm-agent/HEAD/notebooks/chapter2/sample_image1.png -------------------------------------------------------------------------------- /notebooks/chapter2/sample_image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elith-co-jp/book-llm-agent/HEAD/notebooks/chapter2/sample_image2.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "llm-agent" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["aRySt0cat "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.11" 10 | notebook = "^7.2.1" 11 | ipywidgets = "^8.1.3" 12 | openai = "^1.34.0" 13 | tiktoken = "^0.7.0" 14 | langchain = "^0.2.5" 15 | gradio = "^4.36.1" 16 | python-dotenv = "^1.0.1" 17 | pydantic = "^2.7.4" 18 | langchain-openai = "^0.1.17" 19 | 20 | 21 | [build-system] 22 | requires = ["poetry-core"] 23 | build-backend = "poetry.core.masonry.api" 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Installer logs 10 | pip-log.txt 11 | pip-delete-this-directory.txt 12 | 13 | # Jupyter Notebook 14 | .ipynb_checkpoints 15 | 16 | # IPython 17 | profile_default/ 18 | ipython_config.py 19 | 20 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 21 | __pypackages__/ 22 | 23 | # Environments 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | env.bak/ 30 | venv.bak/ 31 | 32 | # mypy 33 | .mypy_cache/ 34 | .dmypy.json 35 | dmypy.json 36 | 37 | # Others 38 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # やさしく学ぶLLMエージェント 3 | 4 | 5 | 6 | このリポジトリは、書籍『やさしく学ぶLLMエージェント 基本からマルチエージェント構築まで』のスクリプトソースコードを提供するものです。 7 | 8 | 9 | 10 | ## 書籍情報 11 | 12 | 13 | 14 | - **タイトル**: やさしく学ぶLLMエージェント 基本からマルチエージェント構築まで 15 | 16 | - **著者**: 株式会社Elith 井上 顧基・下垣内 隆太・松山 純大・成木 太音 17 | 18 | - **ISBN**: 978-4-274-23316-6 19 | 20 | - **発行日**: 2025年2月15日 21 | 22 | - **出版社**: 株式会社オーム社 23 | 24 | - **書籍ページ**: [オーム社の書籍ページ](https://www.ohmsha.co.jp/book/9784274233166/) 25 | 26 | - **書籍イベント**: 2025年2月28日(土)東京秋葉原・書泉ブックタワーで本書の発行記念祭りを行います。申し込みは、https://connpass.com/event/343788/ 27 | 28 | 29 | ## 本書の内容 30 | 31 | 32 | 33 | 本書では、LLMエージェントの基本概念から実際の構築方法、さらに応用的なマルチエージェントシステムの開発までを解説します。 34 | 35 | 36 | 37 | ### 主な目次 38 | 39 | 40 | 41 | #### 第1章 LLMエージェントとは 42 | 43 | - 1.1 言語モデルとは何か 44 | 45 | - 1.2 LLMエージェントとは 46 | 47 | 48 | 49 | #### 第2章 エージェント作成のための基礎知識 50 | 51 | - 2.1 OpenAI API 52 | 53 | - 2.2 LangChain入門 54 | 55 | - 2.3 Gradioを用いたGUI作成 56 | 57 | 58 | 59 | #### 第3章 エージェント 60 | 61 | - 3.1 LLMに知識を与える 62 | 63 | - 3.2 LLMにツールを与える 64 | 65 | - 3.3 複雑なフローで推論するエージェント 66 | 67 | - 3.4 記憶を持つエージェント 68 | 69 | - 3.5 ペルソナのあるエージェント 70 | 71 | 72 | 73 | #### 第4章 マルチエージェント 74 | 75 | - 4.1 マルチエージェントとは 76 | 77 | - 4.2 マルチエージェントシステムの構築 78 | 79 | - 4.3 マルチエージェントの活用 80 | 81 | 82 | 83 | #### 第5章 LLMエージェント研究の最先端 84 | 85 | - 5.1 直近の研究動向 86 | 87 | - 5.2 ビジネスでの利用例 88 | 89 | 90 | 91 | 92 | --- 93 | 94 | #AIエージェント #LLMエージェント 95 | -------------------------------------------------------------------------------- /notebooks/chapter3/04_memory.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 3.4章 記憶を持つエージェント" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "colab": { 15 | "base_uri": "https://localhost:8080/" 16 | }, 17 | "id": "akxteNI1JM7i", 18 | "outputId": "999d99e2-5875-4590-ebac-f376aae2c70f" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install langchain\n", 23 | "!pip install langchain-openai\n", 24 | "\n", 25 | "!pip install serpapi\n", 26 | "!pip install google-search-results\n", 27 | "\n", 28 | "# load_toolsを利用するのに必要\n", 29 | "!pip install langchain_community" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "AZwOvtUQL39U" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "import os\n", 41 | "from google.colab import userdata\n", 42 | "\n", 43 | "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n", 44 | "os.environ['SERPAPI_API_KEY'] = userdata.get('SERPAPI_API_KEY')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "id": "1MD4CvRQKZnv" 51 | }, 52 | "source": [ 53 | "# 3.4章" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "e3KbFb6ed4Ub", 64 | "outputId": "4a404e23-2c61-4900-8a8c-0d31687c9e1d" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "from langchain_core.prompts import PromptTemplate\n", 69 | "\n", 70 | "input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools']\n", 71 | "template=\"\"\"\\\n", 72 | "Answer the following questions as best you can. You have access to the following tools:\n", 73 | "\n", 74 | "{tools}\n", 75 | "\n", 76 | "Use the following format:\n", 77 | "\n", 78 | "Question: the input question you must answer\n", 79 | "Thought: you should always think about what to do\n", 80 | "Action: the action to take, should be one of [{tool_names}]\n", 81 | "Action Input: the input to the action\n", 82 | "Observation: the result of the action\n", 83 | "... (this Thought/Action/Action Input/Observation can repeat N times)\n", 84 | "Thought: I now know the final answer\n", 85 | "Final Answer: the final answer to the original input question\n", 86 | "\n", 87 | "Begin!\n", 88 | "\n", 89 | "Previous conversation history: {chat_history}\n", 90 | "Question: {input}\n", 91 | "Thought:{agent_scratchpad}\"\"\"\n", 92 | "\n", 93 | "prompt = PromptTemplate(input_variables=input_variables, template=template)\n", 94 | "print(prompt.template)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "id": "HboQJ764GDcb" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "from langchain.memory import ChatMessageHistory\n", 106 | "\n", 107 | "store = {}\n", 108 | "\n", 109 | "def get_by_session_id(session_id: str) -> ChatMessageHistory:\n", 110 | " if session_id not in store:\n", 111 | " store[session_id] = ChatMessageHistory()\n", 112 | " return store[session_id]" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "colab": { 120 | "base_uri": "https://localhost:8080/" 121 | }, 122 | "id": "8JXcSlf7KYfm", 123 | "outputId": "26cd43f5-3f96-4ef1-c292-fb973e567d44" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "from langchain_openai import ChatOpenAI\n", 128 | "from langchain.agents import load_tools\n", 129 | "from langchain.agents import AgentExecutor, create_react_agent\n", 130 | "from langchain_core.runnables.history import\\\n", 131 | "RunnableWithMessageHistory\n", 132 | "\n", 133 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 134 | "tools = load_tools([\"serpapi\"], llm=model)\n", 135 | "agent = create_react_agent(model, tools, prompt)\n", 136 | "agent_executor = AgentExecutor(agent=agent, tools=tools,\\\n", 137 | "verbose=True)\n", 138 | "\n", 139 | "agent_with_chat_history = RunnableWithMessageHistory(\n", 140 | " agent_executor,\n", 141 | " get_by_session_id,\n", 142 | " input_messages_key=\"input\",\n", 143 | " history_messages_key=\"chat_history\",\n", 144 | ")\n", 145 | "\n", 146 | "response = agent_with_chat_history.invoke({\"input\": \"株式会社\\\n", 147 | "Elithの住所を教えてください。最新の公式情報として公開されているものを教え\\\n", 148 | "てください。\"},\n", 149 | " config={\"configurable\": {\"session_id\": \"test-session1\"}})" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "colab": { 157 | "base_uri": "https://localhost:8080/" 158 | }, 159 | "id": "plxI1w4baNG9", 160 | "outputId": "023bc94a-5cf9-44be-c4ac-f7d2e7188e78" 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "print(get_by_session_id(\"test-session1\"))" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "colab": { 172 | "base_uri": "https://localhost:8080/" 173 | }, 174 | "id": "PuC6bbJyXO5b", 175 | "outputId": "2e397b11-5a8c-44d6-98e8-66852cc5aa7d" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "response = agent_with_chat_history.invoke({\"input\": \"先ほど尋ね\\\n", 180 | "た会社は何の会社ですか?\"},\n", 181 | " config={\"configurable\": {\"session_id\": \"test-session1\"}})" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": { 188 | "colab": { 189 | "base_uri": "https://localhost:8080/" 190 | }, 191 | "id": "JAB-rBsflMFz", 192 | "outputId": "413f282d-900b-4e8a-82ec-079d61e9a4af" 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "print(get_by_session_id(\"test-session1\"))" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "colab": { 204 | "base_uri": "https://localhost:8080/" 205 | }, 206 | "id": "5Q-idfiqGzEY", 207 | "outputId": "9cac879a-8da6-4cfc-8506-d2b9f53ed2ac" 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "print(get_by_session_id(\"test-session2\"))" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "colab": { 219 | "base_uri": "https://localhost:8080/" 220 | }, 221 | "id": "ACk0_4HKlQkl", 222 | "outputId": "f8fd5554-5c23-4af1-ff80-0d24863e2073" 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "response = agent_with_chat_history.invoke({\"input\": \"先ほど尋ね\\\n", 227 | "た会社は何の会社ですか?\"},\n", 228 | " config={\"configurable\": {\"session_id\": \"test-session2\"}})" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "colab": { 236 | "base_uri": "https://localhost:8080/" 237 | }, 238 | "id": "z5AVO4FoGsUU", 239 | "outputId": "b102b599-f58a-443b-a520-cf979b5f5a13" 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "print(get_by_session_id(\"test-session2\"))" 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "colab": { 249 | "provenance": [] 250 | }, 251 | "kernelspec": { 252 | "display_name": "Python 3", 253 | "name": "python3" 254 | }, 255 | "language_info": { 256 | "name": "python" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 0 261 | } 262 | -------------------------------------------------------------------------------- /notebooks/chapter3/03_agent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 3.3章 複雑なフローをこなすエージェント" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "colab": { 15 | "base_uri": "https://localhost:8080/" 16 | }, 17 | "id": "akxteNI1JM7i", 18 | "outputId": "d42491d3-6e2d-49a9-b8aa-2103b0040e43" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install langchain\n", 23 | "!pip install langchain-openai\n", 24 | "\n", 25 | "!pip install serpapi\n", 26 | "!pip install google-search-results\n", 27 | "\n", 28 | "# load_toolsを利用するのに必要\n", 29 | "!pip install langchain_community" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "AZwOvtUQL39U" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "import os\n", 41 | "from google.colab import userdata\n", 42 | "\n", 43 | "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n", 44 | "os.environ['SERPAPI_API_KEY'] = userdata.get('SERPAPI_API_KEY')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "id": "tfLJfsb7kXFI" 51 | }, 52 | "source": [ 53 | "# 3.3章" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "LBS07Mq2BC-l", 64 | "outputId": "b5316ff0-3cde-4dea-fa47-9541c4a142ee" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "!pip install langchainhub" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "colab": { 76 | "base_uri": "https://localhost:8080/" 77 | }, 78 | "id": "5qk8cVRpA-pS", 79 | "outputId": "ede2cfd1-f438-499f-c93e-28aa0e5fe60d" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "from langchain import hub\n", 84 | "\n", 85 | "prompt = hub.pull(\"hwchase17/react\")\n", 86 | "\n", 87 | "print(prompt.template)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "id": "wBsk5cY6fvnu" 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "# おみくじツール定義(再掲)\n", 99 | "\n", 100 | "# おみくじ関数\n", 101 | "\n", 102 | "import random\n", 103 | "from datetime import datetime\n", 104 | "\n", 105 | "def get_fortune(date_string):\n", 106 | " # 日付文字列を解析\n", 107 | " try:\n", 108 | " date = datetime.strptime(date_string, \"%m月%d日\")\n", 109 | " except ValueError:\n", 110 | " return \"無効な日付形式です。'X月X日'の形式で入力してくださ\\\n", 111 | "い。\"\n", 112 | "\n", 113 | " # 運勢のリスト\n", 114 | " fortunes = [\n", 115 | " \"大吉\", \"中吉\", \"小吉\", \"吉\", \"末吉\", \"凶\", \"大凶\"\n", 116 | " ]\n", 117 | "\n", 118 | " # 運勢の重み付け(大吉と大凶の確率を低くする)\n", 119 | " weights = [1, 3, 3, 4, 3, 2, 1]\n", 120 | "\n", 121 | " # 日付に基づいてシードを設定(同じ日付なら同じ運勢を返す)\n", 122 | " random.seed(date.month * 100 + date.day)\n", 123 | "\n", 124 | " # 運勢をランダムに選択\n", 125 | " fortune = random.choices(fortunes, weights=weights)[0]\n", 126 | "\n", 127 | " return f\"{date_string}の運勢は【{fortune}】です。\"\n", 128 | "\n", 129 | "# ツール作成\n", 130 | "\n", 131 | "from langchain.tools import BaseTool\n", 132 | "\n", 133 | "class Get_fortune(BaseTool):\n", 134 | " name: str = 'Get_fortune'\n", 135 | " description: str = (\n", 136 | " \"特定の日付の運勢を占う。インプットは 'date_string'です。\\\n", 137 | "'date_string' は、占いを行う日付で、mm月dd日 という形式です。「1月1日」\\\n", 138 | "のように入力し、「'1月1日'」のように余計な文字列を付けてはいけません。\"\n", 139 | " )\n", 140 | "\n", 141 | " def _run(self, date_string) -> str:\n", 142 | " return get_fortune(date_string)\n", 143 | "\n", 144 | "\n", 145 | " async def _arun(self, query: str) -> str:\n", 146 | " raise NotImplementedError(\"does not support async\")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "id": "eAhTEA4Yf8iM" 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "# 日付ツール定義(再掲)\n", 158 | "\n", 159 | "# 日付取得関数\n", 160 | "\n", 161 | "from datetime import timedelta\n", 162 | "from zoneinfo import ZoneInfo\n", 163 | "\n", 164 | "\n", 165 | "def get_date(date):\n", 166 | " date_now = datetime.now(ZoneInfo(\"Asia/Tokyo\"))\n", 167 | " if (\"今日\" in date):\n", 168 | " date_delta = 0\n", 169 | " elif (\"明日\" in date):\n", 170 | " date_delta = 1\n", 171 | " elif (\"明後日\" in date):\n", 172 | " date_delta = 2\n", 173 | " else:\n", 174 | " return \"サポートしていません\"\n", 175 | " return (date_now + timedelta(days=date_delta)).strftime\\\n", 176 | "('%m月%d日')\n", 177 | "\n", 178 | "class Get_date(BaseTool):\n", 179 | " name: str = 'Get_date'\n", 180 | " description: str = (\n", 181 | " \"今日の日付を取得する。インプットは 'date'です。'date' は、日\\\n", 182 | "付を取得する対象の日で、'今日', '明日', '明後日' という3種類の文字列\\\n", 183 | "から指定します。「今日」のように入力し、「'今日'」のように余計な文字列を付\\\n", 184 | "けてはいけません。\"\n", 185 | " )\n", 186 | "\n", 187 | " def _run(self, date) -> str:\n", 188 | " return get_date(date)\n", 189 | "\n", 190 | " async def _arun(self, query: str) -> str:\n", 191 | " raise NotImplementedError(\"does not support async\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": { 198 | "colab": { 199 | "base_uri": "https://localhost:8080/" 200 | }, 201 | "id": "VX3VtomhAoHG", 202 | "outputId": "cc8bf29b-b6aa-4a15-b2e8-3044c5c465d1" 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "# エージェントの作成\n", 207 | "\n", 208 | "from langchain_openai import ChatOpenAI\n", 209 | "from langchain.schema import HumanMessage\n", 210 | "from langchain.agents import AgentExecutor, create_react_agent\n", 211 | "\n", 212 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 213 | "\n", 214 | "tools = [Get_date(), Get_fortune()]\n", 215 | "\n", 216 | "#1 エージェントの作成\n", 217 | "agent = create_react_agent(model, tools, prompt)\n", 218 | "\n", 219 | "#2 エージェントの実行準備\n", 220 | "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n", 221 | "\n", 222 | "\n", 223 | "response = agent_executor.invoke({\"input\": [HumanMessage\\\n", 224 | "(content=\"今日の運勢を教えてください。\")]})\n", 225 | "\n", 226 | "print(response)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "colab": { 234 | "base_uri": "https://localhost:8080/" 235 | }, 236 | "id": "KspH82xVKS_g", 237 | "outputId": "ed2ec95f-b73f-47a9-e7c1-a62e33dd6e75" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "# 検索できるエージェント作成\n", 242 | "\n", 243 | "from langchain.agents import load_tools\n", 244 | "\n", 245 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 246 | "tools = load_tools([\"serpapi\"], llm=model)\n", 247 | "agent = create_react_agent(model, tools, prompt)\n", 248 | "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n", 249 | "\n", 250 | "\n", 251 | "response = agent_executor.invoke({\"input\": [HumanMessage\\\n", 252 | " (content=\"株式会社Elithの住所を教えてください。最新の公式情報として公\\\n", 253 | " 開されているものを教えてください。\")]})\n", 254 | "\n", 255 | "print(response)" 256 | ] 257 | } 258 | ], 259 | "metadata": { 260 | "colab": { 261 | "provenance": [] 262 | }, 263 | "kernelspec": { 264 | "display_name": "Python 3", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "name": "python" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 0 273 | } 274 | -------------------------------------------------------------------------------- /notebooks/chapter3/01_knowledge.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "g8rO7Fp5nOAx" 7 | }, 8 | "source": [ 9 | "# 3.1章 LLMに知識を与える" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "srSBnZ_knwn4" 16 | }, 17 | "source": [ 18 | "## 3.1.1 LLM に知識を与える" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | }, 28 | "id": "oh7PG3aJWn6H", 29 | "outputId": "7f2ba520-7108-4a04-bcec-16e6c5c795d5" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "!pip install langchain\n", 34 | "!pip install langchain-openai" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "id": "SPMHGl6_ZX8s" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "import os\n", 46 | "from google.colab import userdata\n", 47 | "\n", 48 | "os.environ[\"OPENAI_API_KEY\"] = userdata.get('OPENAI_API_KEY')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "colab": { 56 | "base_uri": "https://localhost:8080/" 57 | }, 58 | "id": "9Wb7oue_Y3tU", 59 | "outputId": "8ad41b9f-ef11-4361-a0cb-0839f8e89462" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "from langchain_openai import ChatOpenAI\n", 64 | "from langchain.schema import HumanMessage\n", 65 | "\n", 66 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 67 | "result = model.invoke([HumanMessage(content=\"熊童子について教えて\\\n", 68 | "ください。\")])\n", 69 | "print(result.content)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "colab": { 77 | "base_uri": "https://localhost:8080/" 78 | }, 79 | "id": "0dHmKRPoZ7F1", 80 | "outputId": "f7f7abdc-8d91-481f-deaf-d314d9854238" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "from langchain_core.prompts import ChatPromptTemplate\n", 85 | "\n", 86 | "#1 プロンプトテンプレートの作成\n", 87 | "message = \"\"\"\n", 88 | "Answer this question using the provided context only.\n", 89 | "\n", 90 | "{question}\n", 91 | "\n", 92 | "Context:\n", 93 | "{context}\n", 94 | "\"\"\"\n", 95 | "\n", 96 | "prompt = ChatPromptTemplate.from_messages([(\"human\", message)])\n", 97 | "\n", 98 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 99 | "chain = prompt | model\n", 100 | "\n", 101 | "question_text = \" 熊童子について教えてください。\"\n", 102 | "information_text = \"\"\"\\\n", 103 | "熊童子はベンケイソウ科コチレドン属の多肉植物です。\n", 104 | "葉に丸みや厚みがあり、先端には爪のような突起があることから「熊の手」という\\\n", 105 | "愛称で人気を集めています。\n", 106 | "花はオレンジ色のベル型の花を咲かせることがあります。\"\"\"\n", 107 | "\n", 108 | "response = chain.invoke({\"context\": information_text, \"question\":\n", 109 | "question_text})\n", 110 | "print(response.content)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "pBp9Zkm7n6AG" 117 | }, 118 | "source": [ 119 | "## 3.1.2 文書の構造化" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "colab": { 127 | "base_uri": "https://localhost:8080/" 128 | }, 129 | "id": "arof7KOwbFN8", 130 | "outputId": "f8f80c4a-8827-4025-bcb7-b9883b4b498f" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "from langchain_core.documents import Document\n", 135 | "\n", 136 | "#1 Documentクラスオブジェクトの作成\n", 137 | "document = Document(\n", 138 | " page_content=\"\"\"\\\n", 139 | "セダムはベンケイソウ科マンネングザ属で、日本にも自生しているポピュラーな多\\\n", 140 | "肉植物です。\n", 141 | "種類が多くて葉の大きさや形状、カラーバリエーションも豊富なので、組み合わせ\\\n", 142 | "て寄せ植えにしたり、庭のグランドカバーにしたりして楽しむことができます。\n", 143 | "とても丈夫で育てやすく、多肉植物を初めて育てる方にもおすすめです。\"\"\",\n", 144 | " metadata={\"source\": \"succulent-plants-doc\"},\n", 145 | " )\n", 146 | "\n", 147 | "print(document)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "colab": { 155 | "base_uri": "https://localhost:8080/" 156 | }, 157 | "id": "fHwh7QmddTbg", 158 | "outputId": "d778e279-a611-4f51-9118-3139ceff7351" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "!pip install langchain_chroma" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "id": "nSL4q-vwdZa0" 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "from langchain_chroma import Chroma\n", 174 | "from langchain_openai import OpenAIEmbeddings\n", 175 | "\n", 176 | "#1 Documentクラスオブジェクトの作成\n", 177 | "documents = [\n", 178 | " Document(\n", 179 | " page_content=\"\"\"\\\n", 180 | "セダムはベンケイソウ科マンネングザ属で、日本にも自生しているポピュラーな多\\\n", 181 | "肉植物です。\n", 182 | "種類が多くて葉の大きさや形状、カラーバリエーションも豊富なので、組み合わせ\\\n", 183 | "て寄せ植えにしたり、庭のグランドカバーにしたりして楽しむことができます。\n", 184 | "とても丈夫で育てやすく、多肉植物を初めて育てる方にもおすすめです。\"\"\",\n", 185 | " metadata={\"source\": \"succulent-plants-doc\"},\n", 186 | " ),\n", 187 | " Document(\n", 188 | " page_content=\"\"\"\\\n", 189 | "熊童子はベンケイソウ科コチレドン属の多肉植物です。\n", 190 | "葉に丸みや厚みがあり、先端には爪のような突起があることから「熊の手」という\\\n", 191 | "愛称で人気を集めています。\n", 192 | "花はオレンジ色のベル型の花を咲かせることがあります。\"\"\",\n", 193 | " metadata={\"source\": \"succulent-plants-doc\"},\n", 194 | " ),\n", 195 | " Document(\n", 196 | " page_content=\"\"\"\\\n", 197 | "エケベリアはベンケイソウ科エケベリア属の多肉植物で、メキシコなど中南米が原\\\n", 198 | "産です。\n", 199 | "まるで花びらのように広がる肉厚な葉が特徴で、秋には紅葉も楽しめます。\n", 200 | "品種が多く、室内でも気軽に育てられるので、人気のある多肉植物です。\"\"\",\n", 201 | " metadata={\"source\": \"succulent-plants-doc\"},\n", 202 | " ),\n", 203 | " Document(\n", 204 | " page_content=\"\"\"\\\n", 205 | "ハオルチアは、春と秋に成長するロゼット形の多肉植物です。\n", 206 | "密に重なった葉が放射状に展開し、幾何学的で整った株姿になるのが魅力です。\n", 207 | "室内でも育てやすく手頃なサイズの多肉植物です。\"\"\",\n", 208 | " metadata={\"source\": \"succulent-plants-doc\"},\n", 209 | " ),\n", 210 | "]\n", 211 | "\n", 212 | "#2 Chromaデータベースの作成\n", 213 | "vectorstore = Chroma.from_documents(\n", 214 | " documents,\n", 215 | " embedding=OpenAIEmbeddings(),\n", 216 | ")\n" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "colab": { 224 | "base_uri": "https://localhost:8080/" 225 | }, 226 | "id": "-WSzY5zXe5pL", 227 | "outputId": "6f37c049-c79c-43a6-d58c-f148cc703545" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "vectorstore.similarity_search(\"熊童子\")" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": { 238 | "colab": { 239 | "base_uri": "https://localhost:8080/" 240 | }, 241 | "id": "jQSBpAJDe_SW", 242 | "outputId": "8b590a62-642f-459f-a4cf-6ecc0a7414e0" 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "vectorstore.similarity_search_with_score(\"熊童子\")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "VuI7iHjRoAu5" 253 | }, 254 | "source": [ 255 | "## 3.1.3 文書検索機能を持つLLM" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": { 262 | "colab": { 263 | "base_uri": "https://localhost:8080/" 264 | }, 265 | "id": "OI__7lAOfP_g", 266 | "outputId": "6cb46eb6-acfe-4c6b-c6cd-ddfa8988e402" 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "from langchain_core.runnables import RunnableLambda\n", 271 | "\n", 272 | "#1 Runnable オブジェクトの作成\n", 273 | "retriever = RunnableLambda(vectorstore.similarity_search).bind\\\n", 274 | "(k=1)\n", 275 | "retriever.invoke(\"熊童子\")" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": { 282 | "colab": { 283 | "base_uri": "https://localhost:8080/" 284 | }, 285 | "id": "Fe9qGmuEfnB9", 286 | "outputId": "7b8fc329-5eb3-490d-f92c-a0fadf32a31e" 287 | }, 288 | "outputs": [], 289 | "source": [ 290 | "from langchain_core.runnables import RunnablePassthrough\n", 291 | "\n", 292 | "message = \"\"\"\n", 293 | "Answer this question using the provided context only.\n", 294 | "\n", 295 | "{question}\n", 296 | "\n", 297 | "Context:\n", 298 | "{context}\n", 299 | "\"\"\"\n", 300 | "\n", 301 | "prompt = ChatPromptTemplate.from_messages([(\"human\", message)])\n", 302 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 303 | "\n", 304 | "#1 Chainの作成\n", 305 | "rag_chain = {\"context\": retriever,\n", 306 | " \"question\": RunnablePassthrough()} | prompt | model\n", 307 | "\n", 308 | "result = rag_chain.invoke(\"熊童子について教えてください。\")\n", 309 | "print(result.content)" 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "colab": { 315 | "provenance": [] 316 | }, 317 | "kernelspec": { 318 | "display_name": "Python 3", 319 | "name": "python3" 320 | }, 321 | "language_info": { 322 | "name": "python" 323 | } 324 | }, 325 | "nbformat": 4, 326 | "nbformat_minor": 0 327 | } 328 | -------------------------------------------------------------------------------- /notebooks/chapter2/02_langchain_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 準備" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import getpass\n", 17 | "import os\n", 18 | "\n", 19 | "# OpenAI API キーの設定\n", 20 | "api_key = getpass.getpass(\"OpenAI API キーを入力してください: \")\n", 21 | "os.environ[\"OPENAI_API_KEY\"] = api_key" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# 2.2.2 チャットアプリ" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from langchain_openai.chat_models import ChatOpenAI\n", 38 | "from langchain_core.messages import HumanMessage\n", 39 | "\n", 40 | "# 1 モデルの定義\n", 41 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 42 | "\n", 43 | "history = []\n", 44 | "n = 10\n", 45 | "for i in range(10):\n", 46 | " user_input = input(\"ユーザ入力: \")\n", 47 | " if user_input == \"exit\":\n", 48 | " break\n", 49 | " # 2 HumanMessage の作成と表示\n", 50 | " human_message = HumanMessage(user_input)\n", 51 | " human_message.pretty_print()\n", 52 | " # 3 会話履歴の追加\n", 53 | " history.append(HumanMessage(user_input))\n", 54 | " # 4 応答の作成と表示\n", 55 | " ai_message = llm.invoke(history)\n", 56 | " ai_message.pretty_print()\n", 57 | " # 5 会話履歴の追加\n", 58 | " history.append(ai_message)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "# 2.2.3 翻訳アプリ" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "from langchain_openai.chat_models import ChatOpenAI\n", 75 | "from langchain_core.prompts import PromptTemplate\n", 76 | "\n", 77 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 78 | "\n", 79 | "# 1 テンプレートの作成\n", 80 | "TRANSLATION_PROMPT = \"\"\"\\\n", 81 | "以下の文章を {language} に翻訳し、翻訳結果のみを返してください。\n", 82 | "{source_text}\n", 83 | "\"\"\"\n", 84 | "prompt = PromptTemplate.from_template(TRANSLATION_PROMPT)\n", 85 | "\n", 86 | "# 2 Runnable の作成\n", 87 | "runnable = prompt | llm\n", 88 | "\n", 89 | "language = \"日本語\"\n", 90 | "source_text = \"\"\"\\\n", 91 | "cogito, ergo sum\n", 92 | "\"\"\"\n", 93 | "\n", 94 | "# 3 Runnable の実行と結果の表示\n", 95 | "response = runnable.invoke(dict(language=language, source_text=source_text))\n", 96 | "response.pretty_print()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "# 2.2.4 テーブル作成アプリ" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "from langchain_core.tools import tool\n", 113 | "from langchain_core.pydantic_v1 import BaseModel, Field\n", 114 | "from langchain_openai.chat_models import ChatOpenAI\n", 115 | "import csv\n", 116 | "\n", 117 | "\n", 118 | "# 1 入力形式の定義\n", 119 | "class CSVSaveToolInput(BaseModel):\n", 120 | " filename: str = Field(description=\"ファイル名\")\n", 121 | " csv_text: str = Field(description=\"CSVのテキスト\")\n", 122 | "\n", 123 | "\n", 124 | "# 2 ツール本体の定義\n", 125 | "@tool(\"csv-save-tool\", args_schema=CSVSaveToolInput)\n", 126 | "def csv_save(filename: str, csv_text: str) -> bool:\n", 127 | " \"\"\"CSV テキストをファイルに保存する\"\"\"\n", 128 | " # parse CSV text\n", 129 | " try:\n", 130 | " rows = list(csv.reader(csv_text.splitlines()))\n", 131 | " except Exception as e:\n", 132 | " return False\n", 133 | "\n", 134 | " # save to file\n", 135 | " with open(filename, \"w\") as f:\n", 136 | " writer = csv.writer(f)\n", 137 | " writer.writerows(rows)\n", 138 | "\n", 139 | " return True" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# 3 ツールを LLM に紐づける\n", 149 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 150 | "tools = [csv_save]\n", 151 | "llm_with_tool = llm.bind_tools(tools=tools, tool_choice=\"csv-save-tool\")\n", 152 | "\n", 153 | "TABLE_PROMPT = \"\"\"\\\n", 154 | "{user_input}\n", 155 | "\n", 156 | "結果は CSV ファイルに保存してください。ただし、ファイル名は上記の内容から適切に決定してください。\n", 157 | "\"\"\"\n", 158 | "prompt = PromptTemplate.from_template(TABLE_PROMPT)\n", 159 | "get_tool_args = lambda x: x.tool_calls[0]\n", 160 | "\n", 161 | "# 4 Runnable の作成\n", 162 | "runnable = prompt | llm_with_tool | get_tool_args | csv_save\n", 163 | "\n", 164 | "user_input = \"フィボナッチ数列の番号と値を10番目まで表にまとめて、CSV ファイルに保存してください。\"\n", 165 | "\n", 166 | "# 5 Runnable の実行と結果の確認\n", 167 | "response = runnable.invoke(dict(user_input=user_input))\n", 168 | "print(response)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "# 2.2.5 Plan-and-Solve" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "from langchain_openai import ChatOpenAI\n", 185 | "\n", 186 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "from pydantic import BaseModel, Field\n", 196 | "\n", 197 | "\n", 198 | "# ツール入力形式の定義\n", 199 | "class ActionItem(BaseModel):\n", 200 | " action_name: str = Field(description=\"アクション名\")\n", 201 | " action_description: str = Field(description=\"アクションの詳細\")\n", 202 | "\n", 203 | "\n", 204 | "class Plan(BaseModel):\n", 205 | " \"\"\"アクションプランを格納する\"\"\"\n", 206 | "\n", 207 | " problem: str = Field(description=\"問題の説明\")\n", 208 | " actions: list[ActionItem] = Field(description=\"実行すべきアクションリスト\")\n", 209 | "\n", 210 | "\n", 211 | "class ActionResult(BaseModel):\n", 212 | " \"\"\"実行時の考えと結果を格納する\"\"\"\n", 213 | "\n", 214 | " thoughts: str = Field(description=\"検討内容\")\n", 215 | " result: str = Field(description=\"結果\")" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# 単一アクションの実行\n", 225 | "from langchain_openai.output_parsers.tools import PydanticToolsParser\n", 226 | "from langchain_core.prompts import PromptTemplate\n", 227 | "\n", 228 | "\n", 229 | "ACTION_PROMPT = \"\"\"\\\n", 230 | "問題をアクションプランに分解して解いています。\n", 231 | "これまでのアクションの結果と、次に行うべきアクションを示すので、実際にアクションを実行してその結果を報告してください。\n", 232 | "# 問題\n", 233 | "{problem}\n", 234 | "# アクションプラン\n", 235 | "{action_items}\n", 236 | "# これまでのアクションの結果\n", 237 | "{action_results}\n", 238 | "# 次のアクション\n", 239 | "{next_action}\n", 240 | "\"\"\"\n", 241 | "\n", 242 | "llm_action = llm.bind_tools([ActionResult], tool_choice=\"ActionResult\")\n", 243 | "action_parser = PydanticToolsParser(tools=[ActionResult], first_tool_only=True)\n", 244 | "plan_parser = PydanticToolsParser(tools=[Plan], first_tool_only=True)\n", 245 | "\n", 246 | "action_prompt = PromptTemplate.from_template(ACTION_PROMPT)\n", 247 | "action_runnable = action_prompt | llm_action | action_parser" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "# プランに含まれるアクションを実行するループ\n", 257 | "from langchain_core.messages import AIMessage\n", 258 | "\n", 259 | "\n", 260 | "def action_loop(action_plan: Plan):\n", 261 | " problem = action_plan.problem\n", 262 | " actions = action_plan.actions\n", 263 | "\n", 264 | " action_items = \"\\n\".join([\"* \" + action.action_name for action in actions])\n", 265 | " action_results = []\n", 266 | " action_results_str = \"\"\n", 267 | " for i, action in enumerate(actions):\n", 268 | " print(\"=\" * 20)\n", 269 | " print(f\"[{i+1}/{len(actions)}]以下のアクションを実行します。\")\n", 270 | " print(action.action_name)\n", 271 | "\n", 272 | " next_action = f\"* {action.action_name} \\n{action.action_description}\"\n", 273 | " response = action_runnable.invoke(\n", 274 | " dict(\n", 275 | " problem=problem,\n", 276 | " action_items=action_items,\n", 277 | " action_results=action_results_str,\n", 278 | " next_action=next_action,\n", 279 | " )\n", 280 | " )\n", 281 | " action_results.append(response)\n", 282 | " action_results_str += f\"* {action.action_name} \\n{response.result}\\n\"\n", 283 | " print(\"-\" * 10 + \"検討内容\" + \"-\" * 10)\n", 284 | " print(response.thoughts)\n", 285 | " print(\"-\" * 10 + \"結果\" + \"-\" * 10)\n", 286 | " print(response.result)\n", 287 | "\n", 288 | " return AIMessage(action_results_str)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "# Plan-and-Solve を行うか否かの分岐\n", 298 | "def route(ai_message: AIMessage):\n", 299 | " if ai_message.response_metadata[\"finish_reason\"] == \"tool_calls\":\n", 300 | " return plan_parser | action_loop\n", 301 | " else:\n", 302 | " return ai_message" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "# 全体を通した Runnable 作成\n", 312 | "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", 313 | "from langchain_core.messages import SystemMessage\n", 314 | "\n", 315 | "PLAN_AND_SOLVE_PROMPT = \"\"\"\\\n", 316 | "ユーザの質問が複雑な場合は、アクションプランを作成し、その後に1つずつ実行する Plan-and-Solve 形式をとります。\n", 317 | "これが必要と判断した場合は、Plan ツールによってアクションプランを保存してください。\n", 318 | "\"\"\"\n", 319 | "system_prompt = SystemMessage(PLAN_AND_SOLVE_PROMPT)\n", 320 | "chat_prompt = ChatPromptTemplate.from_messages(\n", 321 | " [system_prompt, MessagesPlaceholder(variable_name=\"history\")]\n", 322 | ")\n", 323 | "\n", 324 | "llm_plan = llm.bind_tools(tools=[Plan])\n", 325 | "planning_runnable = chat_prompt | llm_plan | route" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "# チャット部分の作成\n", 335 | "history = []\n", 336 | "n = 10\n", 337 | "for i in range(10):\n", 338 | " user_input = input(\"ユーザ入力: \")\n", 339 | " if user_input == \"exit\":\n", 340 | " break\n", 341 | " # 1 HumanMessage の作成と表示\n", 342 | " human_message = HumanMessage(user_input)\n", 343 | " human_message.pretty_print()\n", 344 | " # 2 会話履歴の追加\n", 345 | " history.append(HumanMessage(user_input))\n", 346 | " # 3 応答の作成と表示\n", 347 | " ai_message = planning_runnable.invoke(dict(history=history))\n", 348 | " ai_message.pretty_print()\n", 349 | " # 4 会話履歴の追加\n", 350 | " history.append(ai_message)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "# 出力例\n", 360 | "\n", 361 | "# ================================ Human Message =================================\n", 362 | "#\n", 363 | "# ある製造工場では、1時間に200個の部品が生産されます。工場は1日8時間稼働し、1週間に5日間営業しています。生産された部品のうち5%は品質不良で廃棄されます。この工場では1ヶ月(4週間)に品質不良で廃棄される部品の総数を求めなさい。\n", 364 | "# ====================\n", 365 | "# [1/4]以下のアクションを実行します。\n", 366 | "# 部品の1日の生産量を求める\n", 367 | "# ----------検討内容----------\n", 368 | "# 部品の1日の生産量を求めた結果、1日の生産量は1600個である。\n", 369 | "# ----------結果----------\n", 370 | "# 部品の1日の生産量は1600個である。\n", 371 | "# ====================\n", 372 | "# [2/4]以下のアクションを実行します。\n", 373 | "# 部品の1週間の生産量を求める\n", 374 | "# ----------検討内容----------\n", 375 | "# 1週間の生産量を求めた。1日の生産量1600個に5営業日を掛けて計算した結果、8000個となった。この結果を記録する。\n", 376 | "# ----------結果----------\n", 377 | "# 部品の1週間の生産量は8000個である。\n", 378 | "# ====================\n", 379 | "# [3/4]以下のアクションを実行します。\n", 380 | "# 部品の1ヶ月の生産量を求める\n", 381 | "# ----------検討内容----------\n", 382 | "# 部品の1ヶ月の生産量は8000個 × 4週間 = 32000個と計算した。\n", 383 | "# ----------結果----------\n", 384 | "# 部品の1ヶ月の生産量は32000個である。\n", 385 | "# ====================\n", 386 | "# [4/4]以下のアクションを実行します。\n", 387 | "# 品質不良で廃棄される部品の数を求める\n", 388 | "# ----------検討内容----------\n", 389 | "# 品質不良で廃棄される部品の数を求めるために、1ヶ月の生産量32000個に5%を掛け算して1600個を算出する。\n", 390 | "# ----------結果----------\n", 391 | "# 品質不良で廃棄される部品の数は1600個である。\n", 392 | "# ================================== Ai Message ==================================\n", 393 | "#\n", 394 | "# * 部品の1日の生産量を求める\n", 395 | "# 部品の1日の生産量は1600個である。\n", 396 | "# * 部品の1週間の生産量を求める\n", 397 | "# 部品の1週間の生産量は8000個である。\n", 398 | "# * 部品の1ヶ月の生産量を求める\n", 399 | "# 部品の1ヶ月の生産量は32000個である。\n", 400 | "# * 品質不良で廃棄される部品の数を求める\n", 401 | "# 品質不良で廃棄される部品の数は1600個である。" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [] 410 | } 411 | ], 412 | "metadata": { 413 | "kernelspec": { 414 | "display_name": ".venv", 415 | "language": "python", 416 | "name": "python3" 417 | }, 418 | "language_info": { 419 | "codemirror_mode": { 420 | "name": "ipython", 421 | "version": 3 422 | }, 423 | "file_extension": ".py", 424 | "mimetype": "text/x-python", 425 | "name": "python", 426 | "nbconvert_exporter": "python", 427 | "pygments_lexer": "ipython3", 428 | "version": "3.11.3" 429 | } 430 | }, 431 | "nbformat": 4, 432 | "nbformat_minor": 2 433 | } 434 | -------------------------------------------------------------------------------- /notebooks/chapter3/05_persona.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 3.5章 ペルソナのあるエージェント" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "colab": { 15 | "base_uri": "https://localhost:8080/" 16 | }, 17 | "id": "TCVITKYgV-og", 18 | "outputId": "b01863f3-1731-42e4-a8d0-f22257c2e21a" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install langchain\n", 23 | "!pip install langchain-openai\n", 24 | "\n", 25 | "!pip install serpapi\n", 26 | "!pip install google-search-results\n", 27 | "\n", 28 | "# load_toolsを利用するのに必要\n", 29 | "!pip install langchain_community" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "l-8zAJ6if-Jx" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "import os\n", 41 | "from google.colab import userdata\n", 42 | "\n", 43 | "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n", 44 | "os.environ['SERPAPI_API_KEY'] = userdata.get('SERPAPI_API_KEY')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "id": "HEhyVk_rcYDz" 51 | }, 52 | "source": [ 53 | "# 3.5\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "c59aBMgmojMn" 60 | }, 61 | "source": [ 62 | "## 3.5.2 ペルソナ付与のためのプロンプト技術" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "NYNy63_UxwEG", 73 | "outputId": "d1feaa24-7ad9-4663-e493-53af7f8d6fa0" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "from langchain_openai import ChatOpenAI\n", 78 | "from langchain_core.prompts import ChatPromptTemplate\n", 79 | "\n", 80 | "# プロンプトテンプレートの作成\n", 81 | "message = \"\"\"\n", 82 | "以下の質問に答えてください。\n", 83 | "\n", 84 | "{question}\n", 85 | "\"\"\"\n", 86 | "\n", 87 | "prompt = ChatPromptTemplate.from_messages([(\"human\",\\\n", 88 | "message)])\n", 89 | "\n", 90 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 91 | "chain = prompt | model\n", 92 | "\n", 93 | "question_text = \"LLMエージェントについて教えてください。\"\n", 94 | "\n", 95 | "\n", 96 | "response = chain.invoke({\"question\": question_text})\n", 97 | "print(response.content)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "colab": { 105 | "base_uri": "https://localhost:8080/" 106 | }, 107 | "id": "3C1D7Dxqxv9j", 108 | "outputId": "af3692da-c5a6-46e0-b924-e4e3511863b4" 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "# プロンプトテンプレートの作成\n", 113 | "message = \"\"\"\n", 114 | "あなたは「えりすちゃん」というキャラクターです。\n", 115 | "えりすちゃんは以下のような特徴のキャラクターです。\n", 116 | "- 株式会社Elithのマスコット\n", 117 | "- ペガサスの見た目をしている\n", 118 | "- 人懐っこい性格で、誰にでも優しく接する\n", 119 | "- ポジティブな性格で励ましの言葉を常に意識している\n", 120 | "- 「~エリ!」というのが口癖\n", 121 | " - 例:「今日も頑張るエリ!」\n", 122 | "\n", 123 | "「えりすちゃん」として以下の質問に答えてください。\n", 124 | "\n", 125 | "{question}\n", 126 | "\"\"\"\n", 127 | "\n", 128 | "prompt = ChatPromptTemplate.from_messages([(\"human\", message)])\n", 129 | "\n", 130 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 131 | "chain = prompt | model\n", 132 | "\n", 133 | "question_text = \"LLMエージェントについて教えてください。\"\n", 134 | "\n", 135 | "\n", 136 | "response = chain.invoke({\"question\": question_text})\n", 137 | "print(response.content)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "rjTfI549qa8p" 144 | }, 145 | "source": [ 146 | "## 3.5.3 ペルソナ付与のためのメモリ技術" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "colab": { 154 | "base_uri": "https://localhost:8080/" 155 | }, 156 | "id": "NafsFbAgca2a", 157 | "outputId": "5a3a4179-bc18-46c4-e5a7-a14d5ce08f6e" 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "!pip install mem0ai" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "id": "SfDfZi8-o2hv" 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "os.environ['MEM0_API_KEY'] = userdata.get('MEM0_API_KEY')" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "colab": { 180 | "base_uri": "https://localhost:8080/" 181 | }, 182 | "id": "xeFTL4pRfpfI", 183 | "outputId": "da930076-c4d8-4ba0-a332-045bd7b5c83a" 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "from mem0 import MemoryClient\n", 188 | "\n", 189 | "# Mem0 クライアントの初期化\n", 190 | "client = MemoryClient(api_key=os.environ['MEM0_API_KEY'])\n", 191 | "\n", 192 | "# 特定のユーザの全てのメモリを削除\n", 193 | "client.delete_all(user_id=\"elith_chan\")\n", 194 | "\n", 195 | "# 削除後、ユーザのメモリを取得\n", 196 | "user_memories = client.get_all(user_id=\"elith_chan\")\n", 197 | "print(user_memories)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": { 204 | "colab": { 205 | "base_uri": "https://localhost:8080/" 206 | }, 207 | "id": "KmkrFIxUxmSu", 208 | "outputId": "26c1d1d2-1842-4109-e8d3-0a2c82e72c1e" 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "# プロンプトテンプレートの作成\n", 213 | "message = \"\"\"\n", 214 | "Translate the following text into {language}.\n", 215 | "\n", 216 | "text:\n", 217 | "{text}\n", 218 | "\"\"\"\n", 219 | "\n", 220 | "prompt = ChatPromptTemplate.from_messages([(\"human\",\\\n", 221 | "message)])\n", 222 | "\n", 223 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 224 | "translate_chain = prompt | model\n", 225 | "\n", 226 | "text = \"私の名前は「えりすちゃん」です。私は、AI系スタートアップの\\\n", 227 | "Elithに所属しています。私はElithを象徴するキャラクターとして、知識と優\\\n", 228 | "しさを兼ね備えた存在です。。\"\n", 229 | "language = \"English\"\n", 230 | "\n", 231 | "response = translate_chain.invoke({\"text\": text, \"language\": language})\n", 232 | "print(response.content)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": { 239 | "colab": { 240 | "base_uri": "https://localhost:8080/" 241 | }, 242 | "id": "6TVwuh7fg-fF", 243 | "outputId": "caab5951-eb8f-4c21-9258-c52af7298221" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "text = \"私の名前は「えりすちゃん」です。私は、AI系スタートアップの\\\n", 248 | "Elithに所属しています。私はElithを象徴するキャラクターとして、知識と優\\\n", 249 | "しさを兼ね備えた存在です。Elithのことを世の中に発信することが私の仕事で\\\n", 250 | "す。\"\n", 251 | "language = \"English\"\n", 252 | "\n", 253 | "text_en = translate_chain.invoke({\"text\": text, \"language\"\\\n", 254 | ": language}).content\n", 255 | "\n", 256 | "messages = [\n", 257 | " {\"role\": \"user\", \"content\":text_en},\n", 258 | "]\n", 259 | "client.add(messages, user_id=\"elith_chan\")" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "colab": { 267 | "base_uri": "https://localhost:8080/" 268 | }, 269 | "id": "DPjHH636h4-7", 270 | "outputId": "3dad6aa8-5525-4b6d-b4ef-21f31fa6b3f0" 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "user_memories = client.get_all(user_id=\"elith_chan\")\n", 275 | "print(user_memories)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": { 282 | "colab": { 283 | "base_uri": "https://localhost:8080/" 284 | }, 285 | "id": "6oJYvJrR0vpp", 286 | "outputId": "e10b70f4-6cfe-46aa-eb66-efad0e306860" 287 | }, 288 | "outputs": [], 289 | "source": [ 290 | "query_ja = \"あなたのお仕事は何ですか?\"\n", 291 | "language = \"English\"\n", 292 | "\n", 293 | "query_en = translate_chain.invoke({\"text\": query_ja, \"language\": language}).content\n", 294 | "client.search(query_en, user_id=\"elith_chan\")" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": { 300 | "id": "Mhio2drszHmS" 301 | }, 302 | "source": [ 303 | "## 3.5.4 mem0 を用いたエージェント作成" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": { 310 | "colab": { 311 | "base_uri": "https://localhost:8080/" 312 | }, 313 | "id": "3ZQD5BpD3VlA", 314 | "outputId": "0b8c0085-86ac-49bf-f0d9-18a098591ffa" 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "from langchain_core.prompts import PromptTemplate\n", 319 | "\n", 320 | "input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools']\n", 321 | "template=\"\"\"\\\n", 322 | "Answer the following questions as best you can. You have access to the following tools:\n", 323 | "\n", 324 | "{tools}\n", 325 | "\n", 326 | "Use the following format:\n", 327 | "\n", 328 | "Question: the input question you must answer\n", 329 | "Thought: you should always think about what to do\n", 330 | "Action: the action to take, should be one of [{tool_names}]\n", 331 | "Action Input: the input to the action\n", 332 | "Observation: the result of the action\n", 333 | "... (this Thought/Action/Action Input/Observation can repeat N times)\n", 334 | "Thought: I now know the final answer\n", 335 | "Final Answer: the final answer to the original input question\n", 336 | "\n", 337 | "Begin!\n", 338 | "\n", 339 | "Previous conversation history: {chat_history}\n", 340 | "Question: {input}\n", 341 | "Thought:{agent_scratchpad}\"\"\"\n", 342 | "\n", 343 | "prompt = PromptTemplate(input_variables=input_variables, template=template)\n", 344 | "print(prompt.template)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "colab": { 352 | "base_uri": "https://localhost:8080/" 353 | }, 354 | "id": "xycCm1vi2-CG", 355 | "outputId": "b69f4ed5-caaf-4e97-9fc7-f54d06500836" 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "from langchain.agents import load_tools\n", 360 | "from langchain.agents import AgentExecutor, create_react_agent\n", 361 | "\n", 362 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 363 | "tools = load_tools([\"serpapi\"], llm=model)\n", 364 | "agent = create_react_agent(model, tools, prompt)\n", 365 | "agent_executor = AgentExecutor(agent=agent, tools=tools,\\\n", 366 | "verbose=True, handle_parsing_errors=True)\n", 367 | "\n", 368 | "query_ja = \"あなたのお仕事は何ですか?\"\n", 369 | "language = \"English\"\n", 370 | "\n", 371 | "query_en = translate_chain.invoke({\"text\": query_ja, \"language\"\\\n", 372 | ": language}).content\n", 373 | "memory = client.search(query_en, user_id=\"elith_chan\")\n", 374 | "\n", 375 | "\n", 376 | "response = agent_executor.invoke({\"input\": query_ja,\\\n", 377 | "'chat_history':memory},)\n", 378 | "\n", 379 | "print(response[\"output\"])" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": { 386 | "colab": { 387 | "base_uri": "https://localhost:8080/" 388 | }, 389 | "id": "1bG11f2PTfpZ", 390 | "outputId": "d2bd5c6e-f3c5-41b4-982c-c5523fc7abee" 391 | }, 392 | "outputs": [], 393 | "source": [ 394 | "text = \"私、えりすちゃんは「〜エリ!」という語尾を使います。「今日も頑張る\\\n", 395 | "エリ!」が口癖です。\"\n", 396 | "language = \"English\"\n", 397 | "\n", 398 | "text_en = translate_chain.invoke({\"text\": text, \"language\"\\\n", 399 | ": language}).content\n", 400 | "\n", 401 | "messages = [\n", 402 | " {\"role\": \"user\", \"content\":text_en},\n", 403 | "]\n", 404 | "client.add(messages, user_id=\"elith_chan\")" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "metadata": { 411 | "colab": { 412 | "base_uri": "https://localhost:8080/" 413 | }, 414 | "id": "nVQs-m-_T3fU", 415 | "outputId": "763bb0a2-5662-4e11-95ec-f6043f78ac31" 416 | }, 417 | "outputs": [], 418 | "source": [ 419 | "from langchain.agents import load_tools\n", 420 | "from langchain.agents import AgentExecutor, create_react_agent\n", 421 | "\n", 422 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 423 | "tools = load_tools([\"serpapi\"], llm=model)\n", 424 | "agent = create_react_agent(model, tools, prompt)\n", 425 | "agent_executor = AgentExecutor(agent=agent, tools=tools,\\\n", 426 | "verbose=True, handle_parsing_errors=True)\n", 427 | "\n", 428 | "query_ja = \"あなたのお仕事は何ですか?\"\n", 429 | "language = \"English\"\n", 430 | "\n", 431 | "query_en = translate_chain.invoke({\"text\": query_ja, \"language\"\\\n", 432 | ": language}).content\n", 433 | "memory = client.search(query_en, user_id=\"elith_chan\")\n", 434 | "\n", 435 | "\n", 436 | "response = agent_executor.invoke({\"input\": query_ja,\\\n", 437 | "'chat_history':memory},)\n", 438 | "\n", 439 | "print(response[\"output\"])" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": { 446 | "colab": { 447 | "base_uri": "https://localhost:8080/" 448 | }, 449 | "id": "IO4hA5_O6uNy", 450 | "outputId": "dbfe7172-c225-4e0c-9b11-2f5288a348da" 451 | }, 452 | "outputs": [], 453 | "source": [ 454 | "# プロンプト定義\n", 455 | "\n", 456 | "from langchain_core.prompts import PromptTemplate\n", 457 | "\n", 458 | "input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools']\n", 459 | "\n", 460 | "\n", 461 | "template=\"\"\"\\\n", 462 | "あなたは「えりすちゃん」です。\n", 463 | "えりすちゃんは、AI系スタートアップのElithを象徴するキャラクターとして、知識と優しさを兼ね備えた存在です。\n", 464 | "えりすちゃんは「〜エリ!」という語尾を使います。\n", 465 | "例:「一緒に頑張るエリ!」\n", 466 | "\n", 467 | "えりすちゃんとして、以下の質問に最善を尽くして答えてください。\n", 468 | "\n", 469 | "You have access to the following tools:\n", 470 | "\n", 471 | "{tools}\n", 472 | "\n", 473 | "Use the following format:\n", 474 | "\n", 475 | "Question: the input question you must answer\n", 476 | "Thought: you should always think about what to do\n", 477 | "Action: the action to take, should be one of [{tool_names}]\n", 478 | "Action Input: the input to the action\n", 479 | "Observation: the result of the action\n", 480 | "... (this Thought/Action/Action Input/Observation can repeat N times)\n", 481 | "Thought: I now know the final answer\n", 482 | "Final Answer: the final answer to the original input question\n", 483 | "\n", 484 | "Begin!\n", 485 | "\n", 486 | "Previous conversation history: {chat_history}\n", 487 | "Question: {input}\n", 488 | "Thought:{agent_scratchpad}\"\"\"\n", 489 | "\n", 490 | "\n", 491 | "prompt = PromptTemplate(input_variables=input_variables, template=template)\n", 492 | "print(prompt)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": { 499 | "colab": { 500 | "base_uri": "https://localhost:8080/" 501 | }, 502 | "id": "10b0WOGSFu0q", 503 | "outputId": "b07657f9-8d60-44b3-f1da-9d50ac932d14" 504 | }, 505 | "outputs": [], 506 | "source": [ 507 | "from langchain.agents import load_tools\n", 508 | "from langchain.agents import AgentExecutor, create_react_agent\n", 509 | "\n", 510 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 511 | "tools = load_tools([\"serpapi\"], llm=model)\n", 512 | "agent = create_react_agent(model, tools, prompt)\n", 513 | "agent_executor = AgentExecutor(agent=agent, tools=tools,\\\n", 514 | "verbose=True, handle_parsing_errors=True)\n", 515 | "\n", 516 | "query_ja = \"あなたのお仕事は何ですか?\"\n", 517 | "language = \"English\"\n", 518 | "\n", 519 | "query_en = translate_chain.invoke({\"text\": query_ja, \"language\"\\\n", 520 | ": language}).content\n", 521 | "memory = client.search(query_en, user_id=\"elith_chan\")\n", 522 | "\n", 523 | "\n", 524 | "response = agent_executor.invoke({\"input\": query_ja,\\\n", 525 | "'chat_history':memory},)\n", 526 | "\n", 527 | "print(response[\"output\"])" 528 | ] 529 | } 530 | ], 531 | "metadata": { 532 | "colab": { 533 | "provenance": [] 534 | }, 535 | "kernelspec": { 536 | "display_name": "Python 3", 537 | "name": "python3" 538 | }, 539 | "language_info": { 540 | "name": "python" 541 | } 542 | }, 543 | "nbformat": 4, 544 | "nbformat_minor": 0 545 | } 546 | -------------------------------------------------------------------------------- /notebooks/chapter2/01_openai_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 準備" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import getpass\n", 17 | "import os\n", 18 | "\n", 19 | "# OpenAI API キーの設定\n", 20 | "api_key = getpass.getpass(\"OpenAI API キーを入力してください: \")\n", 21 | "os.environ[\"OPENAI_API_KEY\"] = api_key" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# 2.1 OpenAI API" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## 2.1.1 テキスト生成の基礎" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# テキスト生成の基本的な流れ\n", 45 | "\n", 46 | "from openai import OpenAI\n", 47 | "\n", 48 | "client = OpenAI()\n", 49 | "\n", 50 | "response = client.chat.completions.create(\n", 51 | " temperature=0.0,\n", 52 | " model=\"gpt-4o-mini\",\n", 53 | " messages=[{\"role\": \"user\", \"content\": \"こんにちは\"}],\n", 54 | ")\n", 55 | "\n", 56 | "print(response.choices[0].message.content)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# シンプルな対話 AI の作成\n", 66 | "\n", 67 | "history = []\n", 68 | "n = 10 # 会話の上限\n", 69 | "model = \"gpt-4o-mini\"\n", 70 | "for _ in range(n):\n", 71 | " user_input = input(\"ユーザ入力: \")\n", 72 | " if user_input == \"exit\":\n", 73 | " break\n", 74 | " print(f\"ユーザ: {user_input}\")\n", 75 | " history.append({\"role\": \"user\", \"content\": user_input})\n", 76 | " response = client.chat.completions.create(model=model, messages=history)\n", 77 | " content = response.choices[0].message.content\n", 78 | " print(f\"AI: {content}\")\n", 79 | " history.append({\"role\": \"assistant\", \"content\": content})" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "## 2.1.2 テキスト生成の応用" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### Stream Generation" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# 生成できた部分から順に表示する\n", 103 | "\n", 104 | "history = []\n", 105 | "n = 10 # 会話の上限\n", 106 | "model = \"gpt-4o-mini\"\n", 107 | "for _ in range(n):\n", 108 | " user_input = input(\"ユーザ入力: \")\n", 109 | " if user_input == \"exit\":\n", 110 | " break\n", 111 | " print(f\"ユーザ: {user_input}\")\n", 112 | " history.append({\"role\": \"user\", \"content\": user_input})\n", 113 | " # stream=True でストリーミングを有効化\n", 114 | " stream = client.chat.completions.create(model=model, messages=history, stream=True)\n", 115 | " print(\"AI: \", end=\"\")\n", 116 | " # 応答を集める文字列\n", 117 | " ai_content = \"\"\n", 118 | " # ストリーミングの各チャンクを処理\n", 119 | " for chunk in stream:\n", 120 | " # message ではなく ChoiceDelta\n", 121 | " content = chunk.choices[0].delta.content\n", 122 | " # ChoiceDelta の finish_reason が stop なら生成完了\n", 123 | " if chunk.choices[0].finish_reason == \"stop\":\n", 124 | " break\n", 125 | " print(content, end=\"\")\n", 126 | " ai_content += content\n", 127 | " print()\n", 128 | " history.append({\"role\": \"assistant\", \"content\": ai_content})" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### Function Calling" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# 最大公約数を求めるツールの利用\n", 145 | "\n", 146 | "gcd_function = {\n", 147 | " \"name\": \"gcd\",\n", 148 | " \"description\": \"最大公約数を求める\",\n", 149 | " \"parameters\": {\n", 150 | " \"type\": \"object\",\n", 151 | " \"properties\": {\n", 152 | " \"num1\": {\"type\": \"number\", \"description\": \"整数1\"},\n", 153 | " \"num2\": {\"type\": \"number\", \"description\": \"整数2\"},\n", 154 | " },\n", 155 | " \"required\": [\"num1\", \"num2\"],\n", 156 | " },\n", 157 | "}\n", 158 | "tools = [{\"type\": \"function\", \"function\": gcd_function}]\n", 159 | "\n", 160 | "messages = [\n", 161 | " {\"role\": \"user\", \"content\": \"50141 と 53599 の最大公約数を求めてください。\"}\n", 162 | "]\n", 163 | "\n", 164 | "response = client.chat.completions.create(\n", 165 | " model=\"gpt-4o-mini\", messages=messages, tools=tools\n", 166 | ")\n", 167 | "print(response.choices[0].message.content) # None\n", 168 | "print(response.choices[0].finish_reason) # tool_calls\n", 169 | "print(response.choices[0].message.tool_calls) # [ChatCompletionMessageToolCall(...)]\n" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 18, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "# 関数情報を抽出\n", 179 | "\n", 180 | "import json\n", 181 | "\n", 182 | "function_info = response.choices[0].message.tool_calls[0].function\n", 183 | "name = function_info.name\n", 184 | "args = json.loads(function_info.arguments)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "# 最大公約数の計算\n", 194 | "\n", 195 | "import math\n", 196 | "\n", 197 | "print(math.gcd(args[\"num1\"], args[\"num2\"]))" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 20, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "# Pydantic を用いた関数の定義\n", 207 | "\n", 208 | "from pydantic import BaseModel, Field\n", 209 | "\n", 210 | "\n", 211 | "class GCD(BaseModel):\n", 212 | " num1: int = Field(description=\"整数1\")\n", 213 | " num2: int = Field(description=\"整数2\")\n", 214 | "\n", 215 | "\n", 216 | "gcd_function = {\n", 217 | " \"name\": \"gcd\",\n", 218 | " \"description\": \"最大公約数を求める\",\n", 219 | " \"parameters\": GCD.model_json_schema(),\n", 220 | "}" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 21, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "tools = [{\"type\": \"function\", \"function\": gcd_function}]\n", 230 | "\n", 231 | "messages = [\n", 232 | " {\"role\": \"user\", \"content\": \"50141 と 53599 の最大公約数を求めてください。\"}\n", 233 | "]\n", 234 | "\n", 235 | "response = client.chat.completions.create(\n", 236 | " model=\"gpt-4o-mini\", messages=messages, tools=tools\n", 237 | ")" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "# Pydantic を用いた引数の取得\n", 247 | "\n", 248 | "parsed_result = GCD.model_validate_json(\n", 249 | " response.choices[0].message.tool_calls[0].function.arguments\n", 250 | ")\n", 251 | "print(parsed_result)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "# ツール利用全体の流れ\n", 261 | "\n", 262 | "\n", 263 | "class LCM(BaseModel):\n", 264 | " num1: int = Field(description=\"整数1\")\n", 265 | " num2: int = Field(description=\"整数2\")\n", 266 | "\n", 267 | "\n", 268 | "lcm_function = {\n", 269 | " \"name\": \"lcm\",\n", 270 | " \"description\": \"最小公倍数を求める\",\n", 271 | " \"parameters\": LCM.model_json_schema(),\n", 272 | "}\n", 273 | "\n", 274 | "tools = [\n", 275 | " {\"type\": \"function\", \"function\": gcd_function},\n", 276 | " {\"type\": \"function\", \"function\": lcm_function},\n", 277 | "]\n", 278 | "\n", 279 | "messages = [\n", 280 | " {\n", 281 | " \"role\": \"user\",\n", 282 | " \"content\": \"50141 と 53599 の最大公約数と最小公倍数を求めてください。\",\n", 283 | " }\n", 284 | "]\n", 285 | "\n", 286 | "response = client.chat.completions.create(\n", 287 | " model=\"gpt-4o-mini\", messages=messages, tools=tools\n", 288 | ")\n", 289 | "choice = response.choices[0]\n", 290 | "if choice.finish_reason == \"tool_calls\":\n", 291 | " for tool in choice.message.tool_calls:\n", 292 | " if tool.function.name == \"gcd\":\n", 293 | " gcd_args = GCD.model_validate_json(tool.function.arguments)\n", 294 | " print(f\"最大公約数: {math.gcd(gcd_args.num1, gcd_args.num2)}\")\n", 295 | " elif tool.function.name == \"lcm\":\n", 296 | " lcm_args = LCM.model_validate_json(tool.function.arguments)\n", 297 | " print(f\"最小公倍数: {math.lcm(lcm_args.num1, lcm_args.num2)}\")\n", 298 | "elif choice.finish_reason == \"stop\":\n", 299 | " print(\"AI: \", choice.message.content)" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "### response_format" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "# response_format の利用例\n", 316 | "\n", 317 | "\n", 318 | "class Translations(BaseModel):\n", 319 | " english: str = Field(description=\"英語の文章\")\n", 320 | " french: str = Field(description=\"フランス語の文章\")\n", 321 | " chinese: str = Field(description=\"中国語の文章\")\n", 322 | "\n", 323 | "\n", 324 | "prompt = f\"\"\"\\\n", 325 | "以下に示す文章を英語・フランス語・中国語に翻訳してください。\n", 326 | "ただし、アウトプットは後述するフォーマットの JSON 形式で出力してください。\n", 327 | "\n", 328 | "# 文章\n", 329 | "吾輩は猫である。名前はまだない。\n", 330 | "\n", 331 | "# 出力フォーマット\n", 332 | "以下に JSON Schema 形式のフォーマットを示します。このフォーマットに従うオブジェクトの形で出力してください。\n", 333 | "{Translations.model_json_schema()}\n", 334 | "\"\"\"\n", 335 | "\n", 336 | "response = client.chat.completions.create(\n", 337 | " temperature=0.0,\n", 338 | " model=\"gpt-4o-mini\",\n", 339 | " messages=[{\"role\": \"user\", \"content\": prompt}],\n", 340 | " response_format={\"type\": \"json_object\"},\n", 341 | ")\n", 342 | "\n", 343 | "translations = Translations.model_validate_json(response.choices[0].message.content)\n", 344 | "print(\"英語:\", translations.english)\n", 345 | "print(\"フランス語:\", translations.french)\n", 346 | "print(\"中国語:\", translations.chinese)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "prompt = \"\"\"\\\n", 356 | "以下に示す文章を英語・フランス語・中国語に翻訳してください。\n", 357 | "ただし、アウトプットは後述するフォーマットの JSON 形式で出力してください。\n", 358 | "\n", 359 | "# 文章\n", 360 | "吾輩は猫である。名前はまだない。\n", 361 | "\n", 362 | "# 出力フォーマット\n", 363 | "JSON Schema に従う形式で出力してください。\n", 364 | "\"\"\"\n", 365 | "\n", 366 | "response = client.beta.chat.completions.parse(\n", 367 | " temperature=0.0,\n", 368 | " model=\"gpt-4o-mini\",\n", 369 | " messages=[{\"role\": \"user\", \"content\": prompt}],\n", 370 | " response_format=Translations,\n", 371 | ")\n", 372 | "translations = response.choices[0].message.parsed\n", 373 | "\n", 374 | "print(\"英語:\", translations.english)\n", 375 | "print(\"フランス語:\", translations.french)\n", 376 | "print(\"中国語:\", translations.chinese)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "# 2.1.3 画像を入力する" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "import base64\n", 393 | "from pathlib import Path\n", 394 | "from typing import Any\n", 395 | "\n", 396 | "from openai import OpenAI\n", 397 | "\n", 398 | "client = OpenAI()\n", 399 | "\n", 400 | "\n", 401 | "def image2content(image_path: Path) -> dict[str, Any]:\n", 402 | " # base64 エンコード\n", 403 | " with image_path.open(\"rb\") as f:\n", 404 | " image_base64 = base64.b64encode(f.read()).decode(\"utf-8\")\n", 405 | "\n", 406 | " # content の作成\n", 407 | " content = {\n", 408 | " \"type\": \"image_url\",\n", 409 | " \"image_url\": {\"url\": f\"data:image/png;base64,{image_base64}\", \"detail\": \"low\"},\n", 410 | " }\n", 411 | " return content" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "prompt = \"これは何の画像ですか?\"\n", 421 | "image_path = Path(\"./sample_image1.png\")\n", 422 | "contents = [{\"type\": \"text\", \"text\": prompt}, image2content(image_path)]\n", 423 | "\n", 424 | "response = client.chat.completions.create(\n", 425 | " model=\"gpt-4o-mini\",\n", 426 | " temperature=0.0,\n", 427 | " messages=[{\"role\": \"user\", \"content\": contents}],\n", 428 | ")\n", 429 | "\n", 430 | "print(response.choices[0].message.content)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "image_path2 = Path(\"./sample_image2.png\")\n", 440 | "\n", 441 | "prompt = \"2枚の画像の違いを教えてください。\"\n", 442 | "contents = [\n", 443 | " {\"type\": \"text\", \"text\": prompt},\n", 444 | " image2content(image_path),\n", 445 | " image2content(image_path2),\n", 446 | "]\n", 447 | "response = client.chat.completions.create(\n", 448 | " model=\"gpt-4o-mini\",\n", 449 | " temperature=0.0,\n", 450 | " messages=[{\"role\": \"user\", \"content\": contents}],\n", 451 | ")\n", 452 | "\n", 453 | "print(response.choices[0].message.content)" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": {}, 459 | "source": [ 460 | "# 2.1.4 音声を扱う" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": {}, 466 | "source": [ 467 | "音声を文字起こしする。" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "from pathlib import Path\n", 477 | "\n", 478 | "from openai import OpenAI\n", 479 | "\n", 480 | "\n", 481 | "client = OpenAI()\n", 482 | "audio_path = Path(\"./sample_audio.mp3\")\n", 483 | "\n", 484 | "with audio_path.open(\"rb\") as f:\n", 485 | " transcription = client.audio.transcriptions.create(\n", 486 | " model=\"whisper-1\", file=f, temperature=0.0\n", 487 | " )\n", 488 | "print(transcription.text)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "プロンプトを用いて文字起こしする。" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "prompt = \"下垣内\"\n", 505 | "\n", 506 | "with audio_path.open(\"rb\") as f:\n", 507 | " transcription = client.audio.transcriptions.create(\n", 508 | " model=\"whisper-1\",\n", 509 | " file=f,\n", 510 | " prompt=prompt,\n", 511 | " response_format=\"text\",\n", 512 | " temperature=0.0,\n", 513 | " )\n", 514 | "print(transcription)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "音声を合成する。" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "audio_output_path = Path(\"output.mp3\")\n", 531 | "with client.audio.speech.with_streaming_response.create(\n", 532 | " model=\"tts-1\",\n", 533 | " voice=\"alloy\",\n", 534 | " input=\"こんにちは。私は AI アシスタントです!\",\n", 535 | ") as response:\n", 536 | " response.stream_to_file(audio_output_path)" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": {}, 542 | "source": [ 543 | "# 2.1.5 画像を生成する" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "from openai import OpenAI\n", 553 | "import requests\n", 554 | "\n", 555 | "client = OpenAI()\n", 556 | "\n", 557 | "prompt = \"\"\"\\\n", 558 | "メタリックな球体\n", 559 | "\"\"\"\n", 560 | "\n", 561 | "response = client.images.generate(\n", 562 | " model=\"dall-e-3\", prompt=prompt, n=1, size=\"1024x1024\"\n", 563 | ")\n", 564 | "\n", 565 | "image_url = response.data[0].url\n", 566 | "image = requests.get(image_url).content\n", 567 | "with open(\"output1.png\", \"wb\") as f:\n", 568 | " f.write(image)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [ 577 | "response = client.images.generate(\n", 578 | " model=\"dall-e-3\", prompt=prompt, n=1, size=\"1024x1024\", response_format=\"b64_json\"\n", 579 | ")\n", 580 | "\n", 581 | "image = response.data[0].b64_json\n", 582 | "\n", 583 | "with open(\"output2.png\", \"wb\") as f:\n", 584 | " f.write(base64.b64decode(image))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "print(response.data[0].revised_prompt)" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [] 602 | } 603 | ], 604 | "metadata": { 605 | "kernelspec": { 606 | "display_name": ".venv", 607 | "language": "python", 608 | "name": "python3" 609 | }, 610 | "language_info": { 611 | "codemirror_mode": { 612 | "name": "ipython", 613 | "version": 3 614 | }, 615 | "file_extension": ".py", 616 | "mimetype": "text/x-python", 617 | "name": "python", 618 | "nbconvert_exporter": "python", 619 | "pygments_lexer": "ipython3", 620 | "version": "3.11.3" 621 | } 622 | }, 623 | "nbformat": 4, 624 | "nbformat_minor": 2 625 | } 626 | -------------------------------------------------------------------------------- /notebooks/chapter3/02_tools.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "IekT2f3zZXSq" 7 | }, 8 | "source": [ 9 | "# 3.2章 LLM にツールを与える" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "akxteNI1JM7i", 20 | "outputId": "85ff2ec1-6b33-4538-84d9-3b5b6716067d" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "!pip install langchain\n", 25 | "!pip install langchain-openai" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "id": "oWpcctbXuU6u" 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "import os\n", 37 | "from google.colab import userdata\n", 38 | "\n", 39 | "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n", 40 | "os.environ['SERPAPI_API_KEY'] = userdata.get('SERPAPI_\\\n", 41 | "API_KEY')" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": { 47 | "id": "WPvPr1gGZ4_S" 48 | }, 49 | "source": [ 50 | "## 3.2.1 検索ツール" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "colab": { 58 | "base_uri": "https://localhost:8080/" 59 | }, 60 | "id": "TSDvWNv7TMPO", 61 | "outputId": "0368e9ff-779a-4bdc-d4a7-e5134ac33a36" 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "from langchain_openai import ChatOpenAI\n", 66 | "from langchain.schema import HumanMessage\n", 67 | "\n", 68 | "question = \"株式会社Elithの住所を教えてください。最新の公式情報として\\\n", 69 | "公開されているものを教えてください。\"\n", 70 | "\n", 71 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 72 | "result = model.invoke([HumanMessage(content=question)])\n", 73 | "\n", 74 | "print(result.content)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": { 81 | "colab": { 82 | "base_uri": "https://localhost:8080/" 83 | }, 84 | "id": "fiwKwICdU3jz", 85 | "outputId": "15ba7bc6-9ae8-4ccf-912c-c48b69653be6" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "!pip install serpapi\n", 90 | "!pip install google-search-results" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "colab": { 98 | "base_uri": "https://localhost:8080/" 99 | }, 100 | "id": "vrTTNmINeDQr", 101 | "outputId": "2ec28ca2-77de-451d-9742-ebb35eb3ee4c" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "# load_toolsを利用するのに必要\n", 106 | "!pip install langchain_community" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "colab": { 114 | "base_uri": "https://localhost:8080/" 115 | }, 116 | "id": "8ORWNe8lTTCv", 117 | "outputId": "d75185df-51da-4569-c57e-a985d634fc0f" 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "from langchain.agents import load_tools\n", 122 | "\n", 123 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 124 | "\n", 125 | "#1 ツールをロード\n", 126 | "tools = load_tools([\"serpapi\"], llm=model)\n", 127 | "\n", 128 | "#2 LLMにツールを紐付け\n", 129 | "model_with_tools = model.bind_tools(tools)\n", 130 | "\n", 131 | "question = \"株式会社Elithの住所を教えてください。最新の公式情報として\\\n", 132 | "公開されているものを教えてください。\"\n", 133 | "\n", 134 | "response = model_with_tools.invoke([HumanMessage(content=\\\n", 135 | "question)])\n", 136 | "\n", 137 | "print(f\"ContentString: {response.content}\")\n", 138 | "print(f\"ToolCalls: {response.tool_calls}\")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "colab": { 146 | "base_uri": "https://localhost:8080/", 147 | "height": 164 148 | }, 149 | "id": "IAjdBqnATfW8", 150 | "outputId": "df5c0005-d40b-4152-e7e3-8baca3b4bdb2" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "tools = load_tools([\"serpapi\"], llm=model)\n", 155 | "search_tool = tools[0]\n", 156 | "search_tool.invoke(response.tool_calls[0][\"args\"])" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": { 162 | "id": "WmTQGidwaYfc" 163 | }, 164 | "source": [ 165 | "## 3.2.2 プログラム実行ツール" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "colab": { 173 | "base_uri": "https://localhost:8080/" 174 | }, 175 | "id": "K9K0NYnokt4i", 176 | "outputId": "261aca54-9af3-40bf-d2f6-55dc1949eb63" 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "question = \"以下をPythonで実行した場合の結果を教えてください。print\\\n", 181 | "(1873648+9285928+3759182+2398597)\"\n", 182 | "\n", 183 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 184 | "result = model.invoke([HumanMessage(content=question)])\n", 185 | "\n", 186 | "print(result.content)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "colab": { 194 | "base_uri": "https://localhost:8080/" 195 | }, 196 | "id": "yPHB1WtUlPrn", 197 | "outputId": "f9f08b99-dde2-42aa-8d3e-d5994185e6ea" 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "print(1873648+9285928+3759182+2398597)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": { 208 | "colab": { 209 | "base_uri": "https://localhost:8080/" 210 | }, 211 | "id": "NW0KKgK7jncx", 212 | "outputId": "1068c13d-013c-4323-a200-8f2ec059b0b2" 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "!pip install langchain_experimental" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "colab": { 224 | "base_uri": "https://localhost:8080/" 225 | }, 226 | "id": "e91GFj4mjddI", 227 | "outputId": "31233bb4-cb6d-46fb-f32d-bccfa70683ae" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "from langchain_experimental.tools.python.tool import\\\n", 232 | "PythonREPLTool\n", 233 | "\n", 234 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 235 | "tools = [PythonREPLTool()]\n", 236 | "model_with_tools = model.bind_tools(tools)\n", 237 | "\n", 238 | "question = \"以下をPythonで実行した場合の結果を教えてください。print\\\n", 239 | "(1873648+9285928+3759182+2398597)\"\n", 240 | "\n", 241 | "response = model_with_tools.invoke([HumanMessage(content=\\\n", 242 | "question)])\n", 243 | "\n", 244 | "print(f\"ContentString: {response.content}\")\n", 245 | "print(f\"ToolCalls: {response.tool_calls}\")" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "colab": { 253 | "base_uri": "https://localhost:8080/", 254 | "height": 54 255 | }, 256 | "id": "uW2zoV4QkiIa", 257 | "outputId": "9caf695d-1e7b-4206-eb55-dcb6e2598d54" 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "pythonrepltool = PythonREPLTool()\n", 262 | "pythonrepltool.invoke(response.tool_calls[0][\"args\"])" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": { 268 | "id": "V5x-zrmobPS6" 269 | }, 270 | "source": [ 271 | "## 3.2.3 ツールを自作する" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": { 278 | "colab": { 279 | "base_uri": "https://localhost:8080/", 280 | "height": 35 281 | }, 282 | "id": "Jf1QurbrbcSk", 283 | "outputId": "aa581735-03f3-4b4f-b6e0-7e09d0b1808c" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "# おみくじ関数\n", 288 | "\n", 289 | "import random\n", 290 | "from datetime import datetime\n", 291 | "\n", 292 | "def get_fortune(date_string):\n", 293 | " # 日付文字列を解析\n", 294 | " try:\n", 295 | " date = datetime.strptime(date_string, \"%m月%d日\")\n", 296 | " except ValueError:\n", 297 | " return \"無効な日付形式です。'X月X日'の形式で入力してくださ\\\n", 298 | "い。\"\n", 299 | "\n", 300 | " # 運勢のリスト\n", 301 | " fortunes = [\n", 302 | " \"大吉\", \"中吉\", \"小吉\", \"吉\", \"末吉\", \"凶\", \"大凶\"\n", 303 | " ]\n", 304 | "\n", 305 | " # 運勢の重み付け(大吉と大凶の確率を低くする)\n", 306 | " weights = [1, 3, 3, 4, 3, 2, 1]\n", 307 | "\n", 308 | " # 日付に基づいてシードを設定(同じ日付なら同じ運勢を返す)\n", 309 | " random.seed(date.month * 100 + date.day)\n", 310 | "\n", 311 | " # 運勢をランダムに選択\n", 312 | " fortune = random.choices(fortunes, weights=weights)[0]\n", 313 | "\n", 314 | " return f\"{date_string}の運勢は【{fortune}】です。\"\n", 315 | "\n", 316 | "# 出力例\n", 317 | "get_fortune(\"10月22日\")" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": { 324 | "id": "D3G7mUfOdUVO" 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "from langchain.tools import BaseTool\n", 329 | "\n", 330 | "#1 ツールの定義\n", 331 | "class Get_fortune(BaseTool):\n", 332 | " name: str = 'Get_fortune'\n", 333 | " description: str = (\n", 334 | " \"特定の日付の運勢を占う。インプットは 'date_string'です。\\\n", 335 | "'date_string' は、占いを行う日付で、mm月dd日 という形式です。「1月1日」\\\n", 336 | "のように入力し、「'1月1日'」のように余計な文字列を付けてはいけません。\"\n", 337 | " )\n", 338 | "\n", 339 | " def _run(self, date_string) -> str:\n", 340 | " return get_fortune(date_string)\n", 341 | "\n", 342 | "\n", 343 | " async def _arun(self, query: str) -> str:\n", 344 | " raise NotImplementedError(\"does not support async\")" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "colab": { 352 | "base_uri": "https://localhost:8080/" 353 | }, 354 | "id": "E-wz2idMd96E", 355 | "outputId": "721b7a70-6961-4438-b8eb-08aaebafdc9d" 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "tools = [Get_fortune()]\n", 360 | "\n", 361 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 362 | "model_with_tools = model.bind_tools(tools)\n", 363 | "\n", 364 | "question = \"10月22日の運勢を教えてください。\"\n", 365 | "\n", 366 | "response = model_with_tools.invoke([HumanMessage(content=\\\n", 367 | "question)])\n", 368 | "\n", 369 | "print(f\"ContentString: {response.content}\")\n", 370 | "print(f\"ToolCalls: {response.tool_calls}\")" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": { 377 | "colab": { 378 | "base_uri": "https://localhost:8080/" 379 | }, 380 | "id": "q4f_qEFOee51", 381 | "outputId": "e320f18b-4690-4546-89ed-92781036b900" 382 | }, 383 | "outputs": [], 384 | "source": [ 385 | "tool = Get_fortune()\n", 386 | "tool.invoke(response.tool_calls[0])" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": { 393 | "colab": { 394 | "base_uri": "https://localhost:8080/" 395 | }, 396 | "id": "ci5RimKbwLak", 397 | "outputId": "35efab9d-a3de-473a-91fe-6ae9b7f60997" 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "tools = [Get_fortune()]\n", 402 | "\n", 403 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 404 | "model_with_tools = model.bind_tools(tools)\n", 405 | "\n", 406 | "question = \"今日の運勢を教えてください。\"\n", 407 | "\n", 408 | "response = model_with_tools.invoke([HumanMessage(content=\\\n", 409 | "question)])\n", 410 | "\n", 411 | "print(f\"ContentString: {response.content}\")\n", 412 | "print(f\"ToolCalls: {response.tool_calls}\")" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": { 419 | "colab": { 420 | "base_uri": "https://localhost:8080/" 421 | }, 422 | "id": "xRF1G72iedPF", 423 | "outputId": "4e30eb86-978b-4869-ce88-0ed17611936b" 424 | }, 425 | "outputs": [], 426 | "source": [ 427 | "from datetime import timedelta\n", 428 | "from zoneinfo import ZoneInfo\n", 429 | "\n", 430 | "# 日付取得関数\n", 431 | "\n", 432 | "def get_date(date):\n", 433 | " date_now = datetime.now(ZoneInfo(\"Asia/Tokyo\"))\n", 434 | " if (\"今日\" in date):\n", 435 | " date_delta = 0\n", 436 | " elif (\"明日\" in date):\n", 437 | " date_delta = 1\n", 438 | " elif (\"明後日\" in date):\n", 439 | " date_delta = 2\n", 440 | " else:\n", 441 | " return \"サポートしていません\"\n", 442 | " return (date_now + timedelta(days=date_delta)).strftime\\\n", 443 | "('%m月%d日')\n", 444 | "\n", 445 | "# 出力例\n", 446 | "print(get_date(\"今日\"))" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "metadata": { 453 | "id": "h0iQgMS8eLih" 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "class Get_date(BaseTool):\n", 458 | " name: str = 'Get_date'\n", 459 | " description: str = (\n", 460 | " \"今日の日付を取得する。インプットは 'date'です。'date' は、日\\\n", 461 | "付を取得する対象の日で、'今日', '明日', '明後日' という3種類の文字列\\\n", 462 | "から指定します。「今日」のように入力し、「'今日'」のように余計な文字列を付\\\n", 463 | "けてはいけません。\"\n", 464 | " )\n", 465 | "\n", 466 | " def _run(self, date) -> str:\n", 467 | " return get_date(date)\n", 468 | "\n", 469 | " async def _arun(self, query: str) -> str:\n", 470 | " raise NotImplementedError(\"does not support async\")" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": { 477 | "colab": { 478 | "base_uri": "https://localhost:8080/" 479 | }, 480 | "id": "rJ_rseXRfyu9", 481 | "outputId": "5176665b-f645-4b8e-bb50-8c6ab1847925" 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "tools = [Get_date()]\n", 486 | "\n", 487 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 488 | "model_with_tools = model.bind_tools(tools)\n", 489 | "\n", 490 | "question = \"今日の日付を教えてください。。\"\n", 491 | "\n", 492 | "response = model_with_tools.invoke([HumanMessage(content=\\\n", 493 | "question)])\n", 494 | "\n", 495 | "print(f\"ContentString: {response.content}\")\n", 496 | "print(f\"ToolCalls: {response.tool_calls}\")" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "metadata": { 503 | "colab": { 504 | "base_uri": "https://localhost:8080/" 505 | }, 506 | "id": "DD5MmTjOf77F", 507 | "outputId": "35f8115e-b2aa-4219-df7c-68e938a63019" 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "tool = Get_date()\n", 512 | "tool.invoke(response.tool_calls[0])" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": { 519 | "colab": { 520 | "base_uri": "https://localhost:8080/" 521 | }, 522 | "id": "Bj46Qgv5rBP0", 523 | "outputId": "2c568f9f-f393-4301-da93-b4ab85ca5a9c" 524 | }, 525 | "outputs": [], 526 | "source": [ 527 | "tools = [Get_fortune(), Get_date()]\n", 528 | "\n", 529 | "model = ChatOpenAI(model=\"gpt-4o-mini\")\n", 530 | "model_with_tools = model.bind_tools(tools)\n", 531 | "\n", 532 | "question = \"今日の運勢を教えてください。。\"\n", 533 | "\n", 534 | "response = model_with_tools.invoke([HumanMessage(content=question)])\n", 535 | "\n", 536 | "print(f\"ContentString: {response.content}\")\n", 537 | "print(f\"ToolCalls: {response.tool_calls}\")" 538 | ] 539 | }, 540 | { 541 | "cell_type": "markdown", 542 | "metadata": { 543 | "id": "x3rfQs0rfHx8" 544 | }, 545 | "source": [ 546 | "## appendix" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "metadata": { 553 | "colab": { 554 | "base_uri": "https://localhost:8080/" 555 | }, 556 | "id": "eMJ7jxZ3fGcI", 557 | "outputId": "1661b771-83d2-4580-f833-f3e9929595d3" 558 | }, 559 | "outputs": [], 560 | "source": [ 561 | "# ツール定義(別パターン)\n", 562 | "\n", 563 | "from langchain_core.tools import Tool\n", 564 | "\n", 565 | "get_date_tool = Tool(\n", 566 | " name=\"Get_date\",\n", 567 | " description =\"今日の日付を取得する。インプットは 'date'です。'date' は、日付を取得する対象の日で、'今日', '明日', '明後日' という3種類の文字列から指定します。今日の日付を知りたい際は'今日'を入力します\",\n", 568 | " func=get_date\n", 569 | ")\n", 570 | "get_fortune_tool = Tool(\n", 571 | " name=\"Get_fortune\",\n", 572 | " description = \"特定の日付の運勢を占う。インプットは 'date_string'です。'date_string' は、占いを行う日付で、mm月dd日 という形式です。1月1日の占いを行う際は'1月1日'を入力します\",\n", 573 | " func=get_fortune\n", 574 | ")\n", 575 | "\n", 576 | "# 出力例\n", 577 | "print(get_date_tool.invoke(\"今日\"))\n", 578 | "print(get_fortune_tool.invoke(\"10月23日\"))" 579 | ] 580 | } 581 | ], 582 | "metadata": { 583 | "colab": { 584 | "provenance": [] 585 | }, 586 | "kernelspec": { 587 | "display_name": "Python 3", 588 | "name": "python3" 589 | }, 590 | "language_info": { 591 | "name": "python" 592 | } 593 | }, 594 | "nbformat": 4, 595 | "nbformat_minor": 0 596 | } 597 | -------------------------------------------------------------------------------- /notebooks/chapter2/03_gradio_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 準備" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import getpass\n", 17 | "import os\n", 18 | "\n", 19 | "import gradio as gr\n", 20 | "\n", 21 | "# OpenAI API キーの設定\n", 22 | "api_key = getpass.getpass(\"OpenAI API キーを入力してください: \")\n", 23 | "os.environ[\"OPENAI_API_KEY\"] = api_key" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Gradio の基礎" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## 簡単なインターフェースの実装" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def text2text(text):\n", 47 | " text = \"<<\" + text + \">>\"\n", 48 | " return text" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "input_text = gr.Text(label=\"入力\")\n", 58 | "output_text = gr.Text(label=\"出力\")\n", 59 | "\n", 60 | "demo = gr.Interface(inputs=input_text, outputs=output_text, fn=text2text)\n", 61 | "demo.launch(debug=True)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## ブロックの実装" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def text2text_rich(text):\n", 78 | " top = \"^\" * len(text)\n", 79 | " bottom = \"v\" * len(text)\n", 80 | " text = f\" {top}\\n<{text}>\\n {bottom}\"\n", 81 | " return text" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "with gr.Blocks() as demo:\n", 91 | " input_text = gr.Text(label=\"入力\")\n", 92 | " button1 = gr.Button(value=\"Normal\")\n", 93 | " button2 = gr.Button(value=\"Rich\")\n", 94 | " output_text = gr.Text(label=\"出力\")\n", 95 | "\n", 96 | " button1.click(inputs=input_text, outputs=output_text, fn=text2text)\n", 97 | " button2.click(inputs=input_text, outputs=output_text, fn=text2text_rich)\n", 98 | "demo.launch()" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "## 重要なコンポーネント" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "def audio_upload(audio):\n", 115 | " return audio\n", 116 | "\n", 117 | "\n", 118 | "with gr.Blocks() as demo:\n", 119 | " # Audio\n", 120 | " audio = gr.Audio(label=\"音声\", type=\"filepath\")\n", 121 | " # Checkbox\n", 122 | " checkbox = gr.Checkbox(label=\"チェックボックス\")\n", 123 | " # File\n", 124 | " file = gr.File(label=\"ファイル\", file_types=[\"image\"])\n", 125 | " # Number\n", 126 | " number = gr.Number(label=\"数値\")\n", 127 | " # Markdown\n", 128 | " markdown = gr.Markdown(label=\"Markdown\", value=\"# タイトル\\n## サブタイトル\\n本文\")\n", 129 | " # Slider\n", 130 | " slider = gr.Slider(\n", 131 | " label=\"スライダー\", minimum=-10, maximum=10, step=0.5, interactive=True\n", 132 | " )\n", 133 | " # Textbox\n", 134 | " textbox = gr.Textbox(label=\"テキストボックス\")\n", 135 | "\n", 136 | "demo.launch(height=1200)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "## UI の工夫" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "with gr.Blocks() as demo:\n", 153 | " # Accordion\n", 154 | " with gr.Accordion(label=\"アコーディオン\"):\n", 155 | " gr.Text(value=\"アコーディオンの中身\")\n", 156 | " with gr.Row():\n", 157 | " gr.Text(value=\"左\")\n", 158 | " gr.Text(value=\"右\")\n", 159 | "\n", 160 | " with gr.Row():\n", 161 | " with gr.Column():\n", 162 | " gr.Text(value=\"(0, 0)\")\n", 163 | " gr.Text(value=\"(1, 0)\")\n", 164 | " with gr.Column():\n", 165 | " gr.Text(value=\"(0, 1)\")\n", 166 | " gr.Text(value=\"(1, 1)\")\n", 167 | "\n", 168 | " with gr.Tab(label=\"タブ1\"):\n", 169 | " gr.Text(value=\"コンテンツ1\")\n", 170 | " with gr.Tab(label=\"タブ2\"):\n", 171 | " gr.Text(value=\"コンテンツ2\")\n", 172 | "\n", 173 | "demo.launch(height=800)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "with gr.Blocks() as demo:\n", 183 | " slider = gr.Slider(label=\"個数\", minimum=0, maximum=10, step=1)\n", 184 | "\n", 185 | " @gr.render(inputs=slider)\n", 186 | " def render_blocks(value):\n", 187 | " for i in range(value):\n", 188 | " gr.Text(value=f\"Block {i}\")\n", 189 | "\n", 190 | "\n", 191 | "demo.launch()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "import time\n", 201 | "\n", 202 | "\n", 203 | "def iterative_output():\n", 204 | " for i in range(10):\n", 205 | " time.sleep(0.5)\n", 206 | " yield str(i)\n", 207 | "\n", 208 | "\n", 209 | "with gr.Blocks() as demo:\n", 210 | " button = gr.Button(\"実行\")\n", 211 | " output = gr.Text(label=\"出力\")\n", 212 | " button.click(outputs=output, fn=iterative_output)\n", 213 | "\n", 214 | "demo.launch()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "## 状態を保持する" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "with gr.Blocks() as demo:\n", 231 | " username = gr.State(\"\")\n", 232 | " text_input = gr.Text(label=\"ユーザ名\")\n", 233 | " button1 = gr.Button(\"決定\")\n", 234 | " button2 = gr.Button(\"自分の名前を表示\")\n", 235 | " text_output = gr.Text(label=\"出力\")\n", 236 | " button1.click(inputs=text_input, outputs=username, fn=lambda x: x)\n", 237 | " button2.click(inputs=username, outputs=text_output, fn=lambda x: x)\n", 238 | "\n", 239 | "demo.launch()" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## チャット UI を作る" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "from langchain_openai.chat_models import ChatOpenAI\n", 256 | "\n", 257 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 258 | "\n", 259 | "\n", 260 | "def history2messages(history):\n", 261 | " messages = []\n", 262 | " for user, assistant in history:\n", 263 | " messages.append({\"role\": \"user\", \"content\": user})\n", 264 | " messages.append({\"role\": \"assistant\", \"content\": assistant})\n", 265 | " return messages\n", 266 | "\n", 267 | "\n", 268 | "def chat(message, history):\n", 269 | " messages = history2messages(history)\n", 270 | " messages.append({\"role\": \"user\", \"content\": message})\n", 271 | " response = llm.invoke(message)\n", 272 | " return response.content\n", 273 | "\n", 274 | "\n", 275 | "demo = gr.ChatInterface(chat)\n", 276 | "\n", 277 | "demo.launch()" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "## Stream チャットボット" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "from langchain_openai.chat_models import ChatOpenAI\n", 294 | "\n", 295 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 296 | "\n", 297 | "\n", 298 | "def chat(message, history):\n", 299 | " messages = history2messages(history)\n", 300 | " messages.append({\"role\": \"user\", \"content\": message})\n", 301 | " output = \"\"\n", 302 | " for chunk in llm.stream(messages):\n", 303 | " output += chunk.content\n", 304 | " yield output\n", 305 | "\n", 306 | "\n", 307 | "demo = gr.ChatInterface(chat)\n", 308 | "\n", 309 | "demo.launch(debug=True)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "# Gradio の応用" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "## 翻訳アプリケーション" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 13, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "# 2.2.3 の Runnable\n", 333 | "from langchain_core.prompts import PromptTemplate\n", 334 | "from langchain_openai.chat_models import ChatOpenAI\n", 335 | "\n", 336 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 337 | "\n", 338 | "# 1 テンプレートの作成\n", 339 | "TRANSLATION_PROMPT = \"\"\"\\\n", 340 | "以下の文章を {language} に翻訳し、翻訳結果のみを返してください。\n", 341 | "{source_text}\n", 342 | "\"\"\"\n", 343 | "prompt = PromptTemplate.from_template(TRANSLATION_PROMPT)\n", 344 | "\n", 345 | "# 2 Runnable の作成\n", 346 | "runnable = prompt | llm" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "languages = [\"日本語\", \"英語\", \"中国語\", \"ラテン語\", \"ギリシャ語\"]\n", 356 | "\n", 357 | "\n", 358 | "def translate(source_text, language):\n", 359 | " # 3 Runnable の実行\n", 360 | " response = runnable.invoke(dict(source_text=source_text, language=language))\n", 361 | " return response.content\n", 362 | "\n", 363 | "\n", 364 | "with gr.Blocks() as demo:\n", 365 | " # 入力テキスト\n", 366 | " source_text = gr.Textbox(label=\"翻訳元の文章\")\n", 367 | " # 言語を選択\n", 368 | " language = gr.Dropdown(label=\"言語\", choices=languages)\n", 369 | " button = gr.Button(\"翻訳\")\n", 370 | " # 出力テキスト\n", 371 | " translated_text = gr.Textbox(label=\"翻訳結果\")\n", 372 | "\n", 373 | " button.click(inputs=[source_text, language], outputs=translated_text, fn=translate)\n", 374 | "\n", 375 | "demo.launch()" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "## テーブル作成アプリケーション" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 70, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "# 2.2.4 の Runnable\n", 392 | "import csv\n", 393 | "\n", 394 | "import pandas as pd\n", 395 | "from langchain_core.prompts import PromptTemplate\n", 396 | "from langchain_core.pydantic_v1 import BaseModel, Field\n", 397 | "from langchain_core.tools import tool\n", 398 | "from langchain_openai.chat_models import ChatOpenAI\n", 399 | "\n", 400 | "\n", 401 | "# 1 入力形式の定義\n", 402 | "class CSV2DFToolInput(BaseModel):\n", 403 | " csv_text: str = Field(description=\"CSVのテキスト\")\n", 404 | "\n", 405 | "\n", 406 | "# 2 ツール本体の定義. csv を保存するツールから json に変換するツールに変更\n", 407 | "@tool(\"csv2json-tool\", args_schema=CSV2DFToolInput, return_direct=True)\n", 408 | "def csv2json(csv_text: str) -> pd.DataFrame:\n", 409 | " \"\"\"CSV テキストを pandas DataFrame に変換する\"\"\"\n", 410 | " try:\n", 411 | " rows = list(csv.reader(csv_text.splitlines()))\n", 412 | " df = pd.DataFrame(rows[1:], columns=rows[0])\n", 413 | " except Exception:\n", 414 | " df = pd.DataFrame()\n", 415 | " return df.to_json()\n", 416 | "\n", 417 | "\n", 418 | "# 3 ツールを LLM に紐づける\n", 419 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 420 | "# bind するツールを変更\n", 421 | "tools = [csv2json]\n", 422 | "llm_with_tool = llm.bind_tools(tools=tools, tool_choice=\"csv2json-tool\")\n", 423 | "\n", 424 | "# プロンプトを修正\n", 425 | "TABLE_PROMPT = \"\"\"\\\n", 426 | "{user_input}\n", 427 | "結果は CSV で作成し、csv2json-tool を利用して json に変換してください。\n", 428 | "\"\"\"\n", 429 | "prompt = PromptTemplate.from_template(TABLE_PROMPT)\n", 430 | "\n", 431 | "\n", 432 | "# 4 Runnable の作成\n", 433 | "def get_tool_args(x):\n", 434 | " return x.tool_calls[0] # AIMessage から ToolCall オブジェクトを取り出す。\n", 435 | "\n", 436 | "\n", 437 | "runnable = prompt | llm_with_tool | get_tool_args | csv2json" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "def create_df(user_input):\n", 447 | " response = runnable.invoke(dict(user_input=user_input))\n", 448 | " json_str = response.content\n", 449 | " df = pd.read_json(json_str)\n", 450 | " return df\n", 451 | "\n", 452 | "\n", 453 | "with gr.Blocks() as demo:\n", 454 | " # 入力テキスト\n", 455 | " user_input = gr.Textbox(label=\"テーブルを作成したい内容のテキスト\")\n", 456 | " button = gr.Button(\"実行\")\n", 457 | " # 出力テキスト\n", 458 | " output_table = gr.DataFrame()\n", 459 | "\n", 460 | " button.click(inputs=user_input, outputs=output_table, fn=create_df)\n", 461 | "\n", 462 | "demo.launch(height=1000)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "## Plan-and-Solve チャットボット" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 8, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "from langchain_openai import ChatOpenAI\n", 479 | "\n", 480 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 9, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "from langchain_core.pydantic_v1 import BaseModel, Field\n", 490 | "\n", 491 | "\n", 492 | "# ツール入力形式の定義\n", 493 | "class ActionItem(BaseModel):\n", 494 | " action_name: str = Field(description=\"アクション名\")\n", 495 | " action_description: str = Field(description=\"アクションの詳細\")\n", 496 | "\n", 497 | "\n", 498 | "class Plan(BaseModel):\n", 499 | " \"\"\"アクションプランを格納する\"\"\"\n", 500 | "\n", 501 | " problem: str = Field(description=\"問題の説明\")\n", 502 | " actions: list[ActionItem] = Field(description=\"実行すべきアクションリスト\")\n", 503 | "\n", 504 | "\n", 505 | "class ActionResult(BaseModel):\n", 506 | " \"\"\"実行時の考えと結果を格納する\"\"\"\n", 507 | "\n", 508 | " thoughts: str = Field(description=\"検討内容\")\n", 509 | " result: str = Field(description=\"結果\")" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 10, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "# 単一アクションの実行\n", 519 | "from langchain_openai.output_parsers.tools import PydanticToolsParser\n", 520 | "from langchain_core.prompts import PromptTemplate\n", 521 | "\n", 522 | "\n", 523 | "ACTION_PROMPT = \"\"\"\\\n", 524 | "問題をアクションプランに分解して解いています。\n", 525 | "これまでのアクションの結果と、次に行うべきアクションを示すので、実際にアクションを実行してその結果を報告してください。\n", 526 | "# 問題\n", 527 | "{problem}\n", 528 | "# アクションプラン\n", 529 | "{action_items}\n", 530 | "# これまでのアクションの結果\n", 531 | "{action_results}\n", 532 | "# 次のアクション\n", 533 | "{next_action}\n", 534 | "\"\"\"\n", 535 | "\n", 536 | "llm_action = llm.bind_tools([ActionResult], tool_choice=\"ActionResult\")\n", 537 | "action_parser = PydanticToolsParser(tools=[ActionResult], first_tool_only=True)\n", 538 | "plan_parser = PydanticToolsParser(tools=[Plan], first_tool_only=True)\n", 539 | "\n", 540 | "action_prompt = PromptTemplate.from_template(ACTION_PROMPT)\n", 541 | "action_runnable = action_prompt | llm_action | action_parser" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 11, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "# プランに含まれるアクションを実行するループ\n", 551 | "def action_loop(action_plan: Plan):\n", 552 | " problem = action_plan.problem\n", 553 | " actions = action_plan.actions\n", 554 | "\n", 555 | " action_items = \"\\n\".join([\"* \" + action.action_name for action in actions])\n", 556 | " action_results = []\n", 557 | " action_results_str = \"\"\n", 558 | " for _, action in enumerate(actions):\n", 559 | " next_action = f\"* {action.action_name} \\n{action.action_description}\"\n", 560 | " response = action_runnable.invoke(\n", 561 | " dict(\n", 562 | " problem=problem,\n", 563 | " action_items=action_items,\n", 564 | " action_results=action_results_str,\n", 565 | " next_action=next_action,\n", 566 | " )\n", 567 | " )\n", 568 | " action_results.append(response)\n", 569 | " action_results_str += f\"* {action.action_name} \\n{response.result}\\n\"\n", 570 | " yield (\n", 571 | " response.thoughts,\n", 572 | " response.result,\n", 573 | " ) # 変更ポイント: 途中結果を yield で返す" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 12, 579 | "metadata": {}, 580 | "outputs": [], 581 | "source": [ 582 | "# 全体を通した Runnable 作成\n", 583 | "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", 584 | "from langchain_core.messages import SystemMessage\n", 585 | "\n", 586 | "PLAN_AND_SOLVE_PROMPT = \"\"\"\\\n", 587 | "ユーザの質問が複雑な場合は、アクションプランを作成し、その後に1つずつ実行する Plan-and-Solve 形式をとります。\n", 588 | "これが必要と判断した場合は、Plan ツールによってアクションプランを保存してください。\n", 589 | "\"\"\"\n", 590 | "system_prompt = SystemMessage(PLAN_AND_SOLVE_PROMPT)\n", 591 | "chat_prompt = ChatPromptTemplate.from_messages(\n", 592 | " [system_prompt, MessagesPlaceholder(variable_name=\"history\")]\n", 593 | ")\n", 594 | "\n", 595 | "llm_plan = llm.bind_tools(tools=[Plan])\n", 596 | "planning_runnable = chat_prompt | llm_plan # route を削除" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 24, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "from gradio import ChatMessage\n", 606 | "from langchain_core.messages import AIMessage, HumanMessage\n", 607 | "\n", 608 | "\n", 609 | "def chat(prompt, messages, history):\n", 610 | " # 描画用の履歴をアップデート\n", 611 | " messages.append(ChatMessage(role=\"user\", content=prompt))\n", 612 | " # LangChain 用の履歴をアップデート\n", 613 | " history.append(HumanMessage(content=prompt))\n", 614 | " # プランまたは返答を作成\n", 615 | " response = planning_runnable.invoke(dict(history=history))\n", 616 | " if response.response_metadata[\"finish_reason\"] != \"tool_calls\":\n", 617 | " # タスクが簡単な場合はプランを作らずに返す\n", 618 | " messages.append(ChatMessage(role=\"assistant\", content=response.content))\n", 619 | " history.append(AIMessage(content=response.content))\n", 620 | " yield \"\", messages, history\n", 621 | " else:\n", 622 | " # アクションプランを抽出\n", 623 | " action_plan = plan_parser.invoke(response)\n", 624 | "\n", 625 | " # アクション名を表示\n", 626 | " action_items = \"\\n\".join(\n", 627 | " [\"* \" + action.action_name for action in action_plan.actions]\n", 628 | " )\n", 629 | " messages.append(\n", 630 | " ChatMessage(\n", 631 | " role=\"assistant\",\n", 632 | " content=action_items,\n", 633 | " metadata={\"title\": \"実行されるアクション\"},\n", 634 | " )\n", 635 | " )\n", 636 | " # プランの段階で一度描画する\n", 637 | " yield \"\", messages, history\n", 638 | "\n", 639 | " # アクションプランを実行\n", 640 | " action_results_str = \"\"\n", 641 | " for i, (thoughts, result) in enumerate(action_loop(action_plan)):\n", 642 | " action_name = action_plan.actions[i].action_name\n", 643 | " action_results_str += f\"* {action_name} \\n{result}\\n\"\n", 644 | " text = f\"## {action_name}\\n### 思考過程\\n{thoughts}\\n### 結果\\n{result}\"\n", 645 | " messages.append(ChatMessage(role=\"assistant\", content=text))\n", 646 | " # 実行結果を描画する\n", 647 | " yield \"\", messages, history\n", 648 | "\n", 649 | " history.append(AIMessage(content=action_results_str))\n", 650 | " # LangChain 用の履歴を更新する\n", 651 | " yield \"\", messages, history" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "metadata": {}, 658 | "outputs": [], 659 | "source": [ 660 | "with gr.Blocks() as demo:\n", 661 | " chatbot = gr.Chatbot(label=\"Assistant\", type=\"messages\", height=800)\n", 662 | " history = gr.State([])\n", 663 | " with gr.Row():\n", 664 | " with gr.Column(scale=9):\n", 665 | " user_input = gr.Textbox(lines=1, label=\"Chat Message\")\n", 666 | " with gr.Column(scale=1):\n", 667 | " submit = gr.Button(\"Submit\")\n", 668 | " clear = gr.ClearButton([user_input, chatbot, history])\n", 669 | " submit.click(\n", 670 | " chat,\n", 671 | " inputs=[user_input, chatbot, history],\n", 672 | " outputs=[user_input, chatbot, history],\n", 673 | " )\n", 674 | "demo.launch(height=1000)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": null, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [] 683 | } 684 | ], 685 | "metadata": { 686 | "kernelspec": { 687 | "display_name": ".venv", 688 | "language": "python", 689 | "name": "python3" 690 | }, 691 | "language_info": { 692 | "codemirror_mode": { 693 | "name": "ipython", 694 | "version": 3 695 | }, 696 | "file_extension": ".py", 697 | "mimetype": "text/x-python", 698 | "name": "python", 699 | "nbconvert_exporter": "python", 700 | "pygments_lexer": "ipython3", 701 | "version": "3.11.3" 702 | } 703 | }, 704 | "nbformat": 4, 705 | "nbformat_minor": 2 706 | } 707 | -------------------------------------------------------------------------------- /notebooks/chapter4/02_multi_agent_system_construction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "eGQrvXao8Ue5" 7 | }, 8 | "source": [ 9 | "# 準備" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "gmgwxvC071EL", 20 | "outputId": "a70e637f-f7dd-4a9f-f6f5-9dce992bcee5" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "from google.colab import userdata\n", 25 | "import os\n", 26 | "\n", 27 | "# OpenAI API キーの設定\n", 28 | "os.environ[\"OPENAI_API_KEY\"] = userdata.get('OPENAI_API_KEY')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "colab": { 36 | "base_uri": "https://localhost:8080/" 37 | }, 38 | "id": "BzCRyhnb9Kf6", 39 | "outputId": "88a3a00c-94db-4f9b-e9bc-532b018f772f" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "!pip install -q langchain langgraph langchain-openai langchain-community" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "5vXW4rC68Xd0" 50 | }, 51 | "source": [ 52 | "# 4.1 マルチエージェントシステムの構築" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": { 58 | "id": "aQGJiFm_8ko0" 59 | }, 60 | "source": [ 61 | "## 4.2.2 チャットボットの構築" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "id": "rGTXx5dG9BvR" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "from typing_extensions import TypedDict\n", 73 | "from typing import Annotated\n", 74 | "\n", 75 | "from langgraph.graph import StateGraph, START, END\n", 76 | "from langgraph.graph.message import add_messages\n", 77 | "from langchain_openai import ChatOpenAI\n", 78 | "\n", 79 | "llm = ChatOpenAI(model=\"gpt-4o\", model_kwargs={\"temperature\": 0})\n", 80 | "\n", 81 | "class State(TypedDict):\n", 82 | " count: int\n", 83 | " messages: Annotated[list, add_messages]\n", 84 | "\n", 85 | "def chatbot(state: State):\n", 86 | " messages = [llm.invoke(state[\"messages\"])]\n", 87 | " count = state[\"count\"] + 1\n", 88 | " return {\n", 89 | " \"messages\": messages,\n", 90 | " \"count\": count,\n", 91 | " }\n", 92 | "\n", 93 | "graph_builder = StateGraph(State)\n", 94 | "\n", 95 | "graph_builder.add_node(\"chatbot\", chatbot)\n", 96 | "\n", 97 | "graph_builder.add_edge(START, \"chatbot\")\n", 98 | "graph_builder.add_edge(\"chatbot\", END)\n", 99 | "\n", 100 | "graph = graph_builder.compile()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "colab": { 108 | "base_uri": "https://localhost:8080/", 109 | "height": 251 110 | }, 111 | "id": "OTFpGOgy9CYN", 112 | "outputId": "5d2657be-8ebb-4f6d-daa4-ae1678f319ad" 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "from IPython.display import display, Image\n", 117 | "\n", 118 | "display(Image(graph.get_graph().draw_mermaid_png()))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "colab": { 126 | "base_uri": "https://localhost:8080/" 127 | }, 128 | "id": "XC1svxRU9CVi", 129 | "outputId": "1d81751c-385f-472c-96ee-1777cbbf2000" 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "from langchain_core.messages import HumanMessage\n", 134 | "\n", 135 | "\n", 136 | "human_message = HumanMessage(\"こんにちは\")\n", 137 | "\n", 138 | "for event in graph.stream({\"messages\": [human_message], \"count\": 0}):\n", 139 | " for value in event.values():\n", 140 | " print(f\"### ターン{value['count']} ###\")\n", 141 | " value[\"messages\"][-1].pretty_print()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "id": "SlU1rU3W9CTS" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "# ペルソナの設定\n", 153 | "\n", 154 | "from langchain_core.messages import SystemMessage\n", 155 | "\n", 156 | "\n", 157 | "def chatbot(state: State):\n", 158 | " system_message = SystemMessage(\"あなたは、元気なエンジニアです。元気に返答してください。\")\n", 159 | " messages = [llm.invoke([system_message] + state[\"messages\"])]\n", 160 | " count = state[\"count\"] + 1\n", 161 | " return {\n", 162 | " \"messages\": messages,\n", 163 | " \"count\": count,\n", 164 | " }\n", 165 | "\n", 166 | "graph_builder = StateGraph(State)\n", 167 | "\n", 168 | "graph_builder.add_node(\"chatbot\", chatbot)\n", 169 | "\n", 170 | "graph_builder.add_edge(START, \"chatbot\")\n", 171 | "graph_builder.add_edge(\"chatbot\", END)\n", 172 | "\n", 173 | "graph = graph_builder.compile()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "colab": { 181 | "base_uri": "https://localhost:8080/" 182 | }, 183 | "id": "MMyO_d8z9CQq", 184 | "outputId": "5c8adb6c-465c-4523-ad36-bb774507e6fb" 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "from langchain_core.messages import HumanMessage\n", 189 | "\n", 190 | "\n", 191 | "human_message = HumanMessage(\"上手くデバッグができません\")\n", 192 | "\n", 193 | "for event in graph.stream({\"messages\": [human_message], \"count\": 0}):\n", 194 | " for value in event.values():\n", 195 | " print(f\"### ターン{value['count']} ###\")\n", 196 | " value[\"messages\"][-1].pretty_print()" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "id": "tHXlL21f8klr" 203 | }, 204 | "source": [ 205 | "## 4.2.3 複数のエージェントの接続" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": { 211 | "id": "KHD2smCU8kiz" 212 | }, 213 | "source": [ 214 | "### 3つのエージェントの準備" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": { 221 | "id": "2_SGC-du_nt2" 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "from langchain_core.messages import SystemMessage, HumanMessage\n", 226 | "from langchain.prompts import SystemMessagePromptTemplate\n", 227 | "import functools\n", 228 | "from langchain_openai import ChatOpenAI\n", 229 | "\n", 230 | "\n", 231 | "llm = ChatOpenAI(model=\"gpt-4o\")\n", 232 | "\n", 233 | "def agent_with_persona(state: State, name: str, traits: str):\n", 234 | " system_message_template = SystemMessagePromptTemplate.from_template(\n", 235 | " \"あなたの名前は{name}です。\\nあなたの性格は以下のとおりです。\\n\\n{traits}\"\n", 236 | " )\n", 237 | " system_message = system_message_template.format(name=name, traits=traits)\n", 238 | "\n", 239 | " message = HumanMessage(\n", 240 | " content=llm.invoke([system_message, *state[\"messages\"]]).content,\n", 241 | " name=name,\n", 242 | " )\n", 243 | "\n", 244 | " return {\n", 245 | " \"messages\": [message],\n", 246 | " }\n", 247 | "\n", 248 | "kenta_traits = \"\"\"\\\n", 249 | "- アクティブで冒険好き\n", 250 | "- 新しい経験を求める\n", 251 | "- アウトドア活動を好む\n", 252 | "- SNSでの共有を楽しむ\n", 253 | "- エネルギッシュで社交的\"\"\"\n", 254 | "\n", 255 | "mari_traits = \"\"\"\\\n", 256 | "- 穏やかでリラックス志向\n", 257 | "- 家族を大切にする\n", 258 | "- 静かな趣味を楽しむ\n", 259 | "- 心身の休養を重視\n", 260 | "- 丁寧な生活を好む\"\"\"\n", 261 | "\n", 262 | "yuta_traits = \"\"\"\\\n", 263 | "- バランス重視\n", 264 | "- 柔軟性がある\n", 265 | "- 自己啓発に熱心\n", 266 | "- 伝統と現代の融合を好む\n", 267 | "- 多様な経験を求める\"\"\"\n", 268 | "\n", 269 | "kenta = functools.partial(agent_with_persona, name=\"kenta\", traits=kenta_traits)\n", 270 | "mari = functools.partial(agent_with_persona, name=\"mari\", traits=mari_traits)\n", 271 | "yuta = functools.partial(agent_with_persona, name=\"yuta\", traits=yuta_traits)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": { 277 | "id": "v2KaKU0-8kf2" 278 | }, 279 | "source": [ 280 | "### 3つのエージェントが順番に回答するシステム" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": { 287 | "id": "xvcbBjzf_wtB" 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "from typing_extensions import TypedDict\n", 292 | "from typing import Annotated\n", 293 | "from langgraph.graph.message import add_messages\n", 294 | "\n", 295 | "\n", 296 | "class State(TypedDict):\n", 297 | " messages: Annotated[list, add_messages]" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "id": "Sw6t5kaDAY_Q" 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "from langgraph.graph import StateGraph, START, END\n", 309 | "\n", 310 | "graph_builder = StateGraph(State)\n", 311 | "\n", 312 | "graph_builder.add_node(\"kenta\", kenta)\n", 313 | "graph_builder.add_node(\"mari\", mari)\n", 314 | "graph_builder.add_node(\"yuta\", yuta)\n", 315 | "\n", 316 | "graph_builder.add_edge(START, \"kenta\")\n", 317 | "graph_builder.add_edge(\"kenta\", \"mari\")\n", 318 | "graph_builder.add_edge(\"mari\", \"yuta\")\n", 319 | "graph_builder.add_edge(\"yuta\", END)\n", 320 | "\n", 321 | "graph = graph_builder.compile()" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": { 328 | "colab": { 329 | "base_uri": "https://localhost:8080/", 330 | "height": 449 331 | }, 332 | "id": "fny-Bwr_AZnR", 333 | "outputId": "caa45ccd-e549-401b-badb-e94b7129c2ca" 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "from IPython.display import display, Image\n", 338 | "\n", 339 | "display(Image(graph.get_graph().draw_mermaid_png()))" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": { 346 | "colab": { 347 | "base_uri": "https://localhost:8080/" 348 | }, 349 | "id": "cIM6amN9AdHA", 350 | "outputId": "23a08647-d075-4f0b-a3d5-71127e9a13ef" 351 | }, 352 | "outputs": [], 353 | "source": [ 354 | "from langchain_core.messages import HumanMessage\n", 355 | "\n", 356 | "human_message = HumanMessage(\"休日の過ごし方について、建設的に議論してください。\")\n", 357 | "\n", 358 | "for event in graph.stream({\"messages\": [human_message]}):\n", 359 | " for value in event.values():\n", 360 | " value[\"messages\"][-1].pretty_print()" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "id": "_1k-Vxu08kdN" 367 | }, 368 | "source": [ 369 | "### 3つのエージェントが一斉に回答するシステム" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "metadata": { 376 | "id": "F5KR9ZG7AiFt" 377 | }, 378 | "outputs": [], 379 | "source": [ 380 | "from langgraph.graph import StateGraph, START, END\n", 381 | "\n", 382 | "graph_builder = StateGraph(State)\n", 383 | "\n", 384 | "graph_builder.add_node(\"kenta\", kenta)\n", 385 | "graph_builder.add_node(\"mari\", mari)\n", 386 | "graph_builder.add_node(\"yuta\", yuta)\n", 387 | "\n", 388 | "\n", 389 | "graph_builder.add_edge(START, \"kenta\")\n", 390 | "graph_builder.add_edge(START, \"mari\")\n", 391 | "graph_builder.add_edge(START, \"yuta\")\n", 392 | "graph_builder.add_edge(\"kenta\", END)\n", 393 | "graph_builder.add_edge(\"mari\", END)\n", 394 | "graph_builder.add_edge(\"yuta\", END)\n", 395 | "\n", 396 | "graph = graph_builder.compile()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": { 403 | "colab": { 404 | "base_uri": "https://localhost:8080/", 405 | "height": 251 406 | }, 407 | "id": "678fnvYyAnis", 408 | "outputId": "39f36816-1890-46b1-be0f-ba50779f24bd" 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "from IPython.display import display, Image\n", 413 | "\n", 414 | "display(Image(graph.get_graph().draw_mermaid_png()))" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": { 421 | "colab": { 422 | "base_uri": "https://localhost:8080/" 423 | }, 424 | "id": "P7F4qNwWAu04", 425 | "outputId": "639b0b81-d4b6-4fc5-dc4f-edbcbe44c233" 426 | }, 427 | "outputs": [], 428 | "source": [ 429 | "from langchain_core.messages import HumanMessage\n", 430 | "\n", 431 | "\n", 432 | "human_message = HumanMessage(\"休日の過ごし方について、建設的に議論してください。\")\n", 433 | "\n", 434 | "for event in graph.stream({\"messages\": [human_message]}):\n", 435 | " for value in event.values():\n", 436 | " value[\"messages\"][-1].pretty_print()" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "id": "6Q-P5_Wr8kaD" 443 | }, 444 | "source": [ 445 | "### 3つのエージェントから選択されたエージェントが回答するシステム" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": { 452 | "id": "xnYkNzNmAyf8" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "from pydantic import BaseModel, Field\n", 457 | "from langchain.prompts import SystemMessagePromptTemplate\n", 458 | "from typing import Literal\n", 459 | "\n", 460 | "\n", 461 | "class State(TypedDict):\n", 462 | " messages: Annotated[list, add_messages]\n", 463 | " next: str\n", 464 | "\n", 465 | "member_dict = {\n", 466 | " \"kenta\": kenta_traits,\n", 467 | " \"mari\": mari_traits,\n", 468 | " \"yuta\": yuta_traits,\n", 469 | "}\n", 470 | "\n", 471 | "#1 スキーマの設定\n", 472 | "class RouteSchema(BaseModel):\n", 473 | " next: Literal[\"kenta\", \"mari\", \"yuta\"] = Field(..., description=\"次に発言する人\")\n", 474 | "\n", 475 | "#2 監督者の作成\n", 476 | "def supervisor(state: State):\n", 477 | " system_message = SystemMessagePromptTemplate.from_template(\n", 478 | " \"あなたは以下の作業者間の会話を管理する監督者です:{members}。\" \"各メンバーの性格は以下の通りです。\" \"{traits_description}\" \"与えられたユーザーリクエストに対して、次に発言する人を選択してください。\" )\n", 479 | "\n", 480 | " members = \", \".join(list(member_dict.keys()))\n", 481 | " traits_description = \"\\n\".join([f\"**{name}**\\n{traits}\" for name, traits in member_dict.items()])\n", 482 | "\n", 483 | " system_message = system_message.format(members=members, traits_description=traits_description)\n", 484 | "\n", 485 | " llm_with_format = llm.with_structured_output(RouteSchema)\n", 486 | "\n", 487 | " next = llm_with_format.invoke([system_message] + state[\"messages\"]).next\n", 488 | " return {\"next\": next}\n" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": { 495 | "id": "_Tesz8AvBMBg" 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "graph_builder = StateGraph(State)\n", 500 | "\n", 501 | "graph_builder.add_node(\"supervisor\", supervisor)\n", 502 | "graph_builder.add_node(\"kenta\", kenta)\n", 503 | "graph_builder.add_node(\"mari\", mari)\n", 504 | "graph_builder.add_node(\"yuta\", yuta)\n", 505 | "\n", 506 | "graph_builder.add_edge(START, \"supervisor\")\n", 507 | "graph_builder.add_conditional_edges(\n", 508 | " \"supervisor\",\n", 509 | " lambda state: state[\"next\"],\n", 510 | " {\"kenta\": \"kenta\", \"mari\": \"mari\", \"yuta\": \"yuta\"},\n", 511 | ")\n", 512 | "\n", 513 | "for member in [\"kenta\", \"mari\", \"yuta\"]:\n", 514 | " graph_builder.add_edge(member, END)\n", 515 | "\n", 516 | "graph = graph_builder.compile()" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "metadata": { 523 | "colab": { 524 | "base_uri": "https://localhost:8080/", 525 | "height": 350 526 | }, 527 | "id": "C_745661BQmW", 528 | "outputId": "65c802d3-8ce6-4b0a-cec2-276c32820528" 529 | }, 530 | "outputs": [], 531 | "source": [ 532 | "from IPython.display import display, Image\n", 533 | "\n", 534 | "display(Image(graph.get_graph().draw_mermaid_png()))" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": { 541 | "colab": { 542 | "base_uri": "https://localhost:8080/" 543 | }, 544 | "id": "MZL-CGsZBiWU", 545 | "outputId": "36639c7b-343e-4a0a-a61f-c3544b1bb349" 546 | }, 547 | "outputs": [], 548 | "source": [ 549 | "from langchain_core.messages import HumanMessage\n", 550 | "\n", 551 | "human_message = HumanMessage(\"休日のまったりした過ごし方を教えて\")\n", 552 | "for event in graph.stream({\"messages\": [human_message]}):\n", 553 | " for value in event.values():\n", 554 | " if \"next\" in value:\n", 555 | " print(f\"次に発言する人: {value['next']}\")\n", 556 | " elif \"messages\" in value:\n", 557 | " value[\"messages\"][-1].pretty_print()" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": { 563 | "id": "sek8AwOc8kWU" 564 | }, 565 | "source": [ 566 | "## 4.2.4 ツールの使用" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": { 573 | "id": "JVq61UlQ_ePY" 574 | }, 575 | "outputs": [], 576 | "source": [ 577 | "from google.colab import userdata\n", 578 | "import os\n", 579 | "\n", 580 | "# Tavily API キーの設定\n", 581 | "os.environ[\"TAVILY_API_KEY\"] = userdata.get('TAVILY_API_KEY')" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": null, 587 | "metadata": { 588 | "id": "6Jf8Lmrh78vS" 589 | }, 590 | "outputs": [], 591 | "source": [ 592 | "from langchain_community.tools.tavily_search import TavilySearchResults\n", 593 | "from langchain_openai import ChatOpenAI\n", 594 | "from typing_extensions import TypedDict\n", 595 | "from typing import Annotated\n", 596 | "from langgraph.graph.message import add_messages\n", 597 | "\n", 598 | "class State(TypedDict):\n", 599 | " messages: Annotated[list, add_messages]\n", 600 | "\n", 601 | "#1 ツールの作成\n", 602 | "tavily_tool = TavilySearchResults(max_results=2)\n", 603 | "\n", 604 | "#2 ツールの紐づけ\n", 605 | "llm = ChatOpenAI(model=\"gpt-4o\")\n", 606 | "llm_with_tool = llm.bind_tools([tavily_tool])\n", 607 | "\n", 608 | "#3 ツールを使ったチャットボットの作成\n", 609 | "def chatbot(state: State):\n", 610 | " messages = [llm_with_tool.invoke(state[\"messages\"])]\n", 611 | " return {\n", 612 | " \"messages\": messages,\n", 613 | " }" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "metadata": { 620 | "id": "LO7MA1Iq-Gxx" 621 | }, 622 | "outputs": [], 623 | "source": [ 624 | "import json\n", 625 | "\n", 626 | "from langchain_core.messages import ToolMessage\n", 627 | "\n", 628 | "\n", 629 | "class ToolNode:\n", 630 | " def __init__(self, tools: list) -> None:\n", 631 | " self.tools_by_name = {tool.name: tool for tool in tools}\n", 632 | "\n", 633 | " def __call__(self, state: State):\n", 634 | " #1 最後のメッセージを取得\n", 635 | " if messages := state.get(\"messages\", []):\n", 636 | " message = messages[-1]\n", 637 | " else:\n", 638 | " raise ValueError(\"入力にメッセージが見つかりません\")\n", 639 | "\n", 640 | " #2 ツールの実行\n", 641 | " tool_messages = []\n", 642 | " for tool_call in message.tool_calls:\n", 643 | " #2.1 エージェントが指定したnameとargsを元にツールを実1行\n", 644 | " tool_result = self.tools_by_name[tool_call[\"name\"]].invoke(\n", 645 | " tool_call[\"args\"]\n", 646 | " )\n", 647 | " #2.2 ツールの実行結果をメッセージとして追加\n", 648 | " tool_messages.append(\n", 649 | " ToolMessage(\n", 650 | " content=json.dumps(tool_result, ensure_ascii=False),\n", 651 | " name=tool_call[\"name\"],\n", 652 | " tool_call_id=tool_call[\"id\"],\n", 653 | " )\n", 654 | " )\n", 655 | "\n", 656 | " return {\n", 657 | " \"messages\": tool_messages,\n", 658 | " }\n", 659 | "\n", 660 | "tool_node = ToolNode([tavily_tool])" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": { 667 | "id": "XNfI5sO7-e2e" 668 | }, 669 | "outputs": [], 670 | "source": [ 671 | "from typing import Literal\n", 672 | "\n", 673 | "def route_tools(\n", 674 | " state: State,\n", 675 | ") -> Literal[\"tools\", \"__end__\"]:\n", 676 | " if messages := state.get(\"messages\", []):\n", 677 | " ai_message = messages[-1]\n", 678 | " else:\n", 679 | " raise ValueError(f\"stateにツールに関するメッセージが見つかりませんでした: {state}\")\n", 680 | "\n", 681 | " if hasattr(ai_message, \"tool_calls\") and len(ai_message.tool_calls) > 0:\n", 682 | " return \"tools\"\n", 683 | " return \"__end__\"" 684 | ] 685 | }, 686 | { 687 | "cell_type": "code", 688 | "execution_count": null, 689 | "metadata": { 690 | "id": "_q1MFbrl-gVS" 691 | }, 692 | "outputs": [], 693 | "source": [ 694 | "from langgraph.graph import StateGraph, START, END\n", 695 | "\n", 696 | "graph_builder = StateGraph(State)\n", 697 | "\n", 698 | "graph_builder.add_node(\"chatbot\", chatbot)\n", 699 | "graph_builder.add_node(\"tools\", tool_node)\n", 700 | "\n", 701 | "graph_builder.add_conditional_edges(\n", 702 | " \"chatbot\",\n", 703 | " route_tools,\n", 704 | " [\"tools\", \"__end__\"],\n", 705 | ")\n", 706 | "\n", 707 | "graph_builder.add_edge(\"tools\", \"chatbot\")\n", 708 | "graph_builder.add_edge(START, \"chatbot\")\n", 709 | "graph = graph_builder.compile()" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": null, 715 | "metadata": { 716 | "colab": { 717 | "base_uri": "https://localhost:8080/", 718 | "height": 266 719 | }, 720 | "id": "mhSROt22-hy8", 721 | "outputId": "ca9483b8-1846-4b65-cc9e-87f0310d4a78" 722 | }, 723 | "outputs": [], 724 | "source": [ 725 | "from IPython.display import display, Image\n", 726 | "\n", 727 | "display(Image(graph.get_graph().draw_mermaid_png()))" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": { 734 | "colab": { 735 | "base_uri": "https://localhost:8080/" 736 | }, 737 | "id": "TvvyAiHo-jJu", 738 | "outputId": "54514511-8621-4e7c-829d-fb141b688b15" 739 | }, 740 | "outputs": [], 741 | "source": [ 742 | "from langchain_core.messages import HumanMessage\n", 743 | "\n", 744 | "human_message = {\n", 745 | " \"messages\": [HumanMessage(\"今日の東京の天気を教えて\")],\n", 746 | " \"count\": 0,\n", 747 | "}\n", 748 | "\n", 749 | "for event in graph.stream(human_message):\n", 750 | " for value in event.values():\n", 751 | " value[\"messages\"][-1].pretty_print()" 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": null, 757 | "metadata": { 758 | "id": "zVWAraWu-kge" 759 | }, 760 | "outputs": [], 761 | "source": [] 762 | } 763 | ], 764 | "metadata": { 765 | "colab": { 766 | "provenance": [] 767 | }, 768 | "kernelspec": { 769 | "display_name": "Python 3", 770 | "name": "python3" 771 | }, 772 | "language_info": { 773 | "name": "python" 774 | } 775 | }, 776 | "nbformat": 4, 777 | "nbformat_minor": 0 778 | } 779 | -------------------------------------------------------------------------------- /notebooks/chapter4/03_multi_agent_system_application.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "F2e1J3p4Dqu2" 7 | }, 8 | "source": [ 9 | "# 準備" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "eFUMuI4bEGcj", 20 | "outputId": "4354b1d2-2843-49f1-fcdb-f9ce435b1e6c" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "from google.colab import userdata\n", 25 | "import os\n", 26 | "\n", 27 | "# OpenAI API キーの設定\n", 28 | "os.environ[\"OPENAI_API_KEY\"] = userdata.get('OPENAI_API_KEY')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "colab": { 36 | "base_uri": "https://localhost:8080/" 37 | }, 38 | "id": "BgTjtbZuEVxn", 39 | "outputId": "751f2469-f062-4818-c4ff-dab98be54d16" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "!pip install -q langchain langgraph langchain-openai langchain-community langchain-experimental" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "WKUsbzjKDjYg" 50 | }, 51 | "source": [ 52 | "# 4.3. マルチエージェントの活用\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": { 58 | "id": "MEh3ONBaDjSk" 59 | }, 60 | "source": [ 61 | "## 4.3.1. 数学の問題を解かせよう" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "id": "Piv2lKP6JK0v" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "!pip install -q sympy" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": { 79 | "id": "5pCY8g12EoPe" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "from typing_extensions import TypedDict\n", 84 | "from typing import Annotated\n", 85 | "from langgraph.graph.message import add_messages\n", 86 | "\n", 87 | "class State(TypedDict):\n", 88 | " messages: Annotated[list, add_messages]\n", 89 | " problem: str\n", 90 | " first_flag: bool\n", 91 | " end_flag: bool" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 68, 97 | "metadata": { 98 | "id": "BigURL1yEoNc" 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "import re\n", 103 | "from langchain_experimental.utilities import PythonREPL\n", 104 | "from langchain_core.messages import HumanMessage\n", 105 | "\n", 106 | "#1 Python実行用のツール\n", 107 | "repl = PythonREPL()\n", 108 | "\n", 109 | "#2 コード部分を抜き出す関数\n", 110 | "def extract_code(input_string: str):\n", 111 | " pattern = r\"```(.*?)```\"\n", 112 | " match = re.findall(pattern, input_string, flags=re.DOTALL)\n", 113 | "\n", 114 | " queries = \"\"\n", 115 | " for m in match:\n", 116 | " query = m.replace(\"python\", \"\").strip()\n", 117 | " queries += query + \"\\n\"\n", 118 | " return queries\n", 119 | "\n", 120 | "#3 ユーザープロキシエージェントの定義\n", 121 | "INITIAL_PROMPT = \"\"\"\\\n", 122 | "Pythonを使って数学の問題を解いてみましょう。\n", 123 | "\n", 124 | "クエリ要件:\n", 125 | "常に出力には'print'関数を使用し、小数ではなく分数や根号形式を使用してください。\n", 126 | "sympyなどのパッケージを利用しても構いません。\n", 127 | "以下のフォーマットに従ってコードを書いてください。\n", 128 | "```python\n", 129 | "# あなたのコード\n", 130 | "```\n", 131 | "\n", 132 | "まず、問題を解くための主な考え方を述べてください。問題を解くためには以下の3つの方法から選択できます:\n", 133 | "ケース1:問題が直接Pythonコードで解決できる場合、プログラムを書いて解決してください。必要に応じてすべての可能な配置を列挙しても構いません。\n", 134 | "ケース2:問題が主に推論で解決できる場合、自分で直接解決してください。\n", 135 | "ケース3:上記の2つの方法では対処できない場合、次のプロセスに従ってください:\n", 136 | "1. 問題をステップバイステップで解決する(ステップを過度に細分化しないでください)。\n", 137 | "2. Pythonを使って問い合わせることができるクエリ(計算や方程式など)を取り出します。\n", 138 | "3. 結果を私に教えてください。\n", 139 | "4. 結果が正しいと思う場合は続行してください。結果が無効または予期しない場合は、クエリまたは推論を修正してください。\n", 140 | "\n", 141 | "すべてのクエリが実行され、答えを得た後、答えを \\\\boxed{{}} に入れてください。\n", 142 | "答え以外、例えば変数を\\\\boxed{{}}に入れたり、\\\\boxed{{}}を単体で使用しないで下さい。\n", 143 | "\\\\boxed{{}}の有無で答えが出たかを管理しています。最終的な答えが出た時以外は、\\\\boxed{{}}を使用しないでください。\n", 144 | "回答が得られた場合は、シンプルに表示して下さい。追加の出力などはしないでください。\n", 145 | "\n", 146 | "問題文:{problem}\n", 147 | "\"\"\"\n", 148 | "\n", 149 | "def user_proxy_agent(state: State):\n", 150 | " if state[\"first_flag\"]:\n", 151 | " message = INITIAL_PROMPT.format(problem=state[\"problem\"])\n", 152 | " else:\n", 153 | " last_message = state[\"messages\"][-1].content\n", 154 | " code = extract_code(last_message)\n", 155 | " if code:\n", 156 | " message = repl.run(code)\n", 157 | " else:\n", 158 | " message = \"続けてください。クエリが必要になるまで問題を解き続けてください。(答えが出た場合は、\\\\boxed{{}} に入れてください。)\",\n", 159 | " message = HumanMessage(message)\n", 160 | " return {\"messages\": [message], \"first_flag\": False}" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 69, 166 | "metadata": { 167 | "id": "4m5kr76zEoLW" 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "from langchain_openai import ChatOpenAI\n", 172 | "\n", 173 | "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", 174 | "\n", 175 | "# 回答を抜き出す関数\n", 176 | "def extract_boxed(input_string: str):\n", 177 | " pattern = r\"\\\\boxed\\{.*?\\}\"\n", 178 | " matches = re.findall(pattern, input_string)\n", 179 | " return [m.replace(\"\\\\boxed{\", \"\").replace(\"}\", \"\") for m in matches]\n", 180 | "\n", 181 | "# LLMエージェントを定義した関数\n", 182 | "def llm_agent(state: State):\n", 183 | " message = llm.invoke(state[\"messages\"])\n", 184 | " content = message.content\n", 185 | " boxed = extract_boxed(content)\n", 186 | " end_flag = False\n", 187 | " if boxed:\n", 188 | " end_flag = True\n", 189 | " return {\"messages\": [message], \"end_flag\": end_flag}" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 70, 195 | "metadata": { 196 | "id": "45wPtyjeEoIq" 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "from langgraph.graph import StateGraph, START, END\n", 201 | "\n", 202 | "graph_builder = StateGraph(State)\n", 203 | "\n", 204 | "graph_builder.add_node(\"llm_agent\", llm_agent)\n", 205 | "graph_builder.add_node(\"user_proxy_agent\", user_proxy_agent)\n", 206 | "\n", 207 | "graph_builder.add_edge(START, \"user_proxy_agent\")\n", 208 | "graph_builder.add_conditional_edges(\n", 209 | " \"llm_agent\",\n", 210 | " lambda state: state[\"end_flag\"],\n", 211 | " {True: END, False: \"user_proxy_agent\"}\n", 212 | ")\n", 213 | "graph_builder.add_edge(\"user_proxy_agent\", \"llm_agent\")\n", 214 | "\n", 215 | "graph = graph_builder.compile()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": { 222 | "colab": { 223 | "base_uri": "https://localhost:8080/", 224 | "height": 398 225 | }, 226 | "id": "uewP3Yh7EoGS", 227 | "outputId": "706ba059-dc72-4138-b4d2-77b8cd978c03" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "from IPython.display import display, Image\n", 232 | "\n", 233 | "display(Image(graph.get_graph().draw_mermaid_png()))" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": { 240 | "colab": { 241 | "base_uri": "https://localhost:8080/" 242 | }, 243 | "id": "-GTu5pCeEoB5", 244 | "outputId": "f76a6c5a-198f-40d8-a7d8-af48916f7956" 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "problem = \"\"\"\\\n", 249 | "問題: 偽の金塊は、コンクリートの立方体を金色のペイントで覆うことによって作られます。\n", 250 | "ペイントのコストは立方体の表面積に比例し、コンクリートのコストは体積に比例します。\n", 251 | "1インチの立方体を作るコストが130円であり、2インチの立方体を作るコストが680円であるとき、3インチの立方体を作るコストはいくらになりますか?\"\"\"\n", 252 | "\n", 253 | "\n", 254 | "for event in graph.stream({\"problem\": problem, \"first_flag\": True}):\n", 255 | " for value in event.values():\n", 256 | " value[\"messages\"][-1].pretty_print()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": { 262 | "id": "403Y2hADDjNO" 263 | }, 264 | "source": [ 265 | "## 4.3.2. 議論させてみよう" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 73, 271 | "metadata": { 272 | "id": "XehsDu1VHwCO" 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "from typing_extensions import TypedDict\n", 277 | "from typing import Annotated\n", 278 | "from langgraph.graph.message import add_messages\n", 279 | "\n", 280 | "class State(TypedDict):\n", 281 | " messages: Annotated[list, add_messages]\n", 282 | " debate_topic: str\n", 283 | " judged: bool\n", 284 | " round: int" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 74, 290 | "metadata": { 291 | "id": "mmFesZ2EIUWj" 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "from langchain_core.messages import HumanMessage, SystemMessage\n", 296 | "from langchain_openai import ChatOpenAI\n", 297 | "\n", 298 | "llm = ChatOpenAI(model=\"gpt-4o\")\n", 299 | "\n", 300 | "def cot_agent(\n", 301 | " state: State,\n", 302 | "):\n", 303 | " system_message = (\n", 304 | " \"与えられた議題に対し、ステップバイステップで考えてから回答してください。\"\n", 305 | " \"議題:{debate_topic}\"\n", 306 | " )\n", 307 | " system_message = SystemMessage(\n", 308 | " system_message.format(debate_topic=state[\"debate_topic\"])\n", 309 | " )\n", 310 | " message = HumanMessage(\n", 311 | " content=llm.invoke([system_message]).content, name=\"CoT\"\n", 312 | " )\n", 313 | "\n", 314 | " return {\"messages\": [message]}\n" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 75, 320 | "metadata": { 321 | "id": "nfBdd_h5Hv9k" 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "from langchain_core.messages import HumanMessage, SystemMessage\n", 326 | "import functools\n", 327 | "\n", 328 | "def debater(\n", 329 | " state: State,\n", 330 | " name: str,\n", 331 | " position: str,\n", 332 | "):\n", 333 | " system_message = (\n", 334 | " \"あなたはディベーターです。ディベート大会へようこそ。\"\n", 335 | " \"私たちの目的は正しい答えを見つけることですので、お互いの視点に完全に同意する必要はありません。\"\n", 336 | " \"ディベートのテーマは以下の通りです:{debate_topic}\"\n", 337 | " \"\"\n", 338 | " \"{position}\"\n", 339 | " )\n", 340 | "\n", 341 | " debate_topic = state[\"debate_topic\"]\n", 342 | " system_message = SystemMessage(\n", 343 | " system_message.format(debate_topic=debate_topic, position=position)\n", 344 | " )\n", 345 | " message = HumanMessage(\n", 346 | " content=llm.invoke([system_message, *state[\"messages\"]]).content,\n", 347 | " name=name,\n", 348 | " )\n", 349 | " return {\"messages\": [message]}\n", 350 | "\n", 351 | "\n", 352 | "affirmative_debator = functools.partial(\n", 353 | " debater,\n", 354 | " name=\"Affirmative_Debater\",\n", 355 | " position=\"あなたは肯定側です。あなたの見解を簡潔に述べてください。否定側の意見が与えられた場合は、それに反対して理由を簡潔に述べてください。\"\n", 356 | ")\n", 357 | "negative_debator = functools.partial(\n", 358 | " debater,\n", 359 | " name=\"Negative_Debater\",\n", 360 | " position=\"あなたは否定側です。肯定側の意見に反対し、あなたの理由を簡潔に説明してください。\"\n", 361 | ")" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 77, 367 | "metadata": { 368 | "id": "V85r_fHIHv6a" 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "from pydantic import BaseModel, Field\n", 373 | "from langchain_core.messages import AIMessage\n", 374 | "\n", 375 | "\n", 376 | "class JudgeSchema(BaseModel):\n", 377 | " judged: bool = Field(..., description=\"勝者が決まったかどうか\")\n", 378 | " answer: str = Field(description=\"議題に対する結論とその理由\")\n", 379 | "\n", 380 | "\n", 381 | "def judger(state: State):\n", 382 | " system_message = (\n", 383 | " \"あなたは司会者です。\"\n", 384 | " \"ディベート大会に2名のディベーターが参加します。\"\n", 385 | " \"彼らは{debate_topic}について自分の回答を発表し、それぞれの視点について議論します。\"\n", 386 | " \"各ラウンドの終わりに、あなたは両者の回答を評価していき、ディベートの勝者を判断します。\"\n", 387 | " \"判定が難しい場合は、次のラウンドで判断してください。\"\n", 388 | " )\n", 389 | " system_message = SystemMessage(\n", 390 | " system_message.format(debate_topic=state[\"debate_topic\"])\n", 391 | " )\n", 392 | "\n", 393 | " llm_with_format = llm.with_structured_output(JudgeSchema)\n", 394 | " res = llm_with_format.invoke([system_message, *state[\"messages\"]])\n", 395 | " messages = []\n", 396 | "\n", 397 | " if res.judged:\n", 398 | " message = HumanMessage(res.answer)\n", 399 | " messages.append(message)\n", 400 | " return {\n", 401 | " \"messages\": messages,\n", 402 | " \"judged\": res.judged\n", 403 | " }\n" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 78, 409 | "metadata": { 410 | "id": "TBKEZudkHv3d" 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "def round_monitor(state: State, max_round: int):\n", 415 | " round = state[\"round\"] + 1\n", 416 | " if state[\"round\"] < max_round:\n", 417 | " return {\"round\": round}\n", 418 | " else:\n", 419 | " return {\n", 420 | " \"messages\": [HumanMessage(\n", 421 | " \"最終ラウンドなので、勝者を決定し、議題に対する結論とその理由を述べてください。\"\n", 422 | " )],\n", 423 | " \"round\": round,\n", 424 | " }\n", 425 | "\n", 426 | "round_monitor = functools.partial(round_monitor, max_round=3)" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 79, 432 | "metadata": { 433 | "id": "FBA4XLs3Hv0z" 434 | }, 435 | "outputs": [], 436 | "source": [ 437 | "from langgraph.graph import StateGraph, START, END\n", 438 | "\n", 439 | "graph_builder = StateGraph(State)\n", 440 | "\n", 441 | "graph_builder.add_node(\"cot_agent\", cot_agent)\n", 442 | "graph_builder.add_node(\"affirmative_debator\", affirmative_debator)\n", 443 | "graph_builder.add_node(\"negative_debator\", negative_debator)\n", 444 | "graph_builder.add_node(\"judger\", judger)\n", 445 | "graph_builder.add_node(\"round_monitor\", round_monitor)\n", 446 | "\n", 447 | "graph_builder.add_edge(START, \"cot_agent\")\n", 448 | "graph_builder.add_edge(\"cot_agent\", \"affirmative_debator\")\n", 449 | "graph_builder.add_edge(\"affirmative_debator\", \"negative_debator\")\n", 450 | "graph_builder.add_edge(\"negative_debator\", \"round_monitor\")\n", 451 | "graph_builder.add_edge(\"round_monitor\", \"judger\")\n", 452 | "graph_builder.add_conditional_edges(\n", 453 | " \"judger\",\n", 454 | " lambda state: state[\"judged\"],\n", 455 | " {True: END, False: \"affirmative_debator\"}\n", 456 | ")\n", 457 | "\n", 458 | "graph = graph_builder.compile()" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "colab": { 466 | "base_uri": "https://localhost:8080/", 467 | "height": 695 468 | }, 469 | "id": "OOCyn8e0HvyW", 470 | "outputId": "6227e80d-46ca-477b-d96f-e295f6546e91" 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "from IPython.display import display, Image\n", 475 | "\n", 476 | "display(Image(graph.get_graph().draw_mermaid_png()))" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "metadata": { 483 | "colab": { 484 | "base_uri": "https://localhost:8080/" 485 | }, 486 | "id": "dQuwDrrSHvvh", 487 | "outputId": "53072ae3-b75e-4a2e-ddd4-87c97a8b8cb9" 488 | }, 489 | "outputs": [], 490 | "source": [ 491 | "inputs = {\n", 492 | " \"messages\": [],\n", 493 | " \"debate_topic\": \"戦争は必要か?\",\n", 494 | " \"judged\": False,\n", 495 | " \"round\": 0,\n", 496 | "}\n", 497 | "\n", 498 | "for event in graph.stream(inputs):\n", 499 | " for value in event.values():\n", 500 | " try:\n", 501 | " value[\"messages\"][-1].pretty_print()\n", 502 | " except:\n", 503 | " pass" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": { 509 | "id": "GaSUFgy-DlMW" 510 | }, 511 | "source": [ 512 | "## 4.3.3. 回答を洗練させよう" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": { 519 | "colab": { 520 | "base_uri": "https://localhost:8080/" 521 | }, 522 | "id": "PlIsAPgnDSo-", 523 | "outputId": "e2ccb2a7-4f38-4203-e7ad-d53883e99490" 524 | }, 525 | "outputs": [], 526 | "source": [ 527 | "!pip install -q langchain-google-genai\n", 528 | "!pip install -q langchain-anthropic" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": { 535 | "colab": { 536 | "base_uri": "https://localhost:8080/" 537 | }, 538 | "id": "_DY6aQ7pJbTk", 539 | "outputId": "bbdf1f95-1d5b-447a-dcb4-11bcb1b5936f" 540 | }, 541 | "outputs": [], 542 | "source": [ 543 | "from google.colab import userdata\n", 544 | "import os\n", 545 | "\n", 546 | "# Google API キーの設定\n", 547 | "os.environ[\"GOOGLE_API_KEY\"] = userdata.get('GOOGLE_API_KEY')\n", 548 | "\n", 549 | "# Anthropic API キーの設定\n", 550 | "os.environ[\"ANTHROPIC_API_KEY\"] = userdata.get('ANTHROPIC_API_KEY')" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 84, 556 | "metadata": { 557 | "id": "OKbA3y37JTmv" 558 | }, 559 | "outputs": [], 560 | "source": [ 561 | "from langchain_openai import ChatOpenAI\n", 562 | "from langchain_anthropic import ChatAnthropic\n", 563 | "from langchain_google_genai import ChatGoogleGenerativeAI\n", 564 | "\n", 565 | "llm_openai = ChatOpenAI(model=\"gpt-4o-mini\")\n", 566 | "llm_anthropic = ChatAnthropic(model=\"claude-3-5-sonnet-20240620\")\n", 567 | "llm_google = ChatGoogleGenerativeAI(model=\"gemini-1.5-pro\")" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 86, 573 | "metadata": { 574 | "id": "wdBjWZVXJ4S8" 575 | }, 576 | "outputs": [], 577 | "source": [ 578 | "from typing_extensions import TypedDict\n", 579 | "from typing import Annotated\n", 580 | "\n", 581 | "from langgraph.graph.message import add_messages\n", 582 | "\n", 583 | "from langchain_core.messages import HumanMessage, AIMessage, SystemMessage\n", 584 | "\n", 585 | "\n", 586 | "class State(TypedDict):\n", 587 | " human_message: HumanMessage\n", 588 | " messages: Annotated[list, add_messages]\n", 589 | " prev_messages: list[AIMessage]\n", 590 | " layer_cnt: int" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": 87, 596 | "metadata": { 597 | "id": "sq3iVg26J6sc" 598 | }, 599 | "outputs": [], 600 | "source": [ 601 | "from functools import partial\n", 602 | "from typing import Union\n", 603 | "\n", 604 | "\n", 605 | "aggregater_system_message_template = \"\"\"\\\n", 606 | "最新のユーザーの質問に対して、さまざまなLLMからの回答が提供されています。あなたの任務は、これらの回答を統合して、単一の高品質な回答を作成することです。\n", 607 | "提供された回答に含まれる情報を批判的に評価し、一部の情報が偏っていたり誤っていたりする可能性があることを認識することが重要です。\n", 608 | "回答を単に複製するのではなく、正確で包括的な返答を提供してください。\n", 609 | "回答が良く構造化され、一貫性があり、最高の精度と信頼性の基準を満たすようにしてください。\n", 610 | "\n", 611 | "{prev_messages}\"\"\"\n", 612 | "\n", 613 | "def agent(state: State, llm: Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI], name: str):\n", 614 | " input_messages = []\n", 615 | " if len(state[\"prev_messages\"]) > 0:\n", 616 | " prev_messages = [f\"{i+1}. {message.content}\" for i, message in enumerate(state[\"prev_messages\"])]\n", 617 | " prev_messages = \"\\n\".join(prev_messages)\n", 618 | "\n", 619 | " aggregater_system_message = SystemMessage(\n", 620 | " aggregater_system_message_template.format(prev_messages=prev_messages),\n", 621 | " )\n", 622 | "\n", 623 | " input_messages.append(aggregater_system_message)\n", 624 | "\n", 625 | " input_messages.append(state[\"human_message\"])\n", 626 | "\n", 627 | " message = llm.invoke(input_messages)\n", 628 | " message.name = name\n", 629 | "\n", 630 | " return {\"messages\": [message]}\n", 631 | "\n", 632 | "agent_openai = partial(agent, llm=llm_openai, name=\"openai\")\n", 633 | "agent_anthropic = partial(agent, llm=llm_anthropic, name=\"anthropic\")\n", 634 | "agent_google = partial(agent, llm=llm_google, name=\"google\")" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": 88, 640 | "metadata": { 641 | "id": "_Yql6MGSKAIg" 642 | }, 643 | "outputs": [], 644 | "source": [ 645 | "def update(state: State, num_agents: int):\n", 646 | " return {\n", 647 | " \"prev_messages\": state[\"messages\"][-num_agents:],\n", 648 | " \"layer_cnt\": state[\"layer_cnt\"] + 1\n", 649 | " }" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": 89, 655 | "metadata": { 656 | "id": "D2hYclt1KBlC" 657 | }, 658 | "outputs": [], 659 | "source": [ 660 | "def router(\n", 661 | " state: State,\n", 662 | " num_layers: int,\n", 663 | " agent_name_list: list[str]\n", 664 | "):\n", 665 | " if state[\"layer_cnt\"] < num_layers:\n", 666 | " return agent_name_list\n", 667 | " else:\n", 668 | " return \"final_agent\"" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 90, 674 | "metadata": { 675 | "id": "IFd-9LuQKDCz" 676 | }, 677 | "outputs": [], 678 | "source": [ 679 | "from langgraph.graph import StateGraph, START, END\n", 680 | "\n", 681 | "num_layers = 3\n", 682 | "\n", 683 | "graph_builder = StateGraph(State)\n", 684 | "\n", 685 | "agent_dict = {\n", 686 | " \"openai\": agent_openai,\n", 687 | " \"anthropic\": agent_anthropic,\n", 688 | " \"google\": agent_google,\n", 689 | "}\n", 690 | "\n", 691 | "graph_builder.add_node(\n", 692 | " \"update\",\n", 693 | " partial(update, num_agents=len(agent_dict))\n", 694 | ")\n", 695 | "graph_builder.add_node(\"final_agent\", agent_dict[\"openai\"])\n", 696 | "\n", 697 | "for agent_name, agent in agent_dict.items():\n", 698 | " graph_builder.add_node(agent_name, agent)\n", 699 | " graph_builder.add_edge(START, agent_name)\n", 700 | " graph_builder.add_edge(agent_name, \"update\")\n", 701 | "\n", 702 | "agent_name_list = list(agent_dict.keys())\n", 703 | "graph_builder.add_conditional_edges(\n", 704 | " \"update\",\n", 705 | " partial(router, num_layers=num_layers, agent_name_list=agent_name_list),\n", 706 | " agent_name_list + [\"final_agent\"]\n", 707 | ")\n", 708 | "graph_builder.add_edge(\"final_agent\", END)\n", 709 | "\n", 710 | "graph = graph_builder.compile()" 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "execution_count": null, 716 | "metadata": { 717 | "colab": { 718 | "base_uri": "https://localhost:8080/", 719 | "height": 449 720 | }, 721 | "id": "LOaI8azYKEcD", 722 | "outputId": "3966c76a-2804-4cf8-c1bf-807fa391e61d" 723 | }, 724 | "outputs": [], 725 | "source": [ 726 | "from IPython.display import display, Image\n", 727 | "\n", 728 | "display(Image(graph.get_graph().draw_mermaid_png()))" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": null, 734 | "metadata": { 735 | "colab": { 736 | "base_uri": "https://localhost:8080/" 737 | }, 738 | "id": "hdYcAMfEKGM2", 739 | "outputId": "df42b40e-8fa3-4209-cd00-56fa257b150f" 740 | }, 741 | "outputs": [], 742 | "source": [ 743 | "human_message = HumanMessage(\"マルチエージェントについて教えて\")\n", 744 | "\n", 745 | "state = {\n", 746 | " \"human_message\": human_message,\n", 747 | " \"messages\": [],\n", 748 | " \"prev_messages\": [],\n", 749 | " \"layer_cnt\": 1\n", 750 | "}\n", 751 | "\n", 752 | "print(\"#################### Layer 1 ####################\")\n", 753 | "for event in graph.stream(state):\n", 754 | " for value in event.values():\n", 755 | " if \"messages\" in value:\n", 756 | " value[\"messages\"][-1].pretty_print()\n", 757 | " if \"layer_cnt\" in value:\n", 758 | " print(f\"\\n\\n#################### Layer {value['layer_cnt']} ####################\")" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": null, 764 | "metadata": { 765 | "id": "E1kkzZKjKHhq" 766 | }, 767 | "outputs": [], 768 | "source": [] 769 | } 770 | ], 771 | "metadata": { 772 | "colab": { 773 | "provenance": [] 774 | }, 775 | "kernelspec": { 776 | "display_name": "Python 3", 777 | "name": "python3" 778 | }, 779 | "language_info": { 780 | "name": "python" 781 | } 782 | }, 783 | "nbformat": 4, 784 | "nbformat_minor": 0 785 | } 786 | --------------------------------------------------------------------------------