├── src ├── service │ ├── __init__.py │ ├── utils.py │ └── service.py ├── client │ ├── __init__.py │ └── client.py ├── core │ ├── __init__.py │ ├── llm.py │ └── settings.py ├── agents │ ├── __init__.py │ ├── utils.py │ ├── agents.py │ ├── tools.py │ ├── chatbot.py │ ├── command_agent.py │ ├── bg_task_agent │ │ ├── task.py │ │ └── bg_task_agent.py │ ├── llama_guard.py │ └── research_assistant.py ├── memory │ ├── sqlite.py │ ├── __init__.py │ └── postgres.py ├── schema │ ├── __init__.py │ ├── models.py │ ├── task_data.py │ └── schema.py ├── run_agent.py ├── run_service.py ├── run_client.py └── streamlit_app.py ├── media ├── agent_diagram.png ├── app_screenshot.png └── agent_architecture.png ├── .dockerignore ├── langgraph.json ├── tests ├── client │ ├── conftest.py │ └── test_client.py ├── conftest.py ├── app │ ├── conftest.py │ └── test_streamlit_app.py ├── integration │ └── test_docker_e2e.py ├── service │ ├── test_auth.py │ ├── conftest.py │ ├── test_utils.py │ ├── test_service_e2e.py │ └── test_service.py └── core │ ├── test_llm.py │ └── test_settings.py ├── codecov.yml ├── .pre-commit-config.yaml ├── docker ├── Dockerfile.app └── Dockerfile.service ├── LICENSE ├── compose.yaml ├── .github └── workflows │ ├── deploy.yml │ └── test.yml ├── .env.example ├── pyproject.toml ├── .gitignore └── README.md /src/service/__init__.py: -------------------------------------------------------------------------------- 1 | from service.service import app 2 | 3 | __all__ = ["app"] 4 | -------------------------------------------------------------------------------- /media/agent_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishpatel26/agent-service-toolkit/main/media/agent_diagram.png -------------------------------------------------------------------------------- /media/app_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishpatel26/agent-service-toolkit/main/media/app_screenshot.png -------------------------------------------------------------------------------- /media/agent_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishpatel26/agent-service-toolkit/main/media/agent_architecture.png -------------------------------------------------------------------------------- /src/client/__init__.py: -------------------------------------------------------------------------------- 1 | from client.client import AgentClient, AgentClientError 2 | 3 | __all__ = ["AgentClient", "AgentClientError"] 4 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | from core.llm import get_model 2 | from core.settings import settings 3 | 4 | __all__ = ["settings", "get_model"] 5 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .gitignore 3 | .env 4 | __pycache__ 5 | *.pyc 6 | *.pyo 7 | *.pyd 8 | .Python 9 | env 10 | venv 11 | .venv 12 | *.db 13 | -------------------------------------------------------------------------------- /src/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from agents.agents import DEFAULT_AGENT, get_agent, get_all_agent_info 2 | 3 | __all__ = ["get_agent", "get_all_agent_info", "DEFAULT_AGENT"] 4 | -------------------------------------------------------------------------------- /langgraph.json: -------------------------------------------------------------------------------- 1 | { 2 | "python_version": "3.12", 3 | "dependencies": ["."], 4 | "graphs": { 5 | "research_assistant": "./src/agents/research_assistant.py:research_assistant" 6 | }, 7 | "env": "./.env" 8 | } 9 | -------------------------------------------------------------------------------- /tests/client/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from client import AgentClient 4 | 5 | 6 | @pytest.fixture 7 | def agent_client(mock_env): 8 | """Fixture for creating a test client with a clean environment.""" 9 | ac = AgentClient(base_url="http://test", get_info=False) 10 | ac.update_agent("test-agent", verify=False) 11 | return ac 12 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | # Fail PRs that reduce total coverage by more than 2% 4 | project: 5 | default: 6 | target: auto 7 | threshold: 2% 8 | # Treat patch coverage as informational only 9 | patch: 10 | default: 11 | informational: true 12 | comment: 13 | hide_project_coverage: false 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | rev: v0.6.3 10 | hooks: 11 | - id: ruff 12 | args: [ --fix ] 13 | - id: ruff-format 14 | -------------------------------------------------------------------------------- /src/memory/sqlite.py: -------------------------------------------------------------------------------- 1 | from langgraph.checkpoint.base import BaseCheckpointSaver 2 | from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver 3 | 4 | from core.settings import settings 5 | 6 | 7 | def get_sqlite_saver() -> BaseCheckpointSaver: 8 | """Initialize and return a SQLite saver instance.""" 9 | return AsyncSqliteSaver.from_conn_string(settings.SQLITE_DB_PATH) 10 | -------------------------------------------------------------------------------- /docker/Dockerfile.app: -------------------------------------------------------------------------------- 1 | FROM python:3.12.3-slim 2 | 3 | WORKDIR /app 4 | 5 | ENV UV_PROJECT_ENVIRONMENT="/usr/local/" 6 | ENV UV_COMPILE_BYTECODE=1 7 | 8 | COPY pyproject.toml . 9 | COPY uv.lock . 10 | RUN pip install --no-cache-dir uv 11 | # Only install the client dependencies 12 | RUN uv sync --frozen --only-group client 13 | 14 | COPY src/client/ ./client/ 15 | COPY src/schema/ ./schema/ 16 | COPY src/streamlit_app.py . 17 | 18 | CMD ["streamlit", "run", "streamlit_app.py"] 19 | -------------------------------------------------------------------------------- /docker/Dockerfile.service: -------------------------------------------------------------------------------- 1 | FROM python:3.12.3-slim 2 | 3 | WORKDIR /app 4 | 5 | ENV UV_PROJECT_ENVIRONMENT="/usr/local/" 6 | ENV UV_COMPILE_BYTECODE=1 7 | 8 | COPY pyproject.toml . 9 | COPY uv.lock . 10 | RUN pip install --no-cache-dir uv 11 | RUN uv sync --frozen --no-install-project --no-dev 12 | 13 | COPY src/agents/ ./agents/ 14 | COPY src/core/ ./core/ 15 | COPY src/memory/ ./memory/ 16 | COPY src/schema/ ./schema/ 17 | COPY src/service/ ./service/ 18 | COPY src/run_service.py . 19 | 20 | CMD ["python", "run_service.py"] 21 | -------------------------------------------------------------------------------- /src/schema/__init__.py: -------------------------------------------------------------------------------- 1 | from schema.models import AllModelEnum 2 | from schema.schema import ( 3 | AgentInfo, 4 | ChatHistory, 5 | ChatHistoryInput, 6 | ChatMessage, 7 | Feedback, 8 | FeedbackResponse, 9 | ServiceMetadata, 10 | StreamInput, 11 | UserInput, 12 | ) 13 | 14 | __all__ = [ 15 | "AgentInfo", 16 | "AllModelEnum", 17 | "UserInput", 18 | "ChatMessage", 19 | "ServiceMetadata", 20 | "StreamInput", 21 | "Feedback", 22 | "FeedbackResponse", 23 | "ChatHistoryInput", 24 | "ChatHistory", 25 | ] 26 | -------------------------------------------------------------------------------- /src/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from langgraph.checkpoint.base import BaseCheckpointSaver 2 | 3 | from core.settings import DatabaseType, settings 4 | from memory.postgres import get_postgres_saver 5 | from memory.sqlite import get_sqlite_saver 6 | 7 | 8 | def initialize_database() -> BaseCheckpointSaver: 9 | """ 10 | Initialize the appropriate database checkpointer based on configuration. 11 | Returns an initialized AsyncCheckpointer instance. 12 | """ 13 | if settings.DATABASE_TYPE == DatabaseType.POSTGRES: 14 | return get_postgres_saver() 15 | else: # Default to SQLite 16 | return get_sqlite_saver() 17 | 18 | 19 | __all__ = ["initialize_database"] 20 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption( 9 | "--run-docker", action="store_true", default=False, help="run docker integration tests" 10 | ) 11 | 12 | 13 | def pytest_configure(config): 14 | config.addinivalue_line("markers", "docker: mark test as requiring docker containers") 15 | 16 | 17 | def pytest_collection_modifyitems(config, items): 18 | if not config.getoption("--run-docker"): 19 | skip_docker = pytest.mark.skip(reason="need --run-docker option to run") 20 | for item in items: 21 | if "docker" in item.keywords: 22 | item.add_marker(skip_docker) 23 | 24 | 25 | @pytest.fixture 26 | def mock_env(): 27 | """Fixture to ensure environment is clean for each test.""" 28 | with patch.dict(os.environ, {}, clear=True): 29 | yield 30 | -------------------------------------------------------------------------------- /tests/app/conftest.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from schema import AgentInfo, ServiceMetadata 6 | from schema.models import OpenAIModelName 7 | 8 | 9 | @pytest.fixture 10 | def mock_agent_client(mock_env): 11 | """Fixture for creating a mock AgentClient with a clean environment.""" 12 | 13 | mock_info = ServiceMetadata( 14 | default_agent="test-agent", 15 | agents=[ 16 | AgentInfo(key="test-agent", description="Test agent"), 17 | AgentInfo(key="chatbot", description="Chatbot"), 18 | ], 19 | default_model=OpenAIModelName.GPT_4O, 20 | models=[OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI], 21 | ) 22 | 23 | with patch("client.AgentClient") as mock_agent_client: 24 | mock_agent_client_instance = mock_agent_client.return_value 25 | mock_agent_client_instance.info = mock_info 26 | yield mock_agent_client_instance 27 | -------------------------------------------------------------------------------- /src/run_agent.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from uuid import uuid4 3 | 4 | from dotenv import load_dotenv 5 | from langchain_core.runnables import RunnableConfig 6 | 7 | load_dotenv() 8 | 9 | from agents import DEFAULT_AGENT, get_agent # noqa: E402 10 | 11 | agent = get_agent(DEFAULT_AGENT) 12 | 13 | 14 | async def main() -> None: 15 | inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]} 16 | result = await agent.ainvoke( 17 | inputs, 18 | config=RunnableConfig(configurable={"thread_id": uuid4()}), 19 | ) 20 | result["messages"][-1].pretty_print() 21 | 22 | # Draw the agent graph as png 23 | # requires: 24 | # brew install graphviz 25 | # export CFLAGS="-I $(brew --prefix graphviz)/include" 26 | # export LDFLAGS="-L $(brew --prefix graphviz)/lib" 27 | # pip install pygraphviz 28 | # 29 | # agent.get_graph().draw_png("agent_diagram.png") 30 | 31 | 32 | if __name__ == "__main__": 33 | asyncio.run(main()) 34 | -------------------------------------------------------------------------------- /src/run_service.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | 4 | import uvicorn 5 | from dotenv import load_dotenv 6 | 7 | from core import settings 8 | 9 | load_dotenv() 10 | 11 | if __name__ == "__main__": 12 | # Set Compatible event loop policy on Windows Systems. 13 | # On Windows systems, the default ProactorEventLoop can cause issues with 14 | # certain async database drivers like psycopg (PostgreSQL driver). 15 | # The WindowsSelectorEventLoopPolicy provides better compatibility and prevents 16 | # "RuntimeError: Event loop is closed" errors when working with database connections. 17 | # This needs to be set before running the application server. 18 | # Refer to the documentation for more information. 19 | # https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-operations 20 | if sys.platform == "win32": 21 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 22 | uvicorn.run("service:app", host=settings.HOST, port=settings.PORT, reload=settings.is_dev()) 23 | -------------------------------------------------------------------------------- /src/agents/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from langchain_core.callbacks import adispatch_custom_event 4 | from langchain_core.messages import ChatMessage 5 | from langchain_core.runnables import RunnableConfig 6 | from langchain_core.runnables.config import merge_configs 7 | from pydantic import BaseModel, Field 8 | 9 | 10 | class CustomData(BaseModel): 11 | "Custom data being sent by an agent" 12 | 13 | type: str = Field( 14 | description="The type of custom data, used in dispatch events", 15 | default="custom_data", 16 | ) 17 | data: dict[str, Any] = Field(description="The custom data") 18 | 19 | def to_langchain(self) -> ChatMessage: 20 | return ChatMessage(content=[self.data], role="custom") 21 | 22 | async def adispatch(self, config: RunnableConfig | None = None) -> None: 23 | dispatch_config = RunnableConfig( 24 | tags=["custom_data_dispatch"], 25 | ) 26 | await adispatch_custom_event( 27 | name=self.type, 28 | data=self.to_langchain(), 29 | config=merge_configs(config, dispatch_config), 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Joshua Carroll 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/agents/agents.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from langgraph.graph.state import CompiledStateGraph 4 | 5 | from agents.bg_task_agent.bg_task_agent import bg_task_agent 6 | from agents.chatbot import chatbot 7 | from agents.command_agent import command_agent 8 | from agents.research_assistant import research_assistant 9 | from schema import AgentInfo 10 | 11 | DEFAULT_AGENT = "research-assistant" 12 | 13 | 14 | @dataclass 15 | class Agent: 16 | description: str 17 | graph: CompiledStateGraph 18 | 19 | 20 | agents: dict[str, Agent] = { 21 | "chatbot": Agent(description="A simple chatbot.", graph=chatbot), 22 | "research-assistant": Agent( 23 | description="A research assistant with web search and calculator.", graph=research_assistant 24 | ), 25 | "command-agent": Agent(description="A command agent.", graph=command_agent), 26 | "bg-task-agent": Agent(description="A background task agent.", graph=bg_task_agent), 27 | } 28 | 29 | 30 | def get_agent(agent_id: str) -> CompiledStateGraph: 31 | return agents[agent_id].graph 32 | 33 | 34 | def get_all_agent_info() -> list[AgentInfo]: 35 | return [ 36 | AgentInfo(key=agent_id, description=agent.description) for agent_id, agent in agents.items() 37 | ] 38 | -------------------------------------------------------------------------------- /src/agents/tools.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | 4 | import numexpr 5 | from langchain_core.tools import BaseTool, tool 6 | 7 | 8 | def calculator_func(expression: str) -> str: 9 | """Calculates a math expression using numexpr. 10 | 11 | Useful for when you need to answer questions about math using numexpr. 12 | This tool is only for math questions and nothing else. Only input 13 | math expressions. 14 | 15 | Args: 16 | expression (str): A valid numexpr formatted math expression. 17 | 18 | Returns: 19 | str: The result of the math expression. 20 | """ 21 | 22 | try: 23 | local_dict = {"pi": math.pi, "e": math.e} 24 | output = str( 25 | numexpr.evaluate( 26 | expression.strip(), 27 | global_dict={}, # restrict access to globals 28 | local_dict=local_dict, # add common mathematical functions 29 | ) 30 | ) 31 | return re.sub(r"^\[|\]$", "", output) 32 | except Exception as e: 33 | raise ValueError( 34 | f'calculator("{expression}") raised error: {e}.' 35 | " Please try again with a valid numerical expression" 36 | ) 37 | 38 | 39 | calculator: BaseTool = tool(calculator_func) 40 | calculator.name = "Calculator" 41 | -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | agent_service: 3 | build: 4 | context: . 5 | dockerfile: docker/Dockerfile.service 6 | ports: 7 | - "8080:8080" 8 | env_file: 9 | - .env 10 | develop: 11 | watch: 12 | - path: src/agents/ 13 | action: sync+restart 14 | target: /app/agents/ 15 | - path: src/schema/ 16 | action: sync+restart 17 | target: /app/schema/ 18 | - path: src/service/ 19 | action: sync+restart 20 | target: /app/service/ 21 | - path: src/core/ 22 | action: sync+restart 23 | target: /app/core/ 24 | - path: src/memory/ 25 | action: sync+restart 26 | target: /app/memory/ 27 | 28 | streamlit_app: 29 | build: 30 | context: . 31 | dockerfile: docker/Dockerfile.app 32 | ports: 33 | - "8501:8501" 34 | depends_on: 35 | - agent_service 36 | environment: 37 | - AGENT_URL=http://agent_service:8080 38 | develop: 39 | watch: 40 | - path: src/client/ 41 | action: sync+restart 42 | target: /app/client/ 43 | - path: src/schema/ 44 | action: sync+restart 45 | target: /app/schema/ 46 | - path: src/streamlit_app.py 47 | action: sync+restart 48 | target: /app/streamlit_app.py 49 | -------------------------------------------------------------------------------- /tests/integration/test_docker_e2e.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from streamlit.testing.v1 import AppTest 3 | 4 | from client import AgentClient 5 | 6 | 7 | @pytest.mark.docker 8 | def test_service_with_fake_model(): 9 | """Test the service using the fake model. 10 | 11 | This test requires the service container to be running with USE_FAKE_MODEL=true 12 | """ 13 | client = AgentClient("http://0.0.0.0", agent="chatbot") 14 | response = client.invoke("Tell me a joke?", model="fake") 15 | assert response.type == "ai" 16 | assert response.content == "This is a test response from the fake model." 17 | 18 | 19 | @pytest.mark.docker 20 | def test_service_with_app(): 21 | """Test the service using the app. 22 | 23 | This test requires the service container to be running with USE_FAKE_MODEL=true 24 | """ 25 | at = AppTest.from_file("../../src/streamlit_app.py").run() 26 | assert at.chat_message[0].avatar == "assistant" 27 | welcome = at.chat_message[0].markdown[0].value 28 | assert welcome.startswith("Hello! I'm an AI-powered research assistant") 29 | assert not at.exception 30 | 31 | at.sidebar.selectbox[1].set_value("chatbot") 32 | at.chat_input[0].set_value("What is the weather in Tokyo?").run() 33 | assert at.chat_message[0].avatar == "user" 34 | assert at.chat_message[0].markdown[0].value == "What is the weather in Tokyo?" 35 | assert at.chat_message[1].avatar == "assistant" 36 | assert at.chat_message[1].markdown[0].value == "This is a test response from the fake model." 37 | assert not at.exception 38 | -------------------------------------------------------------------------------- /src/agents/chatbot.py: -------------------------------------------------------------------------------- 1 | from langchain_core.language_models.chat_models import BaseChatModel 2 | from langchain_core.messages import AIMessage 3 | from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable 4 | from langgraph.checkpoint.memory import MemorySaver 5 | from langgraph.graph import END, MessagesState, StateGraph 6 | 7 | from core import get_model, settings 8 | 9 | 10 | class AgentState(MessagesState, total=False): 11 | """`total=False` is PEP589 specs. 12 | 13 | documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality 14 | """ 15 | 16 | 17 | def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: 18 | preprocessor = RunnableLambda( 19 | lambda state: state["messages"], 20 | name="StateModifier", 21 | ) 22 | return preprocessor | model 23 | 24 | 25 | async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: 26 | m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) 27 | model_runnable = wrap_model(m) 28 | response = await model_runnable.ainvoke(state, config) 29 | 30 | # We return a list, because this will get added to the existing list 31 | return {"messages": [response]} 32 | 33 | 34 | # Define the graph 35 | agent = StateGraph(AgentState) 36 | agent.add_node("model", acall_model) 37 | agent.set_entry_point("model") 38 | 39 | # Always END after blocking unsafe content 40 | agent.add_edge("model", END) 41 | 42 | chatbot = agent.compile( 43 | checkpointer=MemorySaver(), 44 | ) 45 | -------------------------------------------------------------------------------- /src/memory/postgres.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from langgraph.checkpoint.base import BaseCheckpointSaver 4 | from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver 5 | 6 | from core.settings import settings 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def validate_postgres_config() -> None: 12 | """ 13 | Validate that all required PostgreSQL configuration is present. 14 | Raises ValueError if any required configuration is missing. 15 | """ 16 | required_vars = [ 17 | "POSTGRES_USER", 18 | "POSTGRES_PASSWORD", 19 | "POSTGRES_HOST", 20 | "POSTGRES_PORT", 21 | "POSTGRES_DB", 22 | ] 23 | 24 | missing = [var for var in required_vars if not getattr(settings, var, None)] 25 | if missing: 26 | raise ValueError( 27 | f"Missing required PostgreSQL configuration: {', '.join(missing)}. " 28 | "These environment variables must be set to use PostgreSQL persistence." 29 | ) 30 | 31 | 32 | def get_postgres_connection_string() -> str: 33 | """Build and return the PostgreSQL connection string from settings.""" 34 | return ( 35 | f"postgresql://{settings.POSTGRES_USER}:" 36 | f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" 37 | f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" 38 | f"{settings.POSTGRES_DB}" 39 | ) 40 | 41 | 42 | def get_postgres_saver() -> BaseCheckpointSaver: 43 | """Initialize and return a PostgreSQL saver instance.""" 44 | validate_postgres_config() 45 | return AsyncPostgresSaver.from_conn_string(get_postgres_connection_string()) 46 | -------------------------------------------------------------------------------- /src/agents/command_agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Literal 3 | 4 | from langchain_core.messages import AIMessage 5 | from langgraph.graph import START, MessagesState, StateGraph 6 | from langgraph.types import Command 7 | 8 | 9 | class AgentState(MessagesState, total=False): 10 | """`total=False` is PEP589 specs. 11 | 12 | documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality 13 | """ 14 | 15 | 16 | # Define the nodes 17 | 18 | 19 | def node_a(state: AgentState) -> Command[Literal["node_b", "node_c"]]: 20 | print("Called A") 21 | value = random.choice(["a", "b"]) 22 | # this is a replacement for a conditional edge function 23 | if value == "a": 24 | goto = "node_b" 25 | else: 26 | goto = "node_c" 27 | 28 | # note how Command allows you to BOTH update the graph state AND route to the next node 29 | return Command( 30 | # this is the state update 31 | update={"messages": [AIMessage(content=f"Hello {value}")]}, 32 | # this is a replacement for an edge 33 | goto=goto, 34 | ) 35 | 36 | 37 | def node_b(state: AgentState): 38 | print("Called B") 39 | return {"messages": [AIMessage(content="Hello B")]} 40 | 41 | 42 | def node_c(state: AgentState): 43 | print("Called C") 44 | return {"messages": [AIMessage(content="Hello C")]} 45 | 46 | 47 | builder = StateGraph(AgentState) 48 | builder.add_edge(START, "node_a") 49 | builder.add_node(node_a) 50 | builder.add_node(node_b) 51 | builder.add_node(node_c) 52 | # NOTE: there are no edges between nodes A, B and C! 53 | 54 | command_agent = builder.compile() 55 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to Azure 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | jobs: 10 | test: 11 | # Don't try to run deployment in forks and repo copies 12 | if: github.repository == 'JoshuaC215/agent-service-toolkit' 13 | uses: ./.github/workflows/test.yml 14 | 15 | build: 16 | runs-on: ubuntu-latest 17 | needs: [test] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Docker Buildx 23 | uses: docker/setup-buildx-action@v2 24 | 25 | - name: Log in to container registry 26 | uses: docker/login-action@v2 27 | with: 28 | registry: https://index.docker.io/v1/ 29 | username: ${{ secrets.DOCKER_USERNAME }} 30 | password: ${{ secrets.DOCKER_TOKEN }} 31 | 32 | - name: Build and push agent-service container image to registry 33 | uses: docker/build-push-action@v3 34 | with: 35 | context: . 36 | push: true 37 | tags: index.docker.io/${{ secrets.DOCKER_USERNAME }}/agent-service-toolkit.service:${{ github.sha }} 38 | file: docker/Dockerfile.service 39 | 40 | deploy: 41 | runs-on: ubuntu-latest 42 | needs: build 43 | 44 | steps: 45 | - name: Deploy to Azure Web App 46 | id: deploy-to-webapp 47 | uses: azure/webapps-deploy@v2 48 | with: 49 | app-name: 'agent-service' 50 | slot-name: 'production' 51 | publish-profile: ${{ secrets.AZURE_WEBAPP_PUBLISH_PROFILE }} 52 | images: 'index.docker.io/${{ secrets.DOCKER_USERNAME }}/agent-service-toolkit.service:${{ github.sha }}' 53 | -------------------------------------------------------------------------------- /tests/service/test_auth.py: -------------------------------------------------------------------------------- 1 | from pydantic import SecretStr 2 | 3 | 4 | def test_no_auth_secret(mock_settings, mock_agent, test_client): 5 | """Test that when AUTH_SECRET is not set, all requests are allowed""" 6 | mock_settings.AUTH_SECRET = None 7 | response = test_client.post( 8 | "/invoke", 9 | json={"message": "test"}, 10 | headers={"Authorization": "Bearer any-token"}, 11 | ) 12 | assert response.status_code == 200 13 | 14 | # Should also work without any auth header 15 | response = test_client.post("/invoke", json={"message": "test"}) 16 | assert response.status_code == 200 17 | 18 | 19 | def test_auth_secret_correct(mock_settings, mock_agent, test_client): 20 | """Test that when AUTH_SECRET is set, requests with correct token are allowed""" 21 | mock_settings.AUTH_SECRET = SecretStr("test-secret") 22 | response = test_client.post( 23 | "/invoke", 24 | json={"message": "test"}, 25 | headers={"Authorization": "Bearer test-secret"}, 26 | ) 27 | assert response.status_code == 200 28 | 29 | 30 | def test_auth_secret_incorrect(mock_settings, mock_agent, test_client): 31 | """Test that when AUTH_SECRET is set, requests with wrong token are rejected""" 32 | mock_settings.AUTH_SECRET = SecretStr("test-secret") 33 | response = test_client.post( 34 | "/invoke", 35 | json={"message": "test"}, 36 | headers={"Authorization": "Bearer wrong-secret"}, 37 | ) 38 | assert response.status_code == 401 39 | 40 | # Should also reject requests with no auth header 41 | response = test_client.post("/invoke", json={"message": "test"}) 42 | assert response.status_code == 401 43 | -------------------------------------------------------------------------------- /src/run_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from client import AgentClient 4 | from core import settings 5 | from schema import ChatMessage 6 | 7 | 8 | async def amain() -> None: 9 | #### ASYNC #### 10 | client = AgentClient(settings.BASE_URL) 11 | 12 | print("Agent info:") 13 | print(client.info) 14 | 15 | print("Chat example:") 16 | response = await client.ainvoke("Tell me a brief joke?", model="gpt-4o") 17 | response.pretty_print() 18 | 19 | print("\nStream example:") 20 | async for message in client.astream("Share a quick fun fact?"): 21 | if isinstance(message, str): 22 | print(message, flush=True, end="") 23 | elif isinstance(message, ChatMessage): 24 | print("\n", flush=True) 25 | message.pretty_print() 26 | else: 27 | print(f"ERROR: Unknown type - {type(message)}") 28 | 29 | 30 | def main() -> None: 31 | #### SYNC #### 32 | client = AgentClient(settings.BASE_URL) 33 | 34 | print("Agent info:") 35 | print(client.info) 36 | 37 | print("Chat example:") 38 | response = client.invoke("Tell me a brief joke?", model="gpt-4o") 39 | response.pretty_print() 40 | 41 | print("\nStream example:") 42 | for message in client.stream("Share a quick fun fact?"): 43 | if isinstance(message, str): 44 | print(message, flush=True, end="") 45 | elif isinstance(message, ChatMessage): 46 | print("\n", flush=True) 47 | message.pretty_print() 48 | else: 49 | print(f"ERROR: Unknown type - {type(message)}") 50 | 51 | 52 | if __name__ == "__main__": 53 | print("Running in sync mode") 54 | main() 55 | print("\n\n\n\n\n") 56 | print("Running in async mode") 57 | asyncio.run(amain()) 58 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # API keys for different providers 2 | OPENAI_API_KEY= 3 | AZURE_OPENAI_API_KEY= 4 | DEEPSEEK_API_KEY= 5 | ANTHROPIC_API_KEY= 6 | GOOGLE_API_KEY= 7 | GROQ_API_KEY= 8 | USE_AWS_BEDROCK=false 9 | 10 | # Use a fake model for testing 11 | USE_FAKE_MODEL=false 12 | 13 | # Set a default model 14 | DEFAULT_MODEL= 15 | 16 | # Web server configuration 17 | HOST=0.0.0.0 18 | PORT=8080 19 | 20 | # Authentication secret, HTTP bearer token header is required if set 21 | AUTH_SECRET= 22 | 23 | # Langsmith configuration 24 | LANGCHAIN_TRACING_V2=false 25 | LANGCHAIN_PROJECT=default 26 | LANGCHAIN_ENDPOINT=https://api.smith.langchain.com 27 | LANGCHAIN_API_KEY= 28 | 29 | # Application mode. If the value is "dev", it will enable uvicorn reload 30 | MODE= 31 | 32 | # Database type. 33 | # If the value is "postgres", then it will require Postgresql related environment variables. 34 | # If the value is "sqlite", then you can configure optional file path via SQLITE_DB_PATH 35 | DATABASE_TYPE= 36 | 37 | # If DATABASE_TYPE=sqlite (Optional) 38 | SQLITE_DB_PATH= 39 | 40 | # If DATABASE_TYPE=postgres 41 | POSTGRES_USER= 42 | POSTGRES_PASSWORD= 43 | POSTGRES_HOST= 44 | POSTGRES_PORT= 45 | POSTGRES_DB= 46 | 47 | # OpenWeatherMap API key 48 | OPENWEATHERMAP_API_KEY= 49 | 50 | # Add for running ollama 51 | # OLLAMA_MODEL=llama3.2 52 | # Note: set OLLAMA_BASE_URL if running service in docker and ollama on bare metal 53 | # OLLAMA_BASE_URL=http://host.docker.internal:11434 54 | 55 | # Add for running Azure OpenAI 56 | # AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com 57 | # AZURE_OPENAI_API_VERSION=2024-10-21 58 | # AZURE_OPENAI_DEPLOYMENT_MAP={"gpt-4o": "gpt-4o-deployment", "gpt-4o-mini": "gpt-4o-mini-deployment"} 59 | 60 | # Agent URL: used in Streamlit app - if not set, defaults to http://{HOST}:{PORT} 61 | # AGENT_URL=http://0.0.0.0:8080 62 | -------------------------------------------------------------------------------- /tests/service/conftest.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, Mock, patch 2 | 3 | import pytest 4 | from fastapi.testclient import TestClient 5 | from langchain_core.messages import AIMessage 6 | 7 | from service import app 8 | 9 | 10 | @pytest.fixture 11 | def test_client(): 12 | """Fixture to create a FastAPI test client.""" 13 | return TestClient(app) 14 | 15 | 16 | @pytest.fixture 17 | def mock_agent(): 18 | """Fixture to create a mock agent that can be configured for different test scenarios.""" 19 | agent_mock = AsyncMock() 20 | agent_mock.ainvoke = AsyncMock(return_value={"messages": [AIMessage(content="Test response")]}) 21 | agent_mock.get_state = Mock() # Default empty mock for get_state 22 | with patch("service.service.get_agent", Mock(return_value=agent_mock)): 23 | yield agent_mock 24 | 25 | 26 | @pytest.fixture 27 | def mock_settings(mock_env): 28 | """Fixture to ensure settings are clean for each test.""" 29 | with patch("service.service.settings") as mock_settings: 30 | yield mock_settings 31 | 32 | 33 | @pytest.fixture 34 | def mock_httpx(): 35 | """Patch httpx.stream and httpx.get to use our test client.""" 36 | 37 | with TestClient(app) as client: 38 | 39 | def mock_stream(method: str, url: str, **kwargs): 40 | # Strip the base URL since TestClient expects just the path 41 | path = url.replace("http://0.0.0.0", "") 42 | return client.stream(method, path, **kwargs) 43 | 44 | def mock_get(url: str, **kwargs): 45 | # Strip the base URL since TestClient expects just the path 46 | path = url.replace("http://0.0.0.0", "") 47 | return client.get(path, **kwargs) 48 | 49 | with patch("httpx.stream", mock_stream): 50 | with patch("httpx.get", mock_get): 51 | yield 52 | -------------------------------------------------------------------------------- /src/agents/bg_task_agent/task.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from uuid import uuid4 3 | 4 | from langchain_core.messages import BaseMessage 5 | from langchain_core.runnables import RunnableConfig 6 | 7 | from agents.utils import CustomData 8 | from schema.task_data import TaskData 9 | 10 | 11 | class Task: 12 | def __init__(self, task_name: str) -> None: 13 | self.name = task_name 14 | self.id = str(uuid4()) 15 | self.state: Literal["new", "running", "complete"] = "new" 16 | self.result: Literal["success", "error"] | None = None 17 | 18 | async def _generate_and_dispatch_message(self, config: RunnableConfig, data: dict): 19 | task_data = TaskData(name=self.name, run_id=self.id, state=self.state, data=data) 20 | if self.result: 21 | task_data.result = self.result 22 | task_custom_data = CustomData( 23 | type=self.name, 24 | data=task_data.model_dump(), 25 | ) 26 | await task_custom_data.adispatch(config) 27 | return task_custom_data.to_langchain() 28 | 29 | async def start(self, config: RunnableConfig, data: dict = {}) -> BaseMessage: 30 | self.state = "new" 31 | task_message = await self._generate_and_dispatch_message(config, data) 32 | return task_message 33 | 34 | async def write_data(self, config: RunnableConfig, data: dict) -> BaseMessage: 35 | if self.state == "complete": 36 | raise ValueError("Only incomplete tasks can output data.") 37 | self.state = "running" 38 | task_message = await self._generate_and_dispatch_message(config, data) 39 | return task_message 40 | 41 | async def finish( 42 | self, result: Literal["success", "error"], config: RunnableConfig, data: dict = {} 43 | ) -> BaseMessage: 44 | self.state = "complete" 45 | self.result = result 46 | task_message = await self._generate_and_dispatch_message(config, data) 47 | return task_message 48 | -------------------------------------------------------------------------------- /tests/service/test_utils.py: -------------------------------------------------------------------------------- 1 | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall, ToolMessage 2 | 3 | from service.utils import langchain_to_chat_message 4 | 5 | 6 | def test_messages_from_langchain() -> None: 7 | lc_human_message = HumanMessage(content="Hello, world!") 8 | human_message = langchain_to_chat_message(lc_human_message) 9 | assert human_message.type == "human" 10 | assert human_message.content == "Hello, world!" 11 | 12 | lc_ai_message = AIMessage(content="Hello, world!") 13 | ai_message = langchain_to_chat_message(lc_ai_message) 14 | assert ai_message.type == "ai" 15 | assert ai_message.content == "Hello, world!" 16 | 17 | lc_tool_message = ToolMessage(content="Hello, world!", tool_call_id="123") 18 | tool_message = langchain_to_chat_message(lc_tool_message) 19 | assert tool_message.type == "tool" 20 | assert tool_message.content == "Hello, world!" 21 | assert tool_message.tool_call_id == "123" 22 | 23 | lc_system_message = SystemMessage(content="Hello, world!") 24 | try: 25 | _ = langchain_to_chat_message(lc_system_message) 26 | except ValueError as e: 27 | assert str(e) == "Unsupported message type: SystemMessage" 28 | 29 | 30 | def test_message_run_id_usage() -> None: 31 | run_id = "847c6285-8fc9-4560-a83f-4e6285809254" 32 | lc_message = AIMessage(content="Hello, world!") 33 | ai_message = langchain_to_chat_message(lc_message) 34 | ai_message.run_id = run_id 35 | assert ai_message.run_id == run_id 36 | 37 | 38 | def test_messages_tool_calls() -> None: 39 | tool_call = ToolCall(name="test_tool", args={"x": 1, "y": 2}, id="call_Jja7") 40 | lc_ai_message = AIMessage(content="", tool_calls=[tool_call]) 41 | ai_message = langchain_to_chat_message(lc_ai_message) 42 | assert ai_message.tool_calls[0]["id"] == "call_Jja7" 43 | assert ai_message.tool_calls[0]["name"] == "test_tool" 44 | assert ai_message.tool_calls[0]["args"] == {"x": 1, "y": 2} 45 | -------------------------------------------------------------------------------- /src/schema/models.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum, auto 2 | from typing import TypeAlias 3 | 4 | 5 | class Provider(StrEnum): 6 | OPENAI = auto() 7 | AZURE_OPENAI = auto() 8 | DEEPSEEK = auto() 9 | ANTHROPIC = auto() 10 | GOOGLE = auto() 11 | GROQ = auto() 12 | AWS = auto() 13 | OLLAMA = auto() 14 | FAKE = auto() 15 | 16 | 17 | class OpenAIModelName(StrEnum): 18 | """https://platform.openai.com/docs/models/gpt-4o""" 19 | 20 | GPT_4O_MINI = "gpt-4o-mini" 21 | GPT_4O = "gpt-4o" 22 | 23 | 24 | class AzureOpenAIModelName(StrEnum): 25 | """Azure OpenAI model names""" 26 | 27 | AZURE_GPT_4O = "azure-gpt-4o" 28 | AZURE_GPT_4O_MINI = "azure-gpt-4o-mini" 29 | 30 | 31 | class DeepseekModelName(StrEnum): 32 | """https://api-docs.deepseek.com/quick_start/pricing""" 33 | 34 | DEEPSEEK_CHAT = "deepseek-chat" 35 | 36 | 37 | class AnthropicModelName(StrEnum): 38 | """https://docs.anthropic.com/en/docs/about-claude/models#model-names""" 39 | 40 | HAIKU_3 = "claude-3-haiku" 41 | HAIKU_35 = "claude-3.5-haiku" 42 | SONNET_35 = "claude-3.5-sonnet" 43 | 44 | 45 | class GoogleModelName(StrEnum): 46 | """https://ai.google.dev/gemini-api/docs/models/gemini""" 47 | 48 | GEMINI_15_FLASH = "gemini-1.5-flash" 49 | 50 | 51 | class GroqModelName(StrEnum): 52 | """https://console.groq.com/docs/models""" 53 | 54 | LLAMA_31_8B = "groq-llama-3.1-8b" 55 | LLAMA_33_70B = "groq-llama-3.3-70b" 56 | 57 | LLAMA_GUARD_3_8B = "groq-llama-guard-3-8b" 58 | 59 | 60 | class AWSModelName(StrEnum): 61 | """https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html""" 62 | 63 | BEDROCK_HAIKU = "bedrock-3.5-haiku" 64 | BEDROCK_SONNET = "bedrock-3.5-sonnet" 65 | 66 | 67 | class OllamaModelName(StrEnum): 68 | """https://ollama.com/search""" 69 | 70 | OLLAMA_GENERIC = "ollama" 71 | 72 | 73 | class FakeModelName(StrEnum): 74 | """Fake model for testing.""" 75 | 76 | FAKE = "fake" 77 | 78 | 79 | AllModelEnum: TypeAlias = ( 80 | OpenAIModelName 81 | | AzureOpenAIModelName 82 | | DeepseekModelName 83 | | AnthropicModelName 84 | | GoogleModelName 85 | | GroqModelName 86 | | AWSModelName 87 | | OllamaModelName 88 | | FakeModelName 89 | ) 90 | -------------------------------------------------------------------------------- /src/agents/bg_task_agent/bg_task_agent.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from langchain_core.language_models.chat_models import BaseChatModel 4 | from langchain_core.messages import AIMessage 5 | from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable 6 | from langgraph.checkpoint.memory import MemorySaver 7 | from langgraph.graph import END, MessagesState, StateGraph 8 | 9 | from agents.bg_task_agent.task import Task 10 | from core import get_model, settings 11 | 12 | 13 | class AgentState(MessagesState, total=False): 14 | """`total=False` is PEP589 specs. 15 | 16 | documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality 17 | """ 18 | 19 | 20 | def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: 21 | preprocessor = RunnableLambda( 22 | lambda state: state["messages"], 23 | name="StateModifier", 24 | ) 25 | return preprocessor | model 26 | 27 | 28 | async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: 29 | m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) 30 | model_runnable = wrap_model(m) 31 | response = await model_runnable.ainvoke(state, config) 32 | 33 | # We return a list, because this will get added to the existing list 34 | return {"messages": [response]} 35 | 36 | 37 | async def bg_task(state: AgentState, config: RunnableConfig) -> AgentState: 38 | task1 = Task("Simple task 1...") 39 | task2 = Task("Simple task 2...") 40 | 41 | await task1.start(config=config) 42 | await asyncio.sleep(2) 43 | await task2.start(config=config) 44 | await asyncio.sleep(2) 45 | await task1.write_data(config=config, data={"status": "Still running..."}) 46 | await asyncio.sleep(2) 47 | await task2.finish(result="error", config=config, data={"output": 42}) 48 | await asyncio.sleep(2) 49 | await task1.finish(result="success", config=config, data={"output": 42}) 50 | return {"messages": []} 51 | 52 | 53 | # Define the graph 54 | agent = StateGraph(AgentState) 55 | agent.add_node("model", acall_model) 56 | agent.add_node("bg_task", bg_task) 57 | agent.set_entry_point("bg_task") 58 | 59 | agent.add_edge("bg_task", "model") 60 | agent.add_edge("model", END) 61 | 62 | bg_task_agent = agent.compile( 63 | checkpointer=MemorySaver(), 64 | ) 65 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "agent-service-toolkit" 3 | version = "0.1.0" 4 | description = "Full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit" 5 | readme = "README.md" 6 | authors = [{ name = "Joshua Carroll", email = "carroll.joshk@gmail.com" }] 7 | classifiers = [ 8 | "Development Status :: 4 - Beta", 9 | "License :: OSI Approved :: MIT License", 10 | "Framework :: FastAPI", 11 | "Programming Language :: Python :: 3.11", 12 | "Programming Language :: Python :: 3.12", 13 | "Programming Language :: Python :: 3.13", 14 | ] 15 | 16 | requires-python = ">=3.11" 17 | 18 | dependencies = [ 19 | "duckduckgo-search>=7.3.0", 20 | "fastapi ~=0.115.5", 21 | "grpcio >=1.68.0", 22 | "httpx ~=0.27.2", 23 | "jiter ~=0.8.2", 24 | "langchain-core ~=0.3.33", 25 | "langchain-community ~=0.3.16", 26 | "langchain-openai ~=0.2.9", 27 | "langchain-anthropic ~= 0.3.0", 28 | "langchain-google-genai ~=2.0.11", 29 | "langchain-groq ~=0.2.1", 30 | "langchain-aws ~=0.2.14", 31 | "langchain-ollama ~=0.2.3", 32 | "langgraph ~=0.2.68", 33 | "langgraph-checkpoint-sqlite ~=2.0.1", 34 | "langgraph-checkpoint-postgres ~=2.0.13", 35 | "langsmith ~=0.1.145", 36 | "numexpr ~=2.10.1", 37 | "numpy ~=1.26.4; python_version <= '3.12'", 38 | "numpy ~=2.2.3; python_version >= '3.13'", 39 | "pandas ~=2.2.3", 40 | "psycopg[binary,pool] ~=3.2.4", 41 | "pyarrow >=18.1.0", 42 | "pydantic ~=2.10.1", 43 | "pydantic-settings ~=2.6.1", 44 | "pyowm ~=3.3.0", 45 | "python-dotenv ~=1.0.1", 46 | "setuptools ~=75.6.0", 47 | "streamlit ~=1.40.1", 48 | "tiktoken >=0.8.0", 49 | "uvicorn ~=0.32.1", 50 | ] 51 | 52 | [dependency-groups] 53 | dev = [ 54 | "pre-commit", 55 | "pytest", 56 | "pytest-cov", 57 | "pytest-env", 58 | "pytest-asyncio", 59 | "ruff", 60 | ] 61 | 62 | # Group for the minimal dependencies to run just the client and Streamlit app. 63 | # These are also installed in the default dependencies. 64 | # To install run: `uv sync --frozen --only-group client` 65 | client = [ 66 | "httpx~=0.27.2", 67 | "pydantic ~=2.10.1", 68 | "python-dotenv ~=1.0.1", 69 | "streamlit~=1.40.1", 70 | ] 71 | 72 | [tool.ruff] 73 | line-length = 100 74 | target-version = "py311" 75 | 76 | [tool.ruff.lint] 77 | extend-select = ["I", "U"] 78 | 79 | [tool.pytest.ini_options] 80 | pythonpath = ["src"] 81 | asyncio_default_fixture_loop_scope = "function" 82 | 83 | [tool.pytest_env] 84 | OPENAI_API_KEY = "sk-fake-openai-key" 85 | -------------------------------------------------------------------------------- /tests/core/test_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | from langchain_anthropic import ChatAnthropic 6 | from langchain_community.chat_models import FakeListChatModel 7 | from langchain_groq import ChatGroq 8 | from langchain_ollama import ChatOllama 9 | from langchain_openai import ChatOpenAI 10 | 11 | from core.llm import get_model 12 | from schema.models import ( 13 | AnthropicModelName, 14 | FakeModelName, 15 | GroqModelName, 16 | OllamaModelName, 17 | OpenAIModelName, 18 | ) 19 | 20 | 21 | def test_get_model_openai(): 22 | with patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}): 23 | model = get_model(OpenAIModelName.GPT_4O_MINI) 24 | assert isinstance(model, ChatOpenAI) 25 | assert model.model_name == "gpt-4o-mini" 26 | assert model.temperature == 0.5 27 | assert model.streaming is True 28 | 29 | 30 | def test_get_model_anthropic(): 31 | with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_key"}): 32 | model = get_model(AnthropicModelName.HAIKU_3) 33 | assert isinstance(model, ChatAnthropic) 34 | assert model.model == "claude-3-haiku-20240307" 35 | assert model.temperature == 0.5 36 | assert model.streaming is True 37 | 38 | 39 | def test_get_model_groq(): 40 | with patch.dict(os.environ, {"GROQ_API_KEY": "test_key"}): 41 | model = get_model(GroqModelName.LLAMA_31_8B) 42 | assert isinstance(model, ChatGroq) 43 | assert model.model_name == "llama-3.1-8b-instant" 44 | assert model.temperature == 0.5 45 | 46 | 47 | def test_get_model_groq_guard(): 48 | with patch.dict(os.environ, {"GROQ_API_KEY": "test_key"}): 49 | model = get_model(GroqModelName.LLAMA_GUARD_3_8B) 50 | assert isinstance(model, ChatGroq) 51 | assert model.model_name == "llama-guard-3-8b" 52 | assert model.temperature < 0.01 53 | 54 | 55 | def test_get_model_ollama(): 56 | with patch("core.settings.settings.OLLAMA_MODEL", "llama3.3"): 57 | model = get_model(OllamaModelName.OLLAMA_GENERIC) 58 | assert isinstance(model, ChatOllama) 59 | assert model.model == "llama3.3" 60 | assert model.temperature == 0.5 61 | 62 | 63 | def test_get_model_fake(): 64 | model = get_model(FakeModelName.FAKE) 65 | assert isinstance(model, FakeListChatModel) 66 | assert model.responses == ["This is a test response from the fake model."] 67 | 68 | 69 | def test_get_model_invalid(): 70 | with pytest.raises(ValueError, match="Unsupported model:"): 71 | # Using type: ignore since we're intentionally testing invalid input 72 | get_model("invalid_model") # type: ignore 73 | -------------------------------------------------------------------------------- /src/service/utils.py: -------------------------------------------------------------------------------- 1 | from langchain_core.messages import ( 2 | AIMessage, 3 | BaseMessage, 4 | HumanMessage, 5 | ToolMessage, 6 | ) 7 | from langchain_core.messages import ( 8 | ChatMessage as LangchainChatMessage, 9 | ) 10 | 11 | from schema import ChatMessage 12 | 13 | 14 | def convert_message_content_to_string(content: str | list[str | dict]) -> str: 15 | if isinstance(content, str): 16 | return content 17 | text: list[str] = [] 18 | for content_item in content: 19 | if isinstance(content_item, str): 20 | text.append(content_item) 21 | continue 22 | if content_item["type"] == "text": 23 | text.append(content_item["text"]) 24 | return "".join(text) 25 | 26 | 27 | def langchain_to_chat_message(message: BaseMessage) -> ChatMessage: 28 | """Create a ChatMessage from a LangChain message.""" 29 | match message: 30 | case HumanMessage(): 31 | human_message = ChatMessage( 32 | type="human", 33 | content=convert_message_content_to_string(message.content), 34 | ) 35 | return human_message 36 | case AIMessage(): 37 | ai_message = ChatMessage( 38 | type="ai", 39 | content=convert_message_content_to_string(message.content), 40 | ) 41 | if message.tool_calls: 42 | ai_message.tool_calls = message.tool_calls 43 | if message.response_metadata: 44 | ai_message.response_metadata = message.response_metadata 45 | return ai_message 46 | case ToolMessage(): 47 | tool_message = ChatMessage( 48 | type="tool", 49 | content=convert_message_content_to_string(message.content), 50 | tool_call_id=message.tool_call_id, 51 | ) 52 | return tool_message 53 | case LangchainChatMessage(): 54 | if message.role == "custom": 55 | custom_message = ChatMessage( 56 | type="custom", 57 | content="", 58 | custom_data=message.content[0], 59 | ) 60 | return custom_message 61 | else: 62 | raise ValueError(f"Unsupported chat message role: {message.role}") 63 | case _: 64 | raise ValueError(f"Unsupported message type: {message.__class__.__name__}") 65 | 66 | 67 | def remove_tool_calls(content: str | list[str | dict]) -> str | list[str | dict]: 68 | """Remove tool calls from content.""" 69 | if isinstance(content, str): 70 | return content 71 | # Currently only Anthropic models stream tool calls, using content item type tool_use. 72 | return [ 73 | content_item 74 | for content_item in content 75 | if isinstance(content_item, str) or content_item["type"] != "tool_use" 76 | ] 77 | -------------------------------------------------------------------------------- /src/schema/task_data.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class TaskData(BaseModel): 7 | name: str | None = Field( 8 | description="Name of the task.", default=None, examples=["Check input safety"] 9 | ) 10 | run_id: str = Field( 11 | description="ID of the task run to pair state updates to.", 12 | default="", 13 | examples=["847c6285-8fc9-4560-a83f-4e6285809254"], 14 | ) 15 | state: Literal["new", "running", "complete"] | None = Field( 16 | description="Current state of given task instance.", 17 | default=None, 18 | examples=["running"], 19 | ) 20 | result: Literal["success", "error"] | None = Field( 21 | description="Result of given task instance.", 22 | default=None, 23 | examples=["running"], 24 | ) 25 | data: dict[str, Any] = Field( 26 | description="Additional data generated by the task.", 27 | default={}, 28 | ) 29 | 30 | def completed(self) -> bool: 31 | return self.state == "complete" 32 | 33 | def completed_with_error(self) -> bool: 34 | return self.state == "complete" and self.result == "error" 35 | 36 | 37 | class TaskDataStatus: 38 | def __init__(self) -> None: 39 | import streamlit as st 40 | 41 | self.status = st.status("") 42 | self.current_task_data: dict[str, TaskData] = {} 43 | 44 | def add_and_draw_task_data(self, task_data: TaskData) -> None: 45 | status = self.status 46 | status_str = f"Task **{task_data.name}** " 47 | match task_data.state: 48 | case "new": 49 | status_str += "has :blue[started]. Input:" 50 | case "running": 51 | status_str += "wrote:" 52 | case "complete": 53 | if task_data.result == "success": 54 | status_str += ":green[completed successfully]. Output:" 55 | else: 56 | status_str += ":red[ended with error]. Output:" 57 | status.write(status_str) 58 | status.write(task_data.data) 59 | status.write("---") 60 | if task_data.run_id not in self.current_task_data: 61 | # Status label always shows the last newly started task 62 | status.update(label=f"""Task: {task_data.name}""") 63 | self.current_task_data[task_data.run_id] = task_data 64 | if all(entry.completed() for entry in self.current_task_data.values()): 65 | # Status is "error" if any task has errored 66 | if any(entry.completed_with_error() for entry in self.current_task_data.values()): 67 | state = "error" 68 | # Status is "complete" if all tasks have completed successfully 69 | else: 70 | state = "complete" 71 | # Status is "running" until all tasks have completed 72 | else: 73 | state = "running" 74 | status.update(state=state) 75 | -------------------------------------------------------------------------------- /tests/service/test_service_e2e.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from langchain_core.messages import AIMessage, ToolCall, ToolMessage 4 | from langchain_core.runnables import RunnableConfig 5 | from langgraph.checkpoint.memory import MemorySaver 6 | from langgraph.graph import END, MessagesState, StateGraph 7 | 8 | from agents.agents import Agent 9 | from agents.utils import CustomData 10 | from client import AgentClient 11 | from schema.schema import ChatMessage 12 | from service.utils import langchain_to_chat_message 13 | 14 | START_MESSAGE = CustomData(type="start", data={"key1": "value1", "key2": 123}) 15 | 16 | STATIC_MESSAGES = [ 17 | AIMessage( 18 | content="", 19 | tool_calls=[ 20 | ToolCall( 21 | name="test_tool", 22 | args={"arg1": "value1"}, 23 | id="test_call_id", 24 | ), 25 | ], 26 | ), 27 | ToolMessage(content="42", tool_call_id="test_call_id"), 28 | AIMessage(content="The answer is 42"), 29 | CustomData(type="end", data={"time": "end"}).to_langchain(), 30 | ] 31 | 32 | 33 | EXPECTED_OUTPUT_MESSAGES = [ 34 | langchain_to_chat_message(m) for m in [START_MESSAGE.to_langchain()] + STATIC_MESSAGES 35 | ] 36 | 37 | 38 | def test_messages_conversion() -> None: 39 | """Verify that our list of messages is converted to the expected output.""" 40 | 41 | messages = EXPECTED_OUTPUT_MESSAGES 42 | 43 | # Verify the sequence of messages 44 | assert len(messages) == 5 45 | 46 | # First message: Custom data start marker 47 | assert messages[0].type == "custom" 48 | assert messages[0].custom_data == {"key1": "value1", "key2": 123} 49 | 50 | # Second message: AI with tool call 51 | assert messages[1].type == "ai" 52 | assert len(messages[1].tool_calls) == 1 53 | assert messages[1].tool_calls[0]["name"] == "test_tool" 54 | assert messages[1].tool_calls[0]["args"] == {"arg1": "value1"} 55 | 56 | # Third message: Tool response 57 | assert messages[2].type == "tool" 58 | assert messages[2].content == "42" 59 | assert messages[2].tool_call_id == "test_call_id" 60 | 61 | # Fourth message: Final AI response 62 | assert messages[3].type == "ai" 63 | assert messages[3].content == "The answer is 42" 64 | 65 | # Fifth message: Custom data end marker 66 | assert messages[4].type == "custom" 67 | assert messages[4].custom_data == {"time": "end"} 68 | 69 | 70 | async def static_messages(state: MessagesState, config: RunnableConfig) -> MessagesState: 71 | await START_MESSAGE.adispatch(config) 72 | return {"messages": STATIC_MESSAGES} 73 | 74 | 75 | agent = StateGraph(MessagesState) 76 | agent.add_node("static_messages", static_messages) 77 | agent.set_entry_point("static_messages") 78 | agent.add_edge("static_messages", END) 79 | static_agent = agent.compile(checkpointer=MemorySaver()) 80 | 81 | 82 | def test_agent_stream(mock_httpx): 83 | """Test that streaming from our static agent works correctly with token streaming.""" 84 | agent_meta = Agent(description="A static agent.", graph=static_agent) 85 | with patch.dict("agents.agents.agents", {"static-agent": agent_meta}, clear=True): 86 | client = AgentClient(agent="static-agent") 87 | 88 | # Use stream to get intermediate responses 89 | messages = [] 90 | 91 | def agent_lookup(agent_id): 92 | if agent_id == "static-agent": 93 | return static_agent 94 | return None 95 | 96 | with patch("service.service.get_agent", side_effect=agent_lookup): 97 | for response in client.stream("Test message", stream_tokens=False): 98 | if isinstance(response, ChatMessage): 99 | messages.append(response) 100 | 101 | for expected, actual in zip(EXPECTED_OUTPUT_MESSAGES, messages): 102 | actual.run_id = None 103 | assert expected == actual 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Streamlit and sqlite 2 | .streamlit/secrets.toml 3 | checkpoints.db 4 | checkpoints.db-* 5 | 6 | # VSCode 7 | .vscode 8 | .DS_Store 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .python-version 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | .idea/ 173 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Build and test 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | workflow_call: 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | test-python: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.11", "3.12", "3.13"] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install uv 27 | uses: astral-sh/setup-uv@v4 28 | with: 29 | version: "0.6.3" 30 | - name: Install dependencies with uv 31 | run: | 32 | uv sync --frozen 33 | env: 34 | UV_SYSTEM_PYTHON: 1 35 | - name: Lint and format with ruff 36 | run: | 37 | uv run ruff format --check 38 | uv run ruff check --output-format github 39 | 40 | - name: Test with pytest 41 | run: | 42 | uv run pytest --cov=src/ --cov-report=xml 43 | 44 | - name: Upload coverage reports to Codecov 45 | uses: codecov/codecov-action@v5 46 | with: 47 | token: ${{ secrets.CODECOV_TOKEN }} 48 | 49 | test-docker: 50 | runs-on: ubuntu-latest 51 | 52 | services: 53 | dind: 54 | image: docker:dind 55 | ports: 56 | - 2375:2375 57 | options: >- 58 | --privileged 59 | --health-cmd "docker info" 60 | --health-interval 10s 61 | --health-timeout 5s 62 | --health-retries 5 63 | 64 | steps: 65 | - uses: actions/checkout@v4 66 | 67 | - name: Set up Docker Buildx 68 | uses: docker/setup-buildx-action@v2 69 | with: 70 | driver-opts: network=host 71 | 72 | - name: Build service image 73 | uses: docker/build-push-action@v3 74 | with: 75 | context: . 76 | push: false 77 | load: true 78 | tags: agent-service-toolkit.service:${{ github.sha }} 79 | file: docker/Dockerfile.service 80 | 81 | - name: Build app image 82 | uses: docker/build-push-action@v3 83 | with: 84 | context: . 85 | push: false 86 | load: true 87 | tags: agent-service-toolkit.app:${{ github.sha }} 88 | file: docker/Dockerfile.app 89 | 90 | - name: Start service container 91 | run: docker run -d --name service-container --network host -e USE_FAKE_MODEL=true -e PORT=80 agent-service-toolkit.service:${{ github.sha }} 92 | 93 | - name: Confirm service starts correctly 94 | run: | 95 | timeout 30 bash -c ' 96 | while ! curl -s http://0.0.0.0/health; do 97 | echo "Waiting for service to be ready..." 98 | docker logs service-container 99 | sleep 2 100 | done 101 | ' 102 | 103 | - name: Run app container 104 | run: docker run -d --name app-container --network host -e AGENT_URL=http://0.0.0.0 agent-service-toolkit.app:${{ github.sha }} 105 | 106 | - name: Confirm app starts correctly 107 | run: | 108 | timeout 30 bash -c ' 109 | while ! curl -s http://localhost:8501/healthz; do 110 | echo "Waiting for app to be ready..." 111 | docker logs app-container 112 | sleep 2 113 | done 114 | ' 115 | 116 | - name: Set up Python 117 | uses: actions/setup-python@v5 118 | with: 119 | python-version-file: "pyproject.toml" 120 | - name: Install uv 121 | uses: astral-sh/setup-uv@v4 122 | with: 123 | version: "0.6.3" 124 | - name: Install ONLY CLIENT dependencies with uv 125 | run: | 126 | uv sync --frozen --only-group client --only-group dev 127 | env: 128 | UV_SYSTEM_PYTHON: 1 129 | - name: Run integration tests 130 | run: | 131 | uv run pytest tests/integration -v --run-docker 132 | env: 133 | AGENT_URL: http://0.0.0.0 134 | 135 | - name: Clean up containers 136 | if: always() 137 | run: | 138 | docker stop service-container app-container || true 139 | docker rm service-container app-container || true 140 | -------------------------------------------------------------------------------- /src/core/llm.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | from typing import TypeAlias 3 | 4 | from langchain_anthropic import ChatAnthropic 5 | from langchain_aws import ChatBedrock 6 | from langchain_community.chat_models import FakeListChatModel 7 | from langchain_google_genai import ChatGoogleGenerativeAI 8 | from langchain_groq import ChatGroq 9 | from langchain_ollama import ChatOllama 10 | from langchain_openai import AzureChatOpenAI, ChatOpenAI 11 | 12 | from core.settings import settings 13 | from schema.models import ( 14 | AllModelEnum, 15 | AnthropicModelName, 16 | AWSModelName, 17 | AzureOpenAIModelName, 18 | DeepseekModelName, 19 | FakeModelName, 20 | GoogleModelName, 21 | GroqModelName, 22 | OllamaModelName, 23 | OpenAIModelName, 24 | ) 25 | 26 | _MODEL_TABLE = { 27 | OpenAIModelName.GPT_4O_MINI: "gpt-4o-mini", 28 | OpenAIModelName.GPT_4O: "gpt-4o", 29 | AzureOpenAIModelName.AZURE_GPT_4O_MINI: settings.AZURE_OPENAI_DEPLOYMENT_MAP.get( 30 | "gpt-4o-mini", "" 31 | ), 32 | AzureOpenAIModelName.AZURE_GPT_4O: settings.AZURE_OPENAI_DEPLOYMENT_MAP.get("gpt-4o", ""), 33 | DeepseekModelName.DEEPSEEK_CHAT: "deepseek-chat", 34 | AnthropicModelName.HAIKU_3: "claude-3-haiku-20240307", 35 | AnthropicModelName.HAIKU_35: "claude-3-5-haiku-latest", 36 | AnthropicModelName.SONNET_35: "claude-3-5-sonnet-latest", 37 | GoogleModelName.GEMINI_15_FLASH: "gemini-1.5-flash", 38 | GroqModelName.LLAMA_31_8B: "llama-3.1-8b-instant", 39 | GroqModelName.LLAMA_33_70B: "llama-3.3-70b-versatile", 40 | GroqModelName.LLAMA_GUARD_3_8B: "llama-guard-3-8b", 41 | AWSModelName.BEDROCK_HAIKU: "anthropic.claude-3-5-haiku-20241022-v1:0", 42 | AWSModelName.BEDROCK_SONNET: "anthropic.claude-3-5-sonnet-20240620-v1:0", 43 | OllamaModelName.OLLAMA_GENERIC: "ollama", 44 | FakeModelName.FAKE: "fake", 45 | } 46 | 47 | ModelT: TypeAlias = ( 48 | ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI | ChatGroq | ChatBedrock | ChatOllama 49 | ) 50 | 51 | 52 | @cache 53 | def get_model(model_name: AllModelEnum, /) -> ModelT: 54 | # NOTE: models with streaming=True will send tokens as they are generated 55 | # if the /stream endpoint is called with stream_tokens=True (the default) 56 | api_model_name = _MODEL_TABLE.get(model_name) 57 | if not api_model_name: 58 | raise ValueError(f"Unsupported model: {model_name}") 59 | 60 | if model_name in OpenAIModelName: 61 | return ChatOpenAI(model=api_model_name, temperature=0.5, streaming=True) 62 | if model_name in AzureOpenAIModelName: 63 | if not settings.AZURE_OPENAI_API_KEY or not settings.AZURE_OPENAI_ENDPOINT: 64 | raise ValueError("Azure OpenAI API key and endpoint must be configured") 65 | 66 | return AzureChatOpenAI( 67 | azure_endpoint=settings.AZURE_OPENAI_ENDPOINT, 68 | deployment_name=api_model_name, 69 | api_version=settings.AZURE_OPENAI_API_VERSION, 70 | temperature=0.5, 71 | streaming=True, 72 | timeout=60, 73 | max_retries=3, 74 | ) 75 | if model_name in DeepseekModelName: 76 | return ChatOpenAI( 77 | model=api_model_name, 78 | temperature=0.5, 79 | streaming=True, 80 | openai_api_base="https://api.deepseek.com", 81 | openai_api_key=settings.DEEPSEEK_API_KEY, 82 | ) 83 | if model_name in AnthropicModelName: 84 | return ChatAnthropic(model=api_model_name, temperature=0.5, streaming=True) 85 | if model_name in GoogleModelName: 86 | return ChatGoogleGenerativeAI(model=api_model_name, temperature=0.5, streaming=True) 87 | if model_name in GroqModelName: 88 | if model_name == GroqModelName.LLAMA_GUARD_3_8B: 89 | return ChatGroq(model=api_model_name, temperature=0.0) 90 | return ChatGroq(model=api_model_name, temperature=0.5) 91 | if model_name in AWSModelName: 92 | return ChatBedrock(model_id=api_model_name, temperature=0.5) 93 | if model_name in OllamaModelName: 94 | if settings.OLLAMA_BASE_URL: 95 | chat_ollama = ChatOllama( 96 | model=settings.OLLAMA_MODEL, temperature=0.5, base_url=settings.OLLAMA_BASE_URL 97 | ) 98 | else: 99 | chat_ollama = ChatOllama(model=settings.OLLAMA_MODEL, temperature=0.5) 100 | return chat_ollama 101 | if model_name in FakeModelName: 102 | return FakeListChatModel(responses=["This is a test response from the fake model."]) 103 | -------------------------------------------------------------------------------- /src/agents/llama_guard.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from langchain_core.messages import AIMessage, AnyMessage, HumanMessage 4 | from langchain_core.prompts import PromptTemplate 5 | from pydantic import BaseModel, Field 6 | 7 | from core import get_model, settings 8 | from schema.models import GroqModelName 9 | 10 | 11 | class SafetyAssessment(Enum): 12 | SAFE = "safe" 13 | UNSAFE = "unsafe" 14 | ERROR = "error" 15 | 16 | 17 | class LlamaGuardOutput(BaseModel): 18 | safety_assessment: SafetyAssessment = Field(description="The safety assessment of the content.") 19 | unsafe_categories: list[str] = Field( 20 | description="If content is unsafe, the list of unsafe categories.", default=[] 21 | ) 22 | 23 | 24 | unsafe_content_categories = { 25 | "S1": "Violent Crimes.", 26 | "S2": "Non-Violent Crimes.", 27 | "S3": "Sex Crimes.", 28 | "S4": "Child Exploitation.", 29 | "S5": "Defamation.", 30 | "S6": "Specialized Advice.", 31 | "S7": "Privacy.", 32 | "S8": "Intellectual Property.", 33 | "S9": "Indiscriminate Weapons.", 34 | "S10": "Hate.", 35 | "S11": "Self-Harm.", 36 | "S12": "Sexual Content.", 37 | "S13": "Elections.", 38 | "S14": "Code Interpreter Abuse.", 39 | } 40 | 41 | categories_str = "\n".join([f"{k}: {v}" for k, v in unsafe_content_categories.items()]) 42 | 43 | llama_guard_instructions = f""" 44 | Task: Check if there is unsafe content in '{{role}}' messages in conversations according to our safety policy with the below categories. 45 | 46 | 47 | {categories_str} 48 | 49 | 50 | 51 | 52 | {{conversation_history}} 53 | 54 | 55 | 56 | Provide your safety assessment for ONLY THE LAST {{role}} in the above conversation: 57 | - First line must read 'safe' or 'unsafe'. 58 | - If unsafe, a second line must include a comma-separated list of violated categories.""" 59 | 60 | 61 | def parse_llama_guard_output(output: str) -> LlamaGuardOutput: 62 | if output == "safe": 63 | return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE) 64 | parsed_output = output.split("\n") 65 | if len(parsed_output) != 2 or parsed_output[0] != "unsafe": 66 | return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR) 67 | try: 68 | categories = parsed_output[1].split(",") 69 | readable_categories = [unsafe_content_categories[c.strip()].strip(".") for c in categories] 70 | return LlamaGuardOutput( 71 | safety_assessment=SafetyAssessment.UNSAFE, 72 | unsafe_categories=readable_categories, 73 | ) 74 | except KeyError: 75 | return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR) 76 | 77 | 78 | class LlamaGuard: 79 | def __init__(self) -> None: 80 | if settings.GROQ_API_KEY is None: 81 | print("GROQ_API_KEY not set, skipping LlamaGuard") 82 | self.model = None 83 | return 84 | self.model = get_model(GroqModelName.LLAMA_GUARD_3_8B).with_config(tags=["llama_guard"]) 85 | self.prompt = PromptTemplate.from_template(llama_guard_instructions) 86 | 87 | def _compile_prompt(self, role: str, messages: list[AnyMessage]) -> str: 88 | role_mapping = {"ai": "Agent", "human": "User"} 89 | messages_str = [ 90 | f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"] 91 | ] 92 | conversation_history = "\n\n".join(messages_str) 93 | return self.prompt.format(role=role, conversation_history=conversation_history) 94 | 95 | def invoke(self, role: str, messages: list[AnyMessage]) -> LlamaGuardOutput: 96 | if self.model is None: 97 | return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE) 98 | compiled_prompt = self._compile_prompt(role, messages) 99 | result = self.model.invoke([HumanMessage(content=compiled_prompt)]) 100 | return parse_llama_guard_output(result.content) 101 | 102 | async def ainvoke(self, role: str, messages: list[AnyMessage]) -> LlamaGuardOutput: 103 | if self.model is None: 104 | return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE) 105 | compiled_prompt = self._compile_prompt(role, messages) 106 | result = await self.model.ainvoke([HumanMessage(content=compiled_prompt)]) 107 | return parse_llama_guard_output(result.content) 108 | 109 | 110 | if __name__ == "__main__": 111 | llama_guard = LlamaGuard() 112 | output = llama_guard.invoke( 113 | "Agent", 114 | [ 115 | HumanMessage(content="What's a good way to harm an animal?"), 116 | AIMessage( 117 | content="There are many ways to harm animals, but some include hitting them with a stick, throwing rocks at them, or poisoning them." 118 | ), 119 | ], 120 | ) 121 | print(output) 122 | -------------------------------------------------------------------------------- /src/schema/schema.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, NotRequired 2 | 3 | from pydantic import BaseModel, Field, SerializeAsAny 4 | from typing_extensions import TypedDict 5 | 6 | from schema.models import AllModelEnum, AnthropicModelName, OpenAIModelName 7 | 8 | 9 | class AgentInfo(BaseModel): 10 | """Info about an available agent.""" 11 | 12 | key: str = Field( 13 | description="Agent key.", 14 | examples=["research-assistant"], 15 | ) 16 | description: str = Field( 17 | description="Description of the agent.", 18 | examples=["A research assistant for generating research papers."], 19 | ) 20 | 21 | 22 | class ServiceMetadata(BaseModel): 23 | """Metadata about the service including available agents and models.""" 24 | 25 | agents: list[AgentInfo] = Field( 26 | description="List of available agents.", 27 | ) 28 | models: list[AllModelEnum] = Field( 29 | description="List of available LLMs.", 30 | ) 31 | default_agent: str = Field( 32 | description="Default agent used when none is specified.", 33 | examples=["research-assistant"], 34 | ) 35 | default_model: AllModelEnum = Field( 36 | description="Default model used when none is specified.", 37 | ) 38 | 39 | 40 | class UserInput(BaseModel): 41 | """Basic user input for the agent.""" 42 | 43 | message: str = Field( 44 | description="User input to the agent.", 45 | examples=["What is the weather in Tokyo?"], 46 | ) 47 | model: SerializeAsAny[AllModelEnum] | None = Field( 48 | title="Model", 49 | description="LLM Model to use for the agent.", 50 | default=OpenAIModelName.GPT_4O_MINI, 51 | examples=[OpenAIModelName.GPT_4O_MINI, AnthropicModelName.HAIKU_35], 52 | ) 53 | thread_id: str | None = Field( 54 | description="Thread ID to persist and continue a multi-turn conversation.", 55 | default=None, 56 | examples=["847c6285-8fc9-4560-a83f-4e6285809254"], 57 | ) 58 | agent_config: dict[str, Any] = Field( 59 | description="Additional configuration to pass through to the agent", 60 | default={}, 61 | examples=[{"spicy_level": 0.8}], 62 | ) 63 | 64 | 65 | class StreamInput(UserInput): 66 | """User input for streaming the agent's response.""" 67 | 68 | stream_tokens: bool = Field( 69 | description="Whether to stream LLM tokens to the client.", 70 | default=True, 71 | ) 72 | 73 | 74 | class ToolCall(TypedDict): 75 | """Represents a request to call a tool.""" 76 | 77 | name: str 78 | """The name of the tool to be called.""" 79 | args: dict[str, Any] 80 | """The arguments to the tool call.""" 81 | id: str | None 82 | """An identifier associated with the tool call.""" 83 | type: NotRequired[Literal["tool_call"]] 84 | 85 | 86 | class ChatMessage(BaseModel): 87 | """Message in a chat.""" 88 | 89 | type: Literal["human", "ai", "tool", "custom"] = Field( 90 | description="Role of the message.", 91 | examples=["human", "ai", "tool", "custom"], 92 | ) 93 | content: str = Field( 94 | description="Content of the message.", 95 | examples=["Hello, world!"], 96 | ) 97 | tool_calls: list[ToolCall] = Field( 98 | description="Tool calls in the message.", 99 | default=[], 100 | ) 101 | tool_call_id: str | None = Field( 102 | description="Tool call that this message is responding to.", 103 | default=None, 104 | examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"], 105 | ) 106 | run_id: str | None = Field( 107 | description="Run ID of the message.", 108 | default=None, 109 | examples=["847c6285-8fc9-4560-a83f-4e6285809254"], 110 | ) 111 | response_metadata: dict[str, Any] = Field( 112 | description="Response metadata. For example: response headers, logprobs, token counts.", 113 | default={}, 114 | ) 115 | custom_data: dict[str, Any] = Field( 116 | description="Custom message data.", 117 | default={}, 118 | ) 119 | 120 | def pretty_repr(self) -> str: 121 | """Get a pretty representation of the message.""" 122 | base_title = self.type.title() + " Message" 123 | padded = " " + base_title + " " 124 | sep_len = (80 - len(padded)) // 2 125 | sep = "=" * sep_len 126 | second_sep = sep + "=" if len(padded) % 2 else sep 127 | title = f"{sep}{padded}{second_sep}" 128 | return f"{title}\n\n{self.content}" 129 | 130 | def pretty_print(self) -> None: 131 | print(self.pretty_repr()) # noqa: T201 132 | 133 | 134 | class Feedback(BaseModel): 135 | """Feedback for a run, to record to LangSmith.""" 136 | 137 | run_id: str = Field( 138 | description="Run ID to record feedback for.", 139 | examples=["847c6285-8fc9-4560-a83f-4e6285809254"], 140 | ) 141 | key: str = Field( 142 | description="Feedback key.", 143 | examples=["human-feedback-stars"], 144 | ) 145 | score: float = Field( 146 | description="Feedback score.", 147 | examples=[0.8], 148 | ) 149 | kwargs: dict[str, Any] = Field( 150 | description="Additional feedback kwargs, passed to LangSmith.", 151 | default={}, 152 | examples=[{"comment": "In-line human feedback"}], 153 | ) 154 | 155 | 156 | class FeedbackResponse(BaseModel): 157 | status: Literal["success"] = "success" 158 | 159 | 160 | class ChatHistoryInput(BaseModel): 161 | """Input for retrieving chat history.""" 162 | 163 | thread_id: str = Field( 164 | description="Thread ID to persist and continue a multi-turn conversation.", 165 | examples=["847c6285-8fc9-4560-a83f-4e6285809254"], 166 | ) 167 | 168 | 169 | class ChatHistory(BaseModel): 170 | messages: list[ChatMessage] 171 | -------------------------------------------------------------------------------- /src/agents/research_assistant.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Literal 3 | 4 | from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun 5 | from langchain_community.utilities import OpenWeatherMapAPIWrapper 6 | from langchain_core.language_models.chat_models import BaseChatModel 7 | from langchain_core.messages import AIMessage, SystemMessage 8 | from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable 9 | from langgraph.checkpoint.memory import MemorySaver 10 | from langgraph.graph import END, MessagesState, StateGraph 11 | from langgraph.managed import RemainingSteps 12 | from langgraph.prebuilt import ToolNode 13 | 14 | from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment 15 | from agents.tools import calculator 16 | from core import get_model, settings 17 | 18 | 19 | class AgentState(MessagesState, total=False): 20 | """`total=False` is PEP589 specs. 21 | 22 | documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality 23 | """ 24 | 25 | safety: LlamaGuardOutput 26 | remaining_steps: RemainingSteps 27 | 28 | 29 | web_search = DuckDuckGoSearchResults(name="WebSearch") 30 | tools = [web_search, calculator] 31 | 32 | # Add weather tool if API key is set 33 | # Register for an API key at https://openweathermap.org/api/ 34 | if settings.OPENWEATHERMAP_API_KEY: 35 | wrapper = OpenWeatherMapAPIWrapper( 36 | openweathermap_api_key=settings.OPENWEATHERMAP_API_KEY.get_secret_value() 37 | ) 38 | tools.append(OpenWeatherMapQueryRun(name="Weather", api_wrapper=wrapper)) 39 | 40 | current_date = datetime.now().strftime("%B %d, %Y") 41 | instructions = f""" 42 | You are a helpful research assistant with the ability to search the web and use other tools. 43 | Today's date is {current_date}. 44 | 45 | NOTE: THE USER CAN'T SEE THE TOOL RESPONSE. 46 | 47 | A few things to remember: 48 | - Please include markdown-formatted links to any citations used in your response. Only include one 49 | or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS. 50 | - Use calculator tool with numexpr to answer math questions. The user does not understand numexpr, 51 | so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)". 52 | """ 53 | 54 | 55 | def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: 56 | model = model.bind_tools(tools) 57 | preprocessor = RunnableLambda( 58 | lambda state: [SystemMessage(content=instructions)] + state["messages"], 59 | name="StateModifier", 60 | ) 61 | return preprocessor | model 62 | 63 | 64 | def format_safety_message(safety: LlamaGuardOutput) -> AIMessage: 65 | content = ( 66 | f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}" 67 | ) 68 | return AIMessage(content=content) 69 | 70 | 71 | async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: 72 | m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) 73 | model_runnable = wrap_model(m) 74 | response = await model_runnable.ainvoke(state, config) 75 | 76 | # Run llama guard check here to avoid returning the message if it's unsafe 77 | llama_guard = LlamaGuard() 78 | safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response]) 79 | if safety_output.safety_assessment == SafetyAssessment.UNSAFE: 80 | return {"messages": [format_safety_message(safety_output)], "safety": safety_output} 81 | 82 | if state["remaining_steps"] < 2 and response.tool_calls: 83 | return { 84 | "messages": [ 85 | AIMessage( 86 | id=response.id, 87 | content="Sorry, need more steps to process this request.", 88 | ) 89 | ] 90 | } 91 | # We return a list, because this will get added to the existing list 92 | return {"messages": [response]} 93 | 94 | 95 | async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState: 96 | llama_guard = LlamaGuard() 97 | safety_output = await llama_guard.ainvoke("User", state["messages"]) 98 | return {"safety": safety_output} 99 | 100 | 101 | async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState: 102 | safety: LlamaGuardOutput = state["safety"] 103 | return {"messages": [format_safety_message(safety)]} 104 | 105 | 106 | # Define the graph 107 | agent = StateGraph(AgentState) 108 | agent.add_node("model", acall_model) 109 | agent.add_node("tools", ToolNode(tools)) 110 | agent.add_node("guard_input", llama_guard_input) 111 | agent.add_node("block_unsafe_content", block_unsafe_content) 112 | agent.set_entry_point("guard_input") 113 | 114 | 115 | # Check for unsafe input and block further processing if found 116 | def check_safety(state: AgentState) -> Literal["unsafe", "safe"]: 117 | safety: LlamaGuardOutput = state["safety"] 118 | match safety.safety_assessment: 119 | case SafetyAssessment.UNSAFE: 120 | return "unsafe" 121 | case _: 122 | return "safe" 123 | 124 | 125 | agent.add_conditional_edges( 126 | "guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"} 127 | ) 128 | 129 | # Always END after blocking unsafe content 130 | agent.add_edge("block_unsafe_content", END) 131 | 132 | # Always run "model" after "tools" 133 | agent.add_edge("tools", "model") 134 | 135 | 136 | # After "model", if there are tool calls, run "tools". Otherwise END. 137 | def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]: 138 | last_message = state["messages"][-1] 139 | if not isinstance(last_message, AIMessage): 140 | raise TypeError(f"Expected AIMessage, got {type(last_message)}") 141 | if last_message.tool_calls: 142 | return "tools" 143 | return "done" 144 | 145 | 146 | agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END}) 147 | 148 | research_assistant = agent.compile(checkpointer=MemorySaver()) 149 | -------------------------------------------------------------------------------- /tests/app/test_streamlit_app.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from unittest.mock import AsyncMock, Mock 3 | 4 | import pytest 5 | from streamlit.testing.v1 import AppTest 6 | 7 | from client import AgentClientError 8 | from schema import ChatHistory, ChatMessage 9 | from schema.models import OpenAIModelName 10 | 11 | 12 | def test_app_simple_non_streaming(mock_agent_client): 13 | """Test the full app - happy path""" 14 | at = AppTest.from_file("../../src/streamlit_app.py").run() 15 | 16 | WELCOME_START = "Hello! I'm an AI-powered research assistant" 17 | PROMPT = "Know any jokes?" 18 | RESPONSE = "Sure! Here's a joke:" 19 | 20 | mock_agent_client.ainvoke = AsyncMock( 21 | return_value=ChatMessage(type="ai", content=RESPONSE), 22 | ) 23 | 24 | assert at.chat_message[0].avatar == "assistant" 25 | assert at.chat_message[0].markdown[0].value.startswith(WELCOME_START) 26 | 27 | at.sidebar.toggle[0].set_value(False) # Use Streaming = False 28 | at.chat_input[0].set_value(PROMPT).run() 29 | print(at) 30 | assert at.chat_message[0].avatar == "user" 31 | assert at.chat_message[0].markdown[0].value == PROMPT 32 | assert at.chat_message[1].avatar == "assistant" 33 | assert at.chat_message[1].markdown[0].value == RESPONSE 34 | assert not at.exception 35 | 36 | 37 | def test_app_settings(mock_agent_client): 38 | """Test the full app - happy path""" 39 | at = AppTest.from_file("../../src/streamlit_app.py").run() 40 | 41 | PROMPT = "Know any jokes?" 42 | RESPONSE = "Sure! Here's a joke:" 43 | 44 | mock_agent_client.ainvoke = AsyncMock( 45 | return_value=ChatMessage(type="ai", content=RESPONSE), 46 | ) 47 | 48 | at.sidebar.toggle[0].set_value(False) # Use Streaming = False 49 | assert at.sidebar.selectbox[0].value == "gpt-4o" 50 | assert mock_agent_client.agent == "test-agent" 51 | at.sidebar.selectbox[0].set_value("gpt-4o-mini") 52 | at.sidebar.selectbox[1].set_value("chatbot") 53 | at.chat_input[0].set_value(PROMPT).run() 54 | print(at) 55 | 56 | # Basic checks 57 | assert at.chat_message[0].avatar == "user" 58 | assert at.chat_message[0].markdown[0].value == PROMPT 59 | assert at.chat_message[1].avatar == "assistant" 60 | assert at.chat_message[1].markdown[0].value == RESPONSE 61 | 62 | # Check the args match the settings 63 | assert mock_agent_client.agent == "chatbot" 64 | mock_agent_client.ainvoke.assert_called_with( 65 | message=PROMPT, 66 | model=OpenAIModelName.GPT_4O_MINI, 67 | thread_id="test session id", 68 | ) 69 | assert not at.exception 70 | 71 | 72 | def test_app_thread_id_history(mock_agent_client): 73 | """Test the thread_id is generated""" 74 | 75 | at = AppTest.from_file("../../src/streamlit_app.py").run() 76 | assert at.session_state.thread_id == "test session id" 77 | 78 | # Reset and set thread_id 79 | at = AppTest.from_file("../../src/streamlit_app.py") 80 | at.query_params["thread_id"] = "1234" 81 | HISTORY = [ 82 | ChatMessage(type="human", content="What is the weather?"), 83 | ChatMessage(type="ai", content="The weather is sunny."), 84 | ] 85 | mock_agent_client.get_history.return_value = ChatHistory(messages=HISTORY) 86 | at.run() 87 | print(at) 88 | assert at.session_state.thread_id == "1234" 89 | mock_agent_client.get_history.assert_called_with(thread_id="1234") 90 | assert at.chat_message[0].avatar == "user" 91 | assert at.chat_message[0].markdown[0].value == "What is the weather?" 92 | assert at.chat_message[1].avatar == "assistant" 93 | assert at.chat_message[1].markdown[0].value == "The weather is sunny." 94 | assert not at.exception 95 | 96 | 97 | def test_app_feedback(mock_agent_client): 98 | """TODO: Can't figure out how to interact with st.feedback""" 99 | 100 | pass 101 | 102 | 103 | @pytest.mark.asyncio 104 | async def test_app_streaming(mock_agent_client): 105 | """Test the app with streaming enabled - including tool messages""" 106 | at = AppTest.from_file("../../src/streamlit_app.py").run() 107 | 108 | # Setup mock streaming response 109 | PROMPT = "What is 6 * 7?" 110 | ai_with_tool = ChatMessage( 111 | type="ai", 112 | content="", 113 | tool_calls=[{"name": "calculator", "id": "test_call_id", "args": {"expression": "6 * 7"}}], 114 | ) 115 | tool_message = ChatMessage(type="tool", content="42", tool_call_id="test_call_id") 116 | final_ai_message = ChatMessage(type="ai", content="The answer is 42") 117 | 118 | messages = [ai_with_tool, tool_message, final_ai_message] 119 | 120 | async def amessage_iter() -> AsyncGenerator[ChatMessage, None]: 121 | for m in messages: 122 | yield m 123 | 124 | mock_agent_client.astream = Mock(return_value=amessage_iter()) 125 | 126 | at.toggle[0].set_value(True) # Use Streaming = True 127 | at.chat_input[0].set_value(PROMPT).run() 128 | print(at) 129 | 130 | assert at.chat_message[0].avatar == "user" 131 | assert at.chat_message[0].markdown[0].value == PROMPT 132 | response = at.chat_message[1] 133 | tool_status = response.status[0] 134 | assert response.avatar == "assistant" 135 | assert tool_status.label == "Tool Call: calculator" 136 | assert tool_status.icon == ":material/check:" 137 | assert tool_status.markdown[0].value == "Input:" 138 | assert tool_status.json[0].value == '{"expression": "6 * 7"}' 139 | assert tool_status.markdown[1].value == "Output:" 140 | assert tool_status.markdown[2].value == "42" 141 | assert response.markdown[-1].value == "The answer is 42" 142 | assert not at.exception 143 | 144 | 145 | @pytest.mark.asyncio 146 | async def test_app_init_error(mock_agent_client): 147 | """Test the app with an error in the agent initialization""" 148 | at = AppTest.from_file("../../src/streamlit_app.py").run() 149 | 150 | # Setup mock streaming response 151 | PROMPT = "What is 6 * 7?" 152 | mock_agent_client.astream.side_effect = AgentClientError("Error connecting to agent") 153 | 154 | at.toggle[0].set_value(True) # Use Streaming = True 155 | at.chat_input[0].set_value(PROMPT).run() 156 | print(at) 157 | 158 | assert at.chat_message[0].avatar == "assistant" 159 | assert at.chat_message[1].avatar == "user" 160 | assert at.chat_message[1].markdown[0].value == PROMPT 161 | assert at.error[0].value == "Error generating response: Error connecting to agent" 162 | assert not at.exception 163 | -------------------------------------------------------------------------------- /tests/core/test_settings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | from pydantic import SecretStr, ValidationError 7 | 8 | from core.settings import Settings, check_str_is_http 9 | from schema.models import AnthropicModelName, AzureOpenAIModelName, OpenAIModelName 10 | 11 | 12 | def test_check_str_is_http(): 13 | # Test valid HTTP URLs 14 | assert check_str_is_http("http://example.com/") == "http://example.com/" 15 | assert check_str_is_http("https://api.test.com/") == "https://api.test.com/" 16 | 17 | # Test invalid URLs 18 | with pytest.raises(ValidationError): 19 | check_str_is_http("not_a_url") 20 | with pytest.raises(ValidationError): 21 | check_str_is_http("ftp://invalid.com") 22 | 23 | 24 | def test_settings_default_values(): 25 | settings = Settings(_env_file=None) 26 | assert settings.HOST == "0.0.0.0" 27 | assert settings.PORT == 8080 28 | assert settings.USE_AWS_BEDROCK is False 29 | assert settings.USE_FAKE_MODEL is False 30 | 31 | 32 | def test_settings_no_api_keys(): 33 | # Test that settings raises error when no API keys are provided 34 | with patch.dict(os.environ, {}, clear=True): 35 | with pytest.raises(ValueError, match="At least one LLM API key must be provided"): 36 | _ = Settings(_env_file=None) 37 | 38 | 39 | def test_settings_with_openai_key(): 40 | with patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}, clear=True): 41 | settings = Settings(_env_file=None) 42 | assert settings.OPENAI_API_KEY == SecretStr("test_key") 43 | assert settings.DEFAULT_MODEL == OpenAIModelName.GPT_4O_MINI 44 | assert settings.AVAILABLE_MODELS == set(OpenAIModelName) 45 | 46 | 47 | def test_settings_with_anthropic_key(): 48 | with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_key"}, clear=True): 49 | settings = Settings(_env_file=None) 50 | assert settings.ANTHROPIC_API_KEY == SecretStr("test_key") 51 | assert settings.DEFAULT_MODEL == AnthropicModelName.HAIKU_3 52 | assert settings.AVAILABLE_MODELS == set(AnthropicModelName) 53 | 54 | 55 | def test_settings_with_multiple_api_keys(): 56 | with patch.dict( 57 | os.environ, 58 | { 59 | "OPENAI_API_KEY": "test_openai_key", 60 | "ANTHROPIC_API_KEY": "test_anthropic_key", 61 | }, 62 | clear=True, 63 | ): 64 | settings = Settings(_env_file=None) 65 | assert settings.OPENAI_API_KEY == SecretStr("test_openai_key") 66 | assert settings.ANTHROPIC_API_KEY == SecretStr("test_anthropic_key") 67 | # When multiple providers are available, OpenAI should be the default 68 | assert settings.DEFAULT_MODEL == OpenAIModelName.GPT_4O_MINI 69 | # Available models should include exactly all OpenAI and Anthropic models 70 | expected_models = set(OpenAIModelName) 71 | expected_models.update(set(AnthropicModelName)) 72 | assert settings.AVAILABLE_MODELS == expected_models 73 | 74 | 75 | def test_settings_base_url(): 76 | settings = Settings(HOST="0.0.0.0", PORT=8000, _env_file=None) 77 | assert settings.BASE_URL == "http://0.0.0.0:8000" 78 | 79 | 80 | def test_settings_is_dev(): 81 | settings = Settings(MODE="dev", _env_file=None) 82 | assert settings.is_dev() is True 83 | 84 | settings = Settings(MODE="prod", _env_file=None) 85 | assert settings.is_dev() is False 86 | 87 | 88 | def test_settings_with_azure_openai_key(): 89 | with patch.dict( 90 | os.environ, 91 | { 92 | "AZURE_OPENAI_API_KEY": "test_key", 93 | "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com", 94 | "AZURE_OPENAI_DEPLOYMENT_MAP": '{"gpt-4o": "deployment-1", "gpt-4o-mini": "deployment-2"}', 95 | }, 96 | clear=True, 97 | ): 98 | settings = Settings(_env_file=None) 99 | assert settings.AZURE_OPENAI_API_KEY.get_secret_value() == "test_key" 100 | assert settings.DEFAULT_MODEL == AzureOpenAIModelName.AZURE_GPT_4O_MINI 101 | assert settings.AVAILABLE_MODELS == set(AzureOpenAIModelName) 102 | 103 | 104 | def test_settings_with_both_openai_and_azure(): 105 | with patch.dict( 106 | os.environ, 107 | { 108 | "OPENAI_API_KEY": "test_openai_key", 109 | "AZURE_OPENAI_API_KEY": "test_azure_key", 110 | "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com", 111 | "AZURE_OPENAI_DEPLOYMENT_MAP": '{"gpt-4o": "deployment-1", "gpt-4o-mini": "deployment-2"}', 112 | }, 113 | clear=True, 114 | ): 115 | settings = Settings(_env_file=None) 116 | assert settings.OPENAI_API_KEY == SecretStr("test_openai_key") 117 | assert settings.AZURE_OPENAI_API_KEY == SecretStr("test_azure_key") 118 | # When multiple providers are available, OpenAI should be the default 119 | assert settings.DEFAULT_MODEL == OpenAIModelName.GPT_4O_MINI 120 | # Available models should include both OpenAI and Azure OpenAI models 121 | expected_models = set(OpenAIModelName) 122 | expected_models.update(set(AzureOpenAIModelName)) 123 | assert settings.AVAILABLE_MODELS == expected_models 124 | 125 | 126 | def test_settings_azure_deployment_names(): 127 | # Delete this test 128 | pass 129 | 130 | 131 | def test_settings_azure_missing_deployment_names(): 132 | with patch.dict( 133 | os.environ, 134 | { 135 | "AZURE_OPENAI_API_KEY": "test_key", 136 | "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com", 137 | }, 138 | clear=True, 139 | ): 140 | with pytest.raises(ValidationError, match="AZURE_OPENAI_DEPLOYMENT_MAP must be set"): 141 | Settings(_env_file=None) 142 | 143 | 144 | def test_settings_azure_deployment_map(): 145 | with patch.dict( 146 | os.environ, 147 | { 148 | "AZURE_OPENAI_API_KEY": "test_key", 149 | "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com", 150 | "AZURE_OPENAI_DEPLOYMENT_MAP": '{"gpt-4o": "deploy1", "gpt-4o-mini": "deploy2"}', 151 | }, 152 | clear=True, 153 | ): 154 | settings = Settings(_env_file=None) 155 | assert settings.AZURE_OPENAI_DEPLOYMENT_MAP == { 156 | "gpt-4o": "deploy1", 157 | "gpt-4o-mini": "deploy2", 158 | } 159 | 160 | 161 | def test_settings_azure_invalid_deployment_map(): 162 | with patch.dict( 163 | os.environ, 164 | { 165 | "AZURE_OPENAI_API_KEY": "test_key", 166 | "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com", 167 | "AZURE_OPENAI_DEPLOYMENT_MAP": '{"gpt-4o": "deploy1"}', # Missing required model 168 | }, 169 | clear=True, 170 | ): 171 | with pytest.raises(ValueError, match="Missing required Azure deployments"): 172 | Settings(_env_file=None) 173 | 174 | 175 | def test_settings_azure_openai(): 176 | """Test Azure OpenAI settings.""" 177 | deployment_map = {"gpt-4o": "deployment1", "gpt-4o-mini": "deployment2"} 178 | with patch.dict( 179 | os.environ, 180 | { 181 | "AZURE_OPENAI_API_KEY": "test-key", 182 | "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com", 183 | "AZURE_OPENAI_DEPLOYMENT_MAP": json.dumps(deployment_map), 184 | }, 185 | ): 186 | settings = Settings(_env_file=None) 187 | assert settings.AZURE_OPENAI_API_KEY.get_secret_value() == "test-key" 188 | assert settings.AZURE_OPENAI_ENDPOINT == "https://test.openai.azure.com" 189 | assert settings.AZURE_OPENAI_DEPLOYMENT_MAP == deployment_map 190 | -------------------------------------------------------------------------------- /src/core/settings.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from json import loads 3 | from typing import Annotated, Any 4 | 5 | from dotenv import find_dotenv 6 | from pydantic import ( 7 | BeforeValidator, 8 | Field, 9 | HttpUrl, 10 | SecretStr, 11 | TypeAdapter, 12 | computed_field, 13 | ) 14 | from pydantic_settings import BaseSettings, SettingsConfigDict 15 | 16 | from schema.models import ( 17 | AllModelEnum, 18 | AnthropicModelName, 19 | AWSModelName, 20 | AzureOpenAIModelName, 21 | DeepseekModelName, 22 | FakeModelName, 23 | GoogleModelName, 24 | GroqModelName, 25 | OllamaModelName, 26 | OpenAIModelName, 27 | Provider, 28 | ) 29 | 30 | 31 | class DatabaseType(StrEnum): 32 | SQLITE = "sqlite" 33 | POSTGRES = "postgres" 34 | 35 | 36 | def check_str_is_http(x: str) -> str: 37 | http_url_adapter = TypeAdapter(HttpUrl) 38 | return str(http_url_adapter.validate_python(x)) 39 | 40 | 41 | class Settings(BaseSettings): 42 | model_config = SettingsConfigDict( 43 | env_file=find_dotenv(), 44 | env_file_encoding="utf-8", 45 | env_ignore_empty=True, 46 | extra="ignore", 47 | validate_default=False, 48 | ) 49 | MODE: str | None = None 50 | 51 | HOST: str = "0.0.0.0" 52 | PORT: int = 8080 53 | 54 | AUTH_SECRET: SecretStr | None = None 55 | 56 | OPENAI_API_KEY: SecretStr | None = None 57 | DEEPSEEK_API_KEY: SecretStr | None = None 58 | ANTHROPIC_API_KEY: SecretStr | None = None 59 | GOOGLE_API_KEY: SecretStr | None = None 60 | GROQ_API_KEY: SecretStr | None = None 61 | USE_AWS_BEDROCK: bool = False 62 | OLLAMA_MODEL: str | None = None 63 | OLLAMA_BASE_URL: str | None = None 64 | USE_FAKE_MODEL: bool = False 65 | 66 | # If DEFAULT_MODEL is None, it will be set in model_post_init 67 | DEFAULT_MODEL: AllModelEnum | None = None # type: ignore[assignment] 68 | AVAILABLE_MODELS: set[AllModelEnum] = set() # type: ignore[assignment] 69 | 70 | OPENWEATHERMAP_API_KEY: SecretStr | None = None 71 | 72 | LANGCHAIN_TRACING_V2: bool = False 73 | LANGCHAIN_PROJECT: str = "default" 74 | LANGCHAIN_ENDPOINT: Annotated[str, BeforeValidator(check_str_is_http)] = ( 75 | "https://api.smith.langchain.com" 76 | ) 77 | LANGCHAIN_API_KEY: SecretStr | None = None 78 | 79 | # Database Configuration 80 | DATABASE_TYPE: DatabaseType = ( 81 | DatabaseType.SQLITE 82 | ) # Options: DatabaseType.SQLITE or DatabaseType.POSTGRES 83 | SQLITE_DB_PATH: str = "checkpoints.db" 84 | 85 | # PostgreSQL Configuration 86 | POSTGRES_USER: str | None = None 87 | POSTGRES_PASSWORD: SecretStr | None = None 88 | POSTGRES_HOST: str | None = None 89 | POSTGRES_PORT: int | None = None 90 | POSTGRES_DB: str | None = None 91 | POSTGRES_POOL_SIZE: int = Field( 92 | default=10, description="Maximum number of connections in the pool" 93 | ) 94 | POSTGRES_MIN_SIZE: int = Field( 95 | default=3, description="Minimum number of connections in the pool" 96 | ) 97 | POSTGRES_MAX_IDLE: int = Field(default=5, description="Maximum number of idle connections") 98 | 99 | # Azure OpenAI Settings 100 | AZURE_OPENAI_API_KEY: SecretStr | None = None 101 | AZURE_OPENAI_ENDPOINT: str | None = None 102 | AZURE_OPENAI_API_VERSION: str = "2024-02-15-preview" 103 | AZURE_OPENAI_DEPLOYMENT_MAP: dict[str, str] = Field( 104 | default_factory=dict, description="Map of model names to Azure deployment IDs" 105 | ) 106 | 107 | def model_post_init(self, __context: Any) -> None: 108 | api_keys = { 109 | Provider.OPENAI: self.OPENAI_API_KEY, 110 | Provider.DEEPSEEK: self.DEEPSEEK_API_KEY, 111 | Provider.ANTHROPIC: self.ANTHROPIC_API_KEY, 112 | Provider.GOOGLE: self.GOOGLE_API_KEY, 113 | Provider.GROQ: self.GROQ_API_KEY, 114 | Provider.AWS: self.USE_AWS_BEDROCK, 115 | Provider.OLLAMA: self.OLLAMA_MODEL, 116 | Provider.FAKE: self.USE_FAKE_MODEL, 117 | Provider.AZURE_OPENAI: self.AZURE_OPENAI_API_KEY, 118 | } 119 | active_keys = [k for k, v in api_keys.items() if v] 120 | if not active_keys: 121 | raise ValueError("At least one LLM API key must be provided.") 122 | 123 | for provider in active_keys: 124 | match provider: 125 | case Provider.OPENAI: 126 | if self.DEFAULT_MODEL is None: 127 | self.DEFAULT_MODEL = OpenAIModelName.GPT_4O_MINI 128 | self.AVAILABLE_MODELS.update(set(OpenAIModelName)) 129 | case Provider.DEEPSEEK: 130 | if self.DEFAULT_MODEL is None: 131 | self.DEFAULT_MODEL = DeepseekModelName.DEEPSEEK_CHAT 132 | self.AVAILABLE_MODELS.update(set(DeepseekModelName)) 133 | case Provider.ANTHROPIC: 134 | if self.DEFAULT_MODEL is None: 135 | self.DEFAULT_MODEL = AnthropicModelName.HAIKU_3 136 | self.AVAILABLE_MODELS.update(set(AnthropicModelName)) 137 | case Provider.GOOGLE: 138 | if self.DEFAULT_MODEL is None: 139 | self.DEFAULT_MODEL = GoogleModelName.GEMINI_15_FLASH 140 | self.AVAILABLE_MODELS.update(set(GoogleModelName)) 141 | case Provider.GROQ: 142 | if self.DEFAULT_MODEL is None: 143 | self.DEFAULT_MODEL = GroqModelName.LLAMA_31_8B 144 | self.AVAILABLE_MODELS.update(set(GroqModelName)) 145 | case Provider.AWS: 146 | if self.DEFAULT_MODEL is None: 147 | self.DEFAULT_MODEL = AWSModelName.BEDROCK_HAIKU 148 | self.AVAILABLE_MODELS.update(set(AWSModelName)) 149 | case Provider.OLLAMA: 150 | if self.DEFAULT_MODEL is None: 151 | self.DEFAULT_MODEL = OllamaModelName.OLLAMA_GENERIC 152 | self.AVAILABLE_MODELS.update(set(OllamaModelName)) 153 | case Provider.FAKE: 154 | if self.DEFAULT_MODEL is None: 155 | self.DEFAULT_MODEL = FakeModelName.FAKE 156 | self.AVAILABLE_MODELS.update(set(FakeModelName)) 157 | case Provider.AZURE_OPENAI: 158 | if self.DEFAULT_MODEL is None: 159 | self.DEFAULT_MODEL = AzureOpenAIModelName.AZURE_GPT_4O_MINI 160 | self.AVAILABLE_MODELS.update(set(AzureOpenAIModelName)) 161 | # Validate Azure OpenAI settings if Azure provider is available 162 | if not self.AZURE_OPENAI_API_KEY: 163 | raise ValueError("AZURE_OPENAI_API_KEY must be set") 164 | if not self.AZURE_OPENAI_ENDPOINT: 165 | raise ValueError("AZURE_OPENAI_ENDPOINT must be set") 166 | if not self.AZURE_OPENAI_DEPLOYMENT_MAP: 167 | raise ValueError("AZURE_OPENAI_DEPLOYMENT_MAP must be set") 168 | 169 | # Parse deployment map if it's a string 170 | if isinstance(self.AZURE_OPENAI_DEPLOYMENT_MAP, str): 171 | try: 172 | self.AZURE_OPENAI_DEPLOYMENT_MAP = loads( 173 | self.AZURE_OPENAI_DEPLOYMENT_MAP 174 | ) 175 | except Exception as e: 176 | raise ValueError(f"Invalid AZURE_OPENAI_DEPLOYMENT_MAP JSON: {e}") 177 | 178 | # Validate required deployments exist 179 | required_models = {"gpt-4o", "gpt-4o-mini"} 180 | missing_models = required_models - set(self.AZURE_OPENAI_DEPLOYMENT_MAP.keys()) 181 | if missing_models: 182 | raise ValueError(f"Missing required Azure deployments: {missing_models}") 183 | case _: 184 | raise ValueError(f"Unknown provider: {provider}") 185 | 186 | @computed_field 187 | @property 188 | def BASE_URL(self) -> str: 189 | return f"http://{self.HOST}:{self.PORT}" 190 | 191 | def is_dev(self) -> bool: 192 | return self.MODE == "dev" 193 | 194 | 195 | settings = Settings() 196 | -------------------------------------------------------------------------------- /src/service/service.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import warnings 4 | from collections.abc import AsyncGenerator 5 | from contextlib import asynccontextmanager 6 | from typing import Annotated, Any 7 | from uuid import UUID, uuid4 8 | 9 | from fastapi import APIRouter, Depends, FastAPI, HTTPException, status 10 | from fastapi.responses import StreamingResponse 11 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 12 | from langchain_core._api import LangChainBetaWarning 13 | from langchain_core.messages import AnyMessage, HumanMessage 14 | from langchain_core.runnables import RunnableConfig 15 | from langgraph.graph.state import CompiledStateGraph 16 | from langgraph.types import Command 17 | from langsmith import Client as LangsmithClient 18 | 19 | from agents import DEFAULT_AGENT, get_agent, get_all_agent_info 20 | from core import settings 21 | from memory import initialize_database 22 | from schema import ( 23 | ChatHistory, 24 | ChatHistoryInput, 25 | ChatMessage, 26 | Feedback, 27 | FeedbackResponse, 28 | ServiceMetadata, 29 | StreamInput, 30 | UserInput, 31 | ) 32 | from service.utils import ( 33 | convert_message_content_to_string, 34 | langchain_to_chat_message, 35 | remove_tool_calls, 36 | ) 37 | 38 | warnings.filterwarnings("ignore", category=LangChainBetaWarning) 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | def verify_bearer( 43 | http_auth: Annotated[ 44 | HTTPAuthorizationCredentials | None, 45 | Depends(HTTPBearer(description="Please provide AUTH_SECRET api key.", auto_error=False)), 46 | ], 47 | ) -> None: 48 | if not settings.AUTH_SECRET: 49 | return 50 | auth_secret = settings.AUTH_SECRET.get_secret_value() 51 | if not http_auth or http_auth.credentials != auth_secret: 52 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 53 | 54 | 55 | @asynccontextmanager 56 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 57 | """ 58 | Configurable lifespan that initializes the appropriate database checkpointer based on settings. 59 | """ 60 | try: 61 | async with initialize_database() as saver: 62 | await saver.setup() 63 | agents = get_all_agent_info() 64 | for a in agents: 65 | agent = get_agent(a.key) 66 | agent.checkpointer = saver 67 | yield 68 | except Exception as e: 69 | logger.error(f"Error during database initialization: {e}") 70 | raise 71 | 72 | 73 | app = FastAPI(lifespan=lifespan) 74 | router = APIRouter(dependencies=[Depends(verify_bearer)]) 75 | 76 | 77 | @router.get("/info") 78 | async def info() -> ServiceMetadata: 79 | models = list(settings.AVAILABLE_MODELS) 80 | models.sort() 81 | return ServiceMetadata( 82 | agents=get_all_agent_info(), 83 | models=models, 84 | default_agent=DEFAULT_AGENT, 85 | default_model=settings.DEFAULT_MODEL, 86 | ) 87 | 88 | 89 | def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], UUID]: 90 | run_id = uuid4() 91 | thread_id = user_input.thread_id or str(uuid4()) 92 | 93 | configurable = {"thread_id": thread_id, "model": user_input.model} 94 | 95 | if user_input.agent_config: 96 | if overlap := configurable.keys() & user_input.agent_config.keys(): 97 | raise HTTPException( 98 | status_code=422, 99 | detail=f"agent_config contains reserved keys: {overlap}", 100 | ) 101 | configurable.update(user_input.agent_config) 102 | 103 | kwargs = { 104 | "input": {"messages": [HumanMessage(content=user_input.message)]}, 105 | "config": RunnableConfig( 106 | configurable=configurable, 107 | run_id=run_id, 108 | ), 109 | } 110 | return kwargs, run_id 111 | 112 | 113 | @router.post("/{agent_id}/invoke") 114 | @router.post("/invoke") 115 | async def invoke(user_input: UserInput, agent_id: str = DEFAULT_AGENT) -> ChatMessage: 116 | """ 117 | Invoke an agent with user input to retrieve a final response. 118 | 119 | If agent_id is not provided, the default agent will be used. 120 | Use thread_id to persist and continue a multi-turn conversation. run_id kwarg 121 | is also attached to messages for recording feedback. 122 | """ 123 | agent: CompiledStateGraph = get_agent(agent_id) 124 | kwargs, run_id = _parse_input(user_input) 125 | try: 126 | response = await agent.ainvoke(**kwargs) 127 | output = langchain_to_chat_message(response["messages"][-1]) 128 | output.run_id = str(run_id) 129 | return output 130 | except Exception as e: 131 | logger.error(f"An exception occurred: {e}") 132 | raise HTTPException(status_code=500, detail="Unexpected error") 133 | 134 | 135 | async def message_generator( 136 | user_input: StreamInput, agent_id: str = DEFAULT_AGENT 137 | ) -> AsyncGenerator[str, None]: 138 | """ 139 | Generate a stream of messages from the agent. 140 | 141 | This is the workhorse method for the /stream endpoint. 142 | """ 143 | agent: CompiledStateGraph = get_agent(agent_id) 144 | kwargs, run_id = _parse_input(user_input) 145 | 146 | # Process streamed events from the graph and yield messages over the SSE stream. 147 | async for event in agent.astream_events(**kwargs, version="v2"): 148 | if not event: 149 | continue 150 | 151 | new_messages = [] 152 | # Yield messages written to the graph state after node execution finishes. 153 | if ( 154 | event["event"] == "on_chain_end" 155 | # on_chain_end gets called a bunch of times in a graph execution 156 | # This filters out everything except for "graph node finished" 157 | and any(t.startswith("graph:step:") for t in event.get("tags", [])) 158 | ): 159 | if isinstance(event["data"]["output"], Command): 160 | new_messages = event["data"]["output"].update.get("messages", []) 161 | elif "messages" in event["data"]["output"]: 162 | new_messages = event["data"]["output"]["messages"] 163 | 164 | # Also yield intermediate messages from agents.utils.CustomData.adispatch(). 165 | if event["event"] == "on_custom_event" and "custom_data_dispatch" in event.get("tags", []): 166 | new_messages = [event["data"]] 167 | 168 | for message in new_messages: 169 | try: 170 | chat_message = langchain_to_chat_message(message) 171 | chat_message.run_id = str(run_id) 172 | except Exception as e: 173 | logger.error(f"Error parsing message: {e}") 174 | yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n" 175 | continue 176 | # LangGraph re-sends the input message, which feels weird, so drop it 177 | if chat_message.type == "human" and chat_message.content == user_input.message: 178 | continue 179 | yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n" 180 | 181 | # Yield tokens streamed from LLMs. 182 | if ( 183 | event["event"] == "on_chat_model_stream" 184 | and user_input.stream_tokens 185 | and "llama_guard" not in event.get("tags", []) 186 | ): 187 | content = remove_tool_calls(event["data"]["chunk"].content) 188 | if content: 189 | # Empty content in the context of OpenAI usually means 190 | # that the model is asking for a tool to be invoked. 191 | # So we only print non-empty content. 192 | yield f"data: {json.dumps({'type': 'token', 'content': convert_message_content_to_string(content)})}\n\n" 193 | continue 194 | 195 | yield "data: [DONE]\n\n" 196 | 197 | 198 | def _sse_response_example() -> dict[int, Any]: 199 | return { 200 | status.HTTP_200_OK: { 201 | "description": "Server Sent Event Response", 202 | "content": { 203 | "text/event-stream": { 204 | "example": "data: {'type': 'token', 'content': 'Hello'}\n\ndata: {'type': 'token', 'content': ' World'}\n\ndata: [DONE]\n\n", 205 | "schema": {"type": "string"}, 206 | } 207 | }, 208 | } 209 | } 210 | 211 | 212 | @router.post( 213 | "/{agent_id}/stream", 214 | response_class=StreamingResponse, 215 | responses=_sse_response_example(), 216 | ) 217 | @router.post("/stream", response_class=StreamingResponse, responses=_sse_response_example()) 218 | async def stream(user_input: StreamInput, agent_id: str = DEFAULT_AGENT) -> StreamingResponse: 219 | """ 220 | Stream an agent's response to a user input, including intermediate messages and tokens. 221 | 222 | If agent_id is not provided, the default agent will be used. 223 | Use thread_id to persist and continue a multi-turn conversation. run_id kwarg 224 | is also attached to all messages for recording feedback. 225 | 226 | Set `stream_tokens=false` to return intermediate messages but not token-by-token. 227 | """ 228 | return StreamingResponse( 229 | message_generator(user_input, agent_id), 230 | media_type="text/event-stream", 231 | ) 232 | 233 | 234 | @router.post("/feedback") 235 | async def feedback(feedback: Feedback) -> FeedbackResponse: 236 | """ 237 | Record feedback for a run to LangSmith. 238 | 239 | This is a simple wrapper for the LangSmith create_feedback API, so the 240 | credentials can be stored and managed in the service rather than the client. 241 | See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post 242 | """ 243 | client = LangsmithClient() 244 | kwargs = feedback.kwargs or {} 245 | client.create_feedback( 246 | run_id=feedback.run_id, 247 | key=feedback.key, 248 | score=feedback.score, 249 | **kwargs, 250 | ) 251 | return FeedbackResponse() 252 | 253 | 254 | @router.post("/history") 255 | def history(input: ChatHistoryInput) -> ChatHistory: 256 | """ 257 | Get chat history. 258 | """ 259 | # TODO: Hard-coding DEFAULT_AGENT here is wonky 260 | agent: CompiledStateGraph = get_agent(DEFAULT_AGENT) 261 | try: 262 | state_snapshot = agent.get_state( 263 | config=RunnableConfig( 264 | configurable={ 265 | "thread_id": input.thread_id, 266 | } 267 | ) 268 | ) 269 | messages: list[AnyMessage] = state_snapshot.values["messages"] 270 | chat_messages: list[ChatMessage] = [langchain_to_chat_message(m) for m in messages] 271 | return ChatHistory(messages=chat_messages) 272 | except Exception as e: 273 | logger.error(f"An exception occurred: {e}") 274 | raise HTTPException(status_code=500, detail="Unexpected error") 275 | 276 | 277 | @app.get("/health") 278 | async def health_check(): 279 | """Health check endpoint.""" 280 | return {"status": "ok"} 281 | 282 | 283 | app.include_router(router) 284 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🧰 AI Agent Service Toolkit 2 | 3 | [![build status](https://github.com/JoshuaC215/agent-service-toolkit/actions/workflows/test.yml/badge.svg)](https://github.com/JoshuaC215/agent-service-toolkit/actions/workflows/test.yml) [![codecov](https://codecov.io/github/JoshuaC215/agent-service-toolkit/graph/badge.svg?token=5MTJSYWD05)](https://codecov.io/github/JoshuaC215/agent-service-toolkit) [![Python Version](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2FJoshuaC215%2Fagent-service-toolkit%2Frefs%2Fheads%2Fmain%2Fpyproject.toml)](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/pyproject.toml) 4 | [![GitHub License](https://img.shields.io/github/license/JoshuaC215/agent-service-toolkit)](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/LICENSE) [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_red.svg)](https://agent-service-toolkit.streamlit.app/) 5 | 6 | A full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit. 7 | 8 | It includes a [LangGraph](https://langchain-ai.github.io/langgraph/) agent, a [FastAPI](https://fastapi.tiangolo.com/) service to serve it, a client to interact with the service, and a [Streamlit](https://streamlit.io/) app that uses the client to provide a chat interface. Data structures and settings are built with [Pydantic](https://github.com/pydantic/pydantic). 9 | 10 | This project offers a template for you to easily build and run your own agents using the LangGraph framework. It demonstrates a complete setup from agent definition to user interface, making it easier to get started with LangGraph-based projects by providing a full, robust toolkit. 11 | 12 | **[🎥 Watch a video walkthrough of the repo and app](https://www.youtube.com/watch?v=pdYVHw_YCNY)** 13 | 14 | ## Overview 15 | 16 | ### [Try the app!](https://agent-service-toolkit.streamlit.app/) 17 | 18 | 19 | 20 | ### Quickstart 21 | 22 | Run directly in python 23 | 24 | ```sh 25 | # At least one LLM API key is required 26 | echo 'OPENAI_API_KEY=your_openai_api_key' >> .env 27 | 28 | # uv is recommended but "pip install ." also works 29 | pip install uv 30 | uv sync --frozen 31 | # "uv sync" creates .venv automatically 32 | source .venv/bin/activate 33 | python src/run_service.py 34 | 35 | # In another shell 36 | source .venv/bin/activate 37 | streamlit run src/streamlit_app.py 38 | ``` 39 | 40 | Run with docker 41 | 42 | ```sh 43 | echo 'OPENAI_API_KEY=your_openai_api_key' >> .env 44 | docker compose watch 45 | ``` 46 | 47 | ### Architecture Diagram 48 | 49 | 50 | 51 | ### Key Features 52 | 53 | 1. **LangGraph Agent**: A customizable agent built using the LangGraph framework. 54 | 1. **FastAPI Service**: Serves the agent with both streaming and non-streaming endpoints. 55 | 1. **Advanced Streaming**: A novel approach to support both token-based and message-based streaming. 56 | 1. **Content Moderation**: Implements LlamaGuard for content moderation (requires Groq API key). 57 | 1. **Streamlit Interface**: Provides a user-friendly chat interface for interacting with the agent. 58 | 1. **Multiple Agent Support**: Run multiple agents in the service and call by URL path 59 | 1. **Asynchronous Design**: Utilizes async/await for efficient handling of concurrent requests. 60 | 1. **Feedback Mechanism**: Includes a star-based feedback system integrated with LangSmith. 61 | 1. **Dynamic Metadata**: `/info` endpoint provides dynamically configured metadata about the service and available agents and models. 62 | 1. **Docker Support**: Includes Dockerfiles and a docker compose file for easy development and deployment. 63 | 1. **Testing**: Includes robust unit and integration tests for the full repo. 64 | 65 | ### Key Files 66 | 67 | The repository is structured as follows: 68 | 69 | - `src/agents/`: Defines several agents with different capabilities 70 | - `src/schema/`: Defines the protocol schema 71 | - `src/core/`: Core modules including LLM definition and settings 72 | - `src/service/service.py`: FastAPI service to serve the agents 73 | - `src/client/client.py`: Client to interact with the agent service 74 | - `src/streamlit_app.py`: Streamlit app providing a chat interface 75 | - `tests/`: Unit and integration tests 76 | 77 | ## Setup and Usage 78 | 79 | 1. Clone the repository: 80 | 81 | ```sh 82 | git clone https://github.com/JoshuaC215/agent-service-toolkit.git 83 | cd agent-service-toolkit 84 | ``` 85 | 86 | 2. Set up environment variables: 87 | Create a `.env` file in the root directory. At least one LLM API key or configuration is required. See the [`.env.example` file](./.env.example) for a full list of available environment variables, including a variety of model provider API keys, header-based authentication, LangSmith tracing, testing and development modes, and OpenWeatherMap API key. 88 | 89 | 3. You can now run the agent service and the Streamlit app locally, either with Docker or just using Python. The Docker setup is recommended for simpler environment setup and immediate reloading of the services when you make changes to your code. 90 | 91 | ### Building or customizing your own agent 92 | 93 | To customize the agent for your own use case: 94 | 95 | 1. Add your new agent to the `src/agents` directory. You can copy `research_assistant.py` or `chatbot.py` and modify it to change the agent's behavior and tools. 96 | 1. Import and add your new agent to the `agents` dictionary in `src/agents/agents.py`. Your agent can be called by `//invoke` or `//stream`. 97 | 1. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities. 98 | 99 | ### Docker Setup 100 | 101 | This project includes a Docker setup for easy development and deployment. The `compose.yaml` file defines two services: `agent_service` and `streamlit_app`. The `Dockerfile` for each is in their respective directories. 102 | 103 | For local development, we recommend using [docker compose watch](https://docs.docker.com/compose/file-watch/). This feature allows for a smoother development experience by automatically updating your containers when changes are detected in your source code. 104 | 105 | 1. Make sure you have Docker and Docker Compose (>=[2.23.0](https://docs.docker.com/compose/release-notes/#2230)) installed on your system. 106 | 107 | 2. Build and launch the services in watch mode: 108 | 109 | ```sh 110 | docker compose watch 111 | ``` 112 | 113 | 3. The services will now automatically update when you make changes to your code: 114 | - Changes in the relevant python files and directories will trigger updates for the relevantservices. 115 | - NOTE: If you make changes to the `pyproject.toml` or `uv.lock` files, you will need to rebuild the services by running `docker compose up --build`. 116 | 117 | 4. Access the Streamlit app by navigating to `http://localhost:8501` in your web browser. 118 | 119 | 5. The agent service API will be available at `http://0.0.0.0:8080`. You can also use the OpenAPI docs at `http://0.0.0.0:8080/redoc`. 120 | 121 | 6. Use `docker compose down` to stop the services. 122 | 123 | This setup allows you to develop and test your changes in real-time without manually restarting the services. 124 | 125 | ### Building other apps on the AgentClient 126 | 127 | The repo includes a generic `src/client/client.AgentClient` that can be used to interact with the agent service. This client is designed to be flexible and can be used to build other apps on top of the agent. It supports both synchronous and asynchronous invocations, and streaming and non-streaming requests. 128 | 129 | See the `src/run_client.py` file for full examples of how to use the `AgentClient`. A quick example: 130 | 131 | ```python 132 | from client import AgentClient 133 | client = AgentClient() 134 | 135 | response = client.invoke("Tell me a brief joke?") 136 | response.pretty_print() 137 | # ================================== Ai Message ================================== 138 | # 139 | # A man walked into a library and asked the librarian, "Do you have any books on Pavlov's dogs and Schrödinger's cat?" 140 | # The librarian replied, "It rings a bell, but I'm not sure if it's here or not." 141 | 142 | ``` 143 | 144 | ### Development with LangGraph Studio 145 | 146 | The agent supports [LangGraph Studio](https://github.com/langchain-ai/langgraph-studio), a new IDE for developing agents in LangGraph. 147 | 148 | You can simply install LangGraph Studio, add your `.env` file to the root directory as described above, and then launch LangGraph studio pointed at the root directory. Customize `langgraph.json` as needed. 149 | 150 | ### Using Ollama 151 | 152 | ⚠️ _**Note:** Ollama support in agent-service-toolkit is experimental and may not work as expected. The instructions below have been tested using Docker Desktop on a MacBook Pro. Please file an issue for any challenges you encounter._ 153 | 154 | You can also use [Ollama](https://ollama.com) to run the LLM powering the agent service. 155 | 156 | 1. Install Ollama using instructions from https://github.com/ollama/ollama 157 | 1. Install any model you want to use, e.g. `ollama pull llama3.2` and set the `OLLAMA_MODEL` environment variable to the model you want to use, e.g. `OLLAMA_MODEL=llama3.2` 158 | 159 | If you are running the service locally (e.g. `python src/run_service.py`), you should be all set! 160 | 161 | If you are running the service in Docker, you will also need to: 162 | 163 | 1. [Configure the Ollama server as described here](https://github.com/ollama/ollama/blob/main/docs/faq.md#how-do-i-configure-ollama-server), e.g. by running `launchctl setenv OLLAMA_HOST "0.0.0.0"` on MacOS and restart Ollama. 164 | 1. Set the `OLLAMA_BASE_URL` environment variable to the base URL of the Ollama server, e.g. `OLLAMA_BASE_URL=http://host.docker.internal:11434` 165 | 1. Alternatively, you can run `ollama/ollama` image in Docker and use a similar configuration (however it may be slower in some cases). 166 | 167 | ### Local development without Docker 168 | 169 | You can also run the agent service and the Streamlit app locally without Docker, just using a Python virtual environment. 170 | 171 | 1. Create a virtual environment and install dependencies: 172 | 173 | ```sh 174 | pip install uv 175 | uv sync --frozen 176 | source .venv/bin/activate 177 | ``` 178 | 179 | 2. Run the FastAPI server: 180 | 181 | ```sh 182 | python src/run_service.py 183 | ``` 184 | 185 | 3. In a separate terminal, run the Streamlit app: 186 | 187 | ```sh 188 | streamlit run src/streamlit_app.py 189 | ``` 190 | 191 | 4. Open your browser and navigate to the URL provided by Streamlit (usually `http://localhost:8501`). 192 | 193 | ## Projects built with or inspired by agent-service-toolkit 194 | 195 | The following are a few of the public projects that drew code or inspiration from this repo. 196 | 197 | - **[alexrisch/agent-web-kit](https://github.com/alexrisch/agent-web-kit)** A Next.JS frontend for agent-service-toolkit 198 | - **[raushan-in/dapa](https://github.com/raushan-in/dapa)** - Digital Arrest Protection App (DAPA) enables users to report financial scams and frauds efficiently via a user-friendly platform. 199 | 200 | **Please create a pull request editing the README or open a discussion with any new ones to be added!** Would love to include more projects. 201 | 202 | ## Contributing 203 | 204 | Contributions are welcome! Please feel free to submit a Pull Request. Currently the tests need to be run using the local development without Docker setup. To run the tests for the agent service: 205 | 206 | 1. Ensure you're in the project root directory and have activated your virtual environment. 207 | 208 | 2. Install the development dependencies and pre-commit hooks: 209 | 210 | ```sh 211 | pip install uv 212 | uv sync --frozen 213 | pre-commit install 214 | ``` 215 | 216 | 3. Run the tests using pytest: 217 | 218 | ```sh 219 | pytest 220 | ``` 221 | 222 | ## License 223 | 224 | This project is licensed under the MIT License - see the LICENSE file for details. 225 | -------------------------------------------------------------------------------- /tests/client/test_client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from unittest.mock import AsyncMock, Mock, patch 4 | 5 | import pytest 6 | from httpx import Request, Response 7 | 8 | from client import AgentClient, AgentClientError 9 | from schema import AgentInfo, ChatHistory, ChatMessage, ServiceMetadata 10 | from schema.models import OpenAIModelName 11 | 12 | 13 | def test_init(mock_env): 14 | """Test client initialization with different parameters.""" 15 | # Test default values 16 | client = AgentClient(get_info=False) 17 | assert client.base_url == "http://0.0.0.0" 18 | assert client.timeout is None 19 | 20 | # Test custom values 21 | client = AgentClient( 22 | base_url="http://test", 23 | timeout=30.0, 24 | get_info=False, 25 | ) 26 | assert client.base_url == "http://test" 27 | assert client.timeout == 30.0 28 | client.update_agent("test-agent", verify=False) 29 | assert client.agent == "test-agent" 30 | 31 | 32 | def test_headers(mock_env): 33 | """Test header generation with and without auth.""" 34 | # Test without auth 35 | client = AgentClient(get_info=False) 36 | assert client._headers == {} 37 | 38 | # Test with auth 39 | with patch.dict(os.environ, {"AUTH_SECRET": "test-secret"}, clear=True): 40 | client = AgentClient(get_info=False) 41 | assert client._headers == {"Authorization": "Bearer test-secret"} 42 | 43 | 44 | def test_invoke(agent_client): 45 | """Test synchronous invocation.""" 46 | QUESTION = "What is the weather?" 47 | ANSWER = "The weather is sunny." 48 | 49 | # Mock successful response 50 | mock_request = Request("POST", "http://test/invoke") 51 | mock_response = Response( 52 | 200, 53 | json={"type": "ai", "content": ANSWER}, 54 | request=mock_request, 55 | ) 56 | with patch("httpx.post", return_value=mock_response): 57 | response = agent_client.invoke(QUESTION) 58 | assert isinstance(response, ChatMessage) 59 | assert response.type == "ai" 60 | assert response.content == ANSWER 61 | 62 | # Test with model and thread_id 63 | with patch("httpx.post", return_value=mock_response) as mock_post: 64 | response = agent_client.invoke( 65 | QUESTION, 66 | model="gpt-4o", 67 | thread_id="test-thread", 68 | ) 69 | assert isinstance(response, ChatMessage) 70 | # Verify request 71 | args, kwargs = mock_post.call_args 72 | assert kwargs["json"]["message"] == QUESTION 73 | assert kwargs["json"]["model"] == "gpt-4o" 74 | assert kwargs["json"]["thread_id"] == "test-thread" 75 | 76 | # Test error response 77 | error_response = Response(500, text="Internal Server Error", request=mock_request) 78 | with patch("httpx.post", return_value=error_response): 79 | with pytest.raises(AgentClientError) as exc: 80 | agent_client.invoke(QUESTION) 81 | assert "500 Internal Server Error" in str(exc.value) 82 | 83 | 84 | @pytest.mark.asyncio 85 | async def test_ainvoke(agent_client): 86 | """Test asynchronous invocation.""" 87 | QUESTION = "What is the weather?" 88 | ANSWER = "The weather is sunny." 89 | 90 | # Test successful response 91 | mock_request = Request("POST", "http://test/invoke") 92 | mock_response = Response(200, json={"type": "ai", "content": ANSWER}, request=mock_request) 93 | with patch("httpx.AsyncClient.post", return_value=mock_response): 94 | response = await agent_client.ainvoke(QUESTION) 95 | assert isinstance(response, ChatMessage) 96 | assert response.type == "ai" 97 | assert response.content == ANSWER 98 | 99 | # Test with model and thread_id 100 | with patch("httpx.AsyncClient.post", return_value=mock_response) as mock_post: 101 | response = await agent_client.ainvoke( 102 | QUESTION, 103 | model="gpt-4o", 104 | thread_id="test-thread", 105 | ) 106 | assert isinstance(response, ChatMessage) 107 | assert response.type == "ai" 108 | assert response.content == ANSWER 109 | # Verify request 110 | args, kwargs = mock_post.call_args 111 | assert kwargs["json"]["message"] == QUESTION 112 | assert kwargs["json"]["model"] == "gpt-4o" 113 | assert kwargs["json"]["thread_id"] == "test-thread" 114 | 115 | # Test error response 116 | error_response = Response(500, text="Internal Server Error", request=mock_request) 117 | with patch("httpx.AsyncClient.post", return_value=error_response): 118 | with pytest.raises(AgentClientError) as exc: 119 | await agent_client.ainvoke(QUESTION) 120 | assert "500 Internal Server Error" in str(exc.value) 121 | 122 | 123 | def test_stream(agent_client): 124 | """Test synchronous streaming.""" 125 | QUESTION = "What is the weather?" 126 | TOKENS = ["The", " weather", " is", " sunny", "."] 127 | FINAL_ANSWER = "The weather is sunny." 128 | 129 | # Create mock response with streaming events 130 | events = ( 131 | [f"data: {json.dumps({'type': 'token', 'content': token})}" for token in TOKENS] 132 | + [ 133 | f"data: {json.dumps({'type': 'message', 'content': {'type': 'ai', 'content': FINAL_ANSWER}})}" 134 | ] 135 | + ["data: [DONE]"] 136 | ) 137 | 138 | # Mock the streaming response 139 | mock_response = Mock() 140 | mock_response.status_code = 200 141 | mock_response.iter_lines.return_value = events 142 | mock_response.request = Request("POST", "http://test/stream") 143 | mock_response.__enter__ = Mock(return_value=mock_response) 144 | mock_response.__exit__ = Mock(return_value=None) 145 | 146 | with patch("httpx.stream", return_value=mock_response): 147 | # Collect all streamed responses 148 | responses = list(agent_client.stream(QUESTION)) 149 | 150 | # Verify tokens were streamed 151 | assert len(responses) == len(TOKENS) + 1 # tokens + final message 152 | for i, token in enumerate(TOKENS): 153 | assert responses[i] == token 154 | 155 | # Verify final message 156 | final_message = responses[-1] 157 | assert isinstance(final_message, ChatMessage) 158 | assert final_message.type == "ai" 159 | assert final_message.content == FINAL_ANSWER 160 | 161 | # Test error response 162 | error_response = Response( 163 | 500, text="Internal Server Error", request=Request("POST", "http://test/stream") 164 | ) 165 | error_response_mock = Mock() 166 | error_response_mock.__enter__ = Mock(return_value=error_response) 167 | error_response_mock.__exit__ = Mock(return_value=None) 168 | with patch("httpx.stream", return_value=error_response_mock): 169 | with pytest.raises(AgentClientError) as exc: 170 | list(agent_client.stream(QUESTION)) 171 | assert "500 Internal Server Error" in str(exc.value) 172 | 173 | 174 | @pytest.mark.asyncio 175 | async def test_astream(agent_client): 176 | """Test asynchronous streaming.""" 177 | QUESTION = "What is the weather?" 178 | TOKENS = ["The", " weather", " is", " sunny", "."] 179 | FINAL_ANSWER = "The weather is sunny." 180 | 181 | # Create mock response with streaming events 182 | events = ( 183 | [f"data: {json.dumps({'type': 'token', 'content': token})}" for token in TOKENS] 184 | + [ 185 | f"data: {json.dumps({'type': 'message', 'content': {'type': 'ai', 'content': FINAL_ANSWER}})}" 186 | ] 187 | + ["data: [DONE]"] 188 | ) 189 | 190 | # Create an async iterator for the events 191 | async def async_events(): 192 | for event in events: 193 | yield event 194 | 195 | # Mock the streaming response 196 | mock_response = AsyncMock() 197 | mock_response.status_code = 200 198 | mock_response.request = Request("POST", "http://test/stream") 199 | mock_response.aiter_lines = Mock(return_value=async_events()) 200 | mock_response.__aenter__ = AsyncMock(return_value=mock_response) 201 | 202 | mock_client = AsyncMock() 203 | mock_client.__aenter__.return_value = mock_client 204 | mock_client.stream = Mock(return_value=mock_response) 205 | 206 | with patch("httpx.AsyncClient", return_value=mock_client): 207 | # Collect all streamed responses 208 | responses = [] 209 | async for response in agent_client.astream(QUESTION): 210 | responses.append(response) 211 | 212 | # Verify tokens were streamed 213 | assert len(responses) == len(TOKENS) + 1 # tokens + final message 214 | for i, token in enumerate(TOKENS): 215 | assert responses[i] == token 216 | 217 | # Verify final message 218 | final_message = responses[-1] 219 | assert isinstance(final_message, ChatMessage) 220 | assert final_message.type == "ai" 221 | assert final_message.content == FINAL_ANSWER 222 | 223 | # Test error response 224 | error_response = Response( 225 | 500, text="Internal Server Error", request=Request("POST", "http://test/stream") 226 | ) 227 | error_response_mock = AsyncMock() 228 | error_response_mock.__aenter__ = AsyncMock(return_value=error_response) 229 | 230 | mock_client.stream.return_value = error_response_mock 231 | 232 | with patch("httpx.AsyncClient", return_value=mock_client): 233 | with pytest.raises(AgentClientError) as exc: 234 | async for _ in agent_client.astream(QUESTION): 235 | pass 236 | assert "500 Internal Server Error" in str(exc.value) 237 | 238 | 239 | @pytest.mark.asyncio 240 | async def test_acreate_feedback(agent_client): 241 | """Test asynchronous feedback creation.""" 242 | RUN_ID = "test-run" 243 | KEY = "test-key" 244 | SCORE = 0.8 245 | KWARGS = {"comment": "Great response!"} 246 | 247 | # Test successful response 248 | mock_response = Response(200, json={}, request=Request("POST", "http://test/feedback")) 249 | with patch("httpx.AsyncClient.post", return_value=mock_response) as mock_post: 250 | await agent_client.acreate_feedback(RUN_ID, KEY, SCORE, KWARGS) 251 | # Verify request 252 | args, kwargs = mock_post.call_args 253 | assert kwargs["json"]["run_id"] == RUN_ID 254 | assert kwargs["json"]["key"] == KEY 255 | assert kwargs["json"]["score"] == SCORE 256 | assert kwargs["json"]["kwargs"] == KWARGS 257 | 258 | # Test error response 259 | error_response = Response( 260 | 500, text="Internal Server Error", request=Request("POST", "http://test/feedback") 261 | ) 262 | with patch("httpx.AsyncClient.post", return_value=error_response): 263 | with pytest.raises(AgentClientError) as exc: 264 | await agent_client.acreate_feedback(RUN_ID, KEY, SCORE) 265 | assert "500 Internal Server Error" in str(exc.value) 266 | 267 | 268 | def test_get_history(agent_client): 269 | """Test chat history retrieval.""" 270 | THREAD_ID = "test-thread" 271 | HISTORY = { 272 | "messages": [ 273 | {"type": "human", "content": "What is the weather?"}, 274 | {"type": "ai", "content": "The weather is sunny."}, 275 | ] 276 | } 277 | 278 | # Mock successful response 279 | mock_response = Response(200, json=HISTORY, request=Request("POST", "http://test/history")) 280 | with patch("httpx.post", return_value=mock_response): 281 | history = agent_client.get_history(THREAD_ID) 282 | assert isinstance(history, ChatHistory) 283 | assert len(history.messages) == 2 284 | assert history.messages[0].type == "human" 285 | assert history.messages[1].type == "ai" 286 | 287 | # Test error response 288 | error_response = Response( 289 | 500, text="Internal Server Error", request=Request("POST", "http://test/history") 290 | ) 291 | with patch("httpx.post", return_value=error_response): 292 | with pytest.raises(AgentClientError) as exc: 293 | agent_client.get_history(THREAD_ID) 294 | assert "500 Internal Server Error" in str(exc.value) 295 | 296 | 297 | def test_info(agent_client): 298 | assert agent_client.info is None 299 | assert agent_client.agent == "test-agent" 300 | 301 | # Mock info response 302 | test_info = ServiceMetadata( 303 | default_agent="custom-agent", 304 | agents=[AgentInfo(key="custom-agent", description="Custom agent")], 305 | default_model=OpenAIModelName.GPT_4O, 306 | models=[OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI], 307 | ) 308 | test_response = Response( 309 | 200, json=test_info.model_dump(), request=Request("GET", "http://test/info") 310 | ) 311 | 312 | # Update an existing client with info 313 | with patch("httpx.get", return_value=test_response): 314 | agent_client.retrieve_info() 315 | 316 | assert agent_client.info == test_info 317 | assert agent_client.agent == "custom-agent" 318 | 319 | # Test invalid update_agent 320 | with pytest.raises(AgentClientError) as exc: 321 | agent_client.update_agent("unknown-agent") 322 | assert "Agent unknown-agent not found in available agents: custom-agent" in str(exc.value) 323 | 324 | # Test a fresh client with info 325 | with patch("httpx.get", return_value=test_response): 326 | agent_client = AgentClient(base_url="http://test") 327 | assert agent_client.info == test_info 328 | assert agent_client.agent == "custom-agent" 329 | 330 | # Test error on invoke if no agent set 331 | agent_client = AgentClient(base_url="http://test", get_info=False) 332 | with pytest.raises(AgentClientError) as exc: 333 | agent_client.invoke("test") 334 | assert "No agent selected. Use update_agent() to select an agent." in str(exc.value) 335 | -------------------------------------------------------------------------------- /src/client/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections.abc import AsyncGenerator, Generator 4 | from typing import Any 5 | 6 | import httpx 7 | 8 | from schema import ( 9 | ChatHistory, 10 | ChatHistoryInput, 11 | ChatMessage, 12 | Feedback, 13 | ServiceMetadata, 14 | StreamInput, 15 | UserInput, 16 | ) 17 | 18 | 19 | class AgentClientError(Exception): 20 | pass 21 | 22 | 23 | class AgentClient: 24 | """Client for interacting with the agent service.""" 25 | 26 | def __init__( 27 | self, 28 | base_url: str = "http://0.0.0.0", 29 | agent: str = None, 30 | timeout: float | None = None, 31 | get_info: bool = True, 32 | ) -> None: 33 | """ 34 | Initialize the client. 35 | 36 | Args: 37 | base_url (str): The base URL of the agent service. 38 | agent (str): The name of the default agent to use. 39 | timeout (float, optional): The timeout for requests. 40 | get_info (bool, optional): Whether to fetch agent information on init. 41 | Default: True 42 | """ 43 | self.base_url = base_url 44 | self.auth_secret = os.getenv("AUTH_SECRET") 45 | self.timeout = timeout 46 | self.info: ServiceMetadata | None = None 47 | self.agent: str | None = None 48 | if get_info: 49 | self.retrieve_info() 50 | if agent: 51 | self.update_agent(agent) 52 | 53 | @property 54 | def _headers(self) -> dict[str, str]: 55 | headers = {} 56 | if self.auth_secret: 57 | headers["Authorization"] = f"Bearer {self.auth_secret}" 58 | return headers 59 | 60 | def retrieve_info(self) -> None: 61 | try: 62 | response = httpx.get( 63 | f"{self.base_url}/info", 64 | headers=self._headers, 65 | timeout=self.timeout, 66 | ) 67 | response.raise_for_status() 68 | except httpx.HTTPError as e: 69 | raise AgentClientError(f"Error getting service info: {e}") 70 | 71 | self.info: ServiceMetadata = ServiceMetadata.model_validate(response.json()) 72 | if not self.agent or self.agent not in [a.key for a in self.info.agents]: 73 | self.agent = self.info.default_agent 74 | 75 | def update_agent(self, agent: str, verify: bool = True) -> None: 76 | if verify: 77 | if not self.info: 78 | self.retrieve_info() 79 | agent_keys = [a.key for a in self.info.agents] 80 | if agent not in agent_keys: 81 | raise AgentClientError( 82 | f"Agent {agent} not found in available agents: {', '.join(agent_keys)}" 83 | ) 84 | self.agent = agent 85 | 86 | async def ainvoke( 87 | self, 88 | message: str, 89 | model: str | None = None, 90 | thread_id: str | None = None, 91 | agent_config: dict[str, Any] | None = None, 92 | ) -> ChatMessage: 93 | """ 94 | Invoke the agent asynchronously. Only the final message is returned. 95 | 96 | Args: 97 | message (str): The message to send to the agent 98 | model (str, optional): LLM model to use for the agent 99 | thread_id (str, optional): Thread ID for continuing a conversation 100 | agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent 101 | 102 | Returns: 103 | AnyMessage: The response from the agent 104 | """ 105 | if not self.agent: 106 | raise AgentClientError("No agent selected. Use update_agent() to select an agent.") 107 | request = UserInput(message=message) 108 | if thread_id: 109 | request.thread_id = thread_id 110 | if model: 111 | request.model = model 112 | if agent_config: 113 | request.agent_config = agent_config 114 | async with httpx.AsyncClient() as client: 115 | try: 116 | response = await client.post( 117 | f"{self.base_url}/{self.agent}/invoke", 118 | json=request.model_dump(), 119 | headers=self._headers, 120 | timeout=self.timeout, 121 | ) 122 | response.raise_for_status() 123 | except httpx.HTTPError as e: 124 | raise AgentClientError(f"Error: {e}") 125 | 126 | return ChatMessage.model_validate(response.json()) 127 | 128 | def invoke( 129 | self, 130 | message: str, 131 | model: str | None = None, 132 | thread_id: str | None = None, 133 | agent_config: dict[str, Any] | None = None, 134 | ) -> ChatMessage: 135 | """ 136 | Invoke the agent synchronously. Only the final message is returned. 137 | 138 | Args: 139 | message (str): The message to send to the agent 140 | model (str, optional): LLM model to use for the agent 141 | thread_id (str, optional): Thread ID for continuing a conversation 142 | agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent 143 | 144 | Returns: 145 | ChatMessage: The response from the agent 146 | """ 147 | if not self.agent: 148 | raise AgentClientError("No agent selected. Use update_agent() to select an agent.") 149 | request = UserInput(message=message) 150 | if thread_id: 151 | request.thread_id = thread_id 152 | if model: 153 | request.model = model 154 | if agent_config: 155 | request.agent_config = agent_config 156 | try: 157 | response = httpx.post( 158 | f"{self.base_url}/{self.agent}/invoke", 159 | json=request.model_dump(), 160 | headers=self._headers, 161 | timeout=self.timeout, 162 | ) 163 | response.raise_for_status() 164 | except httpx.HTTPError as e: 165 | raise AgentClientError(f"Error: {e}") 166 | 167 | return ChatMessage.model_validate(response.json()) 168 | 169 | def _parse_stream_line(self, line: str) -> ChatMessage | str | None: 170 | line = line.strip() 171 | if line.startswith("data: "): 172 | data = line[6:] 173 | if data == "[DONE]": 174 | return None 175 | try: 176 | parsed = json.loads(data) 177 | except Exception as e: 178 | raise Exception(f"Error JSON parsing message from server: {e}") 179 | match parsed["type"]: 180 | case "message": 181 | # Convert the JSON formatted message to an AnyMessage 182 | try: 183 | return ChatMessage.model_validate(parsed["content"]) 184 | except Exception as e: 185 | raise Exception(f"Server returned invalid message: {e}") 186 | case "token": 187 | # Yield the str token directly 188 | return parsed["content"] 189 | case "error": 190 | raise Exception(parsed["content"]) 191 | return None 192 | 193 | def stream( 194 | self, 195 | message: str, 196 | model: str | None = None, 197 | thread_id: str | None = None, 198 | agent_config: dict[str, Any] | None = None, 199 | stream_tokens: bool = True, 200 | ) -> Generator[ChatMessage | str, None, None]: 201 | """ 202 | Stream the agent's response synchronously. 203 | 204 | Each intermediate message of the agent process is yielded as a ChatMessage. 205 | If stream_tokens is True (the default value), the response will also yield 206 | content tokens from streaming models as they are generated. 207 | 208 | Args: 209 | message (str): The message to send to the agent 210 | model (str, optional): LLM model to use for the agent 211 | thread_id (str, optional): Thread ID for continuing a conversation 212 | agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent 213 | stream_tokens (bool, optional): Stream tokens as they are generated 214 | Default: True 215 | 216 | Returns: 217 | Generator[ChatMessage | str, None, None]: The response from the agent 218 | """ 219 | if not self.agent: 220 | raise AgentClientError("No agent selected. Use update_agent() to select an agent.") 221 | request = StreamInput(message=message, stream_tokens=stream_tokens) 222 | if thread_id: 223 | request.thread_id = thread_id 224 | if model: 225 | request.model = model 226 | if agent_config: 227 | request.agent_config = agent_config 228 | try: 229 | with httpx.stream( 230 | "POST", 231 | f"{self.base_url}/{self.agent}/stream", 232 | json=request.model_dump(), 233 | headers=self._headers, 234 | timeout=self.timeout, 235 | ) as response: 236 | response.raise_for_status() 237 | for line in response.iter_lines(): 238 | if line.strip(): 239 | parsed = self._parse_stream_line(line) 240 | if parsed is None: 241 | break 242 | yield parsed 243 | except httpx.HTTPError as e: 244 | raise AgentClientError(f"Error: {e}") 245 | 246 | async def astream( 247 | self, 248 | message: str, 249 | model: str | None = None, 250 | thread_id: str | None = None, 251 | agent_config: dict[str, Any] | None = None, 252 | stream_tokens: bool = True, 253 | ) -> AsyncGenerator[ChatMessage | str, None]: 254 | """ 255 | Stream the agent's response asynchronously. 256 | 257 | Each intermediate message of the agent process is yielded as an AnyMessage. 258 | If stream_tokens is True (the default value), the response will also yield 259 | content tokens from streaming modelsas they are generated. 260 | 261 | Args: 262 | message (str): The message to send to the agent 263 | model (str, optional): LLM model to use for the agent 264 | thread_id (str, optional): Thread ID for continuing a conversation 265 | agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent 266 | stream_tokens (bool, optional): Stream tokens as they are generated 267 | Default: True 268 | 269 | Returns: 270 | AsyncGenerator[ChatMessage | str, None]: The response from the agent 271 | """ 272 | if not self.agent: 273 | raise AgentClientError("No agent selected. Use update_agent() to select an agent.") 274 | request = StreamInput(message=message, stream_tokens=stream_tokens) 275 | if thread_id: 276 | request.thread_id = thread_id 277 | if model: 278 | request.model = model 279 | if agent_config: 280 | request.agent_config = agent_config 281 | async with httpx.AsyncClient() as client: 282 | try: 283 | async with client.stream( 284 | "POST", 285 | f"{self.base_url}/{self.agent}/stream", 286 | json=request.model_dump(), 287 | headers=self._headers, 288 | timeout=self.timeout, 289 | ) as response: 290 | response.raise_for_status() 291 | async for line in response.aiter_lines(): 292 | if line.strip(): 293 | parsed = self._parse_stream_line(line) 294 | if parsed is None: 295 | break 296 | yield parsed 297 | except httpx.HTTPError as e: 298 | raise AgentClientError(f"Error: {e}") 299 | 300 | async def acreate_feedback( 301 | self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {} 302 | ) -> None: 303 | """ 304 | Create a feedback record for a run. 305 | 306 | This is a simple wrapper for the LangSmith create_feedback API, so the 307 | credentials can be stored and managed in the service rather than the client. 308 | See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post 309 | """ 310 | request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs) 311 | async with httpx.AsyncClient() as client: 312 | try: 313 | response = await client.post( 314 | f"{self.base_url}/feedback", 315 | json=request.model_dump(), 316 | headers=self._headers, 317 | timeout=self.timeout, 318 | ) 319 | response.raise_for_status() 320 | response.json() 321 | except httpx.HTTPError as e: 322 | raise AgentClientError(f"Error: {e}") 323 | 324 | def get_history( 325 | self, 326 | thread_id: str, 327 | ) -> ChatHistory: 328 | """ 329 | Get chat history. 330 | 331 | Args: 332 | thread_id (str, optional): Thread ID for identifying a conversation 333 | """ 334 | request = ChatHistoryInput(thread_id=thread_id) 335 | try: 336 | response = httpx.post( 337 | f"{self.base_url}/history", 338 | json=request.model_dump(), 339 | headers=self._headers, 340 | timeout=self.timeout, 341 | ) 342 | response.raise_for_status() 343 | except httpx.HTTPError as e: 344 | raise AgentClientError(f"Error: {e}") 345 | 346 | return ChatHistory.model_validate(response.json()) 347 | -------------------------------------------------------------------------------- /tests/service/test_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | from types import SimpleNamespace 3 | from unittest.mock import AsyncMock, patch 4 | 5 | import langsmith 6 | import pytest 7 | from langchain_core.messages import AIMessage, HumanMessage 8 | from langgraph.pregel.types import StateSnapshot 9 | 10 | from agents.agents import Agent 11 | from schema import ChatHistory, ChatMessage, ServiceMetadata 12 | from schema.models import OpenAIModelName 13 | 14 | 15 | def test_invoke(test_client, mock_agent) -> None: 16 | QUESTION = "What is the weather in Tokyo?" 17 | ANSWER = "The weather in Tokyo is 70 degrees." 18 | mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=ANSWER)]} 19 | 20 | response = test_client.post("/invoke", json={"message": QUESTION}) 21 | assert response.status_code == 200 22 | 23 | mock_agent.ainvoke.assert_awaited_once() 24 | input_message = mock_agent.ainvoke.await_args.kwargs["input"]["messages"][0] 25 | assert input_message.content == QUESTION 26 | 27 | output = ChatMessage.model_validate(response.json()) 28 | assert output.type == "ai" 29 | assert output.content == ANSWER 30 | 31 | 32 | def test_invoke_custom_agent(test_client, mock_agent) -> None: 33 | """Test that /invoke works with a custom agent_id path parameter.""" 34 | CUSTOM_AGENT = "custom_agent" 35 | QUESTION = "What is the weather in Tokyo?" 36 | CUSTOM_ANSWER = "The weather in Tokyo is sunny." 37 | DEFAULT_ANSWER = "This is from the default agent." 38 | 39 | # Create a separate mock for the default agent 40 | default_mock = AsyncMock() 41 | default_mock.ainvoke.return_value = {"messages": [AIMessage(content=DEFAULT_ANSWER)]} 42 | 43 | # Configure our custom mock agent 44 | mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=CUSTOM_ANSWER)]} 45 | 46 | # Patch get_agent to return the correct agent based on the provided agent_id 47 | def agent_lookup(agent_id): 48 | if agent_id == CUSTOM_AGENT: 49 | return mock_agent 50 | return default_mock 51 | 52 | with patch("service.service.get_agent", side_effect=agent_lookup): 53 | response = test_client.post(f"/{CUSTOM_AGENT}/invoke", json={"message": QUESTION}) 54 | assert response.status_code == 200 55 | 56 | # Verify custom agent was called and default wasn't 57 | mock_agent.ainvoke.assert_awaited_once() 58 | default_mock.ainvoke.assert_not_awaited() 59 | 60 | input_message = mock_agent.ainvoke.await_args.kwargs["input"]["messages"][0] 61 | assert input_message.content == QUESTION 62 | 63 | output = ChatMessage.model_validate(response.json()) 64 | assert output.type == "ai" 65 | assert output.content == CUSTOM_ANSWER # Verify we got the custom agent's response 66 | 67 | 68 | def test_invoke_model_param(test_client, mock_agent) -> None: 69 | """Test that the model parameter is correctly passed to the agent.""" 70 | QUESTION = "What is the weather in Tokyo?" 71 | ANSWER = "The weather in Tokyo is sunny." 72 | CUSTOM_MODEL = "claude-3.5-sonnet" 73 | mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=ANSWER)]} 74 | 75 | response = test_client.post("/invoke", json={"message": QUESTION, "model": CUSTOM_MODEL}) 76 | assert response.status_code == 200 77 | 78 | # Verify the model was passed correctly in the config 79 | mock_agent.ainvoke.assert_awaited_once() 80 | config = mock_agent.ainvoke.await_args.kwargs["config"] 81 | assert config["configurable"]["model"] == CUSTOM_MODEL 82 | 83 | # Verify the response is still correct 84 | output = ChatMessage.model_validate(response.json()) 85 | assert output.type == "ai" 86 | assert output.content == ANSWER 87 | 88 | # Verify an invalid model throws a validation error 89 | INVALID_MODEL = "gpt-7-notreal" 90 | response = test_client.post("/invoke", json={"message": QUESTION, "model": INVALID_MODEL}) 91 | assert response.status_code == 422 92 | 93 | 94 | def test_invoke_custom_agent_config(test_client, mock_agent) -> None: 95 | """Test that the agent_config parameter is correctly passed to the agent.""" 96 | QUESTION = "What is the weather in Tokyo?" 97 | ANSWER = "The weather in Tokyo is sunny." 98 | CUSTOM_CONFIG = {"spicy_level": 0.1, "additional_param": "value_foo"} 99 | 100 | mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=ANSWER)]} 101 | 102 | response = test_client.post( 103 | "/invoke", json={"message": QUESTION, "agent_config": CUSTOM_CONFIG} 104 | ) 105 | assert response.status_code == 200 106 | 107 | # Verify the agent_config was passed correctly in the config 108 | mock_agent.ainvoke.assert_awaited_once() 109 | config = mock_agent.ainvoke.await_args.kwargs["config"] 110 | assert config["configurable"]["spicy_level"] == 0.1 111 | assert config["configurable"]["additional_param"] == "value_foo" 112 | 113 | # Verify the response is still correct 114 | output = ChatMessage.model_validate(response.json()) 115 | assert output.type == "ai" 116 | assert output.content == ANSWER 117 | 118 | # Verify a reserved key in agent_config throws a validation error 119 | INVALID_CONFIG = {"model": "gpt-4o"} 120 | response = test_client.post( 121 | "/invoke", json={"message": QUESTION, "agent_config": INVALID_CONFIG} 122 | ) 123 | assert response.status_code == 422 124 | 125 | 126 | @patch("service.service.LangsmithClient") 127 | def test_feedback(mock_client: langsmith.Client, test_client) -> None: 128 | ls_instance = mock_client.return_value 129 | ls_instance.create_feedback.return_value = None 130 | body = { 131 | "run_id": "847c6285-8fc9-4560-a83f-4e6285809254", 132 | "key": "human-feedback-stars", 133 | "score": 0.8, 134 | } 135 | response = test_client.post("/feedback", json=body) 136 | assert response.status_code == 200 137 | assert response.json() == {"status": "success"} 138 | ls_instance.create_feedback.assert_called_once_with( 139 | run_id="847c6285-8fc9-4560-a83f-4e6285809254", 140 | key="human-feedback-stars", 141 | score=0.8, 142 | ) 143 | 144 | 145 | def test_history(test_client, mock_agent) -> None: 146 | QUESTION = "What is the weather in Tokyo?" 147 | ANSWER = "The weather in Tokyo is 70 degrees." 148 | user_question = HumanMessage(content=QUESTION) 149 | agent_response = AIMessage(content=ANSWER) 150 | mock_agent.get_state.return_value = StateSnapshot( 151 | values={"messages": [user_question, agent_response]}, 152 | next=(), 153 | config={}, 154 | metadata=None, 155 | created_at=None, 156 | parent_config=None, 157 | tasks=(), 158 | ) 159 | 160 | response = test_client.post( 161 | "/history", json={"thread_id": "7bcc7cc1-99d7-4b1d-bdb5-e6f90ed44de6"} 162 | ) 163 | assert response.status_code == 200 164 | 165 | output = ChatHistory.model_validate(response.json()) 166 | assert output.messages[0].type == "human" 167 | assert output.messages[0].content == QUESTION 168 | assert output.messages[1].type == "ai" 169 | assert output.messages[1].content == ANSWER 170 | 171 | 172 | @pytest.mark.asyncio 173 | async def test_stream(test_client, mock_agent) -> None: 174 | """Test streaming tokens and messages.""" 175 | QUESTION = "What is the weather in Tokyo?" 176 | TOKENS = ["The", " weather", " in", " Tokyo", " is", " sunny", "."] 177 | FINAL_ANSWER = "The weather in Tokyo is sunny." 178 | 179 | # Configure mock to use our async iterator function 180 | events = [ 181 | { 182 | "event": "on_chat_model_stream", 183 | "data": {"chunk": SimpleNamespace(content=token)}, 184 | "tags": [], 185 | } 186 | for token in TOKENS 187 | ] + [ 188 | { 189 | "event": "on_chain_end", 190 | "data": {"output": {"messages": [AIMessage(content=FINAL_ANSWER)]}}, 191 | "tags": ["graph:step:1"], 192 | } 193 | ] 194 | 195 | async def mock_astream_events(**kwargs): 196 | for event in events: 197 | yield event 198 | 199 | mock_agent.astream_events = mock_astream_events 200 | 201 | # Make request with streaming 202 | with test_client.stream( 203 | "POST", "/stream", json={"message": QUESTION, "stream_tokens": True} 204 | ) as response: 205 | assert response.status_code == 200 206 | 207 | # Collect all SSE messages 208 | messages = [] 209 | for line in response.iter_lines(): 210 | if line and line.strip() != "data: [DONE]": # Skip [DONE] message 211 | messages.append(json.loads(line.lstrip("data: "))) 212 | 213 | # Verify streamed tokens 214 | token_messages = [msg for msg in messages if msg["type"] == "token"] 215 | assert len(token_messages) == len(TOKENS) 216 | for i, msg in enumerate(token_messages): 217 | assert msg["content"] == TOKENS[i] 218 | 219 | # Verify final message 220 | final_messages = [msg for msg in messages if msg["type"] == "message"] 221 | assert len(final_messages) == 1 222 | assert final_messages[0]["content"]["content"] == FINAL_ANSWER 223 | assert final_messages[0]["content"]["type"] == "ai" 224 | 225 | 226 | @pytest.mark.asyncio 227 | async def test_stream_no_tokens(test_client, mock_agent) -> None: 228 | """Test streaming without tokens.""" 229 | QUESTION = "What is the weather in Tokyo?" 230 | FINAL_ANSWER = "The weather in Tokyo is sunny." 231 | 232 | # Configure mock to use our async iterator function 233 | events = [ 234 | { 235 | "event": "on_chat_model_stream", 236 | "data": {"chunk": SimpleNamespace(content=token)}, 237 | "tags": [], 238 | } 239 | for token in ["The", " weather", " in", " Tokyo", " is", " sunny", "."] 240 | ] + [ 241 | { 242 | "event": "on_chain_end", 243 | "data": {"output": {"messages": [AIMessage(content=FINAL_ANSWER)]}}, 244 | "tags": ["graph:step:1"], 245 | } 246 | ] 247 | 248 | async def mock_astream_events(**kwargs): 249 | for event in events: 250 | yield event 251 | 252 | mock_agent.astream_events = mock_astream_events 253 | 254 | # Make request with streaming disabled 255 | with test_client.stream( 256 | "POST", "/stream", json={"message": QUESTION, "stream_tokens": False} 257 | ) as response: 258 | assert response.status_code == 200 259 | 260 | # Collect all SSE messages 261 | messages = [] 262 | for line in response.iter_lines(): 263 | if line and line.strip() != "data: [DONE]": # Skip [DONE] message 264 | messages.append(json.loads(line.lstrip("data: "))) 265 | 266 | # Verify no token messages 267 | token_messages = [msg for msg in messages if msg["type"] == "token"] 268 | assert len(token_messages) == 0 269 | 270 | # Verify final message 271 | final_messages = [msg for msg in messages if msg["type"] == "message"] 272 | assert len(final_messages) == 1 273 | assert final_messages[0]["content"]["content"] == FINAL_ANSWER 274 | assert final_messages[0]["content"]["type"] == "ai" 275 | 276 | 277 | def test_info(test_client, mock_settings) -> None: 278 | """Test that /info returns the correct service metadata.""" 279 | 280 | base_agent = Agent(description="A base agent.", graph=None) 281 | mock_settings.AUTH_SECRET = None 282 | mock_settings.DEFAULT_MODEL = OpenAIModelName.GPT_4O_MINI 283 | mock_settings.AVAILABLE_MODELS = {OpenAIModelName.GPT_4O_MINI, OpenAIModelName.GPT_4O} 284 | with patch.dict("agents.agents.agents", {"base-agent": base_agent}, clear=True): 285 | response = test_client.get("/info") 286 | assert response.status_code == 200 287 | output = ServiceMetadata.model_validate(response.json()) 288 | 289 | assert output.default_agent == "research-assistant" 290 | assert len(output.agents) == 1 291 | assert output.agents[0].key == "base-agent" 292 | assert output.agents[0].description == "A base agent." 293 | 294 | assert output.default_model == OpenAIModelName.GPT_4O_MINI 295 | assert output.models == [OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI] 296 | 297 | 298 | @pytest.mark.asyncio 299 | async def test_stream_with_commands(test_client, mock_agent) -> None: 300 | """Test streaming when agent returns Command objects.""" 301 | QUESTION = "Test command streaming" 302 | 303 | # Configure mock agent to return a Command followed by a regular message 304 | events = [ 305 | { 306 | "event": "on_chain_end", 307 | "data": {"output": {"messages": [AIMessage(content="Hello a")]}}, 308 | "tags": ["graph:step:1"], 309 | }, 310 | { 311 | "event": "on_chain_end", 312 | "data": {"output": {"messages": [AIMessage(content="Hello B")]}}, 313 | "tags": ["graph:step:2"], 314 | }, 315 | ] 316 | 317 | async def mock_astream_events(**kwargs): 318 | for event in events: 319 | yield event 320 | 321 | mock_agent.astream_events = mock_astream_events 322 | 323 | # Make request with streaming 324 | with test_client.stream( 325 | "POST", "/stream", json={"message": QUESTION, "stream_tokens": True} 326 | ) as response: 327 | assert response.status_code == 200 328 | 329 | # Collect all SSE messages 330 | messages = [] 331 | for line in response.iter_lines(): 332 | if line and line.strip() != "data: [DONE]": # Skip [DONE] message 333 | messages.append(json.loads(line.lstrip("data: "))) 334 | 335 | # Verify messages 336 | final_messages = [msg for msg in messages if msg["type"] == "message"] 337 | assert len(final_messages) == 2 338 | first_message = final_messages[0]["content"]["content"] 339 | second_message = final_messages[1]["content"]["content"] 340 | assert first_message in ["Hello a", "Hello b"] 341 | if first_message == "Hello a": 342 | assert second_message == "Hello B" 343 | else: 344 | assert second_message == "Hello C" 345 | assert final_messages[0]["content"]["type"] == "ai" 346 | assert final_messages[1]["content"]["type"] == "ai" 347 | -------------------------------------------------------------------------------- /src/streamlit_app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import urllib.parse 4 | from collections.abc import AsyncGenerator 5 | 6 | import streamlit as st 7 | from dotenv import load_dotenv 8 | from pydantic import ValidationError 9 | from streamlit.runtime.scriptrunner import get_script_run_ctx 10 | 11 | from client import AgentClient, AgentClientError 12 | from schema import ChatHistory, ChatMessage 13 | from schema.task_data import TaskData, TaskDataStatus 14 | 15 | # A Streamlit app for interacting with the langgraph agent via a simple chat interface. 16 | # The app has three main functions which are all run async: 17 | 18 | # - main() - sets up the streamlit app and high level structure 19 | # - draw_messages() - draws a set of chat messages - either replaying existing messages 20 | # or streaming new ones. 21 | # - handle_feedback() - Draws a feedback widget and records feedback from the user. 22 | 23 | # The app heavily uses AgentClient to interact with the agent's FastAPI endpoints. 24 | 25 | 26 | APP_TITLE = "Agent Service Toolkit" 27 | APP_ICON = "🧰" 28 | 29 | 30 | async def main() -> None: 31 | st.set_page_config( 32 | page_title=APP_TITLE, 33 | page_icon=APP_ICON, 34 | menu_items={}, 35 | ) 36 | 37 | # Hide the streamlit upper-right chrome 38 | st.html( 39 | """ 40 | 47 | """, 48 | ) 49 | if st.get_option("client.toolbarMode") != "minimal": 50 | st.set_option("client.toolbarMode", "minimal") 51 | await asyncio.sleep(0.1) 52 | st.rerun() 53 | 54 | if "agent_client" not in st.session_state: 55 | load_dotenv() 56 | agent_url = os.getenv("AGENT_URL") 57 | if not agent_url: 58 | host = os.getenv("HOST", "0.0.0.0") 59 | port = os.getenv("PORT", 8080) 60 | agent_url = f"http://{host}:{port}" 61 | try: 62 | with st.spinner("Connecting to agent service..."): 63 | st.session_state.agent_client = AgentClient(base_url=agent_url) 64 | except AgentClientError as e: 65 | st.error(f"Error connecting to agent service at {agent_url}: {e}") 66 | st.markdown("The service might be booting up. Try again in a few seconds.") 67 | st.stop() 68 | agent_client: AgentClient = st.session_state.agent_client 69 | 70 | if "thread_id" not in st.session_state: 71 | thread_id = st.query_params.get("thread_id") 72 | if not thread_id: 73 | thread_id = get_script_run_ctx().session_id 74 | messages = [] 75 | else: 76 | try: 77 | messages: ChatHistory = agent_client.get_history(thread_id=thread_id).messages 78 | except AgentClientError: 79 | st.error("No message history found for this Thread ID.") 80 | messages = [] 81 | st.session_state.messages = messages 82 | st.session_state.thread_id = thread_id 83 | 84 | # Config options 85 | with st.sidebar: 86 | st.header(f"{APP_ICON} {APP_TITLE}") 87 | "" 88 | "Full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit" 89 | with st.popover(":material/settings: Settings", use_container_width=True): 90 | model_idx = agent_client.info.models.index(agent_client.info.default_model) 91 | model = st.selectbox("LLM to use", options=agent_client.info.models, index=model_idx) 92 | agent_list = [a.key for a in agent_client.info.agents] 93 | agent_idx = agent_list.index(agent_client.info.default_agent) 94 | agent_client.agent = st.selectbox( 95 | "Agent to use", 96 | options=agent_list, 97 | index=agent_idx, 98 | ) 99 | use_streaming = st.toggle("Stream results", value=True) 100 | 101 | @st.dialog("Architecture") 102 | def architecture_dialog() -> None: 103 | st.image( 104 | "https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png?raw=true" 105 | ) 106 | "[View full size on Github](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png)" 107 | st.caption( 108 | "App hosted on [Streamlit Cloud](https://share.streamlit.io/) with FastAPI service running in [Azure](https://learn.microsoft.com/en-us/azure/app-service/)" 109 | ) 110 | 111 | if st.button(":material/schema: Architecture", use_container_width=True): 112 | architecture_dialog() 113 | 114 | with st.popover(":material/policy: Privacy", use_container_width=True): 115 | st.write( 116 | "Prompts, responses and feedback in this app are anonymously recorded and saved to LangSmith for product evaluation and improvement purposes only." 117 | ) 118 | 119 | @st.dialog("Share/resume chat") 120 | def share_chat_dialog() -> None: 121 | session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0] 122 | st_base_url = urllib.parse.urlunparse( 123 | [session.client.request.protocol, session.client.request.host, "", "", "", ""] 124 | ) 125 | # if it's not localhost, switch to https by default 126 | if not st_base_url.startswith("https") and "localhost" not in st_base_url: 127 | st_base_url = st_base_url.replace("http", "https") 128 | chat_url = f"{st_base_url}?thread_id={st.session_state.thread_id}" 129 | st.markdown(f"**Chat URL:**\n```text\n{chat_url}\n```") 130 | st.info("Copy the above URL to share or revisit this chat") 131 | 132 | if st.button(":material/upload: Share/resume chat", use_container_width=True): 133 | share_chat_dialog() 134 | 135 | "[View the source code](https://github.com/JoshuaC215/agent-service-toolkit)" 136 | st.caption( 137 | "Made with :material/favorite: by [Joshua](https://www.linkedin.com/in/joshua-k-carroll/) in Oakland" 138 | ) 139 | 140 | # Draw existing messages 141 | messages: list[ChatMessage] = st.session_state.messages 142 | 143 | if len(messages) == 0: 144 | WELCOME = "Hello! I'm an AI-powered research assistant with web search and a calculator. Ask me anything!" 145 | with st.chat_message("ai"): 146 | st.write(WELCOME) 147 | 148 | # draw_messages() expects an async iterator over messages 149 | async def amessage_iter() -> AsyncGenerator[ChatMessage, None]: 150 | for m in messages: 151 | yield m 152 | 153 | await draw_messages(amessage_iter()) 154 | 155 | # Generate new message if the user provided new input 156 | if user_input := st.chat_input(): 157 | messages.append(ChatMessage(type="human", content=user_input)) 158 | st.chat_message("human").write(user_input) 159 | try: 160 | if use_streaming: 161 | stream = agent_client.astream( 162 | message=user_input, 163 | model=model, 164 | thread_id=st.session_state.thread_id, 165 | ) 166 | await draw_messages(stream, is_new=True) 167 | else: 168 | response = await agent_client.ainvoke( 169 | message=user_input, 170 | model=model, 171 | thread_id=st.session_state.thread_id, 172 | ) 173 | messages.append(response) 174 | st.chat_message("ai").write(response.content) 175 | st.rerun() # Clear stale containers 176 | except AgentClientError as e: 177 | st.error(f"Error generating response: {e}") 178 | st.stop() 179 | 180 | # If messages have been generated, show feedback widget 181 | if len(messages) > 0 and st.session_state.last_message: 182 | with st.session_state.last_message: 183 | await handle_feedback() 184 | 185 | 186 | async def draw_messages( 187 | messages_agen: AsyncGenerator[ChatMessage | str, None], 188 | is_new: bool = False, 189 | ) -> None: 190 | """ 191 | Draws a set of chat messages - either replaying existing messages 192 | or streaming new ones. 193 | 194 | This function has additional logic to handle streaming tokens and tool calls. 195 | - Use a placeholder container to render streaming tokens as they arrive. 196 | - Use a status container to render tool calls. Track the tool inputs and outputs 197 | and update the status container accordingly. 198 | 199 | The function also needs to track the last message container in session state 200 | since later messages can draw to the same container. This is also used for 201 | drawing the feedback widget in the latest chat message. 202 | 203 | Args: 204 | messages_aiter: An async iterator over messages to draw. 205 | is_new: Whether the messages are new or not. 206 | """ 207 | 208 | # Keep track of the last message container 209 | last_message_type = None 210 | st.session_state.last_message = None 211 | 212 | # Placeholder for intermediate streaming tokens 213 | streaming_content = "" 214 | streaming_placeholder = None 215 | 216 | # Iterate over the messages and draw them 217 | while msg := await anext(messages_agen, None): 218 | # str message represents an intermediate token being streamed 219 | if isinstance(msg, str): 220 | # If placeholder is empty, this is the first token of a new message 221 | # being streamed. We need to do setup. 222 | if not streaming_placeholder: 223 | if last_message_type != "ai": 224 | last_message_type = "ai" 225 | st.session_state.last_message = st.chat_message("ai") 226 | with st.session_state.last_message: 227 | streaming_placeholder = st.empty() 228 | 229 | streaming_content += msg 230 | streaming_placeholder.write(streaming_content) 231 | continue 232 | if not isinstance(msg, ChatMessage): 233 | st.error(f"Unexpected message type: {type(msg)}") 234 | st.write(msg) 235 | st.stop() 236 | match msg.type: 237 | # A message from the user, the easiest case 238 | case "human": 239 | last_message_type = "human" 240 | st.chat_message("human").write(msg.content) 241 | 242 | # A message from the agent is the most complex case, since we need to 243 | # handle streaming tokens and tool calls. 244 | case "ai": 245 | # If we're rendering new messages, store the message in session state 246 | if is_new: 247 | st.session_state.messages.append(msg) 248 | 249 | # If the last message type was not AI, create a new chat message 250 | if last_message_type != "ai": 251 | last_message_type = "ai" 252 | st.session_state.last_message = st.chat_message("ai") 253 | 254 | with st.session_state.last_message: 255 | # If the message has content, write it out. 256 | # Reset the streaming variables to prepare for the next message. 257 | if msg.content: 258 | if streaming_placeholder: 259 | streaming_placeholder.write(msg.content) 260 | streaming_content = "" 261 | streaming_placeholder = None 262 | else: 263 | st.write(msg.content) 264 | 265 | if msg.tool_calls: 266 | # Create a status container for each tool call and store the 267 | # status container by ID to ensure results are mapped to the 268 | # correct status container. 269 | call_results = {} 270 | for tool_call in msg.tool_calls: 271 | status = st.status( 272 | f"""Tool Call: {tool_call["name"]}""", 273 | state="running" if is_new else "complete", 274 | ) 275 | call_results[tool_call["id"]] = status 276 | status.write("Input:") 277 | status.write(tool_call["args"]) 278 | 279 | # Expect one ToolMessage for each tool call. 280 | for _ in range(len(call_results)): 281 | tool_result: ChatMessage = await anext(messages_agen) 282 | if tool_result.type != "tool": 283 | st.error(f"Unexpected ChatMessage type: {tool_result.type}") 284 | st.write(tool_result) 285 | st.stop() 286 | 287 | # Record the message if it's new, and update the correct 288 | # status container with the result 289 | if is_new: 290 | st.session_state.messages.append(tool_result) 291 | status = call_results[tool_result.tool_call_id] 292 | status.write("Output:") 293 | status.write(tool_result.content) 294 | status.update(state="complete") 295 | 296 | case "custom": 297 | # CustomData example used by the bg-task-agent 298 | # See: 299 | # - src/agents/utils.py CustomData 300 | # - src/agents/bg_task_agent/task.py 301 | try: 302 | task_data: TaskData = TaskData.model_validate(msg.custom_data) 303 | except ValidationError: 304 | st.error("Unexpected CustomData message received from agent") 305 | st.write(msg.custom_data) 306 | st.stop() 307 | 308 | if is_new: 309 | st.session_state.messages.append(msg) 310 | 311 | if last_message_type != "task": 312 | last_message_type = "task" 313 | st.session_state.last_message = st.chat_message( 314 | name="task", avatar=":material/manufacturing:" 315 | ) 316 | with st.session_state.last_message: 317 | status = TaskDataStatus() 318 | 319 | status.add_and_draw_task_data(task_data) 320 | 321 | # In case of an unexpected message type, log an error and stop 322 | case _: 323 | st.error(f"Unexpected ChatMessage type: {msg.type}") 324 | st.write(msg) 325 | st.stop() 326 | 327 | 328 | async def handle_feedback() -> None: 329 | """Draws a feedback widget and records feedback from the user.""" 330 | 331 | # Keep track of last feedback sent to avoid sending duplicates 332 | if "last_feedback" not in st.session_state: 333 | st.session_state.last_feedback = (None, None) 334 | 335 | latest_run_id = st.session_state.messages[-1].run_id 336 | feedback = st.feedback("stars", key=latest_run_id) 337 | 338 | # If the feedback value or run ID has changed, send a new feedback record 339 | if feedback is not None and (latest_run_id, feedback) != st.session_state.last_feedback: 340 | # Normalize the feedback value (an index) to a score between 0 and 1 341 | normalized_score = (feedback + 1) / 5.0 342 | 343 | agent_client: AgentClient = st.session_state.agent_client 344 | try: 345 | await agent_client.acreate_feedback( 346 | run_id=latest_run_id, 347 | key="human-feedback-stars", 348 | score=normalized_score, 349 | kwargs={"comment": "In-line human feedback"}, 350 | ) 351 | except AgentClientError as e: 352 | st.error(f"Error recording feedback: {e}") 353 | st.stop() 354 | st.session_state.last_feedback = (latest_run_id, feedback) 355 | st.toast("Feedback recorded", icon=":material/reviews:") 356 | 357 | 358 | if __name__ == "__main__": 359 | asyncio.run(main()) 360 | --------------------------------------------------------------------------------