├── .gitignore ├── README.md ├── abc_task_manager.py ├── agent.py ├── agentpartner.py ├── card_resolver.py ├── client.py ├── custom_types.py ├── google_host_agent.py ├── host_agent.py ├── in_memory_cache.py ├── mcp_app.py ├── push_notification_auth.py ├── requirements.txt ├── server.py ├── task_manager.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | . 2 | __init__.py 3 | __pycache__ 4 | .env 5 | .venv -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A2A + MCP Example 2 | 3 | This project demonstrates communication between agents using the **Agent-to-Agent (A2A)** protocol in combination with the **Model-Context-Protocol (MCP)**. 4 | 5 | --- 6 | 7 | ## 🔧 Components 8 | 9 | - **mcp_app.py** 10 | MCP server providing tools (functions/endpoints) that can be used by agents. 11 | 12 | - **agentpartner.py** 13 | Agent B — uses tools exposed by the MCP server. 14 | 15 | - **host_agent.py** 16 | Agent A — communicates with Agent B using the A2A protocol. 17 | 18 | --- 19 | 20 | ## ▶️ How to Run 21 | 22 | ### 1. Install dependencies 23 | 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### 2. Start the MCP Server 29 | 30 | ```bash 31 | python mcp_app.py 32 | ``` 33 | 34 | This will start the MCP server that exposes tool endpoints. 35 | 36 | ### 3. Run Agent B (agentpartner) 37 | 38 | ```bash 39 | python agentpartner.py 40 | ``` 41 | 42 | Agent B will register itself and wait for instructions from Agent A. 43 | 44 | ### 4. Run Agent A (host agent) 45 | 46 | ```bash 47 | python host_agent.py 48 | ``` 49 | 50 | Agent A initiates communication with Agent B using the A2A protocol and calls MCP tools via Agent B. 51 | 52 | --- 53 | 54 | ## ✅ Expected Result 55 | 56 | Agent A sends a request to Agent B via A2A. 57 | Agent B uses the MCP protocol to invoke tools and returns the result. 58 | 59 | --- 60 | 61 | Feel free to extend this setup with more tools or agents! 62 | -------------------------------------------------------------------------------- /abc_task_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from abc import ABC, abstractmethod 4 | from typing import AsyncIterable, List, Union 5 | 6 | from custom_types import ( 7 | Artifact, 8 | CancelTaskRequest, 9 | CancelTaskResponse, 10 | GetTaskPushNotificationRequest, 11 | GetTaskPushNotificationResponse, 12 | GetTaskRequest, 13 | GetTaskResponse, 14 | InternalError, 15 | JSONRPCError, 16 | JSONRPCResponse, 17 | PushNotificationConfig, 18 | SendTaskRequest, 19 | SendTaskResponse, 20 | SendTaskStreamingRequest, 21 | SendTaskStreamingResponse, 22 | SetTaskPushNotificationRequest, 23 | SetTaskPushNotificationResponse, 24 | Task, 25 | TaskIdParams, 26 | TaskNotCancelableError, 27 | TaskNotFoundError, 28 | TaskPushNotificationConfig, 29 | TaskQueryParams, 30 | TaskResubscriptionRequest, 31 | TaskSendParams, 32 | TaskState, 33 | TaskStatus, 34 | TaskStatusUpdateEvent, 35 | ) 36 | from utils import new_not_implemented_error 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | class TaskManager(ABC): 42 | @abstractmethod 43 | async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: 44 | pass 45 | 46 | @abstractmethod 47 | async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: 48 | pass 49 | 50 | @abstractmethod 51 | async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: 52 | pass 53 | 54 | @abstractmethod 55 | async def on_send_task_subscribe( 56 | self, request: SendTaskStreamingRequest 57 | ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 58 | pass 59 | 60 | @abstractmethod 61 | async def on_set_task_push_notification( 62 | self, request: SetTaskPushNotificationRequest 63 | ) -> SetTaskPushNotificationResponse: 64 | pass 65 | 66 | @abstractmethod 67 | async def on_get_task_push_notification( 68 | self, request: GetTaskPushNotificationRequest 69 | ) -> GetTaskPushNotificationResponse: 70 | pass 71 | 72 | @abstractmethod 73 | async def on_resubscribe_to_task( 74 | self, request: TaskResubscriptionRequest 75 | ) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]: 76 | pass 77 | 78 | 79 | class InMemoryTaskManager(TaskManager): 80 | def __init__(self): 81 | self.tasks: dict[str, Task] = {} 82 | self.push_notification_infos: dict[str, PushNotificationConfig] = {} 83 | self.lock = asyncio.Lock() 84 | self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {} 85 | self.subscriber_lock = asyncio.Lock() 86 | 87 | async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: 88 | logger.info(f"Getting task {request.params.id}") 89 | task_query_params: TaskQueryParams = request.params 90 | 91 | async with self.lock: 92 | task = self.tasks.get(task_query_params.id) 93 | if task is None: 94 | return GetTaskResponse(id=request.id, error=TaskNotFoundError()) 95 | 96 | task_result = self.append_task_history( 97 | task, task_query_params.historyLength 98 | ) 99 | 100 | return GetTaskResponse(id=request.id, result=task_result) 101 | 102 | async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: 103 | logger.info(f"Cancelling task {request.params.id}") 104 | task_id_params: TaskIdParams = request.params 105 | 106 | async with self.lock: 107 | task = self.tasks.get(task_id_params.id) 108 | if task is None: 109 | return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) 110 | 111 | return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) 112 | 113 | @abstractmethod 114 | async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: 115 | pass 116 | 117 | @abstractmethod 118 | async def on_send_task_subscribe( 119 | self, request: SendTaskStreamingRequest 120 | ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 121 | pass 122 | 123 | async def set_push_notification_info( 124 | self, task_id: str, notification_config: PushNotificationConfig 125 | ): 126 | async with self.lock: 127 | task = self.tasks.get(task_id) 128 | if task is None: 129 | raise ValueError(f"Task not found for {task_id}") 130 | 131 | self.push_notification_infos[task_id] = notification_config 132 | 133 | return 134 | 135 | async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig: 136 | async with self.lock: 137 | task = self.tasks.get(task_id) 138 | if task is None: 139 | raise ValueError(f"Task not found for {task_id}") 140 | 141 | return self.push_notification_infos[task_id] 142 | 143 | return 144 | 145 | async def has_push_notification_info(self, task_id: str) -> bool: 146 | async with self.lock: 147 | return task_id in self.push_notification_infos 148 | 149 | async def on_set_task_push_notification( 150 | self, request: SetTaskPushNotificationRequest 151 | ) -> SetTaskPushNotificationResponse: 152 | logger.info(f"Setting task push notification {request.params.id}") 153 | task_notification_params: TaskPushNotificationConfig = request.params 154 | 155 | try: 156 | await self.set_push_notification_info( 157 | task_notification_params.id, 158 | task_notification_params.pushNotificationConfig, 159 | ) 160 | except Exception as e: 161 | logger.error(f"Error while setting push notification info: {e}") 162 | return JSONRPCResponse( 163 | id=request.id, 164 | error=InternalError( 165 | message="An error occurred while setting push notification info" 166 | ), 167 | ) 168 | 169 | return SetTaskPushNotificationResponse( 170 | id=request.id, result=task_notification_params 171 | ) 172 | 173 | async def on_get_task_push_notification( 174 | self, request: GetTaskPushNotificationRequest 175 | ) -> GetTaskPushNotificationResponse: 176 | logger.info(f"Getting task push notification {request.params.id}") 177 | task_params: TaskIdParams = request.params 178 | 179 | try: 180 | notification_info = await self.get_push_notification_info(task_params.id) 181 | except Exception as e: 182 | logger.error(f"Error while getting push notification info: {e}") 183 | return GetTaskPushNotificationResponse( 184 | id=request.id, 185 | error=InternalError( 186 | message="An error occurred while getting push notification info" 187 | ), 188 | ) 189 | 190 | return GetTaskPushNotificationResponse( 191 | id=request.id, 192 | result=TaskPushNotificationConfig( 193 | id=task_params.id, pushNotificationConfig=notification_info 194 | ), 195 | ) 196 | 197 | async def upsert_task(self, task_send_params: TaskSendParams) -> Task: 198 | logger.info(f"Upserting task {task_send_params.id}") 199 | async with self.lock: 200 | task = self.tasks.get(task_send_params.id) 201 | if task is None: 202 | task = Task( 203 | id=task_send_params.id, 204 | sessionId=task_send_params.sessionId, 205 | messages=[task_send_params.message], 206 | status=TaskStatus(state=TaskState.SUBMITTED), 207 | history=[task_send_params.message], 208 | ) 209 | self.tasks[task_send_params.id] = task 210 | else: 211 | task.history.append(task_send_params.message) 212 | 213 | return task 214 | 215 | async def on_resubscribe_to_task( 216 | self, request: TaskResubscriptionRequest 217 | ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 218 | return new_not_implemented_error(request.id) 219 | 220 | async def update_store( 221 | self, task_id: str, status: TaskStatus, artifacts: list[Artifact] 222 | ) -> Task: 223 | async with self.lock: 224 | try: 225 | task = self.tasks[task_id] 226 | except KeyError: 227 | logger.error(f"Task {task_id} not found for updating the task") 228 | raise ValueError(f"Task {task_id} not found") 229 | 230 | task.status = status 231 | 232 | if status.message is not None: 233 | task.history.append(status.message) 234 | 235 | if artifacts is not None: 236 | if task.artifacts is None: 237 | task.artifacts = [] 238 | task.artifacts.extend(artifacts) 239 | 240 | return task 241 | 242 | def append_task_history(self, task: Task, historyLength: int | None): 243 | new_task = task.model_copy() 244 | if historyLength is not None and historyLength > 0: 245 | new_task.history = new_task.history[-historyLength:] 246 | else: 247 | new_task.history = [] 248 | 249 | return new_task 250 | 251 | async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False): 252 | async with self.subscriber_lock: 253 | if task_id not in self.task_sse_subscribers: 254 | if is_resubscribe: 255 | raise ValueError("Task not found for resubscription") 256 | else: 257 | self.task_sse_subscribers[task_id] = [] 258 | 259 | sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited 260 | self.task_sse_subscribers[task_id].append(sse_event_queue) 261 | return sse_event_queue 262 | 263 | async def enqueue_events_for_sse(self, task_id, task_update_event): 264 | async with self.subscriber_lock: 265 | if task_id not in self.task_sse_subscribers: 266 | return 267 | 268 | current_subscribers = self.task_sse_subscribers[task_id] 269 | for subscriber in current_subscribers: 270 | await subscriber.put(task_update_event) 271 | 272 | async def dequeue_events_for_sse( 273 | self, request_id, task_id, sse_event_queue: asyncio.Queue 274 | ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: 275 | try: 276 | while True: 277 | event = await sse_event_queue.get() 278 | if isinstance(event, JSONRPCError): 279 | yield SendTaskStreamingResponse(id=request_id, error=event) 280 | break 281 | 282 | yield SendTaskStreamingResponse(id=request_id, result=event) 283 | if isinstance(event, TaskStatusUpdateEvent) and event.final: 284 | break 285 | finally: 286 | async with self.subscriber_lock: 287 | if task_id in self.task_sse_subscribers: 288 | self.task_sse_subscribers[task_id].remove(sse_event_queue) 289 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, AsyncIterable, Dict, Literal 3 | 4 | from langchain_core.messages import AIMessage, ToolMessage 5 | from langchain_mcp_adapters.client import MultiServerMCPClient 6 | from langchain_openai import ChatOpenAI 7 | from langgraph.checkpoint.memory import MemorySaver 8 | from langgraph.prebuilt import create_react_agent 9 | from pydantic import BaseModel 10 | 11 | memory = MemorySaver() 12 | 13 | 14 | def _fetch_mcp_tools_sync() -> list: 15 | """ 16 | Helper function: runs the async MultiServerMCPClient code in a synchronous manner. 17 | Fetches the remote tools from your MCP server(s). 18 | """ 19 | servers_config = { 20 | "currency_server": { 21 | "transport": "sse", 22 | "url": "http://127.0.0.1:3000/sse", 23 | } 24 | } 25 | 26 | async def _fetch_tools(): 27 | async with MultiServerMCPClient(servers_config) as client: 28 | return client.get_tools() 29 | 30 | # Run the async method in a blocking (sync) fashion 31 | return asyncio.run(_fetch_tools()) 32 | 33 | 34 | class ResponseFormat(BaseModel): 35 | """Respond to the user in this format.""" 36 | 37 | status: Literal["input_required", "completed", "error"] = "input_required" 38 | message: str 39 | 40 | 41 | class CurrencyAgent: 42 | SYSTEM_INSTRUCTION = ( 43 | "You are a specialized assistant for currency conversions. " 44 | "Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates. " 45 | "If the user asks about anything other than currency conversion or exchange rates, " 46 | "politely state that you cannot help with that topic and can only assist with currency-related queries. " 47 | "Do not attempt to answer unrelated questions or use tools for other purposes." 48 | "Set response status to input_required if the user needs to provide more information." 49 | "Set response status to error if there is an error while processing the request." 50 | "Set response status to completed if the request is complete." 51 | ) 52 | 53 | def __init__(self): 54 | # Instead of a local @tool, fetch remote tools from MCP 55 | self.tools = _fetch_mcp_tools_sync() 56 | 57 | self.model = ChatOpenAI(model="gpt-4o-mini") 58 | self.graph = create_react_agent( 59 | self.model, 60 | tools=self.tools, 61 | checkpointer=memory, 62 | prompt=self.SYSTEM_INSTRUCTION, 63 | response_format=ResponseFormat, 64 | ) 65 | 66 | def invoke(self, query, sessionId) -> str: 67 | config = {"configurable": {"thread_id": sessionId}} 68 | self.graph.invoke({"messages": [("user", query)]}, config) 69 | return self.get_agent_response(config) 70 | 71 | async def stream(self, query, sessionId) -> AsyncIterable[Dict[str, Any]]: 72 | inputs = {"messages": [("user", query)]} 73 | config = {"configurable": {"thread_id": sessionId}} 74 | 75 | for item in self.graph.stream(inputs, config, stream_mode="values"): 76 | message = item["messages"][-1] 77 | if ( 78 | isinstance(message, AIMessage) 79 | and message.tool_calls 80 | and len(message.tool_calls) > 0 81 | ): 82 | yield { 83 | "is_task_complete": False, 84 | "require_user_input": False, 85 | "content": "Looking up the exchange rates...", 86 | } 87 | elif isinstance(message, ToolMessage): 88 | yield { 89 | "is_task_complete": False, 90 | "require_user_input": False, 91 | "content": "Processing the exchange rates..", 92 | } 93 | 94 | yield self.get_agent_response(config) 95 | 96 | def get_agent_response(self, config): 97 | current_state = self.graph.get_state(config) 98 | structured_response = current_state.values.get("structured_response") 99 | if structured_response and isinstance(structured_response, ResponseFormat): 100 | if structured_response.status == "input_required": 101 | return { 102 | "is_task_complete": False, 103 | "require_user_input": True, 104 | "content": structured_response.message, 105 | } 106 | elif structured_response.status == "error": 107 | return { 108 | "is_task_complete": False, 109 | "require_user_input": True, 110 | "content": structured_response.message, 111 | } 112 | elif structured_response.status == "completed": 113 | return { 114 | "is_task_complete": True, 115 | "require_user_input": False, 116 | "content": structured_response.message, 117 | } 118 | 119 | return { 120 | "is_task_complete": False, 121 | "require_user_input": True, 122 | "content": "We are unable to process your request at the moment. Please try again.", 123 | } 124 | 125 | SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] 126 | -------------------------------------------------------------------------------- /agentpartner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import click 5 | from dotenv import load_dotenv 6 | 7 | from agent import CurrencyAgent 8 | from custom_types import AgentCapabilities, AgentCard, AgentSkill, MissingAPIKeyError 9 | from push_notification_auth import PushNotificationSenderAuth 10 | from server import A2AServer 11 | from task_manager import AgentTaskManager 12 | 13 | load_dotenv() 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @click.command() 20 | @click.option("--host", "host", default="localhost") 21 | @click.option("--port", "port", default=8000) 22 | def main(host, port): 23 | """Starts the Currency Agent server.""" 24 | try: 25 | if not os.getenv("OPENAI_API_KEY"): 26 | raise MissingAPIKeyError("OPENAI_API_KEY environment variable not set.") 27 | 28 | capabilities = AgentCapabilities(streaming=True, pushNotifications=False) 29 | skill = AgentSkill( 30 | id="convert_currency", 31 | name="Currency Exchange Rates Tool", 32 | description="Helps with exchange values between various currencies", 33 | tags=["currency conversion", "currency exchange"], 34 | examples=["What is exchange rate between USD and GBP?"], 35 | ) 36 | agent_card = AgentCard( 37 | name="Currency Agent", 38 | description="Helps with exchange rates for currencies", 39 | url=f"http://{host}:{port}/", 40 | version="1.0.0", 41 | defaultInputModes=CurrencyAgent.SUPPORTED_CONTENT_TYPES, 42 | defaultOutputModes=CurrencyAgent.SUPPORTED_CONTENT_TYPES, 43 | capabilities=capabilities, 44 | skills=[skill], 45 | ) 46 | 47 | notification_sender_auth = PushNotificationSenderAuth() 48 | notification_sender_auth.generate_jwk() 49 | server = A2AServer( 50 | agent_card=agent_card, 51 | task_manager=AgentTaskManager( 52 | agent=CurrencyAgent(), notification_sender_auth=notification_sender_auth 53 | ), 54 | host=host, 55 | port=port, 56 | ) 57 | 58 | server.app.add_route( 59 | "/.well-known/jwks.json", 60 | notification_sender_auth.handle_jwks_endpoint, 61 | methods=["GET"], 62 | ) 63 | 64 | logger.info(f"Starting server on {host}:{port}") 65 | server.start() 66 | except MissingAPIKeyError as e: 67 | logger.error(f"Error: {e}") 68 | exit(1) 69 | except Exception as e: 70 | logger.error(f"An error occurred during server startup: {e}") 71 | exit(1) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /card_resolver.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import httpx 4 | 5 | from custom_types import A2AClientJSONError, AgentCard 6 | 7 | 8 | class A2ACardResolver: 9 | def __init__(self, base_url, agent_card_path="/.well-known/agent.json"): 10 | self.base_url = base_url.rstrip("/") 11 | self.agent_card_path = agent_card_path.lstrip("/") 12 | 13 | def get_agent_card(self) -> AgentCard: 14 | with httpx.Client() as client: 15 | response = client.get(self.base_url + "/" + self.agent_card_path) 16 | response.raise_for_status() 17 | try: 18 | return AgentCard(**response.json()) 19 | except json.JSONDecodeError as e: 20 | raise A2AClientJSONError(str(e)) from e 21 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, AsyncIterable 3 | 4 | import httpx 5 | from httpx_sse import connect_sse 6 | 7 | from custom_types import ( 8 | A2AClientHTTPError, 9 | A2AClientJSONError, 10 | AgentCard, 11 | CancelTaskRequest, 12 | CancelTaskResponse, 13 | GetTaskPushNotificationRequest, 14 | GetTaskPushNotificationResponse, 15 | GetTaskRequest, 16 | GetTaskResponse, 17 | JSONRPCRequest, 18 | SendTaskRequest, 19 | SendTaskResponse, 20 | SendTaskStreamingRequest, 21 | SendTaskStreamingResponse, 22 | SetTaskPushNotificationRequest, 23 | SetTaskPushNotificationResponse, 24 | ) 25 | 26 | 27 | class A2AClient: 28 | def __init__(self, agent_card: AgentCard = None, url: str = None): 29 | if agent_card: 30 | self.url = agent_card.url 31 | elif url: 32 | self.url = url 33 | else: 34 | raise ValueError("Must provide either agent_card or url") 35 | 36 | async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse: 37 | request = SendTaskRequest(params=payload) 38 | return SendTaskResponse(**await self._send_request(request)) 39 | 40 | async def send_task_streaming( 41 | self, payload: dict[str, Any] 42 | ) -> AsyncIterable[SendTaskStreamingResponse]: 43 | request = SendTaskStreamingRequest(params=payload) 44 | with httpx.Client(timeout=None) as client: 45 | with connect_sse( 46 | client, "POST", self.url, json=request.model_dump() 47 | ) as event_source: 48 | try: 49 | for sse in event_source.iter_sse(): 50 | yield SendTaskStreamingResponse(**json.loads(sse.data)) 51 | except json.JSONDecodeError as e: 52 | raise A2AClientJSONError(str(e)) from e 53 | except httpx.RequestError as e: 54 | raise A2AClientHTTPError(400, str(e)) from e 55 | 56 | async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]: 57 | async with httpx.AsyncClient() as client: 58 | try: 59 | # Image generation could take time, adding timeout 60 | response = await client.post( 61 | self.url, json=request.model_dump(), timeout=30 62 | ) 63 | response.raise_for_status() 64 | return response.json() 65 | except httpx.HTTPStatusError as e: 66 | raise A2AClientHTTPError(e.response.status_code, str(e)) from e 67 | except json.JSONDecodeError as e: 68 | raise A2AClientJSONError(str(e)) from e 69 | 70 | async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse: 71 | request = GetTaskRequest(params=payload) 72 | return GetTaskResponse(**await self._send_request(request)) 73 | 74 | async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse: 75 | request = CancelTaskRequest(params=payload) 76 | return CancelTaskResponse(**await self._send_request(request)) 77 | 78 | async def set_task_callback( 79 | self, payload: dict[str, Any] 80 | ) -> SetTaskPushNotificationResponse: 81 | request = SetTaskPushNotificationRequest(params=payload) 82 | return SetTaskPushNotificationResponse(**await self._send_request(request)) 83 | 84 | async def get_task_callback( 85 | self, payload: dict[str, Any] 86 | ) -> GetTaskPushNotificationResponse: 87 | request = GetTaskPushNotificationRequest(params=payload) 88 | return GetTaskPushNotificationResponse(**await self._send_request(request)) 89 | -------------------------------------------------------------------------------- /custom_types.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import Enum 3 | from typing import Annotated, Any, List, Literal, Optional, Union 4 | from uuid import uuid4 5 | 6 | from pydantic import ( 7 | BaseModel, 8 | ConfigDict, 9 | Field, 10 | TypeAdapter, 11 | field_serializer, 12 | model_validator, 13 | ) 14 | from typing_extensions import Self 15 | 16 | 17 | class TaskState(str, Enum): 18 | SUBMITTED = "submitted" 19 | WORKING = "working" 20 | INPUT_REQUIRED = "input-required" 21 | COMPLETED = "completed" 22 | CANCELED = "canceled" 23 | FAILED = "failed" 24 | UNKNOWN = "unknown" 25 | 26 | 27 | class TextPart(BaseModel): 28 | type: Literal["text"] = "text" 29 | text: str 30 | metadata: dict[str, Any] | None = None 31 | 32 | 33 | class FileContent(BaseModel): 34 | name: str | None = None 35 | mimeType: str | None = None 36 | bytes: str | None = None 37 | uri: str | None = None 38 | 39 | @model_validator(mode="after") 40 | def check_content(self) -> Self: 41 | if not (self.bytes or self.uri): 42 | raise ValueError("Either 'bytes' or 'uri' must be present in the file data") 43 | if self.bytes and self.uri: 44 | raise ValueError( 45 | "Only one of 'bytes' or 'uri' can be present in the file data" 46 | ) 47 | return self 48 | 49 | 50 | class FilePart(BaseModel): 51 | type: Literal["file"] = "file" 52 | file: FileContent 53 | metadata: dict[str, Any] | None = None 54 | 55 | 56 | class DataPart(BaseModel): 57 | type: Literal["data"] = "data" 58 | data: dict[str, Any] 59 | metadata: dict[str, Any] | None = None 60 | 61 | 62 | Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")] 63 | 64 | 65 | class Message(BaseModel): 66 | role: Literal["user", "agent"] 67 | parts: List[Part] 68 | metadata: dict[str, Any] | None = None 69 | 70 | 71 | class TaskStatus(BaseModel): 72 | state: TaskState 73 | message: Message | None = None 74 | timestamp: datetime = Field(default_factory=datetime.now) 75 | 76 | @field_serializer("timestamp") 77 | def serialize_dt(self, dt: datetime, _info): 78 | return dt.isoformat() 79 | 80 | 81 | class Artifact(BaseModel): 82 | name: str | None = None 83 | description: str | None = None 84 | parts: List[Part] 85 | metadata: dict[str, Any] | None = None 86 | index: int = 0 87 | append: bool | None = None 88 | lastChunk: bool | None = None 89 | 90 | 91 | class Task(BaseModel): 92 | id: str 93 | sessionId: str | None = None 94 | status: TaskStatus 95 | artifacts: List[Artifact] | None = None 96 | history: List[Message] | None = None 97 | metadata: dict[str, Any] | None = None 98 | 99 | 100 | class TaskStatusUpdateEvent(BaseModel): 101 | id: str 102 | status: TaskStatus 103 | final: bool = False 104 | metadata: dict[str, Any] | None = None 105 | 106 | 107 | class TaskArtifactUpdateEvent(BaseModel): 108 | id: str 109 | artifact: Artifact 110 | metadata: dict[str, Any] | None = None 111 | 112 | 113 | class AuthenticationInfo(BaseModel): 114 | model_config = ConfigDict(extra="allow") 115 | 116 | schemes: List[str] 117 | credentials: str | None = None 118 | 119 | 120 | class PushNotificationConfig(BaseModel): 121 | url: str 122 | token: str | None = None 123 | authentication: AuthenticationInfo | None = None 124 | 125 | 126 | class TaskIdParams(BaseModel): 127 | id: str 128 | metadata: dict[str, Any] | None = None 129 | 130 | 131 | class TaskQueryParams(TaskIdParams): 132 | historyLength: int | None = None 133 | 134 | 135 | class TaskSendParams(BaseModel): 136 | id: str 137 | sessionId: str = Field(default_factory=lambda: uuid4().hex) 138 | message: Message 139 | acceptedOutputModes: Optional[List[str]] = None 140 | pushNotification: PushNotificationConfig | None = None 141 | historyLength: int | None = None 142 | metadata: dict[str, Any] | None = None 143 | 144 | 145 | class TaskPushNotificationConfig(BaseModel): 146 | id: str 147 | pushNotificationConfig: PushNotificationConfig 148 | 149 | 150 | ## RPC Messages 151 | 152 | 153 | class JSONRPCMessage(BaseModel): 154 | jsonrpc: Literal["2.0"] = "2.0" 155 | id: int | str | None = Field(default_factory=lambda: uuid4().hex) 156 | 157 | 158 | class JSONRPCRequest(JSONRPCMessage): 159 | method: str 160 | params: dict[str, Any] | None = None 161 | 162 | 163 | class JSONRPCError(BaseModel): 164 | code: int 165 | message: str 166 | data: Any | None = None 167 | 168 | 169 | class JSONRPCResponse(JSONRPCMessage): 170 | result: Any | None = None 171 | error: JSONRPCError | None = None 172 | 173 | 174 | class SendTaskRequest(JSONRPCRequest): 175 | method: Literal["tasks/send"] = "tasks/send" 176 | params: TaskSendParams 177 | 178 | 179 | class SendTaskResponse(JSONRPCResponse): 180 | result: Task | None = None 181 | 182 | 183 | class SendTaskStreamingRequest(JSONRPCRequest): 184 | method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe" 185 | params: TaskSendParams 186 | 187 | 188 | class SendTaskStreamingResponse(JSONRPCResponse): 189 | result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None 190 | 191 | 192 | class GetTaskRequest(JSONRPCRequest): 193 | method: Literal["tasks/get"] = "tasks/get" 194 | params: TaskQueryParams 195 | 196 | 197 | class GetTaskResponse(JSONRPCResponse): 198 | result: Task | None = None 199 | 200 | 201 | class CancelTaskRequest(JSONRPCRequest): 202 | method: Literal["tasks/cancel",] = "tasks/cancel" 203 | params: TaskIdParams 204 | 205 | 206 | class CancelTaskResponse(JSONRPCResponse): 207 | result: Task | None = None 208 | 209 | 210 | class SetTaskPushNotificationRequest(JSONRPCRequest): 211 | method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set" 212 | params: TaskPushNotificationConfig 213 | 214 | 215 | class SetTaskPushNotificationResponse(JSONRPCResponse): 216 | result: TaskPushNotificationConfig | None = None 217 | 218 | 219 | class GetTaskPushNotificationRequest(JSONRPCRequest): 220 | method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get" 221 | params: TaskIdParams 222 | 223 | 224 | class GetTaskPushNotificationResponse(JSONRPCResponse): 225 | result: TaskPushNotificationConfig | None = None 226 | 227 | 228 | class TaskResubscriptionRequest(JSONRPCRequest): 229 | method: Literal["tasks/resubscribe",] = "tasks/resubscribe" 230 | params: TaskIdParams 231 | 232 | 233 | A2ARequest = TypeAdapter( 234 | Annotated[ 235 | Union[ 236 | SendTaskRequest, 237 | GetTaskRequest, 238 | CancelTaskRequest, 239 | SetTaskPushNotificationRequest, 240 | GetTaskPushNotificationRequest, 241 | TaskResubscriptionRequest, 242 | SendTaskStreamingRequest, 243 | ], 244 | Field(discriminator="method"), 245 | ] 246 | ) 247 | 248 | ## Error types 249 | 250 | 251 | class JSONParseError(JSONRPCError): 252 | code: int = -32700 253 | message: str = "Invalid JSON payload" 254 | data: Any | None = None 255 | 256 | 257 | class InvalidRequestError(JSONRPCError): 258 | code: int = -32600 259 | message: str = "Request payload validation error" 260 | data: Any | None = None 261 | 262 | 263 | class MethodNotFoundError(JSONRPCError): 264 | code: int = -32601 265 | message: str = "Method not found" 266 | data: None = None 267 | 268 | 269 | class InvalidParamsError(JSONRPCError): 270 | code: int = -32602 271 | message: str = "Invalid parameters" 272 | data: Any | None = None 273 | 274 | 275 | class InternalError(JSONRPCError): 276 | code: int = -32603 277 | message: str = "Internal error" 278 | data: Any | None = None 279 | 280 | 281 | class TaskNotFoundError(JSONRPCError): 282 | code: int = -32001 283 | message: str = "Task not found" 284 | data: None = None 285 | 286 | 287 | class TaskNotCancelableError(JSONRPCError): 288 | code: int = -32002 289 | message: str = "Task cannot be canceled" 290 | data: None = None 291 | 292 | 293 | class PushNotificationNotSupportedError(JSONRPCError): 294 | code: int = -32003 295 | message: str = "Push Notification is not supported" 296 | data: None = None 297 | 298 | 299 | class UnsupportedOperationError(JSONRPCError): 300 | code: int = -32004 301 | message: str = "This operation is not supported" 302 | data: None = None 303 | 304 | 305 | class ContentTypeNotSupportedError(JSONRPCError): 306 | code: int = -32005 307 | message: str = "Incompatible content types" 308 | data: None = None 309 | 310 | 311 | class AgentProvider(BaseModel): 312 | organization: str 313 | url: str | None = None 314 | 315 | 316 | class AgentCapabilities(BaseModel): 317 | streaming: bool = False 318 | pushNotifications: bool = False 319 | stateTransitionHistory: bool = False 320 | 321 | 322 | class AgentAuthentication(BaseModel): 323 | schemes: List[str] 324 | credentials: str | None = None 325 | 326 | 327 | class AgentSkill(BaseModel): 328 | id: str 329 | name: str 330 | description: str | None = None 331 | tags: List[str] | None = None 332 | examples: List[str] | None = None 333 | inputModes: List[str] | None = None 334 | outputModes: List[str] | None = None 335 | 336 | 337 | class AgentCard(BaseModel): 338 | name: str 339 | description: str | None = None 340 | url: str 341 | provider: AgentProvider | None = None 342 | version: str 343 | documentationUrl: str | None = None 344 | capabilities: AgentCapabilities 345 | authentication: AgentAuthentication | None = None 346 | defaultInputModes: List[str] = ["text"] 347 | defaultOutputModes: List[str] = ["text"] 348 | skills: List[AgentSkill] 349 | 350 | 351 | class A2AClientError(Exception): 352 | pass 353 | 354 | 355 | class A2AClientHTTPError(A2AClientError): 356 | def __init__(self, status_code: int, message: str): 357 | self.status_code = status_code 358 | self.message = message 359 | super().__init__(f"HTTP Error {status_code}: {message}") 360 | 361 | 362 | class A2AClientJSONError(A2AClientError): 363 | def __init__(self, message: str): 364 | self.message = message 365 | super().__init__(f"JSON Error: {message}") 366 | 367 | 368 | class MissingAPIKeyError(Exception): 369 | """Exception for missing API key.""" 370 | 371 | pass 372 | -------------------------------------------------------------------------------- /google_host_agent.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import json 4 | import uuid 5 | from typing import Callable, List 6 | 7 | from google.adk import Agent 8 | from google.adk.agents.callback_context import CallbackContext 9 | from google.adk.agents.invocation_context import InvocationContext 10 | from google.adk.agents.readonly_context import ReadonlyContext 11 | from google.adk.tools.tool_context import ToolContext 12 | from google.genai import types 13 | 14 | from card_resolver import A2ACardResolver 15 | from client import A2AClient 16 | from custom_types import ( 17 | AgentCard, 18 | DataPart, 19 | Message, 20 | Part, 21 | Task, 22 | TaskArtifactUpdateEvent, 23 | TaskSendParams, 24 | TaskState, 25 | TaskStatus, 26 | TaskStatusUpdateEvent, 27 | TextPart, 28 | ) 29 | 30 | TaskCallbackArg = Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent 31 | TaskUpdateCallback = Callable[[TaskCallbackArg], Task] 32 | 33 | 34 | class RemoteAgentConnections: 35 | """A class to hold the connections to the remote agents.""" 36 | 37 | def __init__(self, agent_card: AgentCard): 38 | self.agent_client = A2AClient(agent_card) 39 | self.card = agent_card 40 | 41 | self.conversation_name = None 42 | self.conversation = None 43 | self.pending_tasks = set() 44 | 45 | def get_agent(self) -> AgentCard: 46 | return self.card 47 | 48 | async def send_task( 49 | self, 50 | request: TaskSendParams, 51 | task_callback: TaskUpdateCallback | None, 52 | ) -> Task | None: 53 | if self.card.capabilities.streaming: 54 | task = None 55 | if task_callback: 56 | task_callback( 57 | Task( 58 | id=request.id, 59 | sessionId=request.sessionId, 60 | status=TaskStatus( 61 | state=TaskState.SUBMITTED, 62 | message=request.message, 63 | ), 64 | history=[request.message], 65 | ) 66 | ) 67 | async for response in self.agent_client.send_task_streaming( 68 | request.model_dump() 69 | ): 70 | merge_metadata(response.result, request) 71 | # For task status updates, we need to propagate metadata and provide 72 | # a unique message id. 73 | if ( 74 | hasattr(response.result, "status") 75 | and hasattr(response.result.status, "message") 76 | and response.result.status.message 77 | ): 78 | merge_metadata(response.result.status.message, request.message) 79 | m = response.result.status.message 80 | if not m.metadata: 81 | m.metadata = {} 82 | if "message_id" in m.metadata: 83 | m.metadata["last_message_id"] = m.metadata["message_id"] 84 | m.metadata["message_id"] = str(uuid.uuid4()) 85 | if task_callback: 86 | task = task_callback(response.result) 87 | if hasattr(response.result, "final") and response.result.final: 88 | break 89 | return task 90 | else: # Non-streaming 91 | response = await self.agent_client.send_task(request.model_dump()) 92 | merge_metadata(response.result, request) 93 | # For task status updates, we need to propagate metadata and provide 94 | # a unique message id. 95 | if ( 96 | hasattr(response.result, "status") 97 | and hasattr(response.result.status, "message") 98 | and response.result.status.message 99 | ): 100 | merge_metadata(response.result.status.message, request.message) 101 | m = response.result.status.message 102 | if not m.metadata: 103 | m.metadata = {} 104 | if "message_id" in m.metadata: 105 | m.metadata["last_message_id"] = m.metadata["message_id"] 106 | m.metadata["message_id"] = str(uuid.uuid4()) 107 | 108 | if task_callback: 109 | task_callback(response.result) 110 | return response.result 111 | 112 | 113 | def merge_metadata(target, source): 114 | if not hasattr(target, "metadata") or not hasattr(source, "metadata"): 115 | return 116 | if target.metadata and source.metadata: 117 | target.metadata.update(source.metadata) 118 | elif source.metadata: 119 | target.metadata = dict(**source.metadata) 120 | 121 | 122 | class HostAgent: 123 | """The host agent. 124 | 125 | This is the agent responsible for choosing which remote agents to send 126 | tasks to and coordinate their work. 127 | """ 128 | 129 | def __init__( 130 | self, 131 | remote_agent_addresses: List[str], 132 | task_callback: TaskUpdateCallback | None = None, 133 | ): 134 | self.task_callback = task_callback 135 | self.remote_agent_connections: dict[str, RemoteAgentConnections] = {} 136 | self.cards: dict[str, AgentCard] = {} 137 | for address in remote_agent_addresses: 138 | card_resolver = A2ACardResolver(address) 139 | card = card_resolver.get_agent_card() 140 | remote_connection = RemoteAgentConnections(card) 141 | self.remote_agent_connections[card.name] = remote_connection 142 | self.cards[card.name] = card 143 | agent_info = [] 144 | for ra in self.list_remote_agents(): 145 | agent_info.append(json.dumps(ra)) 146 | self.agents = "\n".join(agent_info) 147 | 148 | def register_agent_card(self, card: AgentCard): 149 | remote_connection = RemoteAgentConnections(card) 150 | self.remote_agent_connections[card.name] = remote_connection 151 | self.cards[card.name] = card 152 | agent_info = [] 153 | for ra in self.list_remote_agents(): 154 | agent_info.append(json.dumps(ra)) 155 | self.agents = "\n".join(agent_info) 156 | 157 | def create_agent(self) -> Agent: 158 | agent = Agent( 159 | model="gemini-2.0-flash-001", 160 | name="host_agent", 161 | instruction=self.root_instruction, 162 | before_model_callback=self.before_model_callback, 163 | description=( 164 | "This agent orchestrates the decomposition of the user request into" 165 | " tasks that can be performed by the child agents." 166 | ), 167 | tools=[ 168 | self.list_remote_agents, 169 | self.send_task, 170 | ], 171 | ) 172 | agent 173 | return agent 174 | 175 | def root_instruction(self, context: ReadonlyContext) -> str: 176 | current_agent = self.check_state(context) 177 | return f"""You are a expert delegator that can delegate the user request to the 178 | appropriate remote agents. 179 | 180 | Discovery: 181 | - You can use `list_remote_agents` to list the available remote agents you 182 | can use to delegate the task. 183 | 184 | Execution: 185 | - For actionable tasks, you can use `create_task` to assign tasks to remote agents to perform. 186 | Be sure to include the remote agent name when you response to the user. 187 | 188 | You can use `check_pending_task_states` to check the states of the pending 189 | tasks. 190 | 191 | Please rely on tools to address the request, don't make up the response. If you are not sure, please ask the user for more details. 192 | Focus on the most recent parts of the conversation primarily. 193 | 194 | If there is an active agent, send the request to that agent with the update task tool. 195 | 196 | Agents: 197 | {self.agents} 198 | 199 | Current agent: {current_agent["active_agent"]} 200 | """ 201 | 202 | def check_state(self, context: ReadonlyContext): 203 | state = context.state 204 | if ( 205 | "session_id" in state 206 | and "session_active" in state 207 | and state["session_active"] 208 | and "agent" in state 209 | ): 210 | return {"active_agent": f"{state['agent']}"} 211 | return {"active_agent": "None"} 212 | 213 | def before_model_callback(self, callback_context: CallbackContext, llm_request): 214 | state = callback_context.state 215 | if "session_active" not in state or not state["session_active"]: 216 | if "session_id" not in state: 217 | state["session_id"] = str(uuid.uuid4()) 218 | state["session_active"] = True 219 | 220 | def list_remote_agents(self): 221 | """List the available remote agents you can use to delegate the task.""" 222 | if not self.remote_agent_connections: 223 | return [] 224 | 225 | remote_agent_info = [] 226 | for card in self.cards.values(): 227 | remote_agent_info.append( 228 | {"name": card.name, "description": card.description} 229 | ) 230 | return remote_agent_info 231 | 232 | async def send_task(self, agent_name: str, message: str, tool_context: ToolContext): 233 | """Sends a task either streaming (if supported) or non-streaming. 234 | 235 | This will send a message to the remote agent named agent_name. 236 | 237 | Args: 238 | agent_name: The name of the agent to send the task to. 239 | message: The message to send to the agent for the task. 240 | tool_context: The tool context this method runs in. 241 | 242 | Yields: 243 | A dictionary of JSON data. 244 | """ 245 | if agent_name not in self.remote_agent_connections: 246 | raise ValueError(f"Agent {agent_name} not found") 247 | state = tool_context.state 248 | state["agent"] = agent_name 249 | card = self.cards[agent_name] 250 | client = self.remote_agent_connections[agent_name] 251 | if not client: 252 | raise ValueError(f"Client not available for {agent_name}") 253 | if "task_id" in state: 254 | taskId = state["task_id"] 255 | else: 256 | taskId = str(uuid.uuid4()) 257 | sessionId = state["session_id"] 258 | task: Task 259 | messageId = "" 260 | metadata = {} 261 | if "input_message_metadata" in state: 262 | metadata.update(**state["input_message_metadata"]) 263 | if "message_id" in state["input_message_metadata"]: 264 | messageId = state["input_message_metadata"]["message_id"] 265 | if not messageId: 266 | messageId = str(uuid.uuid4()) 267 | metadata.update(**{"conversation_id": sessionId, "message_id": messageId}) 268 | request: TaskSendParams = TaskSendParams( 269 | id=taskId, 270 | sessionId=sessionId, 271 | message=Message( 272 | role="user", 273 | parts=[TextPart(text=message)], 274 | metadata=metadata, 275 | ), 276 | acceptedOutputModes=["text", "text/plain", "image/png"], 277 | # pushNotification=None, 278 | metadata={"conversation_id": sessionId}, 279 | ) 280 | task = await client.send_task(request, self.task_callback) 281 | # Assume completion unless a state returns that isn't complete 282 | state["session_active"] = task.status.state not in [ 283 | TaskState.COMPLETED, 284 | TaskState.CANCELED, 285 | TaskState.FAILED, 286 | TaskState.UNKNOWN, 287 | ] 288 | if task.status.state == TaskState.INPUT_REQUIRED: 289 | # Force user input back 290 | tool_context.actions.skip_summarization = True 291 | tool_context.actions.escalate = True 292 | elif task.status.state == TaskState.CANCELED: 293 | # Open question, should we return some info for cancellation instead 294 | raise ValueError(f"Agent {agent_name} task {task.id} is cancelled") 295 | elif task.status.state == TaskState.FAILED: 296 | # Raise error for failure 297 | raise ValueError(f"Agent {agent_name} task {task.id} failed") 298 | response = [] 299 | if task.status.message: 300 | # Assume the information is in the task message. 301 | response.extend(convert_parts(task.status.message.parts, tool_context)) 302 | if task.artifacts: 303 | for artifact in task.artifacts: 304 | response.extend(convert_parts(artifact.parts, tool_context)) 305 | return response 306 | 307 | 308 | def convert_parts(parts: list[Part], tool_context: ToolContext): 309 | rval = [] 310 | for p in parts: 311 | rval.append(convert_part(p, tool_context)) 312 | return rval 313 | 314 | 315 | def convert_part(part: Part, tool_context: ToolContext): 316 | if part.type == "text": 317 | return part.text 318 | elif part.type == "data": 319 | return part.data 320 | elif part.type == "file": 321 | # Repackage A2A FilePart to google.genai Blob 322 | # Currently not considering plain text as files 323 | file_id = part.file.name 324 | file_bytes = base64.b64decode(part.file.bytes) 325 | file_part = types.Part( 326 | inline_data=types.Blob(mime_type=part.file.mimeType, data=file_bytes) 327 | ) 328 | tool_context.save_artifact(file_id, file_part) 329 | tool_context.actions.skip_summarization = True 330 | tool_context.actions.escalate = True 331 | return DataPart(data={"artifact-file-id": file_id}) 332 | return f"Unknown type: {p.type}" 333 | 334 | 335 | from google.adk.agents.run_config import RunConfig 336 | from google.adk.sessions.in_memory_session_service import InMemorySessionService 337 | 338 | # --------------------------------------------------------- 339 | # Your HostAgent code (from your snippet) 340 | # --------------------------------------------------------- 341 | host_agent = HostAgent(["http://localhost:8000"]) 342 | root_agent = host_agent.create_agent() 343 | 344 | # --------------------------------------------------------- 345 | # 1. Create an in-memory session service 346 | # --------------------------------------------------------- 347 | session_service = InMemorySessionService() 348 | 349 | # --------------------------------------------------------- 350 | # 2. Create a session with the required fields 351 | # --------------------------------------------------------- 352 | my_session = session_service.create_session( 353 | app_name="test_app", user_id="test_user", session_id="session-123" 354 | ) 355 | 356 | # --------------------------------------------------------- 357 | # 3. Provide a basic RunConfig with response_modalities 358 | # --------------------------------------------------------- 359 | run_config = RunConfig(response_modalities=["text"]) 360 | 361 | # --------------------------------------------------------- 362 | # 4. Build the InvocationContext 363 | # --------------------------------------------------------- 364 | context = InvocationContext( 365 | session_service=session_service, 366 | memory_service=None, 367 | artifact_service=None, 368 | session=my_session, # We just created 369 | agent=root_agent, 370 | invocation_id=str(uuid.uuid4()), 371 | run_config=run_config, 372 | ) 373 | 374 | 375 | # --------------------------------------------------------- 376 | # 5. Run your agent with run_async(...) 377 | # --------------------------------------------------------- 378 | async def main(): 379 | async for event in root_agent.run_async(context): 380 | if event.content: 381 | print("Agent output:", event.content.text()) 382 | 383 | 384 | asyncio.run(main()) 385 | -------------------------------------------------------------------------------- /host_agent.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List, Optional 3 | 4 | import requests 5 | import typer 6 | from langchain_core.messages import AIMessage 7 | from langchain_core.tools import tool 8 | from langchain_openai import ChatOpenAI 9 | from langgraph.checkpoint.memory import MemorySaver 10 | from langgraph.prebuilt import create_react_agent 11 | 12 | 13 | class AgentCapabilities: 14 | def __init__( 15 | self, streaming=False, pushNotifications=False, stateTransitionHistory=False 16 | ): 17 | self.streaming = streaming 18 | self.pushNotifications = pushNotifications 19 | self.stateTransitionHistory = stateTransitionHistory 20 | 21 | 22 | class AgentCard: 23 | def __init__( 24 | self, 25 | name: str, 26 | url: str, 27 | version: str, 28 | capabilities: AgentCapabilities, 29 | description: Optional[str] = None, 30 | ): 31 | self.name = name 32 | self.url = url 33 | self.version = version 34 | self.capabilities = capabilities 35 | self.description = description or "No description." 36 | 37 | 38 | class TaskState: 39 | SUBMITTED = "submitted" 40 | COMPLETED = "completed" 41 | FAILED = "failed" 42 | CANCELED = "canceled" 43 | UNKNOWN = "unknown" 44 | INPUT_REQUIRED = "input-required" 45 | 46 | 47 | ############################################################################### 48 | # 2) Synchronous RemoteAgentClient 49 | ############################################################################### 50 | class RemoteAgentClient: 51 | """Communicates with a single remote agent (A2A) in synchronous mode.""" 52 | 53 | def __init__(self, base_url: str): 54 | self.base_url = base_url 55 | self.agent_card: Optional[AgentCard] = None 56 | 57 | def fetch_agent_card(self) -> AgentCard: 58 | """GET /.well-known/agent.json to retrieve the remote agent's card.""" 59 | url = f"{self.base_url}/.well-known/agent.json" 60 | resp = requests.get(url, timeout=10) 61 | resp.raise_for_status() 62 | data = resp.json() 63 | 64 | caps_data = data["capabilities"] 65 | caps = AgentCapabilities(**caps_data) 66 | 67 | card = AgentCard( 68 | name=data["name"], 69 | url=self.base_url, 70 | version=data["version"], 71 | capabilities=caps, 72 | description=data.get("description", ""), 73 | ) 74 | self.agent_card = card 75 | return card 76 | 77 | def send_task(self, task_id: str, session_id: str, message_text: str) -> dict: 78 | """POST / with JSON-RPC request: method=tasks/send.""" 79 | payload = { 80 | "jsonrpc": "2.0", 81 | "id": str(uuid.uuid4()), 82 | "method": "tasks/send", 83 | "params": { 84 | "id": task_id, 85 | "sessionId": session_id, 86 | "message": { 87 | "role": "user", 88 | "parts": [{"type": "text", "text": message_text}], 89 | }, 90 | }, 91 | } 92 | r = requests.post(self.base_url, json=payload, timeout=30) 93 | r.raise_for_status() 94 | resp = r.json() 95 | if "error" in resp and resp["error"] is not None: 96 | raise RuntimeError(f"Remote agent error: {resp['error']}") 97 | return resp.get("result", {}) 98 | 99 | 100 | class HostAgent: 101 | """Holds references to multiple RemoteAgentClients, one per address.""" 102 | 103 | def __init__(self, remote_addresses: List[str]): 104 | self.clients = {} 105 | for addr in remote_addresses: 106 | self.clients[addr] = RemoteAgentClient(addr) 107 | 108 | def initialize(self): 109 | """Fetch agent cards for all addresses (synchronously).""" 110 | for addr, client in self.clients.items(): 111 | client.fetch_agent_card() 112 | 113 | def list_agents_info(self) -> list: 114 | """Return a list of {name, description, url, streaming} for each loaded agent.""" 115 | infos = [] 116 | for addr, c in self.clients.items(): 117 | card = c.agent_card 118 | if card: 119 | infos.append( 120 | { 121 | "name": card.name, 122 | "description": card.description, 123 | "url": card.url, 124 | "streaming": card.capabilities.streaming, 125 | } 126 | ) 127 | else: 128 | infos.append( 129 | { 130 | "name": "Unknown", 131 | "description": "Not loaded", 132 | "url": addr, 133 | "streaming": False, 134 | } 135 | ) 136 | return infos 137 | 138 | def get_client_by_name(self, agent_name: str) -> Optional[RemoteAgentClient]: 139 | """Find a client whose AgentCard name matches `agent_name`.""" 140 | for c in self.clients.values(): 141 | if c.agent_card and c.agent_card.name == agent_name: 142 | return c 143 | return None 144 | 145 | def send_task(self, agent_name: str, message: str) -> str: 146 | """ 147 | Actually send the user's request to the remote agent via tasks/send JSON-RPC. 148 | Returns a textual summary or error message. 149 | """ 150 | client = self.get_client_by_name(agent_name) 151 | if not client or not client.agent_card: 152 | return f"Error: No agent card found for '{agent_name}'." 153 | 154 | task_id = str(uuid.uuid4()) 155 | session_id = "session-xyz" 156 | 157 | try: 158 | result = client.send_task(task_id, session_id, message) 159 | # Check final state 160 | state = result.get("status", {}).get("state", "unknown") 161 | if state == TaskState.COMPLETED: 162 | return f"Task {task_id} completed with message: {result}" 163 | elif state == TaskState.INPUT_REQUIRED: 164 | return f"Task {task_id} needs more input: {result}" 165 | else: 166 | return f"Task {task_id} ended with state={state}, result={result}" 167 | except Exception as exc: 168 | return f"Remote agent call failed: {exc}" 169 | 170 | 171 | def make_list_agents_tool(host_agent: HostAgent): 172 | """Return a synchronous tool function that calls host_agent.list_agents_info().""" 173 | 174 | @tool 175 | def list_remote_agents_tool() -> list: 176 | """List available remote agents (name, url, streaming).""" 177 | return host_agent.list_agents_info() 178 | 179 | return list_remote_agents_tool 180 | 181 | 182 | def make_send_task_tool(host_agent: HostAgent): 183 | """Return a synchronous tool function that calls host_agent.send_task(...).""" 184 | 185 | @tool 186 | def send_task_tool(agent_name: str, message: str) -> str: 187 | """ 188 | Synchronous tool: sends 'message' to 'agent_name' 189 | via JSON-RPC and returns the result. 190 | """ 191 | return host_agent.send_task(agent_name, message) 192 | 193 | return send_task_tool 194 | 195 | 196 | def build_react_agent(host_agent: HostAgent): 197 | # Create the top-level LLM 198 | llm = ChatOpenAI(model="gpt-4o") 199 | memory = MemorySaver() 200 | 201 | # Make the two tools referencing our host_agent 202 | list_tool = make_list_agents_tool(host_agent) 203 | send_tool = make_send_task_tool(host_agent) 204 | 205 | system_prompt = """ 206 | You are a Host Agent that delegates requests to known remote agents. 207 | You have two tools: 208 | 1) list_remote_agents_tool(): Lists the remote agents (their name, URL, streaming). 209 | 2) send_task_tool(agent_name, message): Sends a text request to the agent. 210 | 211 | If the user wants currency conversion, call 'send_task_tool("some_agent_name", "5 USD to EUR")'. 212 | If the user wants weather info, call 'send_task_tool("some_weather_agent", "Weather in city")'. 213 | 214 | Return the final result to the user. 215 | """ 216 | 217 | agent = create_react_agent( 218 | model=llm, 219 | tools=[list_tool, send_tool], 220 | checkpointer=memory, 221 | prompt=system_prompt, 222 | ) 223 | return agent 224 | 225 | 226 | app = typer.Typer() 227 | 228 | 229 | @app.command() 230 | def run_agent(remote_url: str = "http://localhost:8000"): 231 | """ 232 | Start a synchronous HostAgent pointing at 'remote_url' 233 | and run a simple conversation loop. 234 | """ 235 | # 1) Build the HostAgent 236 | host_agent = HostAgent([remote_url]) 237 | 238 | host_agent.initialize() 239 | react_agent = build_react_agent(host_agent) 240 | 241 | typer.echo(f"Host agent ready. Connected to: {remote_url}") 242 | typer.echo("Type 'quit' or 'exit' to stop.") 243 | 244 | while True: 245 | user_msg = typer.prompt("\nUser") 246 | if user_msg.strip().lower() in ["quit", "exit", "bye"]: 247 | typer.echo("Goodbye!") 248 | break 249 | 250 | raw_result = react_agent.invoke( 251 | {"messages": [{"role": "user", "content": user_msg}]}, 252 | config={"configurable": {"thread_id": "cli-session"}}, 253 | ) 254 | 255 | final_text = None 256 | 257 | # If 'raw_result' is a dictionary with "messages", try to find the last AIMessage 258 | if isinstance(raw_result, dict) and "messages" in raw_result: 259 | all_msgs = raw_result["messages"] 260 | for msg in reversed(all_msgs): 261 | if isinstance(msg, AIMessage): 262 | final_text = msg.content 263 | break 264 | else: 265 | # Otherwise, it's likely a plain string 266 | if isinstance(raw_result, str): 267 | final_text = raw_result 268 | else: 269 | # fallback: convert whatever it is to string 270 | final_text = str(raw_result) 271 | 272 | # Now print only the final AIMessage content 273 | typer.echo(f"HostAgent: {final_text}") 274 | 275 | 276 | def main(): 277 | """ 278 | Entry point for 'python sync_host_agent_cli.py run-agent --remote-url http://whatever' 279 | """ 280 | app() 281 | 282 | 283 | if __name__ == "__main__": 284 | main() 285 | -------------------------------------------------------------------------------- /in_memory_cache.py: -------------------------------------------------------------------------------- 1 | """In Memory Cache utility.""" 2 | 3 | import threading 4 | import time 5 | from typing import Any, Dict, Optional 6 | 7 | 8 | class InMemoryCache: 9 | """A thread-safe Singleton class to manage cache data. 10 | 11 | Ensures only one instance of the cache exists across the application. 12 | """ 13 | 14 | _instance: Optional["InMemoryCache"] = None 15 | _lock: threading.Lock = threading.Lock() 16 | _initialized: bool = False 17 | 18 | def __new__(cls): 19 | """Override __new__ to control instance creation (Singleton pattern). 20 | 21 | Uses a lock to ensure thread safety during the first instantiation. 22 | 23 | Returns: 24 | The singleton instance of InMemoryCache. 25 | """ 26 | if cls._instance is None: 27 | with cls._lock: 28 | if cls._instance is None: 29 | cls._instance = super().__new__(cls) 30 | return cls._instance 31 | 32 | def __init__(self): 33 | """Initialize the cache storage. 34 | 35 | Uses a flag (_initialized) to ensure this logic runs only on the very first 36 | creation of the singleton instance. 37 | """ 38 | if not self._initialized: 39 | with self._lock: 40 | if not self._initialized: 41 | # print("Initializing SessionCache storage") 42 | self._cache_data: Dict[str, Dict[str, Any]] = {} 43 | self._ttl: Dict[str, float] = {} 44 | self._data_lock: threading.Lock = threading.Lock() 45 | self._initialized = True 46 | 47 | def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 48 | """Set a key-value pair. 49 | 50 | Args: 51 | key: The key for the data. 52 | value: The data to store. 53 | ttl: Time to live in seconds. If None, data will not expire. 54 | """ 55 | with self._data_lock: 56 | self._cache_data[key] = value 57 | 58 | if ttl is not None: 59 | self._ttl[key] = time.time() + ttl 60 | else: 61 | if key in self._ttl: 62 | del self._ttl[key] 63 | 64 | def get(self, key: str, default: Any = None) -> Any: 65 | """Get the value associated with a key. 66 | 67 | Args: 68 | key: The key for the data within the session. 69 | default: The value to return if the session or key is not found. 70 | 71 | Returns: 72 | The cached value, or the default value if not found. 73 | """ 74 | with self._data_lock: 75 | if key in self._ttl and time.time() > self._ttl[key]: 76 | del self._cache_data[key] 77 | del self._ttl[key] 78 | return default 79 | return self._cache_data.get(key, default) 80 | 81 | def delete(self, key: str) -> None: 82 | """Delete a specific key-value pair from a cache. 83 | 84 | Args: 85 | key: The key to delete. 86 | 87 | Returns: 88 | True if the key was found and deleted, False otherwise. 89 | """ 90 | 91 | with self._data_lock: 92 | if key in self._cache_data: 93 | del self._cache_data[key] 94 | if key in self._ttl: 95 | del self._ttl[key] 96 | return True 97 | return False 98 | 99 | def clear(self) -> bool: 100 | """Remove all data. 101 | 102 | Returns: 103 | True if the data was cleared, False otherwise. 104 | """ 105 | with self._data_lock: 106 | self._cache_data.clear() 107 | self._ttl.clear() 108 | return True 109 | return False 110 | -------------------------------------------------------------------------------- /mcp_app.py: -------------------------------------------------------------------------------- 1 | from mcp.server.fastmcp import FastMCP 2 | 3 | mcp = FastMCP(name="MinimalServer", host="0.0.0.0", port=3000) 4 | 5 | 6 | @mcp.tool() 7 | def get_exchange_rate( 8 | currency_from: str = "USD", 9 | currency_to: str = "EUR", 10 | currency_date: str = "latest", 11 | ): 12 | """Dummy-Tool, das 1:1 eine statische Antwort zurückgibt, anstelle eines echten API-Calls. 13 | 14 | Args: 15 | currency_from: Die Quellwährung (z.B. "USD"). 16 | currency_to: Die Zielwährung (z.B. "EUR"). 17 | currency_date: Das Datum für den Wechselkurs oder "latest". Standard "latest". 18 | 19 | Returns: 20 | Ein Dictionary mit statischen Placeholder-Daten. 21 | """ 22 | return { 23 | "amount": 1, 24 | "base": currency_from, 25 | "date": currency_date, 26 | "rates": {currency_to: 0.85}, # Beispiel-Rate 27 | } 28 | 29 | 30 | if __name__ == "__main__": 31 | mcp.run(transport="sse") 32 | -------------------------------------------------------------------------------- /push_notification_auth.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import logging 4 | import time 5 | import uuid 6 | from typing import Any 7 | 8 | import httpx 9 | import jwt 10 | from jwcrypto import jwk 11 | from jwt import PyJWK, PyJWKClient 12 | from starlette.requests import Request 13 | from starlette.responses import JSONResponse 14 | 15 | logger = logging.getLogger(__name__) 16 | AUTH_HEADER_PREFIX = "Bearer " 17 | 18 | 19 | class PushNotificationAuth: 20 | def _calculate_request_body_sha256(self, data: dict[str, Any]): 21 | """Calculates the SHA256 hash of a request body. 22 | 23 | This logic needs to be same for both the agent who signs the payload and the client verifier. 24 | """ 25 | body_str = json.dumps( 26 | data, 27 | ensure_ascii=False, 28 | allow_nan=False, 29 | indent=None, 30 | separators=(",", ":"), 31 | ) 32 | return hashlib.sha256(body_str.encode()).hexdigest() 33 | 34 | 35 | class PushNotificationSenderAuth(PushNotificationAuth): 36 | def __init__(self): 37 | self.public_keys = [] 38 | self.private_key_jwk: PyJWK = None 39 | 40 | @staticmethod 41 | async def verify_push_notification_url(url: str) -> bool: 42 | async with httpx.AsyncClient(timeout=10) as client: 43 | try: 44 | validation_token = str(uuid.uuid4()) 45 | response = await client.get( 46 | url, params={"validationToken": validation_token} 47 | ) 48 | response.raise_for_status() 49 | is_verified = response.text == validation_token 50 | 51 | logger.info(f"Verified push-notification URL: {url} => {is_verified}") 52 | return is_verified 53 | except Exception as e: 54 | logger.warning( 55 | f"Error during sending push-notification for URL {url}: {e}" 56 | ) 57 | 58 | return False 59 | 60 | def generate_jwk(self): 61 | key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig") 62 | self.public_keys.append(key.export_public(as_dict=True)) 63 | self.private_key_jwk = PyJWK.from_json(key.export_private()) 64 | 65 | def handle_jwks_endpoint(self, _request: Request): 66 | """Allow clients to fetch public keys.""" 67 | return JSONResponse({"keys": self.public_keys}) 68 | 69 | def _generate_jwt(self, data: dict[str, Any]): 70 | """JWT is generated by signing both the request payload SHA digest and time of token generation. 71 | 72 | Payload is signed with private key and it ensures the integrity of payload for client. 73 | Including iat prevents from replay attack. 74 | """ 75 | 76 | iat = int(time.time()) 77 | 78 | return jwt.encode( 79 | { 80 | "iat": iat, 81 | "request_body_sha256": self._calculate_request_body_sha256(data), 82 | }, 83 | key=self.private_key_jwk, 84 | headers={"kid": self.private_key_jwk.key_id}, 85 | algorithm="RS256", 86 | ) 87 | 88 | async def send_push_notification(self, url: str, data: dict[str, Any]): 89 | jwt_token = self._generate_jwt(data) 90 | headers = {"Authorization": f"Bearer {jwt_token}"} 91 | async with httpx.AsyncClient(timeout=10) as client: 92 | try: 93 | response = await client.post(url, json=data, headers=headers) 94 | response.raise_for_status() 95 | logger.info(f"Push-notification sent for URL: {url}") 96 | except Exception as e: 97 | logger.warning( 98 | f"Error during sending push-notification for URL {url}: {e}" 99 | ) 100 | 101 | 102 | class PushNotificationReceiverAuth(PushNotificationAuth): 103 | def __init__(self): 104 | self.public_keys_jwks = [] 105 | self.jwks_client = None 106 | 107 | async def load_jwks(self, jwks_url: str): 108 | self.jwks_client = PyJWKClient(jwks_url) 109 | 110 | async def verify_push_notification(self, request: Request) -> bool: 111 | auth_header = request.headers.get("Authorization") 112 | if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): 113 | print("Invalid authorization header") 114 | return False 115 | 116 | token = auth_header[len(AUTH_HEADER_PREFIX) :] 117 | signing_key = self.jwks_client.get_signing_key_from_jwt(token) 118 | 119 | decode_token = jwt.decode( 120 | token, 121 | signing_key, 122 | options={"require": ["iat", "request_body_sha256"]}, 123 | algorithms=["RS256"], 124 | ) 125 | 126 | actual_body_sha256 = self._calculate_request_body_sha256(await request.json()) 127 | if actual_body_sha256 != decode_token["request_body_sha256"]: 128 | # Payload signature does not match the digest in signed token. 129 | raise ValueError("Invalid request body") 130 | 131 | if time.time() - decode_token["iat"] > 60 * 5: 132 | # Do not allow push-notifications older than 5 minutes. 133 | # This is to prevent replay attack. 134 | raise ValueError("Token is expired") 135 | 136 | return True 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.7.0 2 | anyio==4.9.0 3 | Authlib==1.5.2 4 | cachetools==5.5.2 5 | certifi==2025.1.31 6 | cffi==1.17.1 7 | charset-normalizer==3.4.1 8 | click==8.1.8 9 | colorama==0.4.6 10 | cryptography==44.0.2 11 | Deprecated==1.2.18 12 | distro==1.9.0 13 | docstring_parser==0.16 14 | fastapi==0.115.12 15 | google-adk==0.1.0 16 | google-api-core==2.24.2 17 | google-api-python-client==2.166.0 18 | google-auth==2.38.0 19 | google-auth-httplib2==0.2.0 20 | google-cloud-aiplatform==1.88.0 21 | google-cloud-bigquery==3.31.0 22 | google-cloud-core==2.4.3 23 | google-cloud-resource-manager==1.14.2 24 | google-cloud-secret-manager==2.23.2 25 | google-cloud-speech==2.31.1 26 | google-cloud-storage==2.19.0 27 | google-cloud-trace==1.16.1 28 | google-crc32c==1.7.1 29 | google-genai==1.10.0 30 | google-resumable-media==2.7.2 31 | googleapis-common-protos==1.69.2 32 | graphviz==0.20.3 33 | greenlet==3.1.1 34 | grpc-google-iam-v1==0.14.2 35 | grpcio==1.71.0 36 | grpcio-status==1.71.0 37 | h11==0.14.0 38 | httpcore==1.0.8 39 | httplib2==0.22.0 40 | httpx==0.28.1 41 | httpx-sse==0.4.0 42 | idna==3.10 43 | importlib_metadata==8.6.1 44 | jiter==0.9.0 45 | jsonpatch==1.33 46 | jsonpointer==3.0.0 47 | jwcrypto==1.5.6 48 | langchain==0.3.23 49 | langchain-core==0.3.51 50 | langchain-mcp-adapters==0.0.8 51 | langchain-openai==0.3.12 52 | langchain-text-splitters==0.3.8 53 | langgraph==0.3.29 54 | langgraph-checkpoint==2.0.24 55 | langgraph-prebuilt==0.1.8 56 | langgraph-sdk==0.1.61 57 | langsmith==0.3.30 58 | markdown-it-py==3.0.0 59 | mcp==1.6.0 60 | mdurl==0.1.2 61 | numpy==2.2.4 62 | openai==1.73.0 63 | opentelemetry-api==1.32.0 64 | opentelemetry-exporter-gcp-trace==1.9.0 65 | opentelemetry-resourcedetector-gcp==1.9.0a0 66 | opentelemetry-sdk==1.32.0 67 | opentelemetry-semantic-conventions==0.53b0 68 | orjson==3.10.16 69 | ormsgpack==1.9.1 70 | packaging==24.2 71 | proto-plus==1.26.1 72 | protobuf==5.29.4 73 | pyasn1==0.6.1 74 | pyasn1_modules==0.4.2 75 | pycparser==2.22 76 | pydantic==2.11.3 77 | pydantic-settings==2.8.1 78 | pydantic_core==2.33.1 79 | Pygments==2.19.1 80 | PyJWT==2.1.0 81 | pyparsing==3.2.3 82 | python-dateutil==2.9.0.post0 83 | python-dotenv==1.1.0 84 | PyYAML==6.0.2 85 | regex==2024.11.6 86 | requests==2.32.3 87 | requests-toolbelt==1.0.0 88 | rich==14.0.0 89 | rsa==4.9 90 | ruff==0.11.5 91 | shapely==2.1.0 92 | shellingham==1.5.4 93 | six==1.17.0 94 | sniffio==1.3.1 95 | SQLAlchemy==2.0.40 96 | sse-starlette==2.2.1 97 | starlette==0.46.2 98 | tenacity==9.1.2 99 | tiktoken==0.9.0 100 | tqdm==4.67.1 101 | typer==0.15.2 102 | typing-inspection==0.4.0 103 | typing_extensions==4.13.2 104 | tzdata==2025.2 105 | tzlocal==5.3.1 106 | uritemplate==4.1.1 107 | urllib3==2.4.0 108 | uvicorn==0.34.1 109 | websockets==15.0.1 110 | wrapt==1.17.2 111 | xxhash==3.5.0 112 | zipp==3.21.0 113 | zstandard==0.23.0 114 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Any, AsyncIterable, Union 4 | 5 | from fastapi import FastAPI, Request 6 | from fastapi.responses import JSONResponse 7 | from pydantic import ValidationError 8 | from sse_starlette.sse import EventSourceResponse 9 | 10 | from abc_task_manager import TaskManager 11 | from custom_types import ( 12 | A2ARequest, 13 | AgentCard, 14 | CancelTaskRequest, 15 | GetTaskPushNotificationRequest, 16 | GetTaskRequest, 17 | InternalError, 18 | InvalidRequestError, 19 | JSONParseError, 20 | JSONRPCResponse, 21 | SendTaskRequest, 22 | SendTaskStreamingRequest, 23 | SetTaskPushNotificationRequest, 24 | TaskResubscriptionRequest, 25 | ) 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class A2AServer: 31 | def __init__( 32 | self, 33 | host: str = "0.0.0.0", 34 | port: int = 5000, 35 | endpoint: str = "/", 36 | agent_card: AgentCard = None, 37 | task_manager: TaskManager = None, 38 | ): 39 | self.host = host 40 | self.port = port 41 | self.endpoint = endpoint 42 | self.task_manager = task_manager 43 | self.agent_card = agent_card 44 | 45 | # Erstelle eine FastAPI-App für automatische Dokumentation (/docs, /redoc, etc.) 46 | self.app = FastAPI( 47 | title="A2A Server", description="A2A Protocol JSON-RPC API", version="1.0.0" 48 | ) 49 | # JSON-RPC-Endpunkt (POST) - automatische Response Modell Generierung deaktiviert 50 | self.app.add_api_route( 51 | self.endpoint, self._process_request, methods=["POST"], response_model=None 52 | ) 53 | # AgentCard-Endpunkt unter .well-known 54 | self.app.add_api_route( 55 | "/.well-known/agent.json", 56 | self._get_agent_card, 57 | methods=["GET"], 58 | response_model=None, 59 | ) 60 | 61 | def start(self): 62 | if self.agent_card is None: 63 | raise ValueError("agent_card is not defined") 64 | if self.task_manager is None: 65 | raise ValueError("task_manager is not defined") 66 | import uvicorn 67 | 68 | uvicorn.run(self.app, host=self.host, port=self.port) 69 | 70 | async def _get_agent_card(self, request: Request) -> JSONResponse: 71 | # Liefert die AgentCard als JSON zurück. 72 | return JSONResponse(self.agent_card.model_dump(exclude_none=True)) 73 | 74 | async def _process_request( 75 | self, request: Request 76 | ) -> Union[JSONResponse, EventSourceResponse]: 77 | try: 78 | body = await request.json() 79 | json_rpc_request = A2ARequest.validate_python(body) 80 | 81 | if isinstance(json_rpc_request, GetTaskRequest): 82 | result = await self.task_manager.on_get_task(json_rpc_request) 83 | elif isinstance(json_rpc_request, SendTaskRequest): 84 | result = await self.task_manager.on_send_task(json_rpc_request) 85 | elif isinstance(json_rpc_request, SendTaskStreamingRequest): 86 | result = await self.task_manager.on_send_task_subscribe( 87 | json_rpc_request 88 | ) 89 | elif isinstance(json_rpc_request, CancelTaskRequest): 90 | result = await self.task_manager.on_cancel_task(json_rpc_request) 91 | elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): 92 | result = await self.task_manager.on_set_task_push_notification( 93 | json_rpc_request 94 | ) 95 | elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): 96 | result = await self.task_manager.on_get_task_push_notification( 97 | json_rpc_request 98 | ) 99 | elif isinstance(json_rpc_request, TaskResubscriptionRequest): 100 | result = await self.task_manager.on_resubscribe_to_task( 101 | json_rpc_request 102 | ) 103 | else: 104 | logger.warning(f"Unexpected request type: {type(json_rpc_request)}") 105 | raise ValueError(f"Unexpected request type: {type(json_rpc_request)}") 106 | 107 | return self._create_response(result) 108 | 109 | except Exception as e: 110 | return self._handle_exception(e) 111 | 112 | def _handle_exception(self, e: Exception) -> JSONResponse: 113 | if isinstance(e, json.decoder.JSONDecodeError): 114 | json_rpc_error = JSONParseError() 115 | elif isinstance(e, ValidationError): 116 | json_rpc_error = InvalidRequestError(data=json.loads(e.json())) 117 | else: 118 | logger.error(f"Unhandled exception: {e}") 119 | json_rpc_error = InternalError() 120 | 121 | response = JSONRPCResponse(id=None, error=json_rpc_error) 122 | return JSONResponse(response.model_dump(exclude_none=True), status_code=400) 123 | 124 | def _create_response(self, result: Any) -> Union[JSONResponse, EventSourceResponse]: 125 | if isinstance(result, AsyncIterable): 126 | 127 | async def event_generator( 128 | result: AsyncIterable, 129 | ) -> AsyncIterable[dict[str, str]]: 130 | async for item in result: 131 | yield {"data": item.model_dump_json(exclude_none=True)} 132 | 133 | return EventSourceResponse(event_generator(result)) 134 | elif isinstance(result, JSONRPCResponse): 135 | return JSONResponse(result.model_dump(exclude_none=True)) 136 | else: 137 | logger.error(f"Unexpected result type: {type(result)}") 138 | raise ValueError(f"Unexpected result type: {type(result)}") 139 | -------------------------------------------------------------------------------- /task_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import traceback 4 | from typing import AsyncIterable, Union 5 | 6 | import utils as utils 7 | from abc_task_manager import InMemoryTaskManager 8 | from agent import CurrencyAgent 9 | from custom_types import ( 10 | Artifact, 11 | InternalError, 12 | InvalidParamsError, 13 | JSONRPCResponse, 14 | Message, 15 | PushNotificationConfig, 16 | SendTaskRequest, 17 | SendTaskResponse, 18 | SendTaskStreamingRequest, 19 | SendTaskStreamingResponse, 20 | Task, 21 | TaskArtifactUpdateEvent, 22 | TaskIdParams, 23 | TaskSendParams, 24 | TaskState, 25 | TaskStatus, 26 | TaskStatusUpdateEvent, 27 | TextPart, 28 | ) 29 | from push_notification_auth import PushNotificationSenderAuth 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class AgentTaskManager(InMemoryTaskManager): 35 | def __init__( 36 | self, agent: CurrencyAgent, notification_sender_auth: PushNotificationSenderAuth 37 | ): 38 | super().__init__() 39 | self.agent = agent 40 | self.notification_sender_auth = notification_sender_auth 41 | 42 | async def _run_streaming_agent(self, request: SendTaskStreamingRequest): 43 | task_send_params: TaskSendParams = request.params 44 | query = self._get_user_query(task_send_params) 45 | 46 | try: 47 | async for item in self.agent.stream(query, task_send_params.sessionId): 48 | is_task_complete = item["is_task_complete"] 49 | require_user_input = item["require_user_input"] 50 | artifact = None 51 | message = None 52 | parts = [{"type": "text", "text": item["content"]}] 53 | end_stream = False 54 | 55 | if not is_task_complete and not require_user_input: 56 | task_state = TaskState.WORKING 57 | message = Message(role="agent", parts=parts) 58 | elif require_user_input: 59 | task_state = TaskState.INPUT_REQUIRED 60 | message = Message(role="agent", parts=parts) 61 | end_stream = True 62 | else: 63 | task_state = TaskState.COMPLETED 64 | artifact = Artifact(parts=parts, index=0, append=False) 65 | end_stream = True 66 | 67 | task_status = TaskStatus(state=task_state, message=message) 68 | latest_task = await self.update_store( 69 | task_send_params.id, 70 | task_status, 71 | None if artifact is None else [artifact], 72 | ) 73 | await self.send_task_notification(latest_task) 74 | 75 | if artifact: 76 | task_artifact_update_event = TaskArtifactUpdateEvent( 77 | id=task_send_params.id, artifact=artifact 78 | ) 79 | await self.enqueue_events_for_sse( 80 | task_send_params.id, task_artifact_update_event 81 | ) 82 | 83 | task_update_event = TaskStatusUpdateEvent( 84 | id=task_send_params.id, status=task_status, final=end_stream 85 | ) 86 | await self.enqueue_events_for_sse( 87 | task_send_params.id, task_update_event 88 | ) 89 | 90 | except Exception as e: 91 | logger.error(f"An error occurred while streaming the response: {e}") 92 | await self.enqueue_events_for_sse( 93 | task_send_params.id, 94 | InternalError( 95 | message=f"An error occurred while streaming the response: {e}" 96 | ), 97 | ) 98 | 99 | def _validate_request( 100 | self, request: Union[SendTaskRequest, SendTaskStreamingRequest] 101 | ) -> JSONRPCResponse | None: 102 | task_send_params: TaskSendParams = request.params 103 | if not utils.are_modalities_compatible( 104 | task_send_params.acceptedOutputModes, CurrencyAgent.SUPPORTED_CONTENT_TYPES 105 | ): 106 | logger.warning( 107 | "Unsupported output mode. Received %s, Support %s", 108 | task_send_params.acceptedOutputModes, 109 | CurrencyAgent.SUPPORTED_CONTENT_TYPES, 110 | ) 111 | return utils.new_incompatible_types_error(request.id) 112 | 113 | if ( 114 | task_send_params.pushNotification 115 | and not task_send_params.pushNotification.url 116 | ): 117 | logger.warning("Push notification URL is missing") 118 | return JSONRPCResponse( 119 | id=request.id, 120 | error=InvalidParamsError(message="Push notification URL is missing"), 121 | ) 122 | 123 | return None 124 | 125 | async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: 126 | """Handles the 'send task' request.""" 127 | validation_error = self._validate_request(request) 128 | if validation_error: 129 | return SendTaskResponse(id=request.id, error=validation_error.error) 130 | 131 | if request.params.pushNotification: 132 | if not await self.set_push_notification_info( 133 | request.params.id, request.params.pushNotification 134 | ): 135 | return SendTaskResponse( 136 | id=request.id, 137 | error=InvalidParamsError( 138 | message="Push notification URL is invalid" 139 | ), 140 | ) 141 | 142 | await self.upsert_task(request.params) 143 | task = await self.update_store( 144 | request.params.id, TaskStatus(state=TaskState.WORKING), None 145 | ) 146 | await self.send_task_notification(task) 147 | 148 | task_send_params: TaskSendParams = request.params 149 | query = self._get_user_query(task_send_params) 150 | try: 151 | agent_response = self.agent.invoke(query, task_send_params.sessionId) 152 | except Exception as e: 153 | logger.error(f"Error invoking agent: {e}") 154 | raise ValueError(f"Error invoking agent: {e}") 155 | return await self._process_agent_response(request, agent_response) 156 | 157 | async def on_send_task_subscribe( 158 | self, request: SendTaskStreamingRequest 159 | ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: 160 | try: 161 | error = self._validate_request(request) 162 | if error: 163 | return error 164 | 165 | await self.upsert_task(request.params) 166 | 167 | if request.params.pushNotification: 168 | if not await self.set_push_notification_info( 169 | request.params.id, request.params.pushNotification 170 | ): 171 | return JSONRPCResponse( 172 | id=request.id, 173 | error=InvalidParamsError( 174 | message="Push notification URL is invalid" 175 | ), 176 | ) 177 | 178 | task_send_params: TaskSendParams = request.params 179 | sse_event_queue = await self.setup_sse_consumer(task_send_params.id, False) 180 | 181 | asyncio.create_task(self._run_streaming_agent(request)) 182 | 183 | return self.dequeue_events_for_sse( 184 | request.id, task_send_params.id, sse_event_queue 185 | ) 186 | except Exception as e: 187 | logger.error(f"Error in SSE stream: {e}") 188 | print(traceback.format_exc()) 189 | return JSONRPCResponse( 190 | id=request.id, 191 | error=InternalError( 192 | message="An error occurred while streaming the response" 193 | ), 194 | ) 195 | 196 | async def _process_agent_response( 197 | self, request: SendTaskRequest, agent_response: dict 198 | ) -> SendTaskResponse: 199 | """Processes the agent's response and updates the task store.""" 200 | task_send_params: TaskSendParams = request.params 201 | task_id = task_send_params.id 202 | history_length = task_send_params.historyLength 203 | task_status = None 204 | 205 | parts = [{"type": "text", "text": agent_response["content"]}] 206 | artifact = None 207 | if agent_response["require_user_input"]: 208 | task_status = TaskStatus( 209 | state=TaskState.INPUT_REQUIRED, 210 | message=Message(role="agent", parts=parts), 211 | ) 212 | else: 213 | task_status = TaskStatus(state=TaskState.COMPLETED) 214 | artifact = Artifact(parts=parts) 215 | task = await self.update_store( 216 | task_id, task_status, None if artifact is None else [artifact] 217 | ) 218 | task_result = self.append_task_history(task, history_length) 219 | await self.send_task_notification(task) 220 | return SendTaskResponse(id=request.id, result=task_result) 221 | 222 | def _get_user_query(self, task_send_params: TaskSendParams) -> str: 223 | part = task_send_params.message.parts[0] 224 | if not isinstance(part, TextPart): 225 | raise ValueError("Only text parts are supported") 226 | return part.text 227 | 228 | async def send_task_notification(self, task: Task): 229 | if not await self.has_push_notification_info(task.id): 230 | logger.info(f"No push notification info found for task {task.id}") 231 | return 232 | push_info = await self.get_push_notification_info(task.id) 233 | 234 | logger.info(f"Notifying for task {task.id} => {task.status.state}") 235 | await self.notification_sender_auth.send_push_notification( 236 | push_info.url, data=task.model_dump(exclude_none=True) 237 | ) 238 | 239 | async def on_resubscribe_to_task( 240 | self, request 241 | ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: 242 | task_id_params: TaskIdParams = request.params 243 | try: 244 | sse_event_queue = await self.setup_sse_consumer(task_id_params.id, True) 245 | return self.dequeue_events_for_sse( 246 | request.id, task_id_params.id, sse_event_queue 247 | ) 248 | except Exception as e: 249 | logger.error(f"Error while reconnecting to SSE stream: {e}") 250 | return JSONRPCResponse( 251 | id=request.id, 252 | error=InternalError( 253 | message=f"An error occurred while reconnecting to stream: {e}" 254 | ), 255 | ) 256 | 257 | async def set_push_notification_info( 258 | self, task_id: str, push_notification_config: PushNotificationConfig 259 | ): 260 | # Verify the ownership of notification URL by issuing a challenge request. 261 | is_verified = await self.notification_sender_auth.verify_push_notification_url( 262 | push_notification_config.url 263 | ) 264 | if not is_verified: 265 | return False 266 | 267 | await super().set_push_notification_info(task_id, push_notification_config) 268 | return True 269 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from custom_types import ( 4 | ContentTypeNotSupportedError, 5 | JSONRPCResponse, 6 | UnsupportedOperationError, 7 | ) 8 | 9 | 10 | def are_modalities_compatible( 11 | server_output_modes: List[str], client_output_modes: List[str] 12 | ): 13 | """Modalities are compatible if they are both non-empty 14 | and there is at least one common element.""" 15 | if client_output_modes is None or len(client_output_modes) == 0: 16 | return True 17 | 18 | if server_output_modes is None or len(server_output_modes) == 0: 19 | return True 20 | 21 | return any(x in server_output_modes for x in client_output_modes) 22 | 23 | 24 | def new_incompatible_types_error(request_id): 25 | return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError()) 26 | 27 | 28 | def new_not_implemented_error(request_id): 29 | return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) 30 | --------------------------------------------------------------------------------