├── tests ├── __init__.py ├── agents │ ├── __init__.py │ ├── test_openai_agent.py │ └── test_anthropic_agent.py ├── mock_anthropic_client.py ├── mock_client.py ├── test_utils.py └── test_workflow.py ├── examples ├── pizza │ ├── __init__.py │ ├── README.md │ ├── run_worker.py │ ├── send_messages.py │ ├── agents.py │ └── functions.py ├── weather │ ├── __init__.py │ ├── README.md │ ├── run_worker.py │ └── run_workflow.py └── mcp_weather │ ├── __init__.py │ ├── README.md │ ├── main.py │ └── mcp_weather_server.py ├── rojak ├── mcp │ ├── __init__.py │ └── mcp_client.py ├── __init__.py ├── types │ ├── __init__.py │ └── types.py ├── utils │ ├── __init__.py │ └── helpers.py ├── retrievers │ ├── __init__.py │ ├── qdrant_retriever.py │ └── retriever.py ├── workflows │ ├── __init__.py │ ├── orchestrator_workflow.py │ └── agent_workflow.py ├── agents │ ├── __init__.py │ ├── openai_agent.py │ ├── anthropic_agent.py │ └── agent.py └── client.py ├── assets └── rojak_diagram.png ├── .vscode └── settings.json ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── test.yml │ ├── lint.yml │ └── publish-to-pypi.yml ├── pyproject.toml ├── .gitignore └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/pizza/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/weather/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/mcp_weather/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rojak/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcp_client import MCPClient 2 | 3 | __all__ = ["MCPClient"] 4 | -------------------------------------------------------------------------------- /assets/rojak_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StreetLamb/rojak/HEAD/assets/rojak_diagram.png -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": ["tests"], 3 | "python.testing.unittestEnabled": false, 4 | "python.testing.pytestEnabled": true 5 | } 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.8.4 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format -------------------------------------------------------------------------------- /rojak/__init__.py: -------------------------------------------------------------------------------- 1 | from temporalio.client import ScheduleSpec, ScheduleIntervalSpec, ScheduleCalendarSpec 2 | from .client import Rojak 3 | 4 | __all__ = [ 5 | "ScheduleSpec", 6 | "ScheduleIntervalSpec", 7 | "ScheduleCalendarSpec", 8 | "Rojak", 9 | ] 10 | -------------------------------------------------------------------------------- /rojak/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .types import ( 2 | ContextVariables, 3 | ConversationMessage, 4 | RetryPolicy, 5 | RetryOptions, 6 | MCPServerConfig, 7 | InitMcpResult, 8 | ) 9 | 10 | __all__ = [ 11 | "ContextVariables", 12 | "ConversationMessage", 13 | "RetryPolicy", 14 | "RetryOptions", 15 | "MCPServerConfig", 16 | "InitMcpResult", 17 | ] 18 | -------------------------------------------------------------------------------- /examples/mcp_weather/README.md: -------------------------------------------------------------------------------- 1 | # MCP Weather Example 2 | 3 | Demostrate how a weather agent can call tools from a MCP server. 4 | 5 | Make sure temporal is running: 6 | ```shell 7 | temporal server start-dev 8 | ``` 9 | 10 | Ensure you have `OPENAI_API_KEY` in the .env file: 11 | ``` 12 | OPENAI_API_KEY= 13 | ``` 14 | 15 | Run the script: 16 | ```shell 17 | python main.py 18 | ``` -------------------------------------------------------------------------------- /rojak/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import ( 2 | create_retry_policy, 3 | debug_print, 4 | function_to_json, 5 | function_to_json_anthropic, 6 | mcp_to_anthropic_tool, 7 | mcp_to_openai_tool, 8 | ) 9 | 10 | __all__ = [ 11 | "create_retry_policy", 12 | "debug_print", 13 | "function_to_json", 14 | "function_to_json_anthropic", 15 | "mcp_to_anthropic_tool", 16 | "mcp_to_openai_tool", 17 | ] 18 | -------------------------------------------------------------------------------- /examples/pizza/README.md: -------------------------------------------------------------------------------- 1 | # Food ordering example 2 | 3 | Demostrate orchestrating multiple agents in a food ordering workflow. 4 | 5 | ## Setup 6 | 7 | Run Temporal locally: 8 | ```shell 9 | temporal server start-dev 10 | ``` 11 | 12 | Ensure you have `OPENAI_API_KEY` in the .env file: 13 | ``` 14 | OPENAI_API_KEY= 15 | ``` 16 | 17 | Start the worker: 18 | ```shell 19 | python run_worker.py 20 | ``` 21 | 22 | In another terminal, run the client: 23 | ```shell 24 | python send_message.py 25 | ``` -------------------------------------------------------------------------------- /examples/weather/README.md: -------------------------------------------------------------------------------- 1 | # Weather Example 2 | 3 | Demostrates how to create a rojak `Session` to interact with a Weather agent that uses tool calling and `context_variables`. 4 | 5 | Make sure temporal is running: 6 | ```shell 7 | temporal server start-dev 8 | ``` 9 | 10 | Ensure you have `OPENAI_API_KEY` in the .env file: 11 | ``` 12 | OPENAI_API_KEY= 13 | ``` 14 | 15 | Run the worker: 16 | ```shell 17 | python run_worker.py 18 | ``` 19 | 20 | On another terminal, run script to start session and send a message: 21 | ``` 22 | python run_workflow.py 23 | ``` -------------------------------------------------------------------------------- /rojak/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from .retriever import Retriever, RetrieverActivities 2 | 3 | try: 4 | from .qdrant_retriever import ( # noqa: F401 5 | QdrantRetriever, 6 | QdrantRetrieverActivities, 7 | QdrantRetrieverOptions, 8 | ) 9 | 10 | _QDRANT_AVAILABLE_ = True 11 | except ImportError: 12 | _QDRANT_AVAILABLE_ = False 13 | 14 | __all__ = ["Retriever", "RetrieverActivities"] 15 | 16 | 17 | if _QDRANT_AVAILABLE_: 18 | __all__.extend( 19 | ["QdrantRetriever", "QdrantRetrieverActivities", "QdrantRetrieverOptions"] 20 | ) 21 | -------------------------------------------------------------------------------- /rojak/workflows/__init__.py: -------------------------------------------------------------------------------- 1 | from .orchestrator_workflow import ( 2 | OrchestratorParams, 3 | OrchestratorResponse, 4 | UpdateConfigParams, 5 | OrchestratorWorkflow, 6 | GetConfigResponse, 7 | TaskParams, 8 | ) 9 | from .agent_workflow import ( 10 | AgentWorkflowRunParams, 11 | ToolResponse, 12 | AgentWorkflowResponse, 13 | AgentWorkflow, 14 | AgentTypes, 15 | ) 16 | 17 | __all__ = [ 18 | "OrchestratorParams", 19 | "OrchestratorResponse", 20 | "UpdateConfigParams", 21 | "OrchestratorWorkflow", 22 | "AgentWorkflowRunParams", 23 | "ToolResponse", 24 | "AgentWorkflowResponse", 25 | "AgentWorkflow", 26 | "AgentTypes", 27 | "GetConfigResponse", 28 | "TaskParams", 29 | ] 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: 9 | - opened 10 | - synchronize 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: "3.x" 24 | 25 | - name: Install Poetry 26 | run: | 27 | curl -sSL https://install.python-poetry.org | python - -y 28 | 29 | - name: Install dependencies 30 | run: | 31 | poetry install --all-extras --no-interaction 32 | 33 | - name: Run tests with pytest 34 | run: poetry run pytest -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Run Lint checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: 9 | - opened 10 | - synchronize 11 | 12 | jobs: 13 | lint: 14 | name: Run Ruff 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.x" 25 | 26 | - name: Install Poetry 27 | run: | 28 | curl -sSL https://install.python-poetry.org | python - -y 29 | 30 | - name: Install dependencies 31 | run: | 32 | poetry install --all-extras --no-interaction 33 | 34 | - name: Run tests with pytest 35 | run: | 36 | poetry run ruff check . -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rojak" 3 | version = "1.0.0" 4 | description = "Durable and scalable multi-agent orchestration framework." 5 | authors = ["StreetLamb "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.12" 10 | temporalio = "^1.8.0" 11 | openai = {version = "^1.55.3", optional = true} 12 | qdrant-client = {extras = ["fastembed"], version = "^1.12.1", optional = true} 13 | anthropic = {extras = ["bedrock"], version = "^0.42.0", optional = true} 14 | mcp = "^1.2.0" 15 | 16 | [tool.poetry.extras] 17 | openai = ["openai"] 18 | anthropic = ["anthropic"] 19 | qdrant-client = ["qdrant-client"] 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | ruff = "^0.8.4" 24 | pytest = "^8.3.4" 25 | pytest-asyncio = "^0.25.0" 26 | pre-commit = "^4.0.1" 27 | 28 | [build-system] 29 | requires = ["poetry-core"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /examples/weather/run_worker.py: -------------------------------------------------------------------------------- 1 | from temporalio.client import Client 2 | import asyncio 3 | from rojak import Rojak 4 | from rojak.agents import ( 5 | OpenAIAgentActivities, 6 | OpenAIAgentOptions, 7 | AgentExecuteFnResult, 8 | ) 9 | import json 10 | 11 | 12 | def get_weather(location: str, time="now"): 13 | """Get the current weather in a given location. Location MUST be a city.""" 14 | return json.dumps({"location": location, "temperature": "65", "time": time}) 15 | 16 | 17 | def send_email(recipient: str, subject: str, body: str, context_variables: dict): 18 | """Send an email to a recipient.""" 19 | print("Sending email...") 20 | print(f"To: {recipient}") 21 | print(f"Subject: {subject}") 22 | print(f"Body: {body}") 23 | context_variables["status"] = "sent" 24 | return AgentExecuteFnResult( 25 | agent=None, context_variables=context_variables, output="Sent!" 26 | ) 27 | 28 | 29 | async def main(): 30 | temporal_client: Client = await Client.connect("localhost:7233") 31 | 32 | openai_activities = OpenAIAgentActivities( 33 | OpenAIAgentOptions( 34 | all_functions=[get_weather, send_email], 35 | ) 36 | ) 37 | 38 | rojak = Rojak(temporal_client, task_queue="weather-tasks") 39 | worker = await rojak.create_worker(agent_activities=[openai_activities]) 40 | await worker.run() 41 | 42 | 43 | if __name__ == "__main__": 44 | asyncio.run(main()) 45 | -------------------------------------------------------------------------------- /rojak/retrievers/qdrant_retriever.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from rojak.retrievers import RetrieverActivities, Retriever 3 | from temporalio import activity 4 | from qdrant_client import QdrantClient 5 | from qdrant_client.qdrant_fastembed import QueryResponse 6 | 7 | 8 | @dataclass 9 | class QdrantRetrieverOptions: 10 | url: str 11 | collection_name: str 12 | 13 | 14 | @dataclass 15 | class QdrantRetriever(Retriever): 16 | type: str = field(init=False, default="qdrant") 17 | 18 | 19 | class QdrantRetrieverActivities(RetrieverActivities): 20 | def __init__(self, options: QdrantRetrieverOptions): 21 | super().__init__(options) 22 | self.client = QdrantClient(url=options.url) 23 | self.collection_name = options.collection_name 24 | 25 | async def retrieve(self, text: str) -> list[QueryResponse]: 26 | retrieval_results = self.client.query( 27 | collection_name=self.collection_name, query_text=text 28 | ) 29 | return retrieval_results 30 | 31 | @activity.defn(name="qdrant_retrieve_and_combine_results") 32 | async def retrieve_and_combine_results(self, text: str) -> str: 33 | retrieval_results = await self.retrieve(text) 34 | return self.combine_retreival_results(retrieval_results) 35 | 36 | @staticmethod 37 | def combine_retreival_results(results: list[QueryResponse]) -> str: 38 | return "\n".join([result.document for result in results]) 39 | -------------------------------------------------------------------------------- /rojak/mcp/mcp_client.py: -------------------------------------------------------------------------------- 1 | from mcp import ClientSession, StdioServerParameters 2 | from contextlib import AsyncExitStack 3 | from mcp.client.stdio import stdio_client 4 | from mcp.client.sse import sse_client 5 | from rojak.types import MCPServerConfig 6 | 7 | 8 | class MCPClient: 9 | def __init__(self): 10 | self.session: ClientSession | None = None 11 | self.exit_stack = AsyncExitStack() 12 | 13 | async def connect_to_server(self, config: MCPServerConfig): 14 | """Connect to an MCP server""" 15 | 16 | if config.type == "stdio": 17 | params = StdioServerParameters(command=config.command, args=config.args) 18 | stdio_transport = await self.exit_stack.enter_async_context( 19 | stdio_client(params) 20 | ) 21 | self.stdio, self.write = stdio_transport 22 | self.session = await self.exit_stack.enter_async_context( 23 | ClientSession(self.stdio, self.write) 24 | ) 25 | else: 26 | stdio_transport = await self.exit_stack.enter_async_context( 27 | sse_client(config.url) 28 | ) 29 | self.sse, self.write = stdio_transport 30 | self.session = await self.exit_stack.enter_async_context( 31 | ClientSession(self.sse, self.write) 32 | ) 33 | await self.session.initialize() 34 | 35 | async def cleanup(self): 36 | """Clean up resources""" 37 | await self.exit_stack.aclose() 38 | -------------------------------------------------------------------------------- /examples/weather/run_workflow.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from temporalio import client 3 | from rojak.agents import OpenAIAgent 4 | from rojak.types import RetryOptions, RetryPolicy 5 | from rojak import Rojak 6 | from rojak.workflows import OrchestratorResponse, TaskParams 7 | 8 | 9 | async def main() -> None: 10 | temporal_client = await client.Client.connect("localhost:7233") 11 | rojak = Rojak(temporal_client, task_queue="weather-tasks") 12 | 13 | weather_agent = OpenAIAgent( 14 | name="Weather Assistant", 15 | instructions="Help provide the weather forecast.", 16 | functions=["get_weather", "send_email"], 17 | retry_options=RetryOptions( 18 | retry_policy=RetryPolicy(maximum_attempts=5), 19 | timeout_in_seconds=20, 20 | ), 21 | parallel_tool_calls=False, 22 | ) 23 | 24 | response = await rojak.run( 25 | "weather-session", 26 | task=TaskParams( 27 | agent=weather_agent, 28 | messages=[ 29 | { 30 | "role": "user", 31 | "content": "What is the weather like in Malaysia and Singapore? Send an email to john@example.com", 32 | } 33 | ], 34 | ), 35 | max_turns=30, 36 | debug=True, 37 | type="persistent", 38 | ) 39 | 40 | assert isinstance(response.result, OrchestratorResponse) 41 | 42 | print(response.result.messages[-1].content) 43 | 44 | 45 | if __name__ == "__main__": 46 | asyncio.run(main()) 47 | -------------------------------------------------------------------------------- /examples/pizza/run_worker.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | from temporalio.client import Client 3 | from rojak import Rojak 4 | from rojak.agents import OpenAIAgentActivities, OpenAIAgentOptions 5 | import asyncio 6 | from examples.pizza.functions import ( 7 | to_food_order, 8 | to_payment, 9 | to_feedback, 10 | to_triage, 11 | get_menu, 12 | add_to_cart, 13 | remove_from_cart, 14 | get_cart, 15 | process_payment, 16 | get_receipt, 17 | provide_feedback, 18 | food_ordering_instructions, 19 | ) 20 | 21 | 22 | async def main(): 23 | # Create client connected to server at the given address 24 | client = await Client.connect("localhost:7233") 25 | 26 | openai_activities = OpenAIAgentActivities( 27 | OpenAIAgentOptions( 28 | all_functions=[ 29 | to_food_order, 30 | to_payment, 31 | to_feedback, 32 | to_triage, 33 | get_menu, 34 | add_to_cart, 35 | remove_from_cart, 36 | get_cart, 37 | process_payment, 38 | get_receipt, 39 | provide_feedback, 40 | food_ordering_instructions, 41 | ] 42 | ) 43 | ) 44 | 45 | rojak = Rojak(client, task_queue="tasks") 46 | worker = await rojak.create_worker([openai_activities]) 47 | await worker.run() 48 | 49 | 50 | if __name__ == "__main__": 51 | print("Starting worker") 52 | print("Then run 'python send_messages.py' to start sending messages.") 53 | 54 | asyncio.run(main()) 55 | -------------------------------------------------------------------------------- /rojak/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import ( 2 | Agent, 3 | AgentActivities, 4 | AgentCallParams, 5 | AgentExecuteFnResult, 6 | AgentInstructionOptions, 7 | AgentResponse, 8 | AgentOptions, 9 | AgentToolCall, 10 | ExecuteFunctionParams, 11 | ExecuteInstructionsParams, 12 | Interrupt, 13 | ResumeRequest, 14 | ResumeResponse, 15 | ) 16 | 17 | try: 18 | from .openai_agent import OpenAIAgent, OpenAIAgentOptions, OpenAIAgentActivities # noqa: F401 19 | 20 | _OPENAI_AVAILABLE_ = True 21 | except ImportError: 22 | _OPENAI_AVAILABLE_ = False 23 | 24 | try: 25 | from .anthropic_agent import ( # noqa: F401 26 | AnthropicAgent, 27 | AnthropicAgentOptions, 28 | AnthropicAgentActivities, 29 | ) 30 | 31 | _ANTHROPIC_AVAILABLE_ = True 32 | except ImportError: 33 | _ANTHROPIC_AVAILABLE_ = False 34 | 35 | __all__ = [ 36 | "Agent", 37 | "AgentActivities", 38 | "AgentOptions", 39 | "AgentCallParams", 40 | "AgentExecuteFnResult", 41 | "AgentInstructionOptions", 42 | "AgentResponse", 43 | "AgentToolCall", 44 | "ExecuteFunctionParams", 45 | "ExecuteInstructionsParams", 46 | "Interrupt", 47 | "ResumeRequest", 48 | "ResumeResponse", 49 | ] 50 | 51 | if _OPENAI_AVAILABLE_: 52 | __all__.extend( 53 | [ 54 | "OpenAIAgent", 55 | "OpenAIAgentOptions", 56 | "OpenAIAgentActivities", 57 | ] 58 | ) 59 | 60 | if _ANTHROPIC_AVAILABLE_: 61 | __all__.extend( 62 | [ 63 | "AnthropicAgent", 64 | "AnthropicAgentOptions", 65 | "AnthropicAgentActivities", 66 | ] 67 | ) 68 | -------------------------------------------------------------------------------- /rojak/retrievers/retriever.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class Retriever(ABC): 7 | type: str 8 | """The prefix of the activity name.""" 9 | 10 | 11 | class RetrieverActivities(ABC): 12 | """ 13 | Abstract base class for Retriever implementations. 14 | This class provides a common structure for different types of retrievers. 15 | """ 16 | 17 | def __init__(self, options: any): 18 | self._options = options 19 | pass 20 | 21 | @abstractmethod 22 | async def retrieve(self, text: str) -> any: 23 | """Retrieve information based on the input text. 24 | 25 | This abstract method must be implemented by all concrete subclasses. It handles the retrieval of 26 | information relevant to the provided text. 27 | 28 | Args: 29 | text (str): The input text to base retrieval on. 30 | 31 | Returns: 32 | Any: The retrieved information corresponding to the input text. 33 | """ 34 | pass 35 | 36 | @abstractmethod 37 | async def retrieve_and_combine_results(self, text: str) -> str: 38 | """Retrieve information and combine the results. 39 | 40 | This abstract method must be implemented by all concrete subclasses. It is responsible for performing 41 | data retrieval based on the input text and combining or processing the results into a final output. 42 | 43 | Args: 44 | text (str): The input text text to guide the retrieval process. 45 | 46 | Returns: 47 | str: The processed and combined results of the retrieval operation. 48 | """ 49 | pass 50 | -------------------------------------------------------------------------------- /examples/mcp_weather/main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import uuid 3 | from temporalio.client import Client 4 | from rojak.agents import OpenAIAgent, OpenAIAgentActivities 5 | from rojak.client import Rojak 6 | from rojak.types import MCPServerConfig, RetryOptions, RetryPolicy 7 | from rojak.workflows import TaskParams, OrchestratorResponse 8 | 9 | 10 | async def main(): 11 | client = await Client.connect("localhost:7233") 12 | rojak = Rojak(client=client, task_queue="tasks") 13 | 14 | agent = OpenAIAgent( 15 | name="Weather Agent", 16 | retry_options=RetryOptions(retry_policy=RetryPolicy(maximum_attempts=5)), 17 | ) 18 | openai_activities = OpenAIAgentActivities() 19 | worker = await rojak.create_worker( 20 | [openai_activities], 21 | mcp_servers={ 22 | "weather": MCPServerConfig("stdio", "python", ["mcp_weather_server.py"]) 23 | }, 24 | ) 25 | try: 26 | async with worker: 27 | response = await rojak.run( 28 | id=str(uuid.uuid4()), 29 | task=TaskParams( 30 | agent=agent, 31 | messages=[ 32 | { 33 | "role": "user", 34 | "content": "Weather like in San Francisco?", 35 | } 36 | ], 37 | ), 38 | type="stateless", 39 | debug=True, 40 | ) 41 | 42 | assert isinstance(response.result, OrchestratorResponse) 43 | 44 | print(response.result.messages[-1].content) 45 | finally: 46 | await rojak.cleanup_mcp() 47 | 48 | 49 | if __name__ == "__main__": 50 | asyncio.run(main()) 51 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package to PyPI when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | release-build: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.x" 28 | 29 | - name: Install Poetry 30 | run: | 31 | curl -sSL https://install.python-poetry.org | python - -y 32 | 33 | - name: Build release distributions 34 | run: | 35 | poetry install --all-extras --no-interaction 36 | poetry build 37 | 38 | - name: Upload distributions 39 | uses: actions/upload-artifact@v4 40 | with: 41 | name: release-dists 42 | path: dist/ 43 | 44 | pypi-publish: 45 | runs-on: ubuntu-latest 46 | needs: 47 | - release-build 48 | permissions: 49 | id-token: write 50 | 51 | environment: 52 | name: pypi 53 | url: https://pypi.org/p/rojak 54 | 55 | steps: 56 | - name: Retrieve release distributions 57 | uses: actions/download-artifact@v4 58 | with: 59 | name: release-dists 60 | path: dist/ 61 | 62 | - name: Publish release distributions to PyPI 63 | uses: pypa/gh-action-pypi-publish@release/v1 64 | with: 65 | packages-dir: dist/ 66 | -------------------------------------------------------------------------------- /tests/mock_anthropic_client.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | from anthropic.types import Message, TextBlock, ToolUseBlock, Usage 3 | 4 | 5 | def create_mock_response(message, function_calls=[], model="gpt-4o"): 6 | role = message.get("role", "assistant") 7 | content = message.get("content", "") 8 | 9 | tool_calls = [ 10 | ToolUseBlock( 11 | type="tool_use", 12 | id="mock_tc_id", 13 | name=call.get("name", ""), 14 | input=call.get("args", {}), 15 | ).model_dump() 16 | for call in function_calls 17 | ] 18 | 19 | return Message( 20 | id="mock_cc_id", 21 | role=role, 22 | model=model, 23 | content=[TextBlock(text=content, type="text"), *tool_calls], 24 | type="message", 25 | usage=Usage( 26 | cache_creation_input_tokens=0, 27 | cache_read_input_tokens=0, 28 | input_tokens=0, 29 | output_tokens=0, 30 | ), 31 | ) 32 | 33 | 34 | class MockAnthropicClient: 35 | def __init__(self): 36 | self.messages = MagicMock() 37 | 38 | def set_response(self, response: Message): 39 | """ 40 | Set the mock to return a specific response. 41 | :param response: A ChatCompletion response to return. 42 | """ 43 | self.messages.create.return_value = response 44 | 45 | def set_sequential_responses(self, responses: list[Message]): 46 | """ 47 | Set the mock to return different responses sequentially. 48 | :param responses: A list of ChatCompletion responses to return in order. 49 | """ 50 | self.messages.create.side_effect = responses 51 | 52 | def assert_create_called_with(self, **kwargs): 53 | self.messages.create.assert_called_with(**kwargs) 54 | 55 | 56 | # Initialize the mock client 57 | client = MockAnthropicClient() 58 | 59 | # Set a sequence of mock responses 60 | client.set_sequential_responses( 61 | [ 62 | create_mock_response( 63 | {"role": "assistant", "content": "First response"}, 64 | [ 65 | { 66 | "name": "process_refund", 67 | "args": {"item_id": "item_123", "reason": "too expensive"}, 68 | } 69 | ], 70 | ), 71 | create_mock_response({"role": "assistant", "content": "Second"}), 72 | ] 73 | ) 74 | 75 | # This should return the first mock response 76 | first_response = client.messages.create() 77 | print(first_response) # Outputs: role='agent' content='First response' 78 | 79 | # This should return the second mock response 80 | second_response = client.messages.create() 81 | print(second_response) # Outputs: role='agent' content='Second response' 82 | -------------------------------------------------------------------------------- /examples/pizza/send_messages.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from rojak import Rojak 3 | from temporalio.client import Client 4 | import asyncio 5 | from examples.pizza.agents import triage_agent 6 | from rojak.agents import ResumeRequest, ResumeResponse 7 | from rojak.workflows import OrchestratorResponse, TaskParams 8 | 9 | 10 | SESSION_ID = "session_1" 11 | 12 | 13 | async def main(): 14 | client = await Client.connect("localhost:7233") 15 | 16 | rojak = Rojak(client, task_queue="tasks") 17 | 18 | agent = triage_agent 19 | 20 | state: Literal["Resume", "Response"] = "Response" 21 | 22 | tool_id: str | None = None 23 | tool_name: str | None = None 24 | 25 | try: 26 | configs = await rojak.get_config(SESSION_ID) 27 | messages = configs.messages 28 | if messages[-1].tool_calls: 29 | state = "Resume" 30 | tool_id = messages[-1].tool_calls[-1]["id"] 31 | tool_name = messages[-1].tool_calls[-1]["name"] 32 | except Exception: 33 | pass 34 | 35 | while True: 36 | if state == "Response": 37 | prompt = input("Enter your message (or 'exit' to quit): ") 38 | else: 39 | prompt = input( 40 | f"Resume '{tool_name}'? Enter 'approve' or state why you reject: " 41 | ) 42 | 43 | if prompt.lower() == "exit": 44 | break 45 | 46 | if state == "Response": 47 | response = await rojak.run( 48 | id=SESSION_ID, 49 | type="persistent", 50 | task=TaskParams( 51 | messages=[{"role": "user", "content": prompt}], agent=agent 52 | ), 53 | context_variables={ 54 | "name": "John", 55 | "cart": {}, 56 | "preferences": "Loves healthy food, allergic to nuts.", 57 | }, 58 | debug=True, 59 | ) 60 | else: 61 | if prompt == "approve": 62 | resume = ResumeResponse(action=prompt, tool_id=tool_id) 63 | else: 64 | resume = ResumeResponse( 65 | action="reject", tool_id=tool_id, content=prompt 66 | ) 67 | 68 | response = await rojak.run( 69 | SESSION_ID, 70 | resume=resume, 71 | ) 72 | 73 | if isinstance(response.result, OrchestratorResponse): 74 | state = "Response" 75 | agent = response.result.agent 76 | print(response.result.messages[-1].content) 77 | 78 | elif isinstance(response.result, ResumeRequest): 79 | print(response.result) 80 | state = "Resume" 81 | tool_id = response.result.tool_id 82 | tool_name = response.result.tool_name 83 | else: 84 | print(response) 85 | 86 | 87 | if __name__ == "__main__": 88 | asyncio.run(main()) 89 | -------------------------------------------------------------------------------- /tests/mock_client.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall 3 | from openai.types.chat.chat_completion_message_tool_call import Function 4 | from openai.types.chat.chat_completion import ChatCompletion, Choice 5 | import json 6 | 7 | 8 | def create_mock_response(message, function_calls=[], model="gpt-4o"): 9 | role = message.get("role", "assistant") 10 | content = message.get("content", "") 11 | tool_calls = ( 12 | [ 13 | ChatCompletionMessageToolCall( 14 | id="mock_tc_id", 15 | type="function", 16 | function=Function( 17 | name=call.get("name", ""), 18 | arguments=json.dumps(call.get("args", {})), 19 | ), 20 | ) 21 | for call in function_calls 22 | ] 23 | if function_calls 24 | else None 25 | ) 26 | 27 | return ChatCompletion( 28 | id="mock_cc_id", 29 | created=1234567890, 30 | model=model, 31 | object="chat.completion", 32 | choices=[ 33 | Choice( 34 | message=ChatCompletionMessage( 35 | role=role, content=content, tool_calls=tool_calls 36 | ), 37 | finish_reason="stop", 38 | index=0, 39 | ) 40 | ], 41 | ) 42 | 43 | 44 | class MockOpenAIClient: 45 | def __init__(self): 46 | self.chat = MagicMock() 47 | self.chat.completions = MagicMock() 48 | 49 | def set_response(self, response: ChatCompletion): 50 | """ 51 | Set the mock to return a specific response. 52 | :param response: A ChatCompletion response to return. 53 | """ 54 | self.chat.completions.create.return_value = response 55 | 56 | def set_sequential_responses(self, responses: list[ChatCompletion]): 57 | """ 58 | Set the mock to return different responses sequentially. 59 | :param responses: A list of ChatCompletion responses to return in order. 60 | """ 61 | self.chat.completions.create.side_effect = responses 62 | 63 | def assert_create_called_with(self, **kwargs): 64 | self.chat.completions.create.assert_called_with(**kwargs) 65 | 66 | 67 | # Initialize the mock client 68 | client = MockOpenAIClient() 69 | 70 | # Set a sequence of mock responses 71 | client.set_sequential_responses( 72 | [ 73 | create_mock_response( 74 | {"role": "assistant", "content": "First response"}, 75 | [ 76 | { 77 | "name": "process_refund", 78 | "args": {"item_id": "item_123", "reason": "too expensive"}, 79 | } 80 | ], 81 | ), 82 | create_mock_response({"role": "assistant", "content": "Second"}), 83 | ] 84 | ) 85 | 86 | # This should return the first mock response 87 | first_response = client.chat.completions.create() 88 | print( 89 | first_response.choices[0].message 90 | ) # Outputs: role='agent' content='First response' 91 | 92 | # This should return the second mock response 93 | second_response = client.chat.completions.create() 94 | print( 95 | second_response.choices[0].message 96 | ) # Outputs: role='agent' content='Second response' 97 | -------------------------------------------------------------------------------- /examples/pizza/agents.py: -------------------------------------------------------------------------------- 1 | from rojak.agents import OpenAIAgent 2 | from rojak.agents import Interrupt 3 | from rojak.types import RetryOptions, RetryPolicy 4 | 5 | 6 | retry_options = RetryOptions( 7 | retry_policy=RetryPolicy(non_retryable_error_types=["TypeError"]) 8 | ) 9 | triage_agent = OpenAIAgent( 10 | name="Triage Agent", 11 | instructions=""" 12 | You are the Triage Agent. Your role is to assist customers by identifying their needs and routing them to the correct agent: 13 | - **Food Ordering** (`to_food_order`): For menu recommendations, adding/removing items, viewing or modifying the cart. 14 | - **Payment** (`to_payment`): For payments, payment method queries, receipts, or payment issues. 15 | - **Feedback** (`to_feedback`): For reviews, ratings, comments, or complaints. 16 | If unsure, guide customers by explaining options (ordering, payment, feedback). For multi-step needs, start with the immediate priority and redirect after. 17 | Always ensure clear, polite, and accurate communication during handoffs. 18 | """, 19 | functions=["to_food_order", "to_payment", "to_feedback"], 20 | tool_choice="required", 21 | retry_options=retry_options, 22 | interrupts=[ 23 | Interrupt("to_food_order"), 24 | Interrupt("to_payment"), 25 | Interrupt("to_feedback"), 26 | ], 27 | ) 28 | 29 | food_order_agent = OpenAIAgent( 30 | name="Food Ordering Agent", 31 | instructions={"type": "function", "name": "food_ordering_instructions"}, 32 | functions=[ 33 | "to_triage", 34 | "add_to_cart", 35 | "remove_from_cart", 36 | "get_cart", 37 | "get_menu", 38 | ], 39 | retry_options=retry_options, 40 | ) 41 | 42 | 43 | payment_agent = OpenAIAgent( 44 | name="Payment Agent", 45 | instructions=""" 46 | You are the Payment Agent. Your role is to securely and efficiently handle the payment process for customers. 47 | Start by confirming the payment amount and presenting the available payment methods (e.g., credit card, mobile wallet, or cash). 48 | Use the `process_payment` function to finalize the transaction and provide a receipt. 49 | In case of a failed transaction, assist customers by suggesting alternative payment options or troubleshooting the issue. 50 | Always maintain a courteous and professional tone, ensuring customers feel supported throughout the payment process. 51 | Redirect customers to the Triage Agent if they need help with non-payment tasks. 52 | """, 53 | functions=[ 54 | "to_triage", 55 | "process_payment", 56 | "get_receipt", 57 | "get_cart", 58 | ], 59 | retry_options=retry_options, 60 | ) 61 | 62 | feedback_agent = OpenAIAgent( 63 | name="Feedback Agent", 64 | instructions=""" 65 | You are the Feedback Agent. Your role is to collect and manage customer feedback to improve the overall experience. 66 | Ask customers to rate their experience and provide detailed comments about the food, service, or satisfaction level. 67 | Use the `provide_feedback` function to log their input. 68 | If the customer has complaints, acknowledge them empathetically and offer to escalate the issue to the relevant team. 69 | Encourage constructive feedback and thank customers for their time and insights. 70 | Redirect customers to the Triage Agent if they wish to engage with other services after providing feedback. 71 | """, 72 | functions=["to_triage", "provide_feedback"], 73 | retry_options=retry_options, 74 | ) 75 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from rojak.utils import function_to_json, function_to_json_anthropic 2 | 3 | 4 | def test_basic_openai_function(): 5 | def basic_function(arg1, arg2): 6 | return arg1 + arg2 7 | 8 | result = function_to_json(basic_function) 9 | assert result == { 10 | "type": "function", 11 | "function": { 12 | "name": "basic_function", 13 | "description": "", 14 | "parameters": { 15 | "type": "object", 16 | "properties": { 17 | "arg1": {"type": "string"}, 18 | "arg2": {"type": "string"}, 19 | }, 20 | "required": ["arg1", "arg2"], 21 | }, 22 | }, 23 | } 24 | 25 | 26 | def test_basic_anthropic_function(): 27 | def basic_function(arg1, arg2): 28 | return arg1 + arg2 29 | 30 | result = function_to_json_anthropic(basic_function) 31 | assert result == { 32 | "name": "basic_function", 33 | "description": "", 34 | "input_schema": { 35 | "type": "object", 36 | "properties": { 37 | "arg1": { 38 | "type": "string", 39 | }, 40 | "arg2": { 41 | "type": "string", 42 | }, 43 | }, 44 | "required": ["arg1", "arg2"], 45 | }, 46 | } 47 | 48 | 49 | def test_complex_function(): 50 | def complex_function_with_types_and_descriptions( 51 | arg1: int, arg2: str, arg3: float = 3.14, arg4: bool = False 52 | ): 53 | """This is a complex function with a docstring.""" 54 | pass 55 | 56 | result = function_to_json(complex_function_with_types_and_descriptions) 57 | assert result == { 58 | "type": "function", 59 | "function": { 60 | "name": "complex_function_with_types_and_descriptions", 61 | "description": "This is a complex function with a docstring.", 62 | "parameters": { 63 | "type": "object", 64 | "properties": { 65 | "arg1": {"type": "integer"}, 66 | "arg2": {"type": "string"}, 67 | "arg3": {"type": "number"}, 68 | "arg4": {"type": "boolean"}, 69 | }, 70 | "required": ["arg1", "arg2"], 71 | }, 72 | }, 73 | } 74 | 75 | 76 | def test_complex_anthropic_function(): 77 | def complex_function_with_types_and_descriptions( 78 | arg1: int, arg2: str, arg3: float = 3.14, arg4: bool = False 79 | ): 80 | """This is a complex function with a docstring.""" 81 | pass 82 | 83 | result = function_to_json_anthropic(complex_function_with_types_and_descriptions) 84 | assert result == { 85 | "name": "complex_function_with_types_and_descriptions", 86 | "description": "This is a complex function with a docstring.", 87 | "input_schema": { 88 | "type": "object", 89 | "properties": { 90 | "arg1": { 91 | "type": "integer", 92 | }, 93 | "arg2": { 94 | "type": "string", 95 | }, 96 | "arg3": { 97 | "type": "number", 98 | }, 99 | "arg4": { 100 | "type": "boolean", 101 | }, 102 | }, 103 | "required": ["arg1", "arg2"], 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /rojak/types/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Literal, Sequence 3 | from typing import TYPE_CHECKING 4 | 5 | if TYPE_CHECKING: 6 | from rojak.mcp.mcp_client import MCPClient 7 | from mcp import Tool 8 | 9 | 10 | @dataclass 11 | class ConversationMessage: 12 | role: Literal["user", "assistant", "tool", "system"] 13 | """The role of the messages author.""" 14 | 15 | content: str | None = None 16 | """The contents of the assistant message. Required unless tool_calls is specified.""" 17 | 18 | tool_calls: list[Any] | None = None 19 | """The tool calls generated by the model, such as function calls.""" 20 | 21 | tool_call_id: str | None = None 22 | """Unique identifier of the tool call.""" 23 | 24 | sender: str | None = None 25 | """Indicate which agent the message originated from.""" 26 | 27 | 28 | ContextVariables = dict[str, Any] 29 | 30 | 31 | @dataclass 32 | class RetryPolicy: 33 | """Options for retrying agent activities.""" 34 | 35 | initial_interval_in_seconds: int = 1 36 | """Backoff interval for the first retry. Default 1s.""" 37 | 38 | backoff_coefficient: float = 2.0 39 | """Coefficient to multiply previous backoff interval by to get new 40 | interval. Default 2.0. 41 | """ 42 | 43 | maximum_interval_in_seconds: int | None = None 44 | """Maximum backoff interval between retries. Default 100x 45 | :py:attr:`initial_interval`. 46 | """ 47 | 48 | maximum_attempts: int = 0 49 | """Maximum number of attempts. 50 | 51 | If 0, the default, there is no maximum. 52 | """ 53 | 54 | non_retryable_error_types: Sequence[str] | None = None 55 | """List of error types that are not retryable.""" 56 | 57 | def __post_init__(self): 58 | # Validation taken from Go SDK's test suite 59 | if self.maximum_attempts == 1: 60 | # Ignore other validation if disabling retries 61 | return 62 | if self.initial_interval_in_seconds < 0: 63 | raise ValueError("Initial interval cannot be negative") 64 | if self.backoff_coefficient < 1: 65 | raise ValueError("Backoff coefficient cannot be less than 1") 66 | if self.maximum_interval_in_seconds: 67 | if self.maximum_interval_in_seconds < 0: 68 | raise ValueError("Maximum interval cannot be negative") 69 | if self.maximum_interval_in_seconds < self.initial_interval_in_seconds: 70 | raise ValueError( 71 | "Maximum interval cannot be less than initial interval" 72 | ) 73 | if self.maximum_attempts < 0: 74 | raise ValueError("Maximum attempts cannot be negative") 75 | 76 | 77 | @dataclass 78 | class RetryOptions: 79 | timeout_in_seconds: int = 60 80 | """Maximum time allowed for an agent to complete its tasks.""" 81 | 82 | retry_policy: RetryPolicy | None = None 83 | """Options for retrying agent activities.""" 84 | 85 | def __post_init__(self): 86 | if self.timeout_in_seconds < 1: 87 | raise ValueError("Timeout cannot be less than one second") 88 | 89 | 90 | @dataclass 91 | class MCPServerConfig: 92 | """Configuration options for the MCP server""" 93 | 94 | type: Literal["sse", "stdio"] 95 | """Connection type to MCP server.""" 96 | 97 | command: str | None = None 98 | """(For `stdio` type) The command or executable to run to start the MCP server.""" 99 | 100 | args: list[str] | None = None 101 | """(For `stdio` type) Command line arguments to pass to the `command`.""" 102 | 103 | url: str | None = None 104 | """(For `websocket` or `sse` type) The URL to connect to the MCP server.""" 105 | 106 | 107 | @dataclass 108 | class InitMcpResult: 109 | """Result from initialising MCP servers""" 110 | 111 | clients: dict[str, "MCPClient"] 112 | """A dictionary mapping server names to their corresponding `MCPClient` instances.""" 113 | 114 | tools: dict[str, "Tool"] 115 | """A dictionary mapping tool names to `Tool` instances registered with the MCP servers.""" 116 | 117 | tool_client_mapping: dict[str, str] 118 | """A dictionary mapping tool names to the corresponding MCP server names.""" 119 | -------------------------------------------------------------------------------- /rojak/agents/openai_agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Literal 3 | from openai import NotGiven, OpenAI 4 | from temporalio import activity 5 | from rojak.utils import function_to_json, mcp_to_openai_tool 6 | from rojak.agents import ( 7 | Agent, 8 | AgentActivities, 9 | AgentCallParams, 10 | AgentExecuteFnResult, 11 | AgentResponse, 12 | AgentOptions, 13 | AgentToolCall, 14 | ExecuteFunctionParams, 15 | ExecuteInstructionsParams, 16 | ) 17 | from openai.types.chat import ChatCompletion 18 | import os 19 | 20 | 21 | @dataclass 22 | class OpenAIAgentOptions(AgentOptions): 23 | api_key: str | None = None 24 | client: OpenAI | None = None 25 | base_url: str | None = None 26 | inference_config: dict[str, Any] = field( 27 | default_factory=lambda: { 28 | "max_tokens": 1000, 29 | "temperature": 0.0, 30 | "top_p": 0.9, 31 | "stop_sequences": [], 32 | } 33 | ) 34 | 35 | 36 | @dataclass 37 | class OpenAIAgent(Agent): 38 | model: str = "gpt-4o-mini" 39 | 40 | type: Literal["openai"] = field(default="openai") 41 | """Type of agent. Must be `"openai"`.""" 42 | 43 | inference_config: dict[str, Any] | None = None 44 | """Inference configuration for OpenAI models""" 45 | 46 | 47 | class OpenAIAgentActivities(AgentActivities): 48 | def __init__(self, options: OpenAIAgentOptions = OpenAIAgentOptions()): 49 | super().__init__(options) 50 | 51 | if options.client: 52 | self.client = options.client 53 | elif options.api_key: 54 | self.client = OpenAI(api_key=options.api_key, base_url=options.base_url) 55 | elif os.environ.get("OPENAI_API_KEY"): 56 | self.client = OpenAI( 57 | api_key=os.environ.get("OPENAI_API_KEY"), base_url=options.base_url 58 | ) 59 | else: 60 | raise ValueError("OpenAI API key is required") 61 | self.inference_config = options.inference_config 62 | 63 | @staticmethod 64 | def handle_model_response(response: ChatCompletion) -> AgentResponse: 65 | """Convert model response to AgentResponse""" 66 | message = response.choices[0].message 67 | if message.tool_calls: 68 | tool_calls = [ 69 | AgentToolCall(**dict(tool_call)) for tool_call in message.tool_calls 70 | ] 71 | return AgentResponse( 72 | content=message.content, 73 | tool_calls=tool_calls, 74 | type="tool", 75 | ) 76 | elif message.content: 77 | return AgentResponse(content=message.content, type="text") 78 | else: 79 | raise ValueError("Unknown message type") 80 | 81 | @activity.defn(name="openai_call") 82 | async def call(self, params: AgentCallParams) -> AgentResponse: 83 | # Create list of messages 84 | messages = [ 85 | { 86 | "role": msg.role, 87 | "content": msg.content, 88 | "tool_calls": msg.tool_calls, 89 | "tool_call_id": msg.tool_call_id, 90 | } 91 | for msg in params.messages 92 | ] 93 | 94 | # Update inference config if needed 95 | if params.inference_config: 96 | self.inference_config = {**self.inference_config, **params.inference_config} 97 | 98 | # Create tool call json 99 | functions = [ 100 | self.function_map[name] 101 | for name in params.function_names 102 | if name in self.function_map 103 | ] 104 | tools = [function_to_json(f) for f in functions] 105 | for tool in tools: 106 | fn_params = tool["function"]["parameters"] 107 | fn_params["properties"].pop("context_variables", None) 108 | if "context_variables" in fn_params["required"]: 109 | fn_params["required"].remove("context_variables") 110 | 111 | tools += [mcp_to_openai_tool(tool) for tool in self.mcp_result.tools.values()] 112 | 113 | response = self.client.chat.completions.create( 114 | model=params.model, 115 | messages=messages, 116 | tools=tools or None, 117 | tool_choice=params.tool_choice, 118 | parallel_tool_calls=params.parallel_tool_calls if tools else NotGiven(), 119 | max_tokens=self.inference_config["max_tokens"], 120 | temperature=self.inference_config["temperature"], 121 | top_p=self.inference_config["top_p"], 122 | stop=self.inference_config["stop_sequences"], 123 | ) 124 | 125 | return self.handle_model_response(response) 126 | 127 | @activity.defn(name="openai_execute_instructions") 128 | async def execute_instructions(self, params: ExecuteInstructionsParams) -> str: 129 | return await super().execute_instructions(params) 130 | 131 | @activity.defn(name="openai_execute_function") 132 | async def execute_function( 133 | self, params: ExecuteFunctionParams 134 | ) -> str | OpenAIAgent | AgentExecuteFnResult: 135 | return await super().execute_function(params) 136 | -------------------------------------------------------------------------------- /examples/pizza/functions.py: -------------------------------------------------------------------------------- 1 | from .agents import ( 2 | triage_agent, 3 | food_order_agent, 4 | payment_agent, 5 | feedback_agent, 6 | ) 7 | from copy import deepcopy 8 | 9 | 10 | def food_ordering_instructions(context_variables): 11 | return f""" 12 | You are the Food Ordering Agent. Your role is to assist customers in selecting and managing their food orders. 13 | - **Recommendations**: Start by retrieving the available menu using the `get_menu` function. Guide customers in choosing items based on their preferences or suggest popular options. Provide clear and concise descriptions of menu items. 14 | - **Adding to Order**: When customers request to add items, confirm their selection against the menu. Use the `get_menu` function to ensure the item is available. If the item exists, proceed to add it using the `add_to_cart` function. For unavailable items, politely inform the customer and suggest alternatives from the menu. 15 | - **Order Management**: Customers may also want to modify their order. Use the `get_cart` function to show the current order details and confirm requested changes. Use `remove_from_cart` to delete specific items as needed and provide an updated summary of the cart. 16 | - **Guiding and Redirecting**: If customers are unsure of what they want, offer assistance by highlighting popular dishes or providing recommendations. Redirect customers to the Triage Agent if they need help beyond ordering or recommendations. 17 | - **Tone and Accuracy**: Maintain a polite and professional tone, ensuring order details are accurate at every step. Always summarize the current state of the cart after any action. 18 | Redirect customers to the Triage Agent if they need help with non-ordering tasks. 19 | Remember: Customers can only order items listed in the menu. Validate all item requests before proceeding with any order changes. 20 | Customer Preferences: {context_variables.get("preferences", "N/A")} 21 | """ 22 | 23 | 24 | def to_food_order(): 25 | """Route to the food order agent.""" 26 | return food_order_agent 27 | 28 | 29 | def to_payment(): 30 | """Route to the payment agent.""" 31 | return payment_agent 32 | 33 | 34 | def to_feedback(): 35 | """Route to the feedback agent.""" 36 | return feedback_agent 37 | 38 | 39 | def to_triage(): 40 | """Route to the triage agent.""" 41 | return triage_agent 42 | 43 | 44 | def get_menu(): 45 | """Return a list of menu items for a pizza place.""" 46 | menu = [ 47 | {"name": "Margherita Pizza", "cost": 12.99}, 48 | {"name": "Pepperoni Pizza", "cost": 14.49}, 49 | {"name": "BBQ Chicken Pizza", "cost": 15.99}, 50 | {"name": "Veggie Supreme Pizza", "cost": 13.99}, 51 | {"name": "Meat Lovers Pizza", "cost": 16.99}, 52 | {"name": "Garlic Knots", "cost": 5.99}, 53 | {"name": "Cheesy Breadsticks", "cost": 6.49}, 54 | {"name": "Classic Caesar Salad", "cost": 9.49}, 55 | {"name": "Buffalo Wings (8 pieces)", "cost": 10.99}, 56 | {"name": "Mozzarella Sticks", "cost": 7.99}, 57 | {"name": "Chocolate Chip Cannoli", "cost": 4.99}, 58 | {"name": "Tiramisu Slice", "cost": 5.99}, 59 | {"name": "Fountain Soda", "cost": 2.49}, 60 | {"name": "Sparkling Lemonade", "cost": 3.49}, 61 | {"name": "Iced Tea", "cost": 2.99}, 62 | {"name": "Craft Beer (Pint)", "cost": 6.99}, 63 | {"name": "House Red Wine (Glass)", "cost": 7.99}, 64 | ] 65 | return menu 66 | 67 | 68 | def add_to_cart(item: str, cost: int, quantity: int, context_variables: dict): 69 | """Add an item to the cart.""" 70 | if item in context_variables["cart"]: 71 | context_variables["cart"][item]["quantity"] += quantity 72 | else: 73 | context_variables["cart"][item] = {"quantity": quantity, "unit_cost": cost} 74 | return f"Added {quantity} of {item} to your cart." 75 | 76 | 77 | def remove_from_cart(item: str, quantity: int, context_variables: dict): 78 | """Remove an item from the cart by quantity.""" 79 | if item in context_variables["cart"]: 80 | if quantity >= context_variables["cart"][item]["quantity"]: 81 | del context_variables["cart"][item] 82 | return f"Removed {item} from your cart." 83 | else: 84 | context_variables["cart"][item]["quantity"] -= quantity 85 | return f"Decreased quantity of {item} by {quantity}." 86 | else: 87 | return f"{item} is not in your cart." 88 | 89 | 90 | def get_cart(context_variables: dict): 91 | """Return the contents of the cart.""" 92 | cart = context_variables.get("cart", {}) 93 | if cart: 94 | cart_items = [ 95 | f"{item} (x{details['quantity']}) - ${details['quantity'] * details['unit_cost']:.2f}" 96 | for item, details in cart.items() 97 | ] 98 | return f"Your cart contains: {', '.join(cart_items)}" 99 | else: 100 | return "Your cart is empty." 101 | 102 | 103 | def process_payment(context_variables: dict): 104 | """Process the payment and clear the cart.""" 105 | context_variables["receipt"] = deepcopy(context_variables["cart"]) 106 | context_variables["cart"] = {} 107 | return "Payment processed successfully!" 108 | 109 | 110 | def get_receipt(context_variables: dict): 111 | """Return the receipt.""" 112 | receipt = context_variables.get("receipt", {}) 113 | if receipt: 114 | receipt_items = [ 115 | f"{item} (x{details['quantity']}) - ${details['quantity'] * details['unit_cost']:.2f}" 116 | for item, details in receipt.items() 117 | ] 118 | return f"Your receipt contains: {', '.join(receipt_items)}" 119 | else: 120 | return "No receipt available." 121 | 122 | 123 | def provide_feedback(): 124 | """Submit a review.""" 125 | return "Review submitted successfully!" 126 | -------------------------------------------------------------------------------- /rojak/utils/helpers.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | import inspect 3 | from rojak.types import RetryPolicy 4 | from temporalio.common import RetryPolicy as TRetryPolicy 5 | from mcp import Tool 6 | 7 | 8 | def function_to_json(func) -> dict: 9 | """ 10 | Converts a Python function into a JSON-serializable dictionary 11 | that describes the function's signature, including its name, 12 | description, and parameters. 13 | 14 | Compatible with OpenAI tool definition. 15 | 16 | Args: 17 | func: The function to be converted. 18 | 19 | Returns: 20 | A dictionary representing the function's signature in JSON format. 21 | """ 22 | type_map = { 23 | str: "string", 24 | int: "integer", 25 | float: "number", 26 | bool: "boolean", 27 | list: "array", 28 | dict: "object", 29 | type(None): "null", 30 | } 31 | 32 | try: 33 | signature = inspect.signature(func) 34 | except ValueError as e: 35 | raise ValueError( 36 | f"Failed to get signature for function {func.__name__}: {str(e)}" 37 | ) 38 | 39 | parameters = {} 40 | for param in signature.parameters.values(): 41 | try: 42 | param_type = type_map.get(param.annotation, "string") 43 | except KeyError as e: 44 | raise KeyError( 45 | f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}" 46 | ) 47 | parameters[param.name] = {"type": param_type} 48 | 49 | required = [ 50 | param.name 51 | for param in signature.parameters.values() 52 | if param.default == inspect._empty 53 | ] 54 | 55 | return { 56 | "type": "function", 57 | "function": { 58 | "name": func.__name__, 59 | "description": func.__doc__ or "", 60 | "parameters": { 61 | "type": "object", 62 | "properties": parameters, 63 | "required": required, 64 | }, 65 | }, 66 | } 67 | 68 | 69 | def function_to_json_anthropic(func) -> dict: 70 | """ 71 | Converts a Python function into a JSON-serializable dictionary 72 | that describes the function's signature, including its name, 73 | description, and parameters. 74 | 75 | Compatible with Anthropic tool definition. 76 | 77 | Args: 78 | func: The function to be converted. 79 | 80 | Returns: 81 | A dictionary representing the function's signature in JSON format. 82 | """ 83 | type_map = { 84 | str: "string", 85 | int: "integer", 86 | float: "number", 87 | bool: "boolean", 88 | list: "array", 89 | dict: "object", 90 | type(None): "null", 91 | } 92 | 93 | try: 94 | signature = inspect.signature(func) 95 | except ValueError as e: 96 | raise ValueError( 97 | f"Failed to get signature for function {func.__name__}: {str(e)}" 98 | ) 99 | 100 | parameters = {} 101 | for param in signature.parameters.values(): 102 | try: 103 | param_type = type_map.get(param.annotation, "string") 104 | except KeyError as e: 105 | raise KeyError( 106 | f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}" 107 | ) 108 | parameters[param.name] = {"type": param_type} 109 | 110 | required = [ 111 | param.name 112 | for param in signature.parameters.values() 113 | if param.default == inspect._empty 114 | ] 115 | 116 | return { 117 | "name": func.__name__, 118 | "description": func.__doc__ or "", 119 | "input_schema": { 120 | "type": "object", 121 | "properties": parameters, 122 | "required": required, 123 | }, 124 | } 125 | 126 | 127 | def mcp_to_openai_tool(tool: Tool) -> dict[str, any]: 128 | """Convert MCP tool to openai format. 129 | 130 | Args: 131 | tool (Tool): MCP Tool object. 132 | 133 | Returns: 134 | dict[str, any]: A dictionary representing tool in JSON format. 135 | """ 136 | return { 137 | "type": "function", 138 | "function": { 139 | "name": tool.name, 140 | "description": tool.description, 141 | "parameters": tool.inputSchema, 142 | }, 143 | } 144 | 145 | 146 | def mcp_to_anthropic_tool(tool: Tool) -> dict[str, any]: 147 | """Convert MCP tool to anthropic format 148 | 149 | Args: 150 | tool (Tool): MCP Tool object 151 | 152 | Returns: 153 | dict[str, any]: A dictionary representing tool in JSON format. 154 | """ 155 | return { 156 | "name": tool.name, 157 | "description": tool.description, 158 | "input_schema": tool.inputSchema, 159 | } 160 | 161 | 162 | def create_retry_policy(retry_policy: RetryPolicy | None) -> TRetryPolicy | None: 163 | """Convert serialisable retry policy to Temporal RetryPolicy.""" 164 | if retry_policy is None: 165 | return None 166 | 167 | initial_interval = ( 168 | timedelta(seconds=retry_policy.initial_interval_in_seconds) 169 | if retry_policy.initial_interval_in_seconds 170 | else timedelta(seconds=1) 171 | ) 172 | 173 | maximum_interval = ( 174 | timedelta(seconds=retry_policy.maximum_interval_in_seconds) 175 | if retry_policy.maximum_interval_in_seconds 176 | else None 177 | ) 178 | 179 | return TRetryPolicy( 180 | initial_interval=initial_interval, 181 | backoff_coefficient=retry_policy.backoff_coefficient, 182 | maximum_interval=maximum_interval, 183 | maximum_attempts=retry_policy.maximum_attempts, 184 | non_retryable_error_types=retry_policy.non_retryable_error_types, 185 | ) 186 | 187 | 188 | def debug_print(debug: bool, timestamp: datetime, *args: str) -> None: 189 | if not debug: 190 | return 191 | message = " ".join(map(str, args)) 192 | print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m") 193 | -------------------------------------------------------------------------------- /tests/test_workflow.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from unittest.mock import patch, Mock 3 | import uuid 4 | import pytest 5 | from rojak import Rojak 6 | from temporalio.testing import WorkflowEnvironment 7 | from rojak.agents import ( 8 | OpenAIAgent, 9 | OpenAIAgentActivities, 10 | OpenAIAgentOptions, 11 | ) 12 | from rojak.workflows import UpdateConfigParams, TaskParams 13 | from tests.mock_client import MockOpenAIClient, create_mock_response 14 | 15 | DEFAULT_RESPONSE_CONTENT = "sample response content" 16 | 17 | DEFAULT_RESPONSE_CONTENT_2 = "sample response content 2" 18 | 19 | 20 | @pytest.fixture 21 | def mock_openai_client(): 22 | m = MockOpenAIClient() 23 | m.set_response( 24 | create_mock_response({"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT}) 25 | ) 26 | return m 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_max_turns(mock_openai_client: MockOpenAIClient): 31 | task_queue_name = str(uuid.uuid4()) 32 | 33 | mock_openai_client.set_sequential_responses( 34 | [ 35 | create_mock_response( 36 | message={"role": "assistant", "content": ""}, 37 | function_calls=[{"name": "transfer_to_agent2"}], 38 | ), 39 | create_mock_response( 40 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 41 | ), 42 | ] 43 | ) 44 | 45 | def transfer_to_agent2(): 46 | return agent2 47 | 48 | agent1 = OpenAIAgent(name="Test Agent 1", functions=["transfer_to_agent2"]) 49 | agent2 = OpenAIAgent(name="Test Agent 2") 50 | 51 | async with await WorkflowEnvironment.start_time_skipping() as env: 52 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 53 | openai_activities = OpenAIAgentActivities( 54 | OpenAIAgentOptions( 55 | client=mock_openai_client, 56 | all_functions=[transfer_to_agent2], 57 | ) 58 | ) 59 | worker = await rojak.create_worker([openai_activities]) 60 | async with worker: 61 | agent = agent1 62 | 63 | response = await rojak.run( 64 | id=str(uuid.uuid4()), 65 | type="persistent", 66 | task=TaskParams( 67 | agent=agent, 68 | messages=[{"role": "user", "content": "Hello how are you?"}], 69 | ), 70 | max_turns=2, 71 | ) 72 | 73 | # Should not reach agent 2. 74 | assert response.result.messages[-1].sender != "Test Agent 2" 75 | 76 | 77 | @pytest.mark.asyncio 78 | async def test_history_size(mock_openai_client: MockOpenAIClient): 79 | task_queue_name = str(uuid.uuid4()) 80 | async with await WorkflowEnvironment.start_time_skipping() as env: 81 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 82 | openai_activities = OpenAIAgentActivities( 83 | OpenAIAgentOptions(client=mock_openai_client) 84 | ) 85 | worker = await rojak.create_worker([openai_activities]) 86 | async with worker: 87 | agent = OpenAIAgent(name="assistant") 88 | 89 | response = await rojak.run( 90 | id=str(uuid.uuid4()), 91 | type="persistent", 92 | task=TaskParams( 93 | agent=agent, 94 | messages=[{"role": "user", "content": "Hello how are you?"}], 95 | ), 96 | history_size=1, 97 | ) 98 | 99 | assert len(response.result.messages) == 2 100 | 101 | config = await rojak.get_config(response.id) 102 | 103 | assert len(config.messages) == 1 104 | 105 | 106 | @pytest.mark.asyncio 107 | async def test_continue_as_new(mock_openai_client: MockOpenAIClient): 108 | task_queue_name = str(uuid.uuid4()) 109 | 110 | mock_workflow_info = Mock() 111 | mock_workflow_info.get_current_history_size.return_value = 10_001 112 | mock_workflow_info.get_current_history_length.return_value = 20_000_001 113 | 114 | async with await WorkflowEnvironment.start_time_skipping() as env: 115 | with patch("temporalio.workflow.info", return_value=mock_workflow_info): 116 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 117 | openai_activities = OpenAIAgentActivities( 118 | OpenAIAgentOptions(client=mock_openai_client) 119 | ) 120 | worker = await rojak.create_worker([openai_activities]) 121 | async with worker: 122 | agent = OpenAIAgent(name="assistant") 123 | configs = { 124 | "agent": agent, 125 | "max_turns": 30, 126 | "context_variables": {"hello": "world"}, 127 | "history_size": 10, 128 | "debug": True, 129 | } 130 | 131 | response = await rojak.run( 132 | id=str(uuid.uuid4()), 133 | type="persistent", 134 | task=TaskParams( 135 | agent=configs["agent"], 136 | messages=[{"role": "user", "content": "Hello how are you?"}], 137 | ), 138 | max_turns=configs["max_turns"], 139 | context_variables=configs["context_variables"], 140 | history_size=configs["history_size"], 141 | debug=configs["debug"], 142 | ) 143 | 144 | await asyncio.sleep(1) 145 | 146 | mock_workflow_info.get_current_history_size.assert_called_once() 147 | mock_workflow_info.get_current_history_length.assert_called_once() 148 | 149 | response = await rojak.get_config(response.id) 150 | assert response.max_turns == configs["max_turns"] 151 | assert response.context_variables == configs["context_variables"] 152 | assert response.history_size == configs["history_size"] 153 | assert response.debug == configs["debug"] 154 | assert len(response.messages) == 2 155 | 156 | 157 | @pytest.mark.asyncio 158 | async def test_get_result(mock_openai_client: MockOpenAIClient): 159 | task_queue_name = str(uuid.uuid4()) 160 | async with await WorkflowEnvironment.start_time_skipping() as env: 161 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 162 | openai_activities = OpenAIAgentActivities( 163 | OpenAIAgentOptions(client=mock_openai_client) 164 | ) 165 | worker = await rojak.create_worker([openai_activities]) 166 | async with worker: 167 | agent = OpenAIAgent(name="assistant") 168 | 169 | response = await rojak.run( 170 | id=str(uuid.uuid4()), 171 | type="persistent", 172 | task=TaskParams( 173 | agent=agent, 174 | messages=[{"role": "user", "content": "Hello how are you?"}], 175 | ), 176 | ) 177 | 178 | response = await rojak.get_result(response.id, response.task_id) 179 | 180 | assert response.agent == agent 181 | assert response.messages[-1].role == "assistant" 182 | assert response.messages[-1].content == DEFAULT_RESPONSE_CONTENT 183 | 184 | 185 | @pytest.mark.asyncio 186 | async def test_update_config(mock_openai_client: MockOpenAIClient): 187 | task_queue_name = str(uuid.uuid4()) 188 | async with await WorkflowEnvironment.start_time_skipping() as env: 189 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 190 | openai_activities = OpenAIAgentActivities( 191 | OpenAIAgentOptions(client=mock_openai_client) 192 | ) 193 | worker = await rojak.create_worker([openai_activities]) 194 | async with worker: 195 | agent = OpenAIAgent(name="assistant") 196 | 197 | response = await rojak.run( 198 | id=str(uuid.uuid4()), 199 | type="persistent", 200 | task=TaskParams( 201 | agent=agent, 202 | messages=[{"role": "user", "content": "Hello how are you?"}], 203 | ), 204 | ) 205 | 206 | await rojak.update_config( 207 | response.id, 208 | UpdateConfigParams( 209 | context_variables={"hello": "world"}, 210 | max_turns=100, 211 | debug=True, 212 | messages=[{"role": "user", "content": "Hello"}], 213 | ), 214 | ) 215 | 216 | response = await rojak.get_config(response.id) 217 | 218 | assert response.debug is True 219 | assert response.max_turns == 100 220 | assert response.context_variables.get("hello") == "world" 221 | assert response.messages[-1].content == "Hello" 222 | -------------------------------------------------------------------------------- /examples/mcp_weather/mcp_weather_server.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import httpx 4 | from mcp.server.models import InitializationOptions 5 | import mcp.types as types 6 | from mcp.server import NotificationOptions, Server 7 | import mcp.server.stdio 8 | import asyncio 9 | 10 | NWS_API_BASE = "https://api.weather.gov" 11 | USER_AGENT = "weather-app/1.0" 12 | 13 | server = Server("weather") 14 | 15 | 16 | @server.list_tools() 17 | async def handle_list_tools() -> list[types.Tool]: 18 | """ 19 | List available tools. 20 | Each tool specifies its arguments using JSON Schema validation. 21 | """ 22 | return [ 23 | types.Tool( 24 | name="get-alerts", 25 | description="Get weather alerts for a state", 26 | inputSchema={ 27 | "type": "object", 28 | "properties": { 29 | "state": { 30 | "type": "string", 31 | "description": "Two-letter state code (e.g. CA, NY)", 32 | }, 33 | }, 34 | "required": ["state"], 35 | }, 36 | ), 37 | types.Tool( 38 | name="get-forecast", 39 | description="Get weather forecast for a location", 40 | inputSchema={ 41 | "type": "object", 42 | "properties": { 43 | "latitude": { 44 | "type": "number", 45 | "description": "Latitude of the location", 46 | }, 47 | "longitude": { 48 | "type": "number", 49 | "description": "Longitude of the location", 50 | }, 51 | }, 52 | "required": ["latitude", "longitude"], 53 | }, 54 | ), 55 | ] 56 | 57 | 58 | async def make_nws_request( 59 | client: httpx.AsyncClient, url: str 60 | ) -> dict[str, Any] | None: 61 | """Make a request to the NWS API with proper error handling.""" 62 | headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"} 63 | 64 | try: 65 | response = await client.get(url, headers=headers, timeout=30.0) 66 | response.raise_for_status() 67 | return response.json() 68 | except Exception: 69 | return None 70 | 71 | 72 | def format_alert(feature: dict) -> str: 73 | """Format an alert feature into a concise string.""" 74 | props = feature["properties"] 75 | return ( 76 | f"Event: {props.get('event', 'Unknown')}\n" 77 | f"Area: {props.get('areaDesc', 'Unknown')}\n" 78 | f"Severity: {props.get('severity', 'Unknown')}\n" 79 | f"Status: {props.get('status', 'Unknown')}\n" 80 | f"Headline: {props.get('headline', 'No headline')}\n" 81 | "---" 82 | ) 83 | 84 | 85 | @server.call_tool() 86 | async def handle_call_tool( 87 | name: str, arguments: dict | None 88 | ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: 89 | """ 90 | Handle tool execution requests. 91 | Tools can fetch weather data and notify clients of changes. 92 | """ 93 | if not arguments: 94 | raise ValueError("Missing arguments") 95 | 96 | if name == "get-alerts": 97 | state = arguments.get("state") 98 | if not state: 99 | raise ValueError("Missing state parameter") 100 | 101 | # Convert state to uppercase to ensure consistent format 102 | state = state.upper() 103 | if len(state) != 2: 104 | raise ValueError("State must be a two-letter code (e.g. CA, NY)") 105 | 106 | async with httpx.AsyncClient() as client: 107 | alerts_url = f"{NWS_API_BASE}/alerts?area={state}" 108 | alerts_data = await make_nws_request(client, alerts_url) 109 | 110 | if not alerts_data: 111 | return [ 112 | types.TextContent( 113 | type="text", text="Failed to retrieve alerts data" 114 | ) 115 | ] 116 | 117 | features = alerts_data.get("features", []) 118 | if not features: 119 | return [ 120 | types.TextContent(type="text", text=f"No active alerts for {state}") 121 | ] 122 | 123 | # Format each alert into a concise string 124 | formatted_alerts = [format_alert(feature) for feature in features] 125 | alerts_text = f"Active alerts for {state}:\n\n" + "\n".join( 126 | formatted_alerts 127 | ) 128 | 129 | return [types.TextContent(type="text", text=alerts_text)] 130 | elif name == "get-forecast": 131 | try: 132 | latitude = float(arguments.get("latitude")) 133 | longitude = float(arguments.get("longitude")) 134 | except (TypeError, ValueError): 135 | return [ 136 | types.TextContent( 137 | type="text", 138 | text="Invalid coordinates. Please provide valid numbers for latitude and longitude.", 139 | ) 140 | ] 141 | 142 | # Basic coordinate validation 143 | if not (-90 <= latitude <= 90) or not (-180 <= longitude <= 180): 144 | return [ 145 | types.TextContent( 146 | type="text", 147 | text="Invalid coordinates. Latitude must be between -90 and 90, longitude between -180 and 180.", 148 | ) 149 | ] 150 | 151 | async with httpx.AsyncClient() as client: 152 | # First get the grid point 153 | lat_str = f"{latitude}" 154 | lon_str = f"{longitude}" 155 | points_url = f"{NWS_API_BASE}/points/{lat_str},{lon_str}" 156 | points_data = await make_nws_request(client, points_url) 157 | 158 | if not points_data: 159 | return [ 160 | types.TextContent( 161 | type="text", 162 | text=f"Failed to retrieve grid point data for coordinates: {latitude}, {longitude}. This location may not be supported by the NWS API (only US locations are supported).", 163 | ) 164 | ] 165 | 166 | # Extract forecast URL from the response 167 | properties = points_data.get("properties", {}) 168 | forecast_url = properties.get("forecast") 169 | 170 | if not forecast_url: 171 | return [ 172 | types.TextContent( 173 | type="text", 174 | text="Failed to get forecast URL from grid point data", 175 | ) 176 | ] 177 | 178 | # Get the forecast 179 | forecast_data = await make_nws_request(client, forecast_url) 180 | 181 | if not forecast_data: 182 | return [ 183 | types.TextContent( 184 | type="text", text="Failed to retrieve forecast data" 185 | ) 186 | ] 187 | 188 | # Format the forecast periods 189 | periods = forecast_data.get("properties", {}).get("periods", []) 190 | if not periods: 191 | return [ 192 | types.TextContent(type="text", text="No forecast periods available") 193 | ] 194 | 195 | # Format each period into a concise string 196 | formatted_forecast = [] 197 | for period in periods: 198 | forecast_text = ( 199 | f"{period.get('name', 'Unknown')}:\n" 200 | f"Temperature: {period.get('temperature', 'Unknown')}°{period.get('temperatureUnit', 'F')}\n" 201 | f"Wind: {period.get('windSpeed', 'Unknown')} {period.get('windDirection', '')}\n" 202 | f"{period.get('shortForecast', 'No forecast available')}\n" 203 | "---" 204 | ) 205 | formatted_forecast.append(forecast_text) 206 | 207 | forecast_text = f"Forecast for {latitude}, {longitude}:\n\n" + "\n".join( 208 | formatted_forecast 209 | ) 210 | 211 | return [types.TextContent(type="text", text=forecast_text)] 212 | else: 213 | raise ValueError(f"Unknown tool: {name}") 214 | 215 | 216 | async def main(): 217 | # Run the server using stdin/stdout streams 218 | async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): 219 | await server.run( 220 | read_stream, 221 | write_stream, 222 | InitializationOptions( 223 | server_name="weather", 224 | server_version="0.1.0", 225 | capabilities=server.get_capabilities( 226 | notification_options=NotificationOptions(), 227 | experimental_capabilities={}, 228 | ), 229 | ), 230 | ) 231 | 232 | 233 | if __name__ == "__main__": 234 | asyncio.run(main()) 235 | -------------------------------------------------------------------------------- /rojak/agents/anthropic_agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import json 3 | import os 4 | from typing import Literal 5 | from temporalio import activity 6 | from anthropic import Anthropic, AnthropicBedrock, NotGiven 7 | from anthropic.types import ( 8 | Message, 9 | TextBlock, 10 | ToolUseBlock, 11 | MessageParam, 12 | ToolResultBlockParam, 13 | TextBlockParam, 14 | ) 15 | from rojak.agents.agent import ( 16 | Agent, 17 | AgentActivities, 18 | AgentCallParams, 19 | AgentExecuteFnResult, 20 | AgentOptions, 21 | AgentResponse, 22 | AgentToolCall, 23 | ExecuteFunctionParams, 24 | ExecuteInstructionsParams, 25 | ToolCallFunction, 26 | ) 27 | from rojak.types.types import ConversationMessage 28 | from rojak.utils import function_to_json_anthropic, mcp_to_anthropic_tool 29 | 30 | 31 | @dataclass 32 | class AnthropicAgentOptions(AgentOptions): 33 | api_key: str | None = None 34 | client: Anthropic | AnthropicBedrock | None = None 35 | inference_config: dict[str, any] = field( 36 | default_factory=lambda: { 37 | "max_tokens": 1000, 38 | "temperature": 0.0, 39 | "top_p": 0.9, 40 | "stop_sequences": [], 41 | } 42 | ) 43 | 44 | 45 | @dataclass 46 | class AnthropicAgent(Agent): 47 | model: str = "claude-3-5-haiku-20241022" 48 | 49 | type: Literal["anthropic"] = field(default="anthropic") 50 | """Type of agent. Must be `"anthropic"`""" 51 | 52 | inference_config: dict[str, any] | None = None 53 | """Inference configuration for Anthropic models""" 54 | 55 | 56 | class AnthropicAgentActivities(AgentActivities): 57 | def __init__(self, options: AnthropicAgentOptions = AnthropicAgentOptions()): 58 | super().__init__(options) 59 | 60 | if options.client: 61 | self.client = options.client 62 | elif options.api_key: 63 | self.client = Anthropic(api_key=options.api_key) 64 | elif os.environ.get("ANTHROPIC_API_KEY"): 65 | self.client = Anthropic() 66 | else: 67 | raise ValueError("Anthropic API key is required") 68 | 69 | self.inference_config = options.inference_config 70 | 71 | @staticmethod 72 | def handle_model_response(response: Message) -> AgentResponse: 73 | """Convert model response to AgentResponse""" 74 | content = "" 75 | tool_calls = [] 76 | for block in response.content: 77 | if isinstance(block, TextBlock): 78 | content += block.text 79 | elif isinstance(block, ToolUseBlock): 80 | tool_calls.append( 81 | AgentToolCall( 82 | id=block.id, 83 | function=ToolCallFunction( 84 | name=block.name, 85 | arguments=json.dumps(block.input), 86 | ), 87 | ) 88 | ) 89 | 90 | if tool_calls: 91 | return AgentResponse(type="tool", content=content, tool_calls=tool_calls) 92 | else: 93 | return AgentResponse(type="text", tool_calls=tool_calls, content=content) 94 | 95 | @staticmethod 96 | def convert_messages( 97 | messages: list[ConversationMessage], 98 | ) -> tuple[list[MessageParam], str | None]: 99 | """Convert messages to be Anthropic compatible. 100 | 101 | 1. System messages become the returned system string (only one). 102 | 2. User/assistant text messages remain single messages. 103 | 3. Assistant messages that contain tool calls become a single message 104 | with a list of tool-use blocks. 105 | 4. Consecutive tool-result messages are combined into one user message, 106 | where each result is a separate 'tool_result' block. 107 | """ 108 | converted_messages: list[MessageParam] = [] 109 | system_message: str | None = None 110 | 111 | # Temporary storage for consecutive tool-result blocks 112 | tool_result_buffer: list[ToolResultBlockParam] = [] 113 | 114 | def flush_tool_results(): 115 | # If we have accumulated any tool results, 116 | # append them as a single user message and clear the buffer. 117 | nonlocal tool_result_buffer 118 | if tool_result_buffer: 119 | converted_messages.append( 120 | MessageParam(role="user", content=tool_result_buffer) 121 | ) 122 | tool_result_buffer = [] 123 | 124 | for msg in messages: 125 | if msg.role == "system": 126 | system_message = msg.content 127 | 128 | elif msg.tool_calls: 129 | # Whenever we hit an assistant message with tool calls, 130 | # first flush any pending tool results (belonging to 'tool' role messages). 131 | flush_tool_results() 132 | 133 | # Convert each tool call into a tool-use block. 134 | tool_blocks = [] 135 | for tool_call_dict in msg.tool_calls: 136 | tool_call = AgentToolCall(**tool_call_dict) 137 | tool_blocks.append( 138 | ToolUseBlock( 139 | type="tool_use", 140 | id=tool_call.id, 141 | input=json.loads(tool_call.function.arguments), 142 | name=tool_call.function.name, 143 | ).model_dump() 144 | ) 145 | 146 | # Append as a single assistant message containing multiple tool calls. 147 | converted_messages.append( 148 | MessageParam(role="assistant", content=tool_blocks) 149 | ) 150 | 151 | elif msg.role == "tool": 152 | # We accumulate tool results into a buffer until we reach a non-tool message. 153 | tool_result_buffer.append( 154 | ToolResultBlockParam( 155 | type="tool_result", 156 | tool_use_id=msg.tool_call_id, 157 | content=[TextBlockParam(type="text", text=msg.content)], 158 | ) 159 | ) 160 | 161 | else: 162 | # For normal user/assistant messages, first flush any pending tool results. 163 | flush_tool_results() 164 | 165 | # Then add this user/assistant message as-is. 166 | if msg.role in ["user", "assistant"]: 167 | converted_messages.append( 168 | MessageParam(role=msg.role, content=msg.content) 169 | ) 170 | 171 | # If conversation ended with tool results, flush them. 172 | flush_tool_results() 173 | 174 | return converted_messages, system_message 175 | 176 | @activity.defn(name="anthropic_call") 177 | async def call(self, params: AgentCallParams) -> AgentResponse: 178 | # Create list of messages 179 | messages, system_message = self.convert_messages(params.messages) 180 | 181 | if params.inference_config: 182 | self.inference_config = {**self.inference_config, **params.inference_config} 183 | 184 | # Create tool call json 185 | functions = [ 186 | self.function_map[name] 187 | for name in params.function_names 188 | if name in self.function_map 189 | ] 190 | tools = [function_to_json_anthropic(f) for f in functions] 191 | for tool in tools: 192 | fn_params = tool["input_schema"]["properties"] 193 | fn_params.pop("context_variables", None) 194 | required_list = tool["input_schema"]["required"] 195 | if "context_variables" in required_list: 196 | required_list.remove("context_variables") 197 | 198 | tools += [ 199 | mcp_to_anthropic_tool(tool) for tool in self.mcp_result.tools.values() 200 | ] 201 | 202 | response: Message = self.client.messages.create( 203 | model=params.model, 204 | messages=messages, 205 | system=system_message or NotGiven(), 206 | tools=tools or NotGiven(), 207 | tool_choice=params.tool_choice 208 | if tools and params.tool_choice 209 | else NotGiven(), 210 | max_tokens=self.inference_config["max_tokens"], 211 | temperature=self.inference_config["temperature"], 212 | top_p=self.inference_config["top_p"], 213 | stop_sequences=self.inference_config["stop_sequences"], 214 | ) 215 | 216 | return self.handle_model_response(response) 217 | 218 | @activity.defn(name="anthropic_execute_instructions") 219 | async def execute_instructions(self, params: ExecuteInstructionsParams) -> str: 220 | return await super().execute_instructions(params) 221 | 222 | @activity.defn(name="anthropic_execute_function") 223 | async def execute_function( 224 | self, params: ExecuteFunctionParams 225 | ) -> str | AnthropicAgent | AgentExecuteFnResult: 226 | return await super().execute_function(params) 227 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /rojak/workflows/orchestrator_workflow.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass, field 3 | from typing import Literal 4 | from temporalio import workflow 5 | from temporalio.exceptions import ActivityError 6 | from rojak.types import ConversationMessage, ContextVariables 7 | from collections import deque 8 | import asyncio 9 | from rojak.utils import debug_print 10 | from rojak.workflows.agent_workflow import ( 11 | AgentWorkflow, 12 | AgentWorkflowRunParams, 13 | ResumeRequest, 14 | ResumeResponse, 15 | ToolResponse, 16 | AgentTypes, 17 | ) 18 | from rojak.agents import Agent, Interrupt 19 | 20 | 21 | @dataclass 22 | class OrchestratorParams: 23 | type: Literal["stateless", "persistent"] 24 | """Specify if it is stateless or persistent workflow.""" 25 | 26 | context_variables: ContextVariables = field(default_factory=dict) 27 | """A dictionary of additional context variables, available to functions and Agent instructions.""" 28 | 29 | max_turns: int | float = field(default=float("inf")) 30 | """The maximum number of conversational turns allowed.""" 31 | 32 | debug: bool = False 33 | """If True, enables debug logging""" 34 | 35 | history_size: int = field(default=10) 36 | """The maximum number of messages retained in the list before older messages are removed.""" 37 | 38 | messages: list[ConversationMessage] = field(default_factory=list) 39 | """List of conversation messages to initialise workflow with.""" 40 | 41 | tasks: deque[tuple[str, "TaskParams"]] = field(default_factory=deque) 42 | """Tasks queue to initialise workflow with.""" 43 | 44 | 45 | @dataclass 46 | class OrchestratorResponse: 47 | """The response object from containing the updated state.""" 48 | 49 | messages: list[ConversationMessage] 50 | """The list of updated messages.""" 51 | 52 | context_variables: ContextVariables 53 | """The dictionary of the updated context variables.""" 54 | 55 | agent: AgentTypes | None = None 56 | """The last agent to be called.""" 57 | 58 | interrupt: Interrupt | None = None 59 | """The object surfaced to the client when the interupt is triggered.""" 60 | 61 | 62 | @dataclass 63 | class TaskParams: 64 | messages: list[ConversationMessage] 65 | """List of message object.""" 66 | 67 | agent: AgentTypes 68 | """The agent to be called.""" 69 | 70 | 71 | @dataclass 72 | class UpdateConfigParams: 73 | messages: list[ConversationMessage] | None = None 74 | """A list of message objects.""" 75 | 76 | context_variables: ContextVariables | None = None 77 | """The dictionary of the updated context variables.""" 78 | 79 | max_turns: int | float | None = None 80 | """The maximum number of conversational turns allowed.""" 81 | 82 | history_size: int | None = None 83 | """The maximum number of messages retained in the list before older messages are removed.""" 84 | 85 | debug: bool | None = None 86 | """If True, enables debug logging""" 87 | 88 | 89 | @dataclass 90 | class GetConfigResponse: 91 | messages: list[ConversationMessage] 92 | """A list of message objects.""" 93 | 94 | context_variables: ContextVariables 95 | """The dictionary of the updated context variables.""" 96 | 97 | max_turns: int | float 98 | """The maximum number of conversational turns allowed.""" 99 | 100 | history_size: int 101 | """The maximum number of messages retained in the list before older messages are removed.""" 102 | 103 | debug: bool 104 | """If True, enables debug logging""" 105 | 106 | 107 | @workflow.defn 108 | class OrchestratorWorkflow: 109 | @workflow.init 110 | def __init__(self, params: OrchestratorParams) -> None: 111 | self.lock = asyncio.Lock() # Prevent concurrent update handler executions 112 | self.tasks: deque[tuple[str, TaskParams]] = params.tasks 113 | self.responses: dict[str, OrchestratorResponse | ResumeRequest] = {} 114 | self.latest_response: OrchestratorResponse | ResumeRequest | None = None 115 | self.max_turns = params.max_turns 116 | self.debug = params.debug 117 | self.context_variables = params.context_variables 118 | self.current_agent_workflow: AgentWorkflow | None = None 119 | self.task_id: str | None = None 120 | self.history_size = params.history_size 121 | self.type = params.type 122 | self.messages: list[ConversationMessage] = params.messages 123 | 124 | @workflow.run 125 | async def run(self, params: OrchestratorParams) -> OrchestratorResponse: 126 | while True: 127 | await workflow.wait_condition(lambda: bool(self.tasks)) 128 | task_id, task = self.tasks.popleft() 129 | self.task_id = task_id 130 | self.messages += task.messages 131 | self.agent = task.agent # Keep track of the last to be called 132 | 133 | message = self.messages[-1] 134 | debug_print( 135 | self.debug, workflow.now(), f"{message.role}: {message.content}" 136 | ) 137 | 138 | active_agent = self.agent 139 | init_len = len(self.messages) 140 | past_message_state = copy.deepcopy(self.messages) 141 | 142 | try: 143 | while len(self.messages) - init_len < self.max_turns and active_agent: 144 | active_agent = await self.process(active_agent) 145 | 146 | response = OrchestratorResponse( 147 | messages=self.messages, 148 | agent=self.agent, 149 | context_variables=self.context_variables, 150 | ) 151 | 152 | self.reply(self.task_id, response) 153 | 154 | await workflow.wait_condition(lambda: workflow.all_handlers_finished()) 155 | 156 | if self.type == "stateless": 157 | return self.responses[self.task_id] 158 | else: 159 | if len(self.messages) > self.history_size: 160 | messages = deque(self.messages[-self.history_size :]) 161 | while messages and messages[0].role == "tool": 162 | messages.popleft() 163 | self.messages = list(messages) 164 | 165 | workflow_history_size = workflow.info().get_current_history_size() 166 | workflow_history_length = ( 167 | workflow.info().get_current_history_length() 168 | ) 169 | if ( 170 | workflow_history_length > 10_000 171 | or workflow_history_size > 20_000_000 172 | ): 173 | debug_print( 174 | self.debug, 175 | workflow.now(), 176 | "Continue as new due to prevent workflow event history from exceeding limit.", 177 | ) 178 | workflow.continue_as_new( 179 | args=[ 180 | OrchestratorParams( 181 | type=params.type, 182 | context_variables=self.context_variables, 183 | max_turns=self.max_turns, 184 | debug=self.debug, 185 | history_size=self.history_size, 186 | messages=self.messages, 187 | tasks=self.tasks, 188 | ) 189 | ] 190 | ) 191 | except ActivityError as e: 192 | # Return messages to previous state and wait for new messages 193 | workflow.logger.error(f"Agent failed to complete. Error: {e}") 194 | self.messages = past_message_state 195 | active_agent = None 196 | self.pending = False 197 | continue 198 | 199 | def reply(self, task_id: str, response: OrchestratorResponse | ResumeRequest): 200 | """Return response back""" 201 | self.responses[task_id] = response 202 | self.latest_response = response 203 | 204 | async def process(self, active_agent: Agent) -> Agent | None: 205 | params = AgentWorkflowRunParams( 206 | agent=active_agent, 207 | messages=self.messages, 208 | context_variables=self.context_variables, 209 | debug=self.debug, 210 | orchestrator=self, 211 | task_id=self.task_id, 212 | ) 213 | agent_workflow = AgentWorkflow(params) 214 | self.current_agent_workflow = agent_workflow 215 | response, updated_messages = await agent_workflow.run() 216 | 217 | self.messages = updated_messages 218 | 219 | if isinstance(response.output, ToolResponse): 220 | fn_result = response.output.output 221 | if fn_result.agent is not None: 222 | debug_print( 223 | self.debug, 224 | workflow.now(), 225 | f"{active_agent.name}: Transferred to '{fn_result.agent.name}'.", 226 | ) 227 | self.agent = active_agent = fn_result.agent 228 | if fn_result.context_variables is not None: 229 | self.context_variables = fn_result.context_variables 230 | elif isinstance(response.output, str): 231 | debug_print( 232 | self.debug, 233 | workflow.now(), 234 | f"\n{active_agent.name}: {response.output}", 235 | ) 236 | active_agent = None 237 | 238 | return active_agent 239 | 240 | def resume(self, params: ResumeResponse): 241 | """Resumes an interrupted agent workflow for a specific tool ID.""" 242 | if not self.current_agent_workflow: 243 | raise ValueError("Cannot resume: No active agent workflow available.") 244 | 245 | tool_id = params.tool_id 246 | if tool_id in self.current_agent_workflow.interrupted: 247 | self.current_agent_workflow.interrupted.remove(tool_id) 248 | self.current_agent_workflow.resumed[tool_id] = params 249 | else: 250 | raise KeyError( 251 | f"Cannot resume: Tool ID '{tool_id}' not found in the approval queue." 252 | ) 253 | 254 | @workflow.query 255 | def get_messages(self) -> list[ConversationMessage]: 256 | return self.messages 257 | 258 | @workflow.update 259 | async def add_task( 260 | self, 261 | params: tuple[str, TaskParams | ResumeResponse], 262 | ) -> OrchestratorResponse | ResumeRequest: 263 | task_id, task = params 264 | async with self.lock: 265 | self.task_id = task_id 266 | if isinstance(task, TaskParams): 267 | self.tasks.append(params) 268 | else: 269 | self.resume(task) 270 | await workflow.wait_condition(lambda: task_id in self.responses) 271 | return self.responses[task_id] 272 | 273 | @workflow.query 274 | def get_result(self, task_id: str) -> OrchestratorResponse: 275 | return self.responses[task_id] 276 | 277 | @workflow.query 278 | def get_latest_result(self) -> OrchestratorResponse | ResumeRequest | None: 279 | return self.latest_response 280 | 281 | @workflow.query 282 | def get_config(self) -> GetConfigResponse: 283 | return GetConfigResponse( 284 | messages=self.messages, 285 | context_variables=self.context_variables, 286 | max_turns=self.max_turns, 287 | history_size=self.history_size, 288 | debug=self.debug, 289 | ) 290 | 291 | @workflow.signal 292 | def update_config(self, params: UpdateConfigParams): 293 | if params.messages is not None: 294 | self.messages = params.messages 295 | if params.context_variables is not None: 296 | self.context_variables = params.context_variables 297 | if params.max_turns is not None: 298 | self.max_turns = params.max_turns 299 | if params.history_size is not None: 300 | self.history_size = params.history_size 301 | if params.debug is not None: 302 | self.debug = params.debug 303 | -------------------------------------------------------------------------------- /rojak/agents/agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Any, Callable, Literal 4 | from mcp.types import TextContent 5 | from rojak.retrievers import Retriever 6 | from rojak.types import ( 7 | ContextVariables, 8 | ConversationMessage, 9 | RetryOptions, 10 | MCPServerConfig, 11 | InitMcpResult, 12 | ) 13 | from temporalio.exceptions import ApplicationError 14 | 15 | AgentFunction = Callable[[], str] 16 | 17 | 18 | @dataclass 19 | class AgentOptions: 20 | all_functions: list[AgentFunction] = field(default_factory=list) 21 | """List of functions that an agent can execute.""" 22 | 23 | mcp_servers: dict[str, MCPServerConfig] = field(default_factory=dict) 24 | """List of MCP servers to connect to.""" 25 | 26 | 27 | @dataclass 28 | class AgentCallParams: 29 | """Parameters for generating response from model.""" 30 | 31 | messages: list[ConversationMessage] 32 | """List of message objects.""" 33 | 34 | model: str 35 | """The LLM model to use.""" 36 | 37 | function_names: list[str] = field(default_factory=list) 38 | """List of tool call function names that the agent can select.""" 39 | 40 | inference_config: dict[str, Any] | None = None 41 | """Additional configurations for the inference.""" 42 | 43 | parallel_tool_calls: bool = True 44 | """Whether model should perform multiple tool calls together.""" 45 | 46 | tool_choice: Any | None = None 47 | """The tool choice for the agent, if any.""" 48 | 49 | 50 | @dataclass 51 | class ExecuteFunctionParams: 52 | """Parameters for executing tool call function.""" 53 | 54 | name: str 55 | """The name of the tool call function.""" 56 | 57 | args: dict[str, Any] 58 | """The arguments for the tool call function.""" 59 | 60 | context_variables: ContextVariables 61 | """A dictionary of additional context variables, available to functions and Agent instructions.""" 62 | 63 | 64 | @dataclass 65 | class ExecuteInstructionsParams: 66 | instructions: "AgentInstructionOptions" 67 | """Options for the callable instructions.""" 68 | 69 | context_variables: ContextVariables 70 | """A dictionary of additional context variables, available to functions and Agent instructions.""" 71 | 72 | 73 | @dataclass 74 | class ToolCallFunction: 75 | arguments: str 76 | """String representation of the arguments for the tool call function.""" 77 | 78 | name: str 79 | """The name of the tool call function.""" 80 | 81 | 82 | @dataclass 83 | class AgentToolCall: 84 | id: str 85 | """Unique identifier of the tool call.""" 86 | 87 | function: ToolCallFunction 88 | """Function that the model called.""" 89 | 90 | type: str = "function" 91 | """The type of the tool.""" 92 | 93 | index: int | None = None 94 | """Identifies which function call the delta is for.""" 95 | 96 | def __post_init__(self): 97 | if isinstance(self.function, dict): 98 | self.function = ToolCallFunction(**self.function) 99 | 100 | 101 | @dataclass 102 | class AgentResponse: 103 | """Response object from generating response from model.""" 104 | 105 | type: Literal["text", "tool"] 106 | """Specify if it is a natural language or a tool call response.""" 107 | 108 | content: str | None = None 109 | """String output from the model""" 110 | 111 | tool_calls: list[AgentToolCall] | None = None 112 | """List of tool call objects.""" 113 | 114 | def __post_init__(self): 115 | if self.tool_calls: 116 | self.tool_calls = [ 117 | tool_call 118 | if isinstance(tool_call, AgentToolCall) 119 | else AgentToolCall(**tool_call) 120 | for tool_call in self.tool_calls 121 | ] 122 | 123 | 124 | @dataclass 125 | class AgentInstructionOptions: 126 | """Information of the callable instructions.""" 127 | 128 | type: Literal["function"] 129 | """The type of the instruction. Only `function` is supported.""" 130 | 131 | name: str 132 | """The name of the function.""" 133 | 134 | 135 | @dataclass 136 | class Interrupt: 137 | tool_name: str 138 | """The name of the tool to interrupt.""" 139 | 140 | question: str = "" 141 | """The question to ask the user.""" 142 | 143 | when: Literal["before"] = "before" 144 | """When the interrupt should be triggered.""" 145 | 146 | 147 | @dataclass 148 | class ResumeRequest: 149 | """Request to resume the interrupted agent.""" 150 | 151 | tool_id: str 152 | """The ID of the tool that is interrupted.""" 153 | 154 | tool_arguments: str 155 | """Arguments that will be passed to the tool that was interrupted.""" 156 | 157 | task_id: str 158 | """Unique identifier of the request that triggered the interrupt.""" 159 | 160 | tool_name: str 161 | """The name of the tool to interrupt.""" 162 | 163 | question: str = "" 164 | """The question to ask the user.""" 165 | 166 | when: Literal["before"] = "before" 167 | """When the interrupt should be triggered.""" 168 | 169 | 170 | @dataclass 171 | class ResumeResponse: 172 | """Response to resume the interrupted agent.""" 173 | 174 | action: Literal["approve", "reject"] 175 | """Action to take on the interrupt.""" 176 | 177 | tool_id: str 178 | """Tool call id to resume.""" 179 | 180 | content: str | None = None 181 | """Feedback to pass to Agent. Only for 'rejected' action.""" 182 | 183 | 184 | @dataclass 185 | class Agent(ABC): 186 | model: str 187 | """The LLM model to use.""" 188 | 189 | type: str 190 | """The prefix of the activity name.""" 191 | 192 | name: str = "Agent" 193 | """The name of the agent.""" 194 | 195 | instructions: str | AgentInstructionOptions = "You are a helpful assistant." 196 | """Instructions for the agent, can be a string or a callable returning a string.""" 197 | 198 | functions: list[str] = field(default_factory=list) 199 | """A list of functions that the agent can call.""" 200 | 201 | tool_choice: Any | None = None 202 | """The tool choice for the agent, if any.""" 203 | 204 | parallel_tool_calls: bool = True 205 | """Whether model should perform multiple tool calls together.""" 206 | 207 | interrupts: list[Interrupt] = field(default_factory=list) 208 | """List of interrupts for reviewing tool use.""" 209 | 210 | retriever: Retriever | None = None 211 | """Specify which retriever to use.""" 212 | 213 | retry_options: RetryOptions = field(default_factory=RetryOptions) 214 | """Options for timeout and retries.""" 215 | 216 | 217 | @dataclass 218 | class AgentExecuteFnResult: 219 | """Result object from executing tool call function.""" 220 | 221 | output: str = "" 222 | """String output to pass as message content.""" 223 | 224 | agent: Agent | None = None 225 | """The agent to call next.""" 226 | 227 | context_variables: ContextVariables = field(default_factory=dict) 228 | """A dictionary of additional context variables, available to functions and Agent instructions.""" 229 | 230 | 231 | class AgentActivities(ABC): 232 | """ 233 | Abstract base class for Agent implementations. 234 | This class provides a common structure for different types of agents. 235 | """ 236 | 237 | def __init__(self, options: AgentOptions): 238 | self.function_map = {f.__name__: f for f in options.all_functions} 239 | self.mcp_result: InitMcpResult | None = None 240 | 241 | def _add_mcp_configs(self, mcp_result: InitMcpResult): 242 | """Add MCP configurations""" 243 | self.mcp_result = mcp_result 244 | 245 | @abstractmethod 246 | async def call(self, params: AgentCallParams) -> AgentResponse: 247 | """Generate response from the LLM model. 248 | 249 | Args: 250 | params (AgentCallParams): Parameters for response generation. 251 | 252 | Returns: 253 | AgentResponse: Generated response from the model. 254 | """ 255 | pass 256 | 257 | @abstractmethod 258 | async def execute_instructions(self, params: ExecuteInstructionsParams) -> str: 259 | """Execute the instruction callable. 260 | 261 | Args: 262 | params (ExecuteInstructionsParams): Parameters containing information for executing callable. 263 | 264 | Raises: 265 | ApplicationError: Error occurred while executing instructions. 266 | 267 | Returns: 268 | str: Instructions as a string. 269 | """ 270 | instructions = params.instructions 271 | if instructions.name not in self.function_map: 272 | raise ApplicationError( 273 | f"Function {instructions.name} not found", 274 | type="FunctionNotFound", 275 | non_retryable=True, 276 | ) 277 | 278 | fn = self.function_map[instructions.name] 279 | args = {} 280 | 281 | if "context_variables" in fn.__code__.co_varnames: 282 | args["context_variables"] = params.context_variables 283 | 284 | res = fn(**args) 285 | return str(res) 286 | 287 | def handle_function_result( 288 | self, 289 | result: str | Agent | AgentExecuteFnResult, 290 | context_variables: ContextVariables, 291 | ) -> AgentExecuteFnResult: 292 | match result: 293 | case str(): 294 | return AgentExecuteFnResult( 295 | output=result, 296 | context_variables=context_variables, 297 | ) 298 | case Agent(): 299 | return AgentExecuteFnResult( 300 | output=f"Transferred to '{result.name}'", 301 | agent=result, 302 | context_variables=context_variables, 303 | ) 304 | case AgentExecuteFnResult(): 305 | return result 306 | case _: 307 | try: 308 | return AgentExecuteFnResult( 309 | output=str(result), context_variables=context_variables 310 | ) 311 | except Exception as e: 312 | raise TypeError( 313 | f"Unknown function result type: {type(result)}. Error: {str(e)}" 314 | ) 315 | 316 | @staticmethod 317 | async def execute_mcp_tool( 318 | mcp_result: InitMcpResult, tool_name: str, args: dict 319 | ) -> str: 320 | """Get tool response from MCP server 321 | 322 | Args: 323 | mcp_result (InitMcpResult): The result from initialising MCP servers. 324 | tool_name (str): Name of the tool. 325 | args (dict): Tool arguments. 326 | 327 | Returns: 328 | str: The tool response. 329 | """ 330 | server_name = mcp_result.tool_client_mapping[tool_name] 331 | client = mcp_result.clients[server_name] 332 | response = await client.session.call_tool(tool_name, args) 333 | texts = [] 334 | for content in response.content: 335 | if isinstance(content, TextContent): 336 | texts.append(content.text) 337 | return "\n".join(texts) 338 | 339 | @abstractmethod 340 | async def execute_function( 341 | self, params: ExecuteFunctionParams 342 | ) -> str | Agent | AgentExecuteFnResult: 343 | """Execute the tool call function 344 | 345 | Args: 346 | params (ExecuteFunctionParams): Parameters for executing tool call function. 347 | 348 | Raises: 349 | ApplicationError: Error executing tool call function. 350 | 351 | Returns: 352 | str | Agent | AgentExecuteFnResult: Response from the tool call function. 353 | """ 354 | if params.name in self.function_map: 355 | fn = self.function_map[params.name] 356 | 357 | if "context_variables" in fn.__code__.co_varnames: 358 | params.args["context_variables"] = params.context_variables 359 | 360 | result = fn(**params.args) 361 | elif self.mcp_result and params.name in self.mcp_result.tools: 362 | result = await self.execute_mcp_tool( 363 | self.mcp_result, params.name, params.args 364 | ) 365 | else: 366 | raise ApplicationError( 367 | f"Function {params.name} not found", 368 | type="FunctionNotFound", 369 | non_retryable=True, 370 | ) 371 | 372 | return self.handle_function_result(result, params.context_variables) 373 | -------------------------------------------------------------------------------- /rojak/workflows/agent_workflow.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import timedelta 3 | import json 4 | from typing import TYPE_CHECKING, Union 5 | from temporalio import workflow 6 | from temporalio.exceptions import ActivityError 7 | from rojak.retrievers import Retriever 8 | from rojak.types import ContextVariables, ConversationMessage 9 | from rojak.utils import create_retry_policy, debug_print 10 | from rojak.agents import ( 11 | AgentCallParams, 12 | ExecuteFunctionParams, 13 | ExecuteInstructionsParams, 14 | AgentInstructionOptions, 15 | AgentToolCall, 16 | AgentResponse, 17 | AgentExecuteFnResult, 18 | Agent, 19 | ResumeResponse, 20 | ResumeRequest, 21 | ) 22 | 23 | if TYPE_CHECKING: 24 | from rojak.workflows.orchestrator_workflow import OrchestratorWorkflow 25 | 26 | try: 27 | from rojak.agents import OpenAIAgent 28 | except ImportError: 29 | OpenAIAgent = None 30 | 31 | try: 32 | from rojak.agents import AnthropicAgent 33 | except ImportError: 34 | AnthropicAgent = None 35 | 36 | AgentTypes = ( 37 | Union[*(agent for agent in (OpenAIAgent, AnthropicAgent) if agent is not None)] 38 | | Agent # typing fallback 39 | ) 40 | 41 | 42 | @dataclass 43 | class AgentWorkflowRunParams: 44 | agent: AgentTypes 45 | """The agent to be called.""" 46 | 47 | messages: list[ConversationMessage] 48 | """List of message objects.""" 49 | 50 | context_variables: ContextVariables 51 | """A dictionary of additional context variables, available to functions and Agent instructions.""" 52 | 53 | orchestrator: "OrchestratorWorkflow" 54 | """The parent orchestrator that called this workflow.""" 55 | 56 | task_id: str 57 | """Unique identifier for this request.""" 58 | 59 | debug: bool = False 60 | """If True, enables debug logging.""" 61 | 62 | 63 | @dataclass 64 | class ToolResponse: 65 | tool_call_id: str 66 | """Unique identifier for the tool call the response is for.""" 67 | 68 | output: AgentExecuteFnResult 69 | """Result from tool call function.""" 70 | 71 | 72 | @dataclass 73 | class AgentWorkflowResponse: 74 | output: str | ToolResponse 75 | """Agent Workflow output.""" 76 | 77 | sender: str 78 | """Indicate which agent the message originated from.""" 79 | 80 | 81 | class ToolRejectedError(Exception): 82 | """Error raised when a tool call is rejected.""" 83 | 84 | def __init__(self, message: str, cause: Exception = None): 85 | super().__init__(message) 86 | self.__cause__ = cause 87 | 88 | 89 | class AgentWorkflow: 90 | def __init__(self, params: AgentWorkflowRunParams): 91 | self.orchestrator = params.orchestrator 92 | self.task_id = params.task_id 93 | self.agent = params.agent 94 | self.retry_policy = create_retry_policy(self.agent.retry_options.retry_policy) 95 | self.start_to_close_timeout = timedelta( 96 | seconds=params.agent.retry_options.timeout_in_seconds 97 | ) 98 | self.debug = params.debug 99 | self.messages = params.messages 100 | self.context_variables = params.context_variables 101 | 102 | # Handle interrupts 103 | self.interrupt_map = { 104 | interrupt.tool_name: interrupt for interrupt in params.agent.interrupts 105 | } 106 | self.interrupted: set[str] = set() # tool call ids for approval 107 | self.resumed: dict[ 108 | str, ResumeResponse 109 | ] = {} # tool call ids that resumed, pending actions 110 | 111 | async def run(self) -> tuple[AgentWorkflowResponse, list[ConversationMessage]]: 112 | # process instructions 113 | instructions = self.agent.instructions 114 | try: 115 | if isinstance(instructions, AgentInstructionOptions): 116 | instructions: str = await workflow.execute_activity( 117 | f"{self.agent.type}_execute_instructions", 118 | ExecuteInstructionsParams( 119 | instructions, 120 | self.context_variables, 121 | ), 122 | result_type=str, 123 | start_to_close_timeout=self.start_to_close_timeout, 124 | retry_policy=self.retry_policy, 125 | ) 126 | except ActivityError as e: 127 | workflow.logger.error(f"Failed to execute instructions: {e}") 128 | raise 129 | 130 | # augment instructions 131 | if isinstance(self.agent.retriever, Retriever): 132 | context_prompt = await self.retrieve_context(self.messages[-1]) 133 | instructions += context_prompt 134 | 135 | # execute call model activity 136 | response: AgentResponse = await workflow.execute_activity( 137 | f"{self.agent.type}_call", 138 | AgentCallParams( 139 | messages=[ 140 | ConversationMessage(role="system", content=instructions), 141 | *self.messages, 142 | ], 143 | model=self.agent.model, 144 | function_names=self.agent.functions, 145 | parallel_tool_calls=self.agent.parallel_tool_calls, 146 | tool_choice=self.agent.tool_choice, 147 | ), 148 | result_type=AgentResponse, 149 | start_to_close_timeout=self.start_to_close_timeout, 150 | retry_policy=self.retry_policy, 151 | ) 152 | 153 | # dont use isinstance to check as response output type different for different llm providers 154 | if response.type == "tool": 155 | tool_calls = response.tool_calls 156 | self.messages.append( 157 | ConversationMessage( 158 | role="assistant", 159 | content=response.content, 160 | tool_calls=tool_calls, 161 | sender=self.agent.name, 162 | ), 163 | ) 164 | debug_print(self.debug, workflow.now(), f"{self.agent.name}: {tool_calls}") 165 | 166 | # TODO: Figure out how to handle concurrent tool calls without race conditions in context_variables 167 | results: list[AgentWorkflowResponse] = [] 168 | for tool_call in tool_calls: 169 | result = await self.handle_tool_call(tool_call, self.context_variables) 170 | assert isinstance(result.output, ToolResponse) 171 | self.context_variables = result.output.output.context_variables 172 | results.append(result) 173 | 174 | final_result: AgentWorkflowResponse | None = None 175 | for result in results: 176 | assert isinstance(result.output, ToolResponse) 177 | fn_result = result.output.output 178 | debug_print( 179 | self.debug, 180 | workflow.now(), 181 | f"{self.agent.name}: {fn_result.output}", 182 | ) 183 | self.messages.append( 184 | ConversationMessage( 185 | role="tool", 186 | content=fn_result.output, 187 | sender=self.agent.name, 188 | tool_call_id=result.output.tool_call_id, 189 | ) 190 | ) 191 | # Send the last tool call response back to orchestrator to be used for next call to agent 192 | # Check if any tool call response is an agent. If so, send it back to orchestrator for next call to agent 193 | # If there are multiple tool call returning an agent, the last one will be used. 194 | if fn_result.agent: 195 | final_result = result 196 | 197 | if not final_result: 198 | final_result = results[-1] 199 | 200 | else: 201 | assert isinstance(response.content, str) 202 | self.messages.append( 203 | ConversationMessage( 204 | role="assistant", 205 | content=response.content, 206 | sender=self.agent.name, 207 | ) 208 | ) 209 | debug_print( 210 | self.debug, workflow.now(), f"{self.agent.name}: {response.content}" 211 | ) 212 | final_result = AgentWorkflowResponse( 213 | output=response.content, sender=self.agent.name 214 | ) 215 | 216 | return final_result, self.messages 217 | 218 | async def retrieve_context(self, message: ConversationMessage) -> str: 219 | try: 220 | retriever_result: str = await workflow.execute_activity( 221 | f"{self.agent.retriever.type}_retrieve_and_combine_results", 222 | message.content, 223 | result_type=str, 224 | start_to_close_timeout=self.start_to_close_timeout, 225 | retry_policy=self.retry_policy, 226 | ) 227 | debug_print( 228 | self.debug, workflow.now(), f"Retriever context: {retriever_result}" 229 | ) 230 | context_prompt = f"\nHere is the context to use to answer the user's question:\n{retriever_result}" 231 | return context_prompt 232 | except ActivityError as e: 233 | workflow.logger.error( 234 | f"Failed to retrieve context from retriever. Context will not be added. Error: {e.cause}" 235 | ) 236 | return "" 237 | 238 | async def handle_interrupt(self, tool_call: AgentToolCall): 239 | """ 240 | Handles interrupts at a specified point ("before" or "after"). 241 | """ 242 | if tool_call.function.name in self.interrupt_map: 243 | self.interrupted.add(tool_call.id) 244 | interrupt = self.interrupt_map[tool_call.function.name] 245 | debug_print( 246 | self.debug, 247 | workflow.now(), 248 | f"Interrupt: {interrupt.question}", 249 | ) 250 | self.orchestrator.reply( 251 | self.task_id, 252 | ResumeRequest( 253 | tool_id=tool_call.id, 254 | tool_name=tool_call.function.name, 255 | question=interrupt.question, 256 | when=interrupt.when, 257 | tool_arguments=tool_call.function.arguments, 258 | task_id=self.task_id, 259 | ), 260 | ) 261 | 262 | await workflow.wait_condition(lambda: tool_call.id in self.resumed) 263 | 264 | debug_print( 265 | self.debug, 266 | workflow.now(), 267 | f"Interrupt: Resuming tool call '{tool_call.function.name}'", 268 | ) 269 | 270 | resume_params = self.resumed[tool_call.id] 271 | 272 | if resume_params.action == "reject": 273 | raise ToolRejectedError("Tool rejected.") from ValueError( 274 | f"Rejected by user. Reason: '{resume_params.content}'" 275 | ) 276 | del self.resumed[tool_call.id] 277 | 278 | async def handle_tool_call( 279 | self, 280 | tool_call: AgentToolCall, 281 | context_variables: ContextVariables, 282 | ) -> AgentWorkflowResponse: 283 | """ 284 | Handles a tool call, checks for interrupts before and after execution, 285 | and handles approval/rejection. 286 | """ 287 | name = tool_call.function.name 288 | args = json.loads(tool_call.function.arguments) 289 | try: 290 | await self.handle_interrupt(tool_call) 291 | 292 | # Execute function in activity 293 | debug_print( 294 | self.debug, 295 | workflow.now(), 296 | f"Processing tool call: '{name}' with args: {args}", 297 | ) 298 | result: AgentExecuteFnResult = await workflow.execute_activity( 299 | f"{self.agent.type}_execute_function", 300 | ExecuteFunctionParams(name, args, context_variables), 301 | result_type=AgentExecuteFnResult, 302 | start_to_close_timeout=self.start_to_close_timeout, 303 | retry_policy=self.retry_policy, 304 | ) 305 | 306 | tool_response = ToolResponse( 307 | tool_call_id=tool_call.id, 308 | output=result, 309 | ) 310 | return AgentWorkflowResponse(output=tool_response, sender=self.agent.name) 311 | except (ActivityError, ToolRejectedError) as e: 312 | # If error, let the model know by sending error message in tool response 313 | workflow.logger.error( 314 | f"Failed to process tool call '{name}'. " 315 | f"Error will be sent to agent to reassess. Error: {e}" 316 | ) 317 | result = AgentExecuteFnResult( 318 | output=str(e.__cause__), context_variables=context_variables 319 | ) 320 | tool_response = ToolResponse(tool_call_id=tool_call.id, output=result) 321 | return AgentWorkflowResponse(output=tool_response, sender=self.agent.name) 322 | -------------------------------------------------------------------------------- /rojak/client.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from dataclasses import dataclass 3 | from typing import AsyncIterator, Literal, overload 4 | from temporalio.client import ( 5 | Client, 6 | Schedule, 7 | ScheduleActionStartWorkflow, 8 | ScheduleSpec, 9 | ScheduleHandle, 10 | WithStartWorkflowOperation, 11 | WorkflowHandle, 12 | ) 13 | from temporalio import workflow, common 14 | from temporalio.worker import Worker 15 | from rojak.types import MCPServerConfig, InitMcpResult 16 | 17 | with workflow.unsafe.imports_passed_through(): 18 | from mcp import Tool 19 | from rojak.retrievers import RetrieverActivities 20 | from rojak.agents import AgentActivities 21 | from rojak.mcp import MCPClient 22 | from rojak.workflows import ( 23 | OrchestratorResponse, 24 | OrchestratorParams, 25 | OrchestratorWorkflow, 26 | GetConfigResponse, 27 | UpdateConfigParams, 28 | TaskParams, 29 | ) 30 | from rojak.agents import ResumeRequest, ResumeResponse 31 | from uuid import uuid4 32 | 33 | 34 | @dataclass 35 | class RunResponse: 36 | id: str 37 | result: OrchestratorResponse | ResumeRequest 38 | task_id: str 39 | workflow_handle: WorkflowHandle 40 | 41 | 42 | class Rojak: 43 | def __init__(self, client: Client, task_queue: str): 44 | self.client: Client = client 45 | self.task_queue: str = task_queue 46 | self.mcp_result: InitMcpResult | None = None 47 | 48 | async def _init_mcp(self, servers: dict[str, MCPServerConfig]) -> None: 49 | """Initialise MCP servers. 50 | 51 | Args: 52 | servers (dict[str, MCPServerConfig]): List of MCP servers. 53 | 54 | Returns: 55 | tuple[dict[str, MCPClient], dict[str, Tool], dict[str, str]]: Response as tuple. 56 | """ 57 | mcp_clients: dict[str, MCPClient] = {} 58 | mcp_tools: dict[str, Tool] = {} 59 | tool_client_mapping: dict[str, str] = {} 60 | for server_name, config in servers.items(): 61 | try: 62 | mcp_client = MCPClient() 63 | await mcp_client.connect_to_server(config) 64 | list_tools_result = await mcp_client.session.list_tools() 65 | mcp_clients[server_name] = mcp_client 66 | for tool in list_tools_result.tools: 67 | mcp_tools[tool.name] = tool 68 | tool_client_mapping[tool.name] = server_name 69 | except Exception as e: 70 | print(f"Unable to connect to MCP server. Skipping. Error: {e}") 71 | self.mcp_result = InitMcpResult(mcp_clients, mcp_tools, tool_client_mapping) 72 | if mcp_tools: 73 | print(f"MCP tools loaded: {list(mcp_tools.keys())}") 74 | 75 | async def cleanup_mcp(self): 76 | """Cleanup MCP connections.""" 77 | for client in list(self.mcp_result.clients.values())[::-1]: 78 | await client.cleanup() 79 | 80 | async def create_worker( 81 | self, 82 | agent_activities: list[AgentActivities], 83 | retriever_activities: list[RetrieverActivities] = [], 84 | mcp_servers: dict[str, MCPServerConfig] = {}, 85 | ) -> Worker: 86 | """Create a worker. 87 | 88 | Args: 89 | agent_activities (list[AgentActivities]): List of activity classes that can be called. 90 | retriever_activities (list[RetrieverActivities], optional): List of retriever activity classes that can be called. Defaults to []. 91 | mcp_servers (dict[str, MCPServerConfig], optional): Dictionary of MCP server configurations. Each key represents the server name, and the value is its corresponding MCPServerConfig object. Defaults to {}. 92 | 93 | Returns: 94 | Worker: A worker object that can be used to start the worker. 95 | """ 96 | await self._init_mcp(mcp_servers) 97 | activities = [] 98 | for activity in agent_activities: 99 | if self.mcp_result: 100 | activity._add_mcp_configs(self.mcp_result) 101 | activities.append(activity.call) 102 | activities.append(activity.execute_function) 103 | activities.append(activity.execute_instructions) 104 | 105 | for retriever in retriever_activities: 106 | activities.append(retriever.retrieve_and_combine_results) 107 | 108 | worker: Worker = Worker( 109 | self.client, 110 | task_queue=self.task_queue, 111 | workflows=[OrchestratorWorkflow], 112 | activities=activities, 113 | ) 114 | 115 | return worker 116 | 117 | async def list_scheduled_runs( 118 | self, 119 | schedule_id: str, 120 | statuses: list[ 121 | Literal[ 122 | "Running", "Completed", "Failed", "Cancelled", "Terminated", "TimedOut" 123 | ] 124 | ] 125 | | None = None, 126 | limit: int = 10, 127 | page_size: int = 1000, 128 | next_page_token: bytes | None = None, 129 | ) -> AsyncIterator[str]: 130 | """List the ID of orchestrators associated with the schedule. 131 | 132 | Args: 133 | schedule_id (str): Unique identifier of the schedule. 134 | statuses (list[ Literal[ 'Running', 'Completed', 'Failed', 'Cancelled', 'Terminated', 'TimedOut' ] ] | None, optional): List of statuses to filter the runs. Defaults to None. 135 | limit (int, optional): Maximum number of IDs to return. Defaults to 10. 136 | page_size (int, optional): Maximum number of results per page. Defaults to 1000. 137 | next_page_token (bytes | None, optional): A previously obtained next page token if doing pagination. Usually not needed as the iterator automatically starts from the beginning. Defaults to None. 138 | 139 | Returns: 140 | AsyncIterator[str]: An async iterator that can be used with `async for`. 141 | """ 142 | status_filter = ( 143 | " OR ".join(f'ExecutionStatus="{status}"' for status in statuses) 144 | if statuses 145 | else "" 146 | ) 147 | query = f'TemporalScheduledById="{schedule_id}"' 148 | if status_filter: 149 | query += f" AND ({status_filter})" 150 | 151 | async for workflow_execution in self.client.list_workflows( 152 | query=query, 153 | limit=limit, 154 | page_size=page_size, 155 | next_page_token=next_page_token, 156 | ): 157 | yield workflow_execution.id 158 | 159 | async def create_schedule( 160 | self, 161 | schedule_id: str, 162 | schedule_spec: ScheduleSpec, 163 | task: TaskParams, 164 | context_variables: dict = {}, 165 | max_turns: int = float("inf"), 166 | history_size: int = 10, 167 | debug: bool = False, 168 | ) -> ScheduleHandle: 169 | """ 170 | Create a schedule that periodically executes a workflow. 171 | 172 | The schedule periodically executes the equivalent of the `run()` method with the provided task, context variables, and configuration. 173 | 174 | Args: 175 | schedule_id (str): Unique identifier for the schedule. 176 | schedule_spec (ScheduleSpec): Specifies when the schedule executes, such as a cron schedule or interval. 177 | task (TaskParams): Encapsulates the agent, messages, and parameters for the workflow to run. 178 | context_variables (dict, optional): Additional variables available to functions and agent instructions. Defaults to {}. 179 | max_turns (int, optional): The maximum number of conversational turns allowed in the workflow. Defaults to float("inf"). 180 | history_size (int, optional): The maximum number of messages retained in the conversation history. Defaults to 10. 181 | debug (bool, optional): Enables debug logging if True. Defaults to False. 182 | 183 | Returns: 184 | ScheduleHandle: A handle to the created schedule, allowing management such as pausing, resuming, or deleting. 185 | """ 186 | task_id = str(uuid4()) 187 | data = OrchestratorParams( 188 | context_variables=context_variables, 189 | max_turns=max_turns, 190 | tasks=deque([(task_id, task)]), 191 | debug=debug, 192 | type="stateless", 193 | history_size=history_size, 194 | ) 195 | 196 | return await self.client.create_schedule( 197 | schedule_id, 198 | Schedule( 199 | action=ScheduleActionStartWorkflow( 200 | OrchestratorWorkflow.run, 201 | data, 202 | id=schedule_id, 203 | task_queue=self.task_queue, 204 | ), 205 | spec=schedule_spec, 206 | ), 207 | ) 208 | 209 | @overload 210 | async def run( 211 | self, 212 | id: str, 213 | type: Literal["stateless", "persistent"], 214 | task: TaskParams, 215 | context_variables: dict = {}, 216 | max_turns: int = float("inf"), 217 | history_size: int = 10, 218 | debug: bool = False, 219 | ) -> RunResponse: ... 220 | 221 | @overload 222 | async def run(self, id: str, resume: ResumeResponse) -> RunResponse: ... 223 | 224 | async def run( 225 | self, 226 | id: str, 227 | type: Literal["stateless", "persistent"] | None = None, 228 | task: TaskParams | None = None, 229 | resume: ResumeResponse | None = None, 230 | context_variables: dict = {}, 231 | max_turns: int = float("inf"), 232 | history_size: int = 10, 233 | debug: bool = False, 234 | ) -> RunResponse: 235 | """ 236 | Initialize and execute an orchestrator with the provided inputs, handling tasks, resuming workflows, 237 | and waiting for completion. 238 | 239 | Requires a running worker. 240 | 241 | Args: 242 | id (str): Unique identifier of the orchestrator. 243 | type (Literal["stateless", "persistent"] | None, optional): Whether to keep track of prior conversations through long-running workflows. 244 | task (TaskParams | None, optional): A task to be executed in the orchestrator. Defaults to None. 245 | resume (ResumeResponse | None, optional): A resume object for continuing a paused workflow. Defaults to None. 246 | context_variables (dict, optional): A dictionary of additional context variables available to functions 247 | and agent instructions. Defaults to an empty dictionary. 248 | max_turns (int, optional): The maximum number of conversational turns allowed. Defaults to infinity. 249 | history_size (int, optional): The maximum number of messages retained in the list before older messages are 250 | removed. When this limit is exceeded, older messages are removed. Defaults to 10. 251 | debug (bool, optional): If True, enables debug logging for the orchestrator. Defaults to False. 252 | 253 | Returns: 254 | RunResponse: 255 | - OrchestratorResponse: A response object containing updated messages, context variables, and agent 256 | information, if the workflow completes successfully. 257 | - ResumeRequest: A request to resume the workflow if the current state requires further inputs or actions. 258 | - WorkflowHandle: A handle to the orchestrator workflow. 259 | 260 | Notes: 261 | - If a `task` is provided, the method starts a new orchestrator workflow. 262 | - If a `resume` is provided, the method resumes the specified workflow. 263 | - A new workflow is initialized with `context_variables`, `max_turns`, and `debug` settings. 264 | """ 265 | start_op = None 266 | task_id = str(uuid4()) 267 | if task: 268 | start_op = WithStartWorkflowOperation( 269 | OrchestratorWorkflow.run, 270 | OrchestratorParams( 271 | context_variables=context_variables, 272 | max_turns=max_turns, 273 | debug=debug, 274 | history_size=history_size, 275 | type=type, 276 | ), 277 | id=id, 278 | id_conflict_policy=common.WorkflowIDConflictPolicy.USE_EXISTING, 279 | task_queue=self.task_queue, 280 | ) 281 | result = await self.client.execute_update_with_start_workflow( 282 | OrchestratorWorkflow.add_task, 283 | (task_id, task), 284 | start_workflow_operation=start_op, 285 | result_type=OrchestratorResponse | ResumeRequest, 286 | ) 287 | workflow_handle = await start_op.workflow_handle() 288 | else: 289 | workflow_handle = self.client.get_workflow_handle_for( 290 | workflow=OrchestratorWorkflow, workflow_id=id 291 | ) 292 | 293 | result: ( 294 | OrchestratorResponse | ResumeRequest 295 | ) = await workflow_handle.execute_update( 296 | OrchestratorWorkflow.add_task, 297 | (task_id, resume), 298 | result_type=OrchestratorResponse | ResumeRequest, 299 | ) 300 | 301 | return RunResponse( 302 | id=workflow_handle.id, 303 | result=result, 304 | task_id=task_id, 305 | workflow_handle=workflow_handle, 306 | ) 307 | 308 | async def get_result( 309 | self, id: str, task_id: str | None 310 | ) -> OrchestratorResponse | ResumeRequest | None: 311 | """ 312 | Retrieve the latest or specific task result for a workflow. 313 | 314 | Requires a running worker. If `task_id` is provided, the result of that specific task 315 | is fetched; otherwise, the latest result of the workflow is returned. 316 | 317 | Args: 318 | id (str): The unique identifier of the workflow. 319 | task_id (str | None): The ID of the specific task to retrieve the result for. If None, retrieves the latest result. 320 | 321 | Returns: 322 | OrchestratorResponse: An object containing updated messages and context variables from the workflow. 323 | """ 324 | workflow_handle = self.client.get_workflow_handle(id) 325 | if task_id is None: 326 | return await workflow_handle.query(OrchestratorWorkflow.get_latest_result) 327 | else: 328 | return await workflow_handle.query(OrchestratorWorkflow.get_result, task_id) 329 | 330 | async def get_config(self, id: str) -> GetConfigResponse: 331 | """ 332 | Retrieve the current configuration of a workflow session. 333 | 334 | Requires a running worker. 335 | 336 | Args: 337 | id (str): The unique identifier of the workflow session. 338 | 339 | Returns: 340 | GetConfigResponse: An object containing the current configuration values of the session. 341 | """ 342 | return await self.client.get_workflow_handle(id).query( 343 | OrchestratorWorkflow.get_config, result_type=GetConfigResponse 344 | ) 345 | 346 | async def update_config(self, id: str, params: UpdateConfigParams): 347 | """ 348 | Update the configuration of a workflow session. 349 | 350 | Requires a running worker. 351 | 352 | Args: 353 | id (str): The unique identifier of the workflow session. 354 | params (UpdateConfigParams): Configuration parameters to update. Only the values specified in `params` will be updated. 355 | """ 356 | await self.client.get_workflow_handle(id).signal( 357 | OrchestratorWorkflow.update_config, params 358 | ) 359 | 360 | async def cancel(self, id: str): 361 | """Cancel the session.""" 362 | return await self.client.get_workflow_handle(id).cancel() 363 | -------------------------------------------------------------------------------- /tests/agents/test_openai_agent.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | import uuid 3 | import pytest 4 | from temporalio.testing import WorkflowEnvironment 5 | 6 | from rojak import Rojak 7 | from rojak.agents import ( 8 | OpenAIAgent, 9 | OpenAIAgentActivities, 10 | OpenAIAgentOptions, 11 | AgentExecuteFnResult, 12 | AgentInstructionOptions, 13 | Interrupt, 14 | ResumeRequest, 15 | ResumeResponse, 16 | ) 17 | from rojak.client import RunResponse 18 | from rojak.types import ( 19 | RetryOptions, 20 | RetryPolicy, 21 | ) 22 | from rojak.workflows import OrchestratorResponse, TaskParams 23 | from tests.mock_client import MockOpenAIClient, create_mock_response 24 | 25 | DEFAULT_RESPONSE_CONTENT = "sample response content" 26 | DEFAULT_RESPONSE_CONTENT_2 = "sample response content 2" 27 | 28 | 29 | @pytest.fixture 30 | def mock_openai_client(): 31 | m = MockOpenAIClient() 32 | m.set_response( 33 | create_mock_response({"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT}) 34 | ) 35 | return m 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_run_with_messages(mock_openai_client: MockOpenAIClient): 40 | task_queue_name = str(uuid.uuid4()) 41 | async with await WorkflowEnvironment.start_time_skipping() as env: 42 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 43 | openai_activities = OpenAIAgentActivities( 44 | OpenAIAgentOptions(client=mock_openai_client) 45 | ) 46 | worker = await rojak.create_worker([openai_activities]) 47 | 48 | async with worker: 49 | agent = OpenAIAgent(name="assistant") 50 | task = TaskParams( 51 | agent=agent, 52 | messages=[{"role": "user", "content": "Hello how are you?"}], 53 | ) 54 | 55 | run_response: RunResponse = await rojak.run( 56 | id=str(uuid.uuid4()), 57 | type="stateless", 58 | task=task, 59 | ) 60 | 61 | # The final result can be an OrchestratorResponse or ResumeRequest 62 | assert isinstance(run_response.result, OrchestratorResponse) 63 | response: OrchestratorResponse = run_response.result 64 | assert response.messages[-1].role == "assistant" 65 | assert response.messages[-1].content == DEFAULT_RESPONSE_CONTENT 66 | 67 | 68 | @pytest.mark.asyncio 69 | async def test_get_result(mock_openai_client: MockOpenAIClient): 70 | """ 71 | Demonstrates that we can re-run the same workflow with no new TaskParams 72 | and still retrieve the last state (OrchestratorResponse). 73 | """ 74 | task_queue_name = str(uuid.uuid4()) 75 | async with await WorkflowEnvironment.start_time_skipping() as env: 76 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 77 | openai_activities = OpenAIAgentActivities( 78 | OpenAIAgentOptions(client=mock_openai_client) 79 | ) 80 | worker = await rojak.create_worker([openai_activities]) 81 | 82 | async with worker: 83 | agent = OpenAIAgent(name="assistant") 84 | workflow_id = str(uuid.uuid4()) 85 | 86 | # First run with an initial task 87 | task = TaskParams( 88 | agent=agent, 89 | messages=[{"role": "user", "content": "Hello how are you?"}], 90 | ) 91 | run_response: RunResponse = await rojak.run( 92 | id=workflow_id, 93 | type="stateless", 94 | task=task, 95 | ) 96 | 97 | response = await rojak.get_result( 98 | id=run_response.id, task_id=run_response.task_id 99 | ) 100 | 101 | assert isinstance(response, OrchestratorResponse) 102 | first_response: OrchestratorResponse = response 103 | assert first_response.messages[-1].content == DEFAULT_RESPONSE_CONTENT 104 | 105 | 106 | @pytest.mark.asyncio 107 | async def test_callable_instructions(mock_openai_client: MockOpenAIClient): 108 | task_queue_name = str(uuid.uuid4()) 109 | 110 | instruct_fn_mock = Mock() 111 | 112 | def instruct_fn(context_variables): 113 | res = f"My name is {context_variables.get('name')}" 114 | instruct_fn_mock(context_variables) 115 | instruct_fn_mock.return_value = res 116 | return res 117 | 118 | async with await WorkflowEnvironment.start_time_skipping() as env: 119 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 120 | openai_activities = OpenAIAgentActivities( 121 | OpenAIAgentOptions(client=mock_openai_client, all_functions=[instruct_fn]) 122 | ) 123 | worker = await rojak.create_worker([openai_activities]) 124 | 125 | async with worker: 126 | agent = OpenAIAgent( 127 | name="assistant", 128 | instructions=AgentInstructionOptions( 129 | type="function", name="instruct_fn" 130 | ), 131 | ) 132 | context_variables = {"name": "John"} 133 | task = TaskParams( 134 | agent=agent, 135 | messages=[{"role": "user", "content": "Hello how are you?"}], 136 | ) 137 | run_response: RunResponse = await rojak.run( 138 | id=str(uuid.uuid4()), 139 | type="stateless", 140 | task=task, 141 | context_variables=context_variables, 142 | ) 143 | 144 | assert isinstance(run_response.result, OrchestratorResponse) 145 | instruct_fn_mock.assert_called_once_with(context_variables) 146 | assert instruct_fn_mock.return_value == ( 147 | f"My name is {context_variables.get('name')}" 148 | ) 149 | 150 | 151 | @pytest.mark.asyncio 152 | async def test_failed_tool_call(mock_openai_client: MockOpenAIClient): 153 | """Context variable should be updated by first tool call only since 2nd tool call fails.""" 154 | task_queue_name = str(uuid.uuid4()) 155 | 156 | get_weather_mock = Mock() 157 | get_air_quality_mock = Mock() 158 | 159 | def get_weather(context_variables: dict): 160 | get_weather_mock() 161 | context_variables["seen"].append("get_weather") 162 | raise Exception("Something went wrong!") 163 | 164 | def get_air_quality(context_variables: dict): 165 | get_air_quality_mock() 166 | context_variables["seen"].append("get_air_quality") 167 | return AgentExecuteFnResult( 168 | output="Air quality is great!", context_variables=context_variables 169 | ) 170 | 171 | messages = [ 172 | { 173 | "role": "user", 174 | "content": "What's the weather and air quality like in San Francisco?", 175 | } 176 | ] 177 | 178 | # set mock to return a response that triggers function calls 179 | mock_openai_client.set_sequential_responses( 180 | [ 181 | create_mock_response( 182 | message={"role": "assistant", "content": ""}, 183 | function_calls=[{"name": "get_air_quality", "args": {}}], 184 | ), 185 | create_mock_response( 186 | message={"role": "assistant", "content": ""}, 187 | function_calls=[{"name": "get_weather", "args": {}}], 188 | ), 189 | create_mock_response( 190 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 191 | ), 192 | ] 193 | ) 194 | 195 | async with await WorkflowEnvironment.start_time_skipping() as env: 196 | agent = OpenAIAgent( 197 | name="Test Agent", 198 | functions=["get_weather", "get_air_quality"], 199 | retry_options=RetryOptions(retry_policy=RetryPolicy(maximum_attempts=5)), 200 | ) 201 | openai_activities = OpenAIAgentActivities( 202 | OpenAIAgentOptions( 203 | client=mock_openai_client, all_functions=[get_weather, get_air_quality] 204 | ) 205 | ) 206 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 207 | worker = await rojak.create_worker([openai_activities]) 208 | 209 | async with worker: 210 | context_vars = {"seen": ["test"]} 211 | task = TaskParams(agent=agent, messages=messages) 212 | run_response: RunResponse = await rojak.run( 213 | id=str(uuid.uuid4()), 214 | type="stateless", 215 | task=task, 216 | context_variables=context_vars, 217 | ) 218 | 219 | assert isinstance(run_response.result, OrchestratorResponse) 220 | orchestrator_resp: OrchestratorResponse = run_response.result 221 | 222 | get_weather_mock.assert_called() 223 | get_air_quality_mock.assert_called_once() 224 | assert orchestrator_resp.context_variables["seen"] == [ 225 | "test", 226 | "get_air_quality", 227 | ] 228 | assert orchestrator_resp.messages[-1].role == "assistant" 229 | assert orchestrator_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT 230 | 231 | 232 | @pytest.mark.asyncio 233 | async def test_multiple_tool_calls(mock_openai_client: MockOpenAIClient): 234 | task_queue_name = str(uuid.uuid4()) 235 | 236 | expected_location = "San Francisco" 237 | 238 | get_weather_mock = Mock() 239 | get_air_quality_mock = Mock() 240 | 241 | def get_weather(location: str, context_variables: dict): 242 | get_weather_mock(location=location) 243 | context_variables["seen"].append("get_weather") 244 | res = f"It's sunny today in {location}" 245 | return AgentExecuteFnResult(output=res, context_variables=context_variables) 246 | 247 | def get_air_quality(location: str, context_variables: dict): 248 | get_air_quality_mock(location=location) 249 | context_variables["seen"].append("get_air_quality") 250 | res = f"Air quality in {location} is good!" 251 | return AgentExecuteFnResult(output=res, context_variables=context_variables) 252 | 253 | messages = [ 254 | { 255 | "role": "user", 256 | "content": "What's the weather and air quality like in San Francisco?", 257 | } 258 | ] 259 | 260 | mock_openai_client.set_sequential_responses( 261 | [ 262 | create_mock_response( 263 | message={"role": "assistant", "content": ""}, 264 | function_calls=[ 265 | {"name": "get_weather", "args": {"location": expected_location}}, 266 | { 267 | "name": "get_air_quality", 268 | "args": {"location": expected_location}, 269 | }, 270 | ], 271 | ), 272 | create_mock_response( 273 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 274 | ), 275 | ] 276 | ) 277 | 278 | async with await WorkflowEnvironment.start_time_skipping() as env: 279 | agent = OpenAIAgent( 280 | name="Test Agent", 281 | functions=["get_weather", "get_air_quality"], 282 | ) 283 | openai_activities = OpenAIAgentActivities( 284 | OpenAIAgentOptions( 285 | client=mock_openai_client, all_functions=[get_weather, get_air_quality] 286 | ) 287 | ) 288 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 289 | worker = await rojak.create_worker([openai_activities]) 290 | 291 | async with worker: 292 | context_vars = {"location": expected_location, "seen": []} 293 | task = TaskParams(agent=agent, messages=messages) 294 | run_response: RunResponse = await rojak.run( 295 | id=str(uuid.uuid4()), 296 | type="stateless", 297 | task=task, 298 | context_variables=context_vars, 299 | ) 300 | 301 | assert isinstance(run_response.result, OrchestratorResponse) 302 | orchestrator_resp: OrchestratorResponse = run_response.result 303 | 304 | get_weather_mock.assert_called_once_with(location=expected_location) 305 | get_air_quality_mock.assert_called_once_with(location=expected_location) 306 | assert "get_weather" in orchestrator_resp.context_variables["seen"] 307 | assert "get_air_quality" in orchestrator_resp.context_variables["seen"] 308 | assert orchestrator_resp.messages[-1].role == "assistant" 309 | assert orchestrator_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT 310 | 311 | 312 | @pytest.mark.asyncio 313 | async def test_handoff(mock_openai_client: MockOpenAIClient): 314 | task_queue_name = str(uuid.uuid4()) 315 | 316 | def transfer_to_agent2(context_variables: dict): 317 | # Transfer to another agent 318 | return AgentExecuteFnResult( 319 | output="Handoff to agent2", 320 | context_variables=context_variables, 321 | agent=agent2, 322 | ) 323 | 324 | agent1 = OpenAIAgent(name="Test Agent 1", functions=["transfer_to_agent2"]) 325 | agent2 = OpenAIAgent(name="Test Agent 2") 326 | 327 | # mock that triggers the handoff 328 | mock_openai_client.set_sequential_responses( 329 | [ 330 | create_mock_response( 331 | message={"role": "assistant", "content": ""}, 332 | function_calls=[{"name": "transfer_to_agent2"}], 333 | ), 334 | create_mock_response( 335 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 336 | ), 337 | ] 338 | ) 339 | 340 | async with await WorkflowEnvironment.start_time_skipping() as env: 341 | openai_activities = OpenAIAgentActivities( 342 | OpenAIAgentOptions( 343 | client=mock_openai_client, all_functions=[transfer_to_agent2] 344 | ) 345 | ) 346 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 347 | worker = await rojak.create_worker([openai_activities]) 348 | 349 | async with worker: 350 | task = TaskParams( 351 | agent=agent1, 352 | messages=[{"role": "user", "content": "I want to talk to agent 2"}], 353 | ) 354 | run_response: RunResponse = await rojak.run( 355 | id=str(uuid.uuid4()), 356 | type="stateless", 357 | task=task, 358 | ) 359 | assert isinstance(run_response.result, OrchestratorResponse) 360 | orchestrator_resp: OrchestratorResponse = run_response.result 361 | assert orchestrator_resp.agent == agent2 362 | assert orchestrator_resp.messages[-1].role == "assistant" 363 | assert orchestrator_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT 364 | 365 | 366 | @pytest.mark.asyncio 367 | async def test_send_multiple_messages(mock_openai_client: MockOpenAIClient): 368 | """ 369 | Demonstrates sending multiple user messages in separate calls to the same workflow. 370 | """ 371 | task_queue_name = str(uuid.uuid4()) 372 | 373 | # We want two different assistant replies in sequence 374 | mock_openai_client.set_sequential_responses( 375 | [ 376 | create_mock_response( 377 | message={"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT}, 378 | ), 379 | create_mock_response( 380 | message={"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2}, 381 | ), 382 | ] 383 | ) 384 | 385 | async with await WorkflowEnvironment.start_time_skipping() as env: 386 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 387 | openai_activities = OpenAIAgentActivities( 388 | OpenAIAgentOptions(client=mock_openai_client) 389 | ) 390 | worker = await rojak.create_worker([openai_activities]) 391 | 392 | async with worker: 393 | agent = OpenAIAgent(name="assistant") 394 | workflow_id = str(uuid.uuid4()) 395 | 396 | # First user message 397 | task_1 = TaskParams( 398 | agent=agent, 399 | messages=[{"role": "user", "content": "Hello how are you?"}], 400 | ) 401 | run_response_1: RunResponse = await rojak.run( 402 | id=workflow_id, 403 | type="persistent", 404 | task=task_1, 405 | ) 406 | assert isinstance(run_response_1.result, OrchestratorResponse) 407 | response_1: OrchestratorResponse = run_response_1.result 408 | assert response_1.messages[-1].role == "assistant" 409 | assert response_1.messages[-1].content == DEFAULT_RESPONSE_CONTENT 410 | 411 | # Second user message (same workflow_id) 412 | task_2 = TaskParams( 413 | agent=agent, 414 | messages=[{"role": "user", "content": "What's new today?"}], 415 | ) 416 | run_response_2: RunResponse = await rojak.run( 417 | id=workflow_id, 418 | type="persistent", 419 | task=task_2, 420 | ) 421 | assert isinstance(run_response_2.result, OrchestratorResponse) 422 | response_2: OrchestratorResponse = run_response_2.result 423 | assert response_2.messages[-1].role == "assistant" 424 | assert response_2.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 425 | 426 | 427 | @pytest.mark.asyncio 428 | async def test_result(mock_openai_client: MockOpenAIClient): 429 | task_queue_name = str(uuid.uuid4()) 430 | 431 | def transfer_agent_b(context_variables: dict): 432 | context_variables["seen"] = True 433 | return AgentExecuteFnResult( 434 | output="Transferred to Agent B", 435 | context_variables=context_variables, 436 | agent=agent_b, 437 | ) 438 | 439 | mock_openai_client.set_sequential_responses( 440 | [ 441 | create_mock_response( 442 | message={"role": "assistant", "content": ""}, 443 | function_calls=[{"name": "transfer_agent_b", "args": {}}], 444 | ), 445 | create_mock_response( 446 | message={"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2}, 447 | ), 448 | ] 449 | ) 450 | 451 | agent_a = OpenAIAgent(name="Agent A", functions=["transfer_agent_b"]) 452 | agent_b = OpenAIAgent(name="Agent B") 453 | 454 | async with await WorkflowEnvironment.start_time_skipping() as env: 455 | openai_activities = OpenAIAgentActivities( 456 | OpenAIAgentOptions( 457 | client=mock_openai_client, all_functions=[transfer_agent_b] 458 | ) 459 | ) 460 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 461 | worker = await rojak.create_worker([openai_activities]) 462 | 463 | async with worker: 464 | context_vars = {"seen": False} 465 | task = TaskParams( 466 | agent=agent_a, 467 | messages=[{"role": "user", "content": "I want to talk to agent B"}], 468 | ) 469 | run_response: RunResponse = await rojak.run( 470 | id=str(uuid.uuid4()), 471 | type="persistent", 472 | task=task, 473 | context_variables=context_vars, 474 | ) 475 | 476 | assert isinstance(run_response.result, OrchestratorResponse) 477 | orchestrator_resp: OrchestratorResponse = run_response.result 478 | assert orchestrator_resp.context_variables["seen"] is True 479 | assert orchestrator_resp.agent == agent_b 480 | assert orchestrator_resp.messages[-1].role == "assistant" 481 | assert orchestrator_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 482 | 483 | 484 | @pytest.mark.asyncio 485 | async def test_interrupt_and_approve(mock_openai_client: MockOpenAIClient): 486 | """ 487 | Demonstrates how the orchestrator interrupts a function call, 488 | returns a ResumeRequest, and how we "approve" the call 489 | by calling rojak.run(..., resume=ResumeResponse(...)). 490 | """ 491 | task_queue_name = str(uuid.uuid4()) 492 | 493 | def say_hello(): 494 | say_hello_mock() 495 | return "Hello!" 496 | 497 | mock_openai_client.set_sequential_responses( 498 | [ 499 | create_mock_response( 500 | message={"role": "assistant", "content": ""}, 501 | function_calls=[{"name": "say_hello", "args": {}}], 502 | ), 503 | create_mock_response( 504 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2} 505 | ), 506 | ] 507 | ) 508 | 509 | agent = OpenAIAgent( 510 | functions=["say_hello"], 511 | interrupts=[Interrupt("say_hello")], 512 | ) 513 | 514 | say_hello_mock = Mock() 515 | 516 | openai_activities = OpenAIAgentActivities( 517 | OpenAIAgentOptions( 518 | client=mock_openai_client, 519 | all_functions=[say_hello], 520 | ) 521 | ) 522 | async with await WorkflowEnvironment.start_time_skipping() as env: 523 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 524 | worker = await rojak.create_worker([openai_activities]) 525 | 526 | async with worker: 527 | workflow_id = str(uuid.uuid4()) 528 | first_task = TaskParams( 529 | agent=agent, 530 | messages=[{"role": "user", "content": "Hello"}], 531 | ) 532 | run_resp_1 = await rojak.run( 533 | id=workflow_id, type="persistent", task=first_task 534 | ) 535 | 536 | # We expect a ResumeRequest because we have an interrupt 537 | assert isinstance(run_resp_1.result, ResumeRequest) 538 | resume_req = run_resp_1.result 539 | assert resume_req.tool_name == "say_hello" 540 | tool_id = resume_req.tool_id # We'll need this to resume 541 | 542 | approve_resume = ResumeResponse(action="approve", tool_id=tool_id) 543 | run_resp_2 = await rojak.run( 544 | id=workflow_id, 545 | resume=approve_resume, 546 | ) 547 | 548 | # This time, we should get an OrchestratorResponse 549 | assert isinstance(run_resp_2.result, OrchestratorResponse) 550 | orch_resp = run_resp_2.result 551 | 552 | say_hello_mock.assert_called_once() 553 | assert orch_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 554 | 555 | 556 | @pytest.mark.asyncio 557 | async def test_interrupt_and_reject(mock_openai_client: MockOpenAIClient): 558 | """ 559 | Demonstrates how we 'reject' an interrupted function call. 560 | The orchestrator will skip calling the function and continue. 561 | """ 562 | task_queue_name = str(uuid.uuid4()) 563 | 564 | def say_hello(): 565 | say_hello_mock() 566 | return "Hello!" 567 | 568 | mock_openai_client.set_sequential_responses( 569 | [ 570 | create_mock_response( 571 | message={"role": "assistant", "content": ""}, 572 | function_calls=[{"name": "say_hello", "args": {}}], 573 | ), 574 | create_mock_response( 575 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2} 576 | ), 577 | ] 578 | ) 579 | 580 | agent = OpenAIAgent( 581 | functions=["say_hello"], 582 | interrupts=[Interrupt("say_hello")], 583 | ) 584 | 585 | say_hello_mock = Mock() 586 | 587 | openai_activities = OpenAIAgentActivities( 588 | OpenAIAgentOptions( 589 | client=mock_openai_client, 590 | all_functions=[say_hello], 591 | ) 592 | ) 593 | async with await WorkflowEnvironment.start_time_skipping() as env: 594 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 595 | worker = await rojak.create_worker([openai_activities]) 596 | 597 | async with worker: 598 | workflow_id = str(uuid.uuid4()) 599 | first_task = TaskParams( 600 | agent=agent, 601 | messages=[{"role": "user", "content": "Hello"}], 602 | ) 603 | run_resp_1 = await rojak.run( 604 | id=workflow_id, type="persistent", task=first_task 605 | ) 606 | 607 | # We expect a ResumeRequest because we have an interrupt 608 | assert isinstance(run_resp_1.result, ResumeRequest) 609 | resume_req = run_resp_1.result 610 | assert resume_req.tool_name == "say_hello" 611 | tool_id = resume_req.tool_id # We'll need this to resume 612 | 613 | # Reject the function call 614 | reject_reason = "User does not want this." 615 | reject_resume = ResumeResponse( 616 | action="reject", tool_id=tool_id, content=reject_reason 617 | ) 618 | run_resp_2 = await rojak.run( 619 | id=workflow_id, 620 | resume=reject_resume, 621 | ) 622 | 623 | # This time, we should get an OrchestratorResponse 624 | assert isinstance(run_resp_2.result, OrchestratorResponse) 625 | orch_resp = run_resp_2.result 626 | 627 | # The function should not have been called 628 | say_hello_mock.assert_not_called() 629 | assert reject_reason in orch_resp.messages[-2].content 630 | assert orch_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 631 | -------------------------------------------------------------------------------- /tests/agents/test_anthropic_agent.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | import uuid 3 | import pytest 4 | from temporalio.testing import WorkflowEnvironment 5 | 6 | from rojak import Rojak 7 | from rojak.agents import ( 8 | AgentExecuteFnResult, 9 | AgentInstructionOptions, 10 | AnthropicAgentActivities, 11 | AnthropicAgentOptions, 12 | AnthropicAgent, 13 | Interrupt, 14 | ResumeRequest, 15 | ResumeResponse, 16 | ) 17 | from rojak.client import RunResponse 18 | from rojak.types import ( 19 | ConversationMessage, 20 | RetryOptions, 21 | RetryPolicy, 22 | ) 23 | from rojak.workflows import OrchestratorResponse, TaskParams 24 | 25 | from tests.mock_anthropic_client import ( 26 | MockAnthropicClient, 27 | create_mock_response, 28 | ) 29 | 30 | DEFAULT_RESPONSE_CONTENT = "sample response content" 31 | DEFAULT_RESPONSE_CONTENT_2 = "sample response content 2" 32 | 33 | 34 | @pytest.fixture 35 | def mock_anthropic_client(): 36 | m = MockAnthropicClient() 37 | m.set_response( 38 | create_mock_response({"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT}) 39 | ) 40 | return m 41 | 42 | 43 | @pytest.mark.asyncio 44 | async def test_run_with_messages(mock_anthropic_client: MockAnthropicClient): 45 | """ 46 | Test a single run call with user messages and verify assistant reply. 47 | """ 48 | task_queue_name = str(uuid.uuid4()) 49 | async with await WorkflowEnvironment.start_time_skipping() as env: 50 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 51 | anthropic_activities = AnthropicAgentActivities( 52 | AnthropicAgentOptions(client=mock_anthropic_client) 53 | ) 54 | worker = await rojak.create_worker([anthropic_activities]) 55 | 56 | async with worker: 57 | agent = AnthropicAgent(name="assistant") 58 | task = TaskParams( 59 | agent=agent, 60 | messages=[{"role": "user", "content": "Hello how are you?"}], 61 | ) 62 | run_response: RunResponse = await rojak.run( 63 | id=str(uuid.uuid4()), 64 | type="stateless", 65 | task=task, 66 | ) 67 | # Verify result is an OrchestratorResponse 68 | assert isinstance(run_response.result, OrchestratorResponse) 69 | response: OrchestratorResponse = run_response.result 70 | assert response.messages[-1].role == "assistant" 71 | assert response.messages[-1].content == DEFAULT_RESPONSE_CONTENT 72 | 73 | 74 | @pytest.mark.asyncio 75 | async def test_get_result(mock_anthropic_client: MockAnthropicClient): 76 | """ 77 | Demonstrates calling run again with no new task to retrieve the final state. 78 | """ 79 | task_queue_name = str(uuid.uuid4()) 80 | async with await WorkflowEnvironment.start_time_skipping() as env: 81 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 82 | anthropic_activities = AnthropicAgentActivities( 83 | AnthropicAgentOptions(client=mock_anthropic_client) 84 | ) 85 | worker = await rojak.create_worker([anthropic_activities]) 86 | 87 | async with worker: 88 | agent = AnthropicAgent(name="assistant") 89 | workflow_id = str(uuid.uuid4()) 90 | 91 | # First call with a user message 92 | task = TaskParams( 93 | agent=agent, 94 | messages=[{"role": "user", "content": "Hello how are you?"}], 95 | ) 96 | run_response = await rojak.run( 97 | id=workflow_id, 98 | type="stateless", 99 | task=task, 100 | ) 101 | assert isinstance(run_response.result, OrchestratorResponse) 102 | 103 | response = await rojak.get_result( 104 | id=run_response.id, task_id=run_response.task_id 105 | ) 106 | 107 | assert isinstance(response, OrchestratorResponse) 108 | first_response: OrchestratorResponse = response 109 | assert first_response.messages[-1].content == DEFAULT_RESPONSE_CONTENT 110 | 111 | 112 | @pytest.mark.asyncio 113 | async def test_callable_instructions(mock_anthropic_client: MockAnthropicClient): 114 | """ 115 | Test agent instructions invoked via a Python callable. 116 | """ 117 | task_queue_name = str(uuid.uuid4()) 118 | instruct_fn_mock = Mock() 119 | 120 | def instruct_fn(context_variables): 121 | res = f"My name is {context_variables.get('name')}" 122 | instruct_fn_mock(context_variables) 123 | instruct_fn_mock.return_value = res 124 | return res 125 | 126 | async with await WorkflowEnvironment.start_time_skipping() as env: 127 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 128 | anthropic_activities = AnthropicAgentActivities( 129 | AnthropicAgentOptions( 130 | client=mock_anthropic_client, all_functions=[instruct_fn] 131 | ) 132 | ) 133 | worker = await rojak.create_worker([anthropic_activities]) 134 | 135 | async with worker: 136 | agent = AnthropicAgent( 137 | name="assistant", 138 | instructions=AgentInstructionOptions( 139 | type="function", name="instruct_fn" 140 | ), 141 | ) 142 | context_variables = {"name": "John"} 143 | task = TaskParams( 144 | agent=agent, 145 | messages=[{"role": "user", "content": "Hello how are you?"}], 146 | ) 147 | run_response = await rojak.run( 148 | id=str(uuid.uuid4()), 149 | type="stateless", 150 | task=task, 151 | context_variables=context_variables, 152 | ) 153 | assert isinstance(run_response.result, OrchestratorResponse) 154 | 155 | instruct_fn_mock.assert_called_once_with(context_variables) 156 | assert instruct_fn_mock.return_value == "My name is John" 157 | 158 | 159 | @pytest.mark.asyncio 160 | async def test_failed_tool_call(mock_anthropic_client: MockAnthropicClient): 161 | """ 162 | Context variable should be updated by first tool call only since 2nd tool call fails. 163 | """ 164 | task_queue_name = str(uuid.uuid4()) 165 | get_weather_mock = Mock() 166 | get_air_quality_mock = Mock() 167 | 168 | def get_weather(context_variables: dict): 169 | get_weather_mock() 170 | context_variables["seen"].append("get_weather") 171 | raise Exception("Something went wrong!") 172 | 173 | def get_air_quality(context_variables: dict): 174 | get_air_quality_mock() 175 | context_variables["seen"].append("get_air_quality") 176 | return AgentExecuteFnResult( 177 | output="Air quality is great!", context_variables=context_variables 178 | ) 179 | 180 | # set mock to return a response that triggers function calls 181 | mock_anthropic_client.set_sequential_responses( 182 | [ 183 | create_mock_response( 184 | message={"role": "assistant", "content": ""}, 185 | function_calls=[{"name": "get_air_quality", "args": {}}], 186 | ), 187 | create_mock_response( 188 | message={"role": "assistant", "content": ""}, 189 | function_calls=[{"name": "get_weather", "args": {}}], 190 | ), 191 | create_mock_response( 192 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 193 | ), 194 | ] 195 | ) 196 | 197 | async with await WorkflowEnvironment.start_time_skipping() as env: 198 | agent = AnthropicAgent( 199 | name="Test Agent", 200 | functions=["get_weather", "get_air_quality"], 201 | retry_options=RetryOptions(retry_policy=RetryPolicy(maximum_attempts=5)), 202 | ) 203 | anthropic_activities = AnthropicAgentActivities( 204 | AnthropicAgentOptions( 205 | client=mock_anthropic_client, 206 | all_functions=[get_weather, get_air_quality], 207 | ) 208 | ) 209 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 210 | worker = await rojak.create_worker([anthropic_activities]) 211 | 212 | async with worker: 213 | context_variables = {"seen": ["test"]} 214 | task = TaskParams( 215 | agent=agent, 216 | messages=[ 217 | { 218 | "role": "user", 219 | "content": "What's the weather and air quality like in San Francisco?", 220 | } 221 | ], 222 | ) 223 | run_response = await rojak.run( 224 | id=str(uuid.uuid4()), 225 | type="stateless", 226 | task=task, 227 | context_variables=context_variables, 228 | ) 229 | assert isinstance(run_response.result, OrchestratorResponse) 230 | resp: OrchestratorResponse = run_response.result 231 | 232 | get_weather_mock.assert_called() 233 | get_air_quality_mock.assert_called_once() 234 | assert resp.context_variables["seen"] == ["test", "get_air_quality"] 235 | assert resp.messages[-1].role == "assistant" 236 | assert resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT 237 | 238 | 239 | @pytest.mark.asyncio 240 | async def test_multiple_tool_calls(mock_anthropic_client: MockAnthropicClient): 241 | """ 242 | Multiple tool calls returned in a single Anthropic response. 243 | """ 244 | task_queue_name = str(uuid.uuid4()) 245 | expected_location = "San Francisco" 246 | get_weather_mock = Mock() 247 | get_air_quality_mock = Mock() 248 | 249 | def get_weather(location: str, context_variables: dict): 250 | get_weather_mock(location=location) 251 | context_variables["seen"].append("get_weather") 252 | res = f"It's sunny today in {location}" 253 | return AgentExecuteFnResult(output=res, context_variables=context_variables) 254 | 255 | def get_air_quality(location: str, context_variables: dict): 256 | get_air_quality_mock(location=location) 257 | context_variables["seen"].append("get_air_quality") 258 | res = f"Air quality in {location} is good!" 259 | return AgentExecuteFnResult(output=res, context_variables=context_variables) 260 | 261 | # mock that triggers both function calls 262 | mock_anthropic_client.set_sequential_responses( 263 | [ 264 | create_mock_response( 265 | message={"role": "assistant", "content": ""}, 266 | function_calls=[ 267 | {"name": "get_weather", "args": {"location": expected_location}}, 268 | { 269 | "name": "get_air_quality", 270 | "args": {"location": expected_location}, 271 | }, 272 | ], 273 | ), 274 | create_mock_response( 275 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 276 | ), 277 | ] 278 | ) 279 | 280 | async with await WorkflowEnvironment.start_time_skipping() as env: 281 | agent = AnthropicAgent( 282 | name="Test Agent", 283 | functions=["get_weather", "get_air_quality"], 284 | ) 285 | anthropic_activities = AnthropicAgentActivities( 286 | AnthropicAgentOptions( 287 | client=mock_anthropic_client, 288 | all_functions=[get_weather, get_air_quality], 289 | ) 290 | ) 291 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 292 | worker = await rojak.create_worker([anthropic_activities]) 293 | 294 | async with worker: 295 | context_vars = {"location": expected_location, "seen": []} 296 | task = TaskParams( 297 | agent=agent, 298 | messages=[ 299 | { 300 | "role": "user", 301 | "content": "What's the weather and air quality like in San Francisco?", 302 | } 303 | ], 304 | ) 305 | run_response = await rojak.run( 306 | id=str(uuid.uuid4()), 307 | type="stateless", 308 | task=task, 309 | context_variables=context_vars, 310 | ) 311 | assert isinstance(run_response.result, OrchestratorResponse) 312 | resp: OrchestratorResponse = run_response.result 313 | 314 | get_weather_mock.assert_called_once_with(location=expected_location) 315 | get_air_quality_mock.assert_called_once_with(location=expected_location) 316 | assert "get_weather" in resp.context_variables["seen"] 317 | assert "get_air_quality" in resp.context_variables["seen"] 318 | assert resp.messages[-1].role == "assistant" 319 | assert resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT 320 | 321 | 322 | @pytest.mark.asyncio 323 | async def test_handoff(mock_anthropic_client: MockAnthropicClient): 324 | """ 325 | Agent A calls a function that returns agent B, verifying a handoff. 326 | """ 327 | task_queue_name = str(uuid.uuid4()) 328 | 329 | def transfer_to_agent2(context_variables: dict): 330 | return AgentExecuteFnResult( 331 | output="handoff to agent2", 332 | context_variables=context_variables, 333 | agent=agent2, 334 | ) 335 | 336 | agent1 = AnthropicAgent(name="Test Agent 1", functions=["transfer_to_agent2"]) 337 | agent2 = AnthropicAgent(name="Test Agent 2") 338 | 339 | mock_anthropic_client.set_sequential_responses( 340 | [ 341 | create_mock_response( 342 | message={"role": "assistant", "content": ""}, 343 | function_calls=[{"name": "transfer_to_agent2"}], 344 | ), 345 | create_mock_response( 346 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 347 | ), 348 | ] 349 | ) 350 | 351 | async with await WorkflowEnvironment.start_time_skipping() as env: 352 | anthropic_activities = AnthropicAgentActivities( 353 | AnthropicAgentOptions( 354 | client=mock_anthropic_client, all_functions=[transfer_to_agent2] 355 | ) 356 | ) 357 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 358 | worker = await rojak.create_worker([anthropic_activities]) 359 | 360 | async with worker: 361 | task = TaskParams( 362 | agent=agent1, 363 | messages=[{"role": "user", "content": "I want to talk to agent 2"}], 364 | ) 365 | run_response = await rojak.run( 366 | id=str(uuid.uuid4()), 367 | type="stateless", 368 | task=task, 369 | ) 370 | assert isinstance(run_response.result, OrchestratorResponse) 371 | resp: OrchestratorResponse = run_response.result 372 | assert resp.agent == agent2 373 | assert resp.messages[-1].role == "assistant" 374 | assert resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT 375 | 376 | 377 | @pytest.mark.asyncio 378 | async def test_send_multiple_messages(mock_anthropic_client: MockAnthropicClient): 379 | """ 380 | Demonstrates sending multiple user messages by calling run repeatedly. 381 | """ 382 | task_queue_name = str(uuid.uuid4()) 383 | 384 | # Two distinct assistant responses in sequence 385 | mock_anthropic_client.set_sequential_responses( 386 | [ 387 | create_mock_response( 388 | message={"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT}, 389 | ), 390 | create_mock_response( 391 | message={"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2}, 392 | ), 393 | ] 394 | ) 395 | 396 | async with await WorkflowEnvironment.start_time_skipping() as env: 397 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 398 | anthropic_activities = AnthropicAgentActivities( 399 | AnthropicAgentOptions(client=mock_anthropic_client) 400 | ) 401 | worker = await rojak.create_worker([anthropic_activities]) 402 | 403 | async with worker: 404 | agent = AnthropicAgent(name="assistant") 405 | workflow_id = str(uuid.uuid4()) 406 | 407 | # First user message 408 | task_1 = TaskParams( 409 | agent=agent, 410 | messages=[{"role": "user", "content": "Hello how are you?"}], 411 | ) 412 | run_response_1 = await rojak.run( 413 | id=workflow_id, 414 | type="persistent", 415 | task=task_1, 416 | ) 417 | assert isinstance(run_response_1.result, OrchestratorResponse) 418 | resp_1: OrchestratorResponse = run_response_1.result 419 | assert resp_1.agent == agent 420 | assert resp_1.messages[-1].role == "assistant" 421 | assert resp_1.messages[-1].content == DEFAULT_RESPONSE_CONTENT 422 | 423 | # Second user message (same workflow_id) 424 | task_2 = TaskParams( 425 | agent=agent, 426 | messages=[{"role": "user", "content": "Hello again?"}], 427 | ) 428 | run_response_2 = await rojak.run( 429 | id=workflow_id, 430 | task=task_2, 431 | ) 432 | assert isinstance(run_response_2.result, OrchestratorResponse) 433 | resp_2: OrchestratorResponse = run_response_2.result 434 | assert resp_2.agent == agent 435 | assert resp_2.messages[-1].role == "assistant" 436 | assert resp_2.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 437 | 438 | 439 | @pytest.mark.asyncio 440 | async def test_persistent_state_across_calls( 441 | mock_anthropic_client: MockAnthropicClient, 442 | ): 443 | """ 444 | Shows how we can accumulate context over multiple run() calls, effectively 445 | replacing 'session' tests from earlier versions of Rojak. The second call 446 | uses `rojak.get_result(...)` rather than `run(..., task=None)`. 447 | """ 448 | task_queue_name = str(uuid.uuid4()) 449 | 450 | def transfer_agent_b(context_variables: dict): 451 | context_variables["seen"] = True 452 | return AgentExecuteFnResult( 453 | output="Transferred to Agent B", 454 | context_variables=context_variables, 455 | agent=agent_b, 456 | ) 457 | 458 | # The mock response triggers a function call, then a final assistant message 459 | mock_anthropic_client.set_sequential_responses( 460 | [ 461 | create_mock_response( 462 | message={"role": "assistant", "content": ""}, 463 | function_calls=[{"name": "transfer_agent_b", "args": {}}], 464 | ), 465 | create_mock_response( 466 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT} 467 | ), 468 | ] 469 | ) 470 | 471 | async with await WorkflowEnvironment.start_time_skipping() as env: 472 | agent_a = AnthropicAgent(name="Agent A", functions=["transfer_agent_b"]) 473 | agent_b = AnthropicAgent(name="Agent B") 474 | 475 | anthropic_activities = AnthropicAgentActivities( 476 | AnthropicAgentOptions( 477 | client=mock_anthropic_client, 478 | all_functions=[transfer_agent_b], 479 | ) 480 | ) 481 | 482 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 483 | worker = await rojak.create_worker([anthropic_activities]) 484 | 485 | async with worker: 486 | workflow_id = str(uuid.uuid4()) 487 | 488 | # --- First call: start a persistent workflow with a user message --- 489 | task_1 = TaskParams( 490 | agent=agent_a, 491 | messages=[{"role": "user", "content": "I want to talk to agent B"}], 492 | ) 493 | run_response_1: RunResponse = await rojak.run( 494 | id=workflow_id, 495 | type="persistent", 496 | task=task_1, 497 | context_variables={"seen": False}, 498 | ) 499 | assert isinstance(run_response_1.result, OrchestratorResponse) 500 | resp_1: OrchestratorResponse = run_response_1.result 501 | 502 | # Verify that the agent was handed off to agent_b and the context updated 503 | assert resp_1.context_variables["seen"] is True 504 | assert resp_1.agent == agent_b 505 | assert resp_1.messages[-1].content == DEFAULT_RESPONSE_CONTENT 506 | 507 | # --- Second call: retrieve the final state (no new task) --- 508 | # We need the task_id from run_response_1 to query that same orchestrator state. 509 | final_response: OrchestratorResponse = await rojak.get_result( 510 | id=workflow_id, 511 | task_id=run_response_1.task_id, 512 | ) 513 | 514 | # Confirm the final state matches what we expect 515 | assert final_response.context_variables["seen"] is True 516 | assert final_response.agent == agent_b 517 | assert final_response.messages[-1].content == DEFAULT_RESPONSE_CONTENT 518 | 519 | 520 | @pytest.mark.asyncio 521 | async def test_interrupt_and_approve(mock_anthropic_client: MockAnthropicClient): 522 | """ 523 | Demonstrates how the orchestrator interrupts a function call, 524 | returns a ResumeRequest, and how we "approve" the call 525 | by calling rojak.run(..., resume=ResumeResponse(...)). 526 | """ 527 | task_queue_name = str(uuid.uuid4()) 528 | 529 | def say_hello(): 530 | say_hello_mock() 531 | return "Hello!" 532 | 533 | mock_anthropic_client.set_sequential_responses( 534 | [ 535 | create_mock_response( 536 | message={"role": "assistant", "content": ""}, 537 | function_calls=[{"name": "say_hello", "args": {}}], 538 | ), 539 | create_mock_response( 540 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2} 541 | ), 542 | ] 543 | ) 544 | 545 | agent = AnthropicAgent( 546 | functions=["say_hello"], 547 | interrupts=[Interrupt("say_hello")], 548 | ) 549 | 550 | say_hello_mock = Mock() 551 | 552 | openai_activities = AnthropicAgentActivities( 553 | AnthropicAgentOptions( 554 | client=mock_anthropic_client, 555 | all_functions=[say_hello], 556 | ) 557 | ) 558 | async with await WorkflowEnvironment.start_time_skipping() as env: 559 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 560 | worker = await rojak.create_worker([openai_activities]) 561 | 562 | async with worker: 563 | workflow_id = str(uuid.uuid4()) 564 | first_task = TaskParams( 565 | agent=agent, 566 | messages=[{"role": "user", "content": "Hello"}], 567 | ) 568 | run_resp_1 = await rojak.run( 569 | id=workflow_id, type="persistent", task=first_task 570 | ) 571 | 572 | # We expect a ResumeRequest because we have an interrupt 573 | assert isinstance(run_resp_1.result, ResumeRequest) 574 | resume_req = run_resp_1.result 575 | assert resume_req.tool_name == "say_hello" 576 | tool_id = resume_req.tool_id # We'll need this to resume 577 | 578 | approve_resume = ResumeResponse(action="approve", tool_id=tool_id) 579 | run_resp_2 = await rojak.run( 580 | id=workflow_id, 581 | resume=approve_resume, 582 | ) 583 | 584 | # This time, we should get an OrchestratorResponse 585 | assert isinstance(run_resp_2.result, OrchestratorResponse) 586 | orch_resp = run_resp_2.result 587 | 588 | say_hello_mock.assert_called_once() 589 | assert orch_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 590 | 591 | 592 | @pytest.mark.asyncio 593 | async def test_interrupt_and_reject(mock_anthropic_client: MockAnthropicClient): 594 | """ 595 | Demonstrates how we 'reject' an interrupted function call. 596 | The orchestrator will skip calling the function and continue. 597 | """ 598 | task_queue_name = str(uuid.uuid4()) 599 | 600 | def say_hello(): 601 | say_hello_mock() 602 | return "Hello!" 603 | 604 | mock_anthropic_client.set_sequential_responses( 605 | [ 606 | create_mock_response( 607 | message={"role": "assistant", "content": ""}, 608 | function_calls=[{"name": "say_hello", "args": {}}], 609 | ), 610 | create_mock_response( 611 | {"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT_2} 612 | ), 613 | ] 614 | ) 615 | 616 | agent = AnthropicAgent( 617 | functions=["say_hello"], 618 | interrupts=[Interrupt("say_hello")], 619 | ) 620 | 621 | say_hello_mock = Mock() 622 | 623 | openai_activities = AnthropicAgentActivities( 624 | AnthropicAgentOptions( 625 | client=mock_anthropic_client, 626 | all_functions=[say_hello], 627 | ) 628 | ) 629 | async with await WorkflowEnvironment.start_time_skipping() as env: 630 | rojak = Rojak(client=env.client, task_queue=task_queue_name) 631 | worker = await rojak.create_worker([openai_activities]) 632 | 633 | async with worker: 634 | workflow_id = str(uuid.uuid4()) 635 | first_task = TaskParams( 636 | agent=agent, 637 | messages=[{"role": "user", "content": "Hello"}], 638 | ) 639 | run_resp_1 = await rojak.run( 640 | id=workflow_id, type="persistent", task=first_task 641 | ) 642 | 643 | # We expect a ResumeRequest because we have an interrupt 644 | assert isinstance(run_resp_1.result, ResumeRequest) 645 | resume_req = run_resp_1.result 646 | assert resume_req.tool_name == "say_hello" 647 | tool_id = resume_req.tool_id # We'll need this to resume 648 | 649 | # Reject the function call 650 | reject_reason = "User does not want this." 651 | reject_resume = ResumeResponse( 652 | action="reject", tool_id=tool_id, content=reject_reason 653 | ) 654 | run_resp_2 = await rojak.run( 655 | id=workflow_id, 656 | resume=reject_resume, 657 | ) 658 | 659 | # This time, we should get an OrchestratorResponse 660 | assert isinstance(run_resp_2.result, OrchestratorResponse) 661 | orch_resp = run_resp_2.result 662 | 663 | # The function should not have been called 664 | say_hello_mock.assert_not_called() 665 | assert reject_reason in orch_resp.messages[-2].content 666 | assert orch_resp.messages[-1].content == DEFAULT_RESPONSE_CONTENT_2 667 | 668 | 669 | # 670 | # Below tests only concern the .convert_messages() utility 671 | # in AnthropicAgentActivities; no workflow code is involved. 672 | # 673 | 674 | 675 | def test_convert_messages_with_parallel_tool_calls(): 676 | conversation_messages = [ 677 | ConversationMessage( 678 | **{ 679 | "content": "Help provide the weather forecast.", 680 | "role": "system", 681 | "sender": None, 682 | "tool_call_id": None, 683 | "tool_calls": None, 684 | } 685 | ), 686 | ConversationMessage( 687 | **{ 688 | "content": "What is the weather like in Malaysia and Singapore?", 689 | "role": "user", 690 | "sender": None, 691 | "tool_call_id": None, 692 | "tool_calls": None, 693 | } 694 | ), 695 | ConversationMessage( 696 | **{ 697 | "content": ( 698 | "I'll help you check the weather for both Malaysia and Singapore. " 699 | "I'll retrieve the current weather information for each location.\n\n" 700 | "Let's start with Malaysia:" 701 | ), 702 | "role": "assistant", 703 | "sender": "Weather Assistant", 704 | "tool_call_id": None, 705 | "tool_calls": [ 706 | { 707 | "function": { 708 | "arguments": '{"location": "Kuala Lumpur"}', 709 | "name": "get_weather", 710 | }, 711 | "id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 712 | "type": "function", 713 | }, 714 | { 715 | "function": { 716 | "arguments": '{"location": "Singapore"}', 717 | "name": "get_weather", 718 | }, 719 | "id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 720 | "type": "function", 721 | }, 722 | ], 723 | } 724 | ), 725 | ConversationMessage( 726 | **{ 727 | "content": '{"location": "Kuala Lumpur", "temperature": "65", "time": "now"}', 728 | "role": "tool", 729 | "sender": "Weather Assistant", 730 | "tool_call_id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 731 | "tool_calls": None, 732 | } 733 | ), 734 | ConversationMessage( 735 | **{ 736 | "content": '{"location": "Singapore", "temperature": "65", "time": "now"}', 737 | "role": "tool", 738 | "sender": "Weather Assistant", 739 | "tool_call_id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 740 | "tool_calls": None, 741 | } 742 | ), 743 | ] 744 | 745 | assert AnthropicAgentActivities.convert_messages(conversation_messages) == ( 746 | [ 747 | { 748 | "role": "user", 749 | "content": "What is the weather like in Malaysia and Singapore?", 750 | }, 751 | { 752 | "role": "assistant", 753 | "content": [ 754 | { 755 | "id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 756 | "input": {"location": "Kuala Lumpur"}, 757 | "name": "get_weather", 758 | "type": "tool_use", 759 | }, 760 | { 761 | "id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 762 | "input": {"location": "Singapore"}, 763 | "name": "get_weather", 764 | "type": "tool_use", 765 | }, 766 | ], 767 | }, 768 | { 769 | "role": "user", 770 | "content": [ 771 | { 772 | "type": "tool_result", 773 | "tool_use_id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 774 | "content": [ 775 | { 776 | "type": "text", 777 | "text": '{"location": "Kuala Lumpur", "temperature": "65", "time": "now"}', 778 | }, 779 | ], 780 | }, 781 | { 782 | "type": "tool_result", 783 | "tool_use_id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 784 | "content": [ 785 | { 786 | "type": "text", 787 | "text": '{"location": "Singapore", "temperature": "65", "time": "now"}', 788 | }, 789 | ], 790 | }, 791 | ], 792 | }, 793 | ], 794 | "Help provide the weather forecast.", 795 | ) 796 | 797 | 798 | def test_convert_messages_with_nonparallel_tool_call(): 799 | conversation_messages = [ 800 | ConversationMessage( 801 | **{ 802 | "content": "Help provide the weather forecast.", 803 | "role": "system", 804 | "sender": None, 805 | "tool_call_id": None, 806 | "tool_calls": None, 807 | } 808 | ), 809 | ConversationMessage( 810 | **{ 811 | "content": "What is the weather like in Malaysia and Singapore?", 812 | "role": "user", 813 | "sender": None, 814 | "tool_call_id": None, 815 | "tool_calls": None, 816 | } 817 | ), 818 | ConversationMessage( 819 | **{ 820 | "content": "I'll help you check the weather for Malaysia", 821 | "role": "assistant", 822 | "sender": "Weather Assistant", 823 | "tool_call_id": None, 824 | "tool_calls": [ 825 | { 826 | "function": { 827 | "arguments": '{"location": "Kuala Lumpur"}', 828 | "name": "get_weather", 829 | }, 830 | "id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 831 | "type": "function", 832 | }, 833 | ], 834 | } 835 | ), 836 | ConversationMessage( 837 | **{ 838 | "content": '{"location": "Kuala Lumpur", "temperature": "65", "time": "now"}', 839 | "role": "tool", 840 | "sender": "Weather Assistant", 841 | "tool_call_id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 842 | "tool_calls": None, 843 | } 844 | ), 845 | ConversationMessage( 846 | **{ 847 | "content": "I'll help you check the weather for Singapore.", 848 | "role": "assistant", 849 | "sender": "Weather Assistant", 850 | "tool_call_id": None, 851 | "tool_calls": [ 852 | { 853 | "function": { 854 | "arguments": '{"location": "Singapore"}', 855 | "name": "get_weather", 856 | }, 857 | "id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 858 | "type": "function", 859 | }, 860 | ], 861 | } 862 | ), 863 | ConversationMessage( 864 | **{ 865 | "content": '{"location": "Singapore", "temperature": "65", "time": "now"}', 866 | "role": "tool", 867 | "sender": "Weather Assistant", 868 | "tool_call_id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 869 | "tool_calls": None, 870 | } 871 | ), 872 | ] 873 | 874 | assert AnthropicAgentActivities.convert_messages(conversation_messages) == ( 875 | [ 876 | { 877 | "role": "user", 878 | "content": "What is the weather like in Malaysia and Singapore?", 879 | }, 880 | { 881 | "role": "assistant", 882 | "content": [ 883 | { 884 | "id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 885 | "input": {"location": "Kuala Lumpur"}, 886 | "name": "get_weather", 887 | "type": "tool_use", 888 | } 889 | ], 890 | }, 891 | { 892 | "role": "user", 893 | "content": [ 894 | { 895 | "type": "tool_result", 896 | "tool_use_id": "toolu_01Qz54ujndhYL3cGXKY1UukD", 897 | "content": [ 898 | { 899 | "type": "text", 900 | "text": '{"location": "Kuala Lumpur", "temperature": "65", "time": "now"}', 901 | } 902 | ], 903 | } 904 | ], 905 | }, 906 | { 907 | "role": "assistant", 908 | "content": [ 909 | { 910 | "id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 911 | "input": {"location": "Singapore"}, 912 | "name": "get_weather", 913 | "type": "tool_use", 914 | } 915 | ], 916 | }, 917 | { 918 | "role": "user", 919 | "content": [ 920 | { 921 | "type": "tool_result", 922 | "tool_use_id": "toolu_01AUvuz1d7UoUrs7SzhpCqnF", 923 | "content": [ 924 | { 925 | "type": "text", 926 | "text": '{"location": "Singapore", "temperature": "65", "time": "now"}', 927 | } 928 | ], 929 | } 930 | ], 931 | }, 932 | ], 933 | "Help provide the weather forecast.", 934 | ) 935 | --------------------------------------------------------------------------------