├── .gitignore ├── LICENSE ├── README.md ├── TinyAgent.zip ├── figs └── tinyagent.png ├── requirements.txt ├── run_tiny_agent_server.py └── src ├── agents ├── agent.py ├── structured_chat_agent.py └── tools.py ├── callbacks └── callbacks.py ├── chains ├── chain.py └── llm_chain.py ├── executors ├── agent_executor.py └── schema.py ├── llm_compiler ├── constants.py ├── llm_compiler.py ├── output_parser.py ├── planner.py └── task_fetching_unit.py ├── tiny_agent ├── computer.py ├── config.py ├── models.py ├── prompts.py ├── run_apple_script.py ├── sub_agents │ ├── compose_email_agent.py │ ├── notes_agent.py │ ├── pdf_summarizer_agent.py │ └── sub_agent.py ├── tiny_agent.py ├── tiny_agent_tools.py ├── tool_rag │ ├── base_tool_rag.py │ ├── classifier_tool_rag.py │ ├── simple_tool_rag.py │ └── text-embedding-3-small │ │ └── embeddings.pkl ├── tools │ ├── calendar.py │ ├── contacts.py │ ├── mail.py │ ├── maps.py │ ├── notes.py │ ├── reminders.py │ ├── sms.py │ ├── spotlight_search.py │ └── zoom.py └── transcription.py ├── tools └── base.py └── utils ├── data_utils.py ├── graph_utils.py ├── logger_utils.py ├── model_utils.py └── plan_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vs/ 2 | .vscode/ 3 | .idea/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | docs/docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | notebooks/ 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .envrc 112 | .venv 113 | .venvs 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # macOS display setting files 139 | .DS_Store 140 | 141 | # Wandb directory 142 | wandb/ 143 | 144 | # asdf tool versions 145 | .tool-versions 146 | /.ruff_cache/ 147 | 148 | *.pkl 149 | *.bin 150 | 151 | # integration test artifacts 152 | data_map* 153 | \[('_type', 'fake'), ('stop', None)] 154 | 155 | # Replit files 156 | *replit* 157 | 158 | node_modules 159 | docs/.yarn/ 160 | docs/node_modules/ 161 | docs/.docusaurus/ 162 | docs/.cache-loader/ 163 | docs/_dist 164 | docs/api_reference/api_reference.rst 165 | docs/api_reference/experimental_api_reference.rst 166 | docs/api_reference/_build 167 | docs/api_reference/*/ 168 | !docs/api_reference/_static/ 169 | !docs/api_reference/templates/ 170 | !docs/api_reference/themes/ 171 | docs/docs_skeleton/build 172 | docs/docs_skeleton/node_modules 173 | docs/docs_skeleton/yarn.lock 174 | 175 | *.ipynb 176 | 177 | *.xcuserstate 178 | project.xcworkspace/ 179 | xcuserdata/ 180 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SqueezeAILab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TinyAgent: Function Calling at the Edge 2 | 3 | [![Arxiv](https://img.shields.io/badge/arXiv-2409.00608-B31B1B.svg)][#arxiv-paper-package] 4 | [![GitHub license](https://img.shields.io/badge/License-MIT-blu.svg)][#license-gh-package] 5 | 6 | [#license-gh-package]: https://lbesson.mit-license.org/ 7 | [#arxiv-paper-package]: https://arxiv.org/abs/2409.00608 8 | 9 | 10 |

11 | Get the desktop app‎ ‎ 12 | |‎ ‎ 13 | Read the blog post‎ ‎ 14 | |‎ ‎ 15 | Read the paper 16 |

