├── .gitignore ├── LICENSE ├── README.md ├── assets ├── full_history.png ├── last_message.png └── supervisor.png ├── examples.ipynb ├── llama_index_supervisor ├── __init__.py ├── agent_name.py ├── events.py ├── handoff.py └── supervisor.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | uv.lock 177 | .python-version 178 | 179 | test.cool -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 johnmalek312 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤖 llama-index-supervisor 2 | 3 | A Python library for creating hierarchical multi-agent systems using [LlamaIndex](https://github.com/run-llama/llama_index). Hierarchical systems are a type of multi-agent architecture where specialized agents are coordinated by a central supervisor agent. The supervisor controls all communication flow and task delegation, making decisions about which agent to invoke based on the current context and task requirements. (inspired by [Langgraph Supervisor](https://github.com/langchain-ai/langgraph-supervisor-py)) 4 | 5 | ## Features 6 | 7 | - 🤖 **Supervisor Workflow**: Create a supervisor to orchestrate multiple specialized agents or tools using LlamaIndex's workflow engine. 8 | - 🛠️ **Agent Handoff**: Built-in mechanism for the supervisor to hand off tasks to appropriate agents. 9 | - 📝 **Flexible Message History**: Control how agent interactions are reflected in the overall chat history (`full_history` or `last_message`). 10 | - 🧩 **Custom Agent Integration**: Use custom LlamaIndex `Workflow` classes as agents within the supervisor framework. 11 | - 🌳 **Hierarchical Structures**: Build multi-level agent hierarchies by using supervisors as agents within other supervisors. Supports adding structure to the context (`add_tree_structure`). 12 | - 🏷️ **Agent Name Attribution**: Automatically adds the agent's name to its messages for clarity in the history (`name_addition`). 13 | 14 | ## Installation 15 | 16 | ```bash 17 | pip install llama-index-supervisor 18 | ``` 19 | 20 | Here's a simple example of a supervisor managing two specialized agents (a math expert and a research expert) using LlamaIndex components. The diagram below illustrates the basic architecture: 21 | 22 | ![Supervisor Architecture and Handoff](assets/supervisor.png) 23 | 24 | 25 | ```python 26 | # Ensure you have OPENAI_API_KEY set in your environment 27 | # import os 28 | # os.environ["OPENAI_API_KEY"] = "sk-..." 29 | 30 | from llama_index.llms.openai import OpenAI 31 | from llama_index_supervisor import Supervisor 32 | from llama_index.core.agent.function_calling import FunctionAgent 33 | from llama_index.core.tools import FunctionTool 34 | from llama_index.core.workflow import Context 35 | 36 | # Initialize the LLM 37 | llm = OpenAI(model="gpt-4o", temperature=0) 38 | 39 | # Define tools 40 | def add(a: float, b: float) -> float: 41 | """Add two numbers.""" 42 | return a + b 43 | 44 | def multiply(a: float, b: float) -> float: 45 | """Multiply two numbers.""" 46 | return a * b 47 | 48 | def web_search(query: str) -> str: 49 | """Search the web for information.""" 50 | # Replace with a real web search implementation if needed 51 | print(f"--- Searching web for: {query} ---") 52 | return ( 53 | "Here are the headcounts for each of the FAANG companies in 2024:\n" 54 | "1. **Facebook (Meta)**: 67,317 employees.\n" 55 | "2. **Apple**: 164,000 employees.\n" 56 | "3. **Amazon**: 1,551,000 employees.\n" 57 | "4. **Netflix**: 14,000 employees.\n" 58 | "5. **Google (Alphabet)**: 181,269 employees." 59 | ) 60 | 61 | # Create function tools 62 | add_tool = FunctionTool.from_defaults(fn=add) 63 | multiply_tool = FunctionTool.from_defaults(fn=multiply) 64 | search_tool = FunctionTool.from_defaults(fn=web_search) 65 | 66 | # Create specialized agents using LlamaIndex FunctionAgent 67 | math_agent = FunctionAgent.from_tools( 68 | tools=[add_tool, multiply_tool], 69 | llm=llm, 70 | name="math_expert", 71 | system_prompt="You are a math expert. Always use one tool at a time.", 72 | description="Specialized in performing mathematical calculations like addition and multiplication." # Added description 73 | ) 74 | 75 | research_agent = FunctionAgent.from_tools( 76 | tools=[search_tool], 77 | llm=llm, 78 | name="research_expert", 79 | system_prompt="You are a world class researcher with access to web search. Do not do any math.", 80 | description="Specialized in searching the web for information." # Added description 81 | ) 82 | 83 | # Create supervisor workflow 84 | # The supervisor automatically gets descriptions from the agents 85 | supervisor = Supervisor( 86 | llm=llm, 87 | agents=[math_agent, research_agent], 88 | # Optional: Add a tree structure representation to the system prompt 89 | add_tree_structure=True, 90 | # Optional: Set a timeout for the workflow 91 | timeout=60 92 | ) 93 | 94 | # Run the workflow 95 | # Context manages state, including memory, across workflow steps 96 | ctx = Context(supervisor) 97 | response = await supervisor.run( 98 | input="what's the combined headcount of the FAANG companies in 2024?", 99 | ctx=ctx # Pass the context 100 | ) 101 | 102 | # Print the response directly 103 | print(response) 104 | ``` 105 | 106 | ## Message History Management 107 | 108 | The `output_mode` parameter in the `Supervisor` constructor controls how messages from delegated agents are added back to the supervisor's main chat history: 109 | 110 | - `output_mode="full_history"` (Default): Includes all messages (intermediate steps, tool calls, final response) from the delegated agent's run. 111 | 112 | ![Full History Output](assets/full_history.png) 113 | 114 | - `output_mode="last_message"`: Includes only the final message generated by the delegated agent. 115 | 116 | ![Last Message Output](assets/last_message.png) 117 | 118 | ```python 119 | # Example: Only include the final response from agents 120 | supervisor_last_only = Supervisor( 121 | llm=llm, 122 | agents=[math_agent, research_agent], 123 | output_mode="last_message" 124 | ) 125 | ``` 126 | 127 | ## Hierarchical Supervisors 128 | 129 | You can create multi-level hierarchies by using `Supervisor` instances as agents within another `Supervisor`. Ensure each agent (including supervisors acting as agents) has a unique `name` and a `description`. 130 | 131 | ```python 132 | # Assume math_agent, research_agent, writing_agent, publishing_agent are defined 133 | # (writing_agent and publishing_agent would need tools and descriptions) 134 | 135 | # Create mid-level supervisors 136 | research_supervisor = Supervisor( 137 | llm=llm, 138 | agents=[research_agent, math_agent], 139 | name="research_supervisor", 140 | description="Manages a research team with experts in web research and math.", 141 | system_prompt="You manage a research team. Delegate tasks appropriately.", 142 | timeout=60 143 | ) 144 | 145 | writing_supervisor = Supervisor( 146 | llm=llm, 147 | # agents=[writing_agent, publishing_agent], # Add writing/publishing agents here 148 | name="writing_supervisor", 149 | description="Manages a content team with experts in writing and publishing.", 150 | system_prompt="You manage a content team. Delegate tasks appropriately.", 151 | timeout=60 152 | ) 153 | 154 | # Create top-level supervisor 155 | top_level_supervisor = Supervisor( 156 | llm=llm, 157 | agents=[research_supervisor, writing_supervisor], 158 | name="top_level_supervisor", 159 | description="Executive supervisor coordinating research and writing teams.", 160 | system_prompt="You are the executive supervisor. For research/math, use research_supervisor. For content/publishing, use writing_supervisor.", 161 | timeout=120, # Increase timeout for multi-level 162 | add_tree_structure=True # Useful for complex hierarchies 163 | ) 164 | 165 | # Run the top-level supervisor 166 | ctx_top = Context(top_level_supervisor) 167 | query = "Research the FAANG headcounts for 2024, calculate the total, and then write a brief summary." 168 | response_top = await top_level_supervisor.run( 169 | input=query, 170 | ctx=ctx_top, 171 | ) 172 | 173 | # Print the final response 174 | print(response) 175 | ``` 176 | 177 | ## Using Custom Agents (Workflows) 178 | 179 | You can define your own agent logic by creating a class that inherits from `llama_index.core.workflow.Workflow`. This custom workflow can then be passed as an agent to the `Supervisor`. 180 | 181 | See the `CoolAgent` example in `CoolAgent.ipynb` ([Colab Link](https://colab.research.google.com/drive/1l3hDjXbJn5VrT6jFtZzReUibsN-AEUJ-?usp=sharing)) for a demonstration. The key is that your custom workflow class needs `name` and `description`. 182 | 183 | ```python 184 | # simplified snippet from CoolAgent.ipynb 185 | from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step 186 | from llama_index.core.llms import LLM, ChatMessage 187 | from llama_index.core.memory import ChatMemoryBuffer 188 | 189 | class CoolAgent(Workflow): 190 | def __init__(self, name: str, description: str, llm: LLM, system_prompt: str = "You are a cool agent. Be cool."): 191 | super().__init__() 192 | self.name = name 193 | self.description = description 194 | self.llm = llm 195 | self.system_message = ChatMessage(role="system", content=system_prompt) 196 | 197 | @step 198 | async def start_flow(self, ctx: Context, ev: StartEvent) -> StopEvent: 199 | memory: ChatMemoryBuffer = await ctx.get("memory", default=ChatMemoryBuffer.from_defaults(llm=self.llm)) 200 | # Agent logic here... interacts with memory and LLM 201 | # Example: Ask how it is doing based on supervisor's prompt 202 | user_input = await ctx.get("input", "") # Get input if passed directly 203 | await memory.aput(ChatMessage(role="user", content="How are you?")) # Example interaction 204 | response = await self.llm.achat([self.system_message] + memory.get()) 205 | await memory.aput(response.message) 206 | await ctx.set("memory", memory) # Update context memory 207 | return StopEvent(result=response.message) # Return result 208 | 209 | # Create the custom agent 210 | cool_agent = CoolAgent( 211 | name="cool_agent", 212 | description="A cool agent that does cool things.", 213 | llm=llm, 214 | system_prompt="You are a cool agent. Be super super cool." 215 | ) 216 | 217 | # Use it in the supervisor 218 | supervisor_with_cool = Supervisor( 219 | llm=llm, 220 | agents=[cool_agent] # Pass the custom agent instance 221 | ) 222 | 223 | # Run it 224 | result = await supervisor_with_cool.run( 225 | input="Ask the cool agent how it is doing." 226 | ) 227 | print(result) 228 | ``` 229 | 230 | ## Agent Handoff Mechanism 231 | 232 | - The `Supervisor` automatically creates internal "handoff" tools (`transfer_to_`) for each agent provided. 233 | - When the supervisor's LLM decides to delegate a task, it calls the corresponding handoff tool. 234 | - The `Supervisor` intercepts this tool call, prepares the context (including chat history), and runs the designated agent's workflow (`agent.run(ctx=...)`). 235 | - If `add_handoff_back_messages=True` (default), special messages are added to the history when control returns to the supervisor, indicating the handoff completion. 236 | 237 | ## Adding Memory / Context 238 | 239 | Memory in `llama-index-supervisor` is managed through the `llama_index.core.workflow.Context` object and uses `llama_index.core.memory.ChatMemoryBuffer`. 240 | 241 | - The `Context` is passed to the `supervisor.run()` method. 242 | - The supervisor automatically retrieves or initializes a `ChatMemoryBuffer` within the context under the key `"memory"`. 243 | - You can pre-populate the memory before running the supervisor: 244 | 245 | ```python 246 | # filepath: examples.ipynb (snippet) 247 | from llama_index.core.memory import ChatMemoryBuffer 248 | from llama_index.core.llms import ChatMessage 249 | from llama_index.core.workflow import Context 250 | 251 | # Assume supervisor, math_agent, research_agent are defined 252 | 253 | ctx = Context(supervisor) # Create context linked to the supervisor workflow 254 | 255 | # Pre-populate memory 256 | memory=ChatMemoryBuffer.from_defaults( 257 | chat_history=[ 258 | ChatMessage(role="user", content="What is 2+2?"), 259 | ChatMessage(role="assistant", content="2 + 2 = 4"), # Example previous turn 260 | ] 261 | ) 262 | await ctx.set("memory", memory) # Set the pre-populated memory in the context 263 | 264 | # Run the supervisor with the context containing pre-populated memory 265 | response = await supervisor.run( 266 | input="Now multiply that by 3.", # Follow-up question 267 | ctx=ctx, 268 | ) 269 | ``` 270 | 271 | ## Contributing 272 | 273 | Contributions are welcome! Please feel free to open issues or submit pull requests for any enhancements, bug fixes, or new features. -------------------------------------------------------------------------------- /assets/full_history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmalek312/llama-index-supervisor/38e1a5a059482c1c77e901522c0a26aaa49503e7/assets/full_history.png -------------------------------------------------------------------------------- /assets/last_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmalek312/llama-index-supervisor/38e1a5a059482c1c77e901522c0a26aaa49503e7/assets/last_message.png -------------------------------------------------------------------------------- /assets/supervisor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmalek312/llama-index-supervisor/38e1a5a059482c1c77e901522c0a26aaa49503e7/assets/supervisor.png -------------------------------------------------------------------------------- /examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# this is for logging, ignore this cell if you are not using arize\n", 10 | "import llama_index.core\n", 11 | "llama_index.core.set_global_handler(\"arize_phoenix\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "The following example demonstrates a basic supervisor workflow.\n", 19 | "\n", 20 | "1. **Define Tools**: Functions for `add`, `multiply`, and `web_search` are defined and converted into `FunctionTool` objects.\n", 21 | "2. **Create Agents**: Two specialized `FunctionAgent` instances are created:\n", 22 | " * `math_agent`: Uses `add_tool` and `multiply_tool`.\n", 23 | " * `research_agent`: Uses `search_tool`.\n", 24 | "3. **Initialize Supervisor**: A `Supervisor` is created to manage the `math_agent` and `research_agent`. The `add_tree_structure=True` adds json context to the llm about the hierarchy of agents and tools.\n", 25 | "4. **Run Workflow**: The supervisor is executed with the input query \"what's the combined headcount of the FAANG companies in 2024?\". The supervisor delegates the task to the appropriate agent(s) based on the query and tool capabilities. First, the `research_agent` is likely called to find the headcounts, and then the `math_agent` is called to sum them up.\n", 26 | "5. **Print Response**: The final response from the supervisor is printed." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 10, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "assistant: The combined headcount of the FAANG companies in 2024 is 1,977,586 employees.\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "from llama_index.llms.openai import OpenAI\n", 44 | "from llama_index_supervisor import Supervisor\n", 45 | "from llama_index.core.agent.workflow import FunctionAgent\n", 46 | "from llama_index.core.tools import FunctionTool\n", 47 | "from llama_index.core.workflow import Context\n", 48 | "\n", 49 | "# Initialize the LLM\n", 50 | "llm = OpenAI(model=\"gpt-4o\", temperature=0)\n", 51 | "\n", 52 | "# Define tools\n", 53 | "def add(a: float, b: float) -> float:\n", 54 | " \"\"\"Add two numbers.\"\"\"\n", 55 | " return a + b\n", 56 | "\n", 57 | "def multiply(a: float, b: float) -> float:\n", 58 | " \"\"\"Multiply two numbers.\"\"\"\n", 59 | " return a * b\n", 60 | "\n", 61 | "def web_search(query: str) -> str:\n", 62 | " \"\"\"Search the web for information.\"\"\"\n", 63 | " return (\n", 64 | " \"Here are the headcounts for each of the FAANG companies in 2024:\\n\"\n", 65 | " \"1. **Facebook (Meta)**: 67,317 employees.\\n\"\n", 66 | " \"2. **Apple**: 164,000 employees.\\n\"\n", 67 | " \"3. **Amazon**: 1,551,000 employees.\\n\"\n", 68 | " \"4. **Netflix**: 14,000 employees.\\n\"\n", 69 | " \"5. **Google (Alphabet)**: 181,269 employees.\"\n", 70 | " )\n", 71 | "\n", 72 | "# Create function tools\n", 73 | "add_tool = FunctionTool.from_defaults(fn=add)\n", 74 | "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", 75 | "search_tool = FunctionTool.from_defaults(fn=web_search)\n", 76 | "\n", 77 | "# Create specialized agents\n", 78 | "math_agent = FunctionAgent(\n", 79 | " name=\"math_expert\",\n", 80 | " llm=llm,\n", 81 | " tools=[add_tool, multiply_tool],\n", 82 | " system_prompt=\"You are a math expert. Always use one tool at a time.\"\n", 83 | ")\n", 84 | "\n", 85 | "research_agent = FunctionAgent(\n", 86 | " name=\"research_expert\",\n", 87 | " llm=llm,\n", 88 | " tools=[search_tool],\n", 89 | " system_prompt=\"You are a world class researcher with access to web search. Do not do any math.\"\n", 90 | ")\n", 91 | "\n", 92 | "# Create supervisor workflow\n", 93 | "supervisor = Supervisor(\n", 94 | " llm=llm,\n", 95 | " agents=[math_agent, research_agent],\n", 96 | " add_tree_structure=True,\n", 97 | " timeout=60\n", 98 | ")\n", 99 | "\n", 100 | "# Run the workflow\n", 101 | "ctx = Context(supervisor)\n", 102 | "response = await supervisor.run(\n", 103 | " input=\"what's the combined headcount of the FAANG companies in 2024?\",\n", 104 | " ctx=ctx\n", 105 | ")\n", 106 | "\n", 107 | "print(str(response))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "This example showcases a hierarchical supervisor structure:\n", 115 | "\n", 116 | "1. **Define Tools**: Functions for `add`, `multiply`, `web_search`, `format_document`, and `publish_content` are defined and converted into `FunctionTool` objects.\n", 117 | "2. **Create Base Agents**: Four specialized `FunctionAgent` instances are created: `math_agent`, `research_agent`, `writing_agent`, and `publishing_agent`, each equipped with relevant tools.\n", 118 | "3. **Create Mid-Level Supervisors**: Two `Supervisor` instances are created:\n", 119 | " * `research_supervisor`: Manages `research_agent` and `math_agent`.\n", 120 | " * `writing_supervisor`: Manages `writing_agent` and `publishing_agent`.\n", 121 | "4. **Create Top-Level Supervisor**: A `top_level_supervisor` is created to manage the `research_supervisor` and `writing_supervisor`, establishing a three-tier hierarchy. The `add_tree_structure=True` adds json context to the llm about the hierarchy of agents and tools.\n", 122 | "5. **Run Workflow**: The `top_level_supervisor` is executed with the input query \"what's the combined headcount of the FAANG companies in 2024?\". The top supervisor delegates the task to the appropriate mid-level supervisor (`research_supervisor`), which in turn delegates to the appropriate base agent(s) (`research_agent` then `math_agent`).\n", 123 | "6. **Print Response**: The final response coordinated through the hierarchy is printed." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "assistant: The research expert was unable to find explicit headcount numbers for the FAANG companies in 2024 through web searches. It is recommended to check the latest annual reports or press releases from each company for the most accurate and up-to-date information. If you need further assistance or a different approach, please let me know!\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "from llama_index.llms.openai import OpenAI\n", 141 | "from llama_index_supervisor import Supervisor\n", 142 | "from llama_index.core.agent.workflow import FunctionAgent\n", 143 | "from llama_index.core.tools import FunctionTool\n", 144 | "from llama_index.core.workflow import Context\n", 145 | "\n", 146 | "# Initialize the LLM\n", 147 | "llm = OpenAI(model=\"gpt-4o\", temperature=0)\n", 148 | "\n", 149 | "# Define all the required tools\n", 150 | "def add(a: float, b: float) -> float:\n", 151 | " \"\"\"Add two numbers.\"\"\"\n", 152 | " return a + b\n", 153 | "\n", 154 | "def multiply(a: float, b: float) -> float:\n", 155 | " \"\"\"Multiply two numbers.\"\"\"\n", 156 | " return a * b\n", 157 | "\n", 158 | "def web_search(query: str) -> str:\n", 159 | " \"\"\"Search the web for information.\"\"\"\n", 160 | " return \"Web search results for: \" + query\n", 161 | "\n", 162 | "def format_document(content: str) -> str:\n", 163 | " \"\"\"Format document with proper structure and styling.\"\"\"\n", 164 | " return f\"Formatted document: {content}\"\n", 165 | "\n", 166 | "def publish_content(content: str) -> str:\n", 167 | " \"\"\"Publish content to appropriate channels.\"\"\"\n", 168 | " return f\"Published: {content}\"\n", 169 | "\n", 170 | "# Create function tools\n", 171 | "add_tool = FunctionTool.from_defaults(fn=add)\n", 172 | "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", 173 | "search_tool = FunctionTool.from_defaults(fn=web_search)\n", 174 | "format_tool = FunctionTool.from_defaults(fn=format_document)\n", 175 | "publish_tool = FunctionTool.from_defaults(fn=publish_content)\n", 176 | "\n", 177 | "# Create base-level agents\n", 178 | "math_agent = FunctionAgent(\n", 179 | " name=\"math_expert\",\n", 180 | " llm=llm,\n", 181 | " tools=[add_tool, multiply_tool],\n", 182 | " system_prompt=\"You are a math expert. Always use one tool at a time.\"\n", 183 | ")\n", 184 | "\n", 185 | "research_agent = FunctionAgent(\n", 186 | " name=\"research_expert\", \n", 187 | " llm=llm,\n", 188 | " tools=[search_tool],\n", 189 | " system_prompt=\"You are a world class researcher with access to web search. Do not do any math.\"\n", 190 | ")\n", 191 | "\n", 192 | "writing_agent = FunctionAgent(\n", 193 | " name=\"writing_expert\",\n", 194 | " llm=llm,\n", 195 | " tools=[format_tool],\n", 196 | " system_prompt=\"You are a professional writer who formats and improves content.\"\n", 197 | ")\n", 198 | "\n", 199 | "publishing_agent = FunctionAgent(\n", 200 | " name=\"publishing_expert\",\n", 201 | " llm=llm,\n", 202 | " tools=[publish_tool],\n", 203 | " system_prompt=\"You are a publishing expert who knows how to distribute content effectively.\"\n", 204 | ")\n", 205 | "\n", 206 | "# Create mid-level supervisors (research team and writing team)\n", 207 | "research_supervisor = Supervisor(\n", 208 | " llm=llm,\n", 209 | " agents=[research_agent, math_agent],\n", 210 | " name=\"research_supervisor\",\n", 211 | " system_prompt=\"You manage a research team with experts in research and math. Delegate tasks appropriately.\", \n", 212 | " timeout=60\n", 213 | ")\n", 214 | "\n", 215 | "writing_supervisor = Supervisor(\n", 216 | " llm=llm,\n", 217 | " agents=[writing_agent, publishing_agent],\n", 218 | " name=\"writing_supervisor\", \n", 219 | " system_prompt=\"You manage a content team with experts in writing and publishing. Delegate tasks appropriately.\",\n", 220 | " timeout=60\n", 221 | ")\n", 222 | "\n", 223 | "# Create top-level supervisor\n", 224 | "top_level_supervisor = Supervisor(\n", 225 | " llm=llm,\n", 226 | " agents=[research_supervisor, writing_supervisor],\n", 227 | " name=\"top_level_supervisor\",\n", 228 | " system_prompt=\"You are the executive supervisor coordinating between the research and writing teams. For research or math problems, use the research_supervisor. For content creation and publishing, use the writing_supervisor.\",\n", 229 | " timeout=60,\n", 230 | " add_tree_structure=True\n", 231 | ")\n", 232 | "\n", 233 | "query = \"what's the combined headcount of the FAANG companies in 2024?\"\n", 234 | "\n", 235 | "ctx = Context(top_level_supervisor)\n", 236 | "response = await top_level_supervisor.run(\n", 237 | " input=query,\n", 238 | " ctx=ctx,\n", 239 | ")\n", 240 | "\n", 241 | "print(response)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 4, 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "data": { 251 | "text/plain": [ 252 | "ChatResponse(message=ChatMessage(role=, additional_kwargs={}, blocks=[TextBlock(block_type='text', text='The total number of employees across all FAANG companies in 2024 is 1,000,000.')]), raw=ChatCompletionChunk(id='chatcmpl-BJKyiHSvUAIe43ZbrzFn0twZebwKk', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role=None, tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1743948572, model='gpt-4o-2024-08-06', object='chat.completion.chunk', service_tier='default', system_fingerprint='fp_898ac29719', usage=None), delta='', logprobs=None, additional_kwargs={})" 253 | ] 254 | }, 255 | "execution_count": 4, 256 | "metadata": {}, 257 | "output_type": "execute_result" 258 | } 259 | ], 260 | "source": [ 261 | "## adding chat messages\n", 262 | "from llama_index.llms.openai import OpenAI\n", 263 | "llm = OpenAI(model=\"gpt-4o\", temperature=0)\n", 264 | "from llama_index.core.llms.llm import ChatMessage\n", 265 | "from llama_index.core.memory import ChatMemoryBuffer\n", 266 | "from llama_index.core.workflow import Context\n", 267 | "from llama_index_supervisor import Supervisor\n", 268 | "\n", 269 | "ctx = Context(\n", 270 | " supervisor\n", 271 | ")\n", 272 | "memory=ChatMemoryBuffer.from_defaults(\n", 273 | " chat_history=[\n", 274 | " # these messages will be added to the chat history right after the system prompt\n", 275 | " ChatMessage(role=\"user\", content=\"what's the combined headcount of the FAANG companies in 2024?\"),\n", 276 | " ChatMessage(role=\"assistant\", content=\"The combined headcount is 1,000,000.\"),\n", 277 | " ]\n", 278 | ")\n", 279 | "await ctx.set(\"memory\", memory)\n", 280 | "\n", 281 | "supervisor = Supervisor(llm=llm, agents=[math_agent, research_agent], system_prompt=\"You are a math expert and a world class researcher with access to web search.\")\n", 282 | "await supervisor.run(\n", 283 | " input=\"Paraphrase your answer.\",\n", 284 | " ctx=ctx,\n", 285 | ")\n", 286 | "\n", 287 | "# chat history will look like this\n", 288 | "# [\n", 289 | "# ChatMessage(role=\"system\", content=\"You are a math expert and a world class researcher with access to web search.\"),\n", 290 | "# ChatMessage(role=\"user\", content=\"what's the combined headcount of the FAANG companies in 2024?\"),\n", 291 | "# ChatMessage(role=\"assistant\", content=\"The combined headcount is 1,000,000.\"),\n", 292 | "# ChatMessage(role=\"user\", content=\"Paraphrase your answer.\")\n", 293 | "#]" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 1, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "\n", 303 | "from llama_index.llms.openai import OpenAI\n", 304 | "llm = OpenAI(model=\"gpt-4o\", temperature=0)\n", 305 | "from llama_index.core.llms.llm import ChatMessage\n", 306 | "from llama_index.core.memory import ChatMemoryBuffer\n", 307 | "from llama_index.core.workflow import Context\n", 308 | "from llama_index_supervisor import Supervisor\n", 309 | "\n", 310 | "### you can also use your own workflow as an agent and pass it into the supervisor\n", 311 | "from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step\n", 312 | "from llama_index.core.llms import LLM\n", 313 | "\n", 314 | "\n", 315 | "class CoolAgent(Workflow):\n", 316 | " def __init__(self, name: str, llm: LLM, system_prompt: str = \"You are a cool agent. Be cool.\"):\n", 317 | " super().__init__()\n", 318 | " self.name = name\n", 319 | " self.llm = llm\n", 320 | " self.system_message = ChatMessage(role=\"system\", content=system_prompt)\n", 321 | " @step\n", 322 | " async def start_flow(self, ctx: Context, ev: StartEvent) -> StopEvent:\n", 323 | " # the supervisor send a copy of the chat history through the ctx\n", 324 | " memory: ChatMemoryBuffer = await ctx.get(\"memory\", default=ChatMemoryBuffer.from_defaults(llm=self.llm))\n", 325 | " # you can use the memory to get the chat history\n", 326 | " # messages = memory.get() # or get_all()\n", 327 | " # do something with the messages\n", 328 | " await memory.aput(\n", 329 | " ChatMessage(role=\"user\", content=\"How are you?\")\n", 330 | " )\n", 331 | " response = await self.llm.achat([self.system_message] + memory.get()) # dont add the system message to the memory so it doesnt get sent back to the supervisor\n", 332 | " \n", 333 | " # add response to the memory\n", 334 | " await memory.aput(response.message)\n", 335 | " # set ctx(\"memory\") to the updated memory (just in case no memory key in context was passed in the first place)\n", 336 | " await ctx.set(\"memory\", memory)\n", 337 | " return StopEvent(result=response.message) # the supervisor doesn't check the result of the workflow, so you can return anything you want in case you use the agent\n", 338 | " \n", 339 | "\n", 340 | "cool_agent = CoolAgent(\n", 341 | " name=\"cool_agent\",\n", 342 | " llm=llm,\n", 343 | " system_prompt=\"You are a cool agent. Be super super cool.\"\n", 344 | ")\n", 345 | "\n", 346 | "supervisor = Supervisor(\n", 347 | " llm=llm,\n", 348 | " agents=[cool_agent]\n", 349 | ")\n", 350 | "\n", 351 | "result = await supervisor.run(\n", 352 | " input=\"Ask the cool agent how it is doing.\"\n", 353 | ")\n" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 2, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "assistant: The cool agent says it's \"doing as cool as a cucumber in a bowl of hot sauce!\" How about you?\n" 366 | ] 367 | } 368 | ], 369 | "source": [ 370 | "print(result)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [] 379 | } 380 | ], 381 | "metadata": { 382 | "kernelspec": { 383 | "display_name": "Python 3", 384 | "language": "python", 385 | "name": "python3" 386 | }, 387 | "language_info": { 388 | "codemirror_mode": { 389 | "name": "ipython", 390 | "version": 3 391 | }, 392 | "file_extension": ".py", 393 | "mimetype": "text/x-python", 394 | "name": "python", 395 | "nbconvert_exporter": "python", 396 | "pygments_lexer": "ipython3", 397 | "version": "3.12.4" 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 2 402 | } 403 | -------------------------------------------------------------------------------- /llama_index_supervisor/__init__.py: -------------------------------------------------------------------------------- 1 | from .supervisor import Supervisor 2 | 3 | __all__ = ["Supervisor"] 4 | -------------------------------------------------------------------------------- /llama_index_supervisor/agent_name.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Literal 3 | 4 | from llama_index.core.llms import ChatMessage 5 | from llama_index.core.llms.llm import LLM 6 | from llama_index.core.llms.function_calling import FunctionCallingLLM 7 | NAME_PATTERN = re.compile(r"(.*?)", re.DOTALL) 8 | CONTENT_PATTERN = re.compile(r"(.*?)", re.DOTALL) 9 | 10 | AgentNameMode = Literal["inline"] 11 | 12 | def add_inline_agent_name(message: ChatMessage, name: str): 13 | """Add name and content XML tags to the message content. 14 | 15 | Examples: 16 | 17 | >>> add_inline_agent_name(AIMessage(content="Hello", name="assistant")) 18 | AIMessage(content="assistantHello", name="assistant") 19 | 20 | >>> add_inline_agent_name(AIMessage(content=[{"type": "text", "text": "Hello"}], name="assistant")) 21 | AIMessage(content=[{"type": "text", "text": "assistantHello"}], name="assistant") 22 | """ 23 | if not isinstance(message, ChatMessage): 24 | return message 25 | message.content = ( 26 | f"{name}{message.content}" 27 | ) 28 | 29 | 30 | def remove_inline_agent_name(message: ChatMessage) -> ChatMessage: 31 | """Remove explicit name and content XML tags from the AI message content. 32 | 33 | Examples: 34 | 35 | >>> remove_inline_agent_name(AIMessage(content="assistantHello", name="assistant")) 36 | AIMessage(content="Hello", name="assistant") 37 | 38 | >>> remove_inline_agent_name(AIMessage(content=[{"type": "text", "text": "assistantHello"}], name="assistant")) 39 | AIMessage(content=[{"type": "text", "text": "Hello"}], name="assistant") 40 | """ 41 | if not isinstance(message, ChatMessage): 42 | return message 43 | 44 | 45 | content = message.content 46 | 47 | name_match: re.Match | None = NAME_PATTERN.search(content) 48 | content_match: re.Match | None = CONTENT_PATTERN.search(content) 49 | if not name_match or not content_match: 50 | return message 51 | 52 | if name_match.group(1) != message.name: 53 | return message 54 | 55 | parsed_content = content_match.group(1) 56 | parsed_message = message.model_copy() 57 | parsed_message.content = parsed_content 58 | return parsed_message 59 | -------------------------------------------------------------------------------- /llama_index_supervisor/events.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.llms import ChatMessage 2 | from llama_index.core.tools import ToolSelection, ToolOutput 3 | from llama_index.core.workflow import Event 4 | 5 | 6 | class InputEvent(Event): 7 | input: list[ChatMessage] 8 | 9 | 10 | class StreamEvent(Event): 11 | delta: str 12 | 13 | 14 | class ToolCallEvent(Event): 15 | tool_calls: list[ToolSelection] 16 | 17 | 18 | class FunctionOutputEvent(Event): 19 | output: ToolOutput -------------------------------------------------------------------------------- /llama_index_supervisor/handoff.py: -------------------------------------------------------------------------------- 1 | import re 2 | import uuid 3 | from llama_index.core.tools import FunctionTool 4 | from llama_index.core.workflow import Context 5 | from llama_index.core.llms import ChatMessage, MessageRole 6 | 7 | 8 | WHITESPACE_RE = re.compile(r"\s+") 9 | 10 | HANDOFF_TOOL_DESCRIPTION = ( 11 | "Transfers the task to {agent_name}. This agent is responsible for {agent_description}. " 12 | "Use this tool when the task falls within the agent’s expertise or when delegation is necessary for better task execution. " 13 | "Provide a clear reason for the transfer and describe the task in detail to ensure smooth handover." 14 | ) 15 | 16 | def _normalize_agent_name(agent_name: str) -> str: 17 | """Normalize an agent name to be used inside the tool name.""" 18 | return WHITESPACE_RE.sub("_", agent_name.strip()).lower() 19 | 20 | 21 | def create_handoff_tool(agent_name: str, agent_description: str) -> FunctionTool: 22 | """Create a tool that can handoff control to the requested agent. 23 | 24 | Args: 25 | agent_name: The name of the agent to handoff control to, i.e. 26 | the name of the agent node in the multi-agent graph. 27 | Agent names should be simple, clear and unique, preferably in snake_case, 28 | although you are only limited to the names accepted by LangGraph 29 | nodes as well as the tool names accepted by LLM providers 30 | (the tool name will look like this: `transfer_to_`). 31 | agent_description: A description of the agent's responsibilities and expertise. 32 | """ 33 | tool_name = f"transfer_to_{_normalize_agent_name(agent_name)}" 34 | if not agent_description: 35 | #Ask agent for help 36 | agent_description = "Ask agent {} for help.".format(agent_name) 37 | 38 | def handoff_to_agent(ctx: Context, task: str, reason: str) -> str: 39 | return # filler function 40 | 41 | return FunctionTool.from_defaults(fn=handoff_to_agent, name=tool_name, description=HANDOFF_TOOL_DESCRIPTION.format( 42 | agent_name=agent_name, agent_description=agent_description)) 43 | 44 | 45 | def create_handoff_back_messages( 46 | agent_name: str, supervisor_name: str 47 | ) -> tuple[ChatMessage]: 48 | """Create a pair of (AIMessage, ToolMessage) to add to the message history when returning control to the supervisor.""" 49 | tool_call_id = f"call_{uuid.uuid4().hex[:8]}" 50 | tool_name = f"transfer_back_to_{_normalize_agent_name(supervisor_name)}" 51 | tool_calls = [ 52 | { 53 | "id": tool_call_id, 54 | "function": {"name": tool_name, "arguments": "{}"}, 55 | "type": "function", 56 | } 57 | ] 58 | return ( 59 | ChatMessage( 60 | role=MessageRole.ASSISTANT, 61 | content=f"Transferring back to {supervisor_name}", 62 | additional_kwargs={"tool_calls": tool_calls}, 63 | name=agent_name, 64 | ), 65 | ChatMessage( 66 | tool_call_id=tool_call_id, 67 | role="tool", 68 | content=f"Successfully transferred back to {supervisor_name}", 69 | additional_kwargs={ 70 | "tool_call_id": tool_call_id, 71 | "name": tool_name, 72 | }, 73 | ), 74 | ) 75 | -------------------------------------------------------------------------------- /llama_index_supervisor/supervisor.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from llama_index.core.llms.function_calling import FunctionCallingLLM 4 | from llama_index.core.memory import ChatMemoryBuffer 5 | from llama_index.core.llms import ChatMessage 6 | from llama_index.core.tools.types import BaseTool 7 | from llama_index.core.workflow import ( 8 | Context, 9 | Workflow, 10 | StartEvent, 11 | StopEvent, 12 | step, 13 | ) 14 | from .events import InputEvent, StreamEvent, ToolCallEvent 15 | from llama_index.core.agent.workflow import BaseWorkflowAgent 16 | from .handoff import ( 17 | _normalize_agent_name, 18 | create_handoff_tool, 19 | create_handoff_back_messages, 20 | ) 21 | from .agent_name import add_inline_agent_name 22 | import json 23 | import re 24 | DEFAULT_SYSTEM_PROMPT = ( 25 | "You are a {agent_name}. You will be responsible for managing the workflow of other agents and tools. " 26 | "You will receive user input and delegate tasks to the appropriate agents or tools. " 27 | "You will also handle the responses from the agents and tools, and provide feedback to the user. " 28 | ) 29 | 30 | DEFAULT_DESCRIPTION = ( 31 | "This agent is responsible for managing the following agents and tools. " 32 | "Agents: {agents}. " 33 | "Tools: {tools}. " 34 | ) 35 | 36 | TREE_STRUCTURE_PROMPT = ( 37 | "The following is a tree structure of the agents and tools in the workflow. " 38 | "The tree structure is as follows:\n\n{tree_structure}\n\n" 39 | ) 40 | ## This is an event driven workflow agent, functions that are decorated with @step return an event that is passed to the next step in the workflow. 41 | ## So the moment an event is returned, the workflow manager will pick it up and pass it to the next step which takes the event as input. 42 | class Supervisor(Workflow): 43 | def __init__( 44 | self, 45 | llm: FunctionCallingLLM, 46 | agents: set[BaseWorkflowAgent | Workflow] = [], 47 | tools: list[BaseTool] = [], 48 | name: str = "supervisor", 49 | system_prompt: str | None = None, 50 | add_handoff_back_messages: bool = True, 51 | output_mode: str = "full_history", 52 | description: str | None = DEFAULT_DESCRIPTION, 53 | add_tree_structure: bool = False, 54 | name_addition: bool = True, 55 | *args: Any, 56 | **kwargs: Any, 57 | ) -> None: 58 | """Initialize the Supervisor agent. 59 | Args: 60 | llm: The LLM to use for the supervisor agent. 61 | agents: A list of agents to manage. 62 | tools: A list of tools to use. 63 | name: The name of the supervisor agent. (sent to LLM as part of the system prompt) 64 | system_prompt: The system prompt to use for the supervisor agent. (defaults to DEFAULT_SYSTEM_PROMPT) 65 | add_handoff_back_messages: Whether to add handoff back messages to the chat history. 66 | output_mode: The output mode for the supervisor agent. Can be either 'full_history' or 'last_message'. 67 | description: The description of the supervisor agent. (sent to LLM as part of function description if this supervisor is used as an agent by another supervisor) 68 | add_tree_structure: Whether to add a tree structure to the context to give llm more context about the agents and tools. 69 | name_addition: Whether to add the name of the agent that the message belongs to in the message. (defaults to True) 70 | """ 71 | super().__init__(*args, **kwargs) 72 | self.validate_agents(agents) 73 | assert ( 74 | llm.metadata.is_function_calling_model 75 | ), "Supervisor only supports function calling LLMs" 76 | assert output_mode in [ 77 | "full_history", 78 | "last_message", 79 | ], "output_mode must be either 'full_history' or 'last_message'" 80 | assert ( 81 | len(agents) + len(tools) > 0 82 | ), "At least one agent or tool must be provided" 83 | 84 | # Initialize core attributes 85 | self.name = name 86 | self.llm = llm 87 | if not system_prompt: 88 | system_prompt = DEFAULT_SYSTEM_PROMPT.format(agent_name=name) 89 | if isinstance(system_prompt, str): 90 | self.system_prompt = [ChatMessage(role="system", content=system_prompt)] 91 | elif isinstance(system_prompt, list[str]): 92 | self.system_prompt = [ 93 | ChatMessage(role="system", content=sp) for sp in system_prompt 94 | ] 95 | elif isinstance(system_prompt, ChatMessage): 96 | self.system_prompt = [system_prompt] 97 | elif isinstance(system_prompt, list[ChatMessage]): 98 | self.system_prompt = system_prompt 99 | self.add_handoff_back_messages = add_handoff_back_messages 100 | self.output_mode = output_mode 101 | 102 | # Initialize tools and agents 103 | self.tools = tools or [] 104 | self.agents = agents 105 | 106 | # Setup agents and tools 107 | self._setup_agents() 108 | self._setup_tools() 109 | 110 | # Set up the description 111 | self.description = description or DEFAULT_DESCRIPTION.format( 112 | agents=", ".join([name for name in self.agent_names]), 113 | tools=", ".join([name for name in self.tools_by_name.keys()]), 114 | ) 115 | self.name_addition = name_addition 116 | self.add_tree_structure = add_tree_structure 117 | self.tree_structure = {} 118 | if add_tree_structure: 119 | self.tree_dict = self._build_agent_tool_tree(self) 120 | 121 | def _build_agent_tool_tree(self, entity: BaseWorkflowAgent | Workflow) -> dict[str, Any]: 122 | """ 123 | Recursively build the tree structure of agents and tools. 124 | 125 | Args: 126 | entity: The current agent or supervisor instance. 127 | 128 | Returns: 129 | A dictionary representing the tree structure, potentially wrapped 130 | with the top-level agent's name if entity is self. 131 | """ 132 | subtree: dict[str, Any] = {} # Renamed 'tree' to 'subtree' for clarity 133 | 134 | # Get tools directly associated with the current entity (excluding agent handoff tools for supervisors) 135 | entity_tools = [] 136 | if hasattr(entity, 'tools'): 137 | # Ensure agent_tools exists before trying to access it, default to empty list 138 | agent_tools = getattr(entity, 'agent_tools', []) 139 | # Ensure metadata and name exist before accessing 140 | agent_tool_names = set( 141 | getattr(getattr(at, 'metadata', None), 'name', None) 142 | for at in agent_tools 143 | ) 144 | agent_tool_names.discard(None) # Remove None if metadata/name was missing 145 | 146 | entity_tools = [ 147 | getattr(getattr(tool, 'metadata', None), 'name', None) 148 | for tool in entity.tools 149 | if hasattr(tool, 'metadata') and ( 150 | not isinstance(entity, Supervisor) or 151 | getattr(getattr(tool, 'metadata', None), 'name', None) not in agent_tool_names 152 | ) 153 | ] 154 | # Filter out None names if metadata or name was missing 155 | entity_tools = [name for name in entity_tools if name] 156 | 157 | if entity_tools: 158 | subtree["tools"] = sorted(entity_tools) # Sort for consistent output 159 | 160 | # Get agents associated with the current entity 161 | if hasattr(entity, 'agents') and entity.agents: 162 | subtree["agents"] = {} 163 | for agent in entity.agents: 164 | # Check if agent has a name attribute before using it as a key 165 | agent_name = getattr(agent, 'name', None) 166 | if agent_name: 167 | # Recursively build the tree for sub-agents 168 | # Use self._build_agent_tool_tree for the recursive call 169 | subtree["agents"][agent_name] = self._build_agent_tool_tree(agent) 170 | 171 | # If the entity being processed is the top-level 'self', wrap the result 172 | if entity is self: 173 | # Ensure self has a name, provide a default if not 174 | self_name = getattr(self, 'name', 'root_agent') # Use a default name if needed 175 | return {self_name: subtree} 176 | else: 177 | # Otherwise, return the subtree directly for recursive calls 178 | return subtree 179 | 180 | 181 | def _setup_agents(self) -> None: 182 | """Register and initialize all agents.""" 183 | 184 | self.agent_names = set() 185 | self.agents_by_name = {} 186 | 187 | for agent in self.agents: 188 | normalized_name = _normalize_agent_name(agent.name) 189 | if normalized_name in self.agents_by_name: 190 | raise ValueError( 191 | f"Duplicate agent name found: {normalized_name}. Agent names must be unique." 192 | ) 193 | self.agent_names.add(normalized_name) 194 | self.agents_by_name[normalized_name] = agent 195 | def validate_agents(self, agents): 196 | for agent in agents: 197 | # add a check to check if agent name is not None or empty or ... 198 | assert hasattr(agent, "name") and agent.name, f"Agent {agent} is missing 'name' or name is empty" 199 | assert isinstance(agent.name, str) and agent.name.strip(), f"Agent {agent} has an invalid 'name'. Name must be a non-empty string." 200 | assert hasattr(agent, "run") and callable(getattr(agent, "run")), f"Agent {agent} is missing 'run()' method" 201 | 202 | def _setup_tools(self) -> None: 203 | """Create tools for agents and register all tools.""" 204 | # Create agent handoff tools 205 | self.agent_tools = [create_handoff_tool(agent.name, agent.description if hasattr(agent, "description") else "") for agent in self.agents] 206 | self.tools.extend(self.agent_tools) 207 | 208 | # Create lookup dictionaries 209 | self.tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools} 210 | self.agents_by_tool_name = { 211 | tool.metadata.get_name(): tool for tool in self.agent_tools 212 | } 213 | 214 | @step 215 | async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent: 216 | """Prepare chat history from user input.""" 217 | 218 | # Get or create memory 219 | memory: ChatMemoryBuffer = await ctx.get( 220 | "memory", default=ChatMemoryBuffer.from_defaults(llm=self.llm) 221 | ) 222 | user_input = ev.get("input", default=None) 223 | assert len(memory.get_all()) > 0 or user_input, "Memory input cannot be empty." 224 | if self.add_tree_structure: 225 | # Add tree structure to memory 226 | await memory.aput( 227 | ChatMessage(role="system", content=TREE_STRUCTURE_PROMPT.format(tree_structure=json.dumps(self.tree_dict, indent=2))) 228 | ) 229 | # Add user input to memory 230 | if user_input: 231 | await memory.aput(ChatMessage(role="user", content=user_input)) 232 | # Update context 233 | await ctx.set("memory", memory) 234 | input_messages = memory.get() 235 | return InputEvent(input=input_messages) 236 | 237 | @step 238 | async def handle_llm_input( 239 | self, ctx: Context, ev: InputEvent 240 | ) -> ToolCallEvent | StopEvent: 241 | """Process input through LLM and handle streaming response.""" 242 | 243 | chat_history = ev.input 244 | 245 | # Stream response from LLM 246 | response = await self._get_llm_response(ctx, chat_history) 247 | 248 | # Save the final response 249 | memory = await ctx.get("memory") 250 | 251 | # Check for tool calls 252 | tool_calls = self.llm.get_tool_calls_from_response( 253 | response, error_on_no_tool_call=False 254 | ) 255 | if not tool_calls: 256 | message = response.message.model_copy(deep=True) 257 | else: 258 | message = response.message 259 | add_inline_agent_name(message, self.name) 260 | await memory.aput(message) 261 | await ctx.set("memory", memory) 262 | 263 | if not tool_calls: 264 | return StopEvent(result=response) 265 | 266 | return ToolCallEvent(tool_calls=tool_calls) 267 | 268 | async def _get_llm_response(self, ctx: Context, chat_history): 269 | """Get streaming response from LLM.""" 270 | response_stream = await self.llm.astream_chat_with_tools( 271 | self.tools, chat_history=self.system_prompt + chat_history 272 | ) 273 | response = None 274 | async for response in response_stream: 275 | ctx.write_event_to_stream(StreamEvent(delta=response.delta or "")) 276 | 277 | return response 278 | 279 | @step 280 | async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent: 281 | """Handle tool calls and agent handoffs.""" 282 | tool_calls = ev.tool_calls 283 | 284 | # Split agent handoffs from regular tool calls 285 | agent_handoffs, regular_tools = self._split_tool_calls(tool_calls) 286 | 287 | # Process all tool calls 288 | tool_msgs = [] 289 | await self._process_regular_tools(regular_tools, tool_msgs) 290 | await self._process_agent_handoffs(ctx, agent_handoffs, tool_msgs) 291 | 292 | # Update memory and return input event 293 | await self._update_memory(ctx, tool_msgs) 294 | return await self._get_input_event(ctx) 295 | 296 | def _split_tool_calls(self, tool_calls): 297 | """Split tool calls into agent handoffs and regular tools.""" 298 | agent_handoffs = [ 299 | tc 300 | for tc in tool_calls 301 | if any(tc.tool_name == at.metadata.name for at in self.agent_tools) 302 | ] 303 | regular_tools = [tc for tc in tool_calls if tc not in agent_handoffs] 304 | return agent_handoffs, regular_tools 305 | 306 | async def _process_regular_tools(self, regular_tools, tool_msgs: list) -> None: 307 | """Process regular tool calls.""" 308 | for tool_call in regular_tools: 309 | tool_name = tool_call.tool_name 310 | 311 | additional_kwargs = { 312 | "tool_call_id": tool_call.tool_id, 313 | "name": tool_name, 314 | } 315 | 316 | if not (tool := self.tools_by_name.get(tool_name)): 317 | 318 | tool_msgs.append( 319 | self._create_tool_error_message( 320 | f"Tool {tool_name} does not exist", additional_kwargs 321 | ) 322 | ) 323 | continue 324 | 325 | try: 326 | tool_output = tool(**tool_call.tool_kwargs) 327 | 328 | tool_msgs.append( 329 | ChatMessage( 330 | role="tool", 331 | content=tool_output.content, 332 | additional_kwargs=additional_kwargs, 333 | ) 334 | ) 335 | except Exception as e: 336 | 337 | tool_msgs.append( 338 | self._create_tool_error_message( 339 | f"Encountered error in tool call: {e}", additional_kwargs 340 | ) 341 | ) 342 | 343 | def _create_tool_error_message( 344 | self, content: str, kwargs: dict[str, Any] 345 | ) -> ChatMessage: 346 | """Create a tool error message.""" 347 | return ChatMessage( 348 | role="tool", 349 | content=content, 350 | additional_kwargs=kwargs, 351 | ) 352 | 353 | async def _process_agent_handoffs( 354 | self, ctx: Context, agent_handoffs, tool_msgs: list 355 | ) -> None: 356 | """Process agent handoff tool calls.""" 357 | if len(agent_handoffs) > 1: 358 | # Multiple handoffs - return error 359 | 360 | handoff_names = [h.tool_name for h in agent_handoffs] 361 | 362 | for handoff in agent_handoffs: 363 | tool_msgs.append( 364 | ChatMessage( 365 | role="tool", 366 | content=f"Multiple agent handoff tools selected: {', '.join(handoff_names)} - please select only one.", 367 | additional_kwargs={ 368 | "tool_call_id": handoff.tool_id, 369 | "name": handoff.tool_name, 370 | }, 371 | ) 372 | ) 373 | elif len(agent_handoffs) == 1: 374 | # Process single handoff 375 | await self._process_agent_handoff(ctx, agent_handoffs[0], tool_msgs) 376 | 377 | async def _process_agent_handoff( 378 | self, ctx: Context, handoff, tool_msgs: list 379 | ) -> None: 380 | """Process a single agent handoff.""" 381 | handoff_agent = handoff.tool_name.removeprefix("transfer_to_") 382 | 383 | agent = self.agents_by_name.get(handoff_agent) 384 | if not agent: 385 | 386 | tool_msgs.append( 387 | ChatMessage( 388 | role="tool", 389 | content=f"Agent {handoff.tool_name} does not exist", 390 | additional_kwargs={ 391 | "tool_call_id": handoff.tool_id, 392 | "name": handoff.tool_name, 393 | }, 394 | ) 395 | ) 396 | return 397 | 398 | # Extract handoff parameters 399 | parameters = handoff.tool_kwargs 400 | task = parameters.get("task") 401 | reason = parameters.get("reason") 402 | 403 | # Add success message 404 | tool_msgs.append( 405 | ChatMessage( 406 | role="tool", 407 | content=f"Transitioned to {agent.name}. Your task is: `{task}`, reason: `{reason}`", 408 | additional_kwargs={ 409 | "tool_call_id": handoff.tool_id, 410 | "name": handoff.tool_name, 411 | }, 412 | ) 413 | ) 414 | # this adds tool_msgs to the memory and clears tool_msgs 415 | await self._update_memory(ctx, tool_msgs) 416 | 417 | # Run the agent 418 | await self._run_agent(ctx, agent) 419 | # Add handoff back messages if needed 420 | if self.add_handoff_back_messages: 421 | handoff_messages = create_handoff_back_messages( 422 | agent_name=agent.name, supervisor_name=self.name 423 | ) 424 | tool_msgs.extend(handoff_messages) 425 | # this adds handoff_messages to the memory 426 | await self._update_memory(ctx, tool_msgs) 427 | 428 | async def _run_agent( 429 | self, ctx: Context, agent: BaseWorkflowAgent | Workflow 430 | ) -> None: 431 | """Run an agent with the current context.""" 432 | 433 | new_ctx = Context(agent) 434 | 435 | memory: ChatMemoryBuffer = await ctx.get("memory") 436 | new_memory = memory.model_copy() 437 | 438 | # Create a new chat_store instance (assuming it has a copy method or constructor) 439 | new_memory.chat_store = memory.chat_store.model_copy() # or manually copy if needed 440 | # Shallow copy the lists in store 441 | new_memory.chat_store.store = {key: value[:] for key, value in memory.chat_store.store.items()} 442 | 443 | await new_ctx.set("memory", new_memory) 444 | 445 | # Run the agent 446 | await agent.run(ctx=new_ctx, chat_history=new_memory.get()) 447 | new_memory = await new_ctx.get("memory") 448 | if self.name_addition: 449 | self._add_name_to_messages(new_memory.get_all(), agent, start_range=len(memory.get_all())) 450 | 451 | # Update supervisor memory with agent's memory 452 | if self.output_mode == "full_history": 453 | memo: ChatMemoryBuffer = new_memory 454 | await ctx.set("memory", memo) 455 | elif self.output_mode == "last_message": 456 | memo: ChatMemoryBuffer = await new_ctx.get("memory") 457 | last_message = memo.get_all()[-1] 458 | await memory.aput(last_message) 459 | def _add_name_to_messages(self, messages: list[ChatMessage], agent: BaseWorkflowAgent, start_range: int) -> None: 460 | """Add agent name to messages.""" 461 | for message in messages[start_range:]: 462 | # handle none message.content inside the if statement 463 | if message.role == "assistant" and not re.search(r".*?.*?", message.content or ""): 464 | add_inline_agent_name(message, agent.name) 465 | 466 | async def _update_memory(self, ctx: Context, messages: list[ChatMessage]) -> None: 467 | """Update memory with the provided messages.""" 468 | if not messages: 469 | return 470 | memory = await ctx.get("memory") 471 | for msg in messages: 472 | await memory.aput(msg) 473 | messages.clear() # Empty the list after processing 474 | await ctx.set("memory", memory) 475 | 476 | async def _get_input_event(self, ctx: Context) -> InputEvent: 477 | """Get an input event from the current memory.""" 478 | memory = await ctx.get("memory") 479 | return InputEvent(input=memory.get()) 480 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "llama_index_supervisor" 3 | version = "0.1.11" 4 | description = "This project is a Python-based multi-agent supervisor inspired by LangGraph, adapted for llama_index. It delegates tasks to specialized agents and tools, processes function-calling LLM responses, and manages conversation history. The structured workflow handles tool calls, processes agent handoffs with error handling, and provides flexible message management for hierarchical multi-agent systems." 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "llama-index-core>=0.12.27", 9 | ] 10 | 11 | [tool.setuptools] 12 | packages = ["llama_index_supervisor"] --------------------------------------------------------------------------------