├── pyrightconfig.json ├── src └── mcp_langgraph_tools │ ├── __init__.py │ ├── __main__.py │ └── mcp_tool_node.py ├── .gitattributes ├── LICENSE ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── pyproject.toml ├── ruff.toml └── Makefile /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": [ 3 | "src/**/*.py" 4 | ], 5 | "exclude": [ 6 | "**/node_modules", 7 | "**/__pycache__", 8 | "**/output", 9 | "**/build" 10 | ], 11 | "ignore": [ 12 | "**/.venv" 13 | ], 14 | "defineConstant": { 15 | "DEBUG": true 16 | }, 17 | "venvPath": ".", 18 | "venv": ".venv", 19 | "reportMissingImports": true, 20 | "reportMissingTypeStubs": false, 21 | "pythonVersion": "3.11", 22 | "typeCheckingMode": "basic" 23 | } 24 | -------------------------------------------------------------------------------- /src/mcp_langgraph_tools/__init__.py: -------------------------------------------------------------------------------- 1 | """Mcp Langgraph Tools.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | import warnings 7 | 8 | from langchain._api import LangChainDeprecationWarning 9 | from langchain_core._api import LangChainBetaWarning 10 | 11 | warnings.simplefilter("ignore", category=LangChainDeprecationWarning) 12 | warnings.simplefilter("ignore", category=LangChainBetaWarning) 13 | 14 | 15 | __author__ = "Paul Robello" 16 | __credits__ = ["Paul Robello"] 17 | __maintainer__ = "Paul Robello" 18 | __email__ = "probello@gmail.com" 19 | __version__ = "0.1.0" 20 | __application_title__ = "Mcp Langgraph Tools" 21 | __application_binary__ = "mcp_langgraph_tools" 22 | __licence__ = "MIT" 23 | 24 | 25 | os.environ["USER_AGENT"] = f"{__application_title__} {__version__}" 26 | 27 | 28 | __all__: list[str] = [ 29 | "__author__", 30 | "__credits__", 31 | "__maintainer__", 32 | "__email__", 33 | "__version__", 34 | "__application_binary__", 35 | "__licence__", 36 | "__application_title__", 37 | ] 38 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Autodetect text files 2 | * text=auto 3 | 4 | # ...Unless the name matches the following 5 | # overriding patterns 6 | 7 | # Definitively text files 8 | *.txt text 9 | *.json text 10 | *.js text 11 | *.ts text 12 | .env text 13 | .env-* text 14 | *.sh text 15 | *.sql text 16 | *.yml text 17 | *.py text 18 | *.js text 19 | *.ts text 20 | *.ini text 21 | *.jq text 22 | Dockerfile text 23 | Dockerfile.* text 24 | makefile text 25 | makefile.* text 26 | Makefile text 27 | Makefile.* text 28 | 29 | # Ensure those won't be messed up with 30 | *.jpg binary 31 | *.gif binary 32 | *.png binary 33 | 34 | # force line endings to be lf so db container does not blow up 35 | **/*.sh text eol=lf 36 | **/*.sql text eol=lf 37 | **/.env text eol=lf 38 | **/.env-* text eol=lf 39 | **/Dockerfile text eol=lf 40 | **/Dockerfile.* text eol=lf 41 | **/*.py text eol=lf 42 | **/*.js text eol=lf 43 | **/*.ts text eol=lf 44 | **/*.jq text eol=lf 45 | **/*.json text eol=lf 46 | **/*.yml text eol=lf 47 | **/Makefile text eol=lf 48 | **/Makefile.* text eol=lf 49 | **/makefile text eol=lf 50 | **/makefile.* text eol=lf 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Paul Robello 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### PythonVanilla template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # Installer logs 30 | pip-log.txt 31 | pip-delete-this-directory.txt 32 | 33 | # Unit test / coverage reports 34 | htmlcov/ 35 | .tox/ 36 | .nox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *.cover 43 | *.py,cover 44 | .hypothesis/ 45 | .pytest_cache/ 46 | cover/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # pyenv 53 | # For a library or package, you might want to ignore these files since the code is 54 | # intended to run in multiple environments; otherwise, check them in: 55 | # .python-version 56 | 57 | # pipenv 58 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 59 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 60 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 61 | # install all needed dependencies. 62 | #Pipfile.lock 63 | 64 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 65 | __pypackages__/ 66 | .ruff_Cache 67 | 68 | .aider* 69 | **/venv 70 | **/.venv 71 | **/.env 72 | **/.idea 73 | /config.json 74 | /output/ 75 | /history.json 76 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: [pre-commit, pre-push] 2 | default_language_version: 3 | python: python3.11 4 | fail_fast: false 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.6.0 8 | hooks: 9 | - id: check-merge-conflict 10 | - id: detect-private-key 11 | - id: end-of-file-fixer 12 | - id: mixed-line-ending 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | - id: check-docstring-first 16 | - id: check-toml 17 | - id: check-yaml 18 | - id: check-json 19 | - id: pretty-format-json 20 | args: [--autofix, --no-sort-keys] 21 | exclude: tests(/\w*)*/functional/|tests/input|tests(/.*)+/conftest.py|doc/data/messages|tests(/\w*)*data/|Pipfile.lock|output/.* 22 | 23 | - repo: local 24 | hooks: 25 | - id: pyright 26 | name: pyright 27 | entry: make 28 | language: system 29 | pass_filenames: false 30 | args: 31 | [typecheck] 32 | exclude: tests(/\w*)*/functional/|tests/input|tests(/\w*)*data/|doc/|output/.* 33 | 34 | - repo: local 35 | hooks: 36 | - id: format 37 | name: format 38 | entry: make 39 | language: system 40 | pass_filenames: false 41 | args: 42 | [format] 43 | exclude: tests(/\w*)*/functional/|tests/input|tests(/\w*)*data/|doc/|output/.* 44 | 45 | - repo: local 46 | hooks: 47 | - id: lint 48 | name: lint 49 | entry: make 50 | language: system 51 | pass_filenames: false 52 | args: 53 | [lint] 54 | exclude: tests(/\w*)*/functional/|tests/input|tests(/\w*)*data/|doc/|output/.* 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MCP Tool Langgraph Integration 2 | 3 | ## Description 4 | Example project of how to integrate MCP endpoint tools into a Langgraph tool node 5 | 6 | The graph consists of only 2 nodes, `agent` and `tool`. 7 | 8 | ## Prerequisites 9 | To use this project, make sure you have Python 3.11. 10 | 11 | ### [uv](https://pypi.org/project/uv/) is recommended 12 | 13 | #### Linux and Mac 14 | ```bash 15 | curl -LsSf https://astral.sh/uv/install.sh | sh 16 | ``` 17 | 18 | #### Windows 19 | ```bash 20 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" 21 | ``` 22 | 23 | ### MCP Server requirements 24 | - This example uses the MCP Server sample `@modelcontextprotocol/server-brave-search` to add Brave Search tools. This requires that you have `node` and `npx` installed. 25 | 26 | ### API Keys 27 | - The MCP Server sample used is for Brave Search, you can get a free API key from https://brave.com/search/api/ 28 | - You will need and API key for the chosen AI provider which defaults to Anthropic but can be changed by editing the `__main__.py` file 29 | - Put all api keys in a .env file in the repository root. 30 | 31 | ## From source Usage 32 | ```shell 33 | uv run mcp_langgraph_tools 34 | ``` 35 | 36 | ## Multiple MCP servers at one time 37 | Check the multi_server branch for a more advanced example of how to use multiple MCP servers at once. 38 | 39 | ## Whats New 40 | 41 | - Version 0.1.0: 42 | - Initial release 43 | 44 | ## Contributing 45 | 46 | Contributions are welcome! Please feel free to submit a Pull Request. 47 | 48 | ## License 49 | 50 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 51 | 52 | ## Author 53 | 54 | Paul Robello - probello@gmail.com 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mcp_langgraph_tools" 3 | version = "0.1.0" 4 | description = "Mcp Langgraph Tools" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "asyncio>=3.4.3", 9 | "langchain-anthropic>=0.3.10", 10 | "langchain-community>=0.3.20", 11 | "langchain-core>=0.3.46", 12 | "langchain-google-genai>=2.0.10", 13 | "langchain-groq>=0.3.0", 14 | "langchain-ollama>=0.2.3", 15 | "langchain-openai>=0.3.9", 16 | "langchain>=0.3.21", 17 | "langgraph>=0.3.18", 18 | "mcp>=1.0.0", 19 | "pydantic>=2.10.6", 20 | "python-dotenv>=1.0.1", 21 | "rich>=13.9.4", 22 | ] 23 | packages = [ 24 | "src/mcp_langgraph_tools", 25 | ] 26 | 27 | [project.scripts] 28 | mcp_langgraph_tools = "mcp_langgraph_tools.__main__:main" 29 | 30 | [tool.setuptools.package-data] 31 | mcp_langgraph_tools = [ 32 | "py.typed", 33 | "*/*.png", 34 | "*/*.md", 35 | "*/*.tcss", 36 | "*.png", 37 | "*.md", 38 | "*.tcss", 39 | ] 40 | 41 | [tool.uv] 42 | dev-dependencies = [ 43 | "build>=1.2.1", 44 | "twine>=5.1.1", 45 | "pyright>=1.1.379", 46 | "pre-commit>=3.8.0", 47 | "ruff>=0.7.0", 48 | "types-orjson>=3.6.2", 49 | "pyinstrument>=5.0.0", 50 | ] 51 | 52 | [tool.hatch.version] 53 | path = "src/mcp_langgraph_tools/__init__.py" 54 | 55 | [tool.hatch.build.targets.wheel] 56 | packages = [ 57 | "src/mcp_langgraph_tools", 58 | ] 59 | include = [ 60 | "*.py", 61 | "py.typed", 62 | "*.png", 63 | "*.md", 64 | "*.tcss", 65 | "*.png", 66 | "*.md", 67 | "*.tcss", 68 | ] 69 | 70 | [tool.hatch.build.targets.sdist] 71 | include = [ 72 | "src/mcp_langgraph_tools", 73 | "LICENSE", 74 | "README.md", 75 | "pyproject.toml", 76 | ] 77 | exclude = [ 78 | "*.pyc", 79 | "__pycache__", 80 | "*.so", 81 | "*.dylib", 82 | ] 83 | 84 | [build-system] 85 | requires = [ 86 | "hatchling", 87 | "wheel", 88 | ] 89 | build-backend = "hatchling.build" 90 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".ipynb_checkpoints", 10 | ".mypy_cache", 11 | ".nox", 12 | ".pants.d", 13 | ".pyenv", 14 | ".pytest_cache", 15 | ".pytype", 16 | ".ruff_cache", 17 | ".svn", 18 | ".tox", 19 | ".venv", 20 | ".vscode", 21 | "__pypackages__", 22 | "_build", 23 | "buck-out", 24 | "build", 25 | "dist", 26 | "node_modules", 27 | "site-packages", 28 | "venv", 29 | ] 30 | 31 | # Same as Black. 32 | line-length = 120 33 | indent-width = 4 34 | 35 | # Assume Python 3.11 36 | target-version = "py311" 37 | 38 | [lint] 39 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 40 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 41 | # McCabe complexity (`C901`) by default. 42 | select = ["E4", "E5", "E7", "E9", "F", "W", "UP"] 43 | ignore = ["E501"] 44 | 45 | # Allow fix for all enabled rules (when `--fix`) is provided. 46 | fixable = ["ALL"] 47 | unfixable = [] 48 | 49 | # Allow unused variables when underscore-prefixed. 50 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 51 | 52 | [format] 53 | # Like Black, use double quotes for strings. 54 | quote-style = "double" 55 | 56 | # Like Black, indent with spaces, rather than tabs. 57 | indent-style = "space" 58 | 59 | # Like Black, respect magic trailing commas. 60 | skip-magic-trailing-comma = false 61 | 62 | # Like Black, automatically detect the appropriate line ending. 63 | line-ending = "auto" 64 | 65 | # Enable auto-formatting of code examples in docstrings. Markdown, 66 | # reStructuredText code/literal blocks and doctests are all supported. 67 | # 68 | # This is currently disabled by default, but it is planned for this 69 | # to be opt-out in the future. 70 | docstring-code-format = true 71 | 72 | # Set the line length limit used when formatting code snippets in 73 | # docstrings. 74 | # 75 | # This only has an effect when the `docstring-code-format` setting is 76 | # enabled. 77 | docstring-code-line-length = "dynamic" 78 | 79 | [lint.isort] 80 | combine-as-imports = true 81 | -------------------------------------------------------------------------------- /src/mcp_langgraph_tools/__main__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from mcp import ClientSession, StdioServerParameters 5 | from mcp.client.stdio import stdio_client 6 | from rich.console import Console 7 | from mcp.types import InitializeResult 8 | from dotenv import load_dotenv 9 | from langchain_anthropic import ChatAnthropic 10 | from langgraph.graph import MessagesState 11 | from langchain_core.messages import HumanMessage, SystemMessage 12 | from langgraph.graph import START, StateGraph 13 | from langgraph.prebuilt import tools_condition 14 | 15 | from .mcp_tool_node import mcp_tool_list, McpToolNode 16 | 17 | load_dotenv() 18 | 19 | console = Console() 20 | 21 | # Define MCP Server Parameters 22 | server_params = StdioServerParameters( 23 | command="npx", 24 | args=["-y", "@modelcontextprotocol/server-brave-search"], 25 | env={ 26 | "BRAVE_API_KEY": os.environ.get("BRAVE_API_KEY"), # get a free key from BRAVE 27 | "PATH": os.environ.get("PATH"), # adding PATH helps MCP spawned process find things your path 28 | }, 29 | ) 30 | 31 | # Works with any tool capable LLM 32 | llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") 33 | # llm = ChatOpenAI(model="gpt-4o") 34 | # llm = ChatOllama(model="llama3.2:latest") 35 | 36 | 37 | async def amain(): 38 | """Async main function to connect to MCP.""" 39 | async with stdio_client(server_params) as (read, write): 40 | async with ClientSession(read, write) as session: 41 | res: InitializeResult = await session.initialize() 42 | try: 43 | llm_tools = await mcp_tool_list(session) 44 | console.print("MCP Tools:", llm_tools) 45 | except Exception as _: 46 | llm_tools = [] 47 | console.print("MCP Server reports no tools available.") 48 | 49 | llm_with_tools = llm.bind_tools(llm_tools) 50 | sys_msg = SystemMessage(content="You are a helpful assistant. Use available tools to assist the user.") 51 | 52 | # Graph Node 53 | def assistant(state: MessagesState): 54 | return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} 55 | 56 | # Build graph 57 | builder = StateGraph(MessagesState) 58 | # Define nodes: these do the work 59 | builder.add_node("assistant", assistant) 60 | builder.add_node("tools", await McpToolNode(session, handle_tool_errors=True).init_funcs()) 61 | 62 | # Define edges: these determine how the control flow moves 63 | builder.add_edge(START, "assistant") 64 | builder.add_conditional_edges( 65 | "assistant", 66 | # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools 67 | # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END 68 | tools_condition, 69 | ) 70 | builder.add_edge("tools", "assistant") 71 | react_graph = builder.compile() 72 | 73 | messages = [HumanMessage(content="Search for Paul Robello the Principal Solution Architect")] 74 | # Invoke the graph with initial messages 75 | messages = await react_graph.ainvoke({"messages": messages}) 76 | for m in messages["messages"]: 77 | m.pretty_print() 78 | 79 | 80 | def main() -> None: 81 | asyncio.run(amain()) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Common make values. 3 | lib := mcp_langgraph_tools 4 | run := uv run 5 | python := $(run) python 6 | pyright := $(run) pyright 7 | ruff := $(run) ruff 8 | twine := $(run) twine 9 | #build := $(python) -m build 10 | build := uvx --from build pyproject-build --installer uv 11 | 12 | export UV_LINK_MODE=copy 13 | export PIPENV_VERBOSITY=-1 14 | ############################################################################## 15 | # Run the app. 16 | .PHONY: run 17 | run: # Run app 18 | $(run) $(lib) 19 | 20 | .PHONY: app_help 21 | app_help: # Show app help 22 | $(run) $(lib) --help 23 | 24 | 25 | ############################################################################## 26 | .PHONY: uv-lock 27 | uv-lock: 28 | uv lock 29 | 30 | .PHONY: uv-sync 31 | uv-sync: 32 | uv sync 33 | 34 | .PHONY: setup 35 | setup: uv-lock uv-sync # use this for first time run 36 | 37 | .PHONY: resetup 38 | resetup: remove-venv setup # Recreate the virtual environment from scratch 39 | 40 | .PHONY: remove-venv 41 | remove-venv: # Remove the virtual environment 42 | rm -rf .venv 43 | 44 | .PHONY: depsupdate 45 | depsupdate: # Update all dependencies 46 | uv sync -U 47 | 48 | .PHONY: depsshow 49 | depsshow: # Show the dependency graph 50 | uv tree 51 | 52 | .PHONY: shell 53 | shell: # Start shell inside of .venv 54 | $(run) bash 55 | ############################################################################## 56 | # Checking/testing/linting/etc. 57 | 58 | .PHONY: format 59 | format: # Reformat the code with ruff. 60 | $(ruff) format src/$(lib) 61 | 62 | .PHONY: lint 63 | lint: # Run ruff over the library 64 | $(ruff) check src/$(lib) --fix 65 | 66 | .PHONY: typecheck 67 | typecheck: # Perform static type checks with pyright 68 | $(pyright) 69 | 70 | .PHONY: typecheck-stats 71 | typecheck-stats: # Perform static type checks with pyright and print stats 72 | $(pyright) --stats 73 | 74 | .PHONY: checkall 75 | checkall: typecheck lint # Check all the things 76 | 77 | .PHONY: pre-commit # run pre-commit checks on all files 78 | pre-commit: 79 | pre-commit run --all-files 80 | 81 | .PHONY: pre-commit-update # run pre-commit and update hooks 82 | pre-commit-update: 83 | pre-commit autoupdate 84 | 85 | ############################################################################## 86 | # Package/publish. 87 | .PHONY: package 88 | package: # Package the library 89 | $(build) -w 90 | 91 | .PHONY: spackage 92 | spackage: # Create a source package for the library 93 | $(build) -s 94 | 95 | .PHONY: packagecheck 96 | packagecheck: clean package spackage # Check the packaging. 97 | $(twine) check dist/* 98 | 99 | .PHONY: testdist 100 | testdist: packagecheck # Perform a test distribution 101 | $(twine) upload --repository testpypi dist/* 102 | #$(twine) upload --skip-existing --repository testpypi dist/* 103 | 104 | .PHONY: dist 105 | dist: packagecheck # Upload to pypi 106 | $(twine) upload --skip-existing dist/* 107 | 108 | ############################################################################## 109 | # Utility. 110 | 111 | 112 | .PHONY: repl 113 | repl: # Start a Python REPL 114 | $(python) 115 | 116 | .PHONY: clean 117 | clean: # Clean the build directories 118 | rm -rf build dist $(lib).egg-info 119 | 120 | .PHONY: help 121 | help: # Display this help 122 | @grep -Eh "^[a-z]+:.+# " $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.+# "}; {printf "%-20s %s\n", $$1, $$2}' 123 | 124 | ############################################################################## 125 | # Housekeeping tasks. 126 | .PHONY: housekeeping 127 | housekeeping: # Perform some git housekeeping 128 | git fsck 129 | git gc --aggressive 130 | git remote update --prune 131 | -------------------------------------------------------------------------------- /src/mcp_langgraph_tools/mcp_tool_node.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from typing import ( 5 | Literal, 6 | cast, 7 | ) 8 | from collections.abc import Callable 9 | from rich.console import Console 10 | from langgraph.prebuilt.tool_node import ( 11 | msg_content_output, 12 | INVALID_TOOL_NAME_ERROR_TEMPLATE, 13 | _handle_tool_error, 14 | _infer_handled_types, 15 | ) 16 | from langchain_core.messages import ( 17 | AIMessage, 18 | AnyMessage, 19 | ToolCall, 20 | ToolMessage, 21 | ) 22 | from langchain_core.runnables import RunnableConfig 23 | from langchain_core.runnables.utils import Input 24 | 25 | from langgraph.errors import GraphInterrupt 26 | from langgraph.store.base import BaseStore 27 | from langgraph.utils.runnable import RunnableCallable 28 | from mcp import ClientSession 29 | 30 | from pydantic import BaseModel 31 | from typing import Any 32 | 33 | console = Console() 34 | 35 | 36 | def mcp_tool_node_basic(session: ClientSession): 37 | """Basic tool node that makes calls to MCP tools.""" 38 | 39 | async def my_tool_node(state: dict): 40 | result = [] 41 | # console.print("state:", state) 42 | for tool_call in state["messages"][-1].tool_calls: 43 | # console.print("Tool calls:", tool_call) 44 | res = await session.call_tool(tool_call["name"], arguments=tool_call["args"]) 45 | tool_message: ToolMessage = ToolMessage( 46 | name=tool_call["name"], 47 | tool_call_id=tool_call["id"], 48 | content=res.content, 49 | status="error" if res.isError else "success", 50 | ) 51 | result.append(tool_message) 52 | return {"messages": result} 53 | 54 | return my_tool_node 55 | 56 | 57 | async def mcp_tool_list(session: ClientSession) -> list[dict[str, Any]]: 58 | """Gets list of tools from MCP and converts to OpenAI standard schema.""" 59 | try: 60 | mcp_tools = (await session.list_tools()).tools 61 | except Exception as _: 62 | mcp_tools = [] 63 | # map mcp tools to openai spec dict 64 | llm_tools = [ 65 | { 66 | "name": tool.name, 67 | "description": tool.description, 68 | "parameters": tool.inputSchema, 69 | } 70 | for tool in mcp_tools 71 | if isinstance(tool, BaseModel) 72 | ] 73 | return llm_tools 74 | 75 | 76 | class McpToolNode(RunnableCallable): 77 | """A node that runs the tools called in the last AIMessage. 78 | 79 | It can be used either in StateGraph with a "messages" state key (or a custom key passed via ToolNode's 'messages_key'). 80 | If multiple tool calls are requested, they will be run in parallel. The output will be 81 | a list of ToolMessages, one for each tool call. 82 | 83 | 84 | Args: 85 | mcp_session: An initialized MCP ClientSession. 86 | whitelisted_tools: A list of tool names that can be run. Defaults to None = Allow all 87 | blacklisted_tools: A list of tool names that should not be run. Defaults to None = Allow all 88 | name: The name of the ToolNode in the graph. Defaults to "tools". 89 | tags: Optional tags to associate with the node. Defaults to None. 90 | handle_tool_errors: How to handle tool errors raised by tools inside the node. Defaults to True. 91 | Must be one of the following: 92 | 93 | - True: all errors will be caught and 94 | a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned. 95 | - str: all errors will be caught and 96 | a ToolMessage with the string value of 'handle_tool_errors' will be returned. 97 | - tuple[type[Exception], ...]: exceptions in the tuple will be caught and 98 | a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned. 99 | - Callable[..., str]: exceptions from the signature of the callable will be caught and 100 | a ToolMessage with the string value of the result of the 'handle_tool_errors' callable will be returned. 101 | - False: none of the errors raised by the tools will be caught 102 | messages_key: The state key in the input that contains the list of messages. 103 | The same key will be used for the output from the ToolNode. 104 | Defaults to "messages". 105 | 106 | Important: 107 | - This node must me used in an async graph. graph.ainvoke() 108 | - Must be called before the first invocation to populate the tools_by_name dictionary. 109 | - The state MUST contain a list of messages. 110 | - The last message MUST be an `AIMessage`. 111 | - The `AIMessage` MUST have `tool_calls` populated. 112 | """ 113 | 114 | name: str = "ToolNode" 115 | 116 | def __init__( 117 | self, 118 | mcp_session: ClientSession, 119 | *, 120 | whitelisted_tools: list[str] | None = None, 121 | blacklisted_tools: list[str] | None = None, 122 | name: str = "tools", 123 | tags: list[str] | None = None, 124 | handle_tool_errors: bool | str | Callable[..., str] | tuple[type[Exception], ...] = True, 125 | messages_key: str = "messages", 126 | ) -> None: 127 | super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) 128 | self.tools_by_name: dict[str, dict] = {} 129 | self.handle_tool_errors = handle_tool_errors 130 | self.messages_key = messages_key 131 | self.mcp_session = mcp_session 132 | self.whitelisted_tools = whitelisted_tools 133 | self.blacklisted_tools = blacklisted_tools 134 | 135 | async def init_funcs(self) -> McpToolNode: 136 | """Must be called before the first invocation to populate the tools_by_name dictionary.""" 137 | llm_tools = await mcp_tool_list(self.mcp_session) 138 | 139 | for tool in llm_tools: 140 | if self.whitelisted_tools is not None and tool["name"] not in self.whitelisted_tools: 141 | continue 142 | if self.blacklisted_tools is not None and tool["name"] in self.blacklisted_tools: 143 | continue 144 | self.tools_by_name[tool["name"]] = tool 145 | return self 146 | 147 | def _func( 148 | self, 149 | input: list[AnyMessage] | dict[str, Any] | BaseModel, 150 | config: RunnableConfig, 151 | *, 152 | store: BaseStore, 153 | ) -> Any: 154 | raise NotImplementedError("You must use _afunc") 155 | 156 | def invoke(self, input: Input, config: RunnableConfig | None = None, **kwargs: Any) -> Any: 157 | raise NotImplementedError("You must use ainvoke") 158 | 159 | async def ainvoke(self, input: Input, config: RunnableConfig | None = None, **kwargs: Any) -> Any: 160 | if "store" not in kwargs: 161 | kwargs["store"] = None 162 | return await super().ainvoke(input, config, **kwargs) 163 | 164 | async def _afunc( 165 | self, 166 | input: list[AnyMessage] | dict[str, Any] | BaseModel, 167 | config: RunnableConfig, 168 | *, 169 | store: BaseStore, 170 | ) -> Any: 171 | tool_calls, output_type = self._parse_input(input, store) 172 | outputs = await asyncio.gather(*(self._arun_one(call, config) for call in tool_calls)) 173 | # TypedDict, pydantic, dataclass, etc. should all be able to load from dict 174 | return outputs if output_type == "list" else {self.messages_key: outputs} 175 | 176 | def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: 177 | raise NotImplementedError("You must use _arun_one") 178 | 179 | async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: 180 | if invalid_tool_message := self._validate_tool_call(call): 181 | return invalid_tool_message 182 | 183 | try: 184 | # console.print(call["args"]) 185 | res = await self.mcp_session.call_tool(call["name"], arguments=call["args"]) 186 | if res.isError: 187 | raise Exception(res.content) 188 | tool_message: ToolMessage = ToolMessage(name=call["name"], tool_call_id=call["id"], content=res.content) 189 | 190 | tool_message.content = cast(str | list, msg_content_output(tool_message.content)) 191 | return tool_message 192 | # GraphInterrupt is a special exception that will always be raised. 193 | # It can be triggered in the following scenarios: 194 | # (1) a NodeInterrupt is raised inside a tool 195 | # (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool 196 | # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool 197 | # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) 198 | except GraphInterrupt as e: 199 | raise e 200 | except Exception as e: 201 | if isinstance(self.handle_tool_errors, tuple): 202 | handled_types: tuple = self.handle_tool_errors 203 | elif callable(self.handle_tool_errors): 204 | handled_types = _infer_handled_types(self.handle_tool_errors) 205 | else: 206 | # default behavior is catching all exceptions 207 | handled_types = (Exception,) 208 | 209 | # Unhandled 210 | if not self.handle_tool_errors or not isinstance(e, handled_types): 211 | raise e 212 | # Handled 213 | else: 214 | content = _handle_tool_error(e, flag=self.handle_tool_errors) 215 | 216 | return ToolMessage(content=content, name=call["name"], tool_call_id=call["id"], status="error") 217 | 218 | def _parse_input( 219 | self, 220 | input: list[AnyMessage] | dict[str, Any] | BaseModel, 221 | store: BaseStore, 222 | ) -> tuple[list[ToolCall], Literal["list", "dict"]]: 223 | if isinstance(input, list): 224 | output_type = "list" 225 | message: AnyMessage = input[-1] 226 | elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])): 227 | output_type = "dict" 228 | message = messages[-1] 229 | elif messages := getattr(input, self.messages_key, None): 230 | # Assume dataclass-like state that can coerce from dict 231 | output_type = "dict" 232 | message = messages[-1] 233 | else: 234 | raise ValueError("No message found in input") 235 | 236 | if not isinstance(message, AIMessage): 237 | raise ValueError("Last message is not an AIMessage") 238 | return message.tool_calls, output_type 239 | 240 | def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None: 241 | if (requested_tool := call["name"]) not in self.tools_by_name: 242 | content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( 243 | requested_tool=requested_tool, 244 | available_tools=", ".join(self.tools_by_name.keys()), 245 | ) 246 | return ToolMessage(content, name=requested_tool, tool_call_id=call["id"], status="error") 247 | else: 248 | return None 249 | --------------------------------------------------------------------------------