├── .github ├── actions │ └── uv_setup │ │ └── action.yml └── workflows │ ├── _lint.yml │ ├── _test.yml │ ├── ci.yml │ └── release.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── langgraph_bigtool ├── __init__.py ├── graph.py ├── tools.py └── utils.py ├── pyproject.toml ├── static └── img │ └── graph.png ├── tests ├── __init__.py ├── integration_tests │ ├── __init__.py │ └── test_end_to_end.py └── unit_tests │ ├── __init__.py │ └── test_end_to_end.py └── uv.lock /.github/actions/uv_setup/action.yml: -------------------------------------------------------------------------------- 1 | # TODO: https://docs.astral.sh/uv/guides/integration/github/#caching 2 | 3 | name: uv-install 4 | description: Set up Python and uv 5 | 6 | inputs: 7 | python-version: 8 | description: Python version, supporting MAJOR.MINOR only 9 | required: true 10 | 11 | env: 12 | UV_VERSION: "0.5.25" 13 | 14 | runs: 15 | using: composite 16 | steps: 17 | - name: Install uv and set the python version 18 | uses: astral-sh/setup-uv@v5 19 | with: 20 | version: ${{ env.UV_VERSION }} 21 | python-version: ${{ inputs.python-version }} 22 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | python-version: 11 | required: true 12 | type: string 13 | description: "Python version to use" 14 | 15 | env: 16 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 17 | 18 | # This env var allows us to get inline annotations when ruff has complaints. 19 | RUFF_OUTPUT_FORMAT: github 20 | 21 | UV_FROZEN: "true" 22 | 23 | jobs: 24 | build: 25 | name: "make lint #${{ inputs.python-version }}" 26 | runs-on: ubuntu-latest 27 | timeout-minutes: 20 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Set up Python ${{ inputs.python-version }} + uv 32 | uses: "./.github/actions/uv_setup" 33 | with: 34 | python-version: ${{ inputs.python-version }} 35 | 36 | - name: Install dependencies 37 | working-directory: ${{ inputs.working-directory }} 38 | run: | 39 | uv sync --group test 40 | 41 | - name: Analysing the code with our lint 42 | working-directory: ${{ inputs.working-directory }} 43 | run: | 44 | make lint 45 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | python-version: 11 | required: true 12 | type: string 13 | description: "Python version to use" 14 | 15 | env: 16 | UV_FROZEN: "true" 17 | UV_NO_SYNC: "true" 18 | 19 | jobs: 20 | build: 21 | defaults: 22 | run: 23 | working-directory: ${{ inputs.working-directory }} 24 | runs-on: ubuntu-latest 25 | timeout-minutes: 20 26 | name: "make test #${{ inputs.python-version }}" 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - name: Set up Python ${{ inputs.python-version }} + uv 31 | uses: "./.github/actions/uv_setup" 32 | id: setup-python 33 | with: 34 | python-version: ${{ inputs.python-version }} 35 | - name: Install dependencies 36 | shell: bash 37 | run: uv sync --group test 38 | 39 | - name: Run core tests 40 | shell: bash 41 | run: | 42 | make test 43 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Run CI Tests 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI 9 | 10 | # If another push to the same PR or branch happens while this workflow is still running, 11 | # cancel the earlier run in favor of the next run. 12 | # 13 | # There's no point in testing an outdated version of the code. GitHub only allows 14 | # a limited number of job runners to be active at the same time, so it's better to cancel 15 | # pointless jobs early so that more useful jobs can run sooner. 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | lint: 22 | strategy: 23 | matrix: 24 | # Only lint on the min and max supported Python versions. 25 | # It's extremely unlikely that there's a lint issue on any version in between 26 | # that doesn't show up on the min or max versions. 27 | # 28 | # GitHub rate-limits how many jobs can be running at any one time. 29 | # Starting new jobs is also relatively slow, 30 | # so linting on fewer versions makes CI faster. 31 | python-version: 32 | - "3.12" 33 | uses: 34 | ./.github/workflows/_lint.yml 35 | with: 36 | working-directory: . 37 | python-version: ${{ matrix.python-version }} 38 | secrets: inherit 39 | test: 40 | strategy: 41 | matrix: 42 | # Only lint on the min and max supported Python versions. 43 | # It's extremely unlikely that there's a lint issue on any version in between 44 | # that doesn't show up on the min or max versions. 45 | # 46 | # GitHub rate-limits how many jobs can be running at any one time. 47 | # Starting new jobs is also relatively slow, 48 | # so linting on fewer versions makes CI faster. 49 | python-version: 50 | - "3.10" 51 | - "3.12" 52 | uses: 53 | ./.github/workflows/_test.yml 54 | with: 55 | working-directory: . 56 | python-version: ${{ matrix.python-version }} 57 | secrets: inherit 58 | 59 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }} 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | description: "From which folder this pipeline executes" 14 | default: "." 15 | dangerous-nonmain-release: 16 | required: false 17 | type: boolean 18 | default: false 19 | description: "Release from a non-main branch (danger!)" 20 | 21 | env: 22 | PYTHON_VERSION: "3.11" 23 | UV_FROZEN: "true" 24 | UV_NO_SYNC: "true" 25 | 26 | jobs: 27 | build: 28 | if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmain-release 29 | environment: Scheduled testing 30 | runs-on: ubuntu-latest 31 | 32 | outputs: 33 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 34 | version: ${{ steps.check-version.outputs.version }} 35 | 36 | steps: 37 | - uses: actions/checkout@v4 38 | 39 | - name: Set up Python + uv 40 | uses: "./.github/actions/uv_setup" 41 | with: 42 | python-version: ${{ env.PYTHON_VERSION }} 43 | 44 | # We want to keep this build stage *separate* from the release stage, 45 | # so that there's no sharing of permissions between them. 46 | # The release stage has trusted publishing and GitHub repo contents write access, 47 | # and we want to keep the scope of that access limited just to the release job. 48 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 49 | # could get access to our GitHub or PyPI credentials. 50 | # 51 | # Per the trusted publishing GitHub Action: 52 | # > It is strongly advised to separate jobs for building [...] 53 | # > from the publish job. 54 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 55 | - name: Build project for distribution 56 | run: uv build 57 | - name: Upload build 58 | uses: actions/upload-artifact@v4 59 | with: 60 | name: dist 61 | path: ${{ inputs.working-directory }}/dist/ 62 | 63 | - name: Check Version 64 | id: check-version 65 | shell: python 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | import os 69 | import tomllib 70 | with open("pyproject.toml", "rb") as f: 71 | data = tomllib.load(f) 72 | pkg_name = data["project"]["name"] 73 | version = data["project"]["version"] 74 | with open(os.environ["GITHUB_OUTPUT"], "a") as f: 75 | f.write(f"pkg-name={pkg_name}\n") 76 | f.write(f"version={version}\n") 77 | publish: 78 | needs: 79 | - build 80 | runs-on: ubuntu-latest 81 | permissions: 82 | # This permission is used for trusted publishing: 83 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 84 | # 85 | # Trusted publishing has to also be configured on PyPI for each package: 86 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 87 | id-token: write 88 | 89 | defaults: 90 | run: 91 | working-directory: ${{ inputs.working-directory }} 92 | 93 | steps: 94 | - uses: actions/checkout@v4 95 | 96 | - name: Set up Python + uv 97 | uses: "./.github/actions/uv_setup" 98 | with: 99 | python-version: ${{ env.PYTHON_VERSION }} 100 | 101 | - uses: actions/download-artifact@v4 102 | with: 103 | name: dist 104 | path: ${{ inputs.working-directory }}/dist/ 105 | 106 | - name: Publish package distributions to PyPI 107 | uses: pypa/gh-action-pypi-publish@release/v1 108 | with: 109 | packages-dir: ${{ inputs.working-directory }}/dist/ 110 | verbose: true 111 | print-hash: true 112 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0 113 | attestations: false 114 | 115 | mark-release: 116 | needs: 117 | - build 118 | - publish 119 | runs-on: ubuntu-latest 120 | permissions: 121 | # This permission is needed by `ncipollo/release-action` to 122 | # create the GitHub release. 123 | contents: write 124 | 125 | defaults: 126 | run: 127 | working-directory: ${{ inputs.working-directory }} 128 | 129 | steps: 130 | - uses: actions/checkout@v4 131 | 132 | - name: Set up Python + uv 133 | uses: "./.github/actions/uv_setup" 134 | with: 135 | python-version: ${{ env.PYTHON_VERSION }} 136 | 137 | - uses: actions/download-artifact@v4 138 | with: 139 | name: dist 140 | path: ${{ inputs.working-directory }}/dist/ 141 | 142 | - name: Create Tag 143 | uses: ncipollo/release-action@v1 144 | with: 145 | artifacts: "dist/*" 146 | token: ${{ secrets.GITHUB_TOKEN }} 147 | generateReleaseNotes: true 148 | tag: ${{needs.build.outputs.pkg-name}}==${{ needs.build.outputs.version }} 149 | body: ${{ needs.release-notes.outputs.release-body }} 150 | commit: main 151 | makeLatest: true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Pyenv 2 | .python-version 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # Environments 33 | .venv 34 | .env 35 | 36 | # mypy 37 | .mypy_cache/ 38 | .dmypy.json 39 | dmypy.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 LangChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all lint format test help 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | ###################### 7 | # TESTING AND COVERAGE 8 | ###################### 9 | 10 | # Define a variable for the test file path. 11 | TEST_FILE ?= tests/unit_tests 12 | 13 | integration_test integration_tests: TEST_FILE=tests/integration_tests/ 14 | 15 | test: 16 | uv run --group test pytest --disable-socket --allow-unix-socket $(TEST_FILE) 17 | 18 | integration_test integration_tests: 19 | uv run --group test --group test_integration pytest $(TEST_FILE) 20 | 21 | test_watch: 22 | uv run ptw . -- $(TEST_FILE) 23 | 24 | 25 | ###################### 26 | # LINTING AND FORMATTING 27 | ###################### 28 | 29 | # Define a variable for Python and notebook files. 30 | lint format: PYTHON_FILES=. 31 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=. --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 32 | 33 | lint lint_diff: 34 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) --diff 35 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff check $(PYTHON_FILES) --diff 36 | # [ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES) 37 | 38 | format format_diff: 39 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff check --fix $(PYTHON_FILES) 40 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) 41 | 42 | 43 | 44 | ###################### 45 | # HELP 46 | ###################### 47 | 48 | help: 49 | @echo '====================' 50 | @echo '-- LINTING --' 51 | @echo 'format - run code formatters' 52 | @echo 'lint - run linters' 53 | @echo '-- TESTS --' 54 | @echo 'test - run unit tests' 55 | @echo 'test TEST_FILE= - run all tests in file' 56 | @echo '-- DOCUMENTATION tasks are from the top-level Makefile --' 57 | 58 | 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # langgraph-bigtool 2 | 3 | `langgraph-bigtool` is a Python library for creating 4 | [LangGraph](https://langchain-ai.github.io/langgraph/) agents that can access large 5 | numbers of tools. It leverages LangGraph's long-term 6 | [memory store](https://langchain-ai.github.io/langgraph/how-tos/memory/semantic-search/) 7 | to allow an agent to search for and retrieve relevant tools for a given problem. 8 | 9 | ## Features 10 | 11 | - 🧰 **Scalable access to tools**: Equip agents with hundreds or thousands of tools. 12 | - 📝 **Storage of tool metadata**: Control storage of tool descriptions, namespaces, 13 | and other information through LangGraph's built-in 14 | [persistence layer](https://langchain-ai.github.io/langgraph/concepts/persistence/). 15 | Includes support for 16 | [in-memory](https://langchain-ai.github.io/langgraph/how-tos/cross-thread-persistence/) 17 | and 18 | [Postgres](https://langchain-ai.github.io/langgraph/reference/store/#langgraph.store.postgres.PostgresStore) 19 | backends. 20 | - 💡 **Customization of tool retrieval**: Optionally define custom functions for tool retrieval. 21 | 22 | This library is built on top of [LangGraph](https://github.com/langchain-ai/langgraph), a powerful framework for building agent applications, and comes with out-of-box support for [streaming](https://langchain-ai.github.io/langgraph/how-tos/#streaming), [short-term and long-term memory](https://langchain-ai.github.io/langgraph/concepts/memory/) and [human-in-the-loop](https://langchain-ai.github.io/langgraph/concepts/human_in_the_loop/). 23 | 24 | ## Installation 25 | 26 | ```bash 27 | pip install langgraph-bigtool 28 | ``` 29 | 30 | ## Quickstart 31 | 32 | We demonstrate `langgraph-bigtool` by equipping an agent with all functions from 33 | Python's built-in `math` library. 34 | 35 | > [!NOTE] 36 | > This includes about 50 tools. Some LLMs can handle this number of tools together in 37 | a single invocation without issue. This example is for demonstration purposes. 38 | 39 | ```bash 40 | pip install langgraph-bigtool "langchain[openai]" 41 | 42 | export OPENAI_API_KEY= 43 | ``` 44 | 45 | ```python 46 | import math 47 | import types 48 | import uuid 49 | 50 | from langchain.chat_models import init_chat_model 51 | from langchain.embeddings import init_embeddings 52 | from langgraph.store.memory import InMemoryStore 53 | 54 | from langgraph_bigtool import create_agent 55 | from langgraph_bigtool.utils import ( 56 | convert_positional_only_function_to_tool 57 | ) 58 | 59 | # Collect functions from `math` built-in 60 | all_tools = [] 61 | for function_name in dir(math): 62 | function = getattr(math, function_name) 63 | if not isinstance( 64 | function, types.BuiltinFunctionType 65 | ): 66 | continue 67 | # This is an idiosyncrasy of the `math` library 68 | if tool := convert_positional_only_function_to_tool( 69 | function 70 | ): 71 | all_tools.append(tool) 72 | 73 | # Create registry of tools. This is a dict mapping 74 | # identifiers to tool instances. 75 | tool_registry = { 76 | str(uuid.uuid4()): tool 77 | for tool in all_tools 78 | } 79 | 80 | # Index tool names and descriptions in the LangGraph 81 | # Store. Here we use a simple in-memory store. 82 | embeddings = init_embeddings("openai:text-embedding-3-small") 83 | 84 | store = InMemoryStore( 85 | index={ 86 | "embed": embeddings, 87 | "dims": 1536, 88 | "fields": ["description"], 89 | } 90 | ) 91 | for tool_id, tool in tool_registry.items(): 92 | store.put( 93 | ("tools",), 94 | tool_id, 95 | { 96 | "description": f"{tool.name}: {tool.description}", 97 | }, 98 | ) 99 | 100 | # Initialize agent 101 | llm = init_chat_model("openai:gpt-4o-mini") 102 | 103 | builder = create_agent(llm, tool_registry) 104 | agent = builder.compile(store=store) 105 | agent 106 | ``` 107 | ![Graph diagram](static/img/graph.png) 108 | ```python 109 | query = "Use available tools to calculate arc cosine of 0.5." 110 | 111 | # Test it out 112 | for step in agent.stream( 113 | {"messages": query}, 114 | stream_mode="updates", 115 | ): 116 | for _, update in step.items(): 117 | for message in update.get("messages", []): 118 | message.pretty_print() 119 | ``` 120 | ``` 121 | ================================== Ai Message ================================== 122 | Tool Calls: 123 | retrieve_tools (call_nYZy6waIhivg94ZFhz3ju4K0) 124 | Call ID: call_nYZy6waIhivg94ZFhz3ju4K0 125 | Args: 126 | query: arc cosine calculation 127 | ================================= Tool Message ================================= 128 | 129 | Available tools: ['cos', 'acos'] 130 | ================================== Ai Message ================================== 131 | Tool Calls: 132 | acos (call_ynI4zBlJqXg4jfR21fVKDTTD) 133 | Call ID: call_ynI4zBlJqXg4jfR21fVKDTTD 134 | Args: 135 | x: 0.5 136 | ================================= Tool Message ================================= 137 | Name: acos 138 | 139 | 1.0471975511965976 140 | ================================== Ai Message ================================== 141 | 142 | The arc cosine of 0.5 is approximately 1.0472 radians. 143 | ``` 144 | 145 | ### Customizing tool retrieval 146 | 147 | `langgraph-bigtool` equips an agent with a tool that is used to retrieve tools in 148 | the registry. You can customize the retrieval by passing `retrieve_tools_function` 149 | and / or `retrieve_tools_coroutine` into `create_agent`. These functions are expected 150 | to return a list of IDs as output. 151 | ```python 152 | from langgraph.prebuilt import InjectedStore 153 | from langgraph.store.base import BaseStore 154 | from typing_extensions import Annotated 155 | 156 | 157 | def retrieve_tools( 158 | query: str, 159 | # Add custom arguments here... 160 | *, 161 | store: Annotated[BaseStore, InjectedStore], 162 | ) -> list[str]: 163 | """Retrieve a tool to use, given a search query.""" 164 | results = store.search(("tools",), query=query, limit=2) 165 | tool_ids = [result.key for result in results] 166 | # Insert your custom logic here... 167 | return tool_ids 168 | 169 | builder = create_agent( 170 | llm, tool_registry, retrieve_tools_function=retrieve_tools 171 | ) 172 | agent = builder.compile(store=store) 173 | ``` 174 | 175 | #### Retrieving tools without LangGraph Store 176 | You can implement arbitrary logic for the tool retrieval, which does not have to run 177 | semantic search against a query. Below, we return collections of tools corresponding 178 | to categories: 179 | ```python 180 | tool_registry = { 181 | "id_1": get_balance, 182 | "id_2": get_history, 183 | "id_3": create_ticket, 184 | } 185 | 186 | def retrieve_tools( 187 | category: Literal["billing", "service"], 188 | ) -> list[str]: 189 | """Get tools for a category.""" 190 | if category == "billing": 191 | return ["id_1", "id_2"] 192 | else: 193 | return ["id_3"] 194 | ``` 195 | > [!TIP] 196 | > Because the argument schema is inferred from type hints, type hinting the function 197 | argument as a `Literal` will signal that the LLM should populate a categorical value. 198 | 199 | ## Related work 200 | 201 | - [Toolshed: Scale Tool-Equipped Agents with Advanced RAG-Tool Fusion and Tool Knowledge Bases]( 202 | https://doi.org/10.48550/arXiv.2410.14594) - Lumer, E., Subbiah, V.K., Burke, J.A., 203 | Basavaraju, P.H. & Huber, A. (2024). arXiv:2410.14594. 204 | 205 | - [Graph RAG-Tool Fusion]( 206 | https://doi.org/10.48550/arXiv.2502.07223) - Lumer, E., Basavaraju, P.H., Mason, M., 207 | Burke, J.A. & Subbiah, V.K. (2025). arXiv:2502.07223. 208 | 209 | - https://github.com/quchangle1/LLM-Tool-Survey 210 | 211 | - [Retrieval Models Aren't Tool-Savvy: Benchmarking Tool Retrieval for Large Language Models]( 212 | https://doi.org/10.48550/arXiv.2503.01763) - Shi, Z., Wang, Y., Yan, L., Ren, P., 213 | Wang, S., Yin, D. & Ren, Z. arXiv:2503.01763. 214 | -------------------------------------------------------------------------------- /langgraph_bigtool/__init__.py: -------------------------------------------------------------------------------- 1 | from langgraph_bigtool.graph import create_agent 2 | 3 | __all__ = ["create_agent"] 4 | -------------------------------------------------------------------------------- /langgraph_bigtool/graph.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Callable 2 | 3 | from langchain_core.language_models import LanguageModelLike 4 | from langchain_core.messages import AIMessage, ToolMessage 5 | from langchain_core.runnables import RunnableConfig 6 | from langchain_core.tools import BaseTool, StructuredTool 7 | from langgraph.graph import END, MessagesState, StateGraph 8 | from langgraph.prebuilt import ToolNode 9 | from langgraph.store.base import BaseStore 10 | from langgraph.types import Send 11 | from langgraph.utils.runnable import RunnableCallable 12 | 13 | from langgraph_bigtool.tools import get_default_retrieval_tool, get_store_arg 14 | 15 | 16 | def _add_new(left: list, right: list) -> list: 17 | """Extend left list with new items from right list.""" 18 | return left + [item for item in right if item not in set(left)] 19 | 20 | 21 | class State(MessagesState): 22 | selected_tool_ids: Annotated[list[str], _add_new] 23 | 24 | 25 | def _format_selected_tools( 26 | selected_tools: dict, tool_registry: dict[str, BaseTool] 27 | ) -> tuple[list[ToolMessage], list[str]]: 28 | tool_messages = [] 29 | tool_ids = [] 30 | for tool_call_id, batch in selected_tools.items(): 31 | tool_names = [] 32 | for result in batch: 33 | if isinstance(tool_registry[result], BaseTool): 34 | tool_names.append(tool_registry[result].name) 35 | else: 36 | tool_names.append(tool_registry[result].__name__) 37 | tool_messages.append( 38 | ToolMessage(f"Available tools: {tool_names}", tool_call_id=tool_call_id) 39 | ) 40 | tool_ids.extend(batch) 41 | 42 | return tool_messages, tool_ids 43 | 44 | 45 | def create_agent( 46 | llm: LanguageModelLike, 47 | tool_registry: dict[str, BaseTool | Callable], 48 | *, 49 | limit: int = 2, 50 | filter: dict[str, any] | None = None, 51 | namespace_prefix: tuple[str, ...] = ("tools",), 52 | retrieve_tools_function: Callable | None = None, 53 | retrieve_tools_coroutine: Callable | None = None, 54 | ) -> StateGraph: 55 | """Create an agent with a registry of tools. 56 | 57 | The agent will function as a typical ReAct agent, but is equipped with a tool 58 | for retrieving tools from a registry. The agent will start with only this tool. 59 | As it is executed, retrieved tools will be bound to the model. 60 | 61 | Args: 62 | llm: Language model to use for the agent. 63 | tool_registry: a dict mapping string IDs to tools or callables. 64 | limit: Maximum number of tools to retrieve with each tool selection step. 65 | filter: Optional key-value pairs with which to filter results. 66 | namespace_prefix: Hierarchical path prefix to search within the Store. Defaults 67 | to ("tools",). 68 | retrieve_tools_function: Optional function to use for retrieving tools. This 69 | function should return a list of tool IDs. If not specified, uses semantic 70 | against the Store with limit, filter, and namespace_prefix set above. 71 | retrieve_tools_coroutine: Optional coroutine to use for retrieving tools. This 72 | function should return a list of tool IDs. If not specified, uses semantic 73 | against the Store with limit, filter, and namespace_prefix set above. 74 | """ 75 | if retrieve_tools_function is None and retrieve_tools_coroutine is None: 76 | retrieve_tools_function, retrieve_tools_coroutine = get_default_retrieval_tool( 77 | namespace_prefix, limit=limit, filter=filter 78 | ) 79 | retrieve_tools = StructuredTool.from_function( 80 | func=retrieve_tools_function, coroutine=retrieve_tools_coroutine 81 | ) 82 | # If needed, get argument name to inject Store 83 | store_arg = get_store_arg(retrieve_tools) 84 | 85 | def call_model(state: State, config: RunnableConfig, *, store: BaseStore) -> State: 86 | selected_tools = [tool_registry[id] for id in state["selected_tool_ids"]] 87 | llm_with_tools = llm.bind_tools([retrieve_tools, *selected_tools]) 88 | response = llm_with_tools.invoke(state["messages"]) 89 | return {"messages": [response]} 90 | 91 | async def acall_model( 92 | state: State, config: RunnableConfig, *, store: BaseStore 93 | ) -> State: 94 | selected_tools = [tool_registry[id] for id in state["selected_tool_ids"]] 95 | llm_with_tools = llm.bind_tools([retrieve_tools, *selected_tools]) 96 | response = await llm_with_tools.ainvoke(state["messages"]) 97 | return {"messages": [response]} 98 | 99 | tool_node = ToolNode(tool for tool in tool_registry.values()) 100 | 101 | def select_tools( 102 | tool_calls: list[dict], config: RunnableConfig, *, store: BaseStore 103 | ) -> State: 104 | selected_tools = {} 105 | for tool_call in tool_calls: 106 | kwargs = {**tool_call["args"]} 107 | if store_arg: 108 | kwargs[store_arg] = store 109 | result = retrieve_tools.invoke(kwargs) 110 | selected_tools[tool_call["id"]] = result 111 | 112 | tool_messages, tool_ids = _format_selected_tools(selected_tools, tool_registry) 113 | return {"messages": tool_messages, "selected_tool_ids": tool_ids} 114 | 115 | async def aselect_tools( 116 | tool_calls: list[dict], config: RunnableConfig, *, store: BaseStore 117 | ) -> State: 118 | selected_tools = {} 119 | for tool_call in tool_calls: 120 | kwargs = {**tool_call["args"]} 121 | if store_arg: 122 | kwargs[store_arg] = store 123 | result = await retrieve_tools.ainvoke(kwargs) 124 | selected_tools[tool_call["id"]] = result 125 | 126 | tool_messages, tool_ids = _format_selected_tools(selected_tools, tool_registry) 127 | return {"messages": tool_messages, "selected_tool_ids": tool_ids} 128 | 129 | def should_continue(state: State, *, store: BaseStore): 130 | messages = state["messages"] 131 | last_message = messages[-1] 132 | if not isinstance(last_message, AIMessage) or not last_message.tool_calls: 133 | return END 134 | else: 135 | destinations = [] 136 | for call in last_message.tool_calls: 137 | if call["name"] == retrieve_tools.name: 138 | destinations.append(Send("select_tools", [call])) 139 | else: 140 | tool_call = tool_node.inject_tool_args(call, state, store) 141 | destinations.append(Send("tools", [tool_call])) 142 | 143 | return destinations 144 | 145 | builder = StateGraph(State) 146 | 147 | if retrieve_tools_function is not None and retrieve_tools_coroutine is not None: 148 | select_tools_node = RunnableCallable(select_tools, aselect_tools) 149 | elif retrieve_tools_function is not None and retrieve_tools_coroutine is None: 150 | select_tools_node = select_tools 151 | elif retrieve_tools_coroutine is not None and retrieve_tools_function is None: 152 | select_tools_node = aselect_tools 153 | else: 154 | raise ValueError( 155 | "One of retrieve_tools_function or retrieve_tools_coroutine must be " 156 | "provided." 157 | ) 158 | 159 | builder.add_node("agent", RunnableCallable(call_model, acall_model)) 160 | builder.add_node("select_tools", select_tools_node) 161 | builder.add_node("tools", tool_node) 162 | 163 | builder.set_entry_point("agent") 164 | 165 | builder.add_conditional_edges( 166 | "agent", 167 | should_continue, 168 | path_map=["select_tools", "tools", END], 169 | ) 170 | builder.add_edge("tools", "agent") 171 | builder.add_edge("select_tools", "agent") 172 | 173 | return builder 174 | -------------------------------------------------------------------------------- /langgraph_bigtool/tools.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type, Union 2 | 3 | from langchain_core.tools.base import ( 4 | BaseTool, 5 | get_all_basemodel_annotations, 6 | ) 7 | from langgraph.prebuilt import InjectedState, InjectedStore 8 | from langgraph.store.base import BaseStore 9 | from typing_extensions import Annotated, get_args, get_origin 10 | 11 | ToolId = str 12 | 13 | 14 | def get_default_retrieval_tool( 15 | namespace_prefix: tuple[str, ...], 16 | *, 17 | limit: int = 2, 18 | filter: dict[str, Any] | None = None, 19 | ): 20 | """Get default sync and async functions for tool retrieval.""" 21 | 22 | def retrieve_tools( 23 | query: str, 24 | *, 25 | store: Annotated[BaseStore, InjectedStore], 26 | ) -> list[ToolId]: 27 | """Retrieve a tool to use, given a search query.""" 28 | results = store.search( 29 | namespace_prefix, 30 | query=query, 31 | limit=limit, 32 | filter=filter, 33 | ) 34 | return [result.key for result in results] 35 | 36 | async def aretrieve_tools( 37 | query: str, 38 | *, 39 | store: Annotated[BaseStore, InjectedStore], 40 | ) -> list[ToolId]: 41 | """Retrieve a tool to use, given a search query.""" 42 | results = await store.asearch( 43 | namespace_prefix, 44 | query=query, 45 | limit=limit, 46 | filter=filter, 47 | ) 48 | return [result.key for result in results] 49 | 50 | return retrieve_tools, aretrieve_tools 51 | 52 | 53 | def _is_injection( 54 | type_arg: Any, injection_type: Union[Type[InjectedState], Type[InjectedStore]] 55 | ) -> bool: 56 | if isinstance(type_arg, injection_type) or ( 57 | isinstance(type_arg, type) and issubclass(type_arg, injection_type) 58 | ): 59 | return True 60 | origin_ = get_origin(type_arg) 61 | if origin_ is Union or origin_ is Annotated: 62 | return any(_is_injection(ta, injection_type) for ta in get_args(type_arg)) 63 | return False 64 | 65 | 66 | def get_store_arg(tool: BaseTool) -> str | None: 67 | full_schema = tool.get_input_schema() 68 | for name, type_ in get_all_basemodel_annotations(full_schema).items(): 69 | injections = [ 70 | type_arg 71 | for type_arg in get_args(type_) 72 | if _is_injection(type_arg, InjectedStore) 73 | ] 74 | if len(injections) > 1: 75 | ValueError( 76 | "A tool argument should not be annotated with InjectedStore more than " 77 | f"once. Received arg {name} with annotations {injections}." 78 | ) 79 | elif len(injections) == 1: 80 | return name 81 | else: 82 | pass 83 | 84 | return None 85 | -------------------------------------------------------------------------------- /langgraph_bigtool/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from functools import wraps 3 | from typing import Callable 4 | 5 | from langchain_core._api import beta 6 | from langchain_core.tools import tool 7 | 8 | 9 | @beta() 10 | def convert_positional_only_function_to_tool(func: Callable): 11 | """Handle tool creation for functions with positional-only args.""" 12 | try: 13 | original_signature = inspect.signature(func) 14 | except ValueError: # no signature 15 | return None 16 | new_params = [] 17 | 18 | # Convert any POSITIONAL_ONLY parameters into POSITIONAL_OR_KEYWORD 19 | for param in original_signature.parameters.values(): 20 | if param.kind == inspect.Parameter.VAR_POSITIONAL: 21 | return None 22 | if param.kind == inspect.Parameter.POSITIONAL_ONLY: 23 | new_params.append( 24 | param.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) 25 | ) 26 | else: 27 | new_params.append(param) 28 | 29 | updated_signature = inspect.Signature(new_params) 30 | 31 | @wraps(func) 32 | def wrapper(*args, **kwargs): 33 | bound = updated_signature.bind(*args, **kwargs) 34 | bound.apply_defaults() 35 | return func(*bound.args, **bound.kwargs) 36 | 37 | wrapper.__signature__ = updated_signature 38 | 39 | return tool(wrapper) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "langgraph-bigtool" 7 | version = "0.0.3" 8 | description = "Build LangGraph agents with large numbers of tools" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | dependencies = [ 12 | "langgraph>=0.3.0", 13 | ] 14 | 15 | [dependency-groups] 16 | test = [ 17 | "numpy>=1", 18 | "pytest>=8.0.0", 19 | "ruff>=0.9.4", 20 | "mypy>=1.8.0", 21 | "pytest-socket>=0.7.0", 22 | "pytest-asyncio>=0.21.1", 23 | "types-setuptools>=69.0.0", 24 | ] 25 | test_integration = [ 26 | "langchain[openai]>=0.3.20", 27 | ] 28 | 29 | [tool.pytest.ini_options] 30 | minversion = "8.0" 31 | addopts = "-ra -q -v" 32 | testpaths = [ 33 | "tests", 34 | ] 35 | python_files = ["test_*.py"] 36 | python_functions = ["test_*"] 37 | asyncio_mode = "auto" 38 | 39 | [tool.ruff] 40 | line-length = 88 41 | target-version = "py310" 42 | 43 | [tool.ruff.lint] 44 | select = [ 45 | "E", # pycodestyle errors 46 | "W", # pycodestyle warnings 47 | "F", # pyflakes 48 | "I", # isort 49 | "B", # flake8-bugbear 50 | ] 51 | ignore = [ 52 | "E501" # line-length 53 | ] 54 | 55 | 56 | [tool.mypy] 57 | python_version = "3.11" 58 | warn_return_any = true 59 | warn_unused_configs = true 60 | disallow_untyped_defs = true 61 | check_untyped_defs = true 62 | -------------------------------------------------------------------------------- /static/img/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langgraph-bigtool/42119f4664b100678b2a39f82ca8bf58aadeef19/static/img/graph.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langgraph-bigtool/42119f4664b100678b2a39f82ca8bf58aadeef19/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langgraph-bigtool/42119f4664b100678b2a39f82ca8bf58aadeef19/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | from langchain.chat_models import init_chat_model 2 | from langchain.embeddings import init_embeddings 3 | 4 | from tests.unit_tests.test_end_to_end import run_end_to_end_test 5 | 6 | 7 | def test_end_to_end() -> None: 8 | llm = init_chat_model("openai:gpt-4o") 9 | embeddings = init_embeddings("openai:text-embedding-3-small") 10 | run_end_to_end_test(llm, embeddings) 11 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langgraph-bigtool/42119f4664b100678b2a39f82ca8bf58aadeef19/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | import types 4 | import uuid 5 | from typing import Callable 6 | from unittest.mock import patch 7 | 8 | import pytest 9 | from langchain_core.embeddings import Embeddings 10 | from langchain_core.embeddings.fake import DeterministicFakeEmbedding 11 | from langchain_core.language_models import GenericFakeChatModel, LanguageModelLike 12 | from langchain_core.messages import AIMessage, ToolMessage 13 | from langchain_core.tools import BaseTool 14 | from langgraph.prebuilt import InjectedStore 15 | from langgraph.store.base import BaseStore 16 | from langgraph.store.memory import InMemoryStore 17 | from typing_extensions import Annotated 18 | 19 | from langgraph_bigtool import create_agent 20 | from langgraph_bigtool.graph import State 21 | from langgraph_bigtool.utils import convert_positional_only_function_to_tool 22 | 23 | EMBEDDING_SIZE = 1536 24 | 25 | 26 | # Create a list of all the functions in the math module 27 | all_names = dir(math) 28 | 29 | math_functions = [ 30 | getattr(math, name) 31 | for name in all_names 32 | if isinstance(getattr(math, name), types.BuiltinFunctionType) 33 | ] 34 | 35 | # Convert to tools, handling positional-only arguments (idiosyncrasy of math module) 36 | all_tools = [] 37 | for function in math_functions: 38 | if wrapper := convert_positional_only_function_to_tool(function): 39 | all_tools.append(wrapper) 40 | 41 | # Store tool objects in registry 42 | tool_registry = {str(uuid.uuid4()): tool for tool in all_tools} 43 | 44 | 45 | class FakeModel(GenericFakeChatModel): 46 | def bind_tools(self, *args, **kwargs) -> "FakeModel": 47 | """Do nothing for now.""" 48 | return self 49 | 50 | 51 | def _get_fake_llm_and_embeddings(retriever_tool_name: str = "retrieve_tools"): 52 | fake_embeddings = DeterministicFakeEmbedding(size=EMBEDDING_SIZE) 53 | 54 | acos_tool = next(tool for tool in tool_registry.values() if tool.name == "acos") 55 | initial_query = ( 56 | f"{acos_tool.name}: {acos_tool.description}" # make same as embedding 57 | ) 58 | fake_llm = FakeModel( 59 | messages=iter( 60 | [ 61 | AIMessage( 62 | "", 63 | tool_calls=[ 64 | { 65 | "name": retriever_tool_name, 66 | "args": {"query": initial_query}, 67 | "id": "abc123", 68 | "type": "tool_call", 69 | } 70 | ], 71 | ), 72 | AIMessage( 73 | "", 74 | tool_calls=[ 75 | { 76 | "name": "acos", 77 | "args": {"x": 0.5}, 78 | "id": "abc234", 79 | "type": "tool_call", 80 | } 81 | ], 82 | ), 83 | AIMessage("The arc cosine of 0.5 is approximately 1.047 radians."), 84 | ] 85 | ) 86 | ) 87 | 88 | return fake_llm, fake_embeddings 89 | 90 | 91 | def _validate_result(result: State, tool_registry=tool_registry) -> None: 92 | assert set(result.keys()) == {"messages", "selected_tool_ids"} 93 | selected = [] 94 | for tool_id in result["selected_tool_ids"]: 95 | if isinstance(tool_registry[tool_id], BaseTool): 96 | selected.append(tool_registry[tool_id].name) 97 | else: 98 | selected.append(tool_registry[tool_id].__name__) 99 | assert "acos" in selected 100 | assert set(message.type for message in result["messages"]) == { 101 | "human", 102 | "ai", 103 | "tool", 104 | } 105 | tool_calls = [ 106 | tool_call 107 | for message in result["messages"] 108 | if isinstance(message, AIMessage) 109 | for tool_call in message.tool_calls 110 | ] 111 | assert tool_calls 112 | tool_call_names = [tool_call["name"] for tool_call in tool_calls] 113 | assert "retrieve_tools" in tool_call_names 114 | math_tool_calls = [ 115 | tool_call for tool_call in tool_calls if tool_call["name"] == "acos" 116 | ] 117 | assert len(math_tool_calls) == 1 118 | math_tool_call = math_tool_calls[0] 119 | tool_messages = [ 120 | message 121 | for message in result["messages"] 122 | if isinstance(message, ToolMessage) 123 | and message.tool_call_id == math_tool_call["id"] 124 | ] 125 | assert len(tool_messages) == 1 126 | tool_message = tool_messages[0] 127 | assert round(float(tool_message.content), 4) == 1.0472 128 | reply = result["messages"][-1] 129 | assert isinstance(reply, AIMessage) 130 | assert not reply.tool_calls 131 | assert reply.content 132 | 133 | 134 | def run_end_to_end_test( 135 | llm: LanguageModelLike, 136 | embeddings: Embeddings, 137 | retrieve_tools_function: Callable | None = None, 138 | retrieve_tools_coroutine: Callable | None = None, 139 | ) -> None: 140 | # Store tool descriptions in store 141 | store = InMemoryStore( 142 | index={ 143 | "embed": embeddings, 144 | "dims": EMBEDDING_SIZE, 145 | "fields": ["description"], 146 | } 147 | ) 148 | for tool_id, tool in tool_registry.items(): 149 | store.put( 150 | ("tools",), 151 | tool_id, 152 | { 153 | "description": f"{tool.name}: {tool.description}", 154 | }, 155 | ) 156 | 157 | builder = create_agent( 158 | llm, 159 | tool_registry, 160 | retrieve_tools_function=retrieve_tools_function, 161 | retrieve_tools_coroutine=retrieve_tools_coroutine, 162 | ) 163 | agent = builder.compile(store=store) 164 | 165 | result = agent.invoke( 166 | {"messages": "Use available tools to calculate arc cosine of 0.5."} 167 | ) 168 | _validate_result(result) 169 | 170 | 171 | async def run_end_to_end_test_async( 172 | llm: LanguageModelLike, 173 | embeddings: Embeddings, 174 | retrieve_tools_function: Callable | None = None, 175 | retrieve_tools_coroutine: Callable | None = None, 176 | ) -> None: 177 | # Store tool descriptions in store 178 | store = InMemoryStore( 179 | index={ 180 | "embed": embeddings, 181 | "dims": EMBEDDING_SIZE, 182 | "fields": ["description"], 183 | } 184 | ) 185 | for tool_id, tool in tool_registry.items(): 186 | await store.aput( 187 | ("tools",), 188 | tool_id, 189 | { 190 | "description": f"{tool.name}: {tool.description}", 191 | }, 192 | ) 193 | 194 | builder = create_agent( 195 | llm, 196 | tool_registry, 197 | retrieve_tools_function=retrieve_tools_function, 198 | retrieve_tools_coroutine=retrieve_tools_coroutine, 199 | ) 200 | agent = builder.compile(store=store) 201 | 202 | result = await agent.ainvoke( 203 | {"messages": "Use available tools to calculate arc cosine of 0.5."} 204 | ) 205 | _validate_result(result) 206 | 207 | 208 | class CustomError(Exception): 209 | pass 210 | 211 | 212 | def custom_retrieve_tools_store( 213 | query: str, 214 | *, 215 | store: Annotated[BaseStore, InjectedStore], 216 | ) -> list[str]: 217 | """Custom retrieve tools.""" 218 | raise CustomError 219 | 220 | 221 | async def acustom_retrieve_tools_store( 222 | query: str, 223 | *, 224 | store: Annotated[BaseStore, InjectedStore], 225 | ) -> list[str]: 226 | """Custom retrieve tools.""" 227 | raise CustomError 228 | 229 | 230 | def custom_retrieve_tools_no_store(query: str) -> list[str]: 231 | """Custom retrieve tools.""" 232 | raise CustomError 233 | 234 | 235 | async def acustom_retrieve_tools_no_store(query: str) -> list[str]: 236 | """Custom retrieve tools.""" 237 | raise CustomError 238 | 239 | 240 | @pytest.mark.parametrize( 241 | "custom_retrieve_tools, acustom_retrieve_tools", 242 | [ 243 | (custom_retrieve_tools_store, acustom_retrieve_tools_store), 244 | (custom_retrieve_tools_no_store, acustom_retrieve_tools_no_store), 245 | ], 246 | ) 247 | def test_end_to_end(custom_retrieve_tools, acustom_retrieve_tools) -> None: 248 | retriever_tool_name = custom_retrieve_tools.__name__ 249 | retriever_tool_name_async = acustom_retrieve_tools.__name__ 250 | # Default 251 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings() 252 | run_end_to_end_test(fake_llm, fake_embeddings) 253 | 254 | # Custom 255 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings( 256 | retriever_tool_name=retriever_tool_name_async 257 | ) 258 | with pytest.raises(TypeError): 259 | # No sync function provided 260 | run_end_to_end_test( 261 | fake_llm, 262 | fake_embeddings, 263 | retrieve_tools_coroutine=acustom_retrieve_tools, 264 | ) 265 | 266 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings( 267 | retriever_tool_name=retriever_tool_name 268 | ) 269 | with pytest.raises(CustomError): 270 | # Calls custom sync function 271 | run_end_to_end_test( 272 | fake_llm, 273 | fake_embeddings, 274 | retrieve_tools_function=custom_retrieve_tools, 275 | retrieve_tools_coroutine=acustom_retrieve_tools, 276 | ) 277 | 278 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings( 279 | retriever_tool_name=retriever_tool_name 280 | ) 281 | with pytest.raises(CustomError): 282 | # Calls custom sync function 283 | run_end_to_end_test( 284 | fake_llm, 285 | fake_embeddings, 286 | retrieve_tools_function=custom_retrieve_tools, 287 | ) 288 | 289 | 290 | @pytest.mark.parametrize( 291 | "custom_retrieve_tools, acustom_retrieve_tools", 292 | [ 293 | (custom_retrieve_tools_store, acustom_retrieve_tools_store), 294 | (custom_retrieve_tools_no_store, acustom_retrieve_tools_no_store), 295 | ], 296 | ) 297 | async def test_end_to_end_async(custom_retrieve_tools, acustom_retrieve_tools) -> None: 298 | retriever_tool_name = custom_retrieve_tools.__name__ 299 | retriever_tool_name_async = acustom_retrieve_tools.__name__ 300 | # Default 301 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings() 302 | await run_end_to_end_test_async(fake_llm, fake_embeddings) 303 | 304 | # Custom 305 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings( 306 | retriever_tool_name=retriever_tool_name 307 | ) 308 | with pytest.raises(CustomError): 309 | # Calls custom sync function 310 | await run_end_to_end_test_async( 311 | fake_llm, 312 | fake_embeddings, 313 | retrieve_tools_function=custom_retrieve_tools, 314 | ) 315 | 316 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings( 317 | retriever_tool_name=retriever_tool_name 318 | ) 319 | with pytest.raises(CustomError): 320 | # Calls custom sync function 321 | await run_end_to_end_test_async( 322 | fake_llm, 323 | fake_embeddings, 324 | retrieve_tools_function=custom_retrieve_tools, 325 | retrieve_tools_coroutine=acustom_retrieve_tools, 326 | ) 327 | 328 | fake_llm, fake_embeddings = _get_fake_llm_and_embeddings( 329 | retriever_tool_name=retriever_tool_name_async 330 | ) 331 | with pytest.raises(CustomError): 332 | # Calls custom sync function 333 | await run_end_to_end_test_async( 334 | fake_llm, 335 | fake_embeddings, 336 | retrieve_tools_coroutine=acustom_retrieve_tools, 337 | ) 338 | 339 | 340 | def test_duplicate_tools() -> None: 341 | fake_embeddings = DeterministicFakeEmbedding(size=EMBEDDING_SIZE) 342 | 343 | acos_tool = next(tool for tool in tool_registry.values() if tool.name == "acos") 344 | initial_query = ( 345 | f"{acos_tool.name}: {acos_tool.description}" # make same as embedding 346 | ) 347 | 348 | fake_llm = FakeModel( 349 | messages=iter( 350 | [ 351 | AIMessage( 352 | "", 353 | tool_calls=[ 354 | { 355 | "name": "retrieve_tools", 356 | "args": {"query": initial_query}, 357 | "id": "abc123", 358 | "type": "tool_call", 359 | } 360 | ], 361 | ), 362 | AIMessage( 363 | "", 364 | tool_calls=[ 365 | { 366 | "name": "acos", 367 | "args": {"x": 0.5}, 368 | "id": "abc234", 369 | "type": "tool_call", 370 | } 371 | ], 372 | ), 373 | AIMessage( 374 | "", 375 | tool_calls=[ 376 | { 377 | "name": "retrieve_tools", 378 | "args": {"query": "another tool"}, 379 | "id": "abc345", 380 | "type": "tool_call", 381 | }, 382 | # Retrieval can return the same tool multiple times. Force this 383 | # by adding the same tool call twice. 384 | { 385 | "name": "retrieve_tools", 386 | "args": {"query": initial_query}, 387 | "id": "abc456", 388 | "type": "tool_call", 389 | }, 390 | ], 391 | ), 392 | AIMessage("The arc cosine of 0.5 is approximately 1.047 radians."), 393 | ] 394 | ) 395 | ) 396 | with patch.object( 397 | FakeModel, "bind_tools", wraps=fake_llm.bind_tools 398 | ) as mock_bind_tools: 399 | run_end_to_end_test(fake_llm, fake_embeddings) 400 | mock_bind_tools.assert_called() 401 | for args, _ in mock_bind_tools.call_args_list: 402 | tool_names = [tool.name for tool in args[0] if isinstance(tool, BaseTool)] 403 | assert len(tool_names) == len(set(tool_names)) 404 | 405 | 406 | def test_functions_in_registry() -> None: 407 | tool_registry = {str(uuid.uuid4()): tool.func for tool in all_tools} 408 | fake_embeddings = DeterministicFakeEmbedding(size=EMBEDDING_SIZE) 409 | 410 | acos_tool = next(tool for tool in tool_registry.values() if tool.__name__ == "acos") 411 | initial_query = ( 412 | f"{acos_tool.__name__}: {inspect.getdoc(acos_tool)}" # make same as embedding 413 | ) 414 | fake_llm = FakeModel( 415 | messages=iter( 416 | [ 417 | AIMessage( 418 | "", 419 | tool_calls=[ 420 | { 421 | "name": "retrieve_tools", 422 | "args": {"query": initial_query}, 423 | "id": "abc123", 424 | "type": "tool_call", 425 | } 426 | ], 427 | ), 428 | AIMessage( 429 | "", 430 | tool_calls=[ 431 | { 432 | "name": "acos", 433 | "args": {"x": 0.5}, 434 | "id": "abc234", 435 | "type": "tool_call", 436 | } 437 | ], 438 | ), 439 | AIMessage("The arc cosine of 0.5 is approximately 1.047 radians."), 440 | ] 441 | ) 442 | ) 443 | store = InMemoryStore( 444 | index={ 445 | "embed": fake_embeddings, 446 | "dims": EMBEDDING_SIZE, 447 | "fields": ["description"], 448 | } 449 | ) 450 | for tool_id, tool in tool_registry.items(): 451 | store.put( 452 | ("tools",), 453 | tool_id, 454 | { 455 | "description": f"{tool.__name__}: {inspect.getdoc(tool)}", 456 | }, 457 | ) 458 | 459 | builder = create_agent( 460 | fake_llm, 461 | tool_registry, 462 | ) 463 | agent = builder.compile(store=store) 464 | 465 | result = agent.invoke( 466 | {"messages": "Use available tools to calculate arc cosine of 0.5."} 467 | ) 468 | _validate_result(result, tool_registry=tool_registry) 469 | --------------------------------------------------------------------------------