17 | 18 | ![Thumbnail](figs/tinyagent.png) 19 | 20 | TinyAgent aims to enable complex reasoning and function calling capabilities in Small Language Models (SLMs) that can be deployed securely and privately at the edge. Traditional Large Language Models (LLMs) like GPT-4 and Gemini-1.5, while powerful, are often too large and resource-intensive for edge deployment, posing challenges in terms of privacy, connectivity, and latency. TinyAgent addresses these challenges by training specialized SLMs with high-quality, curated data, and focusing on function calling with [LLMCompiler](https://github.com/SqueezeAILab/LLMCompiler). As a driving application, TinyAgent can interact with various MacOS applications, assisting users with day-to-day tasks such as composing emails, managing contacts, scheduling calendar events, and organizing Zoom meetings. 21 | 22 | ## Demo 23 | 24 | 25 | TinyAgent Demo 26 | 27 | 28 | ## What can TinyAgent do? 29 | 30 | TinyAgent is equipped with 16 different functions that can interact with different applications on Mac, which includes: 31 | 32 | ### 📧 Mail 33 | 34 | - **Compose New Email** 35 | - Start a new email with options for adding recipients and attachments. 36 | - _Example Query:_ “Email Sid and Nick about the meeting with attachment project.pdf.” 37 | - **Reply to Emails** 38 | - Respond to received emails, optionally adding new recipients and attachments. 39 | - _Example Query:_ “Reply to Alice's email with the updated budget document attached.” 40 | - **Forward Emails** 41 | - Forward an existing email to other contacts, including additional attachments if necessary. 42 | - _Example Query:_ “Forward the project briefing to the marketing team.” 43 | 44 | ### 📇 Contacts 45 | 46 | - Retrieve phone numbers and email addresses from the contacts database. 47 | - _Example Query:_ “Get John’s phone number” or “Find Alice’s email address.” 48 | 49 | ### 📨 SMS 50 | 51 | - Send text messages to contacts directly from TinyAgent. 52 | - _Example Query:_ “Send an SMS to Canberk saying ‘I’m running late.’” 53 | 54 | ### 📅 Calendar 55 | 56 | - Create new calendar events with specified titles, dates, and times. 57 | - _Example Query:_ “Create an event called 'Meeting with Sid' on Friday at 3 PM.” 58 | 59 | ### 🗺️ Maps 60 | 61 | - Find directions or open map locations for points of interest via Apple Maps. 62 | - _Example Query:_ “Show me directions to the nearest Starbucks.” 63 | 64 | ### 🗒️ Notes 65 | 66 | - Create, open, and append content to notes stored in various folders. 67 | - _Example Query:_ “Create a note called 'Meeting Notes' in the Meetings folder.” 68 | 69 | ### 🗂️ File Management 70 | 71 | - **File Reading** 72 | - Open and read files directly through TinyAgent. 73 | - _Example Query:_ “Open the LLM Compiler.pdf.” 74 | - **PDF Summarization** 75 | - Generate summaries of PDF documents, enhancing content digestion and review efficiency. 76 | - _Example Query:_ “Summarize the document LLM Compiler.pdf and save the summary in my Documents folder.” 77 | 78 | ### ⏰ Reminders 79 | 80 | - Set reminders for various activities or tasks, ensuring nothing is forgotten. 81 | - _Example Query:_ “Remind me to call Sid at 3 PM about the budget approval.” 82 | 83 | ### 🎥 Zoom Meetings 84 | 85 | - Schedule and organize Zoom meetings, including setting names and times. 86 | - _Example Query:_ “Create a Zoom meeting called 'Team Standup' at 10 AM next Monday.” 87 | 88 | ### 💬 Custom Instructions 89 | 90 | - Write and configure specific instructions for your TinyAgent. 91 | - _Example Query:_ “Always cc team members Nick and Sid in all emails.” 92 | 93 | > You can choose to enable/disable certain apps by going to the Preferences window. 94 | 95 | ### 🤖 Sub-Agents 96 | 97 | Depending on the task simplicity, TinyAgent orchestrates the execution of different more specialized or smaller LMs. TinyAgent currently can operate LMs that can summarize a PDF, write emails, or take notes. 98 | 99 | > See the [Customization](#customization) section to see how to add your own sub-agentx. 100 | 101 | ### 🛠️ ToolRAG 102 | 103 | When faced with challenging tasks, SLM agents require appropriate tools and in-context examples to guide them. If the model sees irrelevant examples, it can hallucinate. Likewise, if the model sees the descriptions of the tools that it doesn’t need, it usually gets confused, and these tools take up unnecessary prompt space. To tackle this, TinyAgent uses ToolRAG to retrieve the best tools and examples suited for a given query. This process has minimal latency and increases the accuracy of TinyAgent substantially. Please take a look at our [blog post](https://bair.berkeley.edu/blog/2024/05/29/tiny-agent) and our [ToolRAG model](https://huggingface.co/squeeze-ai-lab/TinyAgent-ToolRAG) for more details. 104 | 105 | > You need to first install our [ToolRAG model](https://huggingface.co/squeeze-ai-lab/TinyAgent-ToolRAG) from Hugging Face and enable it from the TinyAgent settings to use it. 106 | 107 | ### 🎙️ Whisper 108 | 109 | TinyAgent also accepts voice commands through both the OpenAI Whisper API and local [whisper.cpp](https://github.com/ggerganov/whisper.cpp) deployment. For whisper.cpp, you need to setup the [local whisper server](https://github.com/ggerganov/whisper.cpp/tree/master/examples/server) and provide the server port number in the TinyAgent settings. 110 | 111 | ## Providers 112 | 113 | You can use with your OpenAI key, Azure deployments, or even your own local models! 114 | 115 | ### OpenAI 116 | 117 | You need to provide OpenAI API Key and the models you want to use in the 'Preferences' window. 118 | 119 | ### Azure Deployments 120 | 121 | You need to provide your deployment name and the endpoints for the main agent/sub-agents/embedding model as well as the context length of the agent models in the 'Preferences' window. 122 | 123 | ### Local Models 124 | 125 | You can plug-and-play every part of TinyAgent with your local models! TinyAgent can use an OpenAI-compatible server to run models locally. There are several options you can take: 126 | 127 | - **[LMStudio](https://lmstudio.ai/) :** For models already on Huggingface, LMStudio provides an easy-to-use to interface to get started with locally served models. 128 | 129 | - **[llama.cpp server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md):** However, if you want more control over your models, we recommend using the official llama.cpp server to get started. Please read through the tagged documentation to get started with it. 130 | 131 | > All TinyAgent needs is the port numbers that you are serving your model at and its context length. 132 | 133 | ## Fine-tuned TinyAgents 134 | 135 | We also provide our own fine-tuned open source models, TinyAgent-1.1B and TinyAgent-7B! We curated a [dataset](https://huggingface.co/datasets/squeeze-ai-lab/TinyAgent-dataset) of 40000 real-life use cases for TinyAgent and fine-tuned two small open-source language models on this dataset with LoRA. After fine-tuning and using ToolRAG, both TinyAgent-1.1B and TinyAgent-7B exceed the performance of GPT-4-turbo. Check out our for the specifics of dataset generation, evaluation, and fine-tuning. 136 | 137 | | Model | Success Rate | 138 | | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------ | 139 | | GPT-3.5-turbo | 65.04% | 140 | | GPT-4-turbo | 79.08% | 141 | | [TinyLLama-1.1B-32K-Instruct](https://huggingface.co/Doctor-Shotgun/TinyLlama-1.1B-32k-Instruct) | 12.71% | 142 | | [WizardLM-2-7b](https://huggingface.co/MaziyarPanahi/WizardLM-2-7B-GGUF) | 41.25% | 143 | | TinyAgent-1.1B + ToolRAG / [[hf](https://huggingface.co/squeeze-ai-lab/TinyAgent-1.1B)] [[gguf](https://huggingface.co/squeeze-ai-lab/TinyAgent-1.1B-GGUF)] | **80.06%** | 144 | | TinyAgent-7B + ToolRAG / [[hf](https://huggingface.co/squeeze-ai-lab/TinyAgent-7B)] [[gguf](https://huggingface.co/squeeze-ai-lab/TinyAgent-7B-GGUF)] | **84.95%** | 145 | 146 | ## Customization 147 | 148 | You can customize your TinyAgent by going to `~/Library/Application Support/TinyAgent/tinyagent-llmcompiler` directory and changing the code yourself. 149 | 150 | ### Using TinyAgent programmatically 151 | 152 | You can use TinyAgent programmatically by just passing in a config file. 153 | 154 | ```python 155 | from src.tiny_agent.tiny_agent import TinyAgent 156 | from src.tiny_agent.config import get_tiny_agent_config 157 | 158 | config_path = "..." 159 | tiny_agent_config = get_tiny_agent_config(config_path=config_path) 160 | tiny_agent = TinyAgent(tiny_agent_config) 161 | 162 | await TinyAgent.arun(query="Create a meeting with Sid and Lutfi for tomorrow 2pm to discuss the meeting notes.") 163 | ``` 164 | 165 | ### Adding your own tools 166 | 167 | 1. Navigate to `src/tiny_agent/models.py` and add your tool;s name to `TinyAgentToolName(Enum)` 168 | 169 | ```python 170 | class TinyAgentToolName(Enum): 171 | ... 172 | CUSTOM_TOOL_NAME = "custom_tool" 173 | ``` 174 | 175 | 2. Navigate to `src/tiny_agent/tiny_agent_tools.py` and define your tool. 176 | 177 | ```python 178 | def get_custom_tool(...) -> list[Tool]: 179 | async def tool_coroutine(...) -> str: # Needs to return a string 180 | ... 181 | return ... 182 | 183 | custom_tool = Tool( 184 | name=TinyAgentToolName.CUSTOM_TOOL_NAME, 185 | func=tool_coroutine, 186 | description=( 187 | f"{TinyAgentToolName.CUSTOM_TOOL_NAME.value}(...) -> str\n" 188 | "" 189 | ) 190 | ) 191 | 192 | return [custom_tool] 193 | ``` 194 | 195 | 3. Add your tools to the `get_tiny_agent_tools` function. 196 | 197 | ```python 198 | def get_tiny_agent_tools(...): 199 | ... 200 | tools += get_custom_tools(...) 201 | ... 202 | ``` 203 | 204 | > Note: Adding your own tools only works for GPT models since our open-source models and ToolRAG were only fine-tuned on the original TinyAgent toolset. 205 | 206 | ### Adding your own sub-agents 207 | 208 | You can also add your own custom subagents. To do so, please follow these steps: 209 | 210 | 1. All sub-agents inherit from `SubAgent` class in `src/tiny_agent/sub_agents/sub_agent.py`. Your custom agent should inherit from this abstract class and define the `__call__` method. 211 | 212 | ```python 213 | class CustomSubAgent(SubAgent): 214 | 215 | async def __call__(self, ...) -> str: 216 | ... 217 | return response 218 | ``` 219 | 220 | 2. Add your custom agent to `TinyAgent` in `src/tiny_agent/tiny_agent.py` 221 | 222 | ```python 223 | from src.tiny_agent.sub_agents.custom_agent import CustomAgent 224 | 225 | class TinyAgent: 226 | ... 227 | custom_agent: CustomAgent 228 | 229 | def __init__(...): 230 | ... 231 | self.custom_agent = CustomAgent( 232 | sub_agent_llm, config.sub_agent_config, config.custom_instructions 233 | ) 234 | ... 235 | ``` 236 | 237 | 3. After defining your custom agent and adding it to TinyAgent, you should create a tool that calls this agent. Please refer to the [Adding your own tools](#adding-your-own-tools) section to see how to do so 238 | 239 | ## Citation 240 | 241 | We would appreciate it if you could please cite our [paper](https://arxiv.org/pdf/2409.00608) if you found TinyAgent useful for your work: 242 | 243 | ``` 244 | @misc{erdogan2024tinyagentfunctioncallingedge, 245 | title={TinyAgent: Function Calling at the Edge}, 246 | author={Lutfi Eren Erdogan and Nicholas Lee and Siddharth Jha and Sehoon Kim and Ryan Tabrizi and Suhong Moon and Coleman Hooper and Gopala Anumanchipalli and Kurt Keutzer and Amir Gholami}, 247 | year={2024}, 248 | eprint={2409.00608}, 249 | archivePrefix={arXiv}, 250 | primaryClass={cs.CL}, 251 | url={https://arxiv.org/abs/2409.00608}, 252 | } 253 | ``` 254 | -------------------------------------------------------------------------------- /TinyAgent.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SqueezeAILab/TinyAgent/cc45c0e842f5d163c3df1c8f41d60e90e005867d/TinyAgent.zip -------------------------------------------------------------------------------- /figs/tinyagent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SqueezeAILab/TinyAgent/cc45c0e842f5d163c3df1c8f41d60e90e005867d/figs/tinyagent.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyMuPDF==1.23.26 2 | bs4==0.0.1 3 | fastapi==0.110.1 4 | langchain-community==0.0.9 5 | langchain-core==0.1.7 6 | langchain-openai==0.0.2 7 | langchain==0.1.0 8 | numexpr==2.8.7 9 | openai==1.6.1 10 | protobuf==5.26.1 11 | python-dateutil==2.8.2 12 | sentencepiece==0.2.0 13 | tiktoken==0.5.2 14 | torch==2.2.2 15 | transformers==4.38.2 16 | uvicorn==0.29.0 17 | python-multipart==0.0.9 18 | httpx==0.27.0 19 | -------------------------------------------------------------------------------- /run_tiny_agent_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import signal 4 | from http import HTTPStatus 5 | from typing import cast 6 | 7 | from fastapi import FastAPI, HTTPException, Request, Response 8 | from fastapi.exceptions import HTTPException 9 | from fastapi.responses import PlainTextResponse, StreamingResponse 10 | from pydantic import BaseModel 11 | from starlette.datastructures import UploadFile 12 | from starlette.exceptions import HTTPException as StarletteHTTPException 13 | 14 | from src.tiny_agent.config import get_tiny_agent_config 15 | from src.tiny_agent.models import ( 16 | LLM_ERROR_TOKEN, 17 | TINY_AGENT_DIR, 18 | ModelType, 19 | streaming_queue, 20 | ) 21 | from src.tiny_agent.tiny_agent import TinyAgent 22 | from src.tiny_agent.transcription import ( 23 | TranscriptionService, 24 | WhisperCppClient, 25 | WhisperOpenAIClient, 26 | ) 27 | from src.utils.logger_utils import enable_logging, enable_logging_to_file, log 28 | 29 | enable_logging(False) 30 | enable_logging_to_file(True) 31 | 32 | CONFIG_PATH = os.path.join(TINY_AGENT_DIR, "Configuration.json") 33 | 34 | app = FastAPI() 35 | 36 | 37 | def empty_queue(q: asyncio.Queue) -> None: 38 | while not q.empty(): 39 | try: 40 | q.get_nowait() 41 | q.task_done() 42 | except asyncio.QueueEmpty: 43 | # Handle the case where the queue is already empty 44 | break 45 | 46 | 47 | class TinyAgentRequest(BaseModel): 48 | query: str 49 | 50 | 51 | @app.exception_handler(StarletteHTTPException) 52 | async def custom_http_exception_handler(request, exc): 53 | """ 54 | Custom error handling for logging the errors to the TinyAgent log file. 55 | """ 56 | log(f"HTTPException {exc.status_code}: {exc.detail}") 57 | return PlainTextResponse(exc.detail, status_code=exc.status_code) 58 | 59 | 60 | @app.post("/generate") 61 | async def execute_command(request: TinyAgentRequest) -> StreamingResponse: 62 | """ 63 | This is the main endpoint that calls the TinyAgent to generate a response to the given query. 64 | """ 65 | log(f"\n\n====\nReceived request: {request.query}") 66 | 67 | # First, ensure the queue is empty 68 | empty_queue(streaming_queue) 69 | 70 | query = request.query 71 | 72 | if not query or len(query) <= 0: 73 | raise HTTPException( 74 | status_code=HTTPStatus.BAD_REQUEST, detail="No query provided" 75 | ) 76 | 77 | try: 78 | tiny_agent_config = get_tiny_agent_config(config_path=CONFIG_PATH) 79 | tiny_agent = TinyAgent(tiny_agent_config) 80 | except Exception as e: 81 | raise HTTPException( 82 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR, 83 | detail=f"Error: {e}", 84 | ) 85 | 86 | async def generate(): 87 | try: 88 | response_task = asyncio.create_task(tiny_agent.arun(query)) 89 | 90 | while True: 91 | # Await a small timeout to periodically check if the task is done 92 | try: 93 | token = await asyncio.wait_for(streaming_queue.get(), timeout=1.0) 94 | if token is None: 95 | break 96 | if token.startswith(LLM_ERROR_TOKEN): 97 | raise Exception(token[len(LLM_ERROR_TOKEN) :]) 98 | yield token 99 | except asyncio.TimeoutError: 100 | pass # No new token, check task status 101 | 102 | # Check if the task is done to handle any potential exception 103 | if response_task.done(): 104 | break 105 | 106 | # Task created with asyncio.create_task() do not propagate exceptions 107 | # to the calling context. Instead, the exception remains encapsulated within 108 | # the task object itself until the task is awaited or its result is explicitly retrieved. 109 | # Hence, we check here if the task has an exception set by awaiting it, which will 110 | # raise the exception if it exists. If it doesn't, we just yield the result. 111 | await response_task 112 | response = response_task.result() 113 | yield f"\n\n{response}" 114 | except Exception as e: 115 | # You cannot raise HTTPExceptions in an async generator, it doesn't 116 | # get caught by the FastAPI exception handling middleware. Hence, 117 | # we are manually catching the exceptions and yielding/logging them. 118 | yield f"Error: {e}" 119 | log(f"Error: {e}") 120 | 121 | return StreamingResponse(generate(), media_type="text/event-stream") 122 | 123 | 124 | @app.post("/voice") 125 | async def get_voice_transcription(request: Request) -> Response: 126 | """ 127 | This endpoint call whisper to get voice transcription. It takes in bytes of audio 128 | returns the transcription in plain text. 129 | """ 130 | log("\n\n====\nReceived request to get voice transcription") 131 | 132 | body = await request.form() 133 | audio_file = cast(UploadFile, body["audio_pcm"]) 134 | sample_rate = int(cast(str, body["sample_rate"])) 135 | raw_bytes = await audio_file.read() 136 | 137 | if not raw_bytes or len(raw_bytes) <= 0: 138 | raise HTTPException( 139 | status_code=HTTPStatus.BAD_REQUEST, detail="No audio provided" 140 | ) 141 | if not sample_rate: 142 | raise HTTPException( 143 | status_code=HTTPStatus.BAD_REQUEST, detail="No sampling rate provided" 144 | ) 145 | 146 | try: 147 | tiny_agent_config = get_tiny_agent_config(config_path=CONFIG_PATH) 148 | except Exception as e: 149 | raise HTTPException( 150 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR, 151 | detail=f"Error: {e}", 152 | ) 153 | 154 | whisper_client = ( 155 | WhisperOpenAIClient(tiny_agent_config) 156 | if tiny_agent_config.whisper_config.provider == ModelType.OPENAI 157 | else WhisperCppClient(tiny_agent_config) 158 | ) 159 | 160 | transcription_service = TranscriptionService(whisper_client) 161 | 162 | try: 163 | transcription = await transcription_service.transcribe(raw_bytes, sample_rate) 164 | except Exception as e: 165 | raise HTTPException( 166 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR, 167 | detail=f"Error: {e}", 168 | ) 169 | 170 | return Response(transcription, status_code=HTTPStatus.OK) 171 | 172 | 173 | @app.post("/quit") 174 | async def shutdown_server() -> Response: 175 | """ 176 | Shuts down the server by sending a SIGINT signal to the main process, 177 | which is a gentle way to terminate the server. This endpoint should be 178 | protected in real applications to prevent unauthorized shutdowns. 179 | """ 180 | os.kill(os.getpid(), signal.SIGTERM) 181 | return Response("Server is shutting down...", status_code=HTTPStatus.OK) 182 | 183 | 184 | @app.get("/ping") 185 | async def ping() -> Response: 186 | """ 187 | A simple endpoint to check if the server is running. 188 | """ 189 | return Response("pong", status_code=HTTPStatus.OK) 190 | 191 | 192 | if __name__ == "__main__": 193 | import uvicorn 194 | 195 | uvicorn.run(app, host="127.0.0.1", port=50001) 196 | -------------------------------------------------------------------------------- /src/agents/agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | from abc import abstractmethod 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 8 | 9 | import yaml 10 | from langchain.agents.agent import AgentOutputParser, BaseSingleActionAgent 11 | from langchain.agents.agent_types import AgentType 12 | from langchain.callbacks.base import BaseCallbackManager 13 | from langchain.callbacks.manager import Callbacks 14 | from langchain.prompts.few_shot import FewShotPromptTemplate 15 | from langchain.prompts.prompt import PromptTemplate 16 | from langchain.pydantic_v1 import root_validator 17 | from langchain.schema import ( 18 | AgentAction, 19 | AgentFinish, 20 | BaseOutputParser, 21 | BasePromptTemplate, 22 | ) 23 | from langchain.schema.language_model import BaseLanguageModel 24 | from langchain.schema.messages import BaseMessage 25 | from langchain.tools import BaseTool 26 | 27 | from src.chains.llm_chain import LLMChain 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class AgentOutputParser(BaseOutputParser): 33 | """Base class for parsing agent output into agent action/finish.""" 34 | 35 | @abstractmethod 36 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]: 37 | """Parse text into agent action/finish.""" 38 | 39 | 40 | class Agent(BaseSingleActionAgent): 41 | """Agent that calls the language model and deciding the action. 42 | 43 | This is driven by an LLMChain. The prompt in the LLMChain MUST include 44 | a variable called "agent_scratchpad" where the agent can put its 45 | intermediary work. 46 | 47 | Copied from Langchain v0.0.283, 48 | but merged with the parent class BaseSingleActionAgent for simplicity. 49 | """ 50 | 51 | llm_chain: LLMChain 52 | output_parser: AgentOutputParser 53 | allowed_tools: Optional[List[str]] = None 54 | 55 | @property 56 | def _agent_type(self) -> str: 57 | """Return Identifier of agent type.""" 58 | raise NotImplementedError 59 | 60 | def dict(self, **kwargs: Any) -> Dict: 61 | """Return dictionary representation of agent.""" 62 | _dict = super().dict() 63 | _type = self._agent_type 64 | if isinstance(_type, AgentType): 65 | _dict["_type"] = str(_type.value) 66 | else: 67 | _dict["_type"] = _type 68 | del _dict["output_parser"] 69 | return _dict 70 | 71 | def get_allowed_tools(self) -> Optional[List[str]]: 72 | return self.allowed_tools 73 | 74 | @property 75 | def return_values(self) -> List[str]: 76 | return ["output"] 77 | 78 | def _fix_text(self, text: str) -> str: 79 | """Fix the text.""" 80 | raise ValueError("fix_text not implemented for this agent.") 81 | 82 | @property 83 | def _stop(self) -> List[str]: 84 | return [ 85 | f"\n{self.observation_prefix.rstrip()}", 86 | f"\n\t{self.observation_prefix.rstrip()}", 87 | ] 88 | 89 | def _construct_scratchpad( 90 | self, intermediate_steps: List[Tuple[AgentAction, str]] 91 | ) -> Union[str, List[BaseMessage]]: 92 | """Construct the scratchpad that lets the agent continue its thought process.""" 93 | thoughts = "" 94 | for action, observation in intermediate_steps: 95 | thoughts += action.log 96 | thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" 97 | return thoughts 98 | 99 | def plan( 100 | self, 101 | intermediate_steps: List[Tuple[AgentAction, str]], 102 | callbacks: Callbacks = None, 103 | **kwargs: Any, 104 | ) -> Union[AgentAction, AgentFinish]: 105 | """Given input, decided what to do. 106 | 107 | Args: 108 | intermediate_steps: Steps the LLM has taken to date, 109 | along with observations 110 | callbacks: Callbacks to run. 111 | **kwargs: User inputs. 112 | 113 | Returns: 114 | Action specifying what tool to use. 115 | """ 116 | try: 117 | full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) 118 | full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) 119 | return self.output_parser.parse(full_output) 120 | except Exception as e: 121 | full_inputs["agent_scratchpad"] = ( 122 | full_inputs["agent_scratchpad"] + full_output + "\nAction: " 123 | ) 124 | full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) 125 | return self.output_parser.parse("Action: " + full_output) 126 | 127 | async def aplan( 128 | self, 129 | intermediate_steps: List[Tuple[AgentAction, str]], 130 | callbacks: Callbacks = None, 131 | **kwargs: Any, 132 | ) -> Union[AgentAction, AgentFinish]: 133 | """Given input, decided what to do. 134 | 135 | Args: 136 | intermediate_steps: Steps the LLM has taken to date, 137 | along with observations 138 | callbacks: Callbacks to run. 139 | **kwargs: User inputs. 140 | 141 | Returns: 142 | Action specifying what tool to use. 143 | """ 144 | try: 145 | full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) 146 | full_output = await self.llm_chain.apredict( 147 | callbacks=callbacks, **full_inputs 148 | ) 149 | agent_output = await self.output_parser.aparse(full_output) 150 | except Exception as e: 151 | full_inputs["agent_scratchpad"] = ( 152 | full_inputs["agent_scratchpad"] + full_output + "\nAction: " 153 | ) 154 | full_output = await self.llm_chain.apredict( 155 | callbacks=callbacks, **full_inputs 156 | ) 157 | agent_output = await self.output_parser.aparse("Action: " + full_output) 158 | 159 | return agent_output 160 | 161 | def get_full_inputs( 162 | self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any 163 | ) -> Dict[str, Any]: 164 | """Create the full inputs for the LLMChain from intermediate steps.""" 165 | thoughts = self._construct_scratchpad(intermediate_steps) 166 | new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} 167 | full_inputs = {**kwargs, **new_inputs} 168 | return full_inputs 169 | 170 | @property 171 | def input_keys(self) -> List[str]: 172 | """Return the input keys. 173 | 174 | :meta private: 175 | """ 176 | return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) 177 | 178 | @root_validator() 179 | def validate_prompt(cls, values: Dict) -> Dict: 180 | """Validate that prompt matches format.""" 181 | prompt = values["llm_chain"].prompt 182 | if "agent_scratchpad" not in prompt.input_variables: 183 | logger.warning( 184 | "`agent_scratchpad` should be a variable in prompt.input_variables." 185 | " Did not find it, so adding it at the end." 186 | ) 187 | prompt.input_variables.append("agent_scratchpad") 188 | if isinstance(prompt, PromptTemplate): 189 | prompt.template += "\n{agent_scratchpad}" 190 | elif isinstance(prompt, FewShotPromptTemplate): 191 | prompt.suffix += "\n{agent_scratchpad}" 192 | else: 193 | raise ValueError(f"Got unexpected prompt type {type(prompt)}") 194 | return values 195 | 196 | @property 197 | @abstractmethod 198 | def observation_prefix(self) -> str: 199 | """Prefix to append the observation with.""" 200 | 201 | @property 202 | @abstractmethod 203 | def llm_prefix(self) -> str: 204 | """Prefix to append the LLM call with.""" 205 | 206 | @classmethod 207 | def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: 208 | """Validate that appropriate tools are passed in.""" 209 | pass 210 | 211 | @classmethod 212 | @abstractmethod 213 | def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: 214 | """Get default output parser for this class.""" 215 | 216 | @classmethod 217 | def from_llm_and_tools( 218 | cls, 219 | llm: BaseLanguageModel, 220 | tools: Sequence[BaseTool], 221 | prompt: BasePromptTemplate, 222 | callback_manager: Optional[BaseCallbackManager] = None, 223 | output_parser: Optional[AgentOutputParser] = None, 224 | **kwargs: Any, 225 | ) -> Agent: 226 | """Construct an agent from an LLM and tools.""" 227 | cls._validate_tools(tools) 228 | llm_chain = LLMChain( 229 | llm=llm, 230 | prompt=prompt, 231 | callback_manager=callback_manager, 232 | ) 233 | tool_names = [tool.name for tool in tools] 234 | _output_parser = output_parser or cls._get_default_output_parser() 235 | return cls( 236 | llm_chain=llm_chain, 237 | allowed_tools=tool_names, 238 | output_parser=_output_parser, 239 | **kwargs, 240 | ) 241 | 242 | def return_stopped_response( 243 | self, 244 | early_stopping_method: str, 245 | intermediate_steps: List[Tuple[AgentAction, str]], 246 | **kwargs: Any, 247 | ) -> AgentFinish: 248 | """Return response when agent has been stopped due to max iterations.""" 249 | if early_stopping_method == "force": 250 | # `force` just returns a constant string 251 | return AgentFinish( 252 | {"output": "Agent stopped due to iteration limit or time limit."}, "" 253 | ) 254 | elif early_stopping_method == "generate": 255 | # Generate does one final forward pass 256 | thoughts = "" 257 | for action, observation in intermediate_steps: 258 | thoughts += action.log 259 | thoughts += ( 260 | f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" 261 | ) 262 | # Adding to the previous steps, we now tell the LLM to make a final pred 263 | thoughts += ( 264 | "\n\nI now need to return a final answer based on the previous steps:" 265 | ) 266 | new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} 267 | full_inputs = {**kwargs, **new_inputs} 268 | full_output = self.llm_chain.predict(**full_inputs) 269 | # We try to extract a final answer 270 | parsed_output = self.output_parser.parse(full_output) 271 | if isinstance(parsed_output, AgentFinish): 272 | # If we can extract, we send the correct stuff 273 | return parsed_output 274 | else: 275 | # If we can extract, but the tool is not the final tool, 276 | # we just return the full output 277 | return AgentFinish({"output": full_output}, full_output) 278 | else: 279 | raise ValueError( 280 | "early_stopping_method should be one of `force` or `generate`, " 281 | f"got {early_stopping_method}" 282 | ) 283 | 284 | def tool_run_logging_kwargs(self) -> Dict: 285 | return { 286 | "llm_prefix": self.llm_prefix, 287 | "observation_prefix": self.observation_prefix, 288 | } 289 | 290 | def save(self, file_path: Union[Path, str]) -> None: 291 | """Save the agent. 292 | 293 | Args: 294 | file_path: Path to file to save the agent to. 295 | 296 | Example: 297 | .. code-block:: python 298 | 299 | # If working with agent executor 300 | agent.agent.save(file_path="path/agent.yaml") 301 | """ 302 | # Convert file to Path object. 303 | if isinstance(file_path, str): 304 | save_path = Path(file_path) 305 | else: 306 | save_path = file_path 307 | 308 | directory_path = save_path.parent 309 | directory_path.mkdir(parents=True, exist_ok=True) 310 | 311 | # Fetch dictionary to save 312 | agent_dict = self.dict() 313 | 314 | if save_path.suffix == ".json": 315 | with open(file_path, "w") as f: 316 | json.dump(agent_dict, f, indent=4) 317 | elif save_path.suffix == ".yaml": 318 | with open(file_path, "w") as f: 319 | yaml.dump(agent_dict, f, default_flow_style=False) 320 | else: 321 | raise ValueError(f"{save_path} must be json or yaml") 322 | -------------------------------------------------------------------------------- /src/agents/structured_chat_agent.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, List, Optional, Sequence, Tuple 3 | 4 | from langchain.agents.agent import AgentOutputParser 5 | from langchain.agents.structured_chat.output_parser import ( 6 | StructuredChatOutputParserWithRetries, 7 | ) 8 | from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX 9 | from langchain.callbacks.base import BaseCallbackManager 10 | from langchain.prompts.chat import ( 11 | ChatPromptTemplate, 12 | HumanMessagePromptTemplate, 13 | SystemMessagePromptTemplate, 14 | ) 15 | from langchain.pydantic_v1 import Field 16 | from langchain.schema import AgentAction, BasePromptTemplate 17 | from langchain.schema.language_model import BaseLanguageModel 18 | from langchain.tools import BaseTool 19 | 20 | from src.agents.agent import Agent 21 | from src.chains.llm_chain import LLMChain 22 | 23 | HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" 24 | 25 | 26 | class StructuredChatAgent(Agent): 27 | """Structured Chat Agent.""" 28 | 29 | output_parser: AgentOutputParser = Field( 30 | default_factory=StructuredChatOutputParserWithRetries 31 | ) 32 | """Output parser for the agent.""" 33 | 34 | @property 35 | def observation_prefix(self) -> str: 36 | """Prefix to append the observation with.""" 37 | return "Observation: " 38 | 39 | @property 40 | def llm_prefix(self) -> str: 41 | """Prefix to append the llm call with.""" 42 | return "Thought:" 43 | 44 | def _construct_scratchpad( 45 | self, intermediate_steps: List[Tuple[AgentAction, str]] 46 | ) -> str: 47 | agent_scratchpad = super()._construct_scratchpad(intermediate_steps) 48 | if not isinstance(agent_scratchpad, str): 49 | raise ValueError("agent_scratchpad should be of type string.") 50 | if agent_scratchpad: 51 | return ( 52 | f"This was your previous work " 53 | f"(but I haven't seen any of it! I only see what " 54 | f"you return as final answer):\n{agent_scratchpad}" 55 | ) 56 | else: 57 | return agent_scratchpad 58 | 59 | @classmethod 60 | def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: 61 | pass 62 | 63 | @classmethod 64 | def _get_default_output_parser( 65 | cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any 66 | ) -> AgentOutputParser: 67 | return StructuredChatOutputParserWithRetries.from_llm(llm=llm) 68 | 69 | @property 70 | def _stop(self) -> List[str]: 71 | return ["Observation:"] 72 | 73 | @classmethod 74 | def create_prompt( 75 | cls, 76 | tools: Sequence[BaseTool], 77 | prefix: str = PREFIX, 78 | suffix: str = SUFFIX, 79 | human_message_template: str = HUMAN_MESSAGE_TEMPLATE, 80 | format_instructions: str = FORMAT_INSTRUCTIONS, 81 | input_variables: Optional[List[str]] = None, 82 | memory_prompts: Optional[List[BasePromptTemplate]] = None, 83 | ) -> BasePromptTemplate: 84 | tool_strings = [] 85 | for tool in tools: 86 | args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) 87 | tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") 88 | formatted_tools = "\n".join(tool_strings) 89 | tool_names = ", ".join([tool.name for tool in tools]) 90 | format_instructions = format_instructions.format(tool_names=tool_names) 91 | template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) 92 | if input_variables is None: 93 | input_variables = ["input", "agent_scratchpad"] 94 | _memory_prompts = memory_prompts or [] 95 | messages = [ 96 | SystemMessagePromptTemplate.from_template(template), 97 | *_memory_prompts, 98 | HumanMessagePromptTemplate.from_template(human_message_template), 99 | ] 100 | return ChatPromptTemplate(input_variables=input_variables, messages=messages) 101 | 102 | @classmethod 103 | def from_llm_and_tools( 104 | cls, 105 | llm: BaseLanguageModel, 106 | tools: Sequence[BaseTool], 107 | callback_manager: Optional[BaseCallbackManager] = None, 108 | output_parser: Optional[AgentOutputParser] = None, 109 | prefix: str = PREFIX, 110 | suffix: str = SUFFIX, 111 | human_message_template: str = HUMAN_MESSAGE_TEMPLATE, 112 | format_instructions: str = FORMAT_INSTRUCTIONS, 113 | input_variables: Optional[List[str]] = None, 114 | memory_prompts: Optional[List[BasePromptTemplate]] = None, 115 | **kwargs: Any, 116 | ) -> Agent: 117 | """Construct an agent from an LLM and tools.""" 118 | cls._validate_tools(tools) 119 | prompt = cls.create_prompt( 120 | tools, 121 | prefix=prefix, 122 | suffix=suffix, 123 | human_message_template=human_message_template, 124 | format_instructions=format_instructions, 125 | input_variables=input_variables, 126 | memory_prompts=memory_prompts, 127 | ) 128 | llm_chain = LLMChain( 129 | llm=llm, 130 | prompt=prompt, 131 | callback_manager=callback_manager, 132 | ) 133 | tool_names = [tool.name for tool in tools] 134 | _output_parser = output_parser or cls._get_default_output_parser(llm=llm) 135 | return cls( 136 | llm_chain=llm_chain, 137 | allowed_tools=tool_names, 138 | output_parser=_output_parser, 139 | **kwargs, 140 | ) 141 | 142 | @property 143 | def _agent_type(self) -> str: 144 | raise ValueError 145 | -------------------------------------------------------------------------------- /src/agents/tools.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from langchain.callbacks.manager import ( 4 | AsyncCallbackManagerForToolRun, 5 | CallbackManagerForToolRun, 6 | ) 7 | from langchain.tools import BaseTool 8 | 9 | from src.tools.base import Tool, tool 10 | 11 | 12 | class InvalidTool(BaseTool): 13 | """Tool that is run when invalid tool name is encountered by agent.""" 14 | 15 | name: str = "invalid_tool" 16 | description: str = "Called when tool name is invalid. Suggests valid tool names." 17 | 18 | def _run( 19 | self, 20 | requested_tool_name: str, 21 | available_tool_names: List[str], 22 | run_manager: Optional[CallbackManagerForToolRun] = None, 23 | ) -> str: 24 | """Use the tool.""" 25 | available_tool_names_str = ", ".join([tool for tool in available_tool_names]) 26 | return ( 27 | f"{requested_tool_name} is not a valid tool, " 28 | f"try one of [{available_tool_names_str}]." 29 | ) 30 | 31 | async def _arun( 32 | self, 33 | requested_tool_name: str, 34 | available_tool_names: List[str], 35 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, 36 | ) -> str: 37 | """Use the tool asynchronously.""" 38 | available_tool_names_str = ", ".join([tool for tool in available_tool_names]) 39 | return ( 40 | f"{requested_tool_name} is not a valid tool, " 41 | f"try one of [{available_tool_names_str}]." 42 | ) 43 | 44 | 45 | __all__ = ["InvalidTool", "BaseTool", "tool", "Tool"] 46 | -------------------------------------------------------------------------------- /src/callbacks/callbacks.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tiktoken 4 | from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler 5 | 6 | 7 | class StatsCallbackHandler(BaseCallbackHandler): 8 | """Collect useful stats about the run. 9 | Add more stats as needed.""" 10 | 11 | def __init__(self) -> None: 12 | super().__init__() 13 | self.cnt = 0 14 | self.input_tokens = 0 15 | self.output_tokens = 0 16 | self.all_times = [] 17 | self.start_time = 0 18 | 19 | def on_chat_model_start(self, serialized, prompts, **kwargs): 20 | self.start_time = time.time() 21 | 22 | def on_llm_end(self, response, *args, **kwargs): 23 | token_usage = response.llm_output["token_usage"] 24 | self.input_tokens += token_usage["prompt_tokens"] 25 | self.output_tokens += token_usage["completion_tokens"] 26 | self.cnt += 1 27 | self.all_times.append(round(time.time() - self.start_time, 2)) 28 | 29 | def reset(self) -> None: 30 | self.cnt = 0 31 | self.input_tokens = 0 32 | self.output_tokens = 0 33 | self.all_times = [] 34 | 35 | def get_stats(self) -> dict[str, int]: 36 | return { 37 | "calls": self.cnt, 38 | "input_tokens": self.input_tokens, 39 | "output_tokens": self.output_tokens, 40 | "all_times": self.all_times, 41 | } 42 | 43 | 44 | class AsyncStatsCallbackHandler(AsyncCallbackHandler): 45 | """Collect useful stats about the run. 46 | Add more stats as needed.""" 47 | 48 | def __init__(self, stream: bool = False) -> None: 49 | super().__init__() 50 | self.cnt = 0 51 | self.input_tokens = 0 52 | self.output_tokens = 0 53 | # same for gpt-3.5 54 | self.encoder = tiktoken.encoding_for_model("gpt-4") 55 | self.stream = stream 56 | self.all_times = [] 57 | self.additional_fields = {} 58 | self.start_time = 0 59 | 60 | async def on_chat_model_start(self, serialized, prompts, **kwargs): 61 | self.start_time = time.time() 62 | if self.stream: 63 | # if streaming mode, on_llm_end response is not collected 64 | # therefore, we need to count input token based on the 65 | # prompt length at the beginning 66 | self.cnt += 1 67 | self.input_tokens += len(self.encoder.encode(prompts[0][0].content)) 68 | 69 | async def on_llm_new_token(self, token, *args, **kwargs): 70 | if self.stream: 71 | # if streaming mode, on_llm_end response is not collected 72 | # therefore, we need to manually count output token based on the 73 | # number of streamed out tokens 74 | self.output_tokens += 1 75 | 76 | async def on_llm_end(self, response, *args, **kwargs): 77 | self.all_times.append(round(time.time() - self.start_time, 2)) 78 | if not self.stream: 79 | # if not streaming mode, on_llm_end response is collected 80 | # so we can use this stats directly 81 | token_usage = response.llm_output["token_usage"] 82 | self.input_tokens += token_usage["prompt_tokens"] 83 | self.output_tokens += token_usage["completion_tokens"] 84 | self.cnt += 1 85 | 86 | def reset(self) -> None: 87 | self.cnt = 0 88 | self.input_tokens = 0 89 | self.output_tokens = 0 90 | self.all_times = [] 91 | self.additional_fields = {} 92 | 93 | def get_stats(self) -> dict[str, int]: 94 | return { 95 | "calls": self.cnt, 96 | "input_tokens": self.input_tokens, 97 | "output_tokens": self.output_tokens, 98 | "all_times": self.all_times, 99 | **self.additional_fields, 100 | } 101 | -------------------------------------------------------------------------------- /src/executors/schema.py: -------------------------------------------------------------------------------- 1 | from langchain.pydantic_v1 import BaseModel 2 | 3 | 4 | class Step(BaseModel): 5 | """Step.""" 6 | 7 | value: str 8 | """The value.""" 9 | 10 | 11 | class Plan(BaseModel): 12 | """Plan.""" 13 | 14 | steps: list[Step] 15 | """The steps.""" 16 | 17 | 18 | class StepResponse(BaseModel): 19 | """Step response.""" 20 | 21 | response: str 22 | """The response.""" 23 | -------------------------------------------------------------------------------- /src/llm_compiler/constants.py: -------------------------------------------------------------------------------- 1 | END_OF_PLAN = "" 2 | 3 | JOINNER_FINISH = "Finish" 4 | JOINNER_REPLAN = "Replan" 5 | 6 | SUMMARY_RESULT = "Summary" 7 | -------------------------------------------------------------------------------- /src/llm_compiler/llm_compiler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast 3 | 4 | from langchain.callbacks.manager import ( 5 | AsyncCallbackManagerForChainRun, 6 | CallbackManagerForChainRun, 7 | ) 8 | from langchain.chat_models.base import BaseChatModel 9 | from langchain.llms import BaseLLM 10 | from langchain.llms.base import BaseLLM 11 | from langchain.prompts.base import StringPromptValue 12 | 13 | from src.callbacks.callbacks import AsyncStatsCallbackHandler 14 | from src.chains.chain import Chain 15 | from src.llm_compiler.constants import JOINNER_REPLAN 16 | from src.llm_compiler.planner import Planner 17 | from src.llm_compiler.task_fetching_unit import Task, TaskFetchingUnit 18 | from src.tiny_agent.models import streaming_queue 19 | from src.tools.base import StructuredTool, Tool 20 | from src.utils.logger_utils import log 21 | 22 | 23 | class LLMCompilerAgent: 24 | """Self defined agent for LLM Compiler.""" 25 | 26 | def __init__(self, llm: BaseLLM) -> None: 27 | self.llm = llm 28 | 29 | async def arun(self, prompt: str, callbacks=None) -> str: 30 | response = await self.llm.agenerate_prompt( 31 | prompts=[StringPromptValue(text=prompt)], 32 | stop=[""], 33 | callbacks=callbacks, 34 | ) 35 | if isinstance(self.llm, BaseChatModel): 36 | return response.generations[0][0].message.content 37 | 38 | if isinstance(self.llm, BaseLLM): 39 | return response.generations[0][0].text 40 | 41 | raise ValueError("LLM must be either BaseChatModel or BaseLLM") 42 | 43 | 44 | class LLMCompiler(Chain, extra="allow"): 45 | """LLMCompiler Engine.""" 46 | 47 | """The step container to use.""" 48 | input_key: str = "input" 49 | output_key: str = "output" 50 | 51 | def __init__( 52 | self, 53 | tools: Sequence[Union[Tool, StructuredTool]], 54 | planner_llm: BaseLLM, 55 | planner_example_prompt: str, 56 | planner_example_prompt_replan: Optional[str], 57 | planner_stop: Optional[list[str]], 58 | planner_stream: bool, 59 | agent_llm: BaseLLM, 60 | joinner_prompt: str, 61 | joinner_prompt_final: Optional[str], 62 | max_replans: int, 63 | benchmark: bool, 64 | planner_custom_instructions_prompt: str | None = None, 65 | **kwargs, 66 | ) -> None: 67 | """ 68 | Args: 69 | tools: List of tools to use. 70 | max_replans: Maximum number of replans to do. 71 | benchmark: Whether to collect benchmark stats. 72 | 73 | Planner Args: 74 | planner_llm: LLM to use for planning. 75 | planner_example_prompt: Example prompt for planning. 76 | planner_example_prompt_replan: Example prompt for replanning. 77 | Assign this if you want to use different example prompt for replanning. 78 | If not assigned, default to `planner_example_prompt`. 79 | planner_stop: Stop tokens for planning. 80 | planner_stream: Whether to stream the planning. 81 | 82 | Agent Args: 83 | agent_llm: LLM to use for agent. 84 | joinner_prompt: Prompt to use for joinner. 85 | joinner_prompt_final: Prompt to use for joinner at the final replanning iter. 86 | If not assigned, default to `joinner_prompt`. 87 | """ 88 | super().__init__(**kwargs) 89 | 90 | if not planner_example_prompt_replan: 91 | log( 92 | "Replan example prompt not specified, using the same prompt as the planner." 93 | ) 94 | planner_example_prompt_replan = planner_example_prompt 95 | 96 | self.planner = Planner( 97 | llm=planner_llm, 98 | example_prompt=planner_example_prompt, 99 | example_prompt_replan=planner_example_prompt_replan, 100 | custom_instructions=planner_custom_instructions_prompt, 101 | tools=tools, 102 | stop=planner_stop, 103 | ) 104 | 105 | self.agent = LLMCompilerAgent(agent_llm) 106 | self.joinner_prompt = joinner_prompt 107 | self.joinner_prompt_final = joinner_prompt_final or joinner_prompt 108 | self.planner_stream = planner_stream 109 | self.max_replans = max_replans 110 | 111 | # callbacks 112 | self.benchmark = benchmark 113 | if benchmark: 114 | self.planner_callback = AsyncStatsCallbackHandler(stream=planner_stream) 115 | self.executor_callback = AsyncStatsCallbackHandler(stream=False) 116 | else: 117 | self.planner_callback = None 118 | self.executor_callback = None 119 | 120 | def get_all_stats(self): 121 | stats = {} 122 | if self.benchmark: 123 | stats["planner"] = self.planner_callback.get_stats() 124 | stats["executor"] = self.executor_callback.get_stats() 125 | stats["total"] = { 126 | k: v + stats["executor"].get(k, 0) for k, v in stats["planner"].items() 127 | } 128 | 129 | return stats 130 | 131 | def reset_all_stats(self): 132 | if self.planner_callback: 133 | self.planner_callback.reset() 134 | if self.executor_callback: 135 | self.executor_callback.reset() 136 | 137 | @property 138 | def input_keys(self) -> List[str]: 139 | return [self.input_key] 140 | 141 | @property 142 | def output_keys(self) -> List[str]: 143 | return [self.output_key] 144 | 145 | # TODO(sk): move all join related functions to a separate class 146 | 147 | def _parse_joinner_output(self, raw_answer: str) -> str: 148 | """We expect the joinner output format to be: 149 | ``` 150 | Thought: xxx 151 | Action: Finish/Replan(yyy) 152 | ``` 153 | Returns: 154 | thought (xxx) 155 | answer (yyy) 156 | is_replan (True/False) 157 | """ 158 | thought, answer, is_replan = "", "", False # default values 159 | raw_answers = raw_answer.split("\n") 160 | for ans in raw_answers: 161 | # Get to the index where the answer is, which is the start of 'Action:' 162 | start_of_answer = ans.find("Action:") 163 | if start_of_answer >= 0: 164 | ans = ans[start_of_answer:] 165 | if ans.startswith("Action:") or ans.startswith(" Answer:"): 166 | answer = ans[ans.find("(") + 1 : ans.rfind(")")] 167 | is_replan = JOINNER_REPLAN in ans 168 | elif ans.startswith("Thought:") or ans.startswith(" Thought:"): 169 | thought = ans.split("Thought:")[1].strip() 170 | # if not is_replan: 171 | # return "", raw_answer, is_replan 172 | return thought, answer, is_replan 173 | 174 | def _generate_context_for_replanner( 175 | self, tasks: Mapping[int, Task], joinner_thought: str 176 | ) -> str: 177 | """Formatted like this: 178 | ``` 179 | 1. action 1 180 | Observation: xxx 181 | 2. action 2 182 | Observation: yyy 183 | ... 184 | Thought: joinner_thought 185 | ``` 186 | """ 187 | previous_plan_and_observations = "\n".join( 188 | [ 189 | task.get_though_action_observation( 190 | include_action=True, include_action_idx=True 191 | ) 192 | for task in tasks.values() 193 | if not task.is_join 194 | ] 195 | ) 196 | joinner_thought = f"Thought: {joinner_thought}" 197 | context = "\n\n".join([previous_plan_and_observations, joinner_thought]) 198 | return context 199 | 200 | def _format_contexts(self, contexts: Sequence[str]) -> str: 201 | """contexts is a list of context 202 | each context is formatted as the description of _generate_context_for_replanner 203 | """ 204 | formatted_contexts = "" 205 | for context in contexts: 206 | formatted_contexts += f"Previous Plan:\n\n{context}\n\n" 207 | formatted_contexts += "Current Plan:\n\n" 208 | return formatted_contexts 209 | 210 | async def join( 211 | self, input_query: str, agent_scratchpad: str, is_final: bool 212 | ) -> str: 213 | if is_final: 214 | joinner_prompt = self.joinner_prompt_final 215 | else: 216 | joinner_prompt = self.joinner_prompt 217 | prompt = ( 218 | f"{joinner_prompt}\n" # Instructions and examples 219 | f"Question: {input_query}\n\n" # User input query 220 | f"{agent_scratchpad}\n" # T-A-O 221 | # "---\n" 222 | ) 223 | log("Joining prompt:\n", prompt, block=True) 224 | response = await self.agent.arun( 225 | prompt, callbacks=[self.executor_callback] if self.benchmark else None 226 | ) 227 | raw_answer = cast(str, response) 228 | log("Question: \n", input_query, block=True) 229 | log("Raw Answer: \n", raw_answer, block=True) 230 | thought, answer, is_replan = self._parse_joinner_output(raw_answer) 231 | if is_final: 232 | # If final, we don't need to replan 233 | is_replan = False 234 | return thought, answer, is_replan 235 | 236 | def _call( 237 | self, 238 | inputs: Dict[str, Any], 239 | run_manager: Optional[CallbackManagerForChainRun] = None, 240 | ): 241 | raise NotImplementedError("LLMCompiler is async only.") 242 | 243 | async def _acall( 244 | self, 245 | inputs: Dict[str, Any], 246 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 247 | ) -> Dict[str, Any]: 248 | contexts = [] 249 | joinner_thought = "" 250 | agent_scratchpad = "" 251 | for i in range(self.max_replans): 252 | is_first_iter = i == 0 253 | is_final_iter = i == self.max_replans - 1 254 | 255 | task_fetching_unit = TaskFetchingUnit() 256 | if self.planner_stream: 257 | task_queue = asyncio.Queue() 258 | asyncio.create_task( 259 | self.planner.aplan( 260 | inputs=inputs, 261 | task_queue=task_queue, 262 | is_replan=not is_first_iter, 263 | callbacks=( 264 | [self.planner_callback] if self.planner_callback else None 265 | ), 266 | ) 267 | ) 268 | await task_fetching_unit.aschedule( 269 | task_queue=task_queue, func=lambda x: None 270 | ) 271 | else: 272 | tasks = await self.planner.plan( 273 | inputs=inputs, 274 | is_replan=not is_first_iter, 275 | # callbacks=run_manager.get_child() if run_manager else None, 276 | callbacks=( 277 | [self.planner_callback] if self.planner_callback else None 278 | ), 279 | ) 280 | log("Graph of tasks: ", tasks, block=True) 281 | if self.benchmark: 282 | self.planner_callback.additional_fields["num_tasks"] = len(tasks) 283 | task_fetching_unit.set_tasks(tasks) 284 | await task_fetching_unit.schedule() 285 | tasks = task_fetching_unit.tasks 286 | 287 | # collect thought-action-observation 288 | agent_scratchpad += "\n\n" 289 | agent_scratchpad += "".join( 290 | [ 291 | task.get_though_action_observation( 292 | include_action=True, include_thought=True 293 | ) 294 | for task in tasks.values() 295 | if not task.is_join 296 | # Also allow join tasks with observation which are there to propagate errors from the planning phase 297 | or (task.is_join and task.observation is not None) 298 | ] 299 | ) 300 | agent_scratchpad = agent_scratchpad.strip() 301 | 302 | log("Agent scratchpad:\n", agent_scratchpad, block=True) 303 | joinner_thought, answer, is_replan = await self.join( 304 | inputs["input"], 305 | agent_scratchpad=agent_scratchpad, 306 | is_final=is_final_iter, 307 | ) 308 | if not is_replan: 309 | log("Break out of replan loop.") 310 | break 311 | 312 | # Collect contexts for the subsequent replanner 313 | context = self._generate_context_for_replanner( 314 | tasks=tasks, joinner_thought=joinner_thought 315 | ) 316 | contexts.append(context) 317 | formatted_contexts = self._format_contexts(contexts) 318 | log("Contexts:\n", formatted_contexts, block=True) 319 | inputs["context"] = formatted_contexts 320 | 321 | if is_final_iter: 322 | log("Reached max replan limit.") 323 | 324 | # End the generation request 325 | await streaming_queue.put(None) 326 | 327 | return {self.output_key: answer} 328 | -------------------------------------------------------------------------------- /src/llm_compiler/output_parser.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import re 3 | from typing import Any, Sequence, Union 4 | 5 | from langchain.agents.agent import AgentOutputParser 6 | from langchain.schema import OutputParserException 7 | 8 | from src.llm_compiler.task_fetching_unit import Task 9 | from src.tools.base import StructuredTool, Tool 10 | 11 | THOUGHT_PATTERN = r"Thought: ([^\n]*)" 12 | ACTION_PATTERN = r"\s*\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?" 13 | # $1 or ${1} -> 1 14 | ID_PATTERN = r"\$\{?(\d+)\}?" 15 | 16 | END_OF_PLAN = "" 17 | 18 | 19 | def default_dependency_rule(idx, args: str): 20 | matches = re.findall(ID_PATTERN, args) 21 | numbers = [int(match) for match in matches] 22 | return idx in numbers 23 | 24 | 25 | class LLMCompilerPlanParser(AgentOutputParser, extra="allow"): 26 | """Planning output parser.""" 27 | 28 | def __init__(self, tools: Sequence[Union[Tool, StructuredTool]], **kwargs): 29 | super().__init__(**kwargs) 30 | self.tools = tools 31 | 32 | def parse(self, text: str) -> dict[int, Any]: 33 | # 1. search("Ronaldo number of kids") -> 1, "search", '"Ronaldo number of kids"' 34 | # pattern = r"(\d+)\. (\w+)\(([^)]+)\)" 35 | pattern = rf"(?:{THOUGHT_PATTERN}\n)?{ACTION_PATTERN}" 36 | matches = re.findall(pattern, text) 37 | 38 | graph_dict = {} 39 | 40 | for match in matches: 41 | # idx = 1, function = "search", args = "Ronaldo number of kids" 42 | # thought will be the preceding thought, if any, otherwise an empty string 43 | thought, idx, tool_name, args, _ = match 44 | idx = int(idx) 45 | 46 | task = instantiate_task( 47 | tools=self.tools, 48 | idx=idx, 49 | tool_name=tool_name, 50 | args=args, 51 | thought=thought, 52 | ) 53 | 54 | graph_dict[idx] = task 55 | if task.is_join: 56 | break 57 | 58 | return graph_dict 59 | 60 | 61 | ### Helper functions 62 | 63 | 64 | def _parse_llm_compiler_action_args(args: str) -> list[Any]: 65 | """Parse arguments from a string.""" 66 | # This will convert the string into a python object 67 | # e.g. '"Ronaldo number of kids"' -> ("Ronaldo number of kids", ) 68 | # '"I can answer the question now.", [3]' -> ("I can answer the question now.", [3]) 69 | if args == "": 70 | return () 71 | try: 72 | args = ast.literal_eval(args) 73 | except: 74 | args = args 75 | if not isinstance(args, list) and not isinstance(args, tuple): 76 | args = (args,) 77 | return args 78 | 79 | 80 | def _find_tool( 81 | tool_name: str, tools: Sequence[Union[Tool, StructuredTool]] 82 | ) -> Union[Tool, StructuredTool]: 83 | """Find a tool by name. 84 | 85 | Args: 86 | tool_name: Name of the tool to find. 87 | 88 | Returns: 89 | Tool or StructuredTool. 90 | """ 91 | for tool in tools: 92 | if tool.name == tool_name: 93 | return tool 94 | raise OutputParserException(f"Tool {tool_name} not found") 95 | 96 | 97 | def _get_dependencies_from_graph( 98 | idx: int, tool_name: str, args: Sequence[Any] 99 | ) -> dict[str, list[str]]: 100 | """Get dependencies from a graph.""" 101 | if tool_name == "join": 102 | # depends on the previous step 103 | dependencies = list(range(1, idx)) 104 | else: 105 | # define dependencies based on the dependency rule in tool_definitions.py 106 | dependencies = [i for i in range(1, idx) if default_dependency_rule(i, args)] 107 | 108 | return dependencies 109 | 110 | 111 | def instantiate_task( 112 | tools: Sequence[Union[Tool, StructuredTool]], 113 | idx: int, 114 | tool_name: str, 115 | args: str, 116 | thought: str, 117 | ) -> Task: 118 | dependencies = _get_dependencies_from_graph(idx, tool_name, args) 119 | args = _parse_llm_compiler_action_args(args) 120 | if tool_name == "join": 121 | # join does not have a tool 122 | tool_func = lambda x: None 123 | stringify_rule = None 124 | else: 125 | tool = _find_tool(tool_name, tools) 126 | tool_func = tool.func 127 | stringify_rule = tool.stringify_rule 128 | return Task( 129 | idx=idx, 130 | name=tool_name, 131 | tool=tool_func, 132 | args=args, 133 | dependencies=dependencies, 134 | stringify_rule=stringify_rule, 135 | thought=thought, 136 | is_join=tool_name == "join", 137 | ) 138 | -------------------------------------------------------------------------------- /src/llm_compiler/planner.py: -------------------------------------------------------------------------------- 1 | """LLM Compiler Planner""" 2 | 3 | import asyncio 4 | import re 5 | from typing import Any, List, Optional, Sequence, Union 6 | from uuid import UUID 7 | 8 | from langchain.callbacks.base import AsyncCallbackHandler, Callbacks 9 | from langchain.chat_models.base import BaseChatModel 10 | from langchain.llms.base import BaseLLM 11 | from langchain.schema import LLMResult 12 | from langchain.schema.messages import HumanMessage, SystemMessage 13 | 14 | from src.executors.schema import Plan 15 | from src.llm_compiler.constants import END_OF_PLAN 16 | from src.llm_compiler.output_parser import ( 17 | ACTION_PATTERN, 18 | THOUGHT_PATTERN, 19 | LLMCompilerPlanParser, 20 | instantiate_task, 21 | ) 22 | from src.llm_compiler.task_fetching_unit import Task 23 | from src.tiny_agent.models import LLM_ERROR_TOKEN, streaming_queue 24 | from src.tools.base import StructuredTool, Tool 25 | from src.utils.logger_utils import log 26 | 27 | JOIN_DESCRIPTION = ( 28 | "join():\n" 29 | " - Collects and combines results from prior actions.\n" 30 | " - A LLM agent is called upon invoking join to either finalize the user query or wait until the plans are executed.\n" 31 | " - join should always be the last action in the plan, and will be called in two scenarios:\n" 32 | " (a) if the answer can be determined by gathering the outputs from tasks to generate the final response.\n" 33 | " (b) if the answer cannot be determined in the planning phase before you execute the plans. " 34 | ) 35 | 36 | 37 | def generate_llm_compiler_prompt( 38 | tools: Sequence[Union[Tool, StructuredTool]], 39 | example_prompt: str, 40 | custom_instructions: str | None, 41 | is_replan: bool = False, 42 | ): 43 | prefix = ( 44 | "Given a user query, create a plan to solve it with the utmost parallelizability. " 45 | f"Each plan should comprise an action from the following {len(tools) + 1} types:\n" 46 | ) 47 | 48 | # Tools 49 | for i, tool in enumerate(tools): 50 | prefix += f"{i+1}. {tool.description}\n" 51 | 52 | # Join operation 53 | prefix += f"{i+2}. {JOIN_DESCRIPTION}\n\n" 54 | 55 | # Guidelines 56 | prefix += ( 57 | "Guidelines:\n" 58 | " - Each action described above contains input/output types and description.\n" 59 | " - You must strictly adhere to the input and output types for each action.\n" 60 | " - The action descriptions contain the guidelines. You MUST strictly follow those guidelines when you use the actions.\n" 61 | " - Each action in the plan should strictly be one of the above types. Follow the Python conventions for each action.\n" 62 | " - Each action MUST have a unique ID, which is strictly increasing.\n" 63 | " - Inputs for actions can either be constants or outputs from preceding actions. " 64 | "In the latter case, use the format $id to denote the ID of the previous action whose output will be the input.\n" 65 | f" - Always call join as the last action in the plan. Say '{END_OF_PLAN}' after you call join\n" 66 | " - Ensure the plan maximizes parallelizability.\n" 67 | " - Only use the provided action types. If a query cannot be addressed using these, invoke the join action for the next steps.\n" 68 | " - Never explain the plan with comments (e.g. #).\n" 69 | " - Never introduce new actions other than the ones provided.\n\n" 70 | ) 71 | 72 | if custom_instructions: 73 | prefix += f"{custom_instructions}\n\n" 74 | 75 | if is_replan: 76 | prefix += ( 77 | ' - You are given "Previous Plan" which is the plan that the previous agent created along with the execution results ' 78 | "(given as Observation) of each plan and a general thought (given as Thought) about the executed results." 79 | 'You MUST use these information to create the next plan under "Current Plan".\n' 80 | ' - When starting the Current Plan, you should start with "Thought" that outlines the strategy for the next plan.\n' 81 | " - In the Current Plan, you should NEVER repeat the actions that are already executed in the Previous Plan.\n" 82 | ) 83 | 84 | # Examples 85 | prefix += "Here are some examples:\n\n" 86 | prefix += example_prompt 87 | 88 | return prefix 89 | 90 | 91 | class StreamingGraphParser: 92 | """Streaming version of the GraphParser.""" 93 | 94 | buffer = "" 95 | thought = "" 96 | graph_dict = {} 97 | 98 | def __init__(self, tools: Sequence[Union[Tool, StructuredTool]]) -> None: 99 | self.tools = tools 100 | 101 | def _match_buffer_and_generate_task(self, suffix: str) -> Optional[Task]: 102 | """Runs every time "\n" is encountered in the input stream or at the end of the stream. 103 | Matches the buffer against the regex patterns and generates a task if a match is found. 104 | Match patterns include: 105 | 1. Thought: 106 | - this case, the thought is stored in self.thought, and we reset the buffer. 107 | - the thought is then used as the thought for the next action. 108 | 2. . () 109 | - this case, the tool is instantiated with the idx, tool_name, args, and thought. 110 | - the thought is reset. 111 | - the buffer is reset. 112 | """ 113 | if match := re.match(THOUGHT_PATTERN, self.buffer): 114 | # Optionally, action can be preceded by a thought 115 | self.thought = match.group(1) 116 | elif match := re.match(ACTION_PATTERN, self.buffer): 117 | # if action is parsed, return the task, and clear the buffer 118 | idx, tool_name, args, _ = match.groups() 119 | idx = int(idx) 120 | task = instantiate_task( 121 | tools=self.tools, 122 | idx=idx, 123 | tool_name=tool_name, 124 | args=args, 125 | thought=self.thought, 126 | ) 127 | self.thought = "" 128 | return task 129 | 130 | return None 131 | 132 | def ingest_token(self, token: str) -> Optional[Task]: 133 | # Append token to buffer 134 | if "\n" in token: 135 | prefix, suffix = token.split("\n", 1) 136 | prefix = prefix.strip() 137 | self.buffer += prefix + "\n" 138 | matched_item = self._match_buffer_and_generate_task(suffix) 139 | self.buffer = suffix 140 | return matched_item 141 | else: 142 | self.buffer += token 143 | 144 | return None 145 | 146 | def finalize(self): 147 | self.buffer = self.buffer + "\n" 148 | return self._match_buffer_and_generate_task("") 149 | 150 | 151 | class TinyAgentEarlyStop(BaseException): 152 | # Defining this as a BaseException to differentiate it from other Exception's that are 153 | # fatal. This is a controlled stop and not an error. 154 | generations = [] 155 | llm_output = "" 156 | 157 | 158 | class LLMCompilerCallback(AsyncCallbackHandler): 159 | _queue: asyncio.Queue[Optional[Task]] 160 | _parser: StreamingGraphParser 161 | _tools: Sequence[Union[Tool, StructuredTool]] 162 | _curr_idx: int 163 | 164 | def __init__( 165 | self, 166 | queue: asyncio.Queue[Optional[str]], 167 | tools: Sequence[Union[Tool, StructuredTool]], 168 | ): 169 | self._queue = queue 170 | self._parser = StreamingGraphParser(tools=tools) 171 | self._tools = tools 172 | self._curr_idx = 0 173 | 174 | async def on_llm_start(self, serialized, prompts, **kwargs: Any) -> Any: 175 | """Run when LLM starts running.""" 176 | 177 | async def on_llm_new_token( 178 | self, 179 | token: str, 180 | *, 181 | run_id: UUID, 182 | parent_run_id: Optional[UUID] = None, 183 | **kwargs: Any, 184 | ) -> None: 185 | try: 186 | parsed_data = self._parser.ingest_token(token) 187 | print(token, end="", flush=True) 188 | await streaming_queue.put(token) 189 | if parsed_data: 190 | self._curr_idx = parsed_data.idx 191 | await self._queue.put(parsed_data) 192 | if parsed_data.is_join: 193 | await self._queue.put(None) 194 | except Exception as e: 195 | # If there was an error in parsing the token, stop the LLM and propagate the error to 196 | # the joinner for it to handle. The error message will be presented as an observation in the join action. 197 | # This usually happens when the tool name is not recognized/hallucinated by the LLM. 198 | join_tool = instantiate_task( 199 | tools=self._tools, 200 | idx=self._curr_idx + 1, 201 | tool_name="join", 202 | args="", 203 | thought="", 204 | ) 205 | join_tool.observation = f"The plan generation was stopped due to an error in tool '{self._parser.buffer.strip()}'! Error: {str(e)}! You MUST correct this error and try again!" 206 | await self._queue.put(join_tool) 207 | await self._queue.put(None) 208 | raise TinyAgentEarlyStop(str(e)) 209 | 210 | async def on_llm_end( 211 | self, 212 | response: LLMResult, 213 | *, 214 | run_id: UUID, 215 | parent_run_id: Optional[UUID] = None, 216 | **kwargs: Any, 217 | ) -> None: 218 | parsed_data = self._parser.finalize() 219 | if parsed_data: 220 | await self._queue.put(parsed_data) 221 | await self._queue.put(None) 222 | 223 | # Define the following error callbacks to be able to stop the LLM when an error occurs 224 | async def on_llm_error( 225 | self, 226 | error: BaseException, 227 | *, 228 | run_id: UUID, 229 | parent_run_id: Optional[UUID] = None, 230 | tags: Optional[List[str]] = None, 231 | **kwargs: Any, 232 | ) -> None: 233 | if isinstance(error, TinyAgentEarlyStop): 234 | # Only allow the TinyAgentEarlyStop exception since it is a controlled stop 235 | return 236 | await streaming_queue.put(f"{LLM_ERROR_TOKEN}LLMError: {error}") 237 | 238 | async def on_chain_error( 239 | self, 240 | error: BaseException, 241 | *, 242 | run_id: UUID, 243 | parent_run_id: Optional[UUID] = None, 244 | tags: Optional[List[str]] = None, 245 | **kwargs: Any, 246 | ) -> None: 247 | await streaming_queue.put(f"{LLM_ERROR_TOKEN}ChainError: {error}") 248 | 249 | 250 | class Planner: 251 | def __init__( 252 | self, 253 | llm: BaseChatModel | BaseLLM, 254 | custom_instructions: str | None, 255 | example_prompt: str, 256 | example_prompt_replan: str, 257 | tools: Sequence[Union[Tool, StructuredTool]], 258 | stop: Optional[list[str]], 259 | ): 260 | self.llm = llm 261 | # different system prompt is needed when replanning 262 | # since they have different guidelines, and also examples provided by the user 263 | self.system_prompt = generate_llm_compiler_prompt( 264 | tools=tools, 265 | custom_instructions=custom_instructions, 266 | example_prompt=example_prompt, 267 | is_replan=False, 268 | ) 269 | self.system_prompt_replan = generate_llm_compiler_prompt( 270 | tools=tools, 271 | custom_instructions=custom_instructions, 272 | example_prompt=example_prompt_replan, 273 | is_replan=True, 274 | ) 275 | self.tools = tools 276 | self.output_parser = LLMCompilerPlanParser(tools=tools) 277 | self.stop = stop 278 | 279 | async def run_llm( 280 | self, 281 | inputs: dict[str, Any], 282 | is_replan: bool = False, 283 | callbacks: Callbacks = None, 284 | ) -> str: 285 | """Run the LLM.""" 286 | if is_replan: 287 | system_prompt = self.system_prompt_replan 288 | assert "context" in inputs, "If replanning, context must be provided" 289 | human_prompt = f"Question: {inputs['input']}\n{inputs['context']}\n" 290 | else: 291 | system_prompt = self.system_prompt 292 | human_prompt = f"Question: {inputs['input']}" 293 | 294 | if isinstance(self.llm, BaseChatModel): 295 | messages = [ 296 | SystemMessage(content=system_prompt), 297 | HumanMessage(content=human_prompt), 298 | ] 299 | try: 300 | llm_response = await self.llm._call_async( 301 | messages, 302 | callbacks=callbacks, 303 | stop=self.stop, 304 | ) 305 | except Exception as e: 306 | # Put this exception in the streaming queue to stop the LLM since the whole planner 307 | # system is running as an async tasks concurrently and is never awaited. Hence 308 | # the errors are not propagated to the main context properly. 309 | await streaming_queue.put(f"{LLM_ERROR_TOKEN}LLMError: {e}") 310 | response = llm_response.content 311 | elif isinstance(self.llm, BaseLLM): 312 | message = system_prompt + "\n\n" + human_prompt 313 | response = await self.llm.apredict( 314 | message, 315 | callbacks=callbacks, 316 | stop=self.stop, 317 | ) 318 | else: 319 | raise ValueError("LLM must be either BaseChatModel or BaseLLM") 320 | 321 | log("LLMCompiler planner response: \n", response, block=True) 322 | 323 | return response 324 | 325 | async def plan( 326 | self, inputs: dict, is_replan: bool, callbacks: Callbacks = None, **kwargs: Any 327 | ): 328 | llm_response = await self.run_llm( 329 | inputs=inputs, is_replan=is_replan, callbacks=callbacks 330 | ) 331 | llm_response = llm_response + "\n" 332 | return self.output_parser.parse(llm_response) 333 | 334 | async def aplan( 335 | self, 336 | inputs: dict, 337 | task_queue: asyncio.Queue[Optional[str]], 338 | is_replan: bool, 339 | callbacks: Callbacks = None, 340 | **kwargs: Any, 341 | ) -> Plan: 342 | """Given input, asynchronously decide what to do.""" 343 | all_callbacks = [ 344 | LLMCompilerCallback( 345 | queue=task_queue, 346 | tools=self.tools, 347 | ) 348 | ] 349 | if callbacks: 350 | all_callbacks.extend(callbacks) 351 | try: 352 | # Actually, we don't need this try-except block here, but we keep it just in case... 353 | await self.run_llm( 354 | inputs=inputs, is_replan=is_replan, callbacks=all_callbacks 355 | ) 356 | except TinyAgentEarlyStop as e: 357 | pass 358 | -------------------------------------------------------------------------------- /src/llm_compiler/task_fetching_unit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Collection, Dict, List, Optional 6 | 7 | from src.utils.logger_utils import log 8 | 9 | SCHEDULING_INTERVAL = 0.01 # seconds 10 | 11 | 12 | def _default_stringify_rule_for_arguments(args): 13 | if len(args) == 1: 14 | return str(args[0]) 15 | else: 16 | return str(tuple(args)) 17 | 18 | 19 | def _replace_arg_mask_with_real_value( 20 | args, dependencies: List[int], tasks: Dict[str, Task] 21 | ): 22 | if isinstance(args, (list, tuple)): 23 | return type(args)( 24 | _replace_arg_mask_with_real_value(item, dependencies, tasks) 25 | for item in args 26 | ) 27 | elif isinstance(args, str): 28 | for dependency in sorted(dependencies, reverse=True): 29 | # consider both ${1} and $1 (in case planner makes a mistake) 30 | for arg_mask in ["${" + str(dependency) + "}", "$" + str(dependency)]: 31 | if arg_mask in args: 32 | if tasks[dependency].observation is not None: 33 | args = args.replace( 34 | arg_mask, str(tasks[dependency].observation) 35 | ) 36 | return args 37 | else: 38 | return args 39 | 40 | 41 | @dataclass 42 | class Task: 43 | idx: int 44 | name: str 45 | tool: Callable 46 | args: Collection[Any] 47 | dependencies: Collection[int] 48 | stringify_rule: Optional[Callable] = None 49 | thought: Optional[str] = None 50 | observation: Optional[str] = None 51 | is_join: bool = False 52 | 53 | async def __call__(self) -> Any: 54 | log(f"running task {self.name}") 55 | x = await self.tool(*self.args) 56 | log(f"done task {self.name}") 57 | return x 58 | 59 | def get_though_action_observation( 60 | self, include_action=True, include_thought=True, include_action_idx=False 61 | ) -> str: 62 | thought_action_observation = "" 63 | if self.thought and include_thought: 64 | thought_action_observation = f"Thought: {self.thought}\n" 65 | if include_action: 66 | idx = f"{self.idx}. " if include_action_idx else "" 67 | if self.stringify_rule: 68 | # If the user has specified a custom stringify rule for the 69 | # function argument, use it 70 | thought_action_observation += f"{idx}{self.stringify_rule(self.args)}\n" 71 | else: 72 | # Otherwise, we have a default stringify rule 73 | thought_action_observation += ( 74 | f"{idx}{self.name}" 75 | f"{_default_stringify_rule_for_arguments(self.args)}\n" 76 | ) 77 | if self.observation is not None: 78 | thought_action_observation += f"Observation: {self.observation}\n" 79 | return thought_action_observation 80 | 81 | 82 | class TaskFetchingUnit: 83 | tasks: Dict[str, Task] 84 | tasks_done: Dict[str, asyncio.Event] 85 | remaining_tasks: set[str] 86 | 87 | def __init__(self): 88 | self.tasks = {} 89 | self.tasks_done = {} 90 | self.remaining_tasks = set() 91 | 92 | def set_tasks(self, tasks: dict[str, Any]): 93 | self.tasks.update(tasks) 94 | self.tasks_done.update({task_idx: asyncio.Event() for task_idx in tasks}) 95 | self.remaining_tasks.update(set(tasks.keys())) 96 | 97 | def _all_tasks_done(self): 98 | return all(self.tasks_done[d].is_set() for d in self.tasks_done) 99 | 100 | def _get_all_executable_tasks(self): 101 | return [ 102 | task_name 103 | for task_name in self.remaining_tasks 104 | if all( 105 | self.tasks_done[d].is_set() for d in self.tasks[task_name].dependencies 106 | ) 107 | ] 108 | 109 | def _preprocess_args(self, task: Task): 110 | """Replace dependency placeholders, i.e. ${1}, in task.args with the actual observation.""" 111 | args = [] 112 | for arg in task.args: 113 | arg = _replace_arg_mask_with_real_value(arg, task.dependencies, self.tasks) 114 | args.append(arg) 115 | task.args = args 116 | 117 | async def _run_task(self, task: Task): 118 | try: 119 | self._preprocess_args(task) 120 | if not task.is_join: 121 | observation = await task() 122 | task.observation = observation 123 | except Exception as e: 124 | # If an exception occurs, stop LLM execution and propagate the error message to the joinner 125 | # by manually setting the observation of the task to the error message. If this is an error of 126 | # providing the wrong arguments to the tool, then we do some cleaning up in the error message. 127 | error_message = str(e) 128 | if "positional argument" in error_message: 129 | error_message = error_message.split(".")[-2] 130 | task.observation = ( 131 | f"Error: {error_message}! You MUST correct this error and try again!" 132 | ) 133 | 134 | self.tasks_done[task.idx].set() 135 | 136 | async def schedule(self): 137 | """Run all tasks in self.tasks in parallel, respecting dependencies.""" 138 | # run until all tasks are done 139 | while not self._all_tasks_done(): 140 | # Find tasks with no dependencies or with all dependencies met 141 | executable_tasks = self._get_all_executable_tasks() 142 | 143 | for task_name in executable_tasks: 144 | asyncio.create_task(self._run_task(self.tasks[task_name])) 145 | self.remaining_tasks.remove(task_name) 146 | 147 | await asyncio.sleep(SCHEDULING_INTERVAL) 148 | 149 | async def aschedule(self, task_queue: asyncio.Queue[Optional[Task]], func): 150 | """Asynchronously listen to task_queue and schedule tasks as they arrive.""" 151 | no_more_tasks = False # Flag to check if all tasks are received 152 | 153 | while True: 154 | if not no_more_tasks: 155 | # Wait for a new task to be added to the queue 156 | task = await task_queue.get() 157 | 158 | # Check for sentinel value indicating end of tasks 159 | if task is None: 160 | no_more_tasks = True 161 | else: 162 | # Parse and set the new tasks 163 | self.set_tasks({task.idx: task}) 164 | 165 | # Schedule and run executable tasks 166 | executable_tasks = self._get_all_executable_tasks() 167 | 168 | if executable_tasks: 169 | for task_name in executable_tasks: 170 | # The task is executed in a separate task to avoid blocking the loop 171 | # without explicitly awaiting it. This, unfortunately, means that the 172 | # task will not be able to propagate exceptions to the calling context. 173 | # Hence, we need to handle exceptions within the task itself. See ._run_task() 174 | asyncio.create_task(self._run_task(self.tasks[task_name])) 175 | self.remaining_tasks.remove(task_name) 176 | elif no_more_tasks and self._all_tasks_done(): 177 | # Exit the loop if no more tasks are expected and all tasks are done 178 | break 179 | else: 180 | # If no executable tasks are found, sleep for the SCHEDULING_INTERVAL 181 | await asyncio.sleep(SCHEDULING_INTERVAL) 182 | -------------------------------------------------------------------------------- /src/tiny_agent/computer.py: -------------------------------------------------------------------------------- 1 | from src.tiny_agent.tools.calendar import Calendar 2 | from src.tiny_agent.tools.contacts import Contacts 3 | from src.tiny_agent.tools.mail import Mail 4 | from src.tiny_agent.tools.maps import Maps 5 | from src.tiny_agent.tools.notes import Notes 6 | from src.tiny_agent.tools.reminders import Reminders 7 | from src.tiny_agent.tools.sms import SMS 8 | from src.tiny_agent.tools.spotlight_search import SpotlightSearch 9 | from src.tiny_agent.tools.zoom import Zoom 10 | 11 | 12 | class Computer: 13 | calendar: Calendar 14 | contacts: Contacts 15 | mail: Mail 16 | maps: Maps 17 | notes: Notes 18 | reminders: Reminders 19 | sms: SMS 20 | spotlight_search: SpotlightSearch 21 | zoom: Zoom 22 | 23 | def __init__(self) -> None: 24 | self.calendar = Calendar() 25 | self.contacts = Contacts() 26 | self.mail = Mail() 27 | self.maps = Maps() 28 | self.notes = Notes() 29 | self.reminders = Reminders() 30 | self.sms = SMS() 31 | self.spotlight_search = SpotlightSearch() 32 | -------------------------------------------------------------------------------- /src/tiny_agent/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any 4 | 5 | from tiktoken import encoding_name_for_model, get_encoding 6 | from transformers import AutoTokenizer 7 | 8 | from src.tiny_agent.models import ( 9 | AgentType, 10 | App, 11 | ModelConfig, 12 | ModelType, 13 | TinyAgentConfig, 14 | WhisperConfig, 15 | ) 16 | 17 | DEFAULT_SAFE_CONTEXT_LENGTH = 4096 18 | DEFAULT_EMBEDDING_CONTEXT_LENGTH = 8192 19 | DEFAULT_OPENAI_EMBEDDING_MODEL = "text-embedding-3-small" 20 | 21 | 22 | OPENAI_MODELS = { 23 | "gpt-4": 8192, 24 | "gpt-4-0613": 8192, 25 | "gpt-4-32k": 32768, 26 | "gpt-4-32k-0613": 32768, 27 | "gpt-4-0125-preview": 128000, 28 | "gpt-4-turbo-preview": 128000, 29 | "gpt-4-turbo": 128000, 30 | "gpt-4-1106-preview": 128000, 31 | "gpt-3.5-turbo-0125": 16385, 32 | "gpt-3.5-turbo": 16385, 33 | "gpt-3.5-turbo-1106": 16385, 34 | "gpt-3.5-turbo-instruct": 4096, 35 | } 36 | 37 | AGENT_TYPE_TO_CONFIG_PREFIX = { 38 | AgentType.MAIN: "", 39 | AgentType.SUB_AGENT: "SubAgent", 40 | AgentType.EMBEDDING: "Embedding", 41 | } 42 | 43 | 44 | def load_config(config_path: str) -> dict[str, Any]: 45 | with open(config_path, "r") as file: 46 | return json.load(file) 47 | 48 | 49 | def get_model_config( 50 | config: dict[str, Any], 51 | provider: str, 52 | agent_type: AgentType, 53 | ) -> ModelConfig: 54 | agent_prefix = AGENT_TYPE_TO_CONFIG_PREFIX[agent_type] 55 | model_type = ModelType(provider) 56 | 57 | if model_type == ModelType.AZURE: 58 | _check_azure_config(config, agent_prefix) 59 | api_key = ( 60 | config["azureApiKey"] 61 | if len(config["azureApiKey"]) > 0 62 | else os.environ["AZURE_OPENAI_API_KEY"] 63 | ) 64 | model_name = config[f"azure{agent_prefix}DeploymentName"] 65 | if agent_type != AgentType.EMBEDDING: 66 | context_length = int(config[f"azure{agent_prefix}CtxLen"]) 67 | tokenizer = get_encoding(encoding_name_for_model("gpt-3.5-turbo")) 68 | elif model_type == ModelType.LOCAL: 69 | _check_local_config(config, agent_prefix) 70 | api_key = "lm-studio" 71 | model_name = config[f"local{agent_prefix}ModelName"] 72 | context_length = int( 73 | config[f"local{agent_prefix}CtxLen"] 74 | if len(config[f"local{agent_prefix}CtxLen"]) > 0 75 | else DEFAULT_EMBEDDING_CONTEXT_LENGTH 76 | ) 77 | if agent_type != AgentType.EMBEDDING: 78 | tokenizer = AutoTokenizer.from_pretrained( 79 | config[f"local{agent_prefix}TokenizerNameOrPath"], 80 | use_fast=True, 81 | token=config["hfToken"], 82 | ) 83 | elif model_type == ModelType.OPENAI: 84 | _check_openai_config(config, agent_prefix) 85 | api_key = ( 86 | config["openAIApiKey"] 87 | if len(config["openAIApiKey"]) > 0 88 | else os.environ["OPENAI_API_KEY"] 89 | ) 90 | model_name = ( 91 | config[f"openAI{agent_prefix}ModelName"] 92 | if agent_type != AgentType.EMBEDDING 93 | else DEFAULT_OPENAI_EMBEDDING_MODEL 94 | ) 95 | if agent_type != AgentType.EMBEDDING: 96 | context_length = OPENAI_MODELS[model_name] 97 | tokenizer = get_encoding(encoding_name_for_model("gpt-3.5-turbo")) 98 | else: 99 | raise ValueError("Invalid model type") 100 | 101 | return ModelConfig( 102 | api_key=api_key, 103 | context_length=( 104 | context_length 105 | if agent_type != AgentType.EMBEDDING 106 | else DEFAULT_EMBEDDING_CONTEXT_LENGTH 107 | ), 108 | model_name=model_name, 109 | model_type=model_type, 110 | tokenizer=tokenizer if agent_type != AgentType.EMBEDDING else None, 111 | port=( 112 | int(config[f"local{agent_prefix}Port"]) 113 | if model_type == ModelType.LOCAL 114 | and len(config[f"local{agent_prefix}Port"]) > 0 115 | else None 116 | ), 117 | ) 118 | 119 | 120 | def get_whisper_config(config: dict[str, Any], provider: str) -> WhisperConfig: 121 | whisper_provider = ModelType(provider) 122 | 123 | api_key = None 124 | port = None 125 | if whisper_provider == ModelType.OPENAI: 126 | if ( 127 | not _is_valid_config_field(config, "openAIApiKey") 128 | and os.environ.get("OPENAI_API_KEY") is None 129 | ): 130 | raise ValueError("OpenAI API key for Whisper not found in config") 131 | 132 | api_key = ( 133 | config.get("openAIApiKey") 134 | if len(config["openAIApiKey"]) > 0 135 | else os.environ.get("OPENAI_API_KEY") 136 | ) 137 | elif whisper_provider == ModelType.LOCAL: 138 | if not _is_valid_config_field(config, "localWhisperPort"): 139 | raise ValueError("Local Whisper port not found in config") 140 | port = int(config["localWhisperPort"]) 141 | else: 142 | raise ValueError("Invalid Whisper provider") 143 | 144 | return WhisperConfig(provider=whisper_provider, api_key=api_key, port=port) 145 | 146 | 147 | def get_tiny_agent_config(config_path: str) -> TinyAgentConfig: 148 | config = load_config(config_path) 149 | 150 | if (provider := config.get("provider")) is None or len(provider) == 0: 151 | raise ValueError("Provider not found in config") 152 | 153 | if (sub_agent_provider := config.get("subAgentProvider")) is None or len( 154 | sub_agent_provider 155 | ) == 0: 156 | raise ValueError("Subagent provider not found in config") 157 | 158 | use_in_context_example_retriever = config.get("useToolRAG") 159 | if ( 160 | use_in_context_example_retriever is False 161 | and config.get("toolRAGProvider") is None 162 | ): 163 | raise ValueError("In-context example retriever provider not found in config") 164 | 165 | whisper_provider = config.get("whisperProvider") 166 | if whisper_provider is None or len(whisper_provider) == 0: 167 | raise ValueError("Whisper provider not found in config") 168 | 169 | # Get the model configs 170 | llmcompiler_config = get_model_config(config, config["provider"], AgentType.MAIN) 171 | sub_agent_config = get_model_config( 172 | config, 173 | sub_agent_provider, 174 | AgentType.SUB_AGENT, 175 | ) 176 | if use_in_context_example_retriever is True: 177 | embedding_model_config = get_model_config( 178 | config, config["toolRAGProvider"], AgentType.EMBEDDING 179 | ) 180 | 181 | # Azure config 182 | azure_api_version = config.get("azureApiVersion") 183 | azure_endpoint = ( 184 | config.get("azureEndpoint") 185 | if len(config["azureEndpoint"]) > 0 186 | else os.environ.get("AZURE_OPENAI_ENDPOINT") 187 | ) 188 | 189 | # Get the other config values 190 | hf_token = config.get("hfToken") 191 | zoom_access_token = config.get("zoomAccessToken") 192 | 193 | # Get the whisper API key which is just the OpenAI API key 194 | if ( 195 | not _is_valid_config_field(config, "openAIApiKey") 196 | and os.environ.get("OPENAI_API_KEY") is None 197 | ): 198 | raise ValueError("OpenAI API key is needed for whisper API.") 199 | 200 | whisper_config = get_whisper_config(config, whisper_provider) 201 | 202 | apps = set() 203 | for app in App: 204 | if config[f"{app.value}Enabled"]: 205 | apps.add(app) 206 | 207 | return TinyAgentConfig( 208 | apps=apps, 209 | custom_instructions=config.get("customInstructions"), 210 | llmcompiler_config=llmcompiler_config, 211 | sub_agent_config=sub_agent_config, 212 | embedding_model_config=( 213 | embedding_model_config if use_in_context_example_retriever else None 214 | ), 215 | azure_api_version=azure_api_version, 216 | azure_endpoint=azure_endpoint, 217 | hf_token=hf_token, 218 | zoom_access_token=zoom_access_token, 219 | whisper_config=whisper_config, 220 | ) 221 | 222 | 223 | def _is_valid_config_field(config: dict[str, Any], field: str) -> bool: 224 | return (field_value := config.get(field)) is not None and len(field_value) > 0 225 | 226 | 227 | def _check_azure_config(config: dict[str, Any], agent_prefix: str) -> None: 228 | if ( 229 | not _is_valid_config_field(config, "azureApiKey") 230 | and os.environ.get("AZURE_OPENAI_API_KEY") is None 231 | ): 232 | raise ValueError("Azure API key not found in config") 233 | if not _is_valid_config_field(config, f"azure{agent_prefix}DeploymentName"): 234 | raise ValueError( 235 | f"Azure {agent_prefix} model deployment name not found in config" 236 | ) 237 | if agent_prefix != AGENT_TYPE_TO_CONFIG_PREFIX[ 238 | AgentType.EMBEDDING 239 | ] and not _is_valid_config_field(config, f"azure{agent_prefix}CtxLen"): 240 | raise ValueError(f"Azure {agent_prefix} context length not found in config") 241 | if not _is_valid_config_field(config, "azureApiVersion"): 242 | raise ValueError("Azure API version not found in config") 243 | if not _is_valid_config_field(config, "azureEndpoint"): 244 | raise ValueError("Azure endpoint not found in config") 245 | 246 | 247 | def _check_openai_config(config: dict[str, Any], agent_prefix: str) -> None: 248 | if ( 249 | not _is_valid_config_field(config, "openAIApiKey") 250 | and os.environ.get("OPENAI_API_KEY") is None 251 | ): 252 | raise ValueError("OpenAI API key not found in config") 253 | # The embedding model for OpenAI API only supports "text-embedding-3-small" hence 254 | # we don't need to check for the model name for the embedding model 255 | if agent_prefix != AGENT_TYPE_TO_CONFIG_PREFIX[ 256 | AgentType.EMBEDDING 257 | ] and not _is_valid_config_field(config, f"openAI{agent_prefix}ModelName"): 258 | raise ValueError(f"OpenAI {agent_prefix} model name not found in config") 259 | 260 | 261 | def _check_local_config(config: dict[str, Any], agent_prefix: str) -> None: 262 | # TinyAgent does not support local embedding models. Hence, we only need to check for 263 | # the context length, port, and tokenizer name or path for the local planner model. 264 | is_embedding_model = ( 265 | agent_prefix == AGENT_TYPE_TO_CONFIG_PREFIX[AgentType.EMBEDDING] 266 | ) 267 | if not is_embedding_model and not _is_valid_config_field( 268 | config, f"local{agent_prefix}CtxLen" 269 | ): 270 | raise ValueError(f"Local {agent_prefix} context length not found in config") 271 | if not is_embedding_model and not _is_valid_config_field( 272 | config, f"local{agent_prefix}Port" 273 | ): 274 | raise ValueError(f"Local {agent_prefix} port not found in config") 275 | if not is_embedding_model and not _is_valid_config_field( 276 | config, f"local{agent_prefix}TokenizerNameOrPath" 277 | ): 278 | raise ValueError( 279 | f"Local {agent_prefix} tokenizer name or path not found in config" 280 | ) 281 | if is_embedding_model and not _is_valid_config_field( 282 | config, f"local{agent_prefix}ModelName" 283 | ): 284 | raise ValueError(f"Local {agent_prefix} model name not found in config") 285 | -------------------------------------------------------------------------------- /src/tiny_agent/models.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | from typing import Collection 6 | 7 | import torch 8 | from tiktoken import Encoding 9 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 10 | 11 | streaming_queue = asyncio.Queue[str | None]() 12 | 13 | LLM_ERROR_TOKEN = "###LLM_ERROR_TOKEN###" 14 | 15 | TINY_AGENT_DIR = os.path.expanduser("~/Library/Application Support/TinyAgent") 16 | 17 | Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast | Encoding 18 | 19 | 20 | class ModelType(Enum): 21 | AZURE = "azure" 22 | OPENAI = "openai" 23 | LOCAL = "local" 24 | 25 | 26 | class App(Enum): 27 | CALENDAR = "calendar" 28 | CONTACTS = "contacts" 29 | FILES = "files" 30 | MAIL = "mail" 31 | MAPS = "maps" 32 | NOTES = "notes" 33 | REMINDERS = "reminders" 34 | SMS = "sms" 35 | ZOOM = "zoom" 36 | 37 | 38 | class AgentType(Enum): 39 | MAIN = "main" 40 | SUB_AGENT = "sub_agent" 41 | EMBEDDING = "embedding" 42 | 43 | 44 | @dataclass 45 | class ModelConfig: 46 | api_key: str 47 | context_length: int 48 | model_name: str 49 | model_type: ModelType 50 | tokenizer: Tokenizer | None 51 | port: int | None 52 | 53 | 54 | @dataclass 55 | class WhisperConfig: 56 | # Azure is not yet supported for whisper 57 | provider: ModelType 58 | api_key: str | None 59 | port: int | None 60 | 61 | 62 | @dataclass 63 | class TinyAgentConfig: 64 | # Custom configs 65 | apps: Collection[App] 66 | custom_instructions: str | None 67 | # Config for the LLMCompiler model 68 | llmcompiler_config: ModelConfig 69 | # Config for the sub-agent LLM model 70 | sub_agent_config: ModelConfig 71 | # Config for the embedding model 72 | embedding_model_config: ModelConfig | None 73 | # Azure model config 74 | azure_api_version: str | None 75 | azure_endpoint: str | None 76 | # Other tokens 77 | hf_token: str | None 78 | zoom_access_token: str | None 79 | # Whisper config 80 | whisper_config: WhisperConfig 81 | 82 | 83 | class TinyAgentToolName(Enum): 84 | GET_PHONE_NUMBER = "get_phone_number" 85 | GET_EMAIL_ADDRESS = "get_email_address" 86 | CREATE_CALENDAR_EVENT = "create_calendar_event" 87 | OPEN_AND_GET_FILE_PATH = "open_and_get_file_path" 88 | SUMMARIZE_PDF = "summarize_pdf" 89 | COMPOSE_NEW_EMAIL = "compose_new_email" 90 | REPLY_TO_EMAIL = "reply_to_email" 91 | FORWARD_EMAIL = "forward_email" 92 | MAPS_OPEN_LOCATION = "maps_open_location" 93 | MAPS_SHOW_DIRECTIONS = "maps_show_directions" 94 | CREATE_NOTE = "create_note" 95 | OPEN_NOTE = "open_note" 96 | APPEND_NOTE_CONTENT = "append_note_content" 97 | CREATE_REMINDER = "create_reminder" 98 | SEND_SMS = "send_sms" 99 | GET_ZOOM_MEETING_LINK = "get_zoom_meeting_link" 100 | 101 | 102 | @dataclass 103 | class InContextExample: 104 | example: str 105 | embedding: torch.Tensor 106 | tools: list[TinyAgentToolName] 107 | 108 | 109 | class ComposeEmailMode(Enum): 110 | NEW = "new" 111 | REPLY = "reply" 112 | FORWARD = "forward" 113 | 114 | 115 | class NotesMode(Enum): 116 | NEW = "new" 117 | APPEND = "append" 118 | 119 | 120 | class TransportationOptions(Enum): 121 | DRIVING = "d" 122 | WALKING = "w" 123 | PUBLIC_TRANSIT = "r" 124 | -------------------------------------------------------------------------------- /src/tiny_agent/prompts.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Sequence 3 | 4 | from src.llm_compiler.constants import ( 5 | END_OF_PLAN, 6 | JOINNER_FINISH, 7 | JOINNER_REPLAN, 8 | SUMMARY_RESULT, 9 | ) 10 | from src.tiny_agent.models import TinyAgentToolName 11 | from src.tools.base import StructuredTool, Tool 12 | from src.utils.logger_utils import log 13 | 14 | NOW = datetime.datetime.now() 15 | 16 | 17 | DEFAULT_PLANNER_IN_CONTEXT_EXAMPLES_PROMPT = ( 18 | "Question: Notify Lutfi Eren Erdogan about the upcoming Apple meeting that is going to start at 3PM on Friday.\n" 19 | '1. get_phone_number("Lutfi Eren Erdogan")\n' 20 | '2. send_sms(["$1"], "Hey Lutfi, just wanted to let you know about the upcoming Apple meeting. It\'s going to be at 3 PM on Friday.")\n' 21 | "Thought: I have succesfully found the contact and sent the message.\n" 22 | f"3. join(){END_OF_PLAN}\n" 23 | "###\n" 24 | "Question: Create a zoom meeting for the upcoming Apple meeting with Eren Erdoğan.\n" 25 | '1. get_email_address("Eren Erdoğan")\n' 26 | '2. get_zoom_meeting_link("Apple Meeting", "2022-10-14 15:00:00", 60, ["$1"])\n' 27 | '3. create_calendar_event("Apple Meeting", "2022-10-14 15:00:00", "2022-10-14 16:00:00", "$2", [], "", None)\n' 28 | "Thought: I have succesfully created the calendar event.\n" 29 | f"4. join(){END_OF_PLAN}\n" 30 | "###\n" 31 | "Question: Show directions to Apple Park.\n" 32 | '1. maps_show_directions("", "Apple Park", "d")\n' 33 | "Thought: I have succesfully shown the directions.\n" 34 | f"2. join(){END_OF_PLAN}\n" 35 | "###\n" 36 | "Question: Send an email to Amir saying that the meeting is postponed to next week.\n" 37 | '1. get_email_address("Amir")\n' 38 | '2. compose_new_email(["$1"], [], "Meeting Postponed," "", [])\n' 39 | f"3. join(){END_OF_PLAN}\n" 40 | "###\n" 41 | ) 42 | 43 | 44 | TOOL_SPECIFIC_PROMPTS: list[tuple[set[TinyAgentToolName], str]] = [ 45 | ( 46 | { 47 | TinyAgentToolName.GET_EMAIL_ADDRESS, 48 | TinyAgentToolName.COMPOSE_NEW_EMAIL, 49 | TinyAgentToolName.REPLY_TO_EMAIL, 50 | TinyAgentToolName.FORWARD_EMAIL, 51 | }, 52 | f" - Before sending an email, you MUST use the {TinyAgentToolName.GET_EMAIL_ADDRESS.value} tool to get the email addresses of the recipients and cc, unless you are explicitly given their email addresses.\n", 53 | ), 54 | ( 55 | {TinyAgentToolName.GET_PHONE_NUMBER, TinyAgentToolName.SEND_SMS}, 56 | f" - Before sending an SMS, you MUST use the {TinyAgentToolName.GET_PHONE_NUMBER.value} tool to get the phone number of the contact, unless you are explicitly given their phone number.\n" 57 | " - If you need to send an SMS message to multiple contacts, send it in one message, unless specified otherwise.\n", 58 | ), 59 | ( 60 | { 61 | TinyAgentToolName.GET_EMAIL_ADDRESS, 62 | TinyAgentToolName.COMPOSE_NEW_EMAIL, 63 | TinyAgentToolName.REPLY_TO_EMAIL, 64 | TinyAgentToolName.FORWARD_EMAIL, 65 | TinyAgentToolName.GET_PHONE_NUMBER, 66 | TinyAgentToolName.SEND_SMS, 67 | }, 68 | f" - If you need to send an email or an sms using {TinyAgentToolName.COMPOSE_NEW_EMAIL.value}, {TinyAgentToolName.REPLY_TO_EMAIL.value}, {TinyAgentToolName.FORWARD_EMAIL.value}, or {TinyAgentToolName.SEND_SMS.value} tools, " 69 | f"you MUST send it before calling join(), or you WILL BE PENALIZED!\n", 70 | ), 71 | ( 72 | {TinyAgentToolName.GET_ZOOM_MEETING_LINK}, 73 | f" - If you need to create a zoom meeting, you MUST use {TinyAgentToolName.GET_ZOOM_MEETING_LINK.value} to get the newly created zoom meeting link.\n", 74 | ), 75 | ( 76 | {TinyAgentToolName.OPEN_NOTE, TinyAgentToolName.APPEND_NOTE_CONTENT}, 77 | f" - If you need to append some content to a note, you DON'T HAVE TO call {TinyAgentToolName.OPEN_NOTE.value} before calling {TinyAgentToolName.APPEND_NOTE_CONTENT.value}. You can directly use {TinyAgentToolName.APPEND_NOTE_CONTENT.value} to append some content to the specific note.\n", 78 | ), 79 | ( 80 | {TinyAgentToolName.MAPS_OPEN_LOCATION, TinyAgentToolName.MAPS_SHOW_DIRECTIONS}, 81 | f" - If you need to show directions to a place, you DON'T HAVE TO call {TinyAgentToolName.MAPS_OPEN_LOCATION.value} before calling {TinyAgentToolName.MAPS_SHOW_DIRECTIONS.value}. You can directly use {TinyAgentToolName.MAPS_SHOW_DIRECTIONS.value} to show directions to the specific place.\n", 82 | ), 83 | ] 84 | 85 | 86 | def get_planner_custom_instructions_prompt( 87 | tools: Sequence[Tool | StructuredTool], custom_instructions: str | None 88 | ) -> str: 89 | prompt = [] 90 | prompt.append( 91 | " - You need to start your plan with the '1.' call\n" 92 | f" - Today's date is {NOW.strftime('%A %Y-%m-%d %H:%M')}\n" 93 | " - Unless otherwise specified, the default meeting duration is 60 minutes.\n" 94 | " - Do not use named arguments in your tool calls.\n" 95 | " - You MUST end your plans with the 'join()' call and a '\\n' character.\n" 96 | " - You MUST fill every argument in the tool calls, even if they are optional.\n" 97 | " - The format for dates MUST be in ISO format of 'YYYY-MM-DD HH:MM:SS', unless other specified.\n" 98 | " - If you want to use the result of a previous tool call, you MUST use the '$' sign followed by the index of the tool call.\n" 99 | f" - You MUST ONLY USE join() at the very very end of the plan, or you WILL BE PENALIZED.\n" 100 | ) 101 | 102 | for tool_set, instructions in TOOL_SPECIFIC_PROMPTS: 103 | if any(TinyAgentToolName(tool.name) in tool_set for tool in tools): 104 | prompt += instructions 105 | 106 | if custom_instructions is not None and len(custom_instructions) > 0: 107 | prompt.append(f" - {custom_instructions}") 108 | 109 | return "".join(prompt) 110 | 111 | 112 | PLANNER_PROMPT_REPLAN = ( 113 | "Question: Say hi to Sid via SMS.\n\n" 114 | "Previous Plan:\n\n" 115 | "1. join()\n" 116 | "Observation:\nThe plan generation was stopped due to an error in tool 1. get_contact_info('Sid')! " 117 | "Error: Tool get_contact_info not found. You MUST correct this error and try again!" 118 | "\n" 119 | "Current Plan:\n\n" 120 | f"Thought: The error is fixable since I have the {TinyAgentToolName.GET_PHONE_NUMBER.value} tool to retrieve the phone number of Sid. Then I will proceed with sending the SMS.\n" 121 | '1. get_phone_number("Sid")\n' 122 | '2. send_sms("$2", "Hi Sid!")\n' 123 | "Thought: I have succesfully created the retrieved the phone number and sent the SMS.\n" 124 | f"4. join(){END_OF_PLAN}\n" 125 | "###\n" 126 | "Question: Summarize 'Apple Demo.pdf'.\n\n" 127 | "Previous Plan:\n\n" 128 | '1. open_and_get_file_path("Apple Demo")\n' 129 | "2. join()\n" 130 | "Observation: summarize_pdf() takes 1 positional arguments but 2 were given! You MUST correct this error and try again!" 131 | "\n" 132 | "Current Plan:\n\n" 133 | f"Thought: Previous plan tried to call the summarize_pdf() tool with the wrong number of arguments. I will correct this and try again.\n" 134 | '1. open_and_get_file_path("Apple Demo")\n' 135 | '2. summarize_pdf("$1")\n' 136 | "Thought: I have succesfully opened the file and summarized it.\n" 137 | f"3. join(){END_OF_PLAN}\n" 138 | "###\n" 139 | ) 140 | 141 | JOINNER_REPLAN_RULES = ( 142 | f" - If you think the plan is not completed yet or an error in the plan is fixable, you should output {JOINNER_REPLAN}.\n" 143 | f" - If the plan is fixable, you will see a message like 'try again'. If you don't see this message, the error is NOT fixable and you MUST output an error message using 'Action: {JOINNER_FINISH}()'\n" 144 | ) 145 | 146 | JOINNER_FINISH_RULES = ( 147 | f" - If you need to answer some knowledge question, just answer it directly using 'Action: {JOINNER_FINISH}()'.\n" 148 | f" - If you need to return the result of a summary (summarize_pdf), you MUST use 'Action: {JOINNER_FINISH}({SUMMARY_RESULT})'\n" 149 | f" - If there is an error in one of the tool calls and it is not fixable, you should provide a user-friendly error message using 'Action: {JOINNER_FINISH}()'.\n" 150 | ) 151 | 152 | REPLAN_EXAMPLES = [ 153 | "Question: Say hi to Sid via SMS.\n" 154 | "join()\n" 155 | "Observation: The plan generation was stopped due to an error in tool 1. get_contact_info('Sid')! " 156 | "Error: Tool get_contact_info not found. You MUST correct this error and try again!" 157 | "Thought: The error is fixable so I need to replan and try again.\n" 158 | f"Action: {JOINNER_REPLAN}\n" 159 | ] 160 | 161 | FINISH_EXAMPLES = [ 162 | "Question: Create a zoom meeting for the upcoming Apple meeting with Eren Erdoğan. \n" 163 | 'get_email_address("Eren Erdoğan")\n' 164 | "Observation: eren@gmail.com\n" 165 | 'get_zoom_meeting_link("Apple Meeting", "2022-10-14 15:00:00", 60, ["$1"])\n' 166 | "Observation: https://zoom.us/j/1234567890?pwd=abc123\n" 167 | 'create_calendar_event("Apple Meeting", "2022-10-14 15:00:00", "2022-10-14 16:00:00", "Apple HQ", "$2", None)\n' 168 | "Observation: Event created successfully\n" 169 | "Thought: I don't need to answer a question.\n" 170 | f"Action: {JOINNER_FINISH}(Task completed!)\n", 171 | "Question: What is the content of the Apple meeting notes? \n" 172 | 'get_note_content("Apple Meeting")\n' 173 | "Observation: The meeting is about the new iPhone release.\n" 174 | "Thought: I can just answer the question directly.\n" 175 | f"Action: {JOINNER_FINISH}(The meeting is about the new iPhone release.)\n", 176 | "Question: Compose a new email to John, attaching the Project.pdf file.\n" 177 | 'get_email_address("John")\n' 178 | "Observation: john@doe.com" 179 | 'open_and_get_file_path("Project")\n' 180 | "Observation: /Users/eren/Downloads/Project.pdf\n" 181 | 'compose_new_email([john@doe.com], [], "Project Update", "Please find the attached project update.", ["/Users/eren/Downloads/Project.pdf"])\n' 182 | "Observation: There was an error while composing the email.\n" 183 | "Thought: There was an error with the compose_new_email tool call and it is not possible to fix it. I need to provide a user-friendly error message.\n" 184 | f"Action: {JOINNER_FINISH}(There was an error while composing the email. Please try again later.)\n", 185 | "Question: Summarize the Apple Demo file. \n" 186 | "open_and_get_file_path(Apple Demo)\n" 187 | "Observation: /Users/eren/Downloads/Apple_Demo.pdf\n" 188 | "summarize_pdf(/Users/eren/Downloads/Apple_Demo.pdf)\n" 189 | "Observation: The new iPhone is going to be released in 2023.\n" 190 | f"Action: {JOINNER_FINISH}({SUMMARY_RESULT})\n", 191 | ] 192 | 193 | OUTPUT_PROMPT = ( 194 | "Follow these rules:\n" 195 | f" - You MUST only output either {JOINNER_FINISH} or {JOINNER_REPLAN}, or you WILL BE PENALIZED.\n" 196 | f"{JOINNER_FINISH_RULES}" 197 | f"{JOINNER_REPLAN_RULES}" 198 | "\n" 199 | "Here are some examples:\n" 200 | + "###\n".join(FINISH_EXAMPLES) 201 | + "###\n" 202 | + "###\n".join(REPLAN_EXAMPLES) 203 | + "###\n" 204 | ) 205 | 206 | 207 | OUTPUT_PROMPT_FINAL = ( 208 | "Follow these rules:\n" 209 | f" - You MUST only output {JOINNER_FINISH}, or you WILL BE PENALIZED.\n" 210 | f"{JOINNER_FINISH_RULES}" 211 | "\n" 212 | "Here are some examples:\n" + "###\n".join(FINISH_EXAMPLES) + "###\n" 213 | ) 214 | -------------------------------------------------------------------------------- /src/tiny_agent/run_apple_script.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def run_applescript(script: str) -> str: 5 | """ 6 | Runs the given AppleScript using osascript and returns the result. 7 | """ 8 | args = ["osascript", "-e", script] 9 | return subprocess.check_output(args, universal_newlines=True) 10 | 11 | 12 | def run_applescript_capture(script: str) -> tuple[str, str]: 13 | """ 14 | Runs the given AppleScript using osascript, captures the output and error, and returns them. 15 | """ 16 | args = ["osascript", "-e", script] 17 | result = subprocess.run(args, capture_output=True, text=True, check=False) 18 | stdout, stderr = result.stdout, result.stderr 19 | return stdout, stderr 20 | 21 | 22 | def run_command(command) -> tuple[str, str]: 23 | """ 24 | Executes a shell command and returns the output. 25 | """ 26 | result = subprocess.run(command, capture_output=True, text=True, check=False) 27 | stdout, stderr = result.stdout, result.stderr 28 | return stdout, stderr 29 | -------------------------------------------------------------------------------- /src/tiny_agent/sub_agents/compose_email_agent.py: -------------------------------------------------------------------------------- 1 | import re 2 | from enum import Enum 3 | 4 | from langchain_core.messages import HumanMessage, SystemMessage 5 | 6 | from src.tiny_agent.models import ComposeEmailMode 7 | from src.tiny_agent.sub_agents.sub_agent import SubAgent 8 | 9 | 10 | class ComposeEmailAgent(SubAgent): 11 | _query: str 12 | 13 | @property 14 | def query(self) -> str: 15 | return self._query 16 | 17 | @query.setter 18 | def query(self, query: str) -> None: 19 | self._query = query 20 | 21 | async def __call__( 22 | self, 23 | context: str, 24 | email_thread: str = "", 25 | mode: ComposeEmailMode = ComposeEmailMode.NEW, 26 | ) -> str: 27 | # Define the system prompt for the LLM to generate HTML content 28 | context = context.strip() 29 | cleaned_thread = re.sub(r"\n+", "\n", email_thread.strip()) 30 | if mode == ComposeEmailMode.NEW: 31 | email_llm_system_prompt = ( 32 | "You are an expert email composer agent. Given an email content or a user query, you MUST generate a well-formatted and " 33 | "informative email. The email should include a polite greeting, a detailed body, and a " 34 | "professional sign-off. You MUST NOT include a subject. The email should be well-structured and free of grammatical errors." 35 | ) 36 | elif mode == ComposeEmailMode.REPLY: 37 | email_llm_system_prompt = ( 38 | "You are an expert email composer agent. Given the content of the past email thread and a user query, " 39 | "you MUST generate a well-formatted and informative reply to the last email in the thread. " 40 | "The email should include a polite greeting, a detailed body, and a " 41 | "professional sign-off. You MUST NOT include a subject. The email should be well-structured and free of grammatical errors." 42 | ) 43 | context += f"\nEmail Thread:\n{cleaned_thread}" 44 | elif mode == ComposeEmailMode.FORWARD: 45 | email_llm_system_prompt = ( 46 | "You are an expert email composer agent. Given the content of the past email thread and a user query, " 47 | "you MUST generate a very concise and informative forward of the last email in the thread. " 48 | ) 49 | context += f"\nEmail Thread:\n{cleaned_thread}" 50 | 51 | # Add custom instructions to the system prompt if specified 52 | if self._custom_instructions is not None: 53 | email_llm_system_prompt += ( 54 | "\nHere are some general facts about the user's preferences, " 55 | f"you MUST keep these in mind when writing your email:\n{self._custom_instructions}" 56 | ) 57 | 58 | email_human_prompt = "Context:\n{context}\nQuery: {query}\nEmail Body:\n" 59 | messages = [ 60 | SystemMessage(content=email_llm_system_prompt), 61 | HumanMessage( 62 | content=email_human_prompt.format(context=context, query=self._query) 63 | ), 64 | ] 65 | 66 | # If the context content is too long, then get the first X tokens of the subagent llm 67 | new_context = self.check_context_length(messages, context) 68 | if new_context is not None: 69 | messages = [ 70 | SystemMessage(content=email_llm_system_prompt), 71 | HumanMessage( 72 | content=email_human_prompt.format( 73 | context=new_context, query=self._query 74 | ) 75 | ), 76 | ] 77 | 78 | # Generate the HTML content for the email 79 | email_content = await self._llm.apredict_messages(messages) 80 | 81 | return str(email_content.content) 82 | -------------------------------------------------------------------------------- /src/tiny_agent/sub_agents/notes_agent.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from bs4 import BeautifulSoup 4 | from langchain_core.messages import HumanMessage, SystemMessage 5 | 6 | from src.tiny_agent.models import NotesMode 7 | from src.tiny_agent.sub_agents.sub_agent import SubAgent 8 | 9 | 10 | class NotesAgent(SubAgent): 11 | 12 | async def __call__( 13 | self, 14 | name: str, 15 | content: str, 16 | prev_content: str = "", 17 | mode: NotesMode = NotesMode.NEW, 18 | ) -> str: 19 | # Construct the prompt 20 | if mode == NotesMode.NEW: 21 | notes_llm_system_prompt = ( 22 | "You are an expert note taking agent. Given the plain text content, you MUST generate a compelling and " 23 | "formatted HTML version of the note (with , , tags, etc.) The content of the note should be rich, well-structured, and verbose. " 24 | ) 25 | elif mode == NotesMode.APPEND: 26 | notes_llm_system_prompt = ( 27 | "You are an expert note-taking agent that specializes in appending new content to existing notes. " 28 | "Given the content of an existing note and the content to append, you MUST generate a continuation of the note that " 29 | "seamlessly integrates with the existing content. Your additions should maintain the tone, style, " 30 | "and subject matter of the original note. You MUST ONLY output the appended content in HTML format, DO NOT include the entire note, or you WILL BE PENALIZED.\n" 31 | f"Previous Content:\n{prev_content}\n" 32 | ) 33 | 34 | notes_llm_system_prompt += ( 35 | "The note should include appropriate use of headings, bold and italic text for emphasis, " 36 | "bullet points for lists, and paragraph tags for separation of ideas. DO NOT include the '' tag." 37 | ) 38 | 39 | # Add custom instructions to the system prompt if specified 40 | if self._custom_instructions is not None: 41 | notes_llm_system_prompt += ( 42 | "\nHere are some general facts about the user's preferences, " 43 | f"you MUST keep these in mind when generating your note:\n{self._custom_instructions}" 44 | ) 45 | 46 | notes_human_prompt = ( 47 | "Note Title: {name}\nNew Text Content: {content}\nHTML Content:\n" 48 | ) 49 | messages = [ 50 | SystemMessage(content=notes_llm_system_prompt), 51 | HumanMessage(content=notes_human_prompt.format(name=name, content=content)), 52 | ] 53 | plain_text_content = self.check_context_length(messages, content) 54 | 55 | # If the plain text content is too long, then get the first X tokens of the subagent llm 56 | if plain_text_content is not None: 57 | messages = [ 58 | SystemMessage(content=notes_llm_system_prompt), 59 | HumanMessage( 60 | content=notes_human_prompt.format(name=name, content=content) 61 | ), 62 | ] 63 | 64 | # Generate the HTML content for the note 65 | html_content = await self._llm.apredict_messages(messages) 66 | 67 | # If the html doesn't start with a tag, then add it 68 | soup = BeautifulSoup(str(html_content.content), "html.parser") 69 | if not soup.find("html"): 70 | soup = BeautifulSoup( 71 | f"{str(html_content.content)}", "html.parser" 72 | ) 73 | 74 | return str(soup.find("html")) 75 | -------------------------------------------------------------------------------- /src/tiny_agent/sub_agents/pdf_summarizer_agent.py: -------------------------------------------------------------------------------- 1 | import fitz 2 | from langchain_core.messages import HumanMessage, SystemMessage 3 | 4 | from src.tiny_agent.sub_agents.sub_agent import SubAgent 5 | 6 | CONTEXT_LENGTHS = {"gpt-4-1106-preview": 127000, "gpt-3.5-turbo": 16000} 7 | 8 | 9 | class PDFSummarizerAgent(SubAgent): 10 | _cached_summary_result: str 11 | 12 | @property 13 | def cached_summary_result(self) -> str: 14 | return self._cached_summary_result 15 | 16 | async def __call__(self, pdf_path: str) -> str: 17 | # Check if the file exists 18 | if ( 19 | pdf_path is None 20 | or len(pdf_path) <= 0 21 | or pdf_path 22 | in ( 23 | "No file found after fuzzy matching.", 24 | "No file found with exact or fuzzy name.", 25 | ) 26 | ): 27 | return "The PDF file path is invalid or the file doesn't exist." 28 | 29 | try: 30 | pdf_content = PDFSummarizerAgent._extract_text_from_pdf(pdf_path) 31 | except Exception as e: 32 | return f"An error occurred while extracting the content from the PDF file: {str(e)}" 33 | 34 | if len(pdf_content) <= 0: 35 | return "The PDF file is empty or the content couldn't be extracted." 36 | 37 | # Construct the prompt 38 | pdf_summarizer_llm_system_prompt = ( 39 | "You are an expert PDF summarizer agent. Given the PDF content, you MUST generate an informative and verbose " 40 | "summary of the content. The summary should include the main points and key details of the content. " 41 | ) 42 | pdf_summarizer_human_prompt = "PDF Content:\n{pdf_content}\nSummary:\n" 43 | messages = [ 44 | SystemMessage(content=pdf_summarizer_llm_system_prompt), 45 | HumanMessage( 46 | content=pdf_summarizer_human_prompt.format(pdf_content=pdf_content) 47 | ), 48 | ] 49 | 50 | # If the PDF content is too long, then get the first X tokens of the subagent llm 51 | pdf_content = self.check_context_length(messages, pdf_content) 52 | if pdf_content is not None: 53 | messages = [ 54 | SystemMessage(content=pdf_summarizer_llm_system_prompt), 55 | HumanMessage( 56 | content=pdf_summarizer_human_prompt.format(pdf_content=pdf_content) 57 | ), 58 | ] 59 | 60 | # Call LLM 61 | summary = await self._llm.apredict_messages(messages) 62 | 63 | # Cache the summary result 64 | self._cached_summary_result = str(summary.content) 65 | 66 | return self._cached_summary_result 67 | 68 | @staticmethod 69 | def _extract_text_from_pdf(pdf_path: str) -> str: 70 | doc = fitz.open(pdf_path) 71 | text = [] 72 | for page in doc: 73 | text.append(page.get_text()) # type: ignore 74 | doc.close() 75 | return "".join(text).replace("\n", " ") 76 | -------------------------------------------------------------------------------- /src/tiny_agent/sub_agents/sub_agent.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from langchain.llms.base import BaseLLM 4 | from langchain_core.messages import BaseMessage, get_buffer_string 5 | 6 | from src.tiny_agent.config import ModelConfig 7 | from src.tiny_agent.models import Tokenizer 8 | 9 | 10 | class SubAgent(abc.ABC): 11 | # A constant to make the context length check more conservative 12 | # And to allow for more output tokens to be generated 13 | _CONTEXT_LENGTH_TRUST = 500 14 | 15 | _llm: BaseLLM 16 | _tokenizer: Tokenizer 17 | _context_length: int 18 | _custom_instructions: str | None 19 | 20 | def __init__( 21 | self, 22 | llm: BaseLLM, 23 | config: ModelConfig, 24 | custom_instructions: str | None, 25 | ) -> None: 26 | assert ( 27 | config.tokenizer is not None 28 | ), "Tokenizer must be provided for sub-agents." 29 | self._llm = llm 30 | self._tokenizer = config.tokenizer 31 | self._context_length = config.context_length 32 | self._custom_instructions = custom_instructions 33 | 34 | @abc.abstractmethod 35 | async def __call__(self, *args, **kwargs) -> str: 36 | pass 37 | 38 | def check_context_length( 39 | self, messages: list[BaseMessage], context: str 40 | ) -> str | None: 41 | """ 42 | Checks if the final length of the messages is greater than the context length, 43 | and if so, removes the excess tokens from the context. If no, return None. 44 | """ 45 | text = get_buffer_string(messages) 46 | total_length = len(self._tokenizer.encode(text)) 47 | length_to_remove = total_length - self._context_length 48 | 49 | if length_to_remove <= 0: 50 | return None 51 | 52 | context = self._tokenizer.decode( 53 | self._tokenizer.encode(context)[ 54 | : -length_to_remove - self._CONTEXT_LENGTH_TRUST 55 | ] 56 | ) 57 | 58 | return context 59 | -------------------------------------------------------------------------------- /src/tiny_agent/tiny_agent.py: -------------------------------------------------------------------------------- 1 | from src.llm_compiler.constants import END_OF_PLAN, SUMMARY_RESULT 2 | from src.llm_compiler.llm_compiler import LLMCompiler 3 | from src.llm_compiler.planner import generate_llm_compiler_prompt 4 | from src.tiny_agent.computer import Computer 5 | from src.tiny_agent.config import TinyAgentConfig 6 | from src.tiny_agent.prompts import ( 7 | DEFAULT_PLANNER_IN_CONTEXT_EXAMPLES_PROMPT, 8 | OUTPUT_PROMPT, 9 | OUTPUT_PROMPT_FINAL, 10 | PLANNER_PROMPT_REPLAN, 11 | get_planner_custom_instructions_prompt, 12 | ) 13 | from src.tiny_agent.sub_agents.compose_email_agent import ComposeEmailAgent 14 | from src.tiny_agent.sub_agents.notes_agent import NotesAgent 15 | from src.tiny_agent.sub_agents.pdf_summarizer_agent import PDFSummarizerAgent 16 | from src.tiny_agent.tiny_agent_tools import ( 17 | get_tiny_agent_tools, 18 | get_tool_names_from_apps, 19 | ) 20 | from src.tiny_agent.tool_rag.base_tool_rag import BaseToolRAG 21 | from src.tiny_agent.tool_rag.classifier_tool_rag import ClassifierToolRAG 22 | from src.utils.model_utils import get_embedding_model, get_model 23 | 24 | 25 | class TinyAgent: 26 | _DEFAULT_TOP_K = 6 27 | 28 | config: TinyAgentConfig 29 | agent: LLMCompiler 30 | computer: Computer 31 | notes_agent: NotesAgent 32 | pdf_summarizer_agent: PDFSummarizerAgent 33 | compose_email_agent: ComposeEmailAgent 34 | tool_rag: BaseToolRAG 35 | 36 | def __init__(self, config: TinyAgentConfig) -> None: 37 | self.config = config 38 | 39 | # Define the models 40 | llm = get_model( 41 | model_type=config.llmcompiler_config.model_type.value, 42 | model_name=config.llmcompiler_config.model_name, 43 | api_key=config.llmcompiler_config.api_key, 44 | stream=False, 45 | vllm_port=config.llmcompiler_config.port, 46 | temperature=0, 47 | azure_api_version=config.azure_api_version, 48 | azure_endpoint=config.azure_endpoint, 49 | azure_deployment=config.llmcompiler_config.model_name, 50 | ) 51 | planner_llm = get_model( 52 | model_type=config.llmcompiler_config.model_type.value, 53 | model_name=config.llmcompiler_config.model_name, 54 | api_key=config.llmcompiler_config.api_key, 55 | stream=True, 56 | vllm_port=config.llmcompiler_config.port, 57 | temperature=0, 58 | azure_api_version=config.azure_api_version, 59 | azure_endpoint=config.azure_endpoint, 60 | azure_deployment=config.llmcompiler_config.model_name, 61 | ) 62 | sub_agent_llm = get_model( 63 | model_type=config.sub_agent_config.model_type.value, 64 | model_name=config.sub_agent_config.model_name, 65 | api_key=config.sub_agent_config.api_key, 66 | stream=False, 67 | vllm_port=config.sub_agent_config.port, 68 | temperature=0, 69 | azure_api_version=config.azure_api_version, 70 | azure_endpoint=config.azure_endpoint, 71 | azure_deployment=config.sub_agent_config.model_name, 72 | ) 73 | 74 | self.computer = Computer() 75 | self.notes_agent = NotesAgent( 76 | sub_agent_llm, config.sub_agent_config, config.custom_instructions 77 | ) 78 | self.pdf_summarizer_agent = PDFSummarizerAgent( 79 | sub_agent_llm, config.sub_agent_config, config.custom_instructions 80 | ) 81 | self.compose_email_agent = ComposeEmailAgent( 82 | sub_agent_llm, config.sub_agent_config, config.custom_instructions 83 | ) 84 | 85 | tools = get_tiny_agent_tools( 86 | computer=self.computer, 87 | notes_agent=self.notes_agent, 88 | pdf_summarizer_agent=self.pdf_summarizer_agent, 89 | compose_email_agent=self.compose_email_agent, 90 | tool_names=get_tool_names_from_apps(config.apps), 91 | zoom_access_token=config.zoom_access_token, 92 | ) 93 | 94 | # Define LLMCompiler 95 | self.agent = LLMCompiler( 96 | tools=tools, 97 | planner_llm=planner_llm, 98 | planner_custom_instructions_prompt=get_planner_custom_instructions_prompt( 99 | tools=tools, custom_instructions=config.custom_instructions 100 | ), 101 | planner_example_prompt=DEFAULT_PLANNER_IN_CONTEXT_EXAMPLES_PROMPT, 102 | planner_example_prompt_replan=PLANNER_PROMPT_REPLAN, 103 | planner_stop=[END_OF_PLAN], 104 | planner_stream=True, 105 | agent_llm=llm, 106 | joinner_prompt=OUTPUT_PROMPT, 107 | joinner_prompt_final=OUTPUT_PROMPT_FINAL, 108 | max_replans=2, 109 | benchmark=False, 110 | ) 111 | 112 | # Define ToolRAG 113 | if config.embedding_model_config is not None: 114 | embedding_model = get_embedding_model( 115 | model_type=config.embedding_model_config.model_type.value, 116 | model_name=config.embedding_model_config.model_name, 117 | api_key=config.embedding_model_config.api_key, 118 | azure_endpoint=config.azure_endpoint, 119 | azure_embedding_deployment=config.embedding_model_config.model_name, 120 | azure_api_version=config.azure_api_version, 121 | local_port=config.embedding_model_config.port, 122 | context_length=config.embedding_model_config.context_length, 123 | ) 124 | self.tool_rag = ClassifierToolRAG( 125 | embedding_model=embedding_model, 126 | tools=tools, 127 | ) 128 | 129 | async def arun(self, query: str) -> str: 130 | if self.config.embedding_model_config is not None: 131 | tool_rag_results = self.tool_rag.retrieve_examples_and_tools( 132 | query, top_k=TinyAgent._DEFAULT_TOP_K 133 | ) 134 | 135 | new_tools = get_tiny_agent_tools( 136 | computer=self.computer, 137 | notes_agent=self.notes_agent, 138 | pdf_summarizer_agent=self.pdf_summarizer_agent, 139 | compose_email_agent=self.compose_email_agent, 140 | tool_names=tool_rag_results.retrieved_tools_set, 141 | zoom_access_token=self.config.zoom_access_token, 142 | ) 143 | 144 | self.agent.planner.system_prompt = generate_llm_compiler_prompt( 145 | tools=new_tools, 146 | example_prompt=tool_rag_results.in_context_examples_prompt, 147 | custom_instructions=get_planner_custom_instructions_prompt( 148 | tools=new_tools, custom_instructions=self.config.custom_instructions 149 | ), 150 | ) 151 | 152 | self.compose_email_agent.query = query 153 | result = await self.agent.arun(query) 154 | 155 | if result == SUMMARY_RESULT: 156 | result = self.pdf_summarizer_agent.cached_summary_result 157 | 158 | return result 159 | -------------------------------------------------------------------------------- /src/tiny_agent/tool_rag/base_tool_rag.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import pickle 4 | from dataclasses import dataclass 5 | from typing import Collection, Sequence 6 | 7 | import torch 8 | from langchain_community.embeddings import HuggingFaceEmbeddings 9 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings 10 | from typing_extensions import TypedDict 11 | 12 | from src.tiny_agent.config import DEFAULT_OPENAI_EMBEDDING_MODEL 13 | from src.tiny_agent.models import TinyAgentToolName 14 | from src.tools.base import StructuredTool, Tool 15 | 16 | TOOLRAG_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | 19 | @dataclass 20 | class ToolRAGResult: 21 | in_context_examples_prompt: str 22 | retrieved_tools_set: Collection[TinyAgentToolName] 23 | 24 | 25 | class PickledEmbedding(TypedDict): 26 | example: str 27 | embedding: torch.Tensor 28 | tools: Sequence[str] 29 | 30 | 31 | class BaseToolRAG(abc.ABC): 32 | """ 33 | The base class for the ToolRAGs that are used to retrieve the in-context examples and tools based on the user query. 34 | """ 35 | 36 | _EMBEDDINGS_DIR_PATH = os.path.join(TOOLRAG_DIR_PATH) 37 | _EMBEDDINGS_FILE_NAME = "embeddings.pkl" 38 | 39 | # Embedding model that computes the embeddings for the examples/tools and the user query 40 | _embedding_model: AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings 41 | # The set of available tools so that we do an initial filtering based on the tools that are available 42 | _available_tools: Sequence[TinyAgentToolName] 43 | # The path to the embeddings.pkl file 44 | _embeddings_pickle_path: str 45 | 46 | def __init__( 47 | self, 48 | embedding_model: ( 49 | AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings 50 | ), 51 | tools: Sequence[Tool | StructuredTool], 52 | ) -> None: 53 | self._embedding_model = embedding_model 54 | self._available_tools = [TinyAgentToolName(tool.name) for tool in tools] 55 | 56 | # TinyAgent currently only supports "text-embedding-3-small" model by default. 57 | # Hence, we only use the directory created for the default model. 58 | model_name = DEFAULT_OPENAI_EMBEDDING_MODEL 59 | self._embeddings_pickle_path = os.path.join( 60 | BaseToolRAG._EMBEDDINGS_DIR_PATH, 61 | model_name.split("/")[-1], # Only use the last model name 62 | BaseToolRAG._EMBEDDINGS_FILE_NAME, 63 | ) 64 | 65 | @property 66 | @abc.abstractmethod 67 | def tool_rag_type(self) -> str: 68 | pass 69 | 70 | @abc.abstractmethod 71 | def retrieve_examples_and_tools(self, query: str, top_k: int) -> ToolRAGResult: 72 | """ 73 | Returns the in-context examples as a formatted prompt and the tools that are relevant to the query. 74 | """ 75 | pass 76 | 77 | def _retrieve_top_k_embeddings( 78 | self, query: str, examples: list[PickledEmbedding], top_k: int 79 | ) -> list[PickledEmbedding]: 80 | """ 81 | Computes the cosine similarity of each example and retrieves the closest top_k examples. 82 | If there are already less than top_k examples, returns the examples directly. 83 | """ 84 | if len(examples) <= top_k: 85 | return examples 86 | 87 | query_embedding = torch.tensor(self._embedding_model.embed_query(query)) 88 | embeddings = torch.stack( 89 | [x["embedding"] for x in examples] 90 | ) # Stacking for batch processing 91 | 92 | # Cosine similarity between query_embedding and all chunks 93 | cosine_similarities = torch.nn.functional.cosine_similarity( 94 | embeddings, query_embedding.unsqueeze(0), dim=1 95 | ) 96 | 97 | # Retrieve the top k indices from cosine_similarities 98 | _, top_k_indices = torch.topk(cosine_similarities, top_k) 99 | 100 | # Select the chunks corresponding to the top k indices 101 | selected_examples = [examples[i] for i in top_k_indices] 102 | 103 | return selected_examples 104 | 105 | def _load_filtered_embeddings( 106 | self, filter_tools: list[TinyAgentToolName] | None = None 107 | ) -> list[PickledEmbedding]: 108 | """ 109 | Loads the embeddings.pkl file that contains a list of PickledEmbedding objects 110 | and returns the filtered results based on the available tools. 111 | """ 112 | with open(self._embeddings_pickle_path, "rb") as file: 113 | embeddings: dict[str, PickledEmbedding] = pickle.load(file) 114 | 115 | filtered_embeddings = [] 116 | tool_names = [tool.value for tool in filter_tools or self._available_tools] 117 | for embedding in embeddings.values(): 118 | # Check if all tools are available in this example 119 | if all(tool in tool_names for tool in embedding["tools"]): 120 | filtered_embeddings.append(embedding) 121 | 122 | return filtered_embeddings 123 | 124 | @staticmethod 125 | def _get_in_context_examples_prompt(embeddings: list[PickledEmbedding]) -> str: 126 | examples = [example["example"] for example in embeddings] 127 | examples_prompt = "###\n".join(examples) 128 | return f"{examples_prompt}###\n" 129 | -------------------------------------------------------------------------------- /src/tiny_agent/tool_rag/classifier_tool_rag.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Sequence 3 | 4 | import torch 5 | from langchain_community.embeddings import HuggingFaceEmbeddings 6 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings 7 | from transformers import ( 8 | AutoModelForSequenceClassification, 9 | AutoTokenizer, 10 | PreTrainedTokenizer, 11 | PreTrainedTokenizerFast, 12 | ) 13 | 14 | from src.tiny_agent.models import TinyAgentToolName 15 | from src.tiny_agent.tool_rag.base_tool_rag import BaseToolRAG, ToolRAGResult 16 | from src.tools.base import StructuredTool, Tool 17 | 18 | 19 | class ClassifierToolRAG(BaseToolRAG): 20 | _CLASSIFIER_MODEL_NAME = "squeeze-ai-lab/TinyAgent-ToolRAG" 21 | _DEFAULT_TOOL_THRESHOLD = 0.5 22 | _NUM_LABELS = 16 23 | _ID_TO_TOOL = { 24 | 0: TinyAgentToolName.CREATE_CALENDAR_EVENT, 25 | 1: TinyAgentToolName.GET_PHONE_NUMBER, 26 | 2: TinyAgentToolName.GET_EMAIL_ADDRESS, 27 | 3: TinyAgentToolName.OPEN_AND_GET_FILE_PATH, 28 | 4: TinyAgentToolName.SUMMARIZE_PDF, 29 | 5: TinyAgentToolName.COMPOSE_NEW_EMAIL, 30 | 6: TinyAgentToolName.REPLY_TO_EMAIL, 31 | 7: TinyAgentToolName.FORWARD_EMAIL, 32 | 8: TinyAgentToolName.MAPS_OPEN_LOCATION, 33 | 9: TinyAgentToolName.MAPS_SHOW_DIRECTIONS, 34 | 10: TinyAgentToolName.CREATE_NOTE, 35 | 11: TinyAgentToolName.OPEN_NOTE, 36 | 12: TinyAgentToolName.APPEND_NOTE_CONTENT, 37 | 13: TinyAgentToolName.CREATE_REMINDER, 38 | 14: TinyAgentToolName.SEND_SMS, 39 | 15: TinyAgentToolName.GET_ZOOM_MEETING_LINK, 40 | } 41 | 42 | _tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast 43 | _classifier_model: Any 44 | _tool_threshold: float 45 | 46 | def __init__( 47 | self, 48 | embedding_model: ( 49 | AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings 50 | ), 51 | tools: Sequence[Tool | StructuredTool], 52 | tool_threshold: float = _DEFAULT_TOOL_THRESHOLD, 53 | ): 54 | super().__init__(embedding_model, tools) 55 | 56 | self._tokenizer = AutoTokenizer.from_pretrained( 57 | ClassifierToolRAG._CLASSIFIER_MODEL_NAME 58 | ) 59 | self._classifier_model = AutoModelForSequenceClassification.from_pretrained( 60 | ClassifierToolRAG._CLASSIFIER_MODEL_NAME, 61 | num_labels=ClassifierToolRAG._NUM_LABELS, 62 | ) 63 | self._tool_threshold = tool_threshold 64 | 65 | @property 66 | def tool_rag_type(self) -> str: 67 | return "classifier_tool_rag" 68 | 69 | def retrieve_examples_and_tools(self, query: str, top_k: int) -> ToolRAGResult: 70 | """ 71 | Returns the in-context examples as a formatted prompt and the tools that are relevant to the query. 72 | It first retrieves the best tools for the given query and then it retrieves top k examples 73 | that use the retrieved tools. 74 | """ 75 | retrieved_tools = self._classify_tools(query) 76 | # Filter the tools that are available 77 | retrieved_tools = list(set(retrieved_tools) & set(self._available_tools)) 78 | filtered_embeddings = self._load_filtered_embeddings(retrieved_tools) 79 | retrieved_embeddings = self._retrieve_top_k_embeddings( 80 | query, filtered_embeddings, top_k 81 | ) 82 | 83 | in_context_examples_prompt = BaseToolRAG._get_in_context_examples_prompt( 84 | retrieved_embeddings 85 | ) 86 | 87 | return ToolRAGResult( 88 | in_context_examples_prompt=in_context_examples_prompt, 89 | retrieved_tools_set=retrieved_tools, 90 | ) 91 | 92 | def _classify_tools(self, query: str) -> list[TinyAgentToolName]: 93 | """ 94 | Retrieves the best tools for the given query by classification. 95 | """ 96 | inputs = self._tokenizer( 97 | query, return_tensors="pt", padding=True, truncation=True, max_length=512 98 | ) 99 | 100 | # Get the output probabilities 101 | with torch.no_grad(): 102 | outputs = self._classifier_model(**inputs) 103 | logits = outputs.logits 104 | probs = torch.sigmoid(logits) 105 | 106 | # Retrieve the tools that have a probability greater than the threshold 107 | retrieved_tools = [ 108 | ClassifierToolRAG._ID_TO_TOOL[i] 109 | for i, prob in enumerate(probs[0]) 110 | if prob > self._tool_threshold 111 | ] 112 | 113 | return retrieved_tools 114 | -------------------------------------------------------------------------------- /src/tiny_agent/tool_rag/simple_tool_rag.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from langchain_community.embeddings import HuggingFaceEmbeddings 4 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings 5 | 6 | from src.tiny_agent.models import TinyAgentToolName 7 | from src.tiny_agent.tool_rag.base_tool_rag import BaseToolRAG, ToolRAGResult 8 | from src.tools.base import StructuredTool, Tool 9 | 10 | 11 | class SimpleToolRAG(BaseToolRAG): 12 | def __init__( 13 | self, 14 | embedding_model: ( 15 | AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings 16 | ), 17 | tools: Sequence[Tool | StructuredTool], 18 | ): 19 | super().__init__(embedding_model, tools) 20 | 21 | @property 22 | def tool_rag_type(self) -> str: 23 | return "simple_tool_rag" 24 | 25 | def retrieve_examples_and_tools(self, query: str, top_k: int) -> ToolRAGResult: 26 | """ 27 | Returns the in-context examples as a formatted prompt and the tools that are relevant to the query 28 | It first filters the examples based on the tools that are available and then retrieves the examples 29 | and tools based on the query. 30 | """ 31 | filtered_embeddings = self._load_filtered_embeddings() 32 | retrieved_embeddings = self._retrieve_top_k_embeddings( 33 | query, filtered_embeddings, top_k 34 | ) 35 | in_context_examples_prompt = BaseToolRAG._get_in_context_examples_prompt( 36 | retrieved_embeddings 37 | ) 38 | 39 | tools_names = set( 40 | sum( 41 | [ 42 | [TinyAgentToolName(tool) for tool in example["tools"]] 43 | for example in retrieved_embeddings 44 | ], 45 | [], 46 | ) 47 | ) 48 | 49 | return ToolRAGResult( 50 | in_context_examples_prompt=in_context_examples_prompt, 51 | retrieved_tools_set=tools_names, 52 | ) 53 | -------------------------------------------------------------------------------- /src/tiny_agent/tool_rag/text-embedding-3-small/embeddings.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SqueezeAILab/TinyAgent/cc45c0e842f5d163c3df1c8f41d60e90e005867d/src/tiny_agent/tool_rag/text-embedding-3-small/embeddings.pkl -------------------------------------------------------------------------------- /src/tiny_agent/tools/calendar.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import platform 3 | import subprocess 4 | 5 | from src.tiny_agent.run_apple_script import run_applescript, run_applescript_capture 6 | 7 | 8 | class Calendar: 9 | def __init__(self): 10 | self.calendar_app = "Calendar" 11 | 12 | def create_event( 13 | self, 14 | title: str, 15 | start_date: datetime.datetime, 16 | end_date: datetime.datetime, 17 | location: str = "", 18 | invitees: list[str] = [], 19 | notes: str = "", 20 | calendar: str | None = None, 21 | ) -> str: 22 | """ 23 | Creates a new event with the given title, start date, end date, location, and notes. 24 | """ 25 | if platform.system() != "Darwin": 26 | return "This method is only supported on MacOS" 27 | 28 | applescript_start_date = start_date.strftime("%B %d, %Y %I:%M:%S %p") 29 | applescript_end_date = end_date.strftime("%B %d, %Y %I:%M:%S %p") 30 | 31 | # Check if the given calendar parameter is valid 32 | if calendar is not None: 33 | script = f""" 34 | tell application "{self.calendar_app}" 35 | set calendarExists to name of calendars contains "{calendar}" 36 | end tell 37 | """ 38 | exists = run_applescript(script) 39 | if not exists: 40 | calendar = self._get_first_calendar() 41 | if calendar is None: 42 | return f"Can't find the calendar named {calendar}. Please try again and specify a valid calendar name." 43 | 44 | # If it is not provded, default to the first calendar 45 | elif calendar is None: 46 | calendar = self._get_first_calendar() 47 | if calendar is None: 48 | return "Can't find a default calendar. Please try again and specify a calendar name." 49 | 50 | invitees_script = [] 51 | for invitee in invitees: 52 | invitees_script.append( 53 | f""" 54 | make new attendee at theEvent with properties {{email:"{invitee}"}} 55 | """ 56 | ) 57 | invitees_script = "".join(invitees_script) 58 | 59 | script = f""" 60 | tell application "System Events" 61 | set calendarIsRunning to (name of processes) contains "{self.calendar_app}" 62 | if calendarIsRunning then 63 | tell application "{self.calendar_app}" to activate 64 | else 65 | tell application "{self.calendar_app}" to launch 66 | delay 1 67 | tell application "{self.calendar_app}" to activate 68 | end if 69 | end tell 70 | tell application "{self.calendar_app}" 71 | tell calendar "{calendar}" 72 | set startDate to date "{applescript_start_date}" 73 | set endDate to date "{applescript_end_date}" 74 | set theEvent to make new event at end with properties {{summary:"{title}", start date:startDate, end date:endDate, location:"{location}", description:"{notes}"}} 75 | {invitees_script} 76 | switch view to day view 77 | show theEvent 78 | end tell 79 | tell application "{self.calendar_app}" to reload calendars 80 | end tell 81 | """ 82 | 83 | try: 84 | run_applescript(script) 85 | return f"""Event created successfully in the "{calendar}" calendar.""" 86 | except subprocess.CalledProcessError as e: 87 | return str(e) 88 | 89 | def _get_first_calendar(self) -> str | None: 90 | script = f""" 91 | tell application "System Events" 92 | set calendarIsRunning to (name of processes) contains "{self.calendar_app}" 93 | if calendarIsRunning is false then 94 | tell application "{self.calendar_app}" to launch 95 | delay 1 96 | end if 97 | end tell 98 | tell application "{self.calendar_app}" 99 | set firstCalendarName to name of first calendar 100 | end tell 101 | return firstCalendarName 102 | """ 103 | stdout = run_applescript_capture(script) 104 | if stdout: 105 | return stdout[0].strip() 106 | else: 107 | return None 108 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/contacts.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | from src.tiny_agent.run_apple_script import run_applescript_capture 4 | 5 | 6 | class Contacts: 7 | def __init__(self): 8 | pass 9 | 10 | def get_phone_number(self, contact_name: str) -> str: 11 | """ 12 | Returns the phone number of a contact by name. 13 | """ 14 | if platform.system() != "Darwin": 15 | return "This method is only supported on MacOS" 16 | 17 | script = f""" 18 | tell application "Contacts" 19 | set thePerson to first person whose name is "{contact_name}" 20 | set theNumber to value of first phone of thePerson 21 | return theNumber 22 | end tell 23 | """ 24 | stout, stderr = run_applescript_capture(script) 25 | # If the person is not found, we will try to find similar contacts 26 | if "Can’t get person" in stderr: 27 | first_name = contact_name.split(" ")[0] 28 | names = self.get_full_names_from_first_name(first_name) 29 | if "No contacts found" in names or len(names) == 0: 30 | return "No contacts found" 31 | else: 32 | # Just find the first person 33 | return self.get_phone_number(names[0]) 34 | else: 35 | return stout.replace("\n", "") 36 | 37 | def get_email_address(self, contact_name: str) -> str: 38 | """ 39 | Returns the email address of a contact by name. 40 | """ 41 | if platform.system() != "Darwin": 42 | return "This method is only supported on MacOS" 43 | 44 | script = f""" 45 | tell application "Contacts" 46 | set thePerson to first person whose name is "{contact_name}" 47 | set theEmail to value of first email of thePerson 48 | return theEmail 49 | end tell 50 | """ 51 | stout, stderr = run_applescript_capture(script) 52 | # If the person is not found, we will try to find similar contacts 53 | if "Can’t get person" in stderr: 54 | first_name = contact_name.split(" ")[0] 55 | names = self.get_full_names_from_first_name(first_name) 56 | if "No contacts found" in names or len(names) == 0: 57 | return "No contacts found" 58 | else: 59 | # Just find the first person 60 | return self.get_email_address(names[0]) 61 | else: 62 | return stout.replace("\n", "") 63 | 64 | def get_full_names_from_first_name(self, first_name: str) -> list[str] | str: 65 | """ 66 | Returns a list of full names of contacts that contain the first name provided. 67 | """ 68 | if platform.system() != "Darwin": 69 | return "This method is only supported on MacOS" 70 | 71 | script = f""" 72 | tell application "Contacts" 73 | set matchingPeople to every person whose first name contains "{first_name}" 74 | set namesList to {{}} 75 | repeat with aPerson in matchingPeople 76 | set end of namesList to name of aPerson 77 | end repeat 78 | return namesList 79 | end tell 80 | """ 81 | names, _ = run_applescript_capture(script) 82 | names = names.strip() 83 | if len(names) > 0: 84 | # Turn name into a list of strings 85 | names = list(map(lambda n: n.strip(), names.split(","))) 86 | return names 87 | else: 88 | return "No contacts found." 89 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/mail.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import subprocess 3 | 4 | from src.tiny_agent.run_apple_script import run_applescript, run_applescript_capture 5 | 6 | 7 | class Mail: 8 | def __init__(self) -> None: 9 | self.mail_app: str = "Mail" 10 | 11 | def compose_email( 12 | self, 13 | recipients: list[str], 14 | subject: str, 15 | content: str, 16 | attachments: list[str], 17 | cc: list[str], 18 | ) -> str: 19 | """ 20 | Composes a new email with the given recipients, subject, content, and attaches files from the given paths. 21 | Adds cc recipients if provided. Does not send it but opens the composed email to the user. 22 | """ 23 | if platform.system() != "Darwin": 24 | return "This method is only supported on MacOS" 25 | 26 | # Format recipients and cc recipients for AppleScript list 27 | recipients_list = Mail._format_email_addresses(recipients) 28 | cc_list = Mail._format_email_addresses(cc) 29 | attachments_str = Mail._format_attachments(attachments) 30 | 31 | content = content.replace('"', '\\"').replace("'", "’") 32 | script = f""" 33 | tell application "{self.mail_app}" 34 | set newMessage to make new outgoing message with properties {{subject:"{subject}", content:"{content}" & return & return}} 35 | tell newMessage 36 | repeat with address in {recipients_list} 37 | make new to recipient at end of to recipients with properties {{address:address}} 38 | end repeat 39 | repeat with address in {cc_list} 40 | make new cc recipient at end of cc recipients with properties {{address:address}} 41 | end repeat 42 | {attachments_str} 43 | end tell 44 | activate 45 | end tell 46 | """ 47 | 48 | try: 49 | run_applescript(script) 50 | return "New email composed successfully with attachments and cc." 51 | except subprocess.CalledProcessError as e: 52 | return str(e) 53 | 54 | def reply_to_email( 55 | self, content: str, cc: list[str], attachments: list[str] 56 | ) -> str: 57 | """ 58 | Replies to the currently selected email in Mail with the given content. 59 | """ 60 | if platform.system() != "Darwin": 61 | return "This method is only supported on MacOS" 62 | 63 | cc_list = Mail._format_email_addresses(cc) 64 | attachments_str = Mail._format_attachments(attachments) 65 | 66 | content = content.replace('"', '\\"').replace("'", "’") 67 | script = f""" 68 | tell application "{self.mail_app}" 69 | activate 70 | set selectedMessages to selected messages of message viewer 1 71 | if (count of selectedMessages) < 1 then 72 | return "No message selected." 73 | else 74 | set theMessage to item 1 of selectedMessages 75 | set theReply to reply theMessage opening window yes 76 | tell theReply 77 | repeat with address in {cc_list} 78 | make new cc recipient at end of cc recipients with properties {{address:address}} 79 | end repeat 80 | set content to "{content}" 81 | {attachments_str} 82 | end tell 83 | end if 84 | end tell 85 | """ 86 | 87 | try: 88 | run_applescript(script) 89 | return "Replied to the selected email successfully." 90 | except subprocess.CalledProcessError as e: 91 | return "An email has to be viewed in Mail to reply to it." 92 | 93 | def forward_email( 94 | self, recipients: list[str], cc: list[str], attachments: list[str] 95 | ) -> str: 96 | """ 97 | Forwards the currently selected email in Mail to the given recipients with the given content. 98 | """ 99 | if platform.system() != "Darwin": 100 | return "This method is only supported on MacOS" 101 | 102 | # Format recipients and cc recipients for AppleScript list 103 | recipients_list = Mail._format_email_addresses(recipients) 104 | cc_list = Mail._format_email_addresses(cc) 105 | attachments_str = Mail._format_attachments(attachments) 106 | 107 | script = f""" 108 | tell application "{self.mail_app}" 109 | activate 110 | set selectedMessages to selected messages of message viewer 1 111 | if (count of selectedMessages) < 1 then 112 | return "No message selected." 113 | else 114 | set theMessage to item 1 of selectedMessages 115 | set theForward to forward theMessage opening window yes 116 | tell theForward 117 | repeat with address in {recipients_list} 118 | make new to recipient at end of to recipients with properties {{address:address}} 119 | end repeat 120 | repeat with address in {cc_list} 121 | make new cc recipient at end of cc recipients with properties {{address:address}} 122 | end repeat 123 | {attachments_str} 124 | end tell 125 | end if 126 | end tell 127 | """ 128 | 129 | try: 130 | run_applescript(script) 131 | return "Forwarded the selected email successfully." 132 | except subprocess.CalledProcessError as e: 133 | return "An email has to be viewed in Mail to forward it." 134 | 135 | def get_email_content(self) -> str: 136 | """ 137 | Gets the content of the currently viewed email in Mail. 138 | """ 139 | if platform.system() != "Darwin": 140 | return "This method is only supported on MacOS" 141 | 142 | script = f""" 143 | tell application "{self.mail_app}" 144 | activate 145 | set selectedMessages to selected messages of message viewer 1 146 | if (count of selectedMessages) < 1 then 147 | return "No message selected." 148 | else 149 | set theMessage to item 1 of selectedMessages 150 | -- Get the content of the message 151 | set theContent to content of theMessage 152 | return theContent 153 | end if 154 | end tell 155 | """ 156 | 157 | try: 158 | return run_applescript(script) 159 | except subprocess.CalledProcessError as e: 160 | return "No message selected or found." 161 | 162 | def find_and_select_first_email_from(self, sender: str) -> str: 163 | """ 164 | Finds and selects an email in Mail based on the sender's name. 165 | """ 166 | if platform.system() != "Darwin": 167 | return "This method is only supported on MacOS" 168 | 169 | script = f""" 170 | tell application "{self.mail_app}" 171 | set theSender to "{sender}" 172 | set theMessage to first message of inbox whose sender contains theSender 173 | set selected messages of message viewer 1 to {{theMessage}} 174 | activate 175 | open theMessage 176 | end tell 177 | """ 178 | 179 | try: 180 | run_applescript(script) 181 | return "Found and selected the email successfully." 182 | except subprocess.CalledProcessError as e: 183 | return "No message found from the sender." 184 | 185 | @staticmethod 186 | def _format_email_addresses(emails: list[str]) -> str: 187 | return "{" + ", ".join([f'"{email}"' for email in emails]) + "}" 188 | 189 | @staticmethod 190 | def _format_attachments(attachments: list[str]) -> str: 191 | attachments_str = [] 192 | for attachment in attachments: 193 | attachment = attachment.replace('"', '\\"') 194 | attachments_str.append( 195 | f""" 196 | make new attachment with properties {{file name:"{attachment}"}} at after the last paragraph 197 | """ 198 | ) 199 | return "".join(attachments_str) 200 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/maps.py: -------------------------------------------------------------------------------- 1 | import webbrowser 2 | from urllib.parse import quote_plus 3 | 4 | from src.tiny_agent.models import TransportationOptions 5 | 6 | 7 | class Maps: 8 | def __init__(self): 9 | pass 10 | 11 | def open_location(self, query: str): 12 | """ 13 | Opens the specified location in Apple Maps. 14 | The query can be a place name, address, or coordinates. 15 | """ 16 | base_url = "https://maps.apple.com/?q=" 17 | query_encoded = quote_plus(query) 18 | full_url = base_url + query_encoded 19 | webbrowser.open(full_url) 20 | return f"Location of {query} in Apple Maps: {full_url}" 21 | 22 | def show_directions( 23 | self, 24 | end: str, 25 | start: str = "", 26 | transport: TransportationOptions = TransportationOptions.DRIVING, 27 | ): 28 | """ 29 | Shows directions from a start location to an end location in Apple Maps. 30 | The transport parameter defaults to 'd' (driving), but can also be 'w' (walking) or 'r' (public transit). 31 | The start location can be left empty to default to the current location of the device. 32 | """ 33 | base_url = "https://maps.apple.com/?" 34 | if len(start) > 0: 35 | start_encoded = quote_plus(start) 36 | start_param = f"saddr={start_encoded}&" 37 | else: 38 | start_param = "" # Use the current location 39 | end_encoded = quote_plus(end) 40 | transport_flag = f"dirflg={transport.value}" 41 | full_url = f"{base_url}{start_param}daddr={end_encoded}&{transport_flag}" 42 | webbrowser.open(full_url) 43 | return f"Directions to {end} in Apple Maps: {full_url}" 44 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/notes.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import platform 3 | import subprocess 4 | 5 | from bs4 import BeautifulSoup 6 | 7 | from src.tiny_agent.run_apple_script import run_applescript, run_applescript_capture 8 | 9 | 10 | class Notes: 11 | _DEFAULT_FOLDER = "Notes" 12 | 13 | def __init__(self): 14 | self.notes_app = "Notes" 15 | 16 | def create_note(self, name: str, content: str, folder: str | None = None) -> str: 17 | """ 18 | Creates a new note with the given content and focuses on it. If a folder is specified, the note 19 | is created in that folder; otherwise, it's created in the default folder. 20 | """ 21 | if platform.system() != "Darwin": 22 | return "This method is only supported on MacOS" 23 | 24 | folder_line = self._get_folder_line(folder) 25 | html_content = content.replace('"', '\\"').replace("'", "’") 26 | 27 | script = f""" 28 | tell application "{self.notes_app}" 29 | tell account "iCloud" 30 | {folder_line} 31 | set newNote to make new note with properties {{body:"{html_content}"}} 32 | end tell 33 | end tell 34 | activate 35 | tell application "System Events" 36 | tell process "Notes" 37 | set frontmost to true 38 | delay 0.5 -- wait a bit for the note to be created and focus to be set 39 | end tell 40 | end tell 41 | tell application "{self.notes_app}" 42 | show newNote 43 | end tell 44 | end tell 45 | """ 46 | 47 | try: 48 | run_applescript(script) 49 | return "Note created and focused successfully." 50 | except subprocess.CalledProcessError as e: 51 | return str(e) 52 | 53 | def open_note( 54 | self, 55 | name: str, 56 | folder: str | None = None, 57 | return_content: bool = False, 58 | ) -> str: 59 | """ 60 | Opens an existing note by its name and optionally returns its content. 61 | If no exact match is found, attempts fuzzy matching to suggest possible notes. 62 | If return_content is True, returns the content of the note. 63 | """ 64 | if platform.system() != "Darwin": 65 | return "This method is only supported on MacOS" 66 | 67 | folder_line = self._get_folder_line(folder) 68 | 69 | # Adjust the script to return content if required 70 | content_line = ( 71 | "return body of theNote" 72 | if return_content 73 | else 'return "Note opened successfully."' 74 | ) 75 | 76 | # Attempt to directly open the note with the exact name and optionally return its content 77 | script_direct_open = f""" 78 | tell application "{self.notes_app}" 79 | tell account "iCloud" 80 | {folder_line} 81 | set matchingNotes to notes whose name is "{name}" 82 | if length of matchingNotes > 0 then 83 | set theNote to item 1 of matchingNotes 84 | show theNote 85 | {content_line} 86 | else 87 | return "No exact match found." 88 | end if 89 | end tell 90 | end tell 91 | end tell 92 | """ 93 | 94 | try: 95 | stdout, _ = run_applescript_capture(script_direct_open) 96 | if ( 97 | "Note opened successfully" in stdout 98 | or "No exact match found" not in stdout 99 | ): 100 | if return_content: 101 | return self._convert_note_to_text(stdout.strip()) 102 | return stdout.strip() # Successfully opened a note with the exact name 103 | 104 | # If no exact match is found, proceed with fuzzy matching 105 | note_to_open = self._do_fuzzy_matching(name) 106 | 107 | # Focus the note with the closest matching name after fuzzy matching 108 | script_focus = f""" 109 | tell application "{self.notes_app}" 110 | tell account "iCloud" 111 | {folder_line} 112 | set theNote to first note whose name is "{note_to_open}" 113 | show theNote 114 | {content_line} 115 | end tell 116 | end tell 117 | activate 118 | end tell 119 | """ 120 | result = run_applescript(script_focus) 121 | if return_content: 122 | return self._convert_note_to_text(result.strip()) 123 | return result.strip() 124 | except subprocess.CalledProcessError as e: 125 | return f"Error: {str(e)}" 126 | 127 | def append_to_note( 128 | self, name: str, append_content: str, folder: str | None = None 129 | ) -> str: 130 | """ 131 | Appends content to an existing note by its name. If the exact name is not found, 132 | attempts fuzzy matching to find the closest note. 133 | """ 134 | if platform.system() != "Darwin": 135 | return "This method is only supported on MacOS" 136 | 137 | folder_line = self._get_folder_line(folder) 138 | 139 | # Try to directly find and append to the note with the exact name 140 | script_find_note = f""" 141 | tell application "{self.notes_app}" 142 | tell account "iCloud" 143 | {folder_line} 144 | set matchingNotes to notes whose name is "{name}" 145 | if length of matchingNotes > 0 then 146 | set theNote to item 1 of matchingNotes 147 | return name of theNote 148 | else 149 | return "No exact match found." 150 | end if 151 | end tell 152 | end tell 153 | end tell 154 | """ 155 | 156 | try: 157 | note_name, _ = run_applescript_capture( 158 | script_find_note.format(notes_app=self.notes_app, name=name) 159 | ) 160 | note_name = note_name.strip() 161 | 162 | if "No exact match found" in note_name or not note_name: 163 | note_name = self._do_fuzzy_matching(name) 164 | if note_name == "No notes found after fuzzy matching.": 165 | return "No notes found after fuzzy matching." 166 | 167 | # If an exact match is found, append content to the note 168 | html_append_content = append_content.replace('"', '\\"').replace("'", "’") 169 | script_append = f""" 170 | tell application "{self.notes_app}" 171 | tell account "iCloud" 172 | {folder_line} 173 | set theNote to first note whose name is "{note_name}" 174 | set body of theNote to (body of theNote) & "
{html_append_content}" 175 | show theNote 176 | end tell 177 | end tell 178 | end tell 179 | """ 180 | 181 | run_applescript(script_append) 182 | return f"Content appended to note '{name}' successfully." 183 | except subprocess.CalledProcessError as e: 184 | return f"Error: {str(e)}" 185 | 186 | def _get_folder_line(self, folder: str | None) -> str: 187 | if folder is not None and len(folder) > 0 and self._check_folder_exists(folder): 188 | return f'tell folder "{folder}"' 189 | return f'tell folder "{Notes._DEFAULT_FOLDER}"' 190 | 191 | def _do_fuzzy_matching(self, name: str) -> str: 192 | script_search = f""" 193 | tell application "{self.notes_app}" 194 | tell account "iCloud" 195 | set noteList to name of every note 196 | end tell 197 | end tell 198 | """ 199 | note_names_str, _ = run_applescript_capture(script_search) 200 | note_names = note_names_str.split(", ") 201 | closest_matches = difflib.get_close_matches(name, note_names, n=1, cutoff=0.0) 202 | if not closest_matches: 203 | return "No notes found after fuzzy matching." 204 | 205 | note_to_open = closest_matches[0] 206 | return note_to_open 207 | 208 | def _check_folder_exists(self, folder: str) -> bool: 209 | # Adjusted script to optionally look for a folder 210 | folder_check_script = f""" 211 | tell application "{self.notes_app}" 212 | set folderExists to false 213 | set folderName to "{folder}" 214 | if folderName is not "" then 215 | repeat with eachFolder in folders 216 | if name of eachFolder is folderName then 217 | set folderExists to true 218 | exit repeat 219 | end if 220 | end repeat 221 | end if 222 | return folderExists 223 | end tell 224 | """ 225 | 226 | folder_exists, _ = run_applescript_capture(folder_check_script) 227 | folder_exists = folder_exists.strip() == "true" 228 | 229 | return folder_exists 230 | 231 | @staticmethod 232 | def _convert_note_to_text(note_html: str) -> str: 233 | """ 234 | Converts an HTML note content to plain text. 235 | """ 236 | soup = BeautifulSoup(note_html, "html.parser") 237 | return soup.get_text().strip() 238 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/reminders.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import platform 3 | import subprocess 4 | 5 | from src.tiny_agent.run_apple_script import run_applescript 6 | 7 | 8 | class Reminders: 9 | def __init__(self): 10 | self.reminders_app = "Reminders" 11 | 12 | def create_reminder( 13 | self, 14 | name: str, 15 | due_date: datetime.datetime | None = None, 16 | notes: str = "", 17 | list_name: str = "", 18 | priority: int = 0, 19 | all_day: bool = False, 20 | ) -> str: 21 | if platform.system() != "Darwin": 22 | return "This method is only supported on MacOS" 23 | 24 | if due_date is not None: 25 | due_date_script = f''', due date:date "{due_date.strftime("%B %d, %Y") 26 | if all_day 27 | else due_date.strftime("%B %d, %Y %I:%M:%S %p")}"''' 28 | else: 29 | due_date_script = "" 30 | 31 | notes = notes.replace('"', '\\"').replace("'", "’") 32 | script = f""" 33 | tell application "{self.reminders_app}" 34 | set listExists to false 35 | set listName to "{list_name}" 36 | if listName is not "" then 37 | repeat with eachList in lists 38 | if name of eachList is listName then 39 | set listExists to true 40 | exit repeat 41 | end if 42 | end repeat 43 | end if 44 | if listExists then 45 | tell list "{list_name}" 46 | set newReminder to make new reminder with properties {{name:"{name}", body:"{notes}", priority:{priority}{due_date_script}}} 47 | activate 48 | show newReminder 49 | end tell 50 | else 51 | set newReminder to make new reminder with properties {{name:"{name}", body:"{notes}", priority:{priority}{due_date_script}}} 52 | activate 53 | show newReminder 54 | end if 55 | end tell 56 | """ 57 | 58 | try: 59 | run_applescript(script) 60 | return f"Reminder '{name}' created successfully in the '{list_name}' list." 61 | except subprocess.CalledProcessError as e: 62 | return f"Failed to create reminder: {str(e)}" 63 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/sms.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import subprocess 3 | 4 | from src.tiny_agent.run_apple_script import run_applescript 5 | 6 | 7 | class SMS: 8 | def __init__(self): 9 | self.messages_app = "Messages" 10 | 11 | def send(self, to: list[str], message: str) -> str: 12 | """ 13 | Opens an SMS draft to the specified recipient using the Messages app, 14 | without sending it, by simulating keystrokes. 15 | """ 16 | if platform.system() != "Darwin": 17 | return "This method is only supported on MacOS" 18 | 19 | to_script = [] 20 | for recipient in to: 21 | recipient = recipient.replace("\n", "") 22 | to_script.append( 23 | f""" 24 | keystroke "{recipient}" 25 | delay 0.5 26 | keystroke return 27 | delay 0.5 28 | """ 29 | ) 30 | to_script = "".join(to_script) 31 | 32 | escaped_message = message.replace('"', '\\"').replace("'", "’") 33 | 34 | script = f""" 35 | tell application "System Events" 36 | tell application "{self.messages_app}" 37 | activate 38 | end tell 39 | tell process "{self.messages_app}" 40 | set frontmost to true 41 | delay 0.5 42 | keystroke "n" using command down 43 | delay 0.5 44 | {to_script} 45 | keystroke tab 46 | delay 0.5 47 | keystroke "{escaped_message}" 48 | end tell 49 | end tell 50 | """ 51 | try: 52 | run_applescript(script) 53 | return "SMS draft composed" 54 | except subprocess.CalledProcessError as e: 55 | return f"An error occurred while composing the SMS: {str(e)}" 56 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/spotlight_search.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import os # Import os to check if the path exists 3 | import platform 4 | import subprocess 5 | 6 | from src.tiny_agent.run_apple_script import run_command 7 | 8 | 9 | class SpotlightSearch: 10 | def __init__(self): 11 | pass 12 | 13 | def open(self, name_or_path: str) -> str: 14 | """ 15 | Does Spotlight Search and opens the first thing that matches the name. 16 | If no exact match, performs fuzzy search. 17 | Additionally, if the input is a path, tries to open the file directly. 18 | """ 19 | if platform.system() != "Darwin": 20 | return "This method is only supported on MacOS" 21 | 22 | # Check if input is a path and file exists 23 | if name_or_path.startswith("/") and os.path.exists(name_or_path): 24 | try: 25 | subprocess.run(["open", name_or_path]) 26 | return name_or_path 27 | except Exception as e: 28 | return f"Error opening file: {e}" 29 | 30 | # Use mdfind for fast searching with Spotlight 31 | command_search_exact = ["mdfind", f"kMDItemDisplayName == '{name_or_path}'"] 32 | stdout, _ = run_command(command_search_exact) 33 | 34 | if stdout: 35 | paths = stdout.strip().split("\n") 36 | path = paths[0] if paths else None 37 | if path: 38 | subprocess.run(["open", path]) 39 | return path 40 | 41 | # If no exact match, perform fuzzy search on the file names 42 | command_search_general = ["mdfind", name_or_path] 43 | stdout, stderr = run_command(command_search_general) 44 | 45 | paths = stdout.strip().split("\n") if stdout else [] 46 | 47 | if paths: 48 | best_match = difflib.get_close_matches(name_or_path, paths, n=1, cutoff=0.0) 49 | if best_match: 50 | _, stderr = run_command(["open", best_match[0]]) 51 | if len(stderr) > 0: 52 | return f"Error: {stderr}" 53 | return best_match[0] 54 | else: 55 | return "No file found after fuzzy matching." 56 | else: 57 | return "No file found with exact or fuzzy name." 58 | -------------------------------------------------------------------------------- /src/tiny_agent/tools/zoom.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Sequence, TypedDict 3 | from zoneinfo import ZoneInfo 4 | 5 | import aiohttp 6 | import dateutil.parser 7 | 8 | 9 | class Zoom: 10 | _TIMEZONE = "US/Pacific" 11 | _MEETINGS_ENDPOINT = "https://api.zoom.us/v2/users/me/meetings" 12 | 13 | class ZoomMeetingInfo(TypedDict): 14 | """ 15 | Zoom meeting information returned by the create meeting endpoint. 16 | https://developers.zoom.us/docs/api/rest/reference/zoom-api/methods/#operation/meetingCreate 17 | Currently only types the fields we need, so add more if needed. 18 | """ 19 | 20 | join_url: str 21 | 22 | def __init__(self, access_token: str) -> None: 23 | self._access_token = access_token 24 | 25 | async def get_meeting_link( 26 | self, 27 | topic: str, 28 | start_time: str, 29 | duration: int, 30 | meeting_invitees: Sequence[str], 31 | ) -> str: 32 | """ 33 | Create a new Zoom meeting and return the join URL. 34 | """ 35 | start_time_utc = ( 36 | dateutil.parser.parse(start_time) 37 | .replace(tzinfo=ZoneInfo(Zoom._TIMEZONE)) 38 | .astimezone(datetime.timezone.utc) 39 | .strftime("%Y-%m-%dT%H:%M:%SZ") 40 | ) 41 | 42 | topic = topic[:200] 43 | 44 | resp = await aiohttp.ClientSession().post( 45 | Zoom._MEETINGS_ENDPOINT, 46 | headers={ 47 | "Authorization": f"Bearer {self._access_token}", 48 | "Content-Type": "application/json", 49 | }, 50 | json={ 51 | "topic": topic, 52 | "start_time": start_time_utc, 53 | "duration": duration, 54 | "meeting_invitees": meeting_invitees, 55 | }, 56 | ) 57 | 58 | info: Zoom.ZoomMeetingInfo = await resp.json() 59 | 60 | return info["join_url"] 61 | -------------------------------------------------------------------------------- /src/tiny_agent/transcription.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import audioop 3 | import io 4 | import json 5 | import wave 6 | from dataclasses import dataclass 7 | 8 | import httpx 9 | from openai import AsyncOpenAI 10 | 11 | from src.tiny_agent.models import TinyAgentConfig 12 | 13 | 14 | @dataclass 15 | class ResampledAudio: 16 | raw_bytes: bytes 17 | sample_rate: int 18 | 19 | 20 | class WhisperClient(abc.ABC): 21 | """An abstract class for the Whisper Servers.""" 22 | 23 | def __init__(self, config: TinyAgentConfig): 24 | pass 25 | 26 | @abc.abstractmethod 27 | async def transcribe(self, file: io.BytesIO) -> str: 28 | """Transcribe the audio to text using Whisper.""" 29 | pass 30 | 31 | @staticmethod 32 | @abc.abstractmethod 33 | def resample_audio(raw_bytes: bytes, sample_rate: int) -> ResampledAudio: 34 | """Resample the audio to the target sample rate which the Whisper server expects""" 35 | pass 36 | 37 | 38 | class WhisperOpenAIClient(WhisperClient): 39 | 40 | def __init__(self, config: TinyAgentConfig): 41 | self.client = AsyncOpenAI(api_key=config.whisper_config.api_key) 42 | 43 | async def transcribe(self, file: io.BytesIO) -> str: 44 | transcript = await self.client.audio.transcriptions.create( 45 | model="whisper-1", 46 | file=file, 47 | response_format="verbose_json", 48 | language="en", 49 | ) 50 | return transcript.text 51 | 52 | @staticmethod 53 | def resample_audio(raw_bytes: bytes, sample_rate: int) -> ResampledAudio: 54 | # OpenAI Whisper API already resample your bytes so this is just a noop 55 | return ResampledAudio(raw_bytes, sample_rate) 56 | 57 | 58 | class WhisperCppClient(WhisperClient): 59 | # Whisper.cpp server expects 16kHz audio 60 | _TARGET_SAMPLE_RATE = 16000 61 | _NON_DATA_FIELDS = { 62 | "temperature": "0.0", 63 | "temperature_inc": "0.2", 64 | "response_format": "json", 65 | } 66 | 67 | _base_url: str 68 | 69 | def __init__(self, config: TinyAgentConfig): 70 | self._base_url = f"http://localhost:{config.whisper_config.port}/inference" 71 | 72 | async def transcribe(self, file: io.BytesIO) -> str: 73 | # Preparing the files dictionary 74 | files = {"file": ("audio.wav", file, "audio/wav")} 75 | 76 | # Send the request to the Whisper.cpp server 77 | async with httpx.AsyncClient() as client: 78 | response = await client.post( 79 | self._base_url, files=files, data=WhisperCppClient._NON_DATA_FIELDS 80 | ) 81 | 82 | data = response.json() 83 | if "text" not in data: 84 | raise ValueError( 85 | f"Whisper.cpp response does not contain 'text' key: {json.dumps(data)}" 86 | ) 87 | 88 | return data["text"] 89 | 90 | @staticmethod 91 | def resample_audio(raw_bytes: bytes, sample_rate: int) -> ResampledAudio: 92 | if sample_rate == WhisperCppClient._TARGET_SAMPLE_RATE: 93 | return ResampledAudio(raw_bytes, sample_rate) 94 | 95 | # Resample the audio to 16kHz 96 | converted, _ = audioop.ratecv( 97 | raw_bytes, 2, 1, sample_rate, WhisperCppClient._TARGET_SAMPLE_RATE, None 98 | ) 99 | return ResampledAudio(converted, WhisperCppClient._TARGET_SAMPLE_RATE) 100 | 101 | 102 | class TranscriptionService: 103 | """ 104 | This is the main service that deals with transcribing raw audio data to text using Whisper. 105 | """ 106 | 107 | _client: WhisperClient 108 | 109 | def __init__(self, client: WhisperClient): 110 | self._client = client 111 | 112 | async def transcribe(self, raw_bytes: bytes, sample_rate: int) -> str: 113 | """ 114 | This method transcribes the audio data to text using Whisper. 115 | """ 116 | resampled_audio = self._client.resample_audio(raw_bytes, sample_rate) 117 | raw_bytes, sample_rate = resampled_audio.raw_bytes, resampled_audio.sample_rate 118 | 119 | with io.BytesIO() as memfile: 120 | # Open a new WAV file in write mode using the in-memory stream 121 | with wave.open(memfile, "wb") as wav_file: 122 | # Set the parameters for the WAV file 123 | wav_file.setnchannels(1) 124 | wav_file.setsampwidth(2) # 16-bit PCM, so 2 bytes per sample 125 | wav_file.setframerate(sample_rate) 126 | 127 | # Write the PCM data to the WAV file 128 | wav_file.writeframes(raw_bytes) 129 | 130 | memfile.name = "audio.wav" 131 | 132 | transcript = await self._client.transcribe(memfile) 133 | 134 | return transcript.strip() 135 | -------------------------------------------------------------------------------- /src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import uuid 4 | from dataclasses import dataclass, field 5 | from enum import Enum 6 | from typing import Any, Literal 7 | 8 | from src.llm_compiler.constants import JOINNER_FINISH, JOINNER_REPLAN 9 | 10 | 11 | class PlanStepToolName(Enum): 12 | GET_PHONE_NUMBER = "get_phone_number" 13 | GET_EMAIL_ADDRESS = "get_email_address" 14 | CREATE_CALENDAR_EVENT = "create_calendar_event" 15 | OPEN_AND_GET_FILE_PATH = "open_and_get_file_path" 16 | SUMMARIZE_PDF = "summarize_pdf" 17 | COMPOSE_NEW_EMAIL = "compose_new_email" 18 | REPLY_TO_EMAIL = "reply_to_email" 19 | FORWARD_EMAIL = "forward_email" 20 | MAPS_OPEN_LOCATION = "maps_open_location" 21 | MAPS_SHOW_DIRECTIONS = "maps_show_directions" 22 | CREATE_NOTE = "create_note" 23 | OPEN_NOTE = "open_note" 24 | APPEND_NOTE_CONTENT = "append_note_content" 25 | CREATE_REMINDER = "create_reminder" 26 | SEND_SMS = "send_sms" 27 | GET_ZOOM_MEETING_LINK = "get_zoom_meeting_link" 28 | JOIN = "join" 29 | 30 | 31 | class DataPointType(Enum): 32 | PLAN = "plan" 33 | JOIN = "join" 34 | REPLAN = "replan" 35 | 36 | 37 | @dataclass 38 | class PlanStep: 39 | tool_name: PlanStepToolName 40 | tool_args: dict[str, Any] 41 | 42 | def serialize(self) -> dict[str, Any]: 43 | return { 44 | "tool_name": ( 45 | self.tool_name.value 46 | if isinstance(self.tool_name, PlanStepToolName) 47 | else "join" 48 | ), 49 | "tool_args": self.tool_args, 50 | } 51 | 52 | 53 | @dataclass 54 | class JoinStep: 55 | thought: str 56 | action: Literal[JOINNER_FINISH, JOINNER_REPLAN] # type: ignore 57 | # Message is only applicable if the action is "Finish" 58 | message: str 59 | 60 | def serialize(self) -> dict[str, Any]: 61 | return { 62 | "thought": self.thought, 63 | "action": self.action, 64 | "message": self.message, 65 | } 66 | 67 | 68 | @dataclass 69 | class DataPoint: 70 | type: DataPointType 71 | raw_input: str 72 | raw_output: str 73 | parsed_output: list[PlanStep] | JoinStep 74 | id: uuid.UUID 75 | index: int 76 | 77 | def serialize(self) -> dict[str, Any]: 78 | return { 79 | "type": self.type.value, 80 | "raw_input": self.raw_input, 81 | "raw_output": self.raw_output, 82 | "parsed_output": ( 83 | [step.serialize() for step in self.parsed_output] 84 | if isinstance(self.parsed_output, list) 85 | else self.parsed_output.serialize() 86 | ), 87 | "id": str(self.id), 88 | "index": self.index, 89 | } 90 | 91 | 92 | @dataclass 93 | class Data: 94 | input: str 95 | output: list[DataPoint] 96 | closest_5_queries: list[str] = field(default_factory=list) 97 | 98 | def serialize(self) -> dict[str, Any]: 99 | return { 100 | "input": self.input, 101 | "output": [point.serialize() for point in self.output], 102 | "closest_5_queries": self.closest_5_queries, 103 | } 104 | 105 | 106 | def deserialize_data(data_json: dict[str, Any]) -> dict[str, Data]: 107 | data_objects = {} 108 | 109 | for key, value in data_json.items(): 110 | # Assuming 'input' and 'output' are directly accessible within `value` 111 | input_str = value["input"] 112 | output_list = value["output"] 113 | 114 | data_points = [] 115 | for output in output_list: 116 | parsed_output = output["parsed_output"] 117 | data_point = DataPoint( 118 | type=DataPointType(output["type"]), 119 | raw_input=output["raw_input"], 120 | raw_output=output["raw_output"], 121 | parsed_output=( 122 | ( 123 | [ 124 | PlanStep( 125 | tool_name=(PlanStepToolName(step["tool_name"])), 126 | tool_args=step["tool_args"], 127 | ) 128 | for step in parsed_output 129 | ] 130 | if isinstance(parsed_output, list) 131 | else JoinStep( 132 | thought=parsed_output["thought"], 133 | action=parsed_output["action"], 134 | message=parsed_output["message"], 135 | ) 136 | ) 137 | ), 138 | id=uuid.UUID(output["id"]), 139 | index=output["index"], 140 | ) 141 | 142 | data_points.append(data_point) 143 | 144 | data_objects[key] = Data( 145 | input=input_str, 146 | output=data_points, 147 | ) 148 | 149 | if "closest_5_queries" in value: 150 | data_objects[key].closest_5_queries = value["closest_5_queries"] 151 | 152 | return data_objects 153 | 154 | 155 | def save_data(data_objects: dict[str, Any], json_path: str) -> None: 156 | data_json = {} 157 | 158 | for key, data in data_objects.items(): 159 | data_json[key] = data.serialize() 160 | 161 | with open(json_path, "w") as file: 162 | json.dump(data_json, file, indent=4) 163 | 164 | 165 | def initialize_data_objects(json_path: str) -> dict[str, Data]: 166 | if not os.path.exists(json_path): 167 | with open(json_path, "w") as file: 168 | file.write("{}") 169 | data_objects = {} 170 | else: 171 | with open(json_path, "r") as file: 172 | data_objects = json.load(file) 173 | 174 | data_objects = deserialize_data(data_objects) 175 | return data_objects 176 | -------------------------------------------------------------------------------- /src/utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any 3 | 4 | import networkx as nx 5 | 6 | from src.utils.data_utils import PlanStep 7 | 8 | DEPENDENCY_REGEX = r"\$\d+" 9 | 10 | 11 | def check_for_dependency(item: str) -> str | None: 12 | match = re.search(DEPENDENCY_REGEX, item) 13 | if match: 14 | return match.group() 15 | return None 16 | 17 | 18 | def build_graph(plan: list[PlanStep]) -> nx.DiGraph: 19 | function_calls = [step.serialize() for step in plan] 20 | graph = nx.DiGraph() 21 | 22 | def add_new_dependency_edge( 23 | *, 24 | node_name: str, 25 | dependency: str, 26 | ) -> tuple[dict[str, Any], str]: 27 | dependency_index = int(dependency[1:]) - 1 28 | dependency_func = function_calls[dependency_index] 29 | dependency_node_name = f"{dependency_index+1}: {dependency_func['tool_name']}" 30 | graph.add_edge(dependency_node_name, node_name) 31 | 32 | return dependency_func, dependency_node_name 33 | 34 | def add_dependencies(dependencies: list[Any], node_name: str): 35 | nested_deps = [] 36 | for dep in dependencies: 37 | if isinstance(dep, list): 38 | # Checking for dependencies in list-type arguments 39 | for item in dep: 40 | if ( 41 | isinstance(item, str) 42 | and (item := check_for_dependency(item)) is not None 43 | ): 44 | dependency_func, dependency_node_name = add_new_dependency_edge( 45 | node_name=node_name, dependency=item 46 | ) 47 | 48 | # In order to avoid infinite recursion, we check 49 | # if the dependency node is the same as the current node 50 | if dependency_node_name == node_name: 51 | raise ValueError( 52 | "Circular dependency detected. " 53 | "The tool is dependent on itself, which is not allowed." 54 | ) 55 | 56 | # Recursively add nested dependencies 57 | dep_dict = { 58 | "tool_name": dependency_func["tool_name"], 59 | "dependencies": ( 60 | add_dependencies( 61 | dependency_func["tool_args"], dependency_node_name 62 | ) 63 | if dependency_node_name != node_name 64 | else [] 65 | ), 66 | } 67 | nested_deps.append(dep_dict) 68 | elif ( 69 | isinstance(dep, str) and (dep := check_for_dependency(dep)) is not None 70 | ): 71 | dependency_func, dependency_node_name = add_new_dependency_edge( 72 | node_name=node_name, dependency=dep 73 | ) 74 | 75 | # In order to avoid infinite recursion, we check 76 | # if the dependency node is the same as the current node 77 | if dependency_node_name == node_name: 78 | raise ValueError( 79 | "Circular dependency detected. " 80 | "The tool is dependent on itself, which is not allowed." 81 | ) 82 | 83 | # Recursively add nested dependencies 84 | # In order to avoid infinite recursion, we check if the dependency node is the same as the current node 85 | dep_dict = { 86 | "tool_name": dependency_func["tool_name"], 87 | "dependencies": ( 88 | add_dependencies( 89 | dependency_func["tool_args"], dependency_node_name 90 | ) 91 | if dependency_node_name != node_name 92 | else [] 93 | ), 94 | } 95 | nested_deps.append(dep_dict) 96 | return nested_deps 97 | 98 | for index, func in enumerate(function_calls): 99 | node_name = f"{index+1}: {func['tool_name']}" 100 | graph.add_node( 101 | node_name, tool_name=func["tool_name"], tool_args=func["tool_args"] 102 | ) 103 | dependencies = add_dependencies(func["tool_args"], node_name) 104 | graph.nodes[node_name]["dependencies"] = dependencies 105 | 106 | return graph 107 | 108 | 109 | def node_match(n1: dict, n2: dict) -> bool: 110 | if n1["tool_name"] != n2["tool_name"]: 111 | return False 112 | 113 | # Compare dependency structures recursively 114 | def compare_dependencies(d1, d2): 115 | if len(d1) != len(d2): 116 | return False 117 | for item1, item2 in zip( 118 | sorted(d1, key=lambda x: x["tool_name"]), 119 | sorted(d2, key=lambda x: x["tool_name"]), 120 | ): 121 | if not ( 122 | item1["tool_name"] == item2["tool_name"] 123 | and compare_dependencies(item1["dependencies"], item2["dependencies"]) 124 | ): 125 | return False 126 | return True 127 | 128 | return compare_dependencies(n1["dependencies"], n2["dependencies"]) 129 | 130 | 131 | def compare_graphs_with_success_rate(graph1: nx.DiGraph, graph2: nx.DiGraph) -> float: 132 | """Ouputs 1.0 if the graphs are isomorphic, 0.0 otherwise.""" 133 | if nx.is_isomorphic(graph1, graph2, node_match=node_match): 134 | return 1.0 135 | return 0.0 136 | 137 | 138 | def compare_graphs_with_edit_distance(graph1: nx.DiGraph, graph2: nx.DiGraph) -> float: 139 | """ 140 | Calculates the graph edit distance between two graphs. 141 | """ 142 | f_node_match = lambda n1, n2: node_match(n1, n2) 143 | 144 | # First, check for isomoprhism since it is a faster check. If yes, the GED is 0. 145 | if nx.is_isomorphic(graph1, graph2, node_match=f_node_match): 146 | return 0.0 147 | 148 | # Calculate the graph edit distance 149 | ged = nx.graph_edit_distance( 150 | graph1, 151 | graph2, 152 | node_match=f_node_match, 153 | timeout=10, 154 | ) 155 | 156 | return ged 157 | -------------------------------------------------------------------------------- /src/utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | 8 | from src.tiny_agent.models import TINY_AGENT_DIR 9 | 10 | # Global variable to toggle logging 11 | LOG_ENABLED = True 12 | LOG_TO_FILE = False 13 | LOG_FILE_PATH = os.path.join(TINY_AGENT_DIR, "log.txt") 14 | 15 | # Create the log file if it doesn't exist 16 | if not os.path.exists(LOG_FILE_PATH): 17 | with open(LOG_FILE_PATH, "w") as f: 18 | pass 19 | 20 | 21 | class Logger: 22 | def __init__(self) -> None: 23 | self._latency_dict = defaultdict(list) 24 | self._answer_dict = defaultdict(list) 25 | self._label_dict = defaultdict(list) 26 | 27 | def log(self, latency: float, answer: str, label: str, key: str) -> None: 28 | self._latency_dict[key].append(latency) 29 | self._answer_dict[key].append(answer) 30 | self._label_dict[key].append(label) 31 | 32 | def _get_mean_latency(self, key: str) -> float: 33 | latency_array = np.array(self._latency_dict[key]) 34 | return latency_array.mean(), latency_array.std() 35 | 36 | def _get_accuracy(self, key: str) -> float: 37 | answer_array = np.array(self._answer_dict[key]) 38 | label_array = np.array(self._label_dict[key]) 39 | return (answer_array == label_array).mean() 40 | 41 | def get_results(self, key: str) -> dict: 42 | mean_latency, std_latency = self._get_mean_latency(key) 43 | accuracy = self._get_accuracy(key) 44 | return { 45 | "mean_latency": mean_latency, 46 | "std_latency": std_latency, 47 | "accuracy": accuracy, 48 | } 49 | 50 | def save_result(self, key: str, path: str): 51 | with open(f"{path}/dev_react_results.csv", "w") as f: 52 | for i in range(len(self._answer_dict[key])): 53 | f.write(f"{self._answer_dict[key][i]},{self._latency_dict[key][i]}\n") 54 | 55 | 56 | def get_logger() -> Logger: 57 | return Logger() 58 | 59 | 60 | # Custom print function to toggle logging 61 | 62 | 63 | def enable_logging(enable=True): 64 | """Toggle logging on or off based on the given argument.""" 65 | global LOG_ENABLED 66 | LOG_ENABLED = enable 67 | 68 | 69 | def enable_logging_to_file(enable=True): 70 | """Toggle logging on or off based on the given argument.""" 71 | global LOG_TO_FILE 72 | LOG_TO_FILE = enable 73 | 74 | 75 | def log(*args, block=False, **kwargs): 76 | """Print the given string only if logging is enabled.""" 77 | if LOG_ENABLED: 78 | if block: 79 | print("=" * 80) 80 | print(*args, **kwargs) 81 | if block: 82 | print("=" * 80) 83 | if LOG_TO_FILE: 84 | with open(LOG_FILE_PATH, "a") as f: 85 | if block: 86 | print("=" * 80, file=f) 87 | print(*args, **kwargs, file=f) 88 | if block: 89 | print("=" * 80, file=f) 90 | 91 | 92 | def flush_results(save_path, results): 93 | print("Saving results") 94 | json.dump(results, open(save_path, "w"), indent=4) 95 | -------------------------------------------------------------------------------- /src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from langchain.chat_models import AzureChatOpenAI, ChatOpenAI 2 | from langchain.llms import OpenAI 3 | from langchain_community.embeddings import HuggingFaceEmbeddings 4 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings 5 | 6 | from src.utils.logger_utils import log 7 | 8 | DEFAULT_SAFE_CONTEXT_LENGTH = 512 9 | DEFAULT_SENTENCE_TRANSFORMER_BATCH_SIZE = 128 10 | 11 | 12 | def get_model( 13 | model_type, 14 | model_name, 15 | api_key, 16 | vllm_port, 17 | stream, 18 | temperature=0, 19 | azure_endpoint=None, 20 | azure_deployment=None, 21 | azure_api_version=None, 22 | ): 23 | if model_type == "openai": 24 | if api_key is None: 25 | raise ValueError("api_key must be provided for openai model") 26 | llm = ChatOpenAI( 27 | model_name=model_name, # type: ignore 28 | openai_api_key=api_key, # type: ignore 29 | streaming=stream, 30 | temperature=temperature, 31 | ) 32 | 33 | elif model_type == "vllm": 34 | if vllm_port is None: 35 | raise ValueError("vllm_port must be provided for vllm model") 36 | if stream: 37 | log( 38 | "WARNING: vllm does not support streaming. " 39 | "Setting stream=False for vllm model." 40 | ) 41 | llm = OpenAI( 42 | openai_api_base=f"http://localhost:{vllm_port}/v1", 43 | model_name=model_name, 44 | temperature=temperature, 45 | max_retries=1, 46 | streaming=stream, 47 | ) 48 | elif model_type == "local": 49 | if vllm_port is None: 50 | raise ValueError("vllm_port must be provided for vllm model") 51 | llm = ChatOpenAI( 52 | openai_api_key=api_key, # type: ignore 53 | openai_api_base=f"http://localhost:{vllm_port}/v1", 54 | model_name=model_name, 55 | temperature=temperature, 56 | max_retries=1, 57 | streaming=stream, 58 | ) 59 | elif model_type == "azure": 60 | if api_key is None: 61 | raise ValueError("api_key must be provided for azure model") 62 | if azure_api_version is None: 63 | raise ValueError("azure_api_version must be provided for azure model") 64 | if azure_endpoint is None: 65 | raise ValueError("azure_endpoint must be provided for azure model") 66 | if azure_deployment is None: 67 | raise ValueError("azure_deployment must be provided for azure model") 68 | 69 | llm = AzureChatOpenAI( 70 | api_key=api_key, # type: ignore 71 | api_version=azure_api_version, 72 | azure_endpoint=azure_endpoint, 73 | azure_deployment=azure_deployment, 74 | streaming=stream, 75 | ) 76 | 77 | else: 78 | raise NotImplementedError(f"Unknown model type: {model_type}") 79 | 80 | return llm 81 | 82 | 83 | def get_embedding_model( 84 | model_type: str, 85 | model_name: str, 86 | api_key: str, 87 | azure_embedding_deployment: str, 88 | azure_endpoint: str | None, 89 | azure_api_version: str | None, 90 | local_port: int | None, 91 | context_length: int | None, 92 | ) -> OpenAIEmbeddings | AzureOpenAIEmbeddings | HuggingFaceEmbeddings: 93 | if model_name is None: 94 | raise ValueError("Embedding model's model_name must be provided") 95 | 96 | if model_type == "openai": 97 | if api_key is None: 98 | raise ValueError("api_key must be provided for openai model") 99 | return OpenAIEmbeddings(api_key=api_key, model=model_name) 100 | elif model_type == "azure": 101 | if api_key is None: 102 | raise ValueError("api_key must be provided for azure model") 103 | if azure_api_version is None: 104 | raise ValueError("azure_api_version must be provided for azure model") 105 | if azure_endpoint is None: 106 | raise ValueError("azure_endpoint must be provided for azure model") 107 | return AzureOpenAIEmbeddings( 108 | api_key=api_key, 109 | api_version=azure_api_version, 110 | azure_endpoint=azure_endpoint, 111 | azure_deployment=azure_embedding_deployment, 112 | model=model_name, 113 | ) 114 | elif model_type == "local": 115 | if local_port is None: 116 | # Use SentenceTransformer for local embeddings 117 | return HuggingFaceEmbeddings( 118 | model_name=model_name, 119 | encode_kwargs={"batch_size": DEFAULT_SENTENCE_TRANSFORMER_BATCH_SIZE}, 120 | ) 121 | if context_length is None: 122 | print( 123 | "WARNING: context_length not provided for local model. Using default value (512).", 124 | flush=True, 125 | ) 126 | context_length = DEFAULT_SAFE_CONTEXT_LENGTH 127 | return OpenAIEmbeddings( 128 | api_key=api_key, 129 | base_url=f"http://localhost:{local_port}/v1", 130 | model=model_name, 131 | embedding_ctx_length=context_length - 1, 132 | tiktoken_enabled=False, 133 | tiktoken_model_name=model_name, 134 | ) 135 | else: 136 | raise NotImplementedError(f"Unknown model type: {model_type}") 137 | -------------------------------------------------------------------------------- /src/utils/plan_utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import re 3 | from typing import Any 4 | 5 | from src.llm_compiler.constants import JOINNER_FINISH, JOINNER_REPLAN 6 | from src.llm_compiler.task_fetching_unit import Task 7 | from src.utils.data_utils import JoinStep, PlanStep, PlanStepToolName 8 | 9 | THOUGHT_PATTERN = r"Thought: ([^\n]*)" 10 | ACTION_PATTERN = r"\s*\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?" 11 | 12 | 13 | def _parse_llm_compiler_action_args(args: str) -> Any: 14 | """Parse arguments from a string.""" 15 | # This will convert the string into a python object 16 | # e.g. '"Ronaldo number of kids"' -> ("Ronaldo number of kids", ) 17 | # '"I can answer the question now.", [3]' -> ("I can answer the question now.", [3]) 18 | if args == "": 19 | return () 20 | try: 21 | args = ast.literal_eval(args) 22 | except: 23 | args = args 24 | if not isinstance(args, list) and not isinstance(args, tuple): 25 | args = (args,) # type: ignore 26 | return args 27 | 28 | 29 | def parse_plan(plan: str) -> dict[int, Any]: 30 | # 1. search("Ronaldo number of kids") -> 1, "search", '"Ronaldo number of kids"' 31 | # pattern = r"(\d+)\. (\w+)\(([^)]+)\)" 32 | pattern = rf"(?:{THOUGHT_PATTERN}\n)?{ACTION_PATTERN}" 33 | matches = re.findall(pattern, plan) 34 | 35 | graph_dict = {} 36 | 37 | for match in matches: 38 | # idx = 1, function = "search", args = "Ronaldo number of kids" 39 | # thought will be the preceding thought, if any, otherwise an empty string 40 | _, idx, tool_name, args, _ = match 41 | idx = int(idx) 42 | 43 | # Create a dummy task 44 | task = Task( 45 | idx=idx, 46 | name=tool_name, 47 | tool=lambda x: None, 48 | args=_parse_llm_compiler_action_args(args), 49 | dependencies=[], 50 | stringify_rule=None, 51 | thought=None, 52 | is_join=tool_name == "join", 53 | ) 54 | 55 | graph_dict[idx] = task 56 | if task.is_join: 57 | break 58 | 59 | return graph_dict 60 | 61 | 62 | def get_parsed_planner_output(tasks: dict[int, Any]) -> list[PlanStep]: 63 | steps: list[PlanStep] = [] 64 | try: 65 | for _, task in tasks.items(): 66 | step = PlanStep( 67 | tool_name=(PlanStepToolName(task.name)), 68 | tool_args=task.args, 69 | ) 70 | steps.append(step) 71 | except Exception as e: 72 | print("Tool hallucination error: ", e) 73 | 74 | return steps 75 | 76 | 77 | def get_parsed_planner_output_from_raw(raw_answer: str) -> list[PlanStep]: 78 | tasks = parse_plan(raw_answer) 79 | return get_parsed_planner_output(tasks) 80 | 81 | 82 | def get_parsed_joinner_output(raw_answer: str) -> JoinStep: 83 | thought, answer, is_replan = "", "", False # default values 84 | raw_answers = raw_answer.split("\n") 85 | for ans in raw_answers: 86 | start_of_answer = ans.find("Action:") 87 | if start_of_answer >= 0: 88 | ans = ans[start_of_answer:] 89 | if ans.startswith("Action:"): 90 | answer = ans[ans.find("(") + 1 : ans.rfind(")")] 91 | is_replan = JOINNER_REPLAN in ans 92 | elif ans.startswith("Thought:"): 93 | thought = ans.split("Thought:")[1].strip() 94 | 95 | step = JoinStep( 96 | thought=thought, 97 | action=JOINNER_REPLAN if is_replan else JOINNER_FINISH, 98 | message=answer, 99 | ) 100 | return step 101 | 102 | 103 | def evaluate_plan(label_plan: list[PlanStep], predicted_plan: list[PlanStep]) -> float: 104 | """Returns the accuracy of the predicted plan based on the tool names and the right ordering.""" 105 | # Assuming the lengths of the two plans are the same 106 | correct = 0 107 | for label_step, predicted_step in zip(label_plan, predicted_plan): 108 | if label_step.tool_name == predicted_step.tool_name: 109 | correct += 1 110 | return correct / len(label_plan) 111 | 112 | 113 | def evaluate_tool_recall( 114 | label_plan: list[PlanStep], predicted_plan: list[PlanStep] 115 | ) -> float: 116 | """ 117 | Instead of determining the score based on the correct function call in the plan, 118 | determines the score based on the correct set of function calls in the plan. 119 | """ 120 | label_set = set([step.tool_name for step in label_plan]) 121 | predicted_set = set([step.tool_name for step in predicted_plan]) 122 | 123 | # Check how many of the predicted steps are in the label set 124 | correct = len(label_set.intersection(predicted_set)) 125 | return correct / len(label_set) 126 | 127 | 128 | def evaluate_join(label_join: JoinStep, predicted_join: JoinStep) -> float: 129 | """Returns 1.0 if the predicted join step is correct, else 0.0.""" 130 | if label_join.action == predicted_join.action: 131 | return 1.0 132 | return 0.0 133 | --------------------------------------------------------------------------------