├── tests └── __init__.py ├── copilotkit ├── demos │ ├── __init__.py │ ├── qa │ │ ├── __init__.py │ │ ├── demo.py │ │ └── agent.py │ ├── starter │ │ ├── __init__.py │ │ ├── demo.py │ │ └── agent.py │ ├── ai_researcher │ │ ├── __init__.py │ │ ├── demo.py │ │ ├── state.py │ │ ├── agent.py │ │ ├── extract.py │ │ ├── search.py │ │ ├── summarize.py │ │ └── steps.py │ ├── autotale_ai │ │ ├── __init__.py │ │ ├── story │ │ │ ├── __init__.py │ │ │ ├── style.py │ │ │ ├── outline.py │ │ │ ├── characters.py │ │ │ └── story.py │ │ ├── demo.py │ │ ├── state.py │ │ ├── agent.py │ │ └── chatbot.py │ ├── multi_agent │ │ ├── __init__.py │ │ ├── pirate_agent.py │ │ ├── demo.py │ │ ├── joke_agent.py │ │ └── email_agent.py │ ├── research_canvas │ │ ├── __init__.py │ │ ├── state.py │ │ ├── delete.py │ │ ├── model.py │ │ ├── demo.py │ │ ├── agent.py │ │ ├── download.py │ │ ├── search.py │ │ └── chat.py │ ├── wait_user_input │ │ ├── __init__.py │ │ ├── demo.py │ │ └── agent.py │ └── multi_agent_anthropic │ │ ├── __init__.py │ │ ├── pirate_agent.py │ │ ├── demo.py │ │ ├── joke_agent.py │ │ └── email_agent.py ├── integrations │ ├── __init__.py │ └── fastapi.py ├── state.py ├── __init__.py ├── logging.py ├── types.py ├── exc.py ├── agent.py ├── action.py ├── parameter.py ├── sdk.py ├── langchain.py ├── langgraph_cloud_agent.py └── langgraph_agent.py ├── README.md ├── .vscode ├── settings.json └── cspell.json ├── langgraph.json ├── pyproject.toml └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/qa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/starter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CopilotKit python SDK (alpha) 2 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/story/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/wait_user_input/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent_anthropic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.exclude": { 3 | "**/__pycache__": true 4 | }, 5 | "python.analysis.typeCheckingMode": "basic" 6 | } 7 | -------------------------------------------------------------------------------- /langgraph.json: -------------------------------------------------------------------------------- 1 | { 2 | "python_version": "3.11", 3 | "dockerfile_lines": [], 4 | "dependencies": ["."], 5 | "graphs": { 6 | "autotale_ai": "./copilotkit/demos/autotale_ai/agent.py:graph" 7 | }, 8 | "env": ".env" 9 | } 10 | -------------------------------------------------------------------------------- /copilotkit/state.py: -------------------------------------------------------------------------------- 1 | """CopilotKit state""" 2 | 3 | from typing import List, Any, TypedDict 4 | from langgraph.graph import MessagesState 5 | 6 | class CopilotKitProperties(TypedDict): 7 | """CopilotKit state""" 8 | actions: List[Any] 9 | 10 | class CopilotKitState(MessagesState): 11 | """CopilotKit state""" 12 | copilotkit: CopilotKitProperties 13 | -------------------------------------------------------------------------------- /.vscode/cspell.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2", 3 | "language": "en", 4 | "words": [ 5 | "coagent", 6 | "langgraph", 7 | "copilotkit", 8 | "checkpointer", 9 | "fastapi", 10 | "uvicorn", 11 | "serializeable", 12 | "astream", 13 | "langchain", 14 | "openai", 15 | "ainvoke", 16 | "readables", 17 | "autotale", 18 | "pydantic", 19 | "childrens", 20 | "partialjson", 21 | "dotenv" 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /copilotkit/__init__.py: -------------------------------------------------------------------------------- 1 | """CopilotKit SDK""" 2 | from .sdk import CopilotKitSDK 3 | from .action import Action 4 | from .langgraph_agent import LangGraphAgent 5 | # from .langgraph_cloud_agent import LangGraphCloudAgent 6 | from .state import CopilotKitState 7 | from .parameter import Parameter 8 | __all__ = [ 9 | 'CopilotKitSDK', 10 | 'Action', 11 | 'LangGraphAgent', 12 | # 'LangGraphCloudAgent', 13 | 'CopilotKitState', 14 | 'Parameter' 15 | ] 16 | -------------------------------------------------------------------------------- /copilotkit/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logging setup for CopilotKit. 3 | """ 4 | 5 | import logging 6 | import os 7 | import sys 8 | 9 | def get_logger(name: str): 10 | """ 11 | Get a logger with the given name. 12 | """ 13 | logger = logging.getLogger(name) 14 | log_level = os.getenv('LOG_LEVEL') 15 | if log_level: 16 | logger.setLevel(log_level.upper()) 17 | return logger 18 | 19 | def bold(text: str) -> str: 20 | """ 21 | Bold the given text. 22 | """ 23 | if sys.stdout.isatty(): 24 | return f"\033[1m{text}\033[0m" 25 | return text 26 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/story/style.py: -------------------------------------------------------------------------------- 1 | """ 2 | Style node. 3 | """ 4 | 5 | import json 6 | from langchain_core.tools import tool 7 | from copilotkit.demos.autotale_ai.state import AgentState 8 | 9 | @tool 10 | def set_style(style: str): 11 | """Sets the graphical style of the story.""" 12 | return style 13 | 14 | 15 | async def style_node(state: AgentState): 16 | """ 17 | The style node is responsible for setting the graphical style of the story. 18 | """ 19 | last_message = state["messages"][-1] 20 | return { 21 | "style": json.loads(last_message.content)["style"] 22 | } 23 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/story/outline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Outline node. 3 | """ 4 | 5 | import json 6 | from langchain_core.tools import tool 7 | from copilotkit.demos.autotale_ai.state import AgentState 8 | 9 | @tool 10 | def set_outline(outline: str): 11 | """Sets the outline of the story.""" 12 | return outline 13 | 14 | 15 | async def outline_node(state: AgentState): 16 | """ 17 | The outline node is responsible for generating an outline for the story. 18 | """ 19 | last_message = state["messages"][-1] 20 | return { 21 | "outline": json.loads(last_message.content)["outline"] 22 | } 23 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/demo.py: -------------------------------------------------------------------------------- 1 | """Demo""" 2 | 3 | from fastapi import FastAPI 4 | import uvicorn 5 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 6 | from copilotkit import CopilotKitSDK, LangGraphAgent 7 | from copilotkit.demos.ai_researcher.agent import graph 8 | 9 | app = FastAPI() 10 | sdk = CopilotKitSDK( 11 | agents=[ 12 | LangGraphAgent( 13 | name="search_agent", 14 | description="Search agent.", 15 | agent=graph, 16 | ) 17 | ], 18 | ) 19 | 20 | add_fastapi_endpoint(app, sdk, "/copilotkit") 21 | 22 | def main(): 23 | """Run the uvicorn server.""" 24 | uvicorn.run("copilotkit.demos.ai_researcher.demo:app", host="127.0.0.1", port=8000, reload=True) 25 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/demo.py: -------------------------------------------------------------------------------- 1 | """Demo""" 2 | 3 | from fastapi import FastAPI 4 | import uvicorn 5 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 6 | from copilotkit import CopilotKitSDK, LangGraphAgent 7 | from .agent import graph 8 | 9 | app = FastAPI() 10 | sdk = CopilotKitSDK( 11 | agents=[ 12 | LangGraphAgent( 13 | name="childrensBookAgent", 14 | description="Write a children's book.", 15 | agent=graph, 16 | ) 17 | ], 18 | ) 19 | 20 | add_fastapi_endpoint(app, sdk, "/copilotkit") 21 | 22 | def main(): 23 | """Run the uvicorn server.""" 24 | uvicorn.run( 25 | "copilotkit.demos.autotale_ai.demo:app", 26 | host="127.0.0.1", 27 | port=8000, 28 | reload=True 29 | ) 30 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/state.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the state definition for the AI. 3 | It defines the state of the agent and the state of the conversation. 4 | """ 5 | 6 | from typing import List, TypedDict, Optional 7 | from langgraph.graph import MessagesState 8 | 9 | class Step(TypedDict): 10 | """ 11 | Represents a step taken in the research process. 12 | """ 13 | id: str 14 | description: str 15 | status: str 16 | type: str 17 | description: str 18 | search_result: Optional[str] 19 | result: Optional[str] 20 | updates: Optional[str] 21 | 22 | class AgentState(MessagesState): 23 | """ 24 | This is the state of the agent. 25 | It is a subclass of the MessagesState class from langgraph. 26 | """ 27 | steps: List[Step] 28 | answer: Optional[str] 29 | -------------------------------------------------------------------------------- /copilotkit/demos/qa/demo.py: -------------------------------------------------------------------------------- 1 | """Demo""" 2 | 3 | from dotenv import load_dotenv 4 | load_dotenv() # pylint: disable=wrong-import-position 5 | 6 | from fastapi import FastAPI 7 | import uvicorn 8 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 9 | from copilotkit import CopilotKitSDK, LangGraphAgent 10 | from copilotkit.demos.qa.agent import graph 11 | 12 | 13 | app = FastAPI() 14 | sdk = CopilotKitSDK( 15 | agents=[ 16 | LangGraphAgent( 17 | name="email_agent", 18 | description="This agent sends emails", 19 | agent=graph, 20 | ) 21 | ], 22 | ) 23 | 24 | add_fastapi_endpoint(app, sdk, "/copilotkit") 25 | 26 | def main(): 27 | """Run the uvicorn server.""" 28 | uvicorn.run("copilotkit.demos.ai_researcher.demo:app", host="127.0.0.1", port=8000, reload=True) 29 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/state.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the state definition for the autotale AI. 3 | It defines the state of the agent and the state of the conversation. 4 | """ 5 | 6 | from typing import List, TypedDict 7 | from langgraph.graph import MessagesState 8 | 9 | 10 | 11 | class Character(TypedDict): 12 | """ 13 | Represents a character in the tale. 14 | """ 15 | name: str 16 | appearance: str 17 | traits: List[str] 18 | 19 | 20 | class Page(TypedDict): 21 | """ 22 | Represents a page in the children's story with an image url. 23 | """ 24 | content: str 25 | 26 | class AgentState(MessagesState): 27 | """ 28 | This is the state of the agent. 29 | It is a subclass of the MessagesState class from langgraph. 30 | """ 31 | outline: str 32 | characters: List[Character] 33 | story: List[Page] 34 | -------------------------------------------------------------------------------- /copilotkit/demos/wait_user_input/demo.py: -------------------------------------------------------------------------------- 1 | """Demo""" 2 | 3 | from dotenv import load_dotenv 4 | load_dotenv() # pylint: disable=wrong-import-position 5 | 6 | from fastapi import FastAPI 7 | import uvicorn 8 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 9 | from copilotkit import CopilotKitSDK, LangGraphAgent 10 | from copilotkit.demos.wait_user_input.agent import graph 11 | 12 | 13 | app = FastAPI() 14 | sdk = CopilotKitSDK( 15 | agents=[ 16 | LangGraphAgent( 17 | name="weather_agent", 18 | description="This agent deals with everything weather related", 19 | agent=graph, 20 | ) 21 | ], 22 | ) 23 | 24 | add_fastapi_endpoint(app, sdk, "/copilotkit") 25 | 26 | def main(): 27 | """Run the uvicorn server.""" 28 | uvicorn.run("copilotkit.demos.wait_user_input.demo:app", host="127.0.0.1", port=8000, reload=True) 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "copilotkit" 3 | version = "0.1.27" 4 | description = "CopilotKit python SDK" 5 | authors = ["Markus Ecker "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://copilotkit.ai" 9 | repository = "https://github.com/CopilotKit/sdk-python" 10 | keywords = ["copilot", "copilotkit", "langgraph", "langchain", "ai", "langsmith", "langserve"] 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.9,<4.0" 14 | langgraph = "^0.2.35" 15 | httpx = "^0.27.2" 16 | fastapi = "^0.115.0" 17 | langchain = "^0.3.3" 18 | langchain-openai = "^0.2.2" 19 | langchain-anthropic = "^0.2.3" 20 | partialjson = "^0.0.8" 21 | langgraph-sdk = "^0.1.32" 22 | toml = "^0.10.2" 23 | 24 | [build-system] 25 | requires = ["poetry-core"] 26 | build-backend = "poetry.core.masonry.api" 27 | 28 | [tool.poetry.scripts] 29 | demo = "copilotkit.demos.research_canvas.demo:main" 30 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/story/characters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Characters node. 3 | """ 4 | 5 | from typing import List 6 | import json 7 | from langchain_core.tools import tool 8 | 9 | from copilotkit.demos.autotale_ai.state import AgentState, Character 10 | 11 | 12 | 13 | @tool 14 | def set_characters(characters: List[Character]): 15 | """ 16 | Extract the book's main characters from the conversation. 17 | The traits should be short: 3-4 adjectives. 18 | The appearance should be as detailed as possible. What they look like, their clothes, etc. 19 | """ 20 | return characters 21 | 22 | 23 | def characters_node(state: AgentState): 24 | """ 25 | The characters node is responsible for extracting the characters from the conversation. 26 | """ 27 | last_message = state["messages"][-1] 28 | return { 29 | "characters": json.loads(last_message.content)["characters"] 30 | } 31 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/state.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the state definition for the AI. 3 | It defines the state of the agent and the state of the conversation. 4 | """ 5 | 6 | from typing import List, TypedDict 7 | from langgraph.graph import MessagesState 8 | 9 | class Resource(TypedDict): 10 | """ 11 | Represents a resource. Give it a good title and a short description. 12 | """ 13 | url: str 14 | title: str 15 | description: str 16 | 17 | class Log(TypedDict): 18 | """ 19 | Represents a log of an action performed by the agent. 20 | """ 21 | message: str 22 | done: bool 23 | 24 | class AgentState(MessagesState): 25 | """ 26 | This is the state of the agent. 27 | It is a subclass of the MessagesState class from langgraph. 28 | """ 29 | model: str 30 | research_question: str 31 | report: str 32 | resources: List[Resource] 33 | logs: List[Log] 34 | -------------------------------------------------------------------------------- /copilotkit/types.py: -------------------------------------------------------------------------------- 1 | """State for CopilotKit""" 2 | 3 | from typing import TypedDict 4 | from enum import Enum 5 | from typing_extensions import NotRequired 6 | 7 | class MessageRole(Enum): 8 | """Message role""" 9 | ASSISTANT = "assistant" 10 | SYSTEM = "system" 11 | USER = "user" 12 | 13 | class Message(TypedDict): 14 | """Message""" 15 | id: str 16 | createdAt: str 17 | 18 | class TextMessage(Message): 19 | """Text message""" 20 | role: MessageRole 21 | content: str 22 | 23 | class ActionExecutionMessage(Message): 24 | """Action execution message""" 25 | name: str 26 | arguments: dict 27 | scope: str 28 | 29 | class ResultMessage(Message): 30 | """Result message""" 31 | actionExecutionId: str 32 | actionName: str 33 | result: str 34 | 35 | class IntermediateStateConfig(TypedDict): 36 | """Intermediate state config""" 37 | state_key: str 38 | tool: str 39 | tool_argument: NotRequired[str] 40 | -------------------------------------------------------------------------------- /copilotkit/demos/starter/demo.py: -------------------------------------------------------------------------------- 1 | """Demo""" 2 | 3 | from dotenv import load_dotenv 4 | load_dotenv() # pylint: disable=wrong-import-position 5 | 6 | from fastapi import FastAPI 7 | import uvicorn 8 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 9 | from copilotkit import CopilotKitSDK, LangGraphAgent 10 | from copilotkit.demos.starter.agent import graph 11 | from copilotkit.langchain import copilotkit_customize_config 12 | 13 | app = FastAPI() 14 | sdk = CopilotKitSDK( 15 | agents=[ 16 | LangGraphAgent( 17 | name="translate_agent", 18 | description="Translate agent that translates text.", 19 | agent=graph, 20 | config=copilotkit_customize_config( 21 | base_config={ 22 | "recursion_limit": 10, 23 | }, 24 | emit_messages=True, 25 | ), 26 | ) 27 | ], 28 | ) 29 | 30 | add_fastapi_endpoint(app, sdk, "/copilotkit") 31 | 32 | def main(): 33 | """Run the uvicorn server.""" 34 | uvicorn.run("copilotkit.demos.starter.demo:app", host="127.0.0.1", port=8000, reload=True) 35 | -------------------------------------------------------------------------------- /copilotkit/exc.py: -------------------------------------------------------------------------------- 1 | """Exceptions for CopilotKit.""" 2 | 3 | class ActionNotFoundException(Exception): 4 | """Exception raised when an action or agent is not found.""" 5 | 6 | def __init__(self, name: str): 7 | self.name = name 8 | super().__init__(f"Action '{name}' not found.") 9 | 10 | class AgentNotFoundException(Exception): 11 | """Exception raised when an agent is not found.""" 12 | 13 | def __init__(self, name: str): 14 | self.name = name 15 | super().__init__(f"Agent '{name}' not found.") 16 | 17 | class ActionExecutionException(Exception): 18 | """Exception raised when an action fails to execute.""" 19 | 20 | def __init__(self, name: str, error: Exception): 21 | self.name = name 22 | self.error = error 23 | super().__init__(f"Action '{name}' failed to execute: {error}") 24 | 25 | class AgentExecutionException(Exception): 26 | """Exception raised when an agent fails to execute.""" 27 | 28 | def __init__(self, name: str, error: Exception): 29 | self.name = name 30 | self.error = error 31 | super().__init__(f"Agent '{name}' failed to execute: {error}") 32 | -------------------------------------------------------------------------------- /copilotkit/agent.py: -------------------------------------------------------------------------------- 1 | """Agents""" 2 | 3 | from typing import Optional, List, TypedDict 4 | from abc import ABC, abstractmethod 5 | from .types import Message 6 | from .action import ActionDict 7 | 8 | class AgentDict(TypedDict): 9 | """Agent dictionary""" 10 | name: str 11 | description: Optional[str] 12 | 13 | class Agent(ABC): 14 | """Agent class for CopilotKit""" 15 | def __init__( 16 | self, 17 | *, 18 | name: str, 19 | description: Optional[str] = None, 20 | ): 21 | self.name = name 22 | self.description = description 23 | 24 | @abstractmethod 25 | def execute( # pylint: disable=too-many-arguments 26 | self, 27 | *, 28 | state: dict, 29 | messages: List[Message], 30 | thread_id: Optional[str] = None, 31 | node_name: Optional[str] = None, 32 | actions: Optional[List[ActionDict]] = None, 33 | ): 34 | """Execute the agent""" 35 | 36 | def dict_repr(self) -> AgentDict: 37 | """Dict representation of the action""" 38 | return { 39 | 'name': self.name, 40 | 'description': self.description or '' 41 | } 42 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/delete.py: -------------------------------------------------------------------------------- 1 | """Delete Resources""" 2 | 3 | import json 4 | from typing import cast 5 | from langchain_core.runnables import RunnableConfig 6 | from langchain_core.messages import ToolMessage, AIMessage 7 | from copilotkit.demos.research_canvas.state import AgentState 8 | 9 | async def delete_node(state: AgentState, config: RunnableConfig): # pylint: disable=unused-argument 10 | """ 11 | Delete Node 12 | """ 13 | return state 14 | 15 | async def perform_delete_node(state: AgentState, config: RunnableConfig): # pylint: disable=unused-argument 16 | """ 17 | Perform Delete Node 18 | """ 19 | ai_message = cast(AIMessage, state["messages"][-2]) 20 | tool_message = cast(ToolMessage, state["messages"][-1]) 21 | if tool_message.content == "YES": 22 | if ai_message.tool_calls: 23 | urls = ai_message.tool_calls[0]["args"]["urls"] 24 | else: 25 | parsed_tool_call = json.loads(ai_message.additional_kwargs["function_call"]["arguments"]) 26 | urls = parsed_tool_call["urls"] 27 | state["resources"] = [ 28 | resource for resource in state["resources"] if resource["url"] not in urls 29 | ] 30 | 31 | return state 32 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides a function to get a model based on the configuration. 3 | """ 4 | import os 5 | from typing import cast, Any 6 | from langchain_core.language_models.chat_models import BaseChatModel 7 | from copilotkit.demos.research_canvas.state import AgentState 8 | 9 | def get_model(state: AgentState) -> BaseChatModel: 10 | """ 11 | Get a model based on the environment variable. 12 | """ 13 | 14 | state_model = state.get("model") 15 | model = os.getenv("MODEL", state_model) 16 | 17 | print(f"Using model: {model}") 18 | 19 | if model == "openai": 20 | from langchain_openai import ChatOpenAI 21 | return ChatOpenAI(temperature=0, model="gpt-4o-mini") 22 | if model == "anthropic": 23 | from langchain_anthropic import ChatAnthropic 24 | return ChatAnthropic( 25 | temperature=0, 26 | model_name="claude-3-5-sonnet-20240620", 27 | timeout=None, 28 | stop=None 29 | ) 30 | if model == "google_genai": 31 | from langchain_google_genai import ChatGoogleGenerativeAI 32 | return ChatGoogleGenerativeAI( 33 | temperature=0, 34 | model="gemini-1.5-pro", 35 | api_key=cast(Any, os.getenv("GOOGLE_API_KEY")) or None 36 | ) 37 | 38 | raise ValueError("Invalid model specified") 39 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent_anthropic/pirate_agent.py: -------------------------------------------------------------------------------- 1 | """Test Pirate Agent""" 2 | 3 | from typing import Any, cast 4 | from langgraph.graph import StateGraph, END 5 | from langgraph.graph import MessagesState 6 | from langgraph.checkpoint.memory import MemorySaver 7 | from langchain_core.runnables import RunnableConfig 8 | from copilotkit.langchain import copilotkit_emit_message 9 | 10 | class PirateAgentState(MessagesState): 11 | """Pirate Agent State""" 12 | 13 | async def pirate_node(state: PirateAgentState, config: RunnableConfig): # pylint: disable=unused-argument 14 | """ 15 | Speaks like a pirate 16 | """ 17 | 18 | await copilotkit_emit_message(config, "Arr!!!") 19 | 20 | # system_message = "You speak like a pirate. Your name is Captain Copilot" 21 | 22 | # pirate_model = ChatOpenAI(model="gpt-4o") 23 | 24 | # response = await pirate_model.ainvoke([ 25 | # *state["messages"], 26 | # SystemMessage( 27 | # content=system_message 28 | # ) 29 | # ], config) 30 | 31 | 32 | return { 33 | "messages": state["messages"], 34 | } 35 | 36 | workflow = StateGraph(PirateAgentState) 37 | workflow.add_node("pirate_node", cast(Any, pirate_node)) 38 | workflow.set_entry_point("pirate_node") 39 | 40 | workflow.add_edge("pirate_node", END) 41 | memory = MemorySaver() 42 | pirate_graph = workflow.compile(checkpointer=memory) 43 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/demo.py: -------------------------------------------------------------------------------- 1 | """Demo""" 2 | 3 | import os 4 | from dotenv import load_dotenv 5 | load_dotenv() 6 | 7 | # pylint: disable=wrong-import-position 8 | from fastapi import FastAPI 9 | import uvicorn 10 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 11 | from copilotkit import CopilotKitSDK, LangGraphAgent 12 | from copilotkit.demos.research_canvas.agent import graph 13 | from copilotkit.langchain import copilotkit_messages_to_langchain 14 | 15 | app = FastAPI() 16 | sdk = CopilotKitSDK( 17 | agents=[ 18 | LangGraphAgent( 19 | name="research_agent", 20 | description="Research agent.", 21 | agent=graph, 22 | ), 23 | LangGraphAgent( 24 | name="research_agent_google_genai", 25 | description="Research agent.", 26 | agent=graph, 27 | copilotkit_config={ 28 | "convert_messages": copilotkit_messages_to_langchain(use_function_call=True) 29 | } 30 | ) 31 | ], 32 | ) 33 | 34 | add_fastapi_endpoint(app, sdk, "/copilotkit") 35 | 36 | # add new route for health check 37 | @app.get("/health") 38 | def health(): 39 | """Health check.""" 40 | return {"status": "ok"} 41 | 42 | 43 | def main(): 44 | """Run the uvicorn server.""" 45 | port = int(os.getenv("PORT", "8000")) 46 | uvicorn.run("copilotkit.demos.research_canvas.demo:app", host="0.0.0.0", port=port, reload=True) 47 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent/pirate_agent.py: -------------------------------------------------------------------------------- 1 | """Test Pirate Agent""" 2 | 3 | from typing import Any, cast 4 | # from langchain_openai import ChatOpenAI 5 | from langgraph.graph import StateGraph, END 6 | from langgraph.graph import MessagesState 7 | from langgraph.checkpoint.memory import MemorySaver 8 | from langchain_core.runnables import RunnableConfig 9 | # from langchain_core.messages import SystemMessage 10 | from copilotkit.langchain import copilotkit_emit_message 11 | 12 | class PirateAgentState(MessagesState): 13 | """Pirate Agent State""" 14 | 15 | async def pirate_node(state: PirateAgentState, config: RunnableConfig): # pylint: disable=unused-argument 16 | """ 17 | Speaks like a pirate 18 | """ 19 | 20 | await copilotkit_emit_message("Arr!!!", config) 21 | 22 | # system_message = "You speak like a pirate. Your name is Captain Copilot" 23 | 24 | # pirate_model = ChatOpenAI(model="gpt-4o") 25 | 26 | # response = await pirate_model.ainvoke([ 27 | # *state["messages"], 28 | # SystemMessage( 29 | # content=system_message 30 | # ) 31 | # ], config) 32 | 33 | 34 | return { 35 | "messages": state["messages"], 36 | } 37 | 38 | workflow = StateGraph(PirateAgentState) 39 | workflow.add_node("pirate_node", cast(Any, pirate_node)) 40 | workflow.set_entry_point("pirate_node") 41 | 42 | workflow.add_edge("pirate_node", END) 43 | memory = MemorySaver() 44 | pirate_graph = workflow.compile(checkpointer=memory) 45 | -------------------------------------------------------------------------------- /copilotkit/action.py: -------------------------------------------------------------------------------- 1 | """Actions""" 2 | 3 | 4 | from inspect import iscoroutinefunction 5 | from typing import Optional, List, Callable, TypedDict, Any, cast 6 | from .parameter import Parameter, normalize_parameters 7 | 8 | class ActionDict(TypedDict): 9 | """Dict representation of an action""" 10 | name: str 11 | description: str 12 | parameters: List[Parameter] 13 | 14 | class ActionResultDict(TypedDict): 15 | """Dict representation of an action result""" 16 | result: Any 17 | 18 | class Action: # pylint: disable=too-few-public-methods 19 | """Action class for CopilotKit""" 20 | def __init__( 21 | self, 22 | *, 23 | name: str, 24 | handler: Callable, 25 | description: Optional[str] = None, 26 | parameters: Optional[List[Parameter]] = None, 27 | ): 28 | self.name = name 29 | self.description = description 30 | self.parameters = parameters 31 | self.handler = handler 32 | 33 | async def execute( 34 | self, 35 | *, 36 | arguments: dict 37 | ) -> ActionResultDict: 38 | """Execute the action""" 39 | result = self.handler(**arguments) 40 | 41 | return { 42 | "result": await result if iscoroutinefunction(self.handler) else result 43 | } 44 | 45 | def dict_repr(self) -> ActionDict: 46 | """Dict representation of the action""" 47 | return { 48 | 'name': self.name, 49 | 'description': self.description or '', 50 | 'parameters': normalize_parameters(cast(Any, self.parameters)), 51 | } 52 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the main entry point for the AI. 3 | It defines the workflow graph and the entry point for the agent. 4 | """ 5 | # pylint: disable=line-too-long, unused-import 6 | import json 7 | 8 | from langgraph.graph import StateGraph, END 9 | from langgraph.checkpoint.memory import MemorySaver 10 | 11 | from copilotkit.demos.ai_researcher.state import AgentState 12 | from copilotkit.demos.ai_researcher.steps import steps_node 13 | from copilotkit.demos.ai_researcher.search import search_node 14 | from copilotkit.demos.ai_researcher.summarize import summarize_node 15 | from copilotkit.demos.ai_researcher.extract import extract_node 16 | 17 | def route(state): 18 | """Route to research nodes.""" 19 | if not state.get("steps", None): 20 | return END 21 | 22 | current_step = next((step for step in state["steps"] if step["status"] == "pending"), None) 23 | 24 | if not current_step: 25 | return "summarize_node" 26 | 27 | if current_step["type"] == "search": 28 | return "search_node" 29 | 30 | raise ValueError(f"Unknown step type: {current_step['type']}") 31 | 32 | # Define a new graph 33 | workflow = StateGraph(AgentState) 34 | workflow.add_node("steps_node", steps_node) 35 | workflow.add_node("search_node", search_node) 36 | workflow.add_node("summarize_node", summarize_node) 37 | workflow.add_node("extract_node", extract_node) 38 | # Chatbot 39 | workflow.set_entry_point("steps_node") 40 | 41 | workflow.add_conditional_edges( 42 | "steps_node", 43 | route, 44 | ["summarize_node", "search_node", END] 45 | ) 46 | 47 | workflow.add_edge("search_node", "extract_node") 48 | 49 | workflow.add_conditional_edges( 50 | "extract_node", 51 | route, 52 | ["summarize_node", "search_node"] 53 | ) 54 | 55 | workflow.add_edge("summarize_node", END) 56 | 57 | memory = MemorySaver() 58 | graph = workflow.compile(checkpointer=memory) 59 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a demo of the CopilotKit SDK. 3 | """ 4 | 5 | from fastapi import FastAPI 6 | import uvicorn 7 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 8 | from copilotkit import CopilotKitSDK, Action, LangGraphAgent 9 | from copilotkit.demos.multi_agent.joke_agent import joke_graph 10 | from copilotkit.demos.multi_agent.email_agent import email_graph 11 | from copilotkit.demos.multi_agent.pirate_agent import pirate_graph 12 | 13 | def greet_user(name): 14 | """Greet the user.""" 15 | print(f"Hello, {name}!") 16 | return "The user has been greeted. Tell them to check the console." 17 | 18 | app = FastAPI() 19 | sdk = CopilotKitSDK( 20 | actions=[ 21 | Action( 22 | name="greet_user", 23 | description="Greet the user.", 24 | handler=greet_user, 25 | parameters=[ 26 | { 27 | "name": "name", 28 | "description": "The name of the user to greet.", 29 | "type": "string", 30 | } 31 | ] 32 | ), 33 | ], 34 | agents=[ 35 | LangGraphAgent( 36 | name="joke_agent", 37 | description="Make a joke.", 38 | agent=joke_graph, 39 | ), 40 | LangGraphAgent( 41 | name="email_agent", 42 | description="Write an email.", 43 | agent=email_graph, 44 | ), 45 | LangGraphAgent( 46 | name="pirate_agent", 47 | description="Speak like a pirate.", 48 | agent=pirate_graph, 49 | ) 50 | ], 51 | ) 52 | 53 | add_fastapi_endpoint(app, sdk, "/copilotkit") 54 | 55 | 56 | def main(): 57 | """Run the uvicorn server.""" 58 | uvicorn.run( 59 | "copilotkit.demos.multi_agent.demo:app", 60 | host="127.0.0.1", 61 | port=8000, 62 | reload=True 63 | ) 64 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent_anthropic/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a demo of the CopilotKit SDK. 3 | """ 4 | 5 | from fastapi import FastAPI 6 | import uvicorn 7 | from copilotkit.integrations.fastapi import add_fastapi_endpoint 8 | from copilotkit import CopilotKitSDK, Action, LangGraphAgent 9 | from copilotkit.demos.multi_agent_anthropic.joke_agent import joke_graph 10 | from copilotkit.demos.multi_agent_anthropic.email_agent import email_graph 11 | from copilotkit.demos.multi_agent_anthropic.pirate_agent import pirate_graph 12 | 13 | def greet_user(name): 14 | """Greet the user.""" 15 | print(f"Hello, {name}!") 16 | return "The user has been greeted. Tell them to check the console." 17 | 18 | app = FastAPI() 19 | sdk = CopilotKitSDK( 20 | actions=[ 21 | Action( 22 | name="greet_user", 23 | description="Greet the user.", 24 | handler=greet_user, 25 | parameters=[ 26 | { 27 | "name": "name", 28 | "description": "The name of the user to greet.", 29 | "type": "string", 30 | } 31 | ] 32 | ), 33 | ], 34 | agents=[ 35 | LangGraphAgent( 36 | name="joke_agent", 37 | description="Make a joke.", 38 | agent=joke_graph, 39 | ), 40 | LangGraphAgent( 41 | name="email_agent", 42 | description="Write an email.", 43 | agent=email_graph, 44 | ), 45 | LangGraphAgent( 46 | name="pirate_agent", 47 | description="Speak like a pirate.", 48 | agent=pirate_graph, 49 | ) 50 | ], 51 | ) 52 | 53 | add_fastapi_endpoint(app, sdk, "/copilotkit") 54 | 55 | 56 | def main(): 57 | """Run the uvicorn server.""" 58 | uvicorn.run( 59 | "copilotkit.demos.multi_agent_anthropic.demo:app", 60 | host="127.0.0.1", 61 | port=8000, 62 | reload=True 63 | ) 64 | -------------------------------------------------------------------------------- /copilotkit/parameter.py: -------------------------------------------------------------------------------- 1 | """Parameter classes for CopilotKit""" 2 | 3 | from typing import TypedDict, Optional, Literal, List, Union, cast, Any 4 | from typing_extensions import NotRequired 5 | 6 | class SimpleParameter(TypedDict): 7 | """Simple parameter class""" 8 | name: str 9 | description: NotRequired[str] 10 | required: NotRequired[bool] 11 | type: NotRequired[Literal[ 12 | "number", 13 | "boolean", 14 | "number[]", 15 | "boolean[]" 16 | ]] 17 | 18 | class ObjectParameter(TypedDict): 19 | """Object parameter class""" 20 | name: str 21 | description: NotRequired[str] 22 | required: NotRequired[bool] 23 | type: Literal["object", "object[]"] 24 | attributes: List['Parameter'] 25 | 26 | class StringParameter(TypedDict): 27 | """String parameter class""" 28 | name: str 29 | description: NotRequired[str] 30 | required: NotRequired[bool] 31 | type: Literal["string", "string[]"] 32 | enum: NotRequired[List[str]] 33 | 34 | Parameter = Union[SimpleParameter, ObjectParameter, StringParameter] 35 | 36 | def normalize_parameters(parameters: Optional[List[Parameter]]) -> List[Parameter]: 37 | """Normalize the parameters to ensure they have the correct type and format.""" 38 | if parameters is None: 39 | return [] 40 | return [_normalize_parameter(parameter) for parameter in parameters] 41 | 42 | def _normalize_parameter(parameter: Parameter) -> Parameter: 43 | """Normalize a parameter to ensure it has the correct type and format.""" 44 | if not "type" in parameter: 45 | cast(Any, parameter)['type'] = 'string' 46 | if not 'required' in parameter: 47 | parameter['required'] = True 48 | if not 'description' in parameter: 49 | parameter['description'] = '' 50 | 51 | if 'type' in parameter and (parameter['type'] == 'object' or parameter['type'] == 'object[]'): 52 | cast(Any, parameter)['attributes'] = normalize_parameters(parameter.get('attributes')) 53 | return parameter 54 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/extract.py: -------------------------------------------------------------------------------- 1 | """ 2 | The extract node is responsible for extracting information from a tavily search. 3 | """ 4 | import json 5 | 6 | from langchain_openai import ChatOpenAI 7 | from langchain_core.messages import SystemMessage 8 | 9 | from langchain_core.runnables import RunnableConfig 10 | 11 | from copilotkit.demos.ai_researcher.state import AgentState 12 | 13 | async def extract_node(state: AgentState, config: RunnableConfig): 14 | """ 15 | The extract node is responsible for extracting information from a tavily search. 16 | """ 17 | 18 | current_step = next((step for step in state["steps"] if step["status"] == "pending"), None) 19 | 20 | if current_step is None: 21 | raise ValueError("No current step") 22 | 23 | if current_step["type"] != "search": 24 | raise ValueError("Current step is not of type search") 25 | 26 | system_message = f""" 27 | This step was just executed: {json.dumps(current_step)} 28 | 29 | This is the result of the search: 30 | 31 | Please summarize ONLY the result of the search and include all relevant information from the search and reference links. 32 | DO NOT INCLUDE ANY EXTRA INFORMATION. ALL OF THE INFORMATION YOU ARE LOOKING FOR IS IN THE SEARCH RESULTS. 33 | 34 | DO NOT answer the user's query yet. Just summarize the search results. 35 | 36 | Use markdown formatting and put the references inline and the links at the end. 37 | Like this: 38 | This is a sentence with a reference to a source [source 1][1] and another reference [source 2][2]. 39 | [1]: http://example.com/source1 "Title of Source 1" 40 | [2]: http://example.com/source2 "Title of Source 2" 41 | """ 42 | 43 | response = await ChatOpenAI(model="gpt-4o").ainvoke([ 44 | *state["messages"], 45 | SystemMessage(content=system_message) 46 | ], config) 47 | 48 | current_step["result"] = response.content 49 | current_step["search_result"] = None 50 | current_step["status"] = "complete" 51 | current_step["updates"] = None 52 | 53 | return state 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyderworkspace 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | 120 | # pytype static type analyzer 121 | .pytype/ 122 | 123 | # Cython debug symbols 124 | cython_debug/ -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the main entry point for the AI. 3 | It defines the workflow graph and the entry point for the agent. 4 | """ 5 | # pylint: disable=line-too-long, unused-import 6 | import json 7 | from typing import cast 8 | 9 | from langchain_core.messages import AIMessage, ToolMessage 10 | from langgraph.graph import StateGraph, END 11 | from langgraph.checkpoint.memory import MemorySaver 12 | from copilotkit.demos.research_canvas.state import AgentState 13 | from copilotkit.demos.research_canvas.download import download_node 14 | from copilotkit.demos.research_canvas.chat import chat_node 15 | from copilotkit.demos.research_canvas.search import search_node 16 | from copilotkit.demos.research_canvas.delete import delete_node, perform_delete_node 17 | 18 | # Define a new graph 19 | workflow = StateGraph(AgentState) 20 | workflow.add_node("download", download_node) 21 | workflow.add_node("chat_node", chat_node) 22 | workflow.add_node("search_node", search_node) 23 | workflow.add_node("delete_node", delete_node) 24 | workflow.add_node("perform_delete_node", perform_delete_node) 25 | 26 | def route(state): 27 | """Route after the chat node.""" 28 | 29 | messages = state.get("messages", []) 30 | if messages and isinstance(messages[-1], AIMessage): 31 | ai_message = cast(AIMessage, messages[-1]) 32 | 33 | if ai_message.tool_calls and ai_message.tool_calls[0]["name"] == "Search": 34 | return "search_node" 35 | if ai_message.tool_calls and ai_message.tool_calls[0]["name"] == "DeleteResources": 36 | return "delete_node" 37 | if messages and isinstance(messages[-1], ToolMessage): 38 | return "chat_node" 39 | 40 | return END 41 | 42 | 43 | memory = MemorySaver() 44 | workflow.set_entry_point("download") 45 | workflow.add_edge("download", "chat_node") 46 | workflow.add_conditional_edges("chat_node", route, ["search_node", "chat_node", "delete_node", END]) 47 | workflow.add_edge("delete_node", "perform_delete_node") 48 | workflow.add_edge("perform_delete_node", "chat_node") 49 | workflow.add_edge("search_node", "download") 50 | graph = workflow.compile(checkpointer=memory, interrupt_after=["delete_node"]) 51 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the main entry point for the autotale AI. 3 | It defines the workflow graph and the entry point for the agent. 4 | """ 5 | # pylint: disable=line-too-long, unused-import 6 | 7 | from typing import Any, cast 8 | 9 | from langgraph.graph import StateGraph, END 10 | from langgraph.checkpoint.memory import MemorySaver 11 | 12 | from langchain_core.messages import ToolMessage 13 | 14 | from copilotkit.demos.autotale_ai.state import AgentState 15 | from copilotkit.demos.autotale_ai.chatbot import chatbot_node 16 | from copilotkit.demos.autotale_ai.story.outline import outline_node 17 | from copilotkit.demos.autotale_ai.story.characters import characters_node 18 | from copilotkit.demos.autotale_ai.story.story import story_node 19 | from copilotkit.demos.autotale_ai.story.style import style_node 20 | 21 | 22 | 23 | 24 | def route_story_writing(state): 25 | """Route to story writing nodes.""" 26 | last_message = state["messages"][-1] 27 | 28 | if isinstance(last_message, ToolMessage): 29 | return last_message.name 30 | return END 31 | 32 | # Define a new graph 33 | workflow = StateGraph(AgentState) 34 | workflow.add_node("chatbot_node", cast(Any, chatbot_node)) 35 | workflow.add_node("outline_node", outline_node) 36 | workflow.add_node("characters_node", characters_node) 37 | workflow.add_node("style_node", style_node) 38 | workflow.add_node("story_node", cast(Any, story_node)) 39 | 40 | # Chatbot 41 | workflow.set_entry_point("chatbot_node") 42 | 43 | workflow.add_conditional_edges( 44 | "chatbot_node", 45 | route_story_writing, 46 | { 47 | "set_outline": "outline_node", 48 | "set_characters": "characters_node", 49 | "set_story": "story_node", 50 | "set_style": "style_node", 51 | END: END, 52 | } 53 | ) 54 | workflow.add_edge( 55 | "outline_node", 56 | "chatbot_node" 57 | ) 58 | 59 | workflow.add_edge( 60 | "characters_node", 61 | "chatbot_node" 62 | ) 63 | 64 | workflow.add_edge( 65 | "story_node", 66 | "chatbot_node" 67 | ) 68 | 69 | workflow.add_edge( 70 | "style_node", 71 | "chatbot_node" 72 | ) 73 | 74 | memory = MemorySaver() 75 | 76 | graph = workflow.compile(checkpointer=memory) 77 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the implementation of the download_node function. 3 | """ 4 | 5 | import aiohttp 6 | import html2text 7 | from copilotkit.langchain import copilotkit_emit_state 8 | from langchain_core.runnables import RunnableConfig 9 | from copilotkit.demos.research_canvas.state import AgentState 10 | 11 | _RESOURCE_CACHE = {} 12 | 13 | def get_resource(url: str): 14 | """ 15 | Get a resource from the cache. 16 | """ 17 | return _RESOURCE_CACHE.get(url, "") 18 | 19 | 20 | _USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" # pylint: disable=line-too-long 21 | 22 | async def _download_resource(url: str): 23 | """ 24 | Download a resource from the internet asynchronously. 25 | """ 26 | try: 27 | async with aiohttp.ClientSession() as session: 28 | async with session.get( 29 | url, 30 | headers={"User-Agent": _USER_AGENT}, 31 | timeout=aiohttp.ClientTimeout(total=10) 32 | ) as response: 33 | response.raise_for_status() 34 | html_content = await response.text() 35 | markdown_content = html2text.html2text(html_content) 36 | _RESOURCE_CACHE[url] = markdown_content 37 | return markdown_content 38 | except Exception as e: # pylint: disable=broad-except 39 | _RESOURCE_CACHE[url] = "ERROR" 40 | return f"Error downloading resource: {e}" 41 | 42 | async def download_node(state: AgentState, config: RunnableConfig): 43 | """ 44 | Download resources from the internet. 45 | """ 46 | state["resources"] = state.get("resources", []) 47 | state["logs"] = state.get("logs", []) 48 | resources_to_download = [] 49 | 50 | logs_offset = len(state["logs"]) 51 | 52 | # Find resources that are not downloaded 53 | for resource in state["resources"]: 54 | if not get_resource(resource["url"]): 55 | resources_to_download.append(resource) 56 | state["logs"].append({ 57 | "message": f"Downloading {resource['url']}", 58 | "done": False 59 | }) 60 | 61 | # Emit the state to let the UI update 62 | await copilotkit_emit_state(config, state) 63 | 64 | # Download the resources 65 | for i, resource in enumerate(resources_to_download): 66 | await _download_resource(resource["url"]) 67 | state["logs"][logs_offset + i]["done"] = True 68 | 69 | # update UI 70 | await copilotkit_emit_state(config, state) 71 | 72 | return state 73 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent/joke_agent.py: -------------------------------------------------------------------------------- 1 | """Test Joker Agent""" 2 | 3 | from typing import Any, cast 4 | 5 | from langchain_openai import ChatOpenAI 6 | from langgraph.graph import StateGraph, END 7 | from langgraph.graph import MessagesState 8 | from langgraph.checkpoint.memory import MemorySaver 9 | from langchain_core.runnables import RunnableConfig 10 | from langchain_core.messages import SystemMessage, ToolMessage 11 | 12 | 13 | from copilotkit.langchain import copilotkit_customize_config, copilotkit_exit 14 | 15 | class JokeAgentState(MessagesState): 16 | """Joke Agent State""" 17 | joke: str 18 | 19 | async def joke_node(state: JokeAgentState, config: RunnableConfig): 20 | """ 21 | Make a joke. 22 | """ 23 | 24 | config = copilotkit_customize_config( 25 | config, 26 | emit_messages=True, 27 | emit_intermediate_state=[ 28 | { 29 | "state_key": "joke", 30 | "tool": "make_joke", 31 | "tool_argument": "the_joke" 32 | }, 33 | ] 34 | ) 35 | 36 | system_message = "You make funny jokes." 37 | 38 | joke_tool = { 39 | 'name': 'make_joke', 40 | 'description': """Make a funny joke.""", 41 | 'parameters': { 42 | 'type': 'object', 43 | 'properties': { 44 | 'the_joke': { 45 | 'description': """The joke""", 46 | 'type': 'string', 47 | } 48 | }, 49 | 'required': ['the_joke'] 50 | } 51 | } 52 | 53 | joke_model = ChatOpenAI(model="gpt-4o").bind_tools( 54 | [joke_tool], 55 | parallel_tool_calls=False, 56 | tool_choice="make_joke" 57 | ) 58 | 59 | response = await joke_model.ainvoke([ 60 | *state["messages"], 61 | SystemMessage( 62 | content=system_message 63 | ) 64 | ], config) 65 | 66 | tool_calls = getattr(response, "tool_calls") 67 | 68 | joke = tool_calls[0]["args"]["the_joke"] 69 | 70 | await copilotkit_exit(config) 71 | 72 | return { 73 | "messages": [ 74 | response, 75 | ToolMessage( 76 | name=tool_calls[0]["name"], 77 | content=joke, 78 | tool_call_id=tool_calls[0]["id"] 79 | ) 80 | ], 81 | "joke": joke, 82 | } 83 | 84 | workflow = StateGraph(JokeAgentState) 85 | workflow.add_node("joke_node", cast(Any, joke_node)) 86 | workflow.set_entry_point("joke_node") 87 | 88 | workflow.add_edge("joke_node", END) 89 | memory = MemorySaver() 90 | joke_graph = workflow.compile(checkpointer=memory) 91 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent/email_agent.py: -------------------------------------------------------------------------------- 1 | """Test Joker Agent""" 2 | 3 | from typing import Any, cast 4 | from langchain_openai import ChatOpenAI 5 | from langgraph.graph import StateGraph, END 6 | from langgraph.graph import MessagesState 7 | from langgraph.checkpoint.memory import MemorySaver 8 | from langchain_core.runnables import RunnableConfig 9 | from langchain_core.messages import SystemMessage, ToolMessage 10 | 11 | 12 | from copilotkit.langchain import copilotkit_customize_config, copilotkit_exit 13 | 14 | class EmailAgentState(MessagesState): 15 | """Email Agent State""" 16 | email: str 17 | 18 | async def email_node(state: EmailAgentState, config: RunnableConfig): 19 | """ 20 | Make a joke. 21 | """ 22 | 23 | config = copilotkit_customize_config( 24 | config, 25 | emit_messages=True, 26 | emit_intermediate_state=[ 27 | { 28 | "state_key": "email", 29 | "tool": "write_email", 30 | "tool_argument": "the_email" 31 | }, 32 | ] 33 | ) 34 | 35 | system_message = "You write emails." 36 | 37 | email_tool = { 38 | 'name': 'write_email', 39 | 'description': """Write an email.""", 40 | 'parameters': { 41 | 'type': 'object', 42 | 'properties': { 43 | 'the_email': { 44 | 'description': """The email""", 45 | 'type': 'string', 46 | } 47 | }, 48 | 'required': ['the_email'] 49 | } 50 | } 51 | 52 | email_model = ChatOpenAI(model="gpt-4o").bind_tools( 53 | [email_tool], 54 | parallel_tool_calls=False, 55 | tool_choice="write_email" 56 | ) 57 | 58 | response = await email_model.ainvoke([ 59 | *state["messages"], 60 | SystemMessage( 61 | content=system_message 62 | ) 63 | ], config) 64 | 65 | tool_calls = getattr(response, "tool_calls") 66 | 67 | email = tool_calls[0]["args"]["the_email"] 68 | 69 | await copilotkit_exit(config) 70 | 71 | return { 72 | "messages": [ 73 | response, 74 | ToolMessage( 75 | name=tool_calls[0]["name"], 76 | content=email, 77 | tool_call_id=tool_calls[0]["id"] 78 | ) 79 | ], 80 | "email": email, 81 | } 82 | 83 | workflow = StateGraph(EmailAgentState) 84 | workflow.add_node("email_node", cast(Any, email_node)) 85 | workflow.set_entry_point("email_node") 86 | 87 | workflow.add_edge("email_node", END) 88 | memory = MemorySaver() 89 | email_graph = workflow.compile(checkpointer=memory) 90 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent_anthropic/joke_agent.py: -------------------------------------------------------------------------------- 1 | """Test Joker Agent""" 2 | 3 | from typing import Any, cast 4 | 5 | from langchain_anthropic import ChatAnthropic 6 | from langgraph.graph import StateGraph, END 7 | from langgraph.graph import MessagesState 8 | from langgraph.checkpoint.memory import MemorySaver 9 | from langchain_core.runnables import RunnableConfig 10 | from langchain_core.messages import SystemMessage, ToolMessage 11 | 12 | 13 | from copilotkit.langchain import copilotkit_customize_config, copilotkit_exit 14 | 15 | class JokeAgentState(MessagesState): 16 | """Joke Agent State""" 17 | joke: str 18 | 19 | async def joke_node(state: JokeAgentState, config: RunnableConfig): 20 | """ 21 | Make a joke. 22 | """ 23 | 24 | config = copilotkit_customize_config( 25 | config, 26 | emit_messages=True, 27 | emit_intermediate_state=[ 28 | { 29 | "state_key": "joke", 30 | "tool": "make_joke", 31 | "tool_argument": "the_joke" 32 | }, 33 | ] 34 | ) 35 | 36 | system_message = "You make funny jokes." 37 | 38 | joke_tool = { 39 | 'name': 'make_joke', 40 | 'description': """Make a funny joke.""", 41 | 'parameters': { 42 | 'type': 'object', 43 | 'properties': { 44 | 'the_joke': { 45 | 'description': """The joke""", 46 | 'type': 'string', 47 | } 48 | }, 49 | 'required': ['the_joke'] 50 | } 51 | } 52 | 53 | joke_model = ChatAnthropic( 54 | model_name="claude-3-5-sonnet-20240620", 55 | timeout=None, 56 | stop=None 57 | ).bind_tools( 58 | [joke_tool], 59 | tool_choice="make_joke" 60 | ) 61 | 62 | response = await joke_model.ainvoke([ 63 | SystemMessage( 64 | content=system_message 65 | ), 66 | *state["messages"] 67 | ], config) 68 | 69 | tool_calls = getattr(response, "tool_calls") 70 | 71 | joke = tool_calls[0]["args"]["the_joke"] 72 | 73 | await copilotkit_exit(config) 74 | 75 | return { 76 | "messages": [ 77 | response, 78 | ToolMessage( 79 | name=tool_calls[0]["name"], 80 | content=joke, 81 | tool_call_id=tool_calls[0]["id"] 82 | ) 83 | ], 84 | "joke": joke, 85 | } 86 | 87 | workflow = StateGraph(JokeAgentState) 88 | workflow.add_node("joke_node", cast(Any, joke_node)) 89 | workflow.set_entry_point("joke_node") 90 | 91 | workflow.add_edge("joke_node", END) 92 | memory = MemorySaver() 93 | joke_graph = workflow.compile(checkpointer=memory) 94 | -------------------------------------------------------------------------------- /copilotkit/demos/multi_agent_anthropic/email_agent.py: -------------------------------------------------------------------------------- 1 | """Test Joker Agent""" 2 | 3 | from typing import Any, cast 4 | from langgraph.graph import StateGraph, END 5 | from langgraph.graph import MessagesState 6 | from langgraph.checkpoint.memory import MemorySaver 7 | from langchain_core.runnables import RunnableConfig 8 | from langchain_core.messages import SystemMessage, ToolMessage 9 | from langchain_anthropic import ChatAnthropic 10 | 11 | from copilotkit.langchain import copilotkit_customize_config, copilotkit_exit 12 | 13 | class EmailAgentState(MessagesState): 14 | """Email Agent State""" 15 | email: str 16 | 17 | async def email_node(state: EmailAgentState, config: RunnableConfig): 18 | """ 19 | Make a joke. 20 | """ 21 | 22 | config = copilotkit_customize_config( 23 | config, 24 | emit_messages=True, 25 | emit_intermediate_state=[ 26 | { 27 | "state_key": "email", 28 | "tool": "write_email", 29 | "tool_argument": "the_email" 30 | }, 31 | ] 32 | ) 33 | 34 | system_message = "You write emails." 35 | 36 | email_tool = { 37 | 'name': 'write_email', 38 | 'description': """Write an email.""", 39 | 'parameters': { 40 | 'type': 'object', 41 | 'properties': { 42 | 'the_email': { 43 | 'description': """The email""", 44 | 'type': 'string', 45 | } 46 | }, 47 | 'required': ['the_email'] 48 | } 49 | } 50 | 51 | print("Before email model") 52 | 53 | email_model = ChatAnthropic( 54 | model_name="claude-3-5-sonnet-20240620", 55 | timeout=None, 56 | stop=None 57 | ).bind_tools( 58 | [email_tool], 59 | tool_choice="write_email" 60 | ) 61 | 62 | response = await email_model.ainvoke([ 63 | SystemMessage( 64 | content=system_message 65 | ), 66 | *state["messages"], 67 | ], config) 68 | 69 | tool_calls = getattr(response, "tool_calls") 70 | 71 | email = tool_calls[0]["args"]["the_email"] 72 | 73 | await copilotkit_exit(config) 74 | 75 | return { 76 | "messages": [ 77 | response, 78 | ToolMessage( 79 | name=tool_calls[0]["name"], 80 | content=email, 81 | tool_call_id=tool_calls[0]["id"] 82 | ) 83 | ], 84 | "email": email, 85 | } 86 | 87 | workflow = StateGraph(EmailAgentState) 88 | workflow.add_node("email_node", cast(Any, email_node)) 89 | workflow.set_entry_point("email_node") 90 | 91 | workflow.add_edge("email_node", END) 92 | memory = MemorySaver() 93 | email_graph = workflow.compile(checkpointer=memory) 94 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/search.py: -------------------------------------------------------------------------------- 1 | """ 2 | The search node is responsible for searching the internet for information. 3 | """ 4 | import json 5 | 6 | from langchain_openai import ChatOpenAI 7 | from langchain_core.messages import SystemMessage 8 | 9 | from langchain_core.runnables import RunnableConfig 10 | from langchain_community.tools import TavilySearchResults 11 | 12 | from copilotkit.demos.ai_researcher.state import AgentState 13 | 14 | async def search_node(state: AgentState, config: RunnableConfig): 15 | """ 16 | The search node is responsible for searching the internet for information. 17 | """ 18 | tavily_tool = TavilySearchResults( 19 | max_results=10, 20 | search_depth="advanced", 21 | include_answer=True, 22 | include_raw_content=True, 23 | include_images=True, 24 | ) 25 | 26 | current_step = next((step for step in state["steps"] if step["status"] == "pending"), None) 27 | 28 | if current_step is None: 29 | raise ValueError("No step to search for") 30 | 31 | if current_step["type"] != "search": 32 | raise ValueError("Current step is not a search step") 33 | 34 | system_message = f""" 35 | This is a step in a series of steps that are being executed to answer the user's query. 36 | These are all of the steps: {json.dumps(state["steps"])} 37 | 38 | You are responsible for carrying out the step: {json.dumps(current_step)} 39 | 40 | This is what you need to search for, please come up with a good search query: {current_step["description"]} 41 | """ 42 | model = ChatOpenAI(model="gpt-4o").bind_tools( 43 | [tavily_tool], 44 | parallel_tool_calls=False, 45 | tool_choice=tavily_tool.name 46 | ) 47 | 48 | response = await model.ainvoke([ 49 | *state["messages"], 50 | SystemMessage( 51 | content=system_message 52 | ) 53 | ], config) 54 | 55 | tool_msg = tavily_tool.invoke(response.tool_calls[0]) 56 | 57 | 58 | # system_message = f""" 59 | # This task was just executed: {json.dumps(current_step)} 60 | 61 | # This is the result of the search: 62 | 63 | # {tool_msg} 64 | 65 | # Please summarize the ONLY the result of the search and include all relevant information from the search and reference links. 66 | 67 | # DO NOT INCLUDE ANY EXTRA INFORMATION. ALL OF THE INFORMATION YOU ARE LOOKING FOR IS IN THE SEARCH RESULTS. 68 | 69 | # DO NOT answer the user's query yet. Just summarize the search results. 70 | 71 | # Use markdown formatting and put the references inline and the links at the end. 72 | # Like this: 73 | # This is a sentence with a reference to a source [source 1][1] and another reference [source 2][2]. 74 | # [1]: http://example.com/source1 "Title of Source 1" 75 | # [2]: http://example.com/source2 "Title of Source 2" 76 | # """ 77 | 78 | # response = await ChatOpenAI(model="gpt-4o").ainvoke([ 79 | # *state["messages"], 80 | # SystemMessage(content=system_message) 81 | # ], config) 82 | 83 | current_step["search_result"] = json.loads(tool_msg.content) 84 | current_step["updates"] = "Extracting information..." 85 | 86 | return state 87 | -------------------------------------------------------------------------------- /copilotkit/demos/starter/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the main entry point for the AI. 3 | It defines the workflow graph and the entry point for the agent. 4 | """ 5 | # pylint: disable=line-too-long, unused-import 6 | 7 | from typing import cast, TypedDict, Any 8 | from langchain_openai import ChatOpenAI 9 | from langchain_core.messages import SystemMessage, ToolMessage, AIMessage, HumanMessage 10 | from langchain_core.runnables import RunnableConfig 11 | from langgraph.graph import StateGraph, END 12 | from langgraph.checkpoint.memory import MemorySaver 13 | from langgraph.graph import MessagesState 14 | from copilotkit.langchain import copilotkit_customize_config 15 | 16 | class Translations(TypedDict): 17 | """Contains the translations in four different languages.""" 18 | translation_es: str 19 | translation_fr: str 20 | translation_de: str 21 | 22 | class AgentState(MessagesState): 23 | """Contains the state of the agent.""" 24 | translations: Translations 25 | input: str 26 | 27 | async def translate_node(state: AgentState, config: RunnableConfig): 28 | """Chatbot that translates text""" 29 | 30 | config = copilotkit_customize_config( 31 | config, 32 | # config emits messages by default, so this is not needed: 33 | ## emit_messages=True, 34 | emit_intermediate_state=[ 35 | { 36 | "state_key": "translations", 37 | "tool": "translate" 38 | } 39 | ] 40 | ) 41 | 42 | model = ChatOpenAI(model="gpt-4o").bind_tools( 43 | [Translations], 44 | parallel_tool_calls=False, 45 | tool_choice=( 46 | None if state["messages"] and 47 | isinstance(state["messages"][-1], HumanMessage) 48 | else "Translations" 49 | ) 50 | ) 51 | 52 | response = await model.ainvoke([ 53 | SystemMessage( 54 | content=f""" 55 | You are a helpful assistant that translates text to different languages 56 | (Spanish, French and German). 57 | Don't ask for confirmation before translating. 58 | { 59 | 'The user is currently working on translating this text: "' + 60 | state["input"] + '"' if state.get("input") else "" 61 | } 62 | """ 63 | ), 64 | *state["messages"], 65 | ], config) 66 | 67 | if hasattr(response, "tool_calls") and len(getattr(response, "tool_calls")) > 0: 68 | ai_message = cast(AIMessage, response) 69 | return { 70 | "messages": [ 71 | response, 72 | ToolMessage( 73 | content="Translated!", 74 | tool_call_id=ai_message.tool_calls[0]["id"] 75 | ) 76 | ], 77 | "translations": cast(AIMessage, response).tool_calls[0]["args"], 78 | } 79 | 80 | return { 81 | "messages": [ 82 | response, 83 | ], 84 | } 85 | 86 | workflow = StateGraph(AgentState) 87 | workflow.add_node("translate_node", cast(Any, translate_node)) 88 | workflow.set_entry_point("translate_node") 89 | workflow.add_edge("translate_node", END) 90 | memory = MemorySaver() 91 | graph = workflow.compile(checkpointer=memory) 92 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/summarize.py: -------------------------------------------------------------------------------- 1 | """ 2 | The summarize node is responsible for summarizing the information. 3 | """ 4 | 5 | import json 6 | from langchain_openai import ChatOpenAI 7 | from langchain_core.messages import SystemMessage, ToolMessage 8 | from langchain_core.runnables import RunnableConfig 9 | from copilotkit.langchain import copilotkit_customize_config 10 | 11 | from copilotkit.demos.ai_researcher.state import AgentState 12 | 13 | async def summarize_node(state: AgentState, config: RunnableConfig): 14 | """ 15 | The summarize node is responsible for summarizing the information. 16 | """ 17 | 18 | config = copilotkit_customize_config( 19 | config, 20 | emit_messages=True, 21 | emit_intermediate_state=[ 22 | { 23 | "state_key": "answer", 24 | "tool": "summarize", 25 | }, 26 | ] 27 | ) 28 | 29 | system_message = f""" 30 | The system has performed a series of steps to answer the user's query. 31 | These are all of the steps: {json.dumps(state["steps"])} 32 | 33 | Please summarize the final result and include all relevant information and reference links. 34 | """ 35 | 36 | summarize_tool = { 37 | 'name': 'summarize', 38 | 'description': """ 39 | Summarize the final result. Make sure that the summary is complete and includes all relevant information and reference links. 40 | """, 41 | 'parameters': { 42 | 'type': 'object', 43 | 'properties': { 44 | 'markdown': { 45 | 'description': 'The markdown formatted summary of the final result.', 46 | 'type': 'string' 47 | }, 48 | 'references': { 49 | 'description': """A list of references.""", 50 | 'type': 'array', 51 | 'items': { 52 | 'type': 'object', 53 | 'properties': { 54 | 55 | 'title': { 56 | 'description': 'The title of the reference.', 57 | 'type': 'string' 58 | }, 59 | 'url': { 60 | 'description': 'The url of the reference.', 61 | 'type': 'string' 62 | }, 63 | }, 64 | 'required': ['title', 'url'] 65 | } 66 | } 67 | }, 68 | 'required': ['markdown', 'references'] 69 | } 70 | } 71 | 72 | response = await ChatOpenAI(model="gpt-4o").bind_tools([summarize_tool], parallel_tool_calls=False, tool_choice="summarize").ainvoke([ 73 | *state["messages"], 74 | SystemMessage( 75 | content=system_message 76 | ) 77 | ], config) 78 | 79 | return { 80 | "messages": [ 81 | response, 82 | ToolMessage( 83 | name=response.tool_calls[0]["name"], 84 | content="summarized.", 85 | tool_call_id=response.tool_calls[0]["id"] 86 | ) 87 | ], 88 | "answer": response.tool_calls[0]["args"], 89 | } 90 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/story/story.py: -------------------------------------------------------------------------------- 1 | """ 2 | Story node. 3 | """ 4 | 5 | from typing import List 6 | import json 7 | import asyncio 8 | 9 | from langchain_core.tools import tool 10 | from langchain_core.messages import SystemMessage 11 | from langchain_core.runnables import RunnableConfig 12 | from langchain_core.pydantic_v1 import BaseModel, Field 13 | from langchain_openai import ChatOpenAI 14 | 15 | from copilotkit.demos.autotale_ai.state import AgentState, Character 16 | 17 | 18 | class ImageDescription(BaseModel): 19 | """ 20 | Represents the description of an image of a character in the story. 21 | """ 22 | description: str 23 | 24 | async def _generate_page_image_description( 25 | messages: list, 26 | page_content: str, 27 | characters: List[Character], 28 | style: str, 29 | config: RunnableConfig 30 | ): 31 | """ 32 | Generate a description of the image of a character. 33 | """ 34 | 35 | system_message = SystemMessage( 36 | content= f""" 37 | The user and the AI are having a conversation about writing a children's story. 38 | It's your job to generate a vivid description of a page in the story. 39 | Make the description as detailed as possible. 40 | 41 | These are the characters in the story: 42 | {characters} 43 | 44 | This is the page content: 45 | {page_content} 46 | 47 | This is the graphical style of the story: 48 | {style} 49 | 50 | Imagine an image of the page. Describe the looks of the page in great detail. 51 | Also describe the setting in which the image is taken. 52 | Make sure to include the name of the characters and full description of the characters in your output. 53 | Describe the style in detail, it's very important for image generation. 54 | """ 55 | ) 56 | model = ChatOpenAI(model="gpt-4o").with_structured_output(ImageDescription) 57 | response = await model.ainvoke([ 58 | *messages, 59 | system_message 60 | ], config) 61 | 62 | return response.description 63 | 64 | class StoryPage(BaseModel): 65 | """ 66 | Represents a page in the children's story. Keep it simple, 3-4 sentences per page. 67 | """ 68 | content: str = Field(..., description="A single page in the story") 69 | 70 | @tool 71 | def set_story(pages: List[StoryPage]): 72 | """ 73 | Considering the outline and characters, write a story. 74 | Keep it simple, 3-4 sentences per page. 75 | 5 pages max. 76 | (If the user mentions "chapters" in the conversation they mean pages, treat it as such) 77 | """ 78 | return pages 79 | 80 | async def story_node(state: AgentState, config: RunnableConfig): 81 | """ 82 | The story node is responsible for extracting the story from the conversation. 83 | """ 84 | last_message = state["messages"][-1] 85 | pages = json.loads(last_message.content)["pages"] 86 | characters = state.get("characters", []) 87 | style = state.get("style", "Pixar movies style 3D images") 88 | 89 | async def generate_page(page): 90 | description = await _generate_page_image_description( 91 | state["messages"], 92 | page["content"], 93 | characters, 94 | style, 95 | config 96 | ) 97 | return { 98 | "content": page["content"], 99 | "image_description": description 100 | } 101 | 102 | tasks = [generate_page(page) for page in pages] 103 | story = await asyncio.gather(*tasks) 104 | 105 | return { 106 | "story": story 107 | } 108 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/search.py: -------------------------------------------------------------------------------- 1 | """ 2 | The search node is responsible for searching the internet for information. 3 | """ 4 | 5 | import os 6 | from typing import cast, List 7 | from pydantic import BaseModel, Field 8 | from langchain_core.runnables import RunnableConfig 9 | from langchain_core.messages import AIMessage, ToolMessage, SystemMessage 10 | from langchain.tools import tool 11 | from tavily import TavilyClient 12 | from copilotkit.langchain import copilotkit_emit_state, copilotkit_customize_config 13 | from copilotkit.demos.research_canvas.state import AgentState 14 | from copilotkit.demos.research_canvas.model import get_model 15 | 16 | class ResourceInput(BaseModel): 17 | """A resource with a short description""" 18 | url: str = Field(description="The URL of the resource") 19 | title: str = Field(description="The title of the resource") 20 | description: str = Field(description="A short description of the resource") 21 | 22 | @tool 23 | def ExtractResources(resources: List[ResourceInput]): # pylint: disable=invalid-name,unused-argument 24 | """Extract the 3-5 most relevant resources from a search result.""" 25 | 26 | tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) 27 | 28 | async def search_node(state: AgentState, config: RunnableConfig): 29 | """ 30 | The search node is responsible for searching the internet for resources. 31 | """ 32 | ai_message = cast(AIMessage, state["messages"][-1]) 33 | 34 | state["resources"] = state.get("resources", []) 35 | state["logs"] = state.get("logs", []) 36 | queries = ai_message.tool_calls[0]["args"]["queries"] 37 | 38 | for query in queries: 39 | state["logs"].append({ 40 | "message": f"Search for {query}", 41 | "done": False 42 | }) 43 | 44 | await copilotkit_emit_state(config, state) 45 | 46 | search_results = [] 47 | 48 | for i, query in enumerate(queries): 49 | response = tavily_client.search(query) 50 | search_results.append(response) 51 | state["logs"][i]["done"] = True 52 | await copilotkit_emit_state(config, state) 53 | 54 | config = copilotkit_customize_config( 55 | config, 56 | emit_intermediate_state=[{ 57 | "state_key": "resources", 58 | "tool": "ExtractResources", 59 | "tool_argument": "resources", 60 | }], 61 | ) 62 | 63 | 64 | model = get_model(state) 65 | ainvoke_kwargs = {} 66 | if model.__class__.__name__ in ["ChatOpenAI"]: 67 | ainvoke_kwargs["parallel_tool_calls"] = False 68 | 69 | # figure out which resources to use 70 | response = await get_model(state).bind_tools( 71 | [ExtractResources], 72 | tool_choice="ExtractResources", 73 | **ainvoke_kwargs 74 | ).ainvoke([ 75 | SystemMessage( 76 | content=""" 77 | You need to extract the 3-5 most relevant resources from the following search results. 78 | """ 79 | ), 80 | *state["messages"], 81 | ToolMessage( 82 | tool_call_id=ai_message.tool_calls[0]["id"], 83 | content=f"Performed search: {search_results}" 84 | ) 85 | ], config) 86 | 87 | state["logs"] = [] 88 | await copilotkit_emit_state(config, state) 89 | 90 | ai_message_response = cast(AIMessage, response) 91 | resources = ai_message_response.tool_calls[0]["args"]["resources"] 92 | 93 | state["resources"].extend(resources) 94 | 95 | state["messages"].append(ToolMessage( 96 | tool_call_id=ai_message.tool_calls[0]["id"], 97 | content=f"Added the following resources: {resources}" 98 | )) 99 | 100 | return state 101 | -------------------------------------------------------------------------------- /copilotkit/demos/ai_researcher/steps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main chatbot node. 3 | """ 4 | 5 | 6 | from langchain_openai import ChatOpenAI 7 | from langchain_core.messages import SystemMessage, ToolMessage 8 | from langchain_core.runnables import RunnableConfig 9 | from copilotkit.langchain import copilotkit_customize_config 10 | 11 | from copilotkit.demos.ai_researcher.state import AgentState 12 | 13 | # pylint: disable=line-too-long 14 | 15 | async def steps_node(state: AgentState, config: RunnableConfig): 16 | """ 17 | The steps node is responsible for building the steps in the research process. 18 | """ 19 | 20 | config = copilotkit_customize_config( 21 | config, 22 | emit_messages=True, 23 | emit_intermediate_state=[ 24 | { 25 | "state_key": "steps", 26 | "tool": "search", 27 | "tool_argument": "steps" 28 | }, 29 | ] 30 | ) 31 | 32 | system_message = """ 33 | You are a search assistant. Your task is to help the user with complex search queries by breaking the down into smaller steps. 34 | 35 | These steps are then executed serially. In the end, a final answer is produced in markdown format. 36 | """ 37 | 38 | # use the openai tool format to get access to enums 39 | search_tool = { 40 | 'name': 'search', 41 | 'description': """ 42 | Break the user's query into smaller steps. 43 | 44 | Use step type "search" to search the web for information. 45 | 46 | Make sure to add all the steps needed to answer the user's query. 47 | """, 48 | 'parameters': { 49 | 'type': 'object', 50 | 'properties': { 51 | 'steps': { 52 | 'description': """The steps to be executed.""", 53 | 'type': 'array', 54 | 'items': { 55 | 'type': 'object', 56 | 'properties': { 57 | 'id': { 58 | 'description': 'The id of the step. This is used to identify the step in the state. Just make sure it is unique.', 59 | 'type': 'string' 60 | }, 61 | 'description': { 62 | 'description': 'The description of the step, i.e. "search for information about the latest AI news"', 63 | 'type': 'string' 64 | }, 65 | 'status': { 66 | 'description': 'The status of the step. Always "pending".', 67 | 'type': 'string', 68 | 'enum': ['pending'] 69 | }, 70 | 'type': { 71 | 'description': 'The type of step.', 72 | 'type': 'string', 73 | 'enum': ['search'] 74 | } 75 | }, 76 | 'required': ['id', 'description', 'status', 'type'] 77 | } 78 | } 79 | }, 80 | 'required': ['steps'] 81 | } 82 | } 83 | 84 | response = await ChatOpenAI(model="gpt-4o").bind_tools([search_tool], parallel_tool_calls=False, tool_choice="search").ainvoke([ 85 | *state["messages"], 86 | SystemMessage( 87 | content=system_message 88 | ) 89 | ], config) 90 | 91 | steps = response.tool_calls[0]["args"]["steps"] 92 | 93 | if len(steps): 94 | steps[0]["updates"] = "Searching the web..." 95 | 96 | return { 97 | "messages": [ 98 | response, 99 | ToolMessage( 100 | name=response.tool_calls[0]["name"], 101 | content="executing steps...", 102 | tool_call_id=response.tool_calls[0]["id"] 103 | ) 104 | ], 105 | "steps": steps.copy(), 106 | } 107 | -------------------------------------------------------------------------------- /copilotkit/demos/autotale_ai/chatbot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main chatbot node. 3 | """ 4 | 5 | import json 6 | 7 | from langchain_openai import ChatOpenAI 8 | from langchain_core.messages import SystemMessage 9 | from langchain_core.runnables import RunnableConfig 10 | from langchain_core.messages import ToolMessage, AIMessage 11 | 12 | from copilotkit.demos.autotale_ai.state import AgentState 13 | from copilotkit.demos.autotale_ai.story.outline import set_outline 14 | from copilotkit.demos.autotale_ai.story.characters import set_characters 15 | from copilotkit.demos.autotale_ai.story.story import set_story 16 | from copilotkit.demos.autotale_ai.story.style import set_style 17 | from copilotkit.langchain import copilotkit_customize_config 18 | # pylint: disable=line-too-long 19 | 20 | async def chatbot_node(state: AgentState, config: RunnableConfig): 21 | """ 22 | The chatbot is responsible for answering the user's questions and selecting 23 | the next route. 24 | """ 25 | 26 | 27 | config = copilotkit_customize_config( 28 | config, 29 | emit_messages=True, 30 | emit_intermediate_state= [ 31 | { 32 | "state_key": "outline", 33 | "tool": "set_outline", 34 | "tool_argument": "outline" 35 | }, 36 | { 37 | "state_key": "characters", 38 | "tool": "set_characters", 39 | "tool_argument": "characters" 40 | }, 41 | { 42 | "state_key": "story", 43 | "tool": "set_story", 44 | "tool_argument": "story" 45 | } 46 | ] 47 | ) 48 | 49 | tools = [set_outline, set_style] 50 | 51 | if state.get("outline") is not None: 52 | tools.append(set_characters) 53 | 54 | if state.get("characters") is not None: 55 | tools.append(set_story) 56 | 57 | system_message = """ 58 | You help the user write a children's story. Please assist the user by either having a conversation or by 59 | taking the appropriate actions to advance the story writing process. Do not repeat the whole story again. 60 | 61 | Your state consists of the following concepts: 62 | 63 | - Outline: The outline of the story. Should be short, 2-3 sentences. 64 | - Characters: The characters that make up the story (depends on outline) 65 | - Story: The final story result. (depends on outline & characters) 66 | 67 | If the user asks you to make changes to any of these, 68 | you MUST take into account dependencies and make the changes accordingly. 69 | 70 | Example: If after coming up with the characters, the user requires changes in the outline, you must first 71 | regenerate the outline. 72 | 73 | Dont bother the user too often, just call the tools. 74 | Especially, dont' repeat the story and so on, just call the tools. 75 | """ 76 | if state.get("outline") is not None: 77 | system_message += f"\n\nThe current outline is: {state['outline']}" 78 | 79 | if state.get("characters") is not None: 80 | system_message += f"\n\nThe current characters are: {json.dumps(state['characters'])}" 81 | 82 | if state.get("story") is not None: 83 | system_message += f"\n\nThe current story is: {json.dumps(state['story'])}" 84 | 85 | last_message = state["messages"][-1] if state["messages"] else None 86 | 87 | if last_message and isinstance(last_message, AIMessage): 88 | system_message += """ 89 | The user did not submit the last message. This means they probably changed the state of the story by 90 | in the UI. Figure out if you need to regenerate the outline, characters or story and call the appropriate 91 | tool. If not, just respond to the user. 92 | """ 93 | 94 | 95 | response = await ChatOpenAI(model="gpt-4o").bind_tools(tools, parallel_tool_calls=False).ainvoke([ 96 | *state["messages"], 97 | SystemMessage( 98 | content=system_message 99 | ) 100 | ], config) 101 | 102 | tool_calls = getattr(response, "tool_calls", None) 103 | 104 | if not tool_calls: 105 | return { 106 | "messages": response, 107 | } 108 | 109 | return { 110 | "messages": [ 111 | response, 112 | ToolMessage( 113 | name=tool_calls[0]["name"], 114 | content=json.dumps(tool_calls[0]["args"]), 115 | tool_call_id=tool_calls[0]["id"] 116 | ) 117 | ], 118 | } 119 | -------------------------------------------------------------------------------- /copilotkit/demos/research_canvas/chat.py: -------------------------------------------------------------------------------- 1 | """Chat Node""" 2 | 3 | from typing import List, cast 4 | from langchain_core.runnables import RunnableConfig 5 | from langchain_core.messages import SystemMessage, AIMessage, ToolMessage 6 | from langchain.tools import tool 7 | from copilotkit.langchain import copilotkit_customize_config 8 | from copilotkit.demos.research_canvas.state import AgentState 9 | from copilotkit.demos.research_canvas.model import get_model 10 | from copilotkit.demos.research_canvas.download import get_resource 11 | 12 | @tool 13 | def Search(queries: List[str]): # pylint: disable=invalid-name,unused-argument 14 | """A list of one or more search queries to find good resources to support the research.""" 15 | 16 | @tool 17 | def WriteReport(report: str): # pylint: disable=invalid-name,unused-argument 18 | """Write the research report.""" 19 | 20 | @tool 21 | def WriteResearchQuestion(research_question: str): # pylint: disable=invalid-name,unused-argument 22 | """Write the research question.""" 23 | 24 | @tool 25 | def DeleteResources(urls: List[str]): # pylint: disable=invalid-name,unused-argument 26 | """Delete the URLs from the resources.""" 27 | 28 | 29 | async def chat_node(state: AgentState, config: RunnableConfig): 30 | """ 31 | Chat Node 32 | """ 33 | 34 | config = copilotkit_customize_config( 35 | config, 36 | emit_intermediate_state=[{ 37 | "state_key": "report", 38 | "tool": "WriteReport", 39 | "tool_argument": "report", 40 | }, { 41 | "state_key": "research_question", 42 | "tool": "WriteResearchQuestion", 43 | "tool_argument": "research_question", 44 | }], 45 | emit_tool_calls="DeleteResources" 46 | ) 47 | 48 | state["resources"] = state.get("resources", []) 49 | research_question = state.get("research_question", "") 50 | report = state.get("report", "") 51 | 52 | resources = [] 53 | 54 | for resource in state["resources"]: 55 | content = get_resource(resource["url"]) 56 | if content == "ERROR": 57 | continue 58 | resources.append({ 59 | **resource, 60 | "content": content 61 | }) 62 | 63 | model = get_model(state) 64 | # Prepare the kwargs for the ainvoke method 65 | ainvoke_kwargs = {} 66 | if model.__class__.__name__ in ["ChatOpenAI"]: 67 | ainvoke_kwargs["parallel_tool_calls"] = False 68 | 69 | response = await model.bind_tools( 70 | [ 71 | Search, 72 | WriteReport, 73 | WriteResearchQuestion, 74 | DeleteResources, 75 | ], 76 | **ainvoke_kwargs 77 | ).ainvoke([ 78 | SystemMessage( 79 | content=f""" 80 | You are a research assistant. You help the user with writing a research report. 81 | Do not recite the resources, instead use them to answer the user's question. 82 | You should use the search tool to get resources before answering the user's question. 83 | If you finished writing the report, ask the user proactively for next steps, changes etc, make it engaging. 84 | To write the report, you should use the WriteReport tool. Never EVER respond with the report, only use the tool. 85 | 86 | This is the research question: 87 | {research_question} 88 | 89 | This is the research report: 90 | {report} 91 | 92 | Here are the resources that you have available: 93 | {resources} 94 | """ 95 | ), 96 | *state["messages"], 97 | ], config) 98 | 99 | ai_message = cast(AIMessage, response) 100 | 101 | if ai_message.tool_calls: 102 | if ai_message.tool_calls[0]["name"] == "WriteReport": 103 | return { 104 | "report": ai_message.tool_calls[0]["args"]["report"], 105 | "messages": [ai_message, ToolMessage( 106 | tool_call_id=ai_message.tool_calls[0]["id"], 107 | content="Report written." 108 | )] 109 | } 110 | if ai_message.tool_calls[0]["name"] == "WriteResearchQuestion": 111 | return { 112 | "research_question": ai_message.tool_calls[0]["args"]["research_question"], 113 | "messages": [ai_message, ToolMessage( 114 | tool_call_id=ai_message.tool_calls[0]["id"], 115 | content="Research question written." 116 | )] 117 | } 118 | 119 | return { 120 | "messages": response 121 | } 122 | -------------------------------------------------------------------------------- /copilotkit/demos/wait_user_input/agent.py: -------------------------------------------------------------------------------- 1 | # Set up the state 2 | from langgraph.graph import MessagesState, START 3 | 4 | # Set up the tool 5 | # We will have one real tool - a search tool 6 | # We'll also have one "fake" tool - a "ask_human" tool 7 | # Here we define any ACTUAL tools 8 | from langchain_core.tools import tool 9 | from langgraph.prebuilt import ToolNode 10 | from langchain_core.messages import AIMessage 11 | from copilotkit.langchain import copilotkit_customize_config 12 | 13 | 14 | @tool 15 | def search(query: str): 16 | """Call to surf the web.""" 17 | # This is a placeholder for the actual implementation 18 | # Don't let the LLM know this though 😊 19 | return f"I looked up: {query}. Result: It's sunny in San Francisco, but you better look out if you're a Gemini 😈." 20 | 21 | 22 | tools = [search] 23 | tool_node = ToolNode(tools) 24 | 25 | # Set up the model 26 | #from langchain_anthropic import ChatAnthropic 27 | from langchain_openai import ChatOpenAI 28 | 29 | # model = ChatAnthropic(model="claude-3-5-sonnet-20240620") 30 | model = ChatOpenAI(model="gpt-4o") 31 | 32 | from pydantic import BaseModel 33 | 34 | 35 | # We are going "bind" all tools to the model 36 | # We have the ACTUAL tools from above, but we also need a mock tool to ask a human 37 | # Since `bind_tools` takes in tools but also just tool definitions, 38 | # We can define a tool definition for `ask_human` 39 | class AskHuman(BaseModel): 40 | """Ask the human a question""" 41 | 42 | question: str 43 | 44 | 45 | model = model.bind_tools(tools + [AskHuman]) 46 | 47 | # Define nodes and conditional edges 48 | 49 | 50 | # Define the function that determines whether to continue or not 51 | def should_continue(state): 52 | messages = state["messages"] 53 | last_message = messages[-1] 54 | # If there is no function call, then we finish 55 | if not last_message.tool_calls: 56 | return "end" 57 | # If tool call is asking Human, we return that node 58 | # You could also add logic here to let some system know that there's something that requires Human input 59 | # For example, send a slack message, etc 60 | elif last_message.tool_calls[0]["name"] == "AskHuman": 61 | return "ask_human" 62 | # Otherwise if there is, we continue 63 | else: 64 | return "continue" 65 | 66 | 67 | # Define the function that calls the model 68 | def call_model(state, config): 69 | 70 | config = copilotkit_customize_config( 71 | config, 72 | emit_tool_calls="AskHuman", 73 | ) 74 | messages = state["messages"] 75 | response = model.invoke(messages, config=config) 76 | # We return a list, because this will get added to the existing list 77 | return {"messages": [response]} 78 | 79 | 80 | # We define a fake node to ask the human 81 | def ask_human(state): 82 | pass 83 | 84 | 85 | # Build the graph 86 | 87 | from langgraph.graph import END, StateGraph 88 | 89 | # Define a new graph 90 | workflow = StateGraph(MessagesState) 91 | 92 | # Define the three nodes we will cycle between 93 | workflow.add_node("agent", call_model) 94 | workflow.add_node("action", tool_node) 95 | workflow.add_node("ask_human", ask_human) 96 | 97 | # Set the entrypoint as `agent` 98 | # This means that this node is the first one called 99 | workflow.add_edge(START, "agent") 100 | 101 | # We now add a conditional edge 102 | workflow.add_conditional_edges( 103 | # First, we define the start node. We use `agent`. 104 | # This means these are the edges taken after the `agent` node is called. 105 | "agent", 106 | # Next, we pass in the function that will determine which node is called next. 107 | should_continue, 108 | # Finally we pass in a mapping. 109 | # The keys are strings, and the values are other nodes. 110 | # END is a special node marking that the graph should finish. 111 | # What will happen is we will call `should_continue`, and then the output of that 112 | # will be matched against the keys in this mapping. 113 | # Based on which one it matches, that node will then be called. 114 | { 115 | # If `tools`, then we call the tool node. 116 | "continue": "action", 117 | # We may ask the human 118 | "ask_human": "ask_human", 119 | # Otherwise we finish. 120 | "end": END, 121 | }, 122 | ) 123 | 124 | # We now add a normal edge from `tools` to `agent`. 125 | # This means that after `tools` is called, `agent` node is called next. 126 | workflow.add_edge("action", "agent") 127 | 128 | # After we get back the human response, we go back to the agent 129 | workflow.add_edge("ask_human", "agent") 130 | 131 | # Set up memory 132 | from langgraph.checkpoint.memory import MemorySaver 133 | 134 | memory = MemorySaver() 135 | 136 | # Finally, we compile it! 137 | # This compiles it into a LangChain Runnable, 138 | # meaning you can use it as you would any other runnable 139 | # We add a breakpoint BEFORE the `ask_human` node so it never executes 140 | graph = workflow.compile(checkpointer=memory, interrupt_after=["ask_human"]) 141 | -------------------------------------------------------------------------------- /copilotkit/sdk.py: -------------------------------------------------------------------------------- 1 | """CopilotKit SDK""" 2 | 3 | from pprint import pformat 4 | from typing import List, Callable, Union, Optional, TypedDict, Any, Coroutine 5 | from .agent import Agent, AgentDict 6 | from .action import Action, ActionDict, ActionResultDict 7 | from .types import Message 8 | from .exc import ( 9 | ActionNotFoundException, 10 | AgentNotFoundException, 11 | ActionExecutionException, 12 | AgentExecutionException 13 | ) 14 | from .logging import get_logger, bold 15 | 16 | 17 | COPILOTKIT_SDK_VERSION = "0.1.22" 18 | 19 | logger = get_logger(__name__) 20 | 21 | class InfoDict(TypedDict): 22 | """Info dictionary""" 23 | sdkVersion: str 24 | actions: List[ActionDict] 25 | agents: List[AgentDict] 26 | 27 | class CopilotKitSDKContext(TypedDict): 28 | """CopilotKit SDK Context""" 29 | properties: Any 30 | frontend_url: Optional[str] 31 | 32 | class CopilotKitSDK: 33 | """CopilotKit SDK""" 34 | 35 | def __init__( 36 | self, 37 | *, 38 | actions: Optional[ 39 | Union[ 40 | List[Action], 41 | Callable[[CopilotKitSDKContext], List[Action]] 42 | ] 43 | ] = None, 44 | agents: Optional[ 45 | Union[ 46 | List[Agent], 47 | Callable[[CopilotKitSDKContext], List[Agent]] 48 | ] 49 | ] = None, 50 | ): 51 | self.agents = agents or [] 52 | self.actions = actions or [] 53 | 54 | def info( 55 | self, 56 | *, 57 | context: CopilotKitSDKContext 58 | ) -> InfoDict: 59 | """Returns information about available actions and agents""" 60 | 61 | actions = self.actions(context) if callable(self.actions) else self.actions 62 | agents = self.agents(context) if callable(self.agents) else self.agents 63 | 64 | actions_list = [action.dict_repr() for action in actions] 65 | agents_list = [agent.dict_repr() for agent in agents] 66 | 67 | logger.debug(bold("Handling info request:")) 68 | logger.debug("--------------------------") 69 | logger.debug(bold("Context:")) 70 | logger.debug(pformat(context)) 71 | logger.debug(bold("Actions:")) 72 | logger.debug(pformat(actions_list)) 73 | logger.debug(bold("Agents:")) 74 | logger.debug(pformat(agents_list)) 75 | logger.debug("--------------------------") 76 | 77 | return { 78 | "actions": actions_list, 79 | "agents": agents_list, 80 | "sdkVersion": COPILOTKIT_SDK_VERSION 81 | } 82 | 83 | def _get_action( 84 | self, 85 | *, 86 | context: CopilotKitSDKContext, 87 | name: str, 88 | ) -> Action: 89 | """Get an action by name""" 90 | actions = self.actions(context) if callable(self.actions) else self.actions 91 | action = next((action for action in actions if action.name == name), None) 92 | if action is None: 93 | raise ActionNotFoundException(name) 94 | return action 95 | 96 | def execute_action( 97 | self, 98 | *, 99 | context: CopilotKitSDKContext, 100 | name: str, 101 | arguments: dict, 102 | ) -> Coroutine[Any, Any, ActionResultDict]: 103 | """Execute an action""" 104 | 105 | action = self._get_action(context=context, name=name) 106 | 107 | logger.info(bold("Handling execute action request:")) 108 | logger.info("--------------------------") 109 | logger.info(bold("Context:")) 110 | logger.info(pformat(context)) 111 | logger.info(bold("Action:")) 112 | logger.info(pformat(action.dict_repr())) 113 | logger.info(bold("Arguments:")) 114 | logger.info(pformat(arguments)) 115 | logger.info("--------------------------") 116 | 117 | try: 118 | result = action.execute(arguments=arguments) 119 | return result 120 | except Exception as error: 121 | raise ActionExecutionException(name, error) from error 122 | 123 | def execute_agent( # pylint: disable=too-many-arguments 124 | self, 125 | *, 126 | context: CopilotKitSDKContext, 127 | name: str, 128 | thread_id: str, 129 | node_name: str, 130 | state: dict, 131 | messages: List[Message], 132 | actions: List[ActionDict], 133 | ) -> Any: 134 | """Execute an agent""" 135 | agents = self.agents(context) if callable(self.agents) else self.agents 136 | agent = next((agent for agent in agents if agent.name == name), None) 137 | if agent is None: 138 | raise AgentNotFoundException(name) 139 | 140 | logger.info(bold("Handling execute agent request:")) 141 | logger.info("--------------------------") 142 | logger.info(bold("Context:")) 143 | logger.info(pformat(context)) 144 | logger.info(bold("Agent:")) 145 | logger.info(pformat(agent.dict_repr())) 146 | logger.info(bold("Thread ID:")) 147 | logger.info(thread_id) 148 | logger.info(bold("Node Name:")) 149 | logger.info(node_name) 150 | logger.info(bold("State:")) 151 | logger.info(pformat(state)) 152 | logger.info(bold("Messages:")) 153 | logger.info(pformat(messages)) 154 | logger.info(bold("Actions:")) 155 | logger.info(pformat(actions)) 156 | logger.info("--------------------------") 157 | 158 | try: 159 | return agent.execute( 160 | thread_id=thread_id, 161 | node_name=node_name, 162 | state=state, 163 | messages=messages, 164 | actions=actions, 165 | ) 166 | except Exception as error: 167 | raise AgentExecutionException(name, error) from error 168 | -------------------------------------------------------------------------------- /copilotkit/integrations/fastapi.py: -------------------------------------------------------------------------------- 1 | """FastAPI integration""" 2 | 3 | import logging 4 | 5 | from typing import List, Any, cast 6 | from fastapi import FastAPI, Request, HTTPException 7 | from fastapi.responses import JSONResponse, StreamingResponse 8 | from ..sdk import CopilotKitSDK, CopilotKitSDKContext 9 | from ..types import Message 10 | from ..exc import ( 11 | ActionNotFoundException, 12 | ActionExecutionException, 13 | AgentNotFoundException, 14 | AgentExecutionException, 15 | ) 16 | from ..action import ActionDict 17 | 18 | logging.basicConfig(level=logging.ERROR) 19 | logger = logging.getLogger(__name__) 20 | 21 | def add_fastapi_endpoint(fastapi_app: FastAPI, sdk: CopilotKitSDK, prefix: str): 22 | """Add FastAPI endpoint""" 23 | async def make_handler(request: Request): 24 | return await handler(request, sdk) 25 | 26 | # Ensure the prefix starts with a slash and remove trailing slashes 27 | normalized_prefix = '/' + prefix.strip('/') 28 | 29 | fastapi_app.add_api_route( 30 | f"{normalized_prefix}/{{path:path}}", 31 | make_handler, 32 | methods=['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS'], 33 | ) 34 | 35 | def body_get_or_raise(body: Any, key: str): 36 | """Get value from body or raise an error""" 37 | value = body.get(key) 38 | if value is None: 39 | raise HTTPException(status_code=400, detail=f"{key} is required") 40 | return value 41 | 42 | 43 | async def handler(request: Request, sdk: CopilotKitSDK): 44 | """Handle FastAPI request""" 45 | 46 | try: 47 | body = await request.json() 48 | except Exception as exc: 49 | raise HTTPException(status_code=400, detail="Request body is required") from exc 50 | 51 | path = request.path_params.get('path') 52 | method = request.method 53 | context = cast( 54 | CopilotKitSDKContext, 55 | { 56 | "properties": body.get("properties", {}), 57 | "frontend_url": body.get("frontendUrl", None) 58 | } 59 | ) 60 | 61 | if method == 'POST' and path == 'info': 62 | return await handle_info(sdk=sdk, context=context) 63 | 64 | if method == 'POST' and path == 'actions/execute': 65 | name = body_get_or_raise(body, "name") 66 | arguments = body.get("arguments", {}) 67 | 68 | return await handle_execute_action( 69 | sdk=sdk, 70 | context=context, 71 | name=name, 72 | arguments=arguments, 73 | ) 74 | 75 | if method == 'POST' and path == 'agents/execute': 76 | thread_id = body.get("threadId") 77 | node_name = body.get("nodeName") 78 | 79 | name = body_get_or_raise(body, "name") 80 | state = body_get_or_raise(body, "state") 81 | messages = body_get_or_raise(body, "messages") 82 | actions = cast(List[ActionDict], body.get("actions", [])) 83 | 84 | return handle_execute_agent( 85 | sdk=sdk, 86 | context=context, 87 | thread_id=thread_id, 88 | node_name=node_name, 89 | name=name, 90 | state=state, 91 | messages=messages, 92 | actions=actions, 93 | ) 94 | 95 | 96 | raise HTTPException(status_code=404, detail="Not found") 97 | 98 | 99 | async def handle_info(*, sdk: CopilotKitSDK, context: CopilotKitSDKContext): 100 | """Handle info request with FastAPI""" 101 | result = sdk.info(context=context) 102 | return JSONResponse(content=result) 103 | 104 | async def handle_execute_action( 105 | *, 106 | sdk: CopilotKitSDK, 107 | context: CopilotKitSDKContext, 108 | name: str, 109 | arguments: dict, 110 | ): 111 | """Handle execute action request with FastAPI""" 112 | try: 113 | result = await sdk.execute_action( 114 | context=context, 115 | name=name, 116 | arguments=arguments 117 | ) 118 | return JSONResponse(content=result) 119 | except ActionNotFoundException as exc: 120 | logger.error("Action not found: %s", exc) 121 | return JSONResponse(content={"error": str(exc)}, status_code=404) 122 | except ActionExecutionException as exc: 123 | logger.error("Action execution error: %s", exc) 124 | return JSONResponse(content={"error": str(exc)}, status_code=500) 125 | except Exception as exc: # pylint: disable=broad-except 126 | logger.error("Action execution error: %s", exc) 127 | return JSONResponse(content={"error": str(exc)}, status_code=500) 128 | 129 | def handle_execute_agent( # pylint: disable=too-many-arguments 130 | *, 131 | sdk: CopilotKitSDK, 132 | context: CopilotKitSDKContext, 133 | thread_id: str, 134 | node_name: str, 135 | name: str, 136 | state: dict, 137 | messages: List[Message], 138 | actions: List[ActionDict], 139 | ): 140 | """Handle continue agent execution request with FastAPI""" 141 | try: 142 | events = sdk.execute_agent( 143 | context=context, 144 | thread_id=thread_id, 145 | name=name, 146 | node_name=node_name, 147 | state=state, 148 | messages=messages, 149 | actions=actions, 150 | ) 151 | return StreamingResponse(events, media_type="application/json") 152 | except AgentNotFoundException as exc: 153 | logger.error("Agent not found: %s", exc, exc_info=True) 154 | return JSONResponse(content={"error": str(exc)}, status_code=404) 155 | except AgentExecutionException as exc: 156 | logger.error("Agent execution error: %s", exc, exc_info=True) 157 | return JSONResponse(content={"error": str(exc)}, status_code=500) 158 | except Exception as exc: # pylint: disable=broad-except 159 | logger.error("Agent execution error: %s", exc, exc_info=True) 160 | return JSONResponse(content={"error": str(exc)}, status_code=500) 161 | -------------------------------------------------------------------------------- /copilotkit/demos/qa/agent.py: -------------------------------------------------------------------------------- 1 | """Test Human in the Loop Agent""" 2 | 3 | from typing import Any, cast 4 | from langgraph.graph import StateGraph, END 5 | from langgraph.graph import MessagesState 6 | from langgraph.checkpoint.memory import MemorySaver 7 | from langchain_core.runnables import RunnableConfig 8 | from langchain_core.messages import HumanMessage, ToolMessage, AIMessage 9 | # from langchain_google_genai import ChatGoogleGenerativeAI 10 | from langchain_openai import ChatOpenAI 11 | 12 | from pydantic import BaseModel, Field 13 | 14 | from copilotkit.langchain import ( 15 | copilotkit_customize_config, copilotkit_exit, copilotkit_emit_message 16 | ) 17 | 18 | 19 | def get_model(): 20 | """ 21 | Get a model based on the environment variable. 22 | """ 23 | # model = os.getenv("MODEL", "openai") 24 | return ChatOpenAI(temperature=0, model="gpt-4o") 25 | # return ChatGoogleGenerativeAI(temperature=0, model="gemini-1.5-pro") 26 | 27 | 28 | # if model == "openai": 29 | # return ChatOpenAI(temperature=0, model="gpt-4o") 30 | # if model == "anthropic": 31 | # return ChatAnthropic( 32 | # temperature=0, 33 | # model_name="claude-3-5-sonnet-20240620", 34 | # timeout=None, 35 | # stop=None 36 | # ) 37 | 38 | # raise ValueError("Invalid model specified") 39 | 40 | 41 | class EmailAgentState(MessagesState): 42 | """Email Agent State""" 43 | email: str 44 | 45 | class EmailTool(BaseModel): 46 | """ 47 | Write an email. 48 | """ 49 | email_draft: str = Field(description="The draft of the email to be written.") 50 | 51 | 52 | async def draft_email_node(state: EmailAgentState, config: RunnableConfig): 53 | """ 54 | Write an email. 55 | """ 56 | 57 | config = copilotkit_customize_config( 58 | config, 59 | emit_tool_calls=True, 60 | ) 61 | 62 | instructions = "You write emails." 63 | 64 | email_model = get_model().bind_tools( 65 | [cast(Any, EmailTool)], 66 | tool_choice="EmailTool" 67 | ) 68 | 69 | messages = state["messages"] 70 | # if len(messages) > 2: 71 | # messages = messages[:-4] 72 | 73 | # print("MESSAGES:") 74 | # for message in messages: 75 | # print(type(message)) 76 | # print(message) 77 | # print("----") 78 | 79 | 80 | response = await email_model.ainvoke([ 81 | *messages, 82 | HumanMessage( 83 | content=instructions 84 | ) 85 | ], config) 86 | 87 | # content='' additional_kwargs={'function_call': {'name': 'EmailTool', 'arguments': '{"email_draft": "Dear Sam Altman,\\\\n\\\\nI hope this email finds you well.\\\\n\\\\nI am writing to request a meeting with you to discuss [topic of discussion]. I am [your title/position] at [your company/organization] and I am particularly interested in [area of interest related to OpenAI].\\\\n\\\\nI am available on [list of dates/times]. Please let me know if any of these times work for you or suggest an alternative.\\\\n\\\\nThank you for your time and consideration.\\\\n\\\\nSincerely,\\\\n[Your Name]"}'}} response_metadata={'safety_ratings': [], 'finish_reason': 'STOP'} id='run-69254734-7c90-4743-adab-5e1d8cfe3099' tool_calls=[{'name': 'EmailTool', 'args': {'email_draft': 'Dear Sam Altman,\\n\\nI hope this email finds you well.\\n\\nI am writing to request a meeting with you to discuss [topic of discussion]. I am [your title/position] at [your company/organization] and I am particularly interested in [area of interest related to OpenAI].\\n\\nI am available on [list of dates/times]. Please let me know if any of these times work for you or suggest an alternative.\\n\\nThank you for your time and consideration.\\n\\nSincerely,\\n[Your Name]'}, 'id': '87607866-4201-4d52-bf8e-0cfd182e55bb', 'type': 'tool_call'}] usage_metadata={'input_tokens': 66, 'output_tokens': 122, 'total_tokens': 188, 'input_token_details': {'cache_read': 0}} 88 | # content='' additional_kwargs={} response_metadata={} id='run-1152f360-5a0b-4cb8-8065-0bb14c40f01f' tool_calls=[{'name': 'EmailTool', 'args': {'email_draft': 'Dear Sam Altman,\\n\\nI hope this email finds you well.\\n\\nI am writing to request a meeting with you to discuss [topic of discussion]. I am [your title/position] at [your company/organization] and I believe that a meeting between us would be mutually beneficial.\\n\\nI am available to meet at your earliest convenience. Please let me know what time works best for you.\\n\\nThank you for your time and consideration.\\n\\nSincerely,\\n[Your Name]'}, 'id': 'run-1152f360-5a0b-4cb8-8065-0bb14c40f01f', 'type': 'tool_call'}] 89 | 90 | print(response) 91 | 92 | tool_calls = cast(Any, response).tool_calls 93 | 94 | # the email content is the argument passed to the email tool 95 | email = tool_calls[0]["args"]["email_draft"] 96 | 97 | return { 98 | "email": email, 99 | } 100 | 101 | async def send_email_node(state: EmailAgentState, config: RunnableConfig): 102 | """ 103 | Send an email. 104 | """ 105 | 106 | config = copilotkit_customize_config( 107 | config, 108 | emit_messages=True, 109 | ) 110 | 111 | await copilotkit_exit(config) 112 | 113 | # get the last message and cast to ToolMessage 114 | last_message = cast(ToolMessage, state["messages"][-1]) 115 | message_to_add = "" 116 | if last_message.content == "CANCEL": 117 | message_to_add = "❌ Cancelled sending email." 118 | else: 119 | message_to_add = "✅ Sent email." 120 | 121 | await copilotkit_emit_message(config, message_to_add) 122 | return { 123 | "messages": state["messages"] + [AIMessage(content=message_to_add)], 124 | } 125 | 126 | 127 | workflow = StateGraph(EmailAgentState) 128 | workflow.add_node("draft_email_node", draft_email_node) 129 | workflow.add_node("send_email_node", send_email_node) 130 | workflow.set_entry_point("draft_email_node") 131 | 132 | workflow.add_edge("draft_email_node", "send_email_node") 133 | workflow.add_edge("send_email_node", END) 134 | memory = MemorySaver() 135 | graph = workflow.compile(checkpointer=memory, interrupt_after=["draft_email_node"]) 136 | -------------------------------------------------------------------------------- /copilotkit/langchain.py: -------------------------------------------------------------------------------- 1 | """ 2 | LangChain specific utilities for CopilotKit 3 | """ 4 | 5 | import uuid 6 | import json 7 | from typing import List, Optional, Any, Union, Dict, Callable 8 | 9 | from langchain_core.messages import ( 10 | HumanMessage, 11 | SystemMessage, 12 | BaseMessage, 13 | AIMessage, 14 | ToolMessage 15 | ) 16 | from langchain_core.runnables import RunnableConfig, RunnableGenerator 17 | 18 | from .types import Message, IntermediateStateConfig 19 | 20 | def copilotkit_messages_to_langchain( 21 | use_function_call: bool = False 22 | ) -> Callable[[List[Message]], List[BaseMessage]]: 23 | """ 24 | Convert CopilotKit messages to LangChain messages 25 | """ 26 | def _copilotkit_messages_to_langchain(messages: List[Message]) -> List[BaseMessage]: 27 | result = [] 28 | for message in messages: 29 | if "content" in message: 30 | if message["role"] == "user": 31 | result.append(HumanMessage(content=message["content"], id=message["id"])) 32 | elif message["role"] == "system": 33 | result.append(SystemMessage(content=message["content"], id=message["id"])) 34 | elif message["role"] == "assistant": 35 | result.append(AIMessage(content=message["content"], id=message["id"])) 36 | elif "arguments" in message: 37 | tool_call = { 38 | "name": message["name"], 39 | "args": message["arguments"], 40 | "id": message["id"], 41 | } 42 | additional_kwargs = { 43 | 'function_call':{ 44 | 'name': message["name"], 45 | 'arguments': json.dumps(message["arguments"]), 46 | } 47 | } 48 | if not use_function_call: 49 | ai_message = AIMessage( 50 | id=message["id"], 51 | content="", 52 | tool_calls=[tool_call] 53 | ) 54 | else: 55 | ai_message = AIMessage( 56 | id=message["id"], 57 | content="", 58 | additional_kwargs=additional_kwargs 59 | ) 60 | result.append(ai_message) 61 | 62 | elif "actionExecutionId" in message: 63 | result.append(ToolMessage( 64 | id=message["id"], 65 | content=message["result"], 66 | name=message["actionName"], 67 | tool_call_id=message["actionExecutionId"] 68 | )) 69 | return result 70 | 71 | return _copilotkit_messages_to_langchain 72 | 73 | def copilotkit_customize_config( 74 | base_config: Optional[RunnableConfig] = None, 75 | *, 76 | emit_tool_calls: Optional[Union[bool, str, List[str]]] = None, 77 | emit_messages: Optional[bool] = None, 78 | emit_all: Optional[bool] = None, 79 | emit_intermediate_state: Optional[List[IntermediateStateConfig]] = None 80 | ) -> RunnableConfig: 81 | """ 82 | Configure for LangChain for use in CopilotKit 83 | """ 84 | metadata = base_config.get("metadata", {}) if base_config else {} 85 | 86 | if emit_all is True: 87 | metadata["copilotkit:emit-tool-calls"] = True 88 | metadata["copilotkit:emit-messages"] = True 89 | else: 90 | if emit_tool_calls is not None: 91 | metadata["copilotkit:emit-tool-calls"] = emit_tool_calls 92 | if emit_messages is not None: 93 | metadata["copilotkit:emit-messages"] = emit_messages 94 | 95 | if emit_intermediate_state: 96 | metadata["copilotkit:emit-intermediate-state"] = emit_intermediate_state 97 | 98 | base_config = base_config or {} 99 | 100 | return { 101 | **base_config, 102 | "metadata": metadata 103 | } 104 | 105 | async def _exit_copilotkit_generator(state): # pylint: disable=unused-argument 106 | yield "Exit" 107 | 108 | 109 | async def copilotkit_exit(config: RunnableConfig): 110 | """ 111 | Exit CopilotKit 112 | """ 113 | # For some reason, we need to use this workaround to get custom events to work 114 | # dispatch_custom_event and friends don't seem to do anything 115 | gen = RunnableGenerator(_exit_copilotkit_generator).with_config( 116 | metadata={ 117 | "copilotkit:exit": True 118 | }, 119 | callbacks=config.get( 120 | "callbacks", [] 121 | ), 122 | ) 123 | async for _message in gen.astream({}): 124 | pass 125 | 126 | return True 127 | 128 | def _emit_copilotkit_state_generator(state): 129 | async def emit_state(_state: Any): # pylint: disable=unused-argument 130 | yield state 131 | return emit_state 132 | 133 | 134 | async def copilotkit_emit_state(config: RunnableConfig, state: Any): 135 | """ 136 | Emit CopilotKit state 137 | """ 138 | gen = RunnableGenerator(_emit_copilotkit_state_generator(state)).with_config( 139 | metadata={ 140 | "copilotkit:force-emit-intermediate-state": True 141 | }, 142 | callbacks=config.get( 143 | "callbacks", [] 144 | ), 145 | ) 146 | async for _message in gen.astream({}): 147 | pass 148 | 149 | return True 150 | 151 | def _emit_copilotkit_message_generator(message: str): 152 | async def emit_message(_message: Any): # pylint: disable=unused-argument 153 | yield message 154 | return emit_message 155 | 156 | async def copilotkit_emit_message(config: RunnableConfig, message: str): 157 | """ 158 | Emit CopilotKit message 159 | """ 160 | gen = RunnableGenerator(_emit_copilotkit_message_generator(message)).with_config( 161 | metadata={ 162 | "copilotkit:manually-emit-message": True 163 | }, 164 | callbacks=config.get( 165 | "callbacks", [] 166 | ), 167 | ) 168 | async for _message in gen.astream({}): 169 | pass 170 | 171 | return True 172 | 173 | def _emit_copilotkit_tool_call_generator(name: str, args: Dict[str, Any]): 174 | async def emit_tool_call(_tool_call: Any): # pylint: disable=unused-argument 175 | yield { 176 | "name": name, 177 | "args": args, 178 | "id": str(uuid.uuid4()) 179 | } 180 | return emit_tool_call 181 | 182 | async def copilotkit_emit_tool_call(config: RunnableConfig, *, name: str, args: Dict[str, Any]): 183 | """ 184 | Emit CopilotKit tool call 185 | """ 186 | gen = RunnableGenerator(_emit_copilotkit_tool_call_generator(name, args)).with_config( 187 | metadata={ 188 | "copilotkit:manually-emit-tool-call": True 189 | }, 190 | callbacks=config.get( 191 | "callbacks", [] 192 | ), 193 | ) 194 | async for _message in gen.astream({}): 195 | pass 196 | 197 | return True 198 | -------------------------------------------------------------------------------- /copilotkit/langgraph_cloud_agent.py: -------------------------------------------------------------------------------- 1 | # """LangGraph agent for CopilotKit""" 2 | 3 | # from typing import Optional, List, Callable, Any, cast 4 | # import uuid 5 | # from langchain.load.dump import dumps as langchain_dumps 6 | # from langchain.schema import BaseMessage, SystemMessage 7 | # from langgraph_sdk import get_client 8 | 9 | # from partialjson.json_parser import JSONParser 10 | 11 | # from .types import Message 12 | # from .langchain import copilotkit_messages_to_langchain 13 | # from .action import ActionDict 14 | # from .agent import Agent 15 | 16 | # def langgraph_default_merge_state( # pylint: disable=unused-argument 17 | # *, 18 | # state: dict, 19 | # messages: List[BaseMessage], 20 | # actions: List[Any] 21 | # ): 22 | # """Default merge state for LangGraph""" 23 | # if len(messages) > 0 and isinstance(messages[0], SystemMessage): 24 | # # remove system message 25 | # messages = messages[1:] 26 | 27 | # # merge with existing messages 28 | # merged_messages = state.get("messages", []) 29 | # existing_message_ids = {message.id for message in merged_messages} 30 | 31 | # for message in messages: 32 | # if message.id not in existing_message_ids: 33 | # merged_messages.append(message) 34 | 35 | # return { 36 | # **state, 37 | # "messages": merged_messages, 38 | # "copilotkit": { 39 | # "actions": actions 40 | # } 41 | # } 42 | 43 | # class LangGraphCloudAgent(Agent): 44 | # """LangGraph agent class for CopilotKit""" 45 | # def __init__( 46 | # self, 47 | # *, 48 | # name: str, 49 | # description: Optional[str] = None, 50 | # assistant_id: Optional[str] = None, 51 | # merge_state: Optional[Callable] = langgraph_default_merge_state 52 | # ): 53 | # super().__init__( 54 | # name=name, 55 | # description=description, 56 | # merge_state=merge_state 57 | # ) 58 | # self.assistant_id = assistant_id or name 59 | 60 | # def _emit_state_sync_event( 61 | # self, 62 | # *, 63 | # thread_id: str, 64 | # run_id: str, 65 | # node_name: str, 66 | # state: dict, 67 | # running: bool, 68 | # active: bool 69 | # ): 70 | # state_without_messages = { 71 | # k: v for k, v in state.items() if k != "messages" 72 | # } 73 | # return langchain_dumps({ 74 | # "event": "on_copilotkit_state_sync", 75 | # "thread_id": thread_id, 76 | # "run_id": run_id, 77 | # "agent_name": self.name, 78 | # "node_name": node_name, 79 | # "active": active, 80 | # "state": state_without_messages, 81 | # "running": running, 82 | # "role": "assistant" 83 | # }) 84 | 85 | # def execute( # pylint: disable=too-many-arguments 86 | # self, 87 | # *, 88 | # state: dict, 89 | # messages: List[Message], 90 | # thread_id: Optional[str] = None, 91 | # node_name: Optional[str] = None, 92 | # actions: Optional[List[ActionDict]] = None, 93 | # ): 94 | # return self._stream_events( 95 | # messages=messages, 96 | # state=state, 97 | # thread_id=thread_id, 98 | # node_name=node_name, 99 | # actions=actions 100 | # ) 101 | 102 | # async def _stream_events( 103 | # self, 104 | # *, 105 | # state: dict, 106 | # messages: List[Message], 107 | # thread_id: Optional[str] = None, 108 | # node_name: Optional[str] = None, 109 | # actions: Optional[List[ActionDict]] = None, 110 | # ): 111 | 112 | # client = get_client() 113 | # agent_state = {} 114 | # if thread_id: 115 | # agent_state = await client.threads.get_state( 116 | # thread_id=thread_id, 117 | # ) 118 | 119 | # state["messages"] = agent_state.get("values", {}).get("messages", []) 120 | 121 | # langchain_messages = copilotkit_messages_to_langchain(messages) 122 | # state = cast(Callable, self.merge_state)( 123 | # state=state, 124 | # messages=langchain_messages, 125 | # actions=actions 126 | # ) 127 | 128 | # mode = "continue" if thread_id and node_name != "__end__" else "start" 129 | # thread_id = thread_id or str(uuid.uuid4()) 130 | 131 | # if mode == "continue": 132 | # await client.threads.update_state( 133 | # thread_id=thread_id, 134 | # values=state, 135 | # as_node=node_name 136 | # ) 137 | 138 | # streaming_state_extractor = _StreamingStateExtractor([]) 139 | # initial_state = state if mode == "start" else None 140 | # prev_node_name = None 141 | # emit_intermediate_state_until_end = None 142 | # should_exit = False 143 | 144 | # graph_info = await client.assistants.get_graph( 145 | # assistant_id=self.assistant_id 146 | # ) 147 | 148 | # async for event in client.runs.stream( 149 | # thread_id, 150 | # self.assistant_id, 151 | # input=initial_state, 152 | # stream_mode="values" 153 | # ): 154 | # current_node_name = event.get("name") 155 | # event_type = event.get("event") 156 | # run_id = event.get("run_id") 157 | # tags = event.get("tags", []) 158 | # metadata = event.get("metadata", {}) 159 | 160 | # should_exit = should_exit or "copilotkit:exit" in tags 161 | 162 | # emit_intermediate_state = metadata.get("copilotkit:emit-intermediate-state") 163 | # force_emit_intermediate_state = "copilotkit:force-emit-intermediate-state" in tags 164 | 165 | # # we only want to update the node name under certain conditions 166 | # # since we don't need any internal node names to be sent to the frontend 167 | # if current_node_name in {node["id"]: node for node in graph_info["nodes"]}: 168 | # node_name = current_node_name 169 | 170 | # # we don't have a node name yet, so we can't update the state 171 | # if node_name is None: 172 | # continue 173 | 174 | # exiting_node = node_name == current_node_name and event_type == "on_chain_end" 175 | 176 | # if force_emit_intermediate_state: 177 | # if event_type == "on_chain_end": 178 | # state = cast(Any, event["data"])["output"] 179 | # yield self._emit_state_sync_event( 180 | # thread_id=thread_id, 181 | # run_id=run_id, 182 | # node_name=node_name, 183 | # state=state, 184 | # running=True, 185 | # active=True 186 | # ) + "\n" 187 | # continue 188 | 189 | # if emit_intermediate_state and emit_intermediate_state_until_end is None: 190 | # emit_intermediate_state_until_end = node_name 191 | 192 | # if emit_intermediate_state and event_type == "on_chat_model_start": 193 | # # reset the streaming state extractor 194 | # streaming_state_extractor = _StreamingStateExtractor(emit_intermediate_state) 195 | 196 | # updated_state = await client.threads.get_state(thread_id=thread_id)["values"] 197 | 198 | # if emit_intermediate_state and event_type == "on_chat_model_stream": 199 | # streaming_state_extractor.buffer_tool_calls(event) 200 | 201 | # if emit_intermediate_state_until_end is not None: 202 | # updated_state = { 203 | # **updated_state, 204 | # **streaming_state_extractor.extract_state() 205 | # } 206 | 207 | # if (not emit_intermediate_state and 208 | # current_node_name == emit_intermediate_state_until_end and 209 | # event_type == "on_chain_end"): 210 | # # stop emitting function call state 211 | # emit_intermediate_state_until_end = None 212 | 213 | # # we send state sync events when: 214 | # # a) the state has changed 215 | # # b) the node has changed 216 | # # c) the node is ending 217 | # if updated_state != state or prev_node_name != node_name or exiting_node: 218 | # state = updated_state 219 | # prev_node_name = node_name 220 | # yield self._emit_state_sync_event( 221 | # thread_id=thread_id, 222 | # run_id=run_id, 223 | # node_name=node_name, 224 | # state=state, 225 | # running=True, 226 | # active=not exiting_node 227 | # ) + "\n" 228 | 229 | # yield langchain_dumps(event) + "\n" 230 | 231 | # state = await client.threads.get_state(thread_id=thread_id) 232 | # is_end_node = state["next"] == () 233 | 234 | # node_name = list(state["metadata"]["writes"].keys())[0] 235 | 236 | # yield self._emit_state_sync_event( 237 | # thread_id=thread_id, 238 | # run_id=run_id, 239 | # node_name=cast(str, node_name) if not is_end_node else "__end__", 240 | # state=state["values"], 241 | # running=not should_exit, 242 | # # at this point, the node is ending so we set active to false 243 | # active=False 244 | # ) + "\n" 245 | 246 | 247 | 248 | # def dict_repr(self): 249 | # super_repr = super().dict_repr() 250 | # return { 251 | # **super_repr, 252 | # 'type': 'langgraph' 253 | # } 254 | 255 | # class _StreamingStateExtractor: 256 | # def __init__(self, emit_intermediate_state: List[dict]): 257 | # self.emit_intermediate_state = emit_intermediate_state 258 | # self.tool_call_buffer = {} 259 | # self.current_tool_call = None 260 | 261 | # self.previously_parsable_state = {} 262 | 263 | # def buffer_tool_calls(self, event: Any): 264 | # """Buffer the tool calls""" 265 | # if len(event["data"]["chunk"].tool_call_chunks) > 0: 266 | # chunk = event["data"]["chunk"].tool_call_chunks[0] 267 | # if chunk["name"] is not None: 268 | # self.current_tool_call = chunk["name"] 269 | # self.tool_call_buffer[self.current_tool_call] = chunk["args"] 270 | # elif self.current_tool_call is not None: 271 | # self.tool_call_buffer[self.current_tool_call] = ( 272 | # self.tool_call_buffer[self.current_tool_call] + chunk["args"] 273 | # ) 274 | 275 | # def get_emit_state_config(self, current_tool_name): 276 | # """Get the emit state config""" 277 | 278 | # for config in self.emit_intermediate_state: 279 | # state_key = config.get("state_key") 280 | # tool = config.get("tool") 281 | # tool_argument = config.get("tool_argument") 282 | 283 | # if current_tool_name == tool: 284 | # return (tool_argument, state_key) 285 | 286 | # return (None, None) 287 | 288 | 289 | # def extract_state(self): 290 | # """Extract the streaming state""" 291 | # parser = JSONParser() 292 | 293 | # state = {} 294 | 295 | # for key, value in self.tool_call_buffer.items(): 296 | # argument_name, state_key = self.get_emit_state_config(key) 297 | 298 | # if state_key is None: 299 | # continue 300 | 301 | # try: 302 | # parsed_value = parser.parse(value) 303 | # except Exception as _exc: # pylint: disable=broad-except 304 | # if key in self.previously_parsable_state: 305 | # parsed_value = self.previously_parsable_state[key] 306 | # else: 307 | # continue 308 | 309 | # self.previously_parsable_state[key] = parsed_value 310 | 311 | # if argument_name is None: 312 | # state[state_key] = parsed_value 313 | # else: 314 | # state[state_key] = parsed_value.get(argument_name) 315 | 316 | # return state 317 | -------------------------------------------------------------------------------- /copilotkit/langgraph_agent.py: -------------------------------------------------------------------------------- 1 | """LangGraph agent for CopilotKit""" 2 | 3 | import uuid 4 | import json 5 | from typing import Optional, List, Callable, Any, cast, Union, TypedDict 6 | from typing_extensions import NotRequired 7 | 8 | from langgraph.graph.graph import CompiledGraph 9 | from langchain.load.dump import dumps as langchain_dumps 10 | from langchain.schema import BaseMessage, SystemMessage 11 | from langchain_core.runnables import RunnableConfig, ensure_config 12 | from langchain_core.messages import AIMessage, ToolMessage 13 | 14 | from partialjson.json_parser import JSONParser 15 | 16 | from .types import Message 17 | from .langchain import copilotkit_messages_to_langchain 18 | from .action import ActionDict 19 | from .agent import Agent 20 | from .logging import get_logger 21 | 22 | logger = get_logger(__name__) 23 | 24 | class CopilotKitConfig(TypedDict): 25 | """CopilotKit config""" 26 | merge_state: NotRequired[Callable] 27 | convert_messages: NotRequired[Callable] 28 | 29 | def langgraph_default_merge_state( # pylint: disable=unused-argument 30 | *, 31 | state: dict, 32 | messages: List[BaseMessage], 33 | actions: List[Any], 34 | agent_name: str 35 | ): 36 | """Default merge state for LangGraph""" 37 | if len(messages) > 0 and isinstance(messages[0], SystemMessage): 38 | # remove system message 39 | messages = messages[1:] 40 | 41 | 42 | # merge with existing messages 43 | merged_messages = state.get("messages", []) 44 | existing_message_ids = {message.id for message in merged_messages} 45 | existing_tool_call_results = set() 46 | 47 | for message in merged_messages: 48 | if isinstance(message, ToolMessage): 49 | existing_tool_call_results.add(message.tool_call_id) 50 | 51 | for message in messages: 52 | # filter tool calls to activate the agent itself 53 | if ( 54 | isinstance(message, AIMessage) and 55 | message.tool_calls and 56 | message.tool_calls[0]["name"] == agent_name 57 | ): 58 | continue 59 | 60 | # filter results from activating the agent 61 | if ( 62 | isinstance(message, ToolMessage) and 63 | message.name == agent_name 64 | ): 65 | continue 66 | 67 | if message.id not in existing_message_ids: 68 | 69 | # skip duplicate tool call results 70 | if (isinstance(message, ToolMessage) and 71 | message.tool_call_id in existing_tool_call_results): 72 | logger.warning( 73 | "Warning: Duplicate tool call result, skipping: %s", 74 | message.tool_call_id 75 | ) 76 | continue 77 | 78 | merged_messages.append(message) 79 | else: 80 | # Replace the message with the existing one 81 | for i, existing_message in enumerate(merged_messages): 82 | if existing_message.id == message.id: 83 | # if the message is an AIMessage, we need to merge 84 | # the tool calls and additional kwargs 85 | if isinstance(message, AIMessage): 86 | if ( 87 | (merged_messages[i].tool_calls or 88 | merged_messages[i].additional_kwargs) and 89 | merged_messages[i].content 90 | ): 91 | message.tool_calls = merged_messages[i].tool_calls 92 | message.additional_kwargs = merged_messages[i].additional_kwargs 93 | merged_messages[i] = message 94 | 95 | # fix wrong tool call ids 96 | for i, current_message in enumerate(merged_messages): 97 | if i == len(merged_messages) - 1: 98 | break 99 | next_message = merged_messages[i + 1] 100 | if (not isinstance(current_message, AIMessage) or 101 | not isinstance(next_message, ToolMessage)): 102 | continue 103 | 104 | if current_message.tool_calls and current_message.tool_calls[0]["id"]: 105 | next_message.tool_call_id = current_message.tool_calls[0]["id"] 106 | 107 | 108 | 109 | return { 110 | **state, 111 | "messages": merged_messages, 112 | "copilotkit": { 113 | "actions": actions 114 | } 115 | } 116 | 117 | class LangGraphAgent(Agent): 118 | """LangGraph agent class for CopilotKit""" 119 | def __init__( 120 | self, 121 | *, 122 | name: str, 123 | description: Optional[str] = None, 124 | graph: Optional[CompiledGraph] = None, 125 | langgraph_config: Union[Optional[RunnableConfig], dict] = None, 126 | copilotkit_config: Optional[CopilotKitConfig] = None, 127 | 128 | # deprecated - use langgraph_config instead 129 | config: Union[Optional[RunnableConfig], dict] = None, 130 | # deprecated - use graph instead 131 | agent: Optional[CompiledGraph] = None, 132 | # deprecated - use copilotkit_config instead 133 | merge_state: Optional[Callable] = None, 134 | 135 | ): 136 | if config is not None: 137 | logger.warning("Warning: config is deprecated, use langgraph_config instead") 138 | 139 | if agent is not None: 140 | logger.warning("Warning: agent is deprecated, use graph instead") 141 | 142 | if merge_state is None: 143 | logger.warning("Warning: merge_state is deprecated, use copilotkit_config instead") 144 | 145 | if graph is None and agent is None: 146 | raise ValueError("graph must be provided") 147 | 148 | super().__init__( 149 | name=name, 150 | description=description, 151 | ) 152 | 153 | self.merge_state = None 154 | 155 | if copilotkit_config is not None: 156 | self.merge_state = copilotkit_config.get("merge_state") 157 | if not self.merge_state and merge_state is not None: 158 | self.merge_state = merge_state 159 | if not self.merge_state: 160 | self.merge_state = langgraph_default_merge_state 161 | 162 | self.convert_messages = ( 163 | copilotkit_config.get("convert_messages") 164 | if copilotkit_config 165 | else None 166 | ) or copilotkit_messages_to_langchain(use_function_call=False) 167 | 168 | self.langgraph_config = langgraph_config or config 169 | 170 | self.graph = cast(CompiledGraph, graph or agent) 171 | 172 | def _emit_state_sync_event( 173 | self, 174 | *, 175 | thread_id: str, 176 | run_id: str, 177 | node_name: str, 178 | state: dict, 179 | running: bool, 180 | active: bool 181 | ): 182 | state_without_messages = { 183 | k: v for k, v in state.items() if k != "messages" 184 | } 185 | return langchain_dumps({ 186 | "event": "on_copilotkit_state_sync", 187 | "thread_id": thread_id, 188 | "run_id": run_id, 189 | "agent_name": self.name, 190 | "node_name": node_name, 191 | "active": active, 192 | "state": state_without_messages, 193 | "running": running, 194 | "role": "assistant" 195 | }) 196 | 197 | def execute( # pylint: disable=too-many-arguments 198 | self, 199 | *, 200 | state: dict, 201 | messages: List[Message], 202 | thread_id: Optional[str] = None, 203 | node_name: Optional[str] = None, 204 | actions: Optional[List[ActionDict]] = None, 205 | ): 206 | config = ensure_config(cast(Any, self.langgraph_config.copy()) if self.langgraph_config else {}) # pylint: disable=line-too-long 207 | config["configurable"] = config.get("configurable", {}) 208 | config["configurable"]["thread_id"] = thread_id 209 | 210 | agent_state = self.graph.get_state(config) 211 | state["messages"] = agent_state.values.get("messages", []) 212 | 213 | langchain_messages = self.convert_messages(messages) 214 | state = cast(Callable, self.merge_state)( 215 | state=state, 216 | messages=langchain_messages, 217 | actions=actions, 218 | agent_name=self.name 219 | ) 220 | 221 | mode = "continue" if thread_id and node_name != "__end__" else "start" 222 | thread_id = thread_id or str(uuid.uuid4()) 223 | config["configurable"]["thread_id"] = thread_id 224 | 225 | if mode == "continue": 226 | self.graph.update_state(config, state, as_node=node_name) 227 | 228 | return self._stream_events( 229 | mode=mode, 230 | config=config, 231 | state=state, 232 | node_name=node_name 233 | ) 234 | 235 | async def _stream_events( # pylint: disable=too-many-locals 236 | self, 237 | *, 238 | mode: str, 239 | config: RunnableConfig, 240 | state: Any, 241 | node_name: Optional[str] = None 242 | ): 243 | 244 | streaming_state_extractor = _StreamingStateExtractor([]) 245 | initial_state = state if mode == "start" else None 246 | prev_node_name = None 247 | emit_intermediate_state_until_end = None 248 | should_exit = False 249 | thread_id = cast(Any, config)["configurable"]["thread_id"] 250 | 251 | async for event in self.graph.astream_events(initial_state, config, version="v1"): 252 | current_node_name = event.get("name") 253 | event_type = event.get("event") 254 | run_id = event.get("run_id") 255 | metadata = event.get("metadata", {}) 256 | 257 | should_exit = should_exit or metadata.get("copilotkit:exit", False) 258 | 259 | emit_intermediate_state = metadata.get("copilotkit:emit-intermediate-state") 260 | force_emit_intermediate_state = metadata.get("copilotkit:force-emit-intermediate-state", False) # pylint: disable=line-too-long 261 | manually_emit_message = metadata.get("copilotkit:manually-emit-message", False) 262 | manually_emit_tool_call = metadata.get("copilotkit:manually-emit-tool-call", False) 263 | 264 | # we only want to update the node name under certain conditions 265 | # since we don't need any internal node names to be sent to the frontend 266 | if current_node_name in self.graph.nodes.keys(): 267 | node_name = current_node_name 268 | 269 | # we don't have a node name yet, so we can't update the state 270 | if node_name is None: 271 | continue 272 | 273 | exiting_node = node_name == current_node_name and event_type == "on_chain_end" 274 | 275 | if force_emit_intermediate_state: 276 | if event_type == "on_chain_end": 277 | state = cast(Any, event["data"])["output"] 278 | yield self._emit_state_sync_event( 279 | thread_id=thread_id, 280 | run_id=run_id, 281 | node_name=node_name, 282 | state=state, 283 | running=True, 284 | active=True 285 | ) + "\n" 286 | continue 287 | 288 | if manually_emit_message: 289 | if event_type == "on_chain_end": 290 | yield json.dumps( 291 | { 292 | "event": "on_copilotkit_emit_message", 293 | "message": cast(Any, event["data"])["output"], 294 | "message_id": str(uuid.uuid4()), 295 | "role": "assistant" 296 | } 297 | ) + "\n" 298 | continue 299 | 300 | if manually_emit_tool_call: 301 | if event_type == "on_chain_end": 302 | yield json.dumps( 303 | { 304 | "event": "on_copilotkit_emit_tool_call", 305 | "name": cast(Any, event["data"])["output"]["name"], 306 | "args": cast(Any, event["data"])["output"]["args"], 307 | "id": cast(Any, event["data"])["output"]["id"] 308 | } 309 | ) + "\n" 310 | continue 311 | 312 | if emit_intermediate_state and emit_intermediate_state_until_end is None: 313 | emit_intermediate_state_until_end = node_name 314 | 315 | if emit_intermediate_state and event_type == "on_chat_model_start": 316 | # reset the streaming state extractor 317 | streaming_state_extractor = _StreamingStateExtractor(emit_intermediate_state) 318 | 319 | updated_state = self.graph.get_state(config).values 320 | 321 | if emit_intermediate_state and event_type == "on_chat_model_stream": 322 | streaming_state_extractor.buffer_tool_calls(event) 323 | 324 | if emit_intermediate_state_until_end is not None: 325 | updated_state = { 326 | **updated_state, 327 | **streaming_state_extractor.extract_state() 328 | } 329 | 330 | if (not emit_intermediate_state and 331 | current_node_name == emit_intermediate_state_until_end and 332 | event_type == "on_chain_end"): 333 | # stop emitting function call state 334 | emit_intermediate_state_until_end = None 335 | 336 | # we send state sync events when: 337 | # a) the state has changed 338 | # b) the node has changed 339 | # c) the node is ending 340 | if updated_state != state or prev_node_name != node_name or exiting_node: 341 | state = updated_state 342 | prev_node_name = node_name 343 | yield self._emit_state_sync_event( 344 | thread_id=thread_id, 345 | run_id=run_id, 346 | node_name=node_name, 347 | state=state, 348 | running=True, 349 | active=not exiting_node 350 | ) + "\n" 351 | 352 | yield langchain_dumps(event) + "\n" 353 | 354 | state = self.graph.get_state(config) 355 | is_end_node = state.next == () 356 | 357 | node_name = list(state.metadata["writes"].keys())[0] 358 | 359 | yield self._emit_state_sync_event( 360 | thread_id=thread_id, 361 | run_id=run_id, 362 | node_name=cast(str, node_name) if not is_end_node else "__end__", 363 | state=state.values, 364 | running=not should_exit, 365 | # at this point, the node is ending so we set active to false 366 | active=False 367 | ) + "\n" 368 | 369 | 370 | 371 | def dict_repr(self): 372 | super_repr = super().dict_repr() 373 | return { 374 | **super_repr, 375 | 'type': 'langgraph' 376 | } 377 | 378 | class _StreamingStateExtractor: 379 | def __init__(self, emit_intermediate_state: List[dict]): 380 | self.emit_intermediate_state = emit_intermediate_state 381 | self.tool_call_buffer = {} 382 | self.current_tool_call = None 383 | 384 | self.previously_parsable_state = {} 385 | 386 | def buffer_tool_calls(self, event: Any): 387 | """Buffer the tool calls""" 388 | if len(event["data"]["chunk"].tool_call_chunks) > 0: 389 | chunk = event["data"]["chunk"].tool_call_chunks[0] 390 | if chunk["name"] is not None: 391 | self.current_tool_call = chunk["name"] 392 | self.tool_call_buffer[self.current_tool_call] = chunk["args"] 393 | elif self.current_tool_call is not None: 394 | self.tool_call_buffer[self.current_tool_call] = ( 395 | self.tool_call_buffer[self.current_tool_call] + chunk["args"] 396 | ) 397 | 398 | def get_emit_state_config(self, current_tool_name): 399 | """Get the emit state config""" 400 | 401 | for config in self.emit_intermediate_state: 402 | state_key = config.get("state_key") 403 | tool = config.get("tool") 404 | tool_argument = config.get("tool_argument") 405 | 406 | if current_tool_name == tool: 407 | return (tool_argument, state_key) 408 | 409 | return (None, None) 410 | 411 | 412 | def extract_state(self): 413 | """Extract the streaming state""" 414 | parser = JSONParser() 415 | 416 | state = {} 417 | 418 | for key, value in self.tool_call_buffer.items(): 419 | argument_name, state_key = self.get_emit_state_config(key) 420 | 421 | if state_key is None: 422 | continue 423 | 424 | try: 425 | parsed_value = parser.parse(value) 426 | except Exception as _exc: # pylint: disable=broad-except 427 | if key in self.previously_parsable_state: 428 | parsed_value = self.previously_parsable_state[key] 429 | else: 430 | continue 431 | 432 | self.previously_parsable_state[key] = parsed_value 433 | 434 | if argument_name is None: 435 | state[state_key] = parsed_value 436 | else: 437 | state[state_key] = parsed_value.get(argument_name) 438 | 439 | return state 440 | --------------------------------------------------------------------------------