├── .gitignore
├── LICENSE
├── README.md
├── TinyAgent.zip
├── figs
└── tinyagent.png
├── requirements.txt
├── run_tiny_agent_server.py
└── src
├── agents
├── agent.py
├── structured_chat_agent.py
└── tools.py
├── callbacks
└── callbacks.py
├── chains
├── chain.py
└── llm_chain.py
├── executors
├── agent_executor.py
└── schema.py
├── llm_compiler
├── constants.py
├── llm_compiler.py
├── output_parser.py
├── planner.py
└── task_fetching_unit.py
├── tiny_agent
├── computer.py
├── config.py
├── models.py
├── prompts.py
├── run_apple_script.py
├── sub_agents
│ ├── compose_email_agent.py
│ ├── notes_agent.py
│ ├── pdf_summarizer_agent.py
│ └── sub_agent.py
├── tiny_agent.py
├── tiny_agent_tools.py
├── tool_rag
│ ├── base_tool_rag.py
│ ├── classifier_tool_rag.py
│ ├── simple_tool_rag.py
│ └── text-embedding-3-small
│ │ └── embeddings.pkl
├── tools
│ ├── calendar.py
│ ├── contacts.py
│ ├── mail.py
│ ├── maps.py
│ ├── notes.py
│ ├── reminders.py
│ ├── sms.py
│ ├── spotlight_search.py
│ └── zoom.py
└── transcription.py
├── tools
└── base.py
└── utils
├── data_utils.py
├── graph_utils.py
├── logger_utils.py
├── model_utils.py
└── plan_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vs/
2 | .vscode/
3 | .idea/
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 | docs/docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 | notebooks/
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .envrc
112 | .venv
113 | .venvs
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
138 | # macOS display setting files
139 | .DS_Store
140 |
141 | # Wandb directory
142 | wandb/
143 |
144 | # asdf tool versions
145 | .tool-versions
146 | /.ruff_cache/
147 |
148 | *.pkl
149 | *.bin
150 |
151 | # integration test artifacts
152 | data_map*
153 | \[('_type', 'fake'), ('stop', None)]
154 |
155 | # Replit files
156 | *replit*
157 |
158 | node_modules
159 | docs/.yarn/
160 | docs/node_modules/
161 | docs/.docusaurus/
162 | docs/.cache-loader/
163 | docs/_dist
164 | docs/api_reference/api_reference.rst
165 | docs/api_reference/experimental_api_reference.rst
166 | docs/api_reference/_build
167 | docs/api_reference/*/
168 | !docs/api_reference/_static/
169 | !docs/api_reference/templates/
170 | !docs/api_reference/themes/
171 | docs/docs_skeleton/build
172 | docs/docs_skeleton/node_modules
173 | docs/docs_skeleton/yarn.lock
174 |
175 | *.ipynb
176 |
177 | *.xcuserstate
178 | project.xcworkspace/
179 | xcuserdata/
180 | *.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 SqueezeAILab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TinyAgent: Function Calling at the Edge
2 |
3 | [][#arxiv-paper-package]
4 | [][#license-gh-package]
5 |
6 | [#license-gh-package]: https://lbesson.mit-license.org/
7 | [#arxiv-paper-package]: https://arxiv.org/abs/2409.00608
8 |
9 |
10 |
11 | Get the desktop app
12 | |
13 | Read the blog post
14 | |
15 | Read the paper
16 |
17 |
18 | 
19 |
20 | TinyAgent aims to enable complex reasoning and function calling capabilities in Small Language Models (SLMs) that can be deployed securely and privately at the edge. Traditional Large Language Models (LLMs) like GPT-4 and Gemini-1.5, while powerful, are often too large and resource-intensive for edge deployment, posing challenges in terms of privacy, connectivity, and latency. TinyAgent addresses these challenges by training specialized SLMs with high-quality, curated data, and focusing on function calling with [LLMCompiler](https://github.com/SqueezeAILab/LLMCompiler). As a driving application, TinyAgent can interact with various MacOS applications, assisting users with day-to-day tasks such as composing emails, managing contacts, scheduling calendar events, and organizing Zoom meetings.
21 |
22 | ## Demo
23 |
24 |
25 |
26 |
27 |
28 | ## What can TinyAgent do?
29 |
30 | TinyAgent is equipped with 16 different functions that can interact with different applications on Mac, which includes:
31 |
32 | ### 📧 Mail
33 |
34 | - **Compose New Email**
35 | - Start a new email with options for adding recipients and attachments.
36 | - _Example Query:_ “Email Sid and Nick about the meeting with attachment project.pdf.”
37 | - **Reply to Emails**
38 | - Respond to received emails, optionally adding new recipients and attachments.
39 | - _Example Query:_ “Reply to Alice's email with the updated budget document attached.”
40 | - **Forward Emails**
41 | - Forward an existing email to other contacts, including additional attachments if necessary.
42 | - _Example Query:_ “Forward the project briefing to the marketing team.”
43 |
44 | ### 📇 Contacts
45 |
46 | - Retrieve phone numbers and email addresses from the contacts database.
47 | - _Example Query:_ “Get John’s phone number” or “Find Alice’s email address.”
48 |
49 | ### 📨 SMS
50 |
51 | - Send text messages to contacts directly from TinyAgent.
52 | - _Example Query:_ “Send an SMS to Canberk saying ‘I’m running late.’”
53 |
54 | ### 📅 Calendar
55 |
56 | - Create new calendar events with specified titles, dates, and times.
57 | - _Example Query:_ “Create an event called 'Meeting with Sid' on Friday at 3 PM.”
58 |
59 | ### 🗺️ Maps
60 |
61 | - Find directions or open map locations for points of interest via Apple Maps.
62 | - _Example Query:_ “Show me directions to the nearest Starbucks.”
63 |
64 | ### 🗒️ Notes
65 |
66 | - Create, open, and append content to notes stored in various folders.
67 | - _Example Query:_ “Create a note called 'Meeting Notes' in the Meetings folder.”
68 |
69 | ### 🗂️ File Management
70 |
71 | - **File Reading**
72 | - Open and read files directly through TinyAgent.
73 | - _Example Query:_ “Open the LLM Compiler.pdf.”
74 | - **PDF Summarization**
75 | - Generate summaries of PDF documents, enhancing content digestion and review efficiency.
76 | - _Example Query:_ “Summarize the document LLM Compiler.pdf and save the summary in my Documents folder.”
77 |
78 | ### ⏰ Reminders
79 |
80 | - Set reminders for various activities or tasks, ensuring nothing is forgotten.
81 | - _Example Query:_ “Remind me to call Sid at 3 PM about the budget approval.”
82 |
83 | ### 🎥 Zoom Meetings
84 |
85 | - Schedule and organize Zoom meetings, including setting names and times.
86 | - _Example Query:_ “Create a Zoom meeting called 'Team Standup' at 10 AM next Monday.”
87 |
88 | ### 💬 Custom Instructions
89 |
90 | - Write and configure specific instructions for your TinyAgent.
91 | - _Example Query:_ “Always cc team members Nick and Sid in all emails.”
92 |
93 | > You can choose to enable/disable certain apps by going to the Preferences window.
94 |
95 | ### 🤖 Sub-Agents
96 |
97 | Depending on the task simplicity, TinyAgent orchestrates the execution of different more specialized or smaller LMs. TinyAgent currently can operate LMs that can summarize a PDF, write emails, or take notes.
98 |
99 | > See the [Customization](#customization) section to see how to add your own sub-agentx.
100 |
101 | ### 🛠️ ToolRAG
102 |
103 | When faced with challenging tasks, SLM agents require appropriate tools and in-context examples to guide them. If the model sees irrelevant examples, it can hallucinate. Likewise, if the model sees the descriptions of the tools that it doesn’t need, it usually gets confused, and these tools take up unnecessary prompt space. To tackle this, TinyAgent uses ToolRAG to retrieve the best tools and examples suited for a given query. This process has minimal latency and increases the accuracy of TinyAgent substantially. Please take a look at our [blog post](https://bair.berkeley.edu/blog/2024/05/29/tiny-agent) and our [ToolRAG model](https://huggingface.co/squeeze-ai-lab/TinyAgent-ToolRAG) for more details.
104 |
105 | > You need to first install our [ToolRAG model](https://huggingface.co/squeeze-ai-lab/TinyAgent-ToolRAG) from Hugging Face and enable it from the TinyAgent settings to use it.
106 |
107 | ### 🎙️ Whisper
108 |
109 | TinyAgent also accepts voice commands through both the OpenAI Whisper API and local [whisper.cpp](https://github.com/ggerganov/whisper.cpp) deployment. For whisper.cpp, you need to setup the [local whisper server](https://github.com/ggerganov/whisper.cpp/tree/master/examples/server) and provide the server port number in the TinyAgent settings.
110 |
111 | ## Providers
112 |
113 | You can use with your OpenAI key, Azure deployments, or even your own local models!
114 |
115 | ### OpenAI
116 |
117 | You need to provide OpenAI API Key and the models you want to use in the 'Preferences' window.
118 |
119 | ### Azure Deployments
120 |
121 | You need to provide your deployment name and the endpoints for the main agent/sub-agents/embedding model as well as the context length of the agent models in the 'Preferences' window.
122 |
123 | ### Local Models
124 |
125 | You can plug-and-play every part of TinyAgent with your local models! TinyAgent can use an OpenAI-compatible server to run models locally. There are several options you can take:
126 |
127 | - **[LMStudio](https://lmstudio.ai/) :** For models already on Huggingface, LMStudio provides an easy-to-use to interface to get started with locally served models.
128 |
129 | - **[llama.cpp server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md):** However, if you want more control over your models, we recommend using the official llama.cpp server to get started. Please read through the tagged documentation to get started with it.
130 |
131 | > All TinyAgent needs is the port numbers that you are serving your model at and its context length.
132 |
133 | ## Fine-tuned TinyAgents
134 |
135 | We also provide our own fine-tuned open source models, TinyAgent-1.1B and TinyAgent-7B! We curated a [dataset](https://huggingface.co/datasets/squeeze-ai-lab/TinyAgent-dataset) of 40000 real-life use cases for TinyAgent and fine-tuned two small open-source language models on this dataset with LoRA. After fine-tuning and using ToolRAG, both TinyAgent-1.1B and TinyAgent-7B exceed the performance of GPT-4-turbo. Check out our for the specifics of dataset generation, evaluation, and fine-tuning.
136 |
137 | | Model | Success Rate |
138 | | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------ |
139 | | GPT-3.5-turbo | 65.04% |
140 | | GPT-4-turbo | 79.08% |
141 | | [TinyLLama-1.1B-32K-Instruct](https://huggingface.co/Doctor-Shotgun/TinyLlama-1.1B-32k-Instruct) | 12.71% |
142 | | [WizardLM-2-7b](https://huggingface.co/MaziyarPanahi/WizardLM-2-7B-GGUF) | 41.25% |
143 | | TinyAgent-1.1B + ToolRAG / [[hf](https://huggingface.co/squeeze-ai-lab/TinyAgent-1.1B)] [[gguf](https://huggingface.co/squeeze-ai-lab/TinyAgent-1.1B-GGUF)] | **80.06%** |
144 | | TinyAgent-7B + ToolRAG / [[hf](https://huggingface.co/squeeze-ai-lab/TinyAgent-7B)] [[gguf](https://huggingface.co/squeeze-ai-lab/TinyAgent-7B-GGUF)] | **84.95%** |
145 |
146 | ## Customization
147 |
148 | You can customize your TinyAgent by going to `~/Library/Application Support/TinyAgent/tinyagent-llmcompiler` directory and changing the code yourself.
149 |
150 | ### Using TinyAgent programmatically
151 |
152 | You can use TinyAgent programmatically by just passing in a config file.
153 |
154 | ```python
155 | from src.tiny_agent.tiny_agent import TinyAgent
156 | from src.tiny_agent.config import get_tiny_agent_config
157 |
158 | config_path = "..."
159 | tiny_agent_config = get_tiny_agent_config(config_path=config_path)
160 | tiny_agent = TinyAgent(tiny_agent_config)
161 |
162 | await TinyAgent.arun(query="Create a meeting with Sid and Lutfi for tomorrow 2pm to discuss the meeting notes.")
163 | ```
164 |
165 | ### Adding your own tools
166 |
167 | 1. Navigate to `src/tiny_agent/models.py` and add your tool;s name to `TinyAgentToolName(Enum)`
168 |
169 | ```python
170 | class TinyAgentToolName(Enum):
171 | ...
172 | CUSTOM_TOOL_NAME = "custom_tool"
173 | ```
174 |
175 | 2. Navigate to `src/tiny_agent/tiny_agent_tools.py` and define your tool.
176 |
177 | ```python
178 | def get_custom_tool(...) -> list[Tool]:
179 | async def tool_coroutine(...) -> str: # Needs to return a string
180 | ...
181 | return ...
182 |
183 | custom_tool = Tool(
184 | name=TinyAgentToolName.CUSTOM_TOOL_NAME,
185 | func=tool_coroutine,
186 | description=(
187 | f"{TinyAgentToolName.CUSTOM_TOOL_NAME.value}(...) -> str\n"
188 | ""
189 | )
190 | )
191 |
192 | return [custom_tool]
193 | ```
194 |
195 | 3. Add your tools to the `get_tiny_agent_tools` function.
196 |
197 | ```python
198 | def get_tiny_agent_tools(...):
199 | ...
200 | tools += get_custom_tools(...)
201 | ...
202 | ```
203 |
204 | > Note: Adding your own tools only works for GPT models since our open-source models and ToolRAG were only fine-tuned on the original TinyAgent toolset.
205 |
206 | ### Adding your own sub-agents
207 |
208 | You can also add your own custom subagents. To do so, please follow these steps:
209 |
210 | 1. All sub-agents inherit from `SubAgent` class in `src/tiny_agent/sub_agents/sub_agent.py`. Your custom agent should inherit from this abstract class and define the `__call__` method.
211 |
212 | ```python
213 | class CustomSubAgent(SubAgent):
214 |
215 | async def __call__(self, ...) -> str:
216 | ...
217 | return response
218 | ```
219 |
220 | 2. Add your custom agent to `TinyAgent` in `src/tiny_agent/tiny_agent.py`
221 |
222 | ```python
223 | from src.tiny_agent.sub_agents.custom_agent import CustomAgent
224 |
225 | class TinyAgent:
226 | ...
227 | custom_agent: CustomAgent
228 |
229 | def __init__(...):
230 | ...
231 | self.custom_agent = CustomAgent(
232 | sub_agent_llm, config.sub_agent_config, config.custom_instructions
233 | )
234 | ...
235 | ```
236 |
237 | 3. After defining your custom agent and adding it to TinyAgent, you should create a tool that calls this agent. Please refer to the [Adding your own tools](#adding-your-own-tools) section to see how to do so
238 |
239 | ## Citation
240 |
241 | We would appreciate it if you could please cite our [paper](https://arxiv.org/pdf/2409.00608) if you found TinyAgent useful for your work:
242 |
243 | ```
244 | @misc{erdogan2024tinyagentfunctioncallingedge,
245 | title={TinyAgent: Function Calling at the Edge},
246 | author={Lutfi Eren Erdogan and Nicholas Lee and Siddharth Jha and Sehoon Kim and Ryan Tabrizi and Suhong Moon and Coleman Hooper and Gopala Anumanchipalli and Kurt Keutzer and Amir Gholami},
247 | year={2024},
248 | eprint={2409.00608},
249 | archivePrefix={arXiv},
250 | primaryClass={cs.CL},
251 | url={https://arxiv.org/abs/2409.00608},
252 | }
253 | ```
254 |
--------------------------------------------------------------------------------
/TinyAgent.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SqueezeAILab/TinyAgent/cc45c0e842f5d163c3df1c8f41d60e90e005867d/TinyAgent.zip
--------------------------------------------------------------------------------
/figs/tinyagent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SqueezeAILab/TinyAgent/cc45c0e842f5d163c3df1c8f41d60e90e005867d/figs/tinyagent.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | PyMuPDF==1.23.26
2 | bs4==0.0.1
3 | fastapi==0.110.1
4 | langchain-community==0.0.9
5 | langchain-core==0.1.7
6 | langchain-openai==0.0.2
7 | langchain==0.1.0
8 | numexpr==2.8.7
9 | openai==1.6.1
10 | protobuf==5.26.1
11 | python-dateutil==2.8.2
12 | sentencepiece==0.2.0
13 | tiktoken==0.5.2
14 | torch==2.2.2
15 | transformers==4.38.2
16 | uvicorn==0.29.0
17 | python-multipart==0.0.9
18 | httpx==0.27.0
19 |
--------------------------------------------------------------------------------
/run_tiny_agent_server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import signal
4 | from http import HTTPStatus
5 | from typing import cast
6 |
7 | from fastapi import FastAPI, HTTPException, Request, Response
8 | from fastapi.exceptions import HTTPException
9 | from fastapi.responses import PlainTextResponse, StreamingResponse
10 | from pydantic import BaseModel
11 | from starlette.datastructures import UploadFile
12 | from starlette.exceptions import HTTPException as StarletteHTTPException
13 |
14 | from src.tiny_agent.config import get_tiny_agent_config
15 | from src.tiny_agent.models import (
16 | LLM_ERROR_TOKEN,
17 | TINY_AGENT_DIR,
18 | ModelType,
19 | streaming_queue,
20 | )
21 | from src.tiny_agent.tiny_agent import TinyAgent
22 | from src.tiny_agent.transcription import (
23 | TranscriptionService,
24 | WhisperCppClient,
25 | WhisperOpenAIClient,
26 | )
27 | from src.utils.logger_utils import enable_logging, enable_logging_to_file, log
28 |
29 | enable_logging(False)
30 | enable_logging_to_file(True)
31 |
32 | CONFIG_PATH = os.path.join(TINY_AGENT_DIR, "Configuration.json")
33 |
34 | app = FastAPI()
35 |
36 |
37 | def empty_queue(q: asyncio.Queue) -> None:
38 | while not q.empty():
39 | try:
40 | q.get_nowait()
41 | q.task_done()
42 | except asyncio.QueueEmpty:
43 | # Handle the case where the queue is already empty
44 | break
45 |
46 |
47 | class TinyAgentRequest(BaseModel):
48 | query: str
49 |
50 |
51 | @app.exception_handler(StarletteHTTPException)
52 | async def custom_http_exception_handler(request, exc):
53 | """
54 | Custom error handling for logging the errors to the TinyAgent log file.
55 | """
56 | log(f"HTTPException {exc.status_code}: {exc.detail}")
57 | return PlainTextResponse(exc.detail, status_code=exc.status_code)
58 |
59 |
60 | @app.post("/generate")
61 | async def execute_command(request: TinyAgentRequest) -> StreamingResponse:
62 | """
63 | This is the main endpoint that calls the TinyAgent to generate a response to the given query.
64 | """
65 | log(f"\n\n====\nReceived request: {request.query}")
66 |
67 | # First, ensure the queue is empty
68 | empty_queue(streaming_queue)
69 |
70 | query = request.query
71 |
72 | if not query or len(query) <= 0:
73 | raise HTTPException(
74 | status_code=HTTPStatus.BAD_REQUEST, detail="No query provided"
75 | )
76 |
77 | try:
78 | tiny_agent_config = get_tiny_agent_config(config_path=CONFIG_PATH)
79 | tiny_agent = TinyAgent(tiny_agent_config)
80 | except Exception as e:
81 | raise HTTPException(
82 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
83 | detail=f"Error: {e}",
84 | )
85 |
86 | async def generate():
87 | try:
88 | response_task = asyncio.create_task(tiny_agent.arun(query))
89 |
90 | while True:
91 | # Await a small timeout to periodically check if the task is done
92 | try:
93 | token = await asyncio.wait_for(streaming_queue.get(), timeout=1.0)
94 | if token is None:
95 | break
96 | if token.startswith(LLM_ERROR_TOKEN):
97 | raise Exception(token[len(LLM_ERROR_TOKEN) :])
98 | yield token
99 | except asyncio.TimeoutError:
100 | pass # No new token, check task status
101 |
102 | # Check if the task is done to handle any potential exception
103 | if response_task.done():
104 | break
105 |
106 | # Task created with asyncio.create_task() do not propagate exceptions
107 | # to the calling context. Instead, the exception remains encapsulated within
108 | # the task object itself until the task is awaited or its result is explicitly retrieved.
109 | # Hence, we check here if the task has an exception set by awaiting it, which will
110 | # raise the exception if it exists. If it doesn't, we just yield the result.
111 | await response_task
112 | response = response_task.result()
113 | yield f"\n\n{response}"
114 | except Exception as e:
115 | # You cannot raise HTTPExceptions in an async generator, it doesn't
116 | # get caught by the FastAPI exception handling middleware. Hence,
117 | # we are manually catching the exceptions and yielding/logging them.
118 | yield f"Error: {e}"
119 | log(f"Error: {e}")
120 |
121 | return StreamingResponse(generate(), media_type="text/event-stream")
122 |
123 |
124 | @app.post("/voice")
125 | async def get_voice_transcription(request: Request) -> Response:
126 | """
127 | This endpoint call whisper to get voice transcription. It takes in bytes of audio
128 | returns the transcription in plain text.
129 | """
130 | log("\n\n====\nReceived request to get voice transcription")
131 |
132 | body = await request.form()
133 | audio_file = cast(UploadFile, body["audio_pcm"])
134 | sample_rate = int(cast(str, body["sample_rate"]))
135 | raw_bytes = await audio_file.read()
136 |
137 | if not raw_bytes or len(raw_bytes) <= 0:
138 | raise HTTPException(
139 | status_code=HTTPStatus.BAD_REQUEST, detail="No audio provided"
140 | )
141 | if not sample_rate:
142 | raise HTTPException(
143 | status_code=HTTPStatus.BAD_REQUEST, detail="No sampling rate provided"
144 | )
145 |
146 | try:
147 | tiny_agent_config = get_tiny_agent_config(config_path=CONFIG_PATH)
148 | except Exception as e:
149 | raise HTTPException(
150 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
151 | detail=f"Error: {e}",
152 | )
153 |
154 | whisper_client = (
155 | WhisperOpenAIClient(tiny_agent_config)
156 | if tiny_agent_config.whisper_config.provider == ModelType.OPENAI
157 | else WhisperCppClient(tiny_agent_config)
158 | )
159 |
160 | transcription_service = TranscriptionService(whisper_client)
161 |
162 | try:
163 | transcription = await transcription_service.transcribe(raw_bytes, sample_rate)
164 | except Exception as e:
165 | raise HTTPException(
166 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
167 | detail=f"Error: {e}",
168 | )
169 |
170 | return Response(transcription, status_code=HTTPStatus.OK)
171 |
172 |
173 | @app.post("/quit")
174 | async def shutdown_server() -> Response:
175 | """
176 | Shuts down the server by sending a SIGINT signal to the main process,
177 | which is a gentle way to terminate the server. This endpoint should be
178 | protected in real applications to prevent unauthorized shutdowns.
179 | """
180 | os.kill(os.getpid(), signal.SIGTERM)
181 | return Response("Server is shutting down...", status_code=HTTPStatus.OK)
182 |
183 |
184 | @app.get("/ping")
185 | async def ping() -> Response:
186 | """
187 | A simple endpoint to check if the server is running.
188 | """
189 | return Response("pong", status_code=HTTPStatus.OK)
190 |
191 |
192 | if __name__ == "__main__":
193 | import uvicorn
194 |
195 | uvicorn.run(app, host="127.0.0.1", port=50001)
196 |
--------------------------------------------------------------------------------
/src/agents/agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import logging
5 | from abc import abstractmethod
6 | from pathlib import Path
7 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8 |
9 | import yaml
10 | from langchain.agents.agent import AgentOutputParser, BaseSingleActionAgent
11 | from langchain.agents.agent_types import AgentType
12 | from langchain.callbacks.base import BaseCallbackManager
13 | from langchain.callbacks.manager import Callbacks
14 | from langchain.prompts.few_shot import FewShotPromptTemplate
15 | from langchain.prompts.prompt import PromptTemplate
16 | from langchain.pydantic_v1 import root_validator
17 | from langchain.schema import (
18 | AgentAction,
19 | AgentFinish,
20 | BaseOutputParser,
21 | BasePromptTemplate,
22 | )
23 | from langchain.schema.language_model import BaseLanguageModel
24 | from langchain.schema.messages import BaseMessage
25 | from langchain.tools import BaseTool
26 |
27 | from src.chains.llm_chain import LLMChain
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class AgentOutputParser(BaseOutputParser):
33 | """Base class for parsing agent output into agent action/finish."""
34 |
35 | @abstractmethod
36 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
37 | """Parse text into agent action/finish."""
38 |
39 |
40 | class Agent(BaseSingleActionAgent):
41 | """Agent that calls the language model and deciding the action.
42 |
43 | This is driven by an LLMChain. The prompt in the LLMChain MUST include
44 | a variable called "agent_scratchpad" where the agent can put its
45 | intermediary work.
46 |
47 | Copied from Langchain v0.0.283,
48 | but merged with the parent class BaseSingleActionAgent for simplicity.
49 | """
50 |
51 | llm_chain: LLMChain
52 | output_parser: AgentOutputParser
53 | allowed_tools: Optional[List[str]] = None
54 |
55 | @property
56 | def _agent_type(self) -> str:
57 | """Return Identifier of agent type."""
58 | raise NotImplementedError
59 |
60 | def dict(self, **kwargs: Any) -> Dict:
61 | """Return dictionary representation of agent."""
62 | _dict = super().dict()
63 | _type = self._agent_type
64 | if isinstance(_type, AgentType):
65 | _dict["_type"] = str(_type.value)
66 | else:
67 | _dict["_type"] = _type
68 | del _dict["output_parser"]
69 | return _dict
70 |
71 | def get_allowed_tools(self) -> Optional[List[str]]:
72 | return self.allowed_tools
73 |
74 | @property
75 | def return_values(self) -> List[str]:
76 | return ["output"]
77 |
78 | def _fix_text(self, text: str) -> str:
79 | """Fix the text."""
80 | raise ValueError("fix_text not implemented for this agent.")
81 |
82 | @property
83 | def _stop(self) -> List[str]:
84 | return [
85 | f"\n{self.observation_prefix.rstrip()}",
86 | f"\n\t{self.observation_prefix.rstrip()}",
87 | ]
88 |
89 | def _construct_scratchpad(
90 | self, intermediate_steps: List[Tuple[AgentAction, str]]
91 | ) -> Union[str, List[BaseMessage]]:
92 | """Construct the scratchpad that lets the agent continue its thought process."""
93 | thoughts = ""
94 | for action, observation in intermediate_steps:
95 | thoughts += action.log
96 | thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
97 | return thoughts
98 |
99 | def plan(
100 | self,
101 | intermediate_steps: List[Tuple[AgentAction, str]],
102 | callbacks: Callbacks = None,
103 | **kwargs: Any,
104 | ) -> Union[AgentAction, AgentFinish]:
105 | """Given input, decided what to do.
106 |
107 | Args:
108 | intermediate_steps: Steps the LLM has taken to date,
109 | along with observations
110 | callbacks: Callbacks to run.
111 | **kwargs: User inputs.
112 |
113 | Returns:
114 | Action specifying what tool to use.
115 | """
116 | try:
117 | full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
118 | full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
119 | return self.output_parser.parse(full_output)
120 | except Exception as e:
121 | full_inputs["agent_scratchpad"] = (
122 | full_inputs["agent_scratchpad"] + full_output + "\nAction: "
123 | )
124 | full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
125 | return self.output_parser.parse("Action: " + full_output)
126 |
127 | async def aplan(
128 | self,
129 | intermediate_steps: List[Tuple[AgentAction, str]],
130 | callbacks: Callbacks = None,
131 | **kwargs: Any,
132 | ) -> Union[AgentAction, AgentFinish]:
133 | """Given input, decided what to do.
134 |
135 | Args:
136 | intermediate_steps: Steps the LLM has taken to date,
137 | along with observations
138 | callbacks: Callbacks to run.
139 | **kwargs: User inputs.
140 |
141 | Returns:
142 | Action specifying what tool to use.
143 | """
144 | try:
145 | full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
146 | full_output = await self.llm_chain.apredict(
147 | callbacks=callbacks, **full_inputs
148 | )
149 | agent_output = await self.output_parser.aparse(full_output)
150 | except Exception as e:
151 | full_inputs["agent_scratchpad"] = (
152 | full_inputs["agent_scratchpad"] + full_output + "\nAction: "
153 | )
154 | full_output = await self.llm_chain.apredict(
155 | callbacks=callbacks, **full_inputs
156 | )
157 | agent_output = await self.output_parser.aparse("Action: " + full_output)
158 |
159 | return agent_output
160 |
161 | def get_full_inputs(
162 | self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
163 | ) -> Dict[str, Any]:
164 | """Create the full inputs for the LLMChain from intermediate steps."""
165 | thoughts = self._construct_scratchpad(intermediate_steps)
166 | new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
167 | full_inputs = {**kwargs, **new_inputs}
168 | return full_inputs
169 |
170 | @property
171 | def input_keys(self) -> List[str]:
172 | """Return the input keys.
173 |
174 | :meta private:
175 | """
176 | return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
177 |
178 | @root_validator()
179 | def validate_prompt(cls, values: Dict) -> Dict:
180 | """Validate that prompt matches format."""
181 | prompt = values["llm_chain"].prompt
182 | if "agent_scratchpad" not in prompt.input_variables:
183 | logger.warning(
184 | "`agent_scratchpad` should be a variable in prompt.input_variables."
185 | " Did not find it, so adding it at the end."
186 | )
187 | prompt.input_variables.append("agent_scratchpad")
188 | if isinstance(prompt, PromptTemplate):
189 | prompt.template += "\n{agent_scratchpad}"
190 | elif isinstance(prompt, FewShotPromptTemplate):
191 | prompt.suffix += "\n{agent_scratchpad}"
192 | else:
193 | raise ValueError(f"Got unexpected prompt type {type(prompt)}")
194 | return values
195 |
196 | @property
197 | @abstractmethod
198 | def observation_prefix(self) -> str:
199 | """Prefix to append the observation with."""
200 |
201 | @property
202 | @abstractmethod
203 | def llm_prefix(self) -> str:
204 | """Prefix to append the LLM call with."""
205 |
206 | @classmethod
207 | def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
208 | """Validate that appropriate tools are passed in."""
209 | pass
210 |
211 | @classmethod
212 | @abstractmethod
213 | def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
214 | """Get default output parser for this class."""
215 |
216 | @classmethod
217 | def from_llm_and_tools(
218 | cls,
219 | llm: BaseLanguageModel,
220 | tools: Sequence[BaseTool],
221 | prompt: BasePromptTemplate,
222 | callback_manager: Optional[BaseCallbackManager] = None,
223 | output_parser: Optional[AgentOutputParser] = None,
224 | **kwargs: Any,
225 | ) -> Agent:
226 | """Construct an agent from an LLM and tools."""
227 | cls._validate_tools(tools)
228 | llm_chain = LLMChain(
229 | llm=llm,
230 | prompt=prompt,
231 | callback_manager=callback_manager,
232 | )
233 | tool_names = [tool.name for tool in tools]
234 | _output_parser = output_parser or cls._get_default_output_parser()
235 | return cls(
236 | llm_chain=llm_chain,
237 | allowed_tools=tool_names,
238 | output_parser=_output_parser,
239 | **kwargs,
240 | )
241 |
242 | def return_stopped_response(
243 | self,
244 | early_stopping_method: str,
245 | intermediate_steps: List[Tuple[AgentAction, str]],
246 | **kwargs: Any,
247 | ) -> AgentFinish:
248 | """Return response when agent has been stopped due to max iterations."""
249 | if early_stopping_method == "force":
250 | # `force` just returns a constant string
251 | return AgentFinish(
252 | {"output": "Agent stopped due to iteration limit or time limit."}, ""
253 | )
254 | elif early_stopping_method == "generate":
255 | # Generate does one final forward pass
256 | thoughts = ""
257 | for action, observation in intermediate_steps:
258 | thoughts += action.log
259 | thoughts += (
260 | f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
261 | )
262 | # Adding to the previous steps, we now tell the LLM to make a final pred
263 | thoughts += (
264 | "\n\nI now need to return a final answer based on the previous steps:"
265 | )
266 | new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
267 | full_inputs = {**kwargs, **new_inputs}
268 | full_output = self.llm_chain.predict(**full_inputs)
269 | # We try to extract a final answer
270 | parsed_output = self.output_parser.parse(full_output)
271 | if isinstance(parsed_output, AgentFinish):
272 | # If we can extract, we send the correct stuff
273 | return parsed_output
274 | else:
275 | # If we can extract, but the tool is not the final tool,
276 | # we just return the full output
277 | return AgentFinish({"output": full_output}, full_output)
278 | else:
279 | raise ValueError(
280 | "early_stopping_method should be one of `force` or `generate`, "
281 | f"got {early_stopping_method}"
282 | )
283 |
284 | def tool_run_logging_kwargs(self) -> Dict:
285 | return {
286 | "llm_prefix": self.llm_prefix,
287 | "observation_prefix": self.observation_prefix,
288 | }
289 |
290 | def save(self, file_path: Union[Path, str]) -> None:
291 | """Save the agent.
292 |
293 | Args:
294 | file_path: Path to file to save the agent to.
295 |
296 | Example:
297 | .. code-block:: python
298 |
299 | # If working with agent executor
300 | agent.agent.save(file_path="path/agent.yaml")
301 | """
302 | # Convert file to Path object.
303 | if isinstance(file_path, str):
304 | save_path = Path(file_path)
305 | else:
306 | save_path = file_path
307 |
308 | directory_path = save_path.parent
309 | directory_path.mkdir(parents=True, exist_ok=True)
310 |
311 | # Fetch dictionary to save
312 | agent_dict = self.dict()
313 |
314 | if save_path.suffix == ".json":
315 | with open(file_path, "w") as f:
316 | json.dump(agent_dict, f, indent=4)
317 | elif save_path.suffix == ".yaml":
318 | with open(file_path, "w") as f:
319 | yaml.dump(agent_dict, f, default_flow_style=False)
320 | else:
321 | raise ValueError(f"{save_path} must be json or yaml")
322 |
--------------------------------------------------------------------------------
/src/agents/structured_chat_agent.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, List, Optional, Sequence, Tuple
3 |
4 | from langchain.agents.agent import AgentOutputParser
5 | from langchain.agents.structured_chat.output_parser import (
6 | StructuredChatOutputParserWithRetries,
7 | )
8 | from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
9 | from langchain.callbacks.base import BaseCallbackManager
10 | from langchain.prompts.chat import (
11 | ChatPromptTemplate,
12 | HumanMessagePromptTemplate,
13 | SystemMessagePromptTemplate,
14 | )
15 | from langchain.pydantic_v1 import Field
16 | from langchain.schema import AgentAction, BasePromptTemplate
17 | from langchain.schema.language_model import BaseLanguageModel
18 | from langchain.tools import BaseTool
19 |
20 | from src.agents.agent import Agent
21 | from src.chains.llm_chain import LLMChain
22 |
23 | HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
24 |
25 |
26 | class StructuredChatAgent(Agent):
27 | """Structured Chat Agent."""
28 |
29 | output_parser: AgentOutputParser = Field(
30 | default_factory=StructuredChatOutputParserWithRetries
31 | )
32 | """Output parser for the agent."""
33 |
34 | @property
35 | def observation_prefix(self) -> str:
36 | """Prefix to append the observation with."""
37 | return "Observation: "
38 |
39 | @property
40 | def llm_prefix(self) -> str:
41 | """Prefix to append the llm call with."""
42 | return "Thought:"
43 |
44 | def _construct_scratchpad(
45 | self, intermediate_steps: List[Tuple[AgentAction, str]]
46 | ) -> str:
47 | agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
48 | if not isinstance(agent_scratchpad, str):
49 | raise ValueError("agent_scratchpad should be of type string.")
50 | if agent_scratchpad:
51 | return (
52 | f"This was your previous work "
53 | f"(but I haven't seen any of it! I only see what "
54 | f"you return as final answer):\n{agent_scratchpad}"
55 | )
56 | else:
57 | return agent_scratchpad
58 |
59 | @classmethod
60 | def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
61 | pass
62 |
63 | @classmethod
64 | def _get_default_output_parser(
65 | cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
66 | ) -> AgentOutputParser:
67 | return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
68 |
69 | @property
70 | def _stop(self) -> List[str]:
71 | return ["Observation:"]
72 |
73 | @classmethod
74 | def create_prompt(
75 | cls,
76 | tools: Sequence[BaseTool],
77 | prefix: str = PREFIX,
78 | suffix: str = SUFFIX,
79 | human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
80 | format_instructions: str = FORMAT_INSTRUCTIONS,
81 | input_variables: Optional[List[str]] = None,
82 | memory_prompts: Optional[List[BasePromptTemplate]] = None,
83 | ) -> BasePromptTemplate:
84 | tool_strings = []
85 | for tool in tools:
86 | args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
87 | tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
88 | formatted_tools = "\n".join(tool_strings)
89 | tool_names = ", ".join([tool.name for tool in tools])
90 | format_instructions = format_instructions.format(tool_names=tool_names)
91 | template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
92 | if input_variables is None:
93 | input_variables = ["input", "agent_scratchpad"]
94 | _memory_prompts = memory_prompts or []
95 | messages = [
96 | SystemMessagePromptTemplate.from_template(template),
97 | *_memory_prompts,
98 | HumanMessagePromptTemplate.from_template(human_message_template),
99 | ]
100 | return ChatPromptTemplate(input_variables=input_variables, messages=messages)
101 |
102 | @classmethod
103 | def from_llm_and_tools(
104 | cls,
105 | llm: BaseLanguageModel,
106 | tools: Sequence[BaseTool],
107 | callback_manager: Optional[BaseCallbackManager] = None,
108 | output_parser: Optional[AgentOutputParser] = None,
109 | prefix: str = PREFIX,
110 | suffix: str = SUFFIX,
111 | human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
112 | format_instructions: str = FORMAT_INSTRUCTIONS,
113 | input_variables: Optional[List[str]] = None,
114 | memory_prompts: Optional[List[BasePromptTemplate]] = None,
115 | **kwargs: Any,
116 | ) -> Agent:
117 | """Construct an agent from an LLM and tools."""
118 | cls._validate_tools(tools)
119 | prompt = cls.create_prompt(
120 | tools,
121 | prefix=prefix,
122 | suffix=suffix,
123 | human_message_template=human_message_template,
124 | format_instructions=format_instructions,
125 | input_variables=input_variables,
126 | memory_prompts=memory_prompts,
127 | )
128 | llm_chain = LLMChain(
129 | llm=llm,
130 | prompt=prompt,
131 | callback_manager=callback_manager,
132 | )
133 | tool_names = [tool.name for tool in tools]
134 | _output_parser = output_parser or cls._get_default_output_parser(llm=llm)
135 | return cls(
136 | llm_chain=llm_chain,
137 | allowed_tools=tool_names,
138 | output_parser=_output_parser,
139 | **kwargs,
140 | )
141 |
142 | @property
143 | def _agent_type(self) -> str:
144 | raise ValueError
145 |
--------------------------------------------------------------------------------
/src/agents/tools.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from langchain.callbacks.manager import (
4 | AsyncCallbackManagerForToolRun,
5 | CallbackManagerForToolRun,
6 | )
7 | from langchain.tools import BaseTool
8 |
9 | from src.tools.base import Tool, tool
10 |
11 |
12 | class InvalidTool(BaseTool):
13 | """Tool that is run when invalid tool name is encountered by agent."""
14 |
15 | name: str = "invalid_tool"
16 | description: str = "Called when tool name is invalid. Suggests valid tool names."
17 |
18 | def _run(
19 | self,
20 | requested_tool_name: str,
21 | available_tool_names: List[str],
22 | run_manager: Optional[CallbackManagerForToolRun] = None,
23 | ) -> str:
24 | """Use the tool."""
25 | available_tool_names_str = ", ".join([tool for tool in available_tool_names])
26 | return (
27 | f"{requested_tool_name} is not a valid tool, "
28 | f"try one of [{available_tool_names_str}]."
29 | )
30 |
31 | async def _arun(
32 | self,
33 | requested_tool_name: str,
34 | available_tool_names: List[str],
35 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
36 | ) -> str:
37 | """Use the tool asynchronously."""
38 | available_tool_names_str = ", ".join([tool for tool in available_tool_names])
39 | return (
40 | f"{requested_tool_name} is not a valid tool, "
41 | f"try one of [{available_tool_names_str}]."
42 | )
43 |
44 |
45 | __all__ = ["InvalidTool", "BaseTool", "tool", "Tool"]
46 |
--------------------------------------------------------------------------------
/src/callbacks/callbacks.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import tiktoken
4 | from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
5 |
6 |
7 | class StatsCallbackHandler(BaseCallbackHandler):
8 | """Collect useful stats about the run.
9 | Add more stats as needed."""
10 |
11 | def __init__(self) -> None:
12 | super().__init__()
13 | self.cnt = 0
14 | self.input_tokens = 0
15 | self.output_tokens = 0
16 | self.all_times = []
17 | self.start_time = 0
18 |
19 | def on_chat_model_start(self, serialized, prompts, **kwargs):
20 | self.start_time = time.time()
21 |
22 | def on_llm_end(self, response, *args, **kwargs):
23 | token_usage = response.llm_output["token_usage"]
24 | self.input_tokens += token_usage["prompt_tokens"]
25 | self.output_tokens += token_usage["completion_tokens"]
26 | self.cnt += 1
27 | self.all_times.append(round(time.time() - self.start_time, 2))
28 |
29 | def reset(self) -> None:
30 | self.cnt = 0
31 | self.input_tokens = 0
32 | self.output_tokens = 0
33 | self.all_times = []
34 |
35 | def get_stats(self) -> dict[str, int]:
36 | return {
37 | "calls": self.cnt,
38 | "input_tokens": self.input_tokens,
39 | "output_tokens": self.output_tokens,
40 | "all_times": self.all_times,
41 | }
42 |
43 |
44 | class AsyncStatsCallbackHandler(AsyncCallbackHandler):
45 | """Collect useful stats about the run.
46 | Add more stats as needed."""
47 |
48 | def __init__(self, stream: bool = False) -> None:
49 | super().__init__()
50 | self.cnt = 0
51 | self.input_tokens = 0
52 | self.output_tokens = 0
53 | # same for gpt-3.5
54 | self.encoder = tiktoken.encoding_for_model("gpt-4")
55 | self.stream = stream
56 | self.all_times = []
57 | self.additional_fields = {}
58 | self.start_time = 0
59 |
60 | async def on_chat_model_start(self, serialized, prompts, **kwargs):
61 | self.start_time = time.time()
62 | if self.stream:
63 | # if streaming mode, on_llm_end response is not collected
64 | # therefore, we need to count input token based on the
65 | # prompt length at the beginning
66 | self.cnt += 1
67 | self.input_tokens += len(self.encoder.encode(prompts[0][0].content))
68 |
69 | async def on_llm_new_token(self, token, *args, **kwargs):
70 | if self.stream:
71 | # if streaming mode, on_llm_end response is not collected
72 | # therefore, we need to manually count output token based on the
73 | # number of streamed out tokens
74 | self.output_tokens += 1
75 |
76 | async def on_llm_end(self, response, *args, **kwargs):
77 | self.all_times.append(round(time.time() - self.start_time, 2))
78 | if not self.stream:
79 | # if not streaming mode, on_llm_end response is collected
80 | # so we can use this stats directly
81 | token_usage = response.llm_output["token_usage"]
82 | self.input_tokens += token_usage["prompt_tokens"]
83 | self.output_tokens += token_usage["completion_tokens"]
84 | self.cnt += 1
85 |
86 | def reset(self) -> None:
87 | self.cnt = 0
88 | self.input_tokens = 0
89 | self.output_tokens = 0
90 | self.all_times = []
91 | self.additional_fields = {}
92 |
93 | def get_stats(self) -> dict[str, int]:
94 | return {
95 | "calls": self.cnt,
96 | "input_tokens": self.input_tokens,
97 | "output_tokens": self.output_tokens,
98 | "all_times": self.all_times,
99 | **self.additional_fields,
100 | }
101 |
--------------------------------------------------------------------------------
/src/executors/schema.py:
--------------------------------------------------------------------------------
1 | from langchain.pydantic_v1 import BaseModel
2 |
3 |
4 | class Step(BaseModel):
5 | """Step."""
6 |
7 | value: str
8 | """The value."""
9 |
10 |
11 | class Plan(BaseModel):
12 | """Plan."""
13 |
14 | steps: list[Step]
15 | """The steps."""
16 |
17 |
18 | class StepResponse(BaseModel):
19 | """Step response."""
20 |
21 | response: str
22 | """The response."""
23 |
--------------------------------------------------------------------------------
/src/llm_compiler/constants.py:
--------------------------------------------------------------------------------
1 | END_OF_PLAN = ""
2 |
3 | JOINNER_FINISH = "Finish"
4 | JOINNER_REPLAN = "Replan"
5 |
6 | SUMMARY_RESULT = "Summary"
7 |
--------------------------------------------------------------------------------
/src/llm_compiler/llm_compiler.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast
3 |
4 | from langchain.callbacks.manager import (
5 | AsyncCallbackManagerForChainRun,
6 | CallbackManagerForChainRun,
7 | )
8 | from langchain.chat_models.base import BaseChatModel
9 | from langchain.llms import BaseLLM
10 | from langchain.llms.base import BaseLLM
11 | from langchain.prompts.base import StringPromptValue
12 |
13 | from src.callbacks.callbacks import AsyncStatsCallbackHandler
14 | from src.chains.chain import Chain
15 | from src.llm_compiler.constants import JOINNER_REPLAN
16 | from src.llm_compiler.planner import Planner
17 | from src.llm_compiler.task_fetching_unit import Task, TaskFetchingUnit
18 | from src.tiny_agent.models import streaming_queue
19 | from src.tools.base import StructuredTool, Tool
20 | from src.utils.logger_utils import log
21 |
22 |
23 | class LLMCompilerAgent:
24 | """Self defined agent for LLM Compiler."""
25 |
26 | def __init__(self, llm: BaseLLM) -> None:
27 | self.llm = llm
28 |
29 | async def arun(self, prompt: str, callbacks=None) -> str:
30 | response = await self.llm.agenerate_prompt(
31 | prompts=[StringPromptValue(text=prompt)],
32 | stop=[""],
33 | callbacks=callbacks,
34 | )
35 | if isinstance(self.llm, BaseChatModel):
36 | return response.generations[0][0].message.content
37 |
38 | if isinstance(self.llm, BaseLLM):
39 | return response.generations[0][0].text
40 |
41 | raise ValueError("LLM must be either BaseChatModel or BaseLLM")
42 |
43 |
44 | class LLMCompiler(Chain, extra="allow"):
45 | """LLMCompiler Engine."""
46 |
47 | """The step container to use."""
48 | input_key: str = "input"
49 | output_key: str = "output"
50 |
51 | def __init__(
52 | self,
53 | tools: Sequence[Union[Tool, StructuredTool]],
54 | planner_llm: BaseLLM,
55 | planner_example_prompt: str,
56 | planner_example_prompt_replan: Optional[str],
57 | planner_stop: Optional[list[str]],
58 | planner_stream: bool,
59 | agent_llm: BaseLLM,
60 | joinner_prompt: str,
61 | joinner_prompt_final: Optional[str],
62 | max_replans: int,
63 | benchmark: bool,
64 | planner_custom_instructions_prompt: str | None = None,
65 | **kwargs,
66 | ) -> None:
67 | """
68 | Args:
69 | tools: List of tools to use.
70 | max_replans: Maximum number of replans to do.
71 | benchmark: Whether to collect benchmark stats.
72 |
73 | Planner Args:
74 | planner_llm: LLM to use for planning.
75 | planner_example_prompt: Example prompt for planning.
76 | planner_example_prompt_replan: Example prompt for replanning.
77 | Assign this if you want to use different example prompt for replanning.
78 | If not assigned, default to `planner_example_prompt`.
79 | planner_stop: Stop tokens for planning.
80 | planner_stream: Whether to stream the planning.
81 |
82 | Agent Args:
83 | agent_llm: LLM to use for agent.
84 | joinner_prompt: Prompt to use for joinner.
85 | joinner_prompt_final: Prompt to use for joinner at the final replanning iter.
86 | If not assigned, default to `joinner_prompt`.
87 | """
88 | super().__init__(**kwargs)
89 |
90 | if not planner_example_prompt_replan:
91 | log(
92 | "Replan example prompt not specified, using the same prompt as the planner."
93 | )
94 | planner_example_prompt_replan = planner_example_prompt
95 |
96 | self.planner = Planner(
97 | llm=planner_llm,
98 | example_prompt=planner_example_prompt,
99 | example_prompt_replan=planner_example_prompt_replan,
100 | custom_instructions=planner_custom_instructions_prompt,
101 | tools=tools,
102 | stop=planner_stop,
103 | )
104 |
105 | self.agent = LLMCompilerAgent(agent_llm)
106 | self.joinner_prompt = joinner_prompt
107 | self.joinner_prompt_final = joinner_prompt_final or joinner_prompt
108 | self.planner_stream = planner_stream
109 | self.max_replans = max_replans
110 |
111 | # callbacks
112 | self.benchmark = benchmark
113 | if benchmark:
114 | self.planner_callback = AsyncStatsCallbackHandler(stream=planner_stream)
115 | self.executor_callback = AsyncStatsCallbackHandler(stream=False)
116 | else:
117 | self.planner_callback = None
118 | self.executor_callback = None
119 |
120 | def get_all_stats(self):
121 | stats = {}
122 | if self.benchmark:
123 | stats["planner"] = self.planner_callback.get_stats()
124 | stats["executor"] = self.executor_callback.get_stats()
125 | stats["total"] = {
126 | k: v + stats["executor"].get(k, 0) for k, v in stats["planner"].items()
127 | }
128 |
129 | return stats
130 |
131 | def reset_all_stats(self):
132 | if self.planner_callback:
133 | self.planner_callback.reset()
134 | if self.executor_callback:
135 | self.executor_callback.reset()
136 |
137 | @property
138 | def input_keys(self) -> List[str]:
139 | return [self.input_key]
140 |
141 | @property
142 | def output_keys(self) -> List[str]:
143 | return [self.output_key]
144 |
145 | # TODO(sk): move all join related functions to a separate class
146 |
147 | def _parse_joinner_output(self, raw_answer: str) -> str:
148 | """We expect the joinner output format to be:
149 | ```
150 | Thought: xxx
151 | Action: Finish/Replan(yyy)
152 | ```
153 | Returns:
154 | thought (xxx)
155 | answer (yyy)
156 | is_replan (True/False)
157 | """
158 | thought, answer, is_replan = "", "", False # default values
159 | raw_answers = raw_answer.split("\n")
160 | for ans in raw_answers:
161 | # Get to the index where the answer is, which is the start of 'Action:'
162 | start_of_answer = ans.find("Action:")
163 | if start_of_answer >= 0:
164 | ans = ans[start_of_answer:]
165 | if ans.startswith("Action:") or ans.startswith(" Answer:"):
166 | answer = ans[ans.find("(") + 1 : ans.rfind(")")]
167 | is_replan = JOINNER_REPLAN in ans
168 | elif ans.startswith("Thought:") or ans.startswith(" Thought:"):
169 | thought = ans.split("Thought:")[1].strip()
170 | # if not is_replan:
171 | # return "", raw_answer, is_replan
172 | return thought, answer, is_replan
173 |
174 | def _generate_context_for_replanner(
175 | self, tasks: Mapping[int, Task], joinner_thought: str
176 | ) -> str:
177 | """Formatted like this:
178 | ```
179 | 1. action 1
180 | Observation: xxx
181 | 2. action 2
182 | Observation: yyy
183 | ...
184 | Thought: joinner_thought
185 | ```
186 | """
187 | previous_plan_and_observations = "\n".join(
188 | [
189 | task.get_though_action_observation(
190 | include_action=True, include_action_idx=True
191 | )
192 | for task in tasks.values()
193 | if not task.is_join
194 | ]
195 | )
196 | joinner_thought = f"Thought: {joinner_thought}"
197 | context = "\n\n".join([previous_plan_and_observations, joinner_thought])
198 | return context
199 |
200 | def _format_contexts(self, contexts: Sequence[str]) -> str:
201 | """contexts is a list of context
202 | each context is formatted as the description of _generate_context_for_replanner
203 | """
204 | formatted_contexts = ""
205 | for context in contexts:
206 | formatted_contexts += f"Previous Plan:\n\n{context}\n\n"
207 | formatted_contexts += "Current Plan:\n\n"
208 | return formatted_contexts
209 |
210 | async def join(
211 | self, input_query: str, agent_scratchpad: str, is_final: bool
212 | ) -> str:
213 | if is_final:
214 | joinner_prompt = self.joinner_prompt_final
215 | else:
216 | joinner_prompt = self.joinner_prompt
217 | prompt = (
218 | f"{joinner_prompt}\n" # Instructions and examples
219 | f"Question: {input_query}\n\n" # User input query
220 | f"{agent_scratchpad}\n" # T-A-O
221 | # "---\n"
222 | )
223 | log("Joining prompt:\n", prompt, block=True)
224 | response = await self.agent.arun(
225 | prompt, callbacks=[self.executor_callback] if self.benchmark else None
226 | )
227 | raw_answer = cast(str, response)
228 | log("Question: \n", input_query, block=True)
229 | log("Raw Answer: \n", raw_answer, block=True)
230 | thought, answer, is_replan = self._parse_joinner_output(raw_answer)
231 | if is_final:
232 | # If final, we don't need to replan
233 | is_replan = False
234 | return thought, answer, is_replan
235 |
236 | def _call(
237 | self,
238 | inputs: Dict[str, Any],
239 | run_manager: Optional[CallbackManagerForChainRun] = None,
240 | ):
241 | raise NotImplementedError("LLMCompiler is async only.")
242 |
243 | async def _acall(
244 | self,
245 | inputs: Dict[str, Any],
246 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
247 | ) -> Dict[str, Any]:
248 | contexts = []
249 | joinner_thought = ""
250 | agent_scratchpad = ""
251 | for i in range(self.max_replans):
252 | is_first_iter = i == 0
253 | is_final_iter = i == self.max_replans - 1
254 |
255 | task_fetching_unit = TaskFetchingUnit()
256 | if self.planner_stream:
257 | task_queue = asyncio.Queue()
258 | asyncio.create_task(
259 | self.planner.aplan(
260 | inputs=inputs,
261 | task_queue=task_queue,
262 | is_replan=not is_first_iter,
263 | callbacks=(
264 | [self.planner_callback] if self.planner_callback else None
265 | ),
266 | )
267 | )
268 | await task_fetching_unit.aschedule(
269 | task_queue=task_queue, func=lambda x: None
270 | )
271 | else:
272 | tasks = await self.planner.plan(
273 | inputs=inputs,
274 | is_replan=not is_first_iter,
275 | # callbacks=run_manager.get_child() if run_manager else None,
276 | callbacks=(
277 | [self.planner_callback] if self.planner_callback else None
278 | ),
279 | )
280 | log("Graph of tasks: ", tasks, block=True)
281 | if self.benchmark:
282 | self.planner_callback.additional_fields["num_tasks"] = len(tasks)
283 | task_fetching_unit.set_tasks(tasks)
284 | await task_fetching_unit.schedule()
285 | tasks = task_fetching_unit.tasks
286 |
287 | # collect thought-action-observation
288 | agent_scratchpad += "\n\n"
289 | agent_scratchpad += "".join(
290 | [
291 | task.get_though_action_observation(
292 | include_action=True, include_thought=True
293 | )
294 | for task in tasks.values()
295 | if not task.is_join
296 | # Also allow join tasks with observation which are there to propagate errors from the planning phase
297 | or (task.is_join and task.observation is not None)
298 | ]
299 | )
300 | agent_scratchpad = agent_scratchpad.strip()
301 |
302 | log("Agent scratchpad:\n", agent_scratchpad, block=True)
303 | joinner_thought, answer, is_replan = await self.join(
304 | inputs["input"],
305 | agent_scratchpad=agent_scratchpad,
306 | is_final=is_final_iter,
307 | )
308 | if not is_replan:
309 | log("Break out of replan loop.")
310 | break
311 |
312 | # Collect contexts for the subsequent replanner
313 | context = self._generate_context_for_replanner(
314 | tasks=tasks, joinner_thought=joinner_thought
315 | )
316 | contexts.append(context)
317 | formatted_contexts = self._format_contexts(contexts)
318 | log("Contexts:\n", formatted_contexts, block=True)
319 | inputs["context"] = formatted_contexts
320 |
321 | if is_final_iter:
322 | log("Reached max replan limit.")
323 |
324 | # End the generation request
325 | await streaming_queue.put(None)
326 |
327 | return {self.output_key: answer}
328 |
--------------------------------------------------------------------------------
/src/llm_compiler/output_parser.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import re
3 | from typing import Any, Sequence, Union
4 |
5 | from langchain.agents.agent import AgentOutputParser
6 | from langchain.schema import OutputParserException
7 |
8 | from src.llm_compiler.task_fetching_unit import Task
9 | from src.tools.base import StructuredTool, Tool
10 |
11 | THOUGHT_PATTERN = r"Thought: ([^\n]*)"
12 | ACTION_PATTERN = r"\s*\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?"
13 | # $1 or ${1} -> 1
14 | ID_PATTERN = r"\$\{?(\d+)\}?"
15 |
16 | END_OF_PLAN = ""
17 |
18 |
19 | def default_dependency_rule(idx, args: str):
20 | matches = re.findall(ID_PATTERN, args)
21 | numbers = [int(match) for match in matches]
22 | return idx in numbers
23 |
24 |
25 | class LLMCompilerPlanParser(AgentOutputParser, extra="allow"):
26 | """Planning output parser."""
27 |
28 | def __init__(self, tools: Sequence[Union[Tool, StructuredTool]], **kwargs):
29 | super().__init__(**kwargs)
30 | self.tools = tools
31 |
32 | def parse(self, text: str) -> dict[int, Any]:
33 | # 1. search("Ronaldo number of kids") -> 1, "search", '"Ronaldo number of kids"'
34 | # pattern = r"(\d+)\. (\w+)\(([^)]+)\)"
35 | pattern = rf"(?:{THOUGHT_PATTERN}\n)?{ACTION_PATTERN}"
36 | matches = re.findall(pattern, text)
37 |
38 | graph_dict = {}
39 |
40 | for match in matches:
41 | # idx = 1, function = "search", args = "Ronaldo number of kids"
42 | # thought will be the preceding thought, if any, otherwise an empty string
43 | thought, idx, tool_name, args, _ = match
44 | idx = int(idx)
45 |
46 | task = instantiate_task(
47 | tools=self.tools,
48 | idx=idx,
49 | tool_name=tool_name,
50 | args=args,
51 | thought=thought,
52 | )
53 |
54 | graph_dict[idx] = task
55 | if task.is_join:
56 | break
57 |
58 | return graph_dict
59 |
60 |
61 | ### Helper functions
62 |
63 |
64 | def _parse_llm_compiler_action_args(args: str) -> list[Any]:
65 | """Parse arguments from a string."""
66 | # This will convert the string into a python object
67 | # e.g. '"Ronaldo number of kids"' -> ("Ronaldo number of kids", )
68 | # '"I can answer the question now.", [3]' -> ("I can answer the question now.", [3])
69 | if args == "":
70 | return ()
71 | try:
72 | args = ast.literal_eval(args)
73 | except:
74 | args = args
75 | if not isinstance(args, list) and not isinstance(args, tuple):
76 | args = (args,)
77 | return args
78 |
79 |
80 | def _find_tool(
81 | tool_name: str, tools: Sequence[Union[Tool, StructuredTool]]
82 | ) -> Union[Tool, StructuredTool]:
83 | """Find a tool by name.
84 |
85 | Args:
86 | tool_name: Name of the tool to find.
87 |
88 | Returns:
89 | Tool or StructuredTool.
90 | """
91 | for tool in tools:
92 | if tool.name == tool_name:
93 | return tool
94 | raise OutputParserException(f"Tool {tool_name} not found")
95 |
96 |
97 | def _get_dependencies_from_graph(
98 | idx: int, tool_name: str, args: Sequence[Any]
99 | ) -> dict[str, list[str]]:
100 | """Get dependencies from a graph."""
101 | if tool_name == "join":
102 | # depends on the previous step
103 | dependencies = list(range(1, idx))
104 | else:
105 | # define dependencies based on the dependency rule in tool_definitions.py
106 | dependencies = [i for i in range(1, idx) if default_dependency_rule(i, args)]
107 |
108 | return dependencies
109 |
110 |
111 | def instantiate_task(
112 | tools: Sequence[Union[Tool, StructuredTool]],
113 | idx: int,
114 | tool_name: str,
115 | args: str,
116 | thought: str,
117 | ) -> Task:
118 | dependencies = _get_dependencies_from_graph(idx, tool_name, args)
119 | args = _parse_llm_compiler_action_args(args)
120 | if tool_name == "join":
121 | # join does not have a tool
122 | tool_func = lambda x: None
123 | stringify_rule = None
124 | else:
125 | tool = _find_tool(tool_name, tools)
126 | tool_func = tool.func
127 | stringify_rule = tool.stringify_rule
128 | return Task(
129 | idx=idx,
130 | name=tool_name,
131 | tool=tool_func,
132 | args=args,
133 | dependencies=dependencies,
134 | stringify_rule=stringify_rule,
135 | thought=thought,
136 | is_join=tool_name == "join",
137 | )
138 |
--------------------------------------------------------------------------------
/src/llm_compiler/planner.py:
--------------------------------------------------------------------------------
1 | """LLM Compiler Planner"""
2 |
3 | import asyncio
4 | import re
5 | from typing import Any, List, Optional, Sequence, Union
6 | from uuid import UUID
7 |
8 | from langchain.callbacks.base import AsyncCallbackHandler, Callbacks
9 | from langchain.chat_models.base import BaseChatModel
10 | from langchain.llms.base import BaseLLM
11 | from langchain.schema import LLMResult
12 | from langchain.schema.messages import HumanMessage, SystemMessage
13 |
14 | from src.executors.schema import Plan
15 | from src.llm_compiler.constants import END_OF_PLAN
16 | from src.llm_compiler.output_parser import (
17 | ACTION_PATTERN,
18 | THOUGHT_PATTERN,
19 | LLMCompilerPlanParser,
20 | instantiate_task,
21 | )
22 | from src.llm_compiler.task_fetching_unit import Task
23 | from src.tiny_agent.models import LLM_ERROR_TOKEN, streaming_queue
24 | from src.tools.base import StructuredTool, Tool
25 | from src.utils.logger_utils import log
26 |
27 | JOIN_DESCRIPTION = (
28 | "join():\n"
29 | " - Collects and combines results from prior actions.\n"
30 | " - A LLM agent is called upon invoking join to either finalize the user query or wait until the plans are executed.\n"
31 | " - join should always be the last action in the plan, and will be called in two scenarios:\n"
32 | " (a) if the answer can be determined by gathering the outputs from tasks to generate the final response.\n"
33 | " (b) if the answer cannot be determined in the planning phase before you execute the plans. "
34 | )
35 |
36 |
37 | def generate_llm_compiler_prompt(
38 | tools: Sequence[Union[Tool, StructuredTool]],
39 | example_prompt: str,
40 | custom_instructions: str | None,
41 | is_replan: bool = False,
42 | ):
43 | prefix = (
44 | "Given a user query, create a plan to solve it with the utmost parallelizability. "
45 | f"Each plan should comprise an action from the following {len(tools) + 1} types:\n"
46 | )
47 |
48 | # Tools
49 | for i, tool in enumerate(tools):
50 | prefix += f"{i+1}. {tool.description}\n"
51 |
52 | # Join operation
53 | prefix += f"{i+2}. {JOIN_DESCRIPTION}\n\n"
54 |
55 | # Guidelines
56 | prefix += (
57 | "Guidelines:\n"
58 | " - Each action described above contains input/output types and description.\n"
59 | " - You must strictly adhere to the input and output types for each action.\n"
60 | " - The action descriptions contain the guidelines. You MUST strictly follow those guidelines when you use the actions.\n"
61 | " - Each action in the plan should strictly be one of the above types. Follow the Python conventions for each action.\n"
62 | " - Each action MUST have a unique ID, which is strictly increasing.\n"
63 | " - Inputs for actions can either be constants or outputs from preceding actions. "
64 | "In the latter case, use the format $id to denote the ID of the previous action whose output will be the input.\n"
65 | f" - Always call join as the last action in the plan. Say '{END_OF_PLAN}' after you call join\n"
66 | " - Ensure the plan maximizes parallelizability.\n"
67 | " - Only use the provided action types. If a query cannot be addressed using these, invoke the join action for the next steps.\n"
68 | " - Never explain the plan with comments (e.g. #).\n"
69 | " - Never introduce new actions other than the ones provided.\n\n"
70 | )
71 |
72 | if custom_instructions:
73 | prefix += f"{custom_instructions}\n\n"
74 |
75 | if is_replan:
76 | prefix += (
77 | ' - You are given "Previous Plan" which is the plan that the previous agent created along with the execution results '
78 | "(given as Observation) of each plan and a general thought (given as Thought) about the executed results."
79 | 'You MUST use these information to create the next plan under "Current Plan".\n'
80 | ' - When starting the Current Plan, you should start with "Thought" that outlines the strategy for the next plan.\n'
81 | " - In the Current Plan, you should NEVER repeat the actions that are already executed in the Previous Plan.\n"
82 | )
83 |
84 | # Examples
85 | prefix += "Here are some examples:\n\n"
86 | prefix += example_prompt
87 |
88 | return prefix
89 |
90 |
91 | class StreamingGraphParser:
92 | """Streaming version of the GraphParser."""
93 |
94 | buffer = ""
95 | thought = ""
96 | graph_dict = {}
97 |
98 | def __init__(self, tools: Sequence[Union[Tool, StructuredTool]]) -> None:
99 | self.tools = tools
100 |
101 | def _match_buffer_and_generate_task(self, suffix: str) -> Optional[Task]:
102 | """Runs every time "\n" is encountered in the input stream or at the end of the stream.
103 | Matches the buffer against the regex patterns and generates a task if a match is found.
104 | Match patterns include:
105 | 1. Thought:
106 | - this case, the thought is stored in self.thought, and we reset the buffer.
107 | - the thought is then used as the thought for the next action.
108 | 2. . ()
109 | - this case, the tool is instantiated with the idx, tool_name, args, and thought.
110 | - the thought is reset.
111 | - the buffer is reset.
112 | """
113 | if match := re.match(THOUGHT_PATTERN, self.buffer):
114 | # Optionally, action can be preceded by a thought
115 | self.thought = match.group(1)
116 | elif match := re.match(ACTION_PATTERN, self.buffer):
117 | # if action is parsed, return the task, and clear the buffer
118 | idx, tool_name, args, _ = match.groups()
119 | idx = int(idx)
120 | task = instantiate_task(
121 | tools=self.tools,
122 | idx=idx,
123 | tool_name=tool_name,
124 | args=args,
125 | thought=self.thought,
126 | )
127 | self.thought = ""
128 | return task
129 |
130 | return None
131 |
132 | def ingest_token(self, token: str) -> Optional[Task]:
133 | # Append token to buffer
134 | if "\n" in token:
135 | prefix, suffix = token.split("\n", 1)
136 | prefix = prefix.strip()
137 | self.buffer += prefix + "\n"
138 | matched_item = self._match_buffer_and_generate_task(suffix)
139 | self.buffer = suffix
140 | return matched_item
141 | else:
142 | self.buffer += token
143 |
144 | return None
145 |
146 | def finalize(self):
147 | self.buffer = self.buffer + "\n"
148 | return self._match_buffer_and_generate_task("")
149 |
150 |
151 | class TinyAgentEarlyStop(BaseException):
152 | # Defining this as a BaseException to differentiate it from other Exception's that are
153 | # fatal. This is a controlled stop and not an error.
154 | generations = []
155 | llm_output = ""
156 |
157 |
158 | class LLMCompilerCallback(AsyncCallbackHandler):
159 | _queue: asyncio.Queue[Optional[Task]]
160 | _parser: StreamingGraphParser
161 | _tools: Sequence[Union[Tool, StructuredTool]]
162 | _curr_idx: int
163 |
164 | def __init__(
165 | self,
166 | queue: asyncio.Queue[Optional[str]],
167 | tools: Sequence[Union[Tool, StructuredTool]],
168 | ):
169 | self._queue = queue
170 | self._parser = StreamingGraphParser(tools=tools)
171 | self._tools = tools
172 | self._curr_idx = 0
173 |
174 | async def on_llm_start(self, serialized, prompts, **kwargs: Any) -> Any:
175 | """Run when LLM starts running."""
176 |
177 | async def on_llm_new_token(
178 | self,
179 | token: str,
180 | *,
181 | run_id: UUID,
182 | parent_run_id: Optional[UUID] = None,
183 | **kwargs: Any,
184 | ) -> None:
185 | try:
186 | parsed_data = self._parser.ingest_token(token)
187 | print(token, end="", flush=True)
188 | await streaming_queue.put(token)
189 | if parsed_data:
190 | self._curr_idx = parsed_data.idx
191 | await self._queue.put(parsed_data)
192 | if parsed_data.is_join:
193 | await self._queue.put(None)
194 | except Exception as e:
195 | # If there was an error in parsing the token, stop the LLM and propagate the error to
196 | # the joinner for it to handle. The error message will be presented as an observation in the join action.
197 | # This usually happens when the tool name is not recognized/hallucinated by the LLM.
198 | join_tool = instantiate_task(
199 | tools=self._tools,
200 | idx=self._curr_idx + 1,
201 | tool_name="join",
202 | args="",
203 | thought="",
204 | )
205 | join_tool.observation = f"The plan generation was stopped due to an error in tool '{self._parser.buffer.strip()}'! Error: {str(e)}! You MUST correct this error and try again!"
206 | await self._queue.put(join_tool)
207 | await self._queue.put(None)
208 | raise TinyAgentEarlyStop(str(e))
209 |
210 | async def on_llm_end(
211 | self,
212 | response: LLMResult,
213 | *,
214 | run_id: UUID,
215 | parent_run_id: Optional[UUID] = None,
216 | **kwargs: Any,
217 | ) -> None:
218 | parsed_data = self._parser.finalize()
219 | if parsed_data:
220 | await self._queue.put(parsed_data)
221 | await self._queue.put(None)
222 |
223 | # Define the following error callbacks to be able to stop the LLM when an error occurs
224 | async def on_llm_error(
225 | self,
226 | error: BaseException,
227 | *,
228 | run_id: UUID,
229 | parent_run_id: Optional[UUID] = None,
230 | tags: Optional[List[str]] = None,
231 | **kwargs: Any,
232 | ) -> None:
233 | if isinstance(error, TinyAgentEarlyStop):
234 | # Only allow the TinyAgentEarlyStop exception since it is a controlled stop
235 | return
236 | await streaming_queue.put(f"{LLM_ERROR_TOKEN}LLMError: {error}")
237 |
238 | async def on_chain_error(
239 | self,
240 | error: BaseException,
241 | *,
242 | run_id: UUID,
243 | parent_run_id: Optional[UUID] = None,
244 | tags: Optional[List[str]] = None,
245 | **kwargs: Any,
246 | ) -> None:
247 | await streaming_queue.put(f"{LLM_ERROR_TOKEN}ChainError: {error}")
248 |
249 |
250 | class Planner:
251 | def __init__(
252 | self,
253 | llm: BaseChatModel | BaseLLM,
254 | custom_instructions: str | None,
255 | example_prompt: str,
256 | example_prompt_replan: str,
257 | tools: Sequence[Union[Tool, StructuredTool]],
258 | stop: Optional[list[str]],
259 | ):
260 | self.llm = llm
261 | # different system prompt is needed when replanning
262 | # since they have different guidelines, and also examples provided by the user
263 | self.system_prompt = generate_llm_compiler_prompt(
264 | tools=tools,
265 | custom_instructions=custom_instructions,
266 | example_prompt=example_prompt,
267 | is_replan=False,
268 | )
269 | self.system_prompt_replan = generate_llm_compiler_prompt(
270 | tools=tools,
271 | custom_instructions=custom_instructions,
272 | example_prompt=example_prompt_replan,
273 | is_replan=True,
274 | )
275 | self.tools = tools
276 | self.output_parser = LLMCompilerPlanParser(tools=tools)
277 | self.stop = stop
278 |
279 | async def run_llm(
280 | self,
281 | inputs: dict[str, Any],
282 | is_replan: bool = False,
283 | callbacks: Callbacks = None,
284 | ) -> str:
285 | """Run the LLM."""
286 | if is_replan:
287 | system_prompt = self.system_prompt_replan
288 | assert "context" in inputs, "If replanning, context must be provided"
289 | human_prompt = f"Question: {inputs['input']}\n{inputs['context']}\n"
290 | else:
291 | system_prompt = self.system_prompt
292 | human_prompt = f"Question: {inputs['input']}"
293 |
294 | if isinstance(self.llm, BaseChatModel):
295 | messages = [
296 | SystemMessage(content=system_prompt),
297 | HumanMessage(content=human_prompt),
298 | ]
299 | try:
300 | llm_response = await self.llm._call_async(
301 | messages,
302 | callbacks=callbacks,
303 | stop=self.stop,
304 | )
305 | except Exception as e:
306 | # Put this exception in the streaming queue to stop the LLM since the whole planner
307 | # system is running as an async tasks concurrently and is never awaited. Hence
308 | # the errors are not propagated to the main context properly.
309 | await streaming_queue.put(f"{LLM_ERROR_TOKEN}LLMError: {e}")
310 | response = llm_response.content
311 | elif isinstance(self.llm, BaseLLM):
312 | message = system_prompt + "\n\n" + human_prompt
313 | response = await self.llm.apredict(
314 | message,
315 | callbacks=callbacks,
316 | stop=self.stop,
317 | )
318 | else:
319 | raise ValueError("LLM must be either BaseChatModel or BaseLLM")
320 |
321 | log("LLMCompiler planner response: \n", response, block=True)
322 |
323 | return response
324 |
325 | async def plan(
326 | self, inputs: dict, is_replan: bool, callbacks: Callbacks = None, **kwargs: Any
327 | ):
328 | llm_response = await self.run_llm(
329 | inputs=inputs, is_replan=is_replan, callbacks=callbacks
330 | )
331 | llm_response = llm_response + "\n"
332 | return self.output_parser.parse(llm_response)
333 |
334 | async def aplan(
335 | self,
336 | inputs: dict,
337 | task_queue: asyncio.Queue[Optional[str]],
338 | is_replan: bool,
339 | callbacks: Callbacks = None,
340 | **kwargs: Any,
341 | ) -> Plan:
342 | """Given input, asynchronously decide what to do."""
343 | all_callbacks = [
344 | LLMCompilerCallback(
345 | queue=task_queue,
346 | tools=self.tools,
347 | )
348 | ]
349 | if callbacks:
350 | all_callbacks.extend(callbacks)
351 | try:
352 | # Actually, we don't need this try-except block here, but we keep it just in case...
353 | await self.run_llm(
354 | inputs=inputs, is_replan=is_replan, callbacks=all_callbacks
355 | )
356 | except TinyAgentEarlyStop as e:
357 | pass
358 |
--------------------------------------------------------------------------------
/src/llm_compiler/task_fetching_unit.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from dataclasses import dataclass
5 | from typing import Any, Callable, Collection, Dict, List, Optional
6 |
7 | from src.utils.logger_utils import log
8 |
9 | SCHEDULING_INTERVAL = 0.01 # seconds
10 |
11 |
12 | def _default_stringify_rule_for_arguments(args):
13 | if len(args) == 1:
14 | return str(args[0])
15 | else:
16 | return str(tuple(args))
17 |
18 |
19 | def _replace_arg_mask_with_real_value(
20 | args, dependencies: List[int], tasks: Dict[str, Task]
21 | ):
22 | if isinstance(args, (list, tuple)):
23 | return type(args)(
24 | _replace_arg_mask_with_real_value(item, dependencies, tasks)
25 | for item in args
26 | )
27 | elif isinstance(args, str):
28 | for dependency in sorted(dependencies, reverse=True):
29 | # consider both ${1} and $1 (in case planner makes a mistake)
30 | for arg_mask in ["${" + str(dependency) + "}", "$" + str(dependency)]:
31 | if arg_mask in args:
32 | if tasks[dependency].observation is not None:
33 | args = args.replace(
34 | arg_mask, str(tasks[dependency].observation)
35 | )
36 | return args
37 | else:
38 | return args
39 |
40 |
41 | @dataclass
42 | class Task:
43 | idx: int
44 | name: str
45 | tool: Callable
46 | args: Collection[Any]
47 | dependencies: Collection[int]
48 | stringify_rule: Optional[Callable] = None
49 | thought: Optional[str] = None
50 | observation: Optional[str] = None
51 | is_join: bool = False
52 |
53 | async def __call__(self) -> Any:
54 | log(f"running task {self.name}")
55 | x = await self.tool(*self.args)
56 | log(f"done task {self.name}")
57 | return x
58 |
59 | def get_though_action_observation(
60 | self, include_action=True, include_thought=True, include_action_idx=False
61 | ) -> str:
62 | thought_action_observation = ""
63 | if self.thought and include_thought:
64 | thought_action_observation = f"Thought: {self.thought}\n"
65 | if include_action:
66 | idx = f"{self.idx}. " if include_action_idx else ""
67 | if self.stringify_rule:
68 | # If the user has specified a custom stringify rule for the
69 | # function argument, use it
70 | thought_action_observation += f"{idx}{self.stringify_rule(self.args)}\n"
71 | else:
72 | # Otherwise, we have a default stringify rule
73 | thought_action_observation += (
74 | f"{idx}{self.name}"
75 | f"{_default_stringify_rule_for_arguments(self.args)}\n"
76 | )
77 | if self.observation is not None:
78 | thought_action_observation += f"Observation: {self.observation}\n"
79 | return thought_action_observation
80 |
81 |
82 | class TaskFetchingUnit:
83 | tasks: Dict[str, Task]
84 | tasks_done: Dict[str, asyncio.Event]
85 | remaining_tasks: set[str]
86 |
87 | def __init__(self):
88 | self.tasks = {}
89 | self.tasks_done = {}
90 | self.remaining_tasks = set()
91 |
92 | def set_tasks(self, tasks: dict[str, Any]):
93 | self.tasks.update(tasks)
94 | self.tasks_done.update({task_idx: asyncio.Event() for task_idx in tasks})
95 | self.remaining_tasks.update(set(tasks.keys()))
96 |
97 | def _all_tasks_done(self):
98 | return all(self.tasks_done[d].is_set() for d in self.tasks_done)
99 |
100 | def _get_all_executable_tasks(self):
101 | return [
102 | task_name
103 | for task_name in self.remaining_tasks
104 | if all(
105 | self.tasks_done[d].is_set() for d in self.tasks[task_name].dependencies
106 | )
107 | ]
108 |
109 | def _preprocess_args(self, task: Task):
110 | """Replace dependency placeholders, i.e. ${1}, in task.args with the actual observation."""
111 | args = []
112 | for arg in task.args:
113 | arg = _replace_arg_mask_with_real_value(arg, task.dependencies, self.tasks)
114 | args.append(arg)
115 | task.args = args
116 |
117 | async def _run_task(self, task: Task):
118 | try:
119 | self._preprocess_args(task)
120 | if not task.is_join:
121 | observation = await task()
122 | task.observation = observation
123 | except Exception as e:
124 | # If an exception occurs, stop LLM execution and propagate the error message to the joinner
125 | # by manually setting the observation of the task to the error message. If this is an error of
126 | # providing the wrong arguments to the tool, then we do some cleaning up in the error message.
127 | error_message = str(e)
128 | if "positional argument" in error_message:
129 | error_message = error_message.split(".")[-2]
130 | task.observation = (
131 | f"Error: {error_message}! You MUST correct this error and try again!"
132 | )
133 |
134 | self.tasks_done[task.idx].set()
135 |
136 | async def schedule(self):
137 | """Run all tasks in self.tasks in parallel, respecting dependencies."""
138 | # run until all tasks are done
139 | while not self._all_tasks_done():
140 | # Find tasks with no dependencies or with all dependencies met
141 | executable_tasks = self._get_all_executable_tasks()
142 |
143 | for task_name in executable_tasks:
144 | asyncio.create_task(self._run_task(self.tasks[task_name]))
145 | self.remaining_tasks.remove(task_name)
146 |
147 | await asyncio.sleep(SCHEDULING_INTERVAL)
148 |
149 | async def aschedule(self, task_queue: asyncio.Queue[Optional[Task]], func):
150 | """Asynchronously listen to task_queue and schedule tasks as they arrive."""
151 | no_more_tasks = False # Flag to check if all tasks are received
152 |
153 | while True:
154 | if not no_more_tasks:
155 | # Wait for a new task to be added to the queue
156 | task = await task_queue.get()
157 |
158 | # Check for sentinel value indicating end of tasks
159 | if task is None:
160 | no_more_tasks = True
161 | else:
162 | # Parse and set the new tasks
163 | self.set_tasks({task.idx: task})
164 |
165 | # Schedule and run executable tasks
166 | executable_tasks = self._get_all_executable_tasks()
167 |
168 | if executable_tasks:
169 | for task_name in executable_tasks:
170 | # The task is executed in a separate task to avoid blocking the loop
171 | # without explicitly awaiting it. This, unfortunately, means that the
172 | # task will not be able to propagate exceptions to the calling context.
173 | # Hence, we need to handle exceptions within the task itself. See ._run_task()
174 | asyncio.create_task(self._run_task(self.tasks[task_name]))
175 | self.remaining_tasks.remove(task_name)
176 | elif no_more_tasks and self._all_tasks_done():
177 | # Exit the loop if no more tasks are expected and all tasks are done
178 | break
179 | else:
180 | # If no executable tasks are found, sleep for the SCHEDULING_INTERVAL
181 | await asyncio.sleep(SCHEDULING_INTERVAL)
182 |
--------------------------------------------------------------------------------
/src/tiny_agent/computer.py:
--------------------------------------------------------------------------------
1 | from src.tiny_agent.tools.calendar import Calendar
2 | from src.tiny_agent.tools.contacts import Contacts
3 | from src.tiny_agent.tools.mail import Mail
4 | from src.tiny_agent.tools.maps import Maps
5 | from src.tiny_agent.tools.notes import Notes
6 | from src.tiny_agent.tools.reminders import Reminders
7 | from src.tiny_agent.tools.sms import SMS
8 | from src.tiny_agent.tools.spotlight_search import SpotlightSearch
9 | from src.tiny_agent.tools.zoom import Zoom
10 |
11 |
12 | class Computer:
13 | calendar: Calendar
14 | contacts: Contacts
15 | mail: Mail
16 | maps: Maps
17 | notes: Notes
18 | reminders: Reminders
19 | sms: SMS
20 | spotlight_search: SpotlightSearch
21 | zoom: Zoom
22 |
23 | def __init__(self) -> None:
24 | self.calendar = Calendar()
25 | self.contacts = Contacts()
26 | self.mail = Mail()
27 | self.maps = Maps()
28 | self.notes = Notes()
29 | self.reminders = Reminders()
30 | self.sms = SMS()
31 | self.spotlight_search = SpotlightSearch()
32 |
--------------------------------------------------------------------------------
/src/tiny_agent/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Any
4 |
5 | from tiktoken import encoding_name_for_model, get_encoding
6 | from transformers import AutoTokenizer
7 |
8 | from src.tiny_agent.models import (
9 | AgentType,
10 | App,
11 | ModelConfig,
12 | ModelType,
13 | TinyAgentConfig,
14 | WhisperConfig,
15 | )
16 |
17 | DEFAULT_SAFE_CONTEXT_LENGTH = 4096
18 | DEFAULT_EMBEDDING_CONTEXT_LENGTH = 8192
19 | DEFAULT_OPENAI_EMBEDDING_MODEL = "text-embedding-3-small"
20 |
21 |
22 | OPENAI_MODELS = {
23 | "gpt-4": 8192,
24 | "gpt-4-0613": 8192,
25 | "gpt-4-32k": 32768,
26 | "gpt-4-32k-0613": 32768,
27 | "gpt-4-0125-preview": 128000,
28 | "gpt-4-turbo-preview": 128000,
29 | "gpt-4-turbo": 128000,
30 | "gpt-4-1106-preview": 128000,
31 | "gpt-3.5-turbo-0125": 16385,
32 | "gpt-3.5-turbo": 16385,
33 | "gpt-3.5-turbo-1106": 16385,
34 | "gpt-3.5-turbo-instruct": 4096,
35 | }
36 |
37 | AGENT_TYPE_TO_CONFIG_PREFIX = {
38 | AgentType.MAIN: "",
39 | AgentType.SUB_AGENT: "SubAgent",
40 | AgentType.EMBEDDING: "Embedding",
41 | }
42 |
43 |
44 | def load_config(config_path: str) -> dict[str, Any]:
45 | with open(config_path, "r") as file:
46 | return json.load(file)
47 |
48 |
49 | def get_model_config(
50 | config: dict[str, Any],
51 | provider: str,
52 | agent_type: AgentType,
53 | ) -> ModelConfig:
54 | agent_prefix = AGENT_TYPE_TO_CONFIG_PREFIX[agent_type]
55 | model_type = ModelType(provider)
56 |
57 | if model_type == ModelType.AZURE:
58 | _check_azure_config(config, agent_prefix)
59 | api_key = (
60 | config["azureApiKey"]
61 | if len(config["azureApiKey"]) > 0
62 | else os.environ["AZURE_OPENAI_API_KEY"]
63 | )
64 | model_name = config[f"azure{agent_prefix}DeploymentName"]
65 | if agent_type != AgentType.EMBEDDING:
66 | context_length = int(config[f"azure{agent_prefix}CtxLen"])
67 | tokenizer = get_encoding(encoding_name_for_model("gpt-3.5-turbo"))
68 | elif model_type == ModelType.LOCAL:
69 | _check_local_config(config, agent_prefix)
70 | api_key = "lm-studio"
71 | model_name = config[f"local{agent_prefix}ModelName"]
72 | context_length = int(
73 | config[f"local{agent_prefix}CtxLen"]
74 | if len(config[f"local{agent_prefix}CtxLen"]) > 0
75 | else DEFAULT_EMBEDDING_CONTEXT_LENGTH
76 | )
77 | if agent_type != AgentType.EMBEDDING:
78 | tokenizer = AutoTokenizer.from_pretrained(
79 | config[f"local{agent_prefix}TokenizerNameOrPath"],
80 | use_fast=True,
81 | token=config["hfToken"],
82 | )
83 | elif model_type == ModelType.OPENAI:
84 | _check_openai_config(config, agent_prefix)
85 | api_key = (
86 | config["openAIApiKey"]
87 | if len(config["openAIApiKey"]) > 0
88 | else os.environ["OPENAI_API_KEY"]
89 | )
90 | model_name = (
91 | config[f"openAI{agent_prefix}ModelName"]
92 | if agent_type != AgentType.EMBEDDING
93 | else DEFAULT_OPENAI_EMBEDDING_MODEL
94 | )
95 | if agent_type != AgentType.EMBEDDING:
96 | context_length = OPENAI_MODELS[model_name]
97 | tokenizer = get_encoding(encoding_name_for_model("gpt-3.5-turbo"))
98 | else:
99 | raise ValueError("Invalid model type")
100 |
101 | return ModelConfig(
102 | api_key=api_key,
103 | context_length=(
104 | context_length
105 | if agent_type != AgentType.EMBEDDING
106 | else DEFAULT_EMBEDDING_CONTEXT_LENGTH
107 | ),
108 | model_name=model_name,
109 | model_type=model_type,
110 | tokenizer=tokenizer if agent_type != AgentType.EMBEDDING else None,
111 | port=(
112 | int(config[f"local{agent_prefix}Port"])
113 | if model_type == ModelType.LOCAL
114 | and len(config[f"local{agent_prefix}Port"]) > 0
115 | else None
116 | ),
117 | )
118 |
119 |
120 | def get_whisper_config(config: dict[str, Any], provider: str) -> WhisperConfig:
121 | whisper_provider = ModelType(provider)
122 |
123 | api_key = None
124 | port = None
125 | if whisper_provider == ModelType.OPENAI:
126 | if (
127 | not _is_valid_config_field(config, "openAIApiKey")
128 | and os.environ.get("OPENAI_API_KEY") is None
129 | ):
130 | raise ValueError("OpenAI API key for Whisper not found in config")
131 |
132 | api_key = (
133 | config.get("openAIApiKey")
134 | if len(config["openAIApiKey"]) > 0
135 | else os.environ.get("OPENAI_API_KEY")
136 | )
137 | elif whisper_provider == ModelType.LOCAL:
138 | if not _is_valid_config_field(config, "localWhisperPort"):
139 | raise ValueError("Local Whisper port not found in config")
140 | port = int(config["localWhisperPort"])
141 | else:
142 | raise ValueError("Invalid Whisper provider")
143 |
144 | return WhisperConfig(provider=whisper_provider, api_key=api_key, port=port)
145 |
146 |
147 | def get_tiny_agent_config(config_path: str) -> TinyAgentConfig:
148 | config = load_config(config_path)
149 |
150 | if (provider := config.get("provider")) is None or len(provider) == 0:
151 | raise ValueError("Provider not found in config")
152 |
153 | if (sub_agent_provider := config.get("subAgentProvider")) is None or len(
154 | sub_agent_provider
155 | ) == 0:
156 | raise ValueError("Subagent provider not found in config")
157 |
158 | use_in_context_example_retriever = config.get("useToolRAG")
159 | if (
160 | use_in_context_example_retriever is False
161 | and config.get("toolRAGProvider") is None
162 | ):
163 | raise ValueError("In-context example retriever provider not found in config")
164 |
165 | whisper_provider = config.get("whisperProvider")
166 | if whisper_provider is None or len(whisper_provider) == 0:
167 | raise ValueError("Whisper provider not found in config")
168 |
169 | # Get the model configs
170 | llmcompiler_config = get_model_config(config, config["provider"], AgentType.MAIN)
171 | sub_agent_config = get_model_config(
172 | config,
173 | sub_agent_provider,
174 | AgentType.SUB_AGENT,
175 | )
176 | if use_in_context_example_retriever is True:
177 | embedding_model_config = get_model_config(
178 | config, config["toolRAGProvider"], AgentType.EMBEDDING
179 | )
180 |
181 | # Azure config
182 | azure_api_version = config.get("azureApiVersion")
183 | azure_endpoint = (
184 | config.get("azureEndpoint")
185 | if len(config["azureEndpoint"]) > 0
186 | else os.environ.get("AZURE_OPENAI_ENDPOINT")
187 | )
188 |
189 | # Get the other config values
190 | hf_token = config.get("hfToken")
191 | zoom_access_token = config.get("zoomAccessToken")
192 |
193 | # Get the whisper API key which is just the OpenAI API key
194 | if (
195 | not _is_valid_config_field(config, "openAIApiKey")
196 | and os.environ.get("OPENAI_API_KEY") is None
197 | ):
198 | raise ValueError("OpenAI API key is needed for whisper API.")
199 |
200 | whisper_config = get_whisper_config(config, whisper_provider)
201 |
202 | apps = set()
203 | for app in App:
204 | if config[f"{app.value}Enabled"]:
205 | apps.add(app)
206 |
207 | return TinyAgentConfig(
208 | apps=apps,
209 | custom_instructions=config.get("customInstructions"),
210 | llmcompiler_config=llmcompiler_config,
211 | sub_agent_config=sub_agent_config,
212 | embedding_model_config=(
213 | embedding_model_config if use_in_context_example_retriever else None
214 | ),
215 | azure_api_version=azure_api_version,
216 | azure_endpoint=azure_endpoint,
217 | hf_token=hf_token,
218 | zoom_access_token=zoom_access_token,
219 | whisper_config=whisper_config,
220 | )
221 |
222 |
223 | def _is_valid_config_field(config: dict[str, Any], field: str) -> bool:
224 | return (field_value := config.get(field)) is not None and len(field_value) > 0
225 |
226 |
227 | def _check_azure_config(config: dict[str, Any], agent_prefix: str) -> None:
228 | if (
229 | not _is_valid_config_field(config, "azureApiKey")
230 | and os.environ.get("AZURE_OPENAI_API_KEY") is None
231 | ):
232 | raise ValueError("Azure API key not found in config")
233 | if not _is_valid_config_field(config, f"azure{agent_prefix}DeploymentName"):
234 | raise ValueError(
235 | f"Azure {agent_prefix} model deployment name not found in config"
236 | )
237 | if agent_prefix != AGENT_TYPE_TO_CONFIG_PREFIX[
238 | AgentType.EMBEDDING
239 | ] and not _is_valid_config_field(config, f"azure{agent_prefix}CtxLen"):
240 | raise ValueError(f"Azure {agent_prefix} context length not found in config")
241 | if not _is_valid_config_field(config, "azureApiVersion"):
242 | raise ValueError("Azure API version not found in config")
243 | if not _is_valid_config_field(config, "azureEndpoint"):
244 | raise ValueError("Azure endpoint not found in config")
245 |
246 |
247 | def _check_openai_config(config: dict[str, Any], agent_prefix: str) -> None:
248 | if (
249 | not _is_valid_config_field(config, "openAIApiKey")
250 | and os.environ.get("OPENAI_API_KEY") is None
251 | ):
252 | raise ValueError("OpenAI API key not found in config")
253 | # The embedding model for OpenAI API only supports "text-embedding-3-small" hence
254 | # we don't need to check for the model name for the embedding model
255 | if agent_prefix != AGENT_TYPE_TO_CONFIG_PREFIX[
256 | AgentType.EMBEDDING
257 | ] and not _is_valid_config_field(config, f"openAI{agent_prefix}ModelName"):
258 | raise ValueError(f"OpenAI {agent_prefix} model name not found in config")
259 |
260 |
261 | def _check_local_config(config: dict[str, Any], agent_prefix: str) -> None:
262 | # TinyAgent does not support local embedding models. Hence, we only need to check for
263 | # the context length, port, and tokenizer name or path for the local planner model.
264 | is_embedding_model = (
265 | agent_prefix == AGENT_TYPE_TO_CONFIG_PREFIX[AgentType.EMBEDDING]
266 | )
267 | if not is_embedding_model and not _is_valid_config_field(
268 | config, f"local{agent_prefix}CtxLen"
269 | ):
270 | raise ValueError(f"Local {agent_prefix} context length not found in config")
271 | if not is_embedding_model and not _is_valid_config_field(
272 | config, f"local{agent_prefix}Port"
273 | ):
274 | raise ValueError(f"Local {agent_prefix} port not found in config")
275 | if not is_embedding_model and not _is_valid_config_field(
276 | config, f"local{agent_prefix}TokenizerNameOrPath"
277 | ):
278 | raise ValueError(
279 | f"Local {agent_prefix} tokenizer name or path not found in config"
280 | )
281 | if is_embedding_model and not _is_valid_config_field(
282 | config, f"local{agent_prefix}ModelName"
283 | ):
284 | raise ValueError(f"Local {agent_prefix} model name not found in config")
285 |
--------------------------------------------------------------------------------
/src/tiny_agent/models.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from dataclasses import dataclass
4 | from enum import Enum
5 | from typing import Collection
6 |
7 | import torch
8 | from tiktoken import Encoding
9 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
10 |
11 | streaming_queue = asyncio.Queue[str | None]()
12 |
13 | LLM_ERROR_TOKEN = "###LLM_ERROR_TOKEN###"
14 |
15 | TINY_AGENT_DIR = os.path.expanduser("~/Library/Application Support/TinyAgent")
16 |
17 | Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast | Encoding
18 |
19 |
20 | class ModelType(Enum):
21 | AZURE = "azure"
22 | OPENAI = "openai"
23 | LOCAL = "local"
24 |
25 |
26 | class App(Enum):
27 | CALENDAR = "calendar"
28 | CONTACTS = "contacts"
29 | FILES = "files"
30 | MAIL = "mail"
31 | MAPS = "maps"
32 | NOTES = "notes"
33 | REMINDERS = "reminders"
34 | SMS = "sms"
35 | ZOOM = "zoom"
36 |
37 |
38 | class AgentType(Enum):
39 | MAIN = "main"
40 | SUB_AGENT = "sub_agent"
41 | EMBEDDING = "embedding"
42 |
43 |
44 | @dataclass
45 | class ModelConfig:
46 | api_key: str
47 | context_length: int
48 | model_name: str
49 | model_type: ModelType
50 | tokenizer: Tokenizer | None
51 | port: int | None
52 |
53 |
54 | @dataclass
55 | class WhisperConfig:
56 | # Azure is not yet supported for whisper
57 | provider: ModelType
58 | api_key: str | None
59 | port: int | None
60 |
61 |
62 | @dataclass
63 | class TinyAgentConfig:
64 | # Custom configs
65 | apps: Collection[App]
66 | custom_instructions: str | None
67 | # Config for the LLMCompiler model
68 | llmcompiler_config: ModelConfig
69 | # Config for the sub-agent LLM model
70 | sub_agent_config: ModelConfig
71 | # Config for the embedding model
72 | embedding_model_config: ModelConfig | None
73 | # Azure model config
74 | azure_api_version: str | None
75 | azure_endpoint: str | None
76 | # Other tokens
77 | hf_token: str | None
78 | zoom_access_token: str | None
79 | # Whisper config
80 | whisper_config: WhisperConfig
81 |
82 |
83 | class TinyAgentToolName(Enum):
84 | GET_PHONE_NUMBER = "get_phone_number"
85 | GET_EMAIL_ADDRESS = "get_email_address"
86 | CREATE_CALENDAR_EVENT = "create_calendar_event"
87 | OPEN_AND_GET_FILE_PATH = "open_and_get_file_path"
88 | SUMMARIZE_PDF = "summarize_pdf"
89 | COMPOSE_NEW_EMAIL = "compose_new_email"
90 | REPLY_TO_EMAIL = "reply_to_email"
91 | FORWARD_EMAIL = "forward_email"
92 | MAPS_OPEN_LOCATION = "maps_open_location"
93 | MAPS_SHOW_DIRECTIONS = "maps_show_directions"
94 | CREATE_NOTE = "create_note"
95 | OPEN_NOTE = "open_note"
96 | APPEND_NOTE_CONTENT = "append_note_content"
97 | CREATE_REMINDER = "create_reminder"
98 | SEND_SMS = "send_sms"
99 | GET_ZOOM_MEETING_LINK = "get_zoom_meeting_link"
100 |
101 |
102 | @dataclass
103 | class InContextExample:
104 | example: str
105 | embedding: torch.Tensor
106 | tools: list[TinyAgentToolName]
107 |
108 |
109 | class ComposeEmailMode(Enum):
110 | NEW = "new"
111 | REPLY = "reply"
112 | FORWARD = "forward"
113 |
114 |
115 | class NotesMode(Enum):
116 | NEW = "new"
117 | APPEND = "append"
118 |
119 |
120 | class TransportationOptions(Enum):
121 | DRIVING = "d"
122 | WALKING = "w"
123 | PUBLIC_TRANSIT = "r"
124 |
--------------------------------------------------------------------------------
/src/tiny_agent/prompts.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import Sequence
3 |
4 | from src.llm_compiler.constants import (
5 | END_OF_PLAN,
6 | JOINNER_FINISH,
7 | JOINNER_REPLAN,
8 | SUMMARY_RESULT,
9 | )
10 | from src.tiny_agent.models import TinyAgentToolName
11 | from src.tools.base import StructuredTool, Tool
12 | from src.utils.logger_utils import log
13 |
14 | NOW = datetime.datetime.now()
15 |
16 |
17 | DEFAULT_PLANNER_IN_CONTEXT_EXAMPLES_PROMPT = (
18 | "Question: Notify Lutfi Eren Erdogan about the upcoming Apple meeting that is going to start at 3PM on Friday.\n"
19 | '1. get_phone_number("Lutfi Eren Erdogan")\n'
20 | '2. send_sms(["$1"], "Hey Lutfi, just wanted to let you know about the upcoming Apple meeting. It\'s going to be at 3 PM on Friday.")\n'
21 | "Thought: I have succesfully found the contact and sent the message.\n"
22 | f"3. join(){END_OF_PLAN}\n"
23 | "###\n"
24 | "Question: Create a zoom meeting for the upcoming Apple meeting with Eren Erdoğan.\n"
25 | '1. get_email_address("Eren Erdoğan")\n'
26 | '2. get_zoom_meeting_link("Apple Meeting", "2022-10-14 15:00:00", 60, ["$1"])\n'
27 | '3. create_calendar_event("Apple Meeting", "2022-10-14 15:00:00", "2022-10-14 16:00:00", "$2", [], "", None)\n'
28 | "Thought: I have succesfully created the calendar event.\n"
29 | f"4. join(){END_OF_PLAN}\n"
30 | "###\n"
31 | "Question: Show directions to Apple Park.\n"
32 | '1. maps_show_directions("", "Apple Park", "d")\n'
33 | "Thought: I have succesfully shown the directions.\n"
34 | f"2. join(){END_OF_PLAN}\n"
35 | "###\n"
36 | "Question: Send an email to Amir saying that the meeting is postponed to next week.\n"
37 | '1. get_email_address("Amir")\n'
38 | '2. compose_new_email(["$1"], [], "Meeting Postponed," "", [])\n'
39 | f"3. join(){END_OF_PLAN}\n"
40 | "###\n"
41 | )
42 |
43 |
44 | TOOL_SPECIFIC_PROMPTS: list[tuple[set[TinyAgentToolName], str]] = [
45 | (
46 | {
47 | TinyAgentToolName.GET_EMAIL_ADDRESS,
48 | TinyAgentToolName.COMPOSE_NEW_EMAIL,
49 | TinyAgentToolName.REPLY_TO_EMAIL,
50 | TinyAgentToolName.FORWARD_EMAIL,
51 | },
52 | f" - Before sending an email, you MUST use the {TinyAgentToolName.GET_EMAIL_ADDRESS.value} tool to get the email addresses of the recipients and cc, unless you are explicitly given their email addresses.\n",
53 | ),
54 | (
55 | {TinyAgentToolName.GET_PHONE_NUMBER, TinyAgentToolName.SEND_SMS},
56 | f" - Before sending an SMS, you MUST use the {TinyAgentToolName.GET_PHONE_NUMBER.value} tool to get the phone number of the contact, unless you are explicitly given their phone number.\n"
57 | " - If you need to send an SMS message to multiple contacts, send it in one message, unless specified otherwise.\n",
58 | ),
59 | (
60 | {
61 | TinyAgentToolName.GET_EMAIL_ADDRESS,
62 | TinyAgentToolName.COMPOSE_NEW_EMAIL,
63 | TinyAgentToolName.REPLY_TO_EMAIL,
64 | TinyAgentToolName.FORWARD_EMAIL,
65 | TinyAgentToolName.GET_PHONE_NUMBER,
66 | TinyAgentToolName.SEND_SMS,
67 | },
68 | f" - If you need to send an email or an sms using {TinyAgentToolName.COMPOSE_NEW_EMAIL.value}, {TinyAgentToolName.REPLY_TO_EMAIL.value}, {TinyAgentToolName.FORWARD_EMAIL.value}, or {TinyAgentToolName.SEND_SMS.value} tools, "
69 | f"you MUST send it before calling join(), or you WILL BE PENALIZED!\n",
70 | ),
71 | (
72 | {TinyAgentToolName.GET_ZOOM_MEETING_LINK},
73 | f" - If you need to create a zoom meeting, you MUST use {TinyAgentToolName.GET_ZOOM_MEETING_LINK.value} to get the newly created zoom meeting link.\n",
74 | ),
75 | (
76 | {TinyAgentToolName.OPEN_NOTE, TinyAgentToolName.APPEND_NOTE_CONTENT},
77 | f" - If you need to append some content to a note, you DON'T HAVE TO call {TinyAgentToolName.OPEN_NOTE.value} before calling {TinyAgentToolName.APPEND_NOTE_CONTENT.value}. You can directly use {TinyAgentToolName.APPEND_NOTE_CONTENT.value} to append some content to the specific note.\n",
78 | ),
79 | (
80 | {TinyAgentToolName.MAPS_OPEN_LOCATION, TinyAgentToolName.MAPS_SHOW_DIRECTIONS},
81 | f" - If you need to show directions to a place, you DON'T HAVE TO call {TinyAgentToolName.MAPS_OPEN_LOCATION.value} before calling {TinyAgentToolName.MAPS_SHOW_DIRECTIONS.value}. You can directly use {TinyAgentToolName.MAPS_SHOW_DIRECTIONS.value} to show directions to the specific place.\n",
82 | ),
83 | ]
84 |
85 |
86 | def get_planner_custom_instructions_prompt(
87 | tools: Sequence[Tool | StructuredTool], custom_instructions: str | None
88 | ) -> str:
89 | prompt = []
90 | prompt.append(
91 | " - You need to start your plan with the '1.' call\n"
92 | f" - Today's date is {NOW.strftime('%A %Y-%m-%d %H:%M')}\n"
93 | " - Unless otherwise specified, the default meeting duration is 60 minutes.\n"
94 | " - Do not use named arguments in your tool calls.\n"
95 | " - You MUST end your plans with the 'join()' call and a '\\n' character.\n"
96 | " - You MUST fill every argument in the tool calls, even if they are optional.\n"
97 | " - The format for dates MUST be in ISO format of 'YYYY-MM-DD HH:MM:SS', unless other specified.\n"
98 | " - If you want to use the result of a previous tool call, you MUST use the '$' sign followed by the index of the tool call.\n"
99 | f" - You MUST ONLY USE join() at the very very end of the plan, or you WILL BE PENALIZED.\n"
100 | )
101 |
102 | for tool_set, instructions in TOOL_SPECIFIC_PROMPTS:
103 | if any(TinyAgentToolName(tool.name) in tool_set for tool in tools):
104 | prompt += instructions
105 |
106 | if custom_instructions is not None and len(custom_instructions) > 0:
107 | prompt.append(f" - {custom_instructions}")
108 |
109 | return "".join(prompt)
110 |
111 |
112 | PLANNER_PROMPT_REPLAN = (
113 | "Question: Say hi to Sid via SMS.\n\n"
114 | "Previous Plan:\n\n"
115 | "1. join()\n"
116 | "Observation:\nThe plan generation was stopped due to an error in tool 1. get_contact_info('Sid')! "
117 | "Error: Tool get_contact_info not found. You MUST correct this error and try again!"
118 | "\n"
119 | "Current Plan:\n\n"
120 | f"Thought: The error is fixable since I have the {TinyAgentToolName.GET_PHONE_NUMBER.value} tool to retrieve the phone number of Sid. Then I will proceed with sending the SMS.\n"
121 | '1. get_phone_number("Sid")\n'
122 | '2. send_sms("$2", "Hi Sid!")\n'
123 | "Thought: I have succesfully created the retrieved the phone number and sent the SMS.\n"
124 | f"4. join(){END_OF_PLAN}\n"
125 | "###\n"
126 | "Question: Summarize 'Apple Demo.pdf'.\n\n"
127 | "Previous Plan:\n\n"
128 | '1. open_and_get_file_path("Apple Demo")\n'
129 | "2. join()\n"
130 | "Observation: summarize_pdf() takes 1 positional arguments but 2 were given! You MUST correct this error and try again!"
131 | "\n"
132 | "Current Plan:\n\n"
133 | f"Thought: Previous plan tried to call the summarize_pdf() tool with the wrong number of arguments. I will correct this and try again.\n"
134 | '1. open_and_get_file_path("Apple Demo")\n'
135 | '2. summarize_pdf("$1")\n'
136 | "Thought: I have succesfully opened the file and summarized it.\n"
137 | f"3. join(){END_OF_PLAN}\n"
138 | "###\n"
139 | )
140 |
141 | JOINNER_REPLAN_RULES = (
142 | f" - If you think the plan is not completed yet or an error in the plan is fixable, you should output {JOINNER_REPLAN}.\n"
143 | f" - If the plan is fixable, you will see a message like 'try again'. If you don't see this message, the error is NOT fixable and you MUST output an error message using 'Action: {JOINNER_FINISH}()'\n"
144 | )
145 |
146 | JOINNER_FINISH_RULES = (
147 | f" - If you need to answer some knowledge question, just answer it directly using 'Action: {JOINNER_FINISH}()'.\n"
148 | f" - If you need to return the result of a summary (summarize_pdf), you MUST use 'Action: {JOINNER_FINISH}({SUMMARY_RESULT})'\n"
149 | f" - If there is an error in one of the tool calls and it is not fixable, you should provide a user-friendly error message using 'Action: {JOINNER_FINISH}()'.\n"
150 | )
151 |
152 | REPLAN_EXAMPLES = [
153 | "Question: Say hi to Sid via SMS.\n"
154 | "join()\n"
155 | "Observation: The plan generation was stopped due to an error in tool 1. get_contact_info('Sid')! "
156 | "Error: Tool get_contact_info not found. You MUST correct this error and try again!"
157 | "Thought: The error is fixable so I need to replan and try again.\n"
158 | f"Action: {JOINNER_REPLAN}\n"
159 | ]
160 |
161 | FINISH_EXAMPLES = [
162 | "Question: Create a zoom meeting for the upcoming Apple meeting with Eren Erdoğan. \n"
163 | 'get_email_address("Eren Erdoğan")\n'
164 | "Observation: eren@gmail.com\n"
165 | 'get_zoom_meeting_link("Apple Meeting", "2022-10-14 15:00:00", 60, ["$1"])\n'
166 | "Observation: https://zoom.us/j/1234567890?pwd=abc123\n"
167 | 'create_calendar_event("Apple Meeting", "2022-10-14 15:00:00", "2022-10-14 16:00:00", "Apple HQ", "$2", None)\n'
168 | "Observation: Event created successfully\n"
169 | "Thought: I don't need to answer a question.\n"
170 | f"Action: {JOINNER_FINISH}(Task completed!)\n",
171 | "Question: What is the content of the Apple meeting notes? \n"
172 | 'get_note_content("Apple Meeting")\n'
173 | "Observation: The meeting is about the new iPhone release.\n"
174 | "Thought: I can just answer the question directly.\n"
175 | f"Action: {JOINNER_FINISH}(The meeting is about the new iPhone release.)\n",
176 | "Question: Compose a new email to John, attaching the Project.pdf file.\n"
177 | 'get_email_address("John")\n'
178 | "Observation: john@doe.com"
179 | 'open_and_get_file_path("Project")\n'
180 | "Observation: /Users/eren/Downloads/Project.pdf\n"
181 | 'compose_new_email([john@doe.com], [], "Project Update", "Please find the attached project update.", ["/Users/eren/Downloads/Project.pdf"])\n'
182 | "Observation: There was an error while composing the email.\n"
183 | "Thought: There was an error with the compose_new_email tool call and it is not possible to fix it. I need to provide a user-friendly error message.\n"
184 | f"Action: {JOINNER_FINISH}(There was an error while composing the email. Please try again later.)\n",
185 | "Question: Summarize the Apple Demo file. \n"
186 | "open_and_get_file_path(Apple Demo)\n"
187 | "Observation: /Users/eren/Downloads/Apple_Demo.pdf\n"
188 | "summarize_pdf(/Users/eren/Downloads/Apple_Demo.pdf)\n"
189 | "Observation: The new iPhone is going to be released in 2023.\n"
190 | f"Action: {JOINNER_FINISH}({SUMMARY_RESULT})\n",
191 | ]
192 |
193 | OUTPUT_PROMPT = (
194 | "Follow these rules:\n"
195 | f" - You MUST only output either {JOINNER_FINISH} or {JOINNER_REPLAN}, or you WILL BE PENALIZED.\n"
196 | f"{JOINNER_FINISH_RULES}"
197 | f"{JOINNER_REPLAN_RULES}"
198 | "\n"
199 | "Here are some examples:\n"
200 | + "###\n".join(FINISH_EXAMPLES)
201 | + "###\n"
202 | + "###\n".join(REPLAN_EXAMPLES)
203 | + "###\n"
204 | )
205 |
206 |
207 | OUTPUT_PROMPT_FINAL = (
208 | "Follow these rules:\n"
209 | f" - You MUST only output {JOINNER_FINISH}, or you WILL BE PENALIZED.\n"
210 | f"{JOINNER_FINISH_RULES}"
211 | "\n"
212 | "Here are some examples:\n" + "###\n".join(FINISH_EXAMPLES) + "###\n"
213 | )
214 |
--------------------------------------------------------------------------------
/src/tiny_agent/run_apple_script.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | def run_applescript(script: str) -> str:
5 | """
6 | Runs the given AppleScript using osascript and returns the result.
7 | """
8 | args = ["osascript", "-e", script]
9 | return subprocess.check_output(args, universal_newlines=True)
10 |
11 |
12 | def run_applescript_capture(script: str) -> tuple[str, str]:
13 | """
14 | Runs the given AppleScript using osascript, captures the output and error, and returns them.
15 | """
16 | args = ["osascript", "-e", script]
17 | result = subprocess.run(args, capture_output=True, text=True, check=False)
18 | stdout, stderr = result.stdout, result.stderr
19 | return stdout, stderr
20 |
21 |
22 | def run_command(command) -> tuple[str, str]:
23 | """
24 | Executes a shell command and returns the output.
25 | """
26 | result = subprocess.run(command, capture_output=True, text=True, check=False)
27 | stdout, stderr = result.stdout, result.stderr
28 | return stdout, stderr
29 |
--------------------------------------------------------------------------------
/src/tiny_agent/sub_agents/compose_email_agent.py:
--------------------------------------------------------------------------------
1 | import re
2 | from enum import Enum
3 |
4 | from langchain_core.messages import HumanMessage, SystemMessage
5 |
6 | from src.tiny_agent.models import ComposeEmailMode
7 | from src.tiny_agent.sub_agents.sub_agent import SubAgent
8 |
9 |
10 | class ComposeEmailAgent(SubAgent):
11 | _query: str
12 |
13 | @property
14 | def query(self) -> str:
15 | return self._query
16 |
17 | @query.setter
18 | def query(self, query: str) -> None:
19 | self._query = query
20 |
21 | async def __call__(
22 | self,
23 | context: str,
24 | email_thread: str = "",
25 | mode: ComposeEmailMode = ComposeEmailMode.NEW,
26 | ) -> str:
27 | # Define the system prompt for the LLM to generate HTML content
28 | context = context.strip()
29 | cleaned_thread = re.sub(r"\n+", "\n", email_thread.strip())
30 | if mode == ComposeEmailMode.NEW:
31 | email_llm_system_prompt = (
32 | "You are an expert email composer agent. Given an email content or a user query, you MUST generate a well-formatted and "
33 | "informative email. The email should include a polite greeting, a detailed body, and a "
34 | "professional sign-off. You MUST NOT include a subject. The email should be well-structured and free of grammatical errors."
35 | )
36 | elif mode == ComposeEmailMode.REPLY:
37 | email_llm_system_prompt = (
38 | "You are an expert email composer agent. Given the content of the past email thread and a user query, "
39 | "you MUST generate a well-formatted and informative reply to the last email in the thread. "
40 | "The email should include a polite greeting, a detailed body, and a "
41 | "professional sign-off. You MUST NOT include a subject. The email should be well-structured and free of grammatical errors."
42 | )
43 | context += f"\nEmail Thread:\n{cleaned_thread}"
44 | elif mode == ComposeEmailMode.FORWARD:
45 | email_llm_system_prompt = (
46 | "You are an expert email composer agent. Given the content of the past email thread and a user query, "
47 | "you MUST generate a very concise and informative forward of the last email in the thread. "
48 | )
49 | context += f"\nEmail Thread:\n{cleaned_thread}"
50 |
51 | # Add custom instructions to the system prompt if specified
52 | if self._custom_instructions is not None:
53 | email_llm_system_prompt += (
54 | "\nHere are some general facts about the user's preferences, "
55 | f"you MUST keep these in mind when writing your email:\n{self._custom_instructions}"
56 | )
57 |
58 | email_human_prompt = "Context:\n{context}\nQuery: {query}\nEmail Body:\n"
59 | messages = [
60 | SystemMessage(content=email_llm_system_prompt),
61 | HumanMessage(
62 | content=email_human_prompt.format(context=context, query=self._query)
63 | ),
64 | ]
65 |
66 | # If the context content is too long, then get the first X tokens of the subagent llm
67 | new_context = self.check_context_length(messages, context)
68 | if new_context is not None:
69 | messages = [
70 | SystemMessage(content=email_llm_system_prompt),
71 | HumanMessage(
72 | content=email_human_prompt.format(
73 | context=new_context, query=self._query
74 | )
75 | ),
76 | ]
77 |
78 | # Generate the HTML content for the email
79 | email_content = await self._llm.apredict_messages(messages)
80 |
81 | return str(email_content.content)
82 |
--------------------------------------------------------------------------------
/src/tiny_agent/sub_agents/notes_agent.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from bs4 import BeautifulSoup
4 | from langchain_core.messages import HumanMessage, SystemMessage
5 |
6 | from src.tiny_agent.models import NotesMode
7 | from src.tiny_agent.sub_agents.sub_agent import SubAgent
8 |
9 |
10 | class NotesAgent(SubAgent):
11 |
12 | async def __call__(
13 | self,
14 | name: str,
15 | content: str,
16 | prev_content: str = "",
17 | mode: NotesMode = NotesMode.NEW,
18 | ) -> str:
19 | # Construct the prompt
20 | if mode == NotesMode.NEW:
21 | notes_llm_system_prompt = (
22 | "You are an expert note taking agent. Given the plain text content, you MUST generate a compelling and "
23 | "formatted HTML version of the note (with , , tags, etc.) The content of the note should be rich, well-structured, and verbose. "
24 | )
25 | elif mode == NotesMode.APPEND:
26 | notes_llm_system_prompt = (
27 | "You are an expert note-taking agent that specializes in appending new content to existing notes. "
28 | "Given the content of an existing note and the content to append, you MUST generate a continuation of the note that "
29 | "seamlessly integrates with the existing content. Your additions should maintain the tone, style, "
30 | "and subject matter of the original note. You MUST ONLY output the appended content in HTML format, DO NOT include the entire note, or you WILL BE PENALIZED.\n"
31 | f"Previous Content:\n{prev_content}\n"
32 | )
33 |
34 | notes_llm_system_prompt += (
35 | "The note should include appropriate use of headings, bold and italic text for emphasis, "
36 | "bullet points for lists, and paragraph tags for separation of ideas. DO NOT include the '' tag."
37 | )
38 |
39 | # Add custom instructions to the system prompt if specified
40 | if self._custom_instructions is not None:
41 | notes_llm_system_prompt += (
42 | "\nHere are some general facts about the user's preferences, "
43 | f"you MUST keep these in mind when generating your note:\n{self._custom_instructions}"
44 | )
45 |
46 | notes_human_prompt = (
47 | "Note Title: {name}\nNew Text Content: {content}\nHTML Content:\n"
48 | )
49 | messages = [
50 | SystemMessage(content=notes_llm_system_prompt),
51 | HumanMessage(content=notes_human_prompt.format(name=name, content=content)),
52 | ]
53 | plain_text_content = self.check_context_length(messages, content)
54 |
55 | # If the plain text content is too long, then get the first X tokens of the subagent llm
56 | if plain_text_content is not None:
57 | messages = [
58 | SystemMessage(content=notes_llm_system_prompt),
59 | HumanMessage(
60 | content=notes_human_prompt.format(name=name, content=content)
61 | ),
62 | ]
63 |
64 | # Generate the HTML content for the note
65 | html_content = await self._llm.apredict_messages(messages)
66 |
67 | # If the html doesn't start with a tag, then add it
68 | soup = BeautifulSoup(str(html_content.content), "html.parser")
69 | if not soup.find("html"):
70 | soup = BeautifulSoup(
71 | f"{str(html_content.content)}", "html.parser"
72 | )
73 |
74 | return str(soup.find("html"))
75 |
--------------------------------------------------------------------------------
/src/tiny_agent/sub_agents/pdf_summarizer_agent.py:
--------------------------------------------------------------------------------
1 | import fitz
2 | from langchain_core.messages import HumanMessage, SystemMessage
3 |
4 | from src.tiny_agent.sub_agents.sub_agent import SubAgent
5 |
6 | CONTEXT_LENGTHS = {"gpt-4-1106-preview": 127000, "gpt-3.5-turbo": 16000}
7 |
8 |
9 | class PDFSummarizerAgent(SubAgent):
10 | _cached_summary_result: str
11 |
12 | @property
13 | def cached_summary_result(self) -> str:
14 | return self._cached_summary_result
15 |
16 | async def __call__(self, pdf_path: str) -> str:
17 | # Check if the file exists
18 | if (
19 | pdf_path is None
20 | or len(pdf_path) <= 0
21 | or pdf_path
22 | in (
23 | "No file found after fuzzy matching.",
24 | "No file found with exact or fuzzy name.",
25 | )
26 | ):
27 | return "The PDF file path is invalid or the file doesn't exist."
28 |
29 | try:
30 | pdf_content = PDFSummarizerAgent._extract_text_from_pdf(pdf_path)
31 | except Exception as e:
32 | return f"An error occurred while extracting the content from the PDF file: {str(e)}"
33 |
34 | if len(pdf_content) <= 0:
35 | return "The PDF file is empty or the content couldn't be extracted."
36 |
37 | # Construct the prompt
38 | pdf_summarizer_llm_system_prompt = (
39 | "You are an expert PDF summarizer agent. Given the PDF content, you MUST generate an informative and verbose "
40 | "summary of the content. The summary should include the main points and key details of the content. "
41 | )
42 | pdf_summarizer_human_prompt = "PDF Content:\n{pdf_content}\nSummary:\n"
43 | messages = [
44 | SystemMessage(content=pdf_summarizer_llm_system_prompt),
45 | HumanMessage(
46 | content=pdf_summarizer_human_prompt.format(pdf_content=pdf_content)
47 | ),
48 | ]
49 |
50 | # If the PDF content is too long, then get the first X tokens of the subagent llm
51 | pdf_content = self.check_context_length(messages, pdf_content)
52 | if pdf_content is not None:
53 | messages = [
54 | SystemMessage(content=pdf_summarizer_llm_system_prompt),
55 | HumanMessage(
56 | content=pdf_summarizer_human_prompt.format(pdf_content=pdf_content)
57 | ),
58 | ]
59 |
60 | # Call LLM
61 | summary = await self._llm.apredict_messages(messages)
62 |
63 | # Cache the summary result
64 | self._cached_summary_result = str(summary.content)
65 |
66 | return self._cached_summary_result
67 |
68 | @staticmethod
69 | def _extract_text_from_pdf(pdf_path: str) -> str:
70 | doc = fitz.open(pdf_path)
71 | text = []
72 | for page in doc:
73 | text.append(page.get_text()) # type: ignore
74 | doc.close()
75 | return "".join(text).replace("\n", " ")
76 |
--------------------------------------------------------------------------------
/src/tiny_agent/sub_agents/sub_agent.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from langchain.llms.base import BaseLLM
4 | from langchain_core.messages import BaseMessage, get_buffer_string
5 |
6 | from src.tiny_agent.config import ModelConfig
7 | from src.tiny_agent.models import Tokenizer
8 |
9 |
10 | class SubAgent(abc.ABC):
11 | # A constant to make the context length check more conservative
12 | # And to allow for more output tokens to be generated
13 | _CONTEXT_LENGTH_TRUST = 500
14 |
15 | _llm: BaseLLM
16 | _tokenizer: Tokenizer
17 | _context_length: int
18 | _custom_instructions: str | None
19 |
20 | def __init__(
21 | self,
22 | llm: BaseLLM,
23 | config: ModelConfig,
24 | custom_instructions: str | None,
25 | ) -> None:
26 | assert (
27 | config.tokenizer is not None
28 | ), "Tokenizer must be provided for sub-agents."
29 | self._llm = llm
30 | self._tokenizer = config.tokenizer
31 | self._context_length = config.context_length
32 | self._custom_instructions = custom_instructions
33 |
34 | @abc.abstractmethod
35 | async def __call__(self, *args, **kwargs) -> str:
36 | pass
37 |
38 | def check_context_length(
39 | self, messages: list[BaseMessage], context: str
40 | ) -> str | None:
41 | """
42 | Checks if the final length of the messages is greater than the context length,
43 | and if so, removes the excess tokens from the context. If no, return None.
44 | """
45 | text = get_buffer_string(messages)
46 | total_length = len(self._tokenizer.encode(text))
47 | length_to_remove = total_length - self._context_length
48 |
49 | if length_to_remove <= 0:
50 | return None
51 |
52 | context = self._tokenizer.decode(
53 | self._tokenizer.encode(context)[
54 | : -length_to_remove - self._CONTEXT_LENGTH_TRUST
55 | ]
56 | )
57 |
58 | return context
59 |
--------------------------------------------------------------------------------
/src/tiny_agent/tiny_agent.py:
--------------------------------------------------------------------------------
1 | from src.llm_compiler.constants import END_OF_PLAN, SUMMARY_RESULT
2 | from src.llm_compiler.llm_compiler import LLMCompiler
3 | from src.llm_compiler.planner import generate_llm_compiler_prompt
4 | from src.tiny_agent.computer import Computer
5 | from src.tiny_agent.config import TinyAgentConfig
6 | from src.tiny_agent.prompts import (
7 | DEFAULT_PLANNER_IN_CONTEXT_EXAMPLES_PROMPT,
8 | OUTPUT_PROMPT,
9 | OUTPUT_PROMPT_FINAL,
10 | PLANNER_PROMPT_REPLAN,
11 | get_planner_custom_instructions_prompt,
12 | )
13 | from src.tiny_agent.sub_agents.compose_email_agent import ComposeEmailAgent
14 | from src.tiny_agent.sub_agents.notes_agent import NotesAgent
15 | from src.tiny_agent.sub_agents.pdf_summarizer_agent import PDFSummarizerAgent
16 | from src.tiny_agent.tiny_agent_tools import (
17 | get_tiny_agent_tools,
18 | get_tool_names_from_apps,
19 | )
20 | from src.tiny_agent.tool_rag.base_tool_rag import BaseToolRAG
21 | from src.tiny_agent.tool_rag.classifier_tool_rag import ClassifierToolRAG
22 | from src.utils.model_utils import get_embedding_model, get_model
23 |
24 |
25 | class TinyAgent:
26 | _DEFAULT_TOP_K = 6
27 |
28 | config: TinyAgentConfig
29 | agent: LLMCompiler
30 | computer: Computer
31 | notes_agent: NotesAgent
32 | pdf_summarizer_agent: PDFSummarizerAgent
33 | compose_email_agent: ComposeEmailAgent
34 | tool_rag: BaseToolRAG
35 |
36 | def __init__(self, config: TinyAgentConfig) -> None:
37 | self.config = config
38 |
39 | # Define the models
40 | llm = get_model(
41 | model_type=config.llmcompiler_config.model_type.value,
42 | model_name=config.llmcompiler_config.model_name,
43 | api_key=config.llmcompiler_config.api_key,
44 | stream=False,
45 | vllm_port=config.llmcompiler_config.port,
46 | temperature=0,
47 | azure_api_version=config.azure_api_version,
48 | azure_endpoint=config.azure_endpoint,
49 | azure_deployment=config.llmcompiler_config.model_name,
50 | )
51 | planner_llm = get_model(
52 | model_type=config.llmcompiler_config.model_type.value,
53 | model_name=config.llmcompiler_config.model_name,
54 | api_key=config.llmcompiler_config.api_key,
55 | stream=True,
56 | vllm_port=config.llmcompiler_config.port,
57 | temperature=0,
58 | azure_api_version=config.azure_api_version,
59 | azure_endpoint=config.azure_endpoint,
60 | azure_deployment=config.llmcompiler_config.model_name,
61 | )
62 | sub_agent_llm = get_model(
63 | model_type=config.sub_agent_config.model_type.value,
64 | model_name=config.sub_agent_config.model_name,
65 | api_key=config.sub_agent_config.api_key,
66 | stream=False,
67 | vllm_port=config.sub_agent_config.port,
68 | temperature=0,
69 | azure_api_version=config.azure_api_version,
70 | azure_endpoint=config.azure_endpoint,
71 | azure_deployment=config.sub_agent_config.model_name,
72 | )
73 |
74 | self.computer = Computer()
75 | self.notes_agent = NotesAgent(
76 | sub_agent_llm, config.sub_agent_config, config.custom_instructions
77 | )
78 | self.pdf_summarizer_agent = PDFSummarizerAgent(
79 | sub_agent_llm, config.sub_agent_config, config.custom_instructions
80 | )
81 | self.compose_email_agent = ComposeEmailAgent(
82 | sub_agent_llm, config.sub_agent_config, config.custom_instructions
83 | )
84 |
85 | tools = get_tiny_agent_tools(
86 | computer=self.computer,
87 | notes_agent=self.notes_agent,
88 | pdf_summarizer_agent=self.pdf_summarizer_agent,
89 | compose_email_agent=self.compose_email_agent,
90 | tool_names=get_tool_names_from_apps(config.apps),
91 | zoom_access_token=config.zoom_access_token,
92 | )
93 |
94 | # Define LLMCompiler
95 | self.agent = LLMCompiler(
96 | tools=tools,
97 | planner_llm=planner_llm,
98 | planner_custom_instructions_prompt=get_planner_custom_instructions_prompt(
99 | tools=tools, custom_instructions=config.custom_instructions
100 | ),
101 | planner_example_prompt=DEFAULT_PLANNER_IN_CONTEXT_EXAMPLES_PROMPT,
102 | planner_example_prompt_replan=PLANNER_PROMPT_REPLAN,
103 | planner_stop=[END_OF_PLAN],
104 | planner_stream=True,
105 | agent_llm=llm,
106 | joinner_prompt=OUTPUT_PROMPT,
107 | joinner_prompt_final=OUTPUT_PROMPT_FINAL,
108 | max_replans=2,
109 | benchmark=False,
110 | )
111 |
112 | # Define ToolRAG
113 | if config.embedding_model_config is not None:
114 | embedding_model = get_embedding_model(
115 | model_type=config.embedding_model_config.model_type.value,
116 | model_name=config.embedding_model_config.model_name,
117 | api_key=config.embedding_model_config.api_key,
118 | azure_endpoint=config.azure_endpoint,
119 | azure_embedding_deployment=config.embedding_model_config.model_name,
120 | azure_api_version=config.azure_api_version,
121 | local_port=config.embedding_model_config.port,
122 | context_length=config.embedding_model_config.context_length,
123 | )
124 | self.tool_rag = ClassifierToolRAG(
125 | embedding_model=embedding_model,
126 | tools=tools,
127 | )
128 |
129 | async def arun(self, query: str) -> str:
130 | if self.config.embedding_model_config is not None:
131 | tool_rag_results = self.tool_rag.retrieve_examples_and_tools(
132 | query, top_k=TinyAgent._DEFAULT_TOP_K
133 | )
134 |
135 | new_tools = get_tiny_agent_tools(
136 | computer=self.computer,
137 | notes_agent=self.notes_agent,
138 | pdf_summarizer_agent=self.pdf_summarizer_agent,
139 | compose_email_agent=self.compose_email_agent,
140 | tool_names=tool_rag_results.retrieved_tools_set,
141 | zoom_access_token=self.config.zoom_access_token,
142 | )
143 |
144 | self.agent.planner.system_prompt = generate_llm_compiler_prompt(
145 | tools=new_tools,
146 | example_prompt=tool_rag_results.in_context_examples_prompt,
147 | custom_instructions=get_planner_custom_instructions_prompt(
148 | tools=new_tools, custom_instructions=self.config.custom_instructions
149 | ),
150 | )
151 |
152 | self.compose_email_agent.query = query
153 | result = await self.agent.arun(query)
154 |
155 | if result == SUMMARY_RESULT:
156 | result = self.pdf_summarizer_agent.cached_summary_result
157 |
158 | return result
159 |
--------------------------------------------------------------------------------
/src/tiny_agent/tool_rag/base_tool_rag.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import os
3 | import pickle
4 | from dataclasses import dataclass
5 | from typing import Collection, Sequence
6 |
7 | import torch
8 | from langchain_community.embeddings import HuggingFaceEmbeddings
9 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
10 | from typing_extensions import TypedDict
11 |
12 | from src.tiny_agent.config import DEFAULT_OPENAI_EMBEDDING_MODEL
13 | from src.tiny_agent.models import TinyAgentToolName
14 | from src.tools.base import StructuredTool, Tool
15 |
16 | TOOLRAG_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
17 |
18 |
19 | @dataclass
20 | class ToolRAGResult:
21 | in_context_examples_prompt: str
22 | retrieved_tools_set: Collection[TinyAgentToolName]
23 |
24 |
25 | class PickledEmbedding(TypedDict):
26 | example: str
27 | embedding: torch.Tensor
28 | tools: Sequence[str]
29 |
30 |
31 | class BaseToolRAG(abc.ABC):
32 | """
33 | The base class for the ToolRAGs that are used to retrieve the in-context examples and tools based on the user query.
34 | """
35 |
36 | _EMBEDDINGS_DIR_PATH = os.path.join(TOOLRAG_DIR_PATH)
37 | _EMBEDDINGS_FILE_NAME = "embeddings.pkl"
38 |
39 | # Embedding model that computes the embeddings for the examples/tools and the user query
40 | _embedding_model: AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings
41 | # The set of available tools so that we do an initial filtering based on the tools that are available
42 | _available_tools: Sequence[TinyAgentToolName]
43 | # The path to the embeddings.pkl file
44 | _embeddings_pickle_path: str
45 |
46 | def __init__(
47 | self,
48 | embedding_model: (
49 | AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings
50 | ),
51 | tools: Sequence[Tool | StructuredTool],
52 | ) -> None:
53 | self._embedding_model = embedding_model
54 | self._available_tools = [TinyAgentToolName(tool.name) for tool in tools]
55 |
56 | # TinyAgent currently only supports "text-embedding-3-small" model by default.
57 | # Hence, we only use the directory created for the default model.
58 | model_name = DEFAULT_OPENAI_EMBEDDING_MODEL
59 | self._embeddings_pickle_path = os.path.join(
60 | BaseToolRAG._EMBEDDINGS_DIR_PATH,
61 | model_name.split("/")[-1], # Only use the last model name
62 | BaseToolRAG._EMBEDDINGS_FILE_NAME,
63 | )
64 |
65 | @property
66 | @abc.abstractmethod
67 | def tool_rag_type(self) -> str:
68 | pass
69 |
70 | @abc.abstractmethod
71 | def retrieve_examples_and_tools(self, query: str, top_k: int) -> ToolRAGResult:
72 | """
73 | Returns the in-context examples as a formatted prompt and the tools that are relevant to the query.
74 | """
75 | pass
76 |
77 | def _retrieve_top_k_embeddings(
78 | self, query: str, examples: list[PickledEmbedding], top_k: int
79 | ) -> list[PickledEmbedding]:
80 | """
81 | Computes the cosine similarity of each example and retrieves the closest top_k examples.
82 | If there are already less than top_k examples, returns the examples directly.
83 | """
84 | if len(examples) <= top_k:
85 | return examples
86 |
87 | query_embedding = torch.tensor(self._embedding_model.embed_query(query))
88 | embeddings = torch.stack(
89 | [x["embedding"] for x in examples]
90 | ) # Stacking for batch processing
91 |
92 | # Cosine similarity between query_embedding and all chunks
93 | cosine_similarities = torch.nn.functional.cosine_similarity(
94 | embeddings, query_embedding.unsqueeze(0), dim=1
95 | )
96 |
97 | # Retrieve the top k indices from cosine_similarities
98 | _, top_k_indices = torch.topk(cosine_similarities, top_k)
99 |
100 | # Select the chunks corresponding to the top k indices
101 | selected_examples = [examples[i] for i in top_k_indices]
102 |
103 | return selected_examples
104 |
105 | def _load_filtered_embeddings(
106 | self, filter_tools: list[TinyAgentToolName] | None = None
107 | ) -> list[PickledEmbedding]:
108 | """
109 | Loads the embeddings.pkl file that contains a list of PickledEmbedding objects
110 | and returns the filtered results based on the available tools.
111 | """
112 | with open(self._embeddings_pickle_path, "rb") as file:
113 | embeddings: dict[str, PickledEmbedding] = pickle.load(file)
114 |
115 | filtered_embeddings = []
116 | tool_names = [tool.value for tool in filter_tools or self._available_tools]
117 | for embedding in embeddings.values():
118 | # Check if all tools are available in this example
119 | if all(tool in tool_names for tool in embedding["tools"]):
120 | filtered_embeddings.append(embedding)
121 |
122 | return filtered_embeddings
123 |
124 | @staticmethod
125 | def _get_in_context_examples_prompt(embeddings: list[PickledEmbedding]) -> str:
126 | examples = [example["example"] for example in embeddings]
127 | examples_prompt = "###\n".join(examples)
128 | return f"{examples_prompt}###\n"
129 |
--------------------------------------------------------------------------------
/src/tiny_agent/tool_rag/classifier_tool_rag.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Sequence
3 |
4 | import torch
5 | from langchain_community.embeddings import HuggingFaceEmbeddings
6 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
7 | from transformers import (
8 | AutoModelForSequenceClassification,
9 | AutoTokenizer,
10 | PreTrainedTokenizer,
11 | PreTrainedTokenizerFast,
12 | )
13 |
14 | from src.tiny_agent.models import TinyAgentToolName
15 | from src.tiny_agent.tool_rag.base_tool_rag import BaseToolRAG, ToolRAGResult
16 | from src.tools.base import StructuredTool, Tool
17 |
18 |
19 | class ClassifierToolRAG(BaseToolRAG):
20 | _CLASSIFIER_MODEL_NAME = "squeeze-ai-lab/TinyAgent-ToolRAG"
21 | _DEFAULT_TOOL_THRESHOLD = 0.5
22 | _NUM_LABELS = 16
23 | _ID_TO_TOOL = {
24 | 0: TinyAgentToolName.CREATE_CALENDAR_EVENT,
25 | 1: TinyAgentToolName.GET_PHONE_NUMBER,
26 | 2: TinyAgentToolName.GET_EMAIL_ADDRESS,
27 | 3: TinyAgentToolName.OPEN_AND_GET_FILE_PATH,
28 | 4: TinyAgentToolName.SUMMARIZE_PDF,
29 | 5: TinyAgentToolName.COMPOSE_NEW_EMAIL,
30 | 6: TinyAgentToolName.REPLY_TO_EMAIL,
31 | 7: TinyAgentToolName.FORWARD_EMAIL,
32 | 8: TinyAgentToolName.MAPS_OPEN_LOCATION,
33 | 9: TinyAgentToolName.MAPS_SHOW_DIRECTIONS,
34 | 10: TinyAgentToolName.CREATE_NOTE,
35 | 11: TinyAgentToolName.OPEN_NOTE,
36 | 12: TinyAgentToolName.APPEND_NOTE_CONTENT,
37 | 13: TinyAgentToolName.CREATE_REMINDER,
38 | 14: TinyAgentToolName.SEND_SMS,
39 | 15: TinyAgentToolName.GET_ZOOM_MEETING_LINK,
40 | }
41 |
42 | _tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast
43 | _classifier_model: Any
44 | _tool_threshold: float
45 |
46 | def __init__(
47 | self,
48 | embedding_model: (
49 | AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings
50 | ),
51 | tools: Sequence[Tool | StructuredTool],
52 | tool_threshold: float = _DEFAULT_TOOL_THRESHOLD,
53 | ):
54 | super().__init__(embedding_model, tools)
55 |
56 | self._tokenizer = AutoTokenizer.from_pretrained(
57 | ClassifierToolRAG._CLASSIFIER_MODEL_NAME
58 | )
59 | self._classifier_model = AutoModelForSequenceClassification.from_pretrained(
60 | ClassifierToolRAG._CLASSIFIER_MODEL_NAME,
61 | num_labels=ClassifierToolRAG._NUM_LABELS,
62 | )
63 | self._tool_threshold = tool_threshold
64 |
65 | @property
66 | def tool_rag_type(self) -> str:
67 | return "classifier_tool_rag"
68 |
69 | def retrieve_examples_and_tools(self, query: str, top_k: int) -> ToolRAGResult:
70 | """
71 | Returns the in-context examples as a formatted prompt and the tools that are relevant to the query.
72 | It first retrieves the best tools for the given query and then it retrieves top k examples
73 | that use the retrieved tools.
74 | """
75 | retrieved_tools = self._classify_tools(query)
76 | # Filter the tools that are available
77 | retrieved_tools = list(set(retrieved_tools) & set(self._available_tools))
78 | filtered_embeddings = self._load_filtered_embeddings(retrieved_tools)
79 | retrieved_embeddings = self._retrieve_top_k_embeddings(
80 | query, filtered_embeddings, top_k
81 | )
82 |
83 | in_context_examples_prompt = BaseToolRAG._get_in_context_examples_prompt(
84 | retrieved_embeddings
85 | )
86 |
87 | return ToolRAGResult(
88 | in_context_examples_prompt=in_context_examples_prompt,
89 | retrieved_tools_set=retrieved_tools,
90 | )
91 |
92 | def _classify_tools(self, query: str) -> list[TinyAgentToolName]:
93 | """
94 | Retrieves the best tools for the given query by classification.
95 | """
96 | inputs = self._tokenizer(
97 | query, return_tensors="pt", padding=True, truncation=True, max_length=512
98 | )
99 |
100 | # Get the output probabilities
101 | with torch.no_grad():
102 | outputs = self._classifier_model(**inputs)
103 | logits = outputs.logits
104 | probs = torch.sigmoid(logits)
105 |
106 | # Retrieve the tools that have a probability greater than the threshold
107 | retrieved_tools = [
108 | ClassifierToolRAG._ID_TO_TOOL[i]
109 | for i, prob in enumerate(probs[0])
110 | if prob > self._tool_threshold
111 | ]
112 |
113 | return retrieved_tools
114 |
--------------------------------------------------------------------------------
/src/tiny_agent/tool_rag/simple_tool_rag.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from langchain_community.embeddings import HuggingFaceEmbeddings
4 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
5 |
6 | from src.tiny_agent.models import TinyAgentToolName
7 | from src.tiny_agent.tool_rag.base_tool_rag import BaseToolRAG, ToolRAGResult
8 | from src.tools.base import StructuredTool, Tool
9 |
10 |
11 | class SimpleToolRAG(BaseToolRAG):
12 | def __init__(
13 | self,
14 | embedding_model: (
15 | AzureOpenAIEmbeddings | OpenAIEmbeddings | HuggingFaceEmbeddings
16 | ),
17 | tools: Sequence[Tool | StructuredTool],
18 | ):
19 | super().__init__(embedding_model, tools)
20 |
21 | @property
22 | def tool_rag_type(self) -> str:
23 | return "simple_tool_rag"
24 |
25 | def retrieve_examples_and_tools(self, query: str, top_k: int) -> ToolRAGResult:
26 | """
27 | Returns the in-context examples as a formatted prompt and the tools that are relevant to the query
28 | It first filters the examples based on the tools that are available and then retrieves the examples
29 | and tools based on the query.
30 | """
31 | filtered_embeddings = self._load_filtered_embeddings()
32 | retrieved_embeddings = self._retrieve_top_k_embeddings(
33 | query, filtered_embeddings, top_k
34 | )
35 | in_context_examples_prompt = BaseToolRAG._get_in_context_examples_prompt(
36 | retrieved_embeddings
37 | )
38 |
39 | tools_names = set(
40 | sum(
41 | [
42 | [TinyAgentToolName(tool) for tool in example["tools"]]
43 | for example in retrieved_embeddings
44 | ],
45 | [],
46 | )
47 | )
48 |
49 | return ToolRAGResult(
50 | in_context_examples_prompt=in_context_examples_prompt,
51 | retrieved_tools_set=tools_names,
52 | )
53 |
--------------------------------------------------------------------------------
/src/tiny_agent/tool_rag/text-embedding-3-small/embeddings.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SqueezeAILab/TinyAgent/cc45c0e842f5d163c3df1c8f41d60e90e005867d/src/tiny_agent/tool_rag/text-embedding-3-small/embeddings.pkl
--------------------------------------------------------------------------------
/src/tiny_agent/tools/calendar.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import platform
3 | import subprocess
4 |
5 | from src.tiny_agent.run_apple_script import run_applescript, run_applescript_capture
6 |
7 |
8 | class Calendar:
9 | def __init__(self):
10 | self.calendar_app = "Calendar"
11 |
12 | def create_event(
13 | self,
14 | title: str,
15 | start_date: datetime.datetime,
16 | end_date: datetime.datetime,
17 | location: str = "",
18 | invitees: list[str] = [],
19 | notes: str = "",
20 | calendar: str | None = None,
21 | ) -> str:
22 | """
23 | Creates a new event with the given title, start date, end date, location, and notes.
24 | """
25 | if platform.system() != "Darwin":
26 | return "This method is only supported on MacOS"
27 |
28 | applescript_start_date = start_date.strftime("%B %d, %Y %I:%M:%S %p")
29 | applescript_end_date = end_date.strftime("%B %d, %Y %I:%M:%S %p")
30 |
31 | # Check if the given calendar parameter is valid
32 | if calendar is not None:
33 | script = f"""
34 | tell application "{self.calendar_app}"
35 | set calendarExists to name of calendars contains "{calendar}"
36 | end tell
37 | """
38 | exists = run_applescript(script)
39 | if not exists:
40 | calendar = self._get_first_calendar()
41 | if calendar is None:
42 | return f"Can't find the calendar named {calendar}. Please try again and specify a valid calendar name."
43 |
44 | # If it is not provded, default to the first calendar
45 | elif calendar is None:
46 | calendar = self._get_first_calendar()
47 | if calendar is None:
48 | return "Can't find a default calendar. Please try again and specify a calendar name."
49 |
50 | invitees_script = []
51 | for invitee in invitees:
52 | invitees_script.append(
53 | f"""
54 | make new attendee at theEvent with properties {{email:"{invitee}"}}
55 | """
56 | )
57 | invitees_script = "".join(invitees_script)
58 |
59 | script = f"""
60 | tell application "System Events"
61 | set calendarIsRunning to (name of processes) contains "{self.calendar_app}"
62 | if calendarIsRunning then
63 | tell application "{self.calendar_app}" to activate
64 | else
65 | tell application "{self.calendar_app}" to launch
66 | delay 1
67 | tell application "{self.calendar_app}" to activate
68 | end if
69 | end tell
70 | tell application "{self.calendar_app}"
71 | tell calendar "{calendar}"
72 | set startDate to date "{applescript_start_date}"
73 | set endDate to date "{applescript_end_date}"
74 | set theEvent to make new event at end with properties {{summary:"{title}", start date:startDate, end date:endDate, location:"{location}", description:"{notes}"}}
75 | {invitees_script}
76 | switch view to day view
77 | show theEvent
78 | end tell
79 | tell application "{self.calendar_app}" to reload calendars
80 | end tell
81 | """
82 |
83 | try:
84 | run_applescript(script)
85 | return f"""Event created successfully in the "{calendar}" calendar."""
86 | except subprocess.CalledProcessError as e:
87 | return str(e)
88 |
89 | def _get_first_calendar(self) -> str | None:
90 | script = f"""
91 | tell application "System Events"
92 | set calendarIsRunning to (name of processes) contains "{self.calendar_app}"
93 | if calendarIsRunning is false then
94 | tell application "{self.calendar_app}" to launch
95 | delay 1
96 | end if
97 | end tell
98 | tell application "{self.calendar_app}"
99 | set firstCalendarName to name of first calendar
100 | end tell
101 | return firstCalendarName
102 | """
103 | stdout = run_applescript_capture(script)
104 | if stdout:
105 | return stdout[0].strip()
106 | else:
107 | return None
108 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/contacts.py:
--------------------------------------------------------------------------------
1 | import platform
2 |
3 | from src.tiny_agent.run_apple_script import run_applescript_capture
4 |
5 |
6 | class Contacts:
7 | def __init__(self):
8 | pass
9 |
10 | def get_phone_number(self, contact_name: str) -> str:
11 | """
12 | Returns the phone number of a contact by name.
13 | """
14 | if platform.system() != "Darwin":
15 | return "This method is only supported on MacOS"
16 |
17 | script = f"""
18 | tell application "Contacts"
19 | set thePerson to first person whose name is "{contact_name}"
20 | set theNumber to value of first phone of thePerson
21 | return theNumber
22 | end tell
23 | """
24 | stout, stderr = run_applescript_capture(script)
25 | # If the person is not found, we will try to find similar contacts
26 | if "Can’t get person" in stderr:
27 | first_name = contact_name.split(" ")[0]
28 | names = self.get_full_names_from_first_name(first_name)
29 | if "No contacts found" in names or len(names) == 0:
30 | return "No contacts found"
31 | else:
32 | # Just find the first person
33 | return self.get_phone_number(names[0])
34 | else:
35 | return stout.replace("\n", "")
36 |
37 | def get_email_address(self, contact_name: str) -> str:
38 | """
39 | Returns the email address of a contact by name.
40 | """
41 | if platform.system() != "Darwin":
42 | return "This method is only supported on MacOS"
43 |
44 | script = f"""
45 | tell application "Contacts"
46 | set thePerson to first person whose name is "{contact_name}"
47 | set theEmail to value of first email of thePerson
48 | return theEmail
49 | end tell
50 | """
51 | stout, stderr = run_applescript_capture(script)
52 | # If the person is not found, we will try to find similar contacts
53 | if "Can’t get person" in stderr:
54 | first_name = contact_name.split(" ")[0]
55 | names = self.get_full_names_from_first_name(first_name)
56 | if "No contacts found" in names or len(names) == 0:
57 | return "No contacts found"
58 | else:
59 | # Just find the first person
60 | return self.get_email_address(names[0])
61 | else:
62 | return stout.replace("\n", "")
63 |
64 | def get_full_names_from_first_name(self, first_name: str) -> list[str] | str:
65 | """
66 | Returns a list of full names of contacts that contain the first name provided.
67 | """
68 | if platform.system() != "Darwin":
69 | return "This method is only supported on MacOS"
70 |
71 | script = f"""
72 | tell application "Contacts"
73 | set matchingPeople to every person whose first name contains "{first_name}"
74 | set namesList to {{}}
75 | repeat with aPerson in matchingPeople
76 | set end of namesList to name of aPerson
77 | end repeat
78 | return namesList
79 | end tell
80 | """
81 | names, _ = run_applescript_capture(script)
82 | names = names.strip()
83 | if len(names) > 0:
84 | # Turn name into a list of strings
85 | names = list(map(lambda n: n.strip(), names.split(",")))
86 | return names
87 | else:
88 | return "No contacts found."
89 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/mail.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import subprocess
3 |
4 | from src.tiny_agent.run_apple_script import run_applescript, run_applescript_capture
5 |
6 |
7 | class Mail:
8 | def __init__(self) -> None:
9 | self.mail_app: str = "Mail"
10 |
11 | def compose_email(
12 | self,
13 | recipients: list[str],
14 | subject: str,
15 | content: str,
16 | attachments: list[str],
17 | cc: list[str],
18 | ) -> str:
19 | """
20 | Composes a new email with the given recipients, subject, content, and attaches files from the given paths.
21 | Adds cc recipients if provided. Does not send it but opens the composed email to the user.
22 | """
23 | if platform.system() != "Darwin":
24 | return "This method is only supported on MacOS"
25 |
26 | # Format recipients and cc recipients for AppleScript list
27 | recipients_list = Mail._format_email_addresses(recipients)
28 | cc_list = Mail._format_email_addresses(cc)
29 | attachments_str = Mail._format_attachments(attachments)
30 |
31 | content = content.replace('"', '\\"').replace("'", "’")
32 | script = f"""
33 | tell application "{self.mail_app}"
34 | set newMessage to make new outgoing message with properties {{subject:"{subject}", content:"{content}" & return & return}}
35 | tell newMessage
36 | repeat with address in {recipients_list}
37 | make new to recipient at end of to recipients with properties {{address:address}}
38 | end repeat
39 | repeat with address in {cc_list}
40 | make new cc recipient at end of cc recipients with properties {{address:address}}
41 | end repeat
42 | {attachments_str}
43 | end tell
44 | activate
45 | end tell
46 | """
47 |
48 | try:
49 | run_applescript(script)
50 | return "New email composed successfully with attachments and cc."
51 | except subprocess.CalledProcessError as e:
52 | return str(e)
53 |
54 | def reply_to_email(
55 | self, content: str, cc: list[str], attachments: list[str]
56 | ) -> str:
57 | """
58 | Replies to the currently selected email in Mail with the given content.
59 | """
60 | if platform.system() != "Darwin":
61 | return "This method is only supported on MacOS"
62 |
63 | cc_list = Mail._format_email_addresses(cc)
64 | attachments_str = Mail._format_attachments(attachments)
65 |
66 | content = content.replace('"', '\\"').replace("'", "’")
67 | script = f"""
68 | tell application "{self.mail_app}"
69 | activate
70 | set selectedMessages to selected messages of message viewer 1
71 | if (count of selectedMessages) < 1 then
72 | return "No message selected."
73 | else
74 | set theMessage to item 1 of selectedMessages
75 | set theReply to reply theMessage opening window yes
76 | tell theReply
77 | repeat with address in {cc_list}
78 | make new cc recipient at end of cc recipients with properties {{address:address}}
79 | end repeat
80 | set content to "{content}"
81 | {attachments_str}
82 | end tell
83 | end if
84 | end tell
85 | """
86 |
87 | try:
88 | run_applescript(script)
89 | return "Replied to the selected email successfully."
90 | except subprocess.CalledProcessError as e:
91 | return "An email has to be viewed in Mail to reply to it."
92 |
93 | def forward_email(
94 | self, recipients: list[str], cc: list[str], attachments: list[str]
95 | ) -> str:
96 | """
97 | Forwards the currently selected email in Mail to the given recipients with the given content.
98 | """
99 | if platform.system() != "Darwin":
100 | return "This method is only supported on MacOS"
101 |
102 | # Format recipients and cc recipients for AppleScript list
103 | recipients_list = Mail._format_email_addresses(recipients)
104 | cc_list = Mail._format_email_addresses(cc)
105 | attachments_str = Mail._format_attachments(attachments)
106 |
107 | script = f"""
108 | tell application "{self.mail_app}"
109 | activate
110 | set selectedMessages to selected messages of message viewer 1
111 | if (count of selectedMessages) < 1 then
112 | return "No message selected."
113 | else
114 | set theMessage to item 1 of selectedMessages
115 | set theForward to forward theMessage opening window yes
116 | tell theForward
117 | repeat with address in {recipients_list}
118 | make new to recipient at end of to recipients with properties {{address:address}}
119 | end repeat
120 | repeat with address in {cc_list}
121 | make new cc recipient at end of cc recipients with properties {{address:address}}
122 | end repeat
123 | {attachments_str}
124 | end tell
125 | end if
126 | end tell
127 | """
128 |
129 | try:
130 | run_applescript(script)
131 | return "Forwarded the selected email successfully."
132 | except subprocess.CalledProcessError as e:
133 | return "An email has to be viewed in Mail to forward it."
134 |
135 | def get_email_content(self) -> str:
136 | """
137 | Gets the content of the currently viewed email in Mail.
138 | """
139 | if platform.system() != "Darwin":
140 | return "This method is only supported on MacOS"
141 |
142 | script = f"""
143 | tell application "{self.mail_app}"
144 | activate
145 | set selectedMessages to selected messages of message viewer 1
146 | if (count of selectedMessages) < 1 then
147 | return "No message selected."
148 | else
149 | set theMessage to item 1 of selectedMessages
150 | -- Get the content of the message
151 | set theContent to content of theMessage
152 | return theContent
153 | end if
154 | end tell
155 | """
156 |
157 | try:
158 | return run_applescript(script)
159 | except subprocess.CalledProcessError as e:
160 | return "No message selected or found."
161 |
162 | def find_and_select_first_email_from(self, sender: str) -> str:
163 | """
164 | Finds and selects an email in Mail based on the sender's name.
165 | """
166 | if platform.system() != "Darwin":
167 | return "This method is only supported on MacOS"
168 |
169 | script = f"""
170 | tell application "{self.mail_app}"
171 | set theSender to "{sender}"
172 | set theMessage to first message of inbox whose sender contains theSender
173 | set selected messages of message viewer 1 to {{theMessage}}
174 | activate
175 | open theMessage
176 | end tell
177 | """
178 |
179 | try:
180 | run_applescript(script)
181 | return "Found and selected the email successfully."
182 | except subprocess.CalledProcessError as e:
183 | return "No message found from the sender."
184 |
185 | @staticmethod
186 | def _format_email_addresses(emails: list[str]) -> str:
187 | return "{" + ", ".join([f'"{email}"' for email in emails]) + "}"
188 |
189 | @staticmethod
190 | def _format_attachments(attachments: list[str]) -> str:
191 | attachments_str = []
192 | for attachment in attachments:
193 | attachment = attachment.replace('"', '\\"')
194 | attachments_str.append(
195 | f"""
196 | make new attachment with properties {{file name:"{attachment}"}} at after the last paragraph
197 | """
198 | )
199 | return "".join(attachments_str)
200 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/maps.py:
--------------------------------------------------------------------------------
1 | import webbrowser
2 | from urllib.parse import quote_plus
3 |
4 | from src.tiny_agent.models import TransportationOptions
5 |
6 |
7 | class Maps:
8 | def __init__(self):
9 | pass
10 |
11 | def open_location(self, query: str):
12 | """
13 | Opens the specified location in Apple Maps.
14 | The query can be a place name, address, or coordinates.
15 | """
16 | base_url = "https://maps.apple.com/?q="
17 | query_encoded = quote_plus(query)
18 | full_url = base_url + query_encoded
19 | webbrowser.open(full_url)
20 | return f"Location of {query} in Apple Maps: {full_url}"
21 |
22 | def show_directions(
23 | self,
24 | end: str,
25 | start: str = "",
26 | transport: TransportationOptions = TransportationOptions.DRIVING,
27 | ):
28 | """
29 | Shows directions from a start location to an end location in Apple Maps.
30 | The transport parameter defaults to 'd' (driving), but can also be 'w' (walking) or 'r' (public transit).
31 | The start location can be left empty to default to the current location of the device.
32 | """
33 | base_url = "https://maps.apple.com/?"
34 | if len(start) > 0:
35 | start_encoded = quote_plus(start)
36 | start_param = f"saddr={start_encoded}&"
37 | else:
38 | start_param = "" # Use the current location
39 | end_encoded = quote_plus(end)
40 | transport_flag = f"dirflg={transport.value}"
41 | full_url = f"{base_url}{start_param}daddr={end_encoded}&{transport_flag}"
42 | webbrowser.open(full_url)
43 | return f"Directions to {end} in Apple Maps: {full_url}"
44 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/notes.py:
--------------------------------------------------------------------------------
1 | import difflib
2 | import platform
3 | import subprocess
4 |
5 | from bs4 import BeautifulSoup
6 |
7 | from src.tiny_agent.run_apple_script import run_applescript, run_applescript_capture
8 |
9 |
10 | class Notes:
11 | _DEFAULT_FOLDER = "Notes"
12 |
13 | def __init__(self):
14 | self.notes_app = "Notes"
15 |
16 | def create_note(self, name: str, content: str, folder: str | None = None) -> str:
17 | """
18 | Creates a new note with the given content and focuses on it. If a folder is specified, the note
19 | is created in that folder; otherwise, it's created in the default folder.
20 | """
21 | if platform.system() != "Darwin":
22 | return "This method is only supported on MacOS"
23 |
24 | folder_line = self._get_folder_line(folder)
25 | html_content = content.replace('"', '\\"').replace("'", "’")
26 |
27 | script = f"""
28 | tell application "{self.notes_app}"
29 | tell account "iCloud"
30 | {folder_line}
31 | set newNote to make new note with properties {{body:"{html_content}"}}
32 | end tell
33 | end tell
34 | activate
35 | tell application "System Events"
36 | tell process "Notes"
37 | set frontmost to true
38 | delay 0.5 -- wait a bit for the note to be created and focus to be set
39 | end tell
40 | end tell
41 | tell application "{self.notes_app}"
42 | show newNote
43 | end tell
44 | end tell
45 | """
46 |
47 | try:
48 | run_applescript(script)
49 | return "Note created and focused successfully."
50 | except subprocess.CalledProcessError as e:
51 | return str(e)
52 |
53 | def open_note(
54 | self,
55 | name: str,
56 | folder: str | None = None,
57 | return_content: bool = False,
58 | ) -> str:
59 | """
60 | Opens an existing note by its name and optionally returns its content.
61 | If no exact match is found, attempts fuzzy matching to suggest possible notes.
62 | If return_content is True, returns the content of the note.
63 | """
64 | if platform.system() != "Darwin":
65 | return "This method is only supported on MacOS"
66 |
67 | folder_line = self._get_folder_line(folder)
68 |
69 | # Adjust the script to return content if required
70 | content_line = (
71 | "return body of theNote"
72 | if return_content
73 | else 'return "Note opened successfully."'
74 | )
75 |
76 | # Attempt to directly open the note with the exact name and optionally return its content
77 | script_direct_open = f"""
78 | tell application "{self.notes_app}"
79 | tell account "iCloud"
80 | {folder_line}
81 | set matchingNotes to notes whose name is "{name}"
82 | if length of matchingNotes > 0 then
83 | set theNote to item 1 of matchingNotes
84 | show theNote
85 | {content_line}
86 | else
87 | return "No exact match found."
88 | end if
89 | end tell
90 | end tell
91 | end tell
92 | """
93 |
94 | try:
95 | stdout, _ = run_applescript_capture(script_direct_open)
96 | if (
97 | "Note opened successfully" in stdout
98 | or "No exact match found" not in stdout
99 | ):
100 | if return_content:
101 | return self._convert_note_to_text(stdout.strip())
102 | return stdout.strip() # Successfully opened a note with the exact name
103 |
104 | # If no exact match is found, proceed with fuzzy matching
105 | note_to_open = self._do_fuzzy_matching(name)
106 |
107 | # Focus the note with the closest matching name after fuzzy matching
108 | script_focus = f"""
109 | tell application "{self.notes_app}"
110 | tell account "iCloud"
111 | {folder_line}
112 | set theNote to first note whose name is "{note_to_open}"
113 | show theNote
114 | {content_line}
115 | end tell
116 | end tell
117 | activate
118 | end tell
119 | """
120 | result = run_applescript(script_focus)
121 | if return_content:
122 | return self._convert_note_to_text(result.strip())
123 | return result.strip()
124 | except subprocess.CalledProcessError as e:
125 | return f"Error: {str(e)}"
126 |
127 | def append_to_note(
128 | self, name: str, append_content: str, folder: str | None = None
129 | ) -> str:
130 | """
131 | Appends content to an existing note by its name. If the exact name is not found,
132 | attempts fuzzy matching to find the closest note.
133 | """
134 | if platform.system() != "Darwin":
135 | return "This method is only supported on MacOS"
136 |
137 | folder_line = self._get_folder_line(folder)
138 |
139 | # Try to directly find and append to the note with the exact name
140 | script_find_note = f"""
141 | tell application "{self.notes_app}"
142 | tell account "iCloud"
143 | {folder_line}
144 | set matchingNotes to notes whose name is "{name}"
145 | if length of matchingNotes > 0 then
146 | set theNote to item 1 of matchingNotes
147 | return name of theNote
148 | else
149 | return "No exact match found."
150 | end if
151 | end tell
152 | end tell
153 | end tell
154 | """
155 |
156 | try:
157 | note_name, _ = run_applescript_capture(
158 | script_find_note.format(notes_app=self.notes_app, name=name)
159 | )
160 | note_name = note_name.strip()
161 |
162 | if "No exact match found" in note_name or not note_name:
163 | note_name = self._do_fuzzy_matching(name)
164 | if note_name == "No notes found after fuzzy matching.":
165 | return "No notes found after fuzzy matching."
166 |
167 | # If an exact match is found, append content to the note
168 | html_append_content = append_content.replace('"', '\\"').replace("'", "’")
169 | script_append = f"""
170 | tell application "{self.notes_app}"
171 | tell account "iCloud"
172 | {folder_line}
173 | set theNote to first note whose name is "{note_name}"
174 | set body of theNote to (body of theNote) & "
{html_append_content}"
175 | show theNote
176 | end tell
177 | end tell
178 | end tell
179 | """
180 |
181 | run_applescript(script_append)
182 | return f"Content appended to note '{name}' successfully."
183 | except subprocess.CalledProcessError as e:
184 | return f"Error: {str(e)}"
185 |
186 | def _get_folder_line(self, folder: str | None) -> str:
187 | if folder is not None and len(folder) > 0 and self._check_folder_exists(folder):
188 | return f'tell folder "{folder}"'
189 | return f'tell folder "{Notes._DEFAULT_FOLDER}"'
190 |
191 | def _do_fuzzy_matching(self, name: str) -> str:
192 | script_search = f"""
193 | tell application "{self.notes_app}"
194 | tell account "iCloud"
195 | set noteList to name of every note
196 | end tell
197 | end tell
198 | """
199 | note_names_str, _ = run_applescript_capture(script_search)
200 | note_names = note_names_str.split(", ")
201 | closest_matches = difflib.get_close_matches(name, note_names, n=1, cutoff=0.0)
202 | if not closest_matches:
203 | return "No notes found after fuzzy matching."
204 |
205 | note_to_open = closest_matches[0]
206 | return note_to_open
207 |
208 | def _check_folder_exists(self, folder: str) -> bool:
209 | # Adjusted script to optionally look for a folder
210 | folder_check_script = f"""
211 | tell application "{self.notes_app}"
212 | set folderExists to false
213 | set folderName to "{folder}"
214 | if folderName is not "" then
215 | repeat with eachFolder in folders
216 | if name of eachFolder is folderName then
217 | set folderExists to true
218 | exit repeat
219 | end if
220 | end repeat
221 | end if
222 | return folderExists
223 | end tell
224 | """
225 |
226 | folder_exists, _ = run_applescript_capture(folder_check_script)
227 | folder_exists = folder_exists.strip() == "true"
228 |
229 | return folder_exists
230 |
231 | @staticmethod
232 | def _convert_note_to_text(note_html: str) -> str:
233 | """
234 | Converts an HTML note content to plain text.
235 | """
236 | soup = BeautifulSoup(note_html, "html.parser")
237 | return soup.get_text().strip()
238 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/reminders.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import platform
3 | import subprocess
4 |
5 | from src.tiny_agent.run_apple_script import run_applescript
6 |
7 |
8 | class Reminders:
9 | def __init__(self):
10 | self.reminders_app = "Reminders"
11 |
12 | def create_reminder(
13 | self,
14 | name: str,
15 | due_date: datetime.datetime | None = None,
16 | notes: str = "",
17 | list_name: str = "",
18 | priority: int = 0,
19 | all_day: bool = False,
20 | ) -> str:
21 | if platform.system() != "Darwin":
22 | return "This method is only supported on MacOS"
23 |
24 | if due_date is not None:
25 | due_date_script = f''', due date:date "{due_date.strftime("%B %d, %Y")
26 | if all_day
27 | else due_date.strftime("%B %d, %Y %I:%M:%S %p")}"'''
28 | else:
29 | due_date_script = ""
30 |
31 | notes = notes.replace('"', '\\"').replace("'", "’")
32 | script = f"""
33 | tell application "{self.reminders_app}"
34 | set listExists to false
35 | set listName to "{list_name}"
36 | if listName is not "" then
37 | repeat with eachList in lists
38 | if name of eachList is listName then
39 | set listExists to true
40 | exit repeat
41 | end if
42 | end repeat
43 | end if
44 | if listExists then
45 | tell list "{list_name}"
46 | set newReminder to make new reminder with properties {{name:"{name}", body:"{notes}", priority:{priority}{due_date_script}}}
47 | activate
48 | show newReminder
49 | end tell
50 | else
51 | set newReminder to make new reminder with properties {{name:"{name}", body:"{notes}", priority:{priority}{due_date_script}}}
52 | activate
53 | show newReminder
54 | end if
55 | end tell
56 | """
57 |
58 | try:
59 | run_applescript(script)
60 | return f"Reminder '{name}' created successfully in the '{list_name}' list."
61 | except subprocess.CalledProcessError as e:
62 | return f"Failed to create reminder: {str(e)}"
63 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/sms.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import subprocess
3 |
4 | from src.tiny_agent.run_apple_script import run_applescript
5 |
6 |
7 | class SMS:
8 | def __init__(self):
9 | self.messages_app = "Messages"
10 |
11 | def send(self, to: list[str], message: str) -> str:
12 | """
13 | Opens an SMS draft to the specified recipient using the Messages app,
14 | without sending it, by simulating keystrokes.
15 | """
16 | if platform.system() != "Darwin":
17 | return "This method is only supported on MacOS"
18 |
19 | to_script = []
20 | for recipient in to:
21 | recipient = recipient.replace("\n", "")
22 | to_script.append(
23 | f"""
24 | keystroke "{recipient}"
25 | delay 0.5
26 | keystroke return
27 | delay 0.5
28 | """
29 | )
30 | to_script = "".join(to_script)
31 |
32 | escaped_message = message.replace('"', '\\"').replace("'", "’")
33 |
34 | script = f"""
35 | tell application "System Events"
36 | tell application "{self.messages_app}"
37 | activate
38 | end tell
39 | tell process "{self.messages_app}"
40 | set frontmost to true
41 | delay 0.5
42 | keystroke "n" using command down
43 | delay 0.5
44 | {to_script}
45 | keystroke tab
46 | delay 0.5
47 | keystroke "{escaped_message}"
48 | end tell
49 | end tell
50 | """
51 | try:
52 | run_applescript(script)
53 | return "SMS draft composed"
54 | except subprocess.CalledProcessError as e:
55 | return f"An error occurred while composing the SMS: {str(e)}"
56 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/spotlight_search.py:
--------------------------------------------------------------------------------
1 | import difflib
2 | import os # Import os to check if the path exists
3 | import platform
4 | import subprocess
5 |
6 | from src.tiny_agent.run_apple_script import run_command
7 |
8 |
9 | class SpotlightSearch:
10 | def __init__(self):
11 | pass
12 |
13 | def open(self, name_or_path: str) -> str:
14 | """
15 | Does Spotlight Search and opens the first thing that matches the name.
16 | If no exact match, performs fuzzy search.
17 | Additionally, if the input is a path, tries to open the file directly.
18 | """
19 | if platform.system() != "Darwin":
20 | return "This method is only supported on MacOS"
21 |
22 | # Check if input is a path and file exists
23 | if name_or_path.startswith("/") and os.path.exists(name_or_path):
24 | try:
25 | subprocess.run(["open", name_or_path])
26 | return name_or_path
27 | except Exception as e:
28 | return f"Error opening file: {e}"
29 |
30 | # Use mdfind for fast searching with Spotlight
31 | command_search_exact = ["mdfind", f"kMDItemDisplayName == '{name_or_path}'"]
32 | stdout, _ = run_command(command_search_exact)
33 |
34 | if stdout:
35 | paths = stdout.strip().split("\n")
36 | path = paths[0] if paths else None
37 | if path:
38 | subprocess.run(["open", path])
39 | return path
40 |
41 | # If no exact match, perform fuzzy search on the file names
42 | command_search_general = ["mdfind", name_or_path]
43 | stdout, stderr = run_command(command_search_general)
44 |
45 | paths = stdout.strip().split("\n") if stdout else []
46 |
47 | if paths:
48 | best_match = difflib.get_close_matches(name_or_path, paths, n=1, cutoff=0.0)
49 | if best_match:
50 | _, stderr = run_command(["open", best_match[0]])
51 | if len(stderr) > 0:
52 | return f"Error: {stderr}"
53 | return best_match[0]
54 | else:
55 | return "No file found after fuzzy matching."
56 | else:
57 | return "No file found with exact or fuzzy name."
58 |
--------------------------------------------------------------------------------
/src/tiny_agent/tools/zoom.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import Sequence, TypedDict
3 | from zoneinfo import ZoneInfo
4 |
5 | import aiohttp
6 | import dateutil.parser
7 |
8 |
9 | class Zoom:
10 | _TIMEZONE = "US/Pacific"
11 | _MEETINGS_ENDPOINT = "https://api.zoom.us/v2/users/me/meetings"
12 |
13 | class ZoomMeetingInfo(TypedDict):
14 | """
15 | Zoom meeting information returned by the create meeting endpoint.
16 | https://developers.zoom.us/docs/api/rest/reference/zoom-api/methods/#operation/meetingCreate
17 | Currently only types the fields we need, so add more if needed.
18 | """
19 |
20 | join_url: str
21 |
22 | def __init__(self, access_token: str) -> None:
23 | self._access_token = access_token
24 |
25 | async def get_meeting_link(
26 | self,
27 | topic: str,
28 | start_time: str,
29 | duration: int,
30 | meeting_invitees: Sequence[str],
31 | ) -> str:
32 | """
33 | Create a new Zoom meeting and return the join URL.
34 | """
35 | start_time_utc = (
36 | dateutil.parser.parse(start_time)
37 | .replace(tzinfo=ZoneInfo(Zoom._TIMEZONE))
38 | .astimezone(datetime.timezone.utc)
39 | .strftime("%Y-%m-%dT%H:%M:%SZ")
40 | )
41 |
42 | topic = topic[:200]
43 |
44 | resp = await aiohttp.ClientSession().post(
45 | Zoom._MEETINGS_ENDPOINT,
46 | headers={
47 | "Authorization": f"Bearer {self._access_token}",
48 | "Content-Type": "application/json",
49 | },
50 | json={
51 | "topic": topic,
52 | "start_time": start_time_utc,
53 | "duration": duration,
54 | "meeting_invitees": meeting_invitees,
55 | },
56 | )
57 |
58 | info: Zoom.ZoomMeetingInfo = await resp.json()
59 |
60 | return info["join_url"]
61 |
--------------------------------------------------------------------------------
/src/tiny_agent/transcription.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import audioop
3 | import io
4 | import json
5 | import wave
6 | from dataclasses import dataclass
7 |
8 | import httpx
9 | from openai import AsyncOpenAI
10 |
11 | from src.tiny_agent.models import TinyAgentConfig
12 |
13 |
14 | @dataclass
15 | class ResampledAudio:
16 | raw_bytes: bytes
17 | sample_rate: int
18 |
19 |
20 | class WhisperClient(abc.ABC):
21 | """An abstract class for the Whisper Servers."""
22 |
23 | def __init__(self, config: TinyAgentConfig):
24 | pass
25 |
26 | @abc.abstractmethod
27 | async def transcribe(self, file: io.BytesIO) -> str:
28 | """Transcribe the audio to text using Whisper."""
29 | pass
30 |
31 | @staticmethod
32 | @abc.abstractmethod
33 | def resample_audio(raw_bytes: bytes, sample_rate: int) -> ResampledAudio:
34 | """Resample the audio to the target sample rate which the Whisper server expects"""
35 | pass
36 |
37 |
38 | class WhisperOpenAIClient(WhisperClient):
39 |
40 | def __init__(self, config: TinyAgentConfig):
41 | self.client = AsyncOpenAI(api_key=config.whisper_config.api_key)
42 |
43 | async def transcribe(self, file: io.BytesIO) -> str:
44 | transcript = await self.client.audio.transcriptions.create(
45 | model="whisper-1",
46 | file=file,
47 | response_format="verbose_json",
48 | language="en",
49 | )
50 | return transcript.text
51 |
52 | @staticmethod
53 | def resample_audio(raw_bytes: bytes, sample_rate: int) -> ResampledAudio:
54 | # OpenAI Whisper API already resample your bytes so this is just a noop
55 | return ResampledAudio(raw_bytes, sample_rate)
56 |
57 |
58 | class WhisperCppClient(WhisperClient):
59 | # Whisper.cpp server expects 16kHz audio
60 | _TARGET_SAMPLE_RATE = 16000
61 | _NON_DATA_FIELDS = {
62 | "temperature": "0.0",
63 | "temperature_inc": "0.2",
64 | "response_format": "json",
65 | }
66 |
67 | _base_url: str
68 |
69 | def __init__(self, config: TinyAgentConfig):
70 | self._base_url = f"http://localhost:{config.whisper_config.port}/inference"
71 |
72 | async def transcribe(self, file: io.BytesIO) -> str:
73 | # Preparing the files dictionary
74 | files = {"file": ("audio.wav", file, "audio/wav")}
75 |
76 | # Send the request to the Whisper.cpp server
77 | async with httpx.AsyncClient() as client:
78 | response = await client.post(
79 | self._base_url, files=files, data=WhisperCppClient._NON_DATA_FIELDS
80 | )
81 |
82 | data = response.json()
83 | if "text" not in data:
84 | raise ValueError(
85 | f"Whisper.cpp response does not contain 'text' key: {json.dumps(data)}"
86 | )
87 |
88 | return data["text"]
89 |
90 | @staticmethod
91 | def resample_audio(raw_bytes: bytes, sample_rate: int) -> ResampledAudio:
92 | if sample_rate == WhisperCppClient._TARGET_SAMPLE_RATE:
93 | return ResampledAudio(raw_bytes, sample_rate)
94 |
95 | # Resample the audio to 16kHz
96 | converted, _ = audioop.ratecv(
97 | raw_bytes, 2, 1, sample_rate, WhisperCppClient._TARGET_SAMPLE_RATE, None
98 | )
99 | return ResampledAudio(converted, WhisperCppClient._TARGET_SAMPLE_RATE)
100 |
101 |
102 | class TranscriptionService:
103 | """
104 | This is the main service that deals with transcribing raw audio data to text using Whisper.
105 | """
106 |
107 | _client: WhisperClient
108 |
109 | def __init__(self, client: WhisperClient):
110 | self._client = client
111 |
112 | async def transcribe(self, raw_bytes: bytes, sample_rate: int) -> str:
113 | """
114 | This method transcribes the audio data to text using Whisper.
115 | """
116 | resampled_audio = self._client.resample_audio(raw_bytes, sample_rate)
117 | raw_bytes, sample_rate = resampled_audio.raw_bytes, resampled_audio.sample_rate
118 |
119 | with io.BytesIO() as memfile:
120 | # Open a new WAV file in write mode using the in-memory stream
121 | with wave.open(memfile, "wb") as wav_file:
122 | # Set the parameters for the WAV file
123 | wav_file.setnchannels(1)
124 | wav_file.setsampwidth(2) # 16-bit PCM, so 2 bytes per sample
125 | wav_file.setframerate(sample_rate)
126 |
127 | # Write the PCM data to the WAV file
128 | wav_file.writeframes(raw_bytes)
129 |
130 | memfile.name = "audio.wav"
131 |
132 | transcript = await self._client.transcribe(memfile)
133 |
134 | return transcript.strip()
135 |
--------------------------------------------------------------------------------
/src/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import uuid
4 | from dataclasses import dataclass, field
5 | from enum import Enum
6 | from typing import Any, Literal
7 |
8 | from src.llm_compiler.constants import JOINNER_FINISH, JOINNER_REPLAN
9 |
10 |
11 | class PlanStepToolName(Enum):
12 | GET_PHONE_NUMBER = "get_phone_number"
13 | GET_EMAIL_ADDRESS = "get_email_address"
14 | CREATE_CALENDAR_EVENT = "create_calendar_event"
15 | OPEN_AND_GET_FILE_PATH = "open_and_get_file_path"
16 | SUMMARIZE_PDF = "summarize_pdf"
17 | COMPOSE_NEW_EMAIL = "compose_new_email"
18 | REPLY_TO_EMAIL = "reply_to_email"
19 | FORWARD_EMAIL = "forward_email"
20 | MAPS_OPEN_LOCATION = "maps_open_location"
21 | MAPS_SHOW_DIRECTIONS = "maps_show_directions"
22 | CREATE_NOTE = "create_note"
23 | OPEN_NOTE = "open_note"
24 | APPEND_NOTE_CONTENT = "append_note_content"
25 | CREATE_REMINDER = "create_reminder"
26 | SEND_SMS = "send_sms"
27 | GET_ZOOM_MEETING_LINK = "get_zoom_meeting_link"
28 | JOIN = "join"
29 |
30 |
31 | class DataPointType(Enum):
32 | PLAN = "plan"
33 | JOIN = "join"
34 | REPLAN = "replan"
35 |
36 |
37 | @dataclass
38 | class PlanStep:
39 | tool_name: PlanStepToolName
40 | tool_args: dict[str, Any]
41 |
42 | def serialize(self) -> dict[str, Any]:
43 | return {
44 | "tool_name": (
45 | self.tool_name.value
46 | if isinstance(self.tool_name, PlanStepToolName)
47 | else "join"
48 | ),
49 | "tool_args": self.tool_args,
50 | }
51 |
52 |
53 | @dataclass
54 | class JoinStep:
55 | thought: str
56 | action: Literal[JOINNER_FINISH, JOINNER_REPLAN] # type: ignore
57 | # Message is only applicable if the action is "Finish"
58 | message: str
59 |
60 | def serialize(self) -> dict[str, Any]:
61 | return {
62 | "thought": self.thought,
63 | "action": self.action,
64 | "message": self.message,
65 | }
66 |
67 |
68 | @dataclass
69 | class DataPoint:
70 | type: DataPointType
71 | raw_input: str
72 | raw_output: str
73 | parsed_output: list[PlanStep] | JoinStep
74 | id: uuid.UUID
75 | index: int
76 |
77 | def serialize(self) -> dict[str, Any]:
78 | return {
79 | "type": self.type.value,
80 | "raw_input": self.raw_input,
81 | "raw_output": self.raw_output,
82 | "parsed_output": (
83 | [step.serialize() for step in self.parsed_output]
84 | if isinstance(self.parsed_output, list)
85 | else self.parsed_output.serialize()
86 | ),
87 | "id": str(self.id),
88 | "index": self.index,
89 | }
90 |
91 |
92 | @dataclass
93 | class Data:
94 | input: str
95 | output: list[DataPoint]
96 | closest_5_queries: list[str] = field(default_factory=list)
97 |
98 | def serialize(self) -> dict[str, Any]:
99 | return {
100 | "input": self.input,
101 | "output": [point.serialize() for point in self.output],
102 | "closest_5_queries": self.closest_5_queries,
103 | }
104 |
105 |
106 | def deserialize_data(data_json: dict[str, Any]) -> dict[str, Data]:
107 | data_objects = {}
108 |
109 | for key, value in data_json.items():
110 | # Assuming 'input' and 'output' are directly accessible within `value`
111 | input_str = value["input"]
112 | output_list = value["output"]
113 |
114 | data_points = []
115 | for output in output_list:
116 | parsed_output = output["parsed_output"]
117 | data_point = DataPoint(
118 | type=DataPointType(output["type"]),
119 | raw_input=output["raw_input"],
120 | raw_output=output["raw_output"],
121 | parsed_output=(
122 | (
123 | [
124 | PlanStep(
125 | tool_name=(PlanStepToolName(step["tool_name"])),
126 | tool_args=step["tool_args"],
127 | )
128 | for step in parsed_output
129 | ]
130 | if isinstance(parsed_output, list)
131 | else JoinStep(
132 | thought=parsed_output["thought"],
133 | action=parsed_output["action"],
134 | message=parsed_output["message"],
135 | )
136 | )
137 | ),
138 | id=uuid.UUID(output["id"]),
139 | index=output["index"],
140 | )
141 |
142 | data_points.append(data_point)
143 |
144 | data_objects[key] = Data(
145 | input=input_str,
146 | output=data_points,
147 | )
148 |
149 | if "closest_5_queries" in value:
150 | data_objects[key].closest_5_queries = value["closest_5_queries"]
151 |
152 | return data_objects
153 |
154 |
155 | def save_data(data_objects: dict[str, Any], json_path: str) -> None:
156 | data_json = {}
157 |
158 | for key, data in data_objects.items():
159 | data_json[key] = data.serialize()
160 |
161 | with open(json_path, "w") as file:
162 | json.dump(data_json, file, indent=4)
163 |
164 |
165 | def initialize_data_objects(json_path: str) -> dict[str, Data]:
166 | if not os.path.exists(json_path):
167 | with open(json_path, "w") as file:
168 | file.write("{}")
169 | data_objects = {}
170 | else:
171 | with open(json_path, "r") as file:
172 | data_objects = json.load(file)
173 |
174 | data_objects = deserialize_data(data_objects)
175 | return data_objects
176 |
--------------------------------------------------------------------------------
/src/utils/graph_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any
3 |
4 | import networkx as nx
5 |
6 | from src.utils.data_utils import PlanStep
7 |
8 | DEPENDENCY_REGEX = r"\$\d+"
9 |
10 |
11 | def check_for_dependency(item: str) -> str | None:
12 | match = re.search(DEPENDENCY_REGEX, item)
13 | if match:
14 | return match.group()
15 | return None
16 |
17 |
18 | def build_graph(plan: list[PlanStep]) -> nx.DiGraph:
19 | function_calls = [step.serialize() for step in plan]
20 | graph = nx.DiGraph()
21 |
22 | def add_new_dependency_edge(
23 | *,
24 | node_name: str,
25 | dependency: str,
26 | ) -> tuple[dict[str, Any], str]:
27 | dependency_index = int(dependency[1:]) - 1
28 | dependency_func = function_calls[dependency_index]
29 | dependency_node_name = f"{dependency_index+1}: {dependency_func['tool_name']}"
30 | graph.add_edge(dependency_node_name, node_name)
31 |
32 | return dependency_func, dependency_node_name
33 |
34 | def add_dependencies(dependencies: list[Any], node_name: str):
35 | nested_deps = []
36 | for dep in dependencies:
37 | if isinstance(dep, list):
38 | # Checking for dependencies in list-type arguments
39 | for item in dep:
40 | if (
41 | isinstance(item, str)
42 | and (item := check_for_dependency(item)) is not None
43 | ):
44 | dependency_func, dependency_node_name = add_new_dependency_edge(
45 | node_name=node_name, dependency=item
46 | )
47 |
48 | # In order to avoid infinite recursion, we check
49 | # if the dependency node is the same as the current node
50 | if dependency_node_name == node_name:
51 | raise ValueError(
52 | "Circular dependency detected. "
53 | "The tool is dependent on itself, which is not allowed."
54 | )
55 |
56 | # Recursively add nested dependencies
57 | dep_dict = {
58 | "tool_name": dependency_func["tool_name"],
59 | "dependencies": (
60 | add_dependencies(
61 | dependency_func["tool_args"], dependency_node_name
62 | )
63 | if dependency_node_name != node_name
64 | else []
65 | ),
66 | }
67 | nested_deps.append(dep_dict)
68 | elif (
69 | isinstance(dep, str) and (dep := check_for_dependency(dep)) is not None
70 | ):
71 | dependency_func, dependency_node_name = add_new_dependency_edge(
72 | node_name=node_name, dependency=dep
73 | )
74 |
75 | # In order to avoid infinite recursion, we check
76 | # if the dependency node is the same as the current node
77 | if dependency_node_name == node_name:
78 | raise ValueError(
79 | "Circular dependency detected. "
80 | "The tool is dependent on itself, which is not allowed."
81 | )
82 |
83 | # Recursively add nested dependencies
84 | # In order to avoid infinite recursion, we check if the dependency node is the same as the current node
85 | dep_dict = {
86 | "tool_name": dependency_func["tool_name"],
87 | "dependencies": (
88 | add_dependencies(
89 | dependency_func["tool_args"], dependency_node_name
90 | )
91 | if dependency_node_name != node_name
92 | else []
93 | ),
94 | }
95 | nested_deps.append(dep_dict)
96 | return nested_deps
97 |
98 | for index, func in enumerate(function_calls):
99 | node_name = f"{index+1}: {func['tool_name']}"
100 | graph.add_node(
101 | node_name, tool_name=func["tool_name"], tool_args=func["tool_args"]
102 | )
103 | dependencies = add_dependencies(func["tool_args"], node_name)
104 | graph.nodes[node_name]["dependencies"] = dependencies
105 |
106 | return graph
107 |
108 |
109 | def node_match(n1: dict, n2: dict) -> bool:
110 | if n1["tool_name"] != n2["tool_name"]:
111 | return False
112 |
113 | # Compare dependency structures recursively
114 | def compare_dependencies(d1, d2):
115 | if len(d1) != len(d2):
116 | return False
117 | for item1, item2 in zip(
118 | sorted(d1, key=lambda x: x["tool_name"]),
119 | sorted(d2, key=lambda x: x["tool_name"]),
120 | ):
121 | if not (
122 | item1["tool_name"] == item2["tool_name"]
123 | and compare_dependencies(item1["dependencies"], item2["dependencies"])
124 | ):
125 | return False
126 | return True
127 |
128 | return compare_dependencies(n1["dependencies"], n2["dependencies"])
129 |
130 |
131 | def compare_graphs_with_success_rate(graph1: nx.DiGraph, graph2: nx.DiGraph) -> float:
132 | """Ouputs 1.0 if the graphs are isomorphic, 0.0 otherwise."""
133 | if nx.is_isomorphic(graph1, graph2, node_match=node_match):
134 | return 1.0
135 | return 0.0
136 |
137 |
138 | def compare_graphs_with_edit_distance(graph1: nx.DiGraph, graph2: nx.DiGraph) -> float:
139 | """
140 | Calculates the graph edit distance between two graphs.
141 | """
142 | f_node_match = lambda n1, n2: node_match(n1, n2)
143 |
144 | # First, check for isomoprhism since it is a faster check. If yes, the GED is 0.
145 | if nx.is_isomorphic(graph1, graph2, node_match=f_node_match):
146 | return 0.0
147 |
148 | # Calculate the graph edit distance
149 | ged = nx.graph_edit_distance(
150 | graph1,
151 | graph2,
152 | node_match=f_node_match,
153 | timeout=10,
154 | )
155 |
156 | return ged
157 |
--------------------------------------------------------------------------------
/src/utils/logger_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 |
8 | from src.tiny_agent.models import TINY_AGENT_DIR
9 |
10 | # Global variable to toggle logging
11 | LOG_ENABLED = True
12 | LOG_TO_FILE = False
13 | LOG_FILE_PATH = os.path.join(TINY_AGENT_DIR, "log.txt")
14 |
15 | # Create the log file if it doesn't exist
16 | if not os.path.exists(LOG_FILE_PATH):
17 | with open(LOG_FILE_PATH, "w") as f:
18 | pass
19 |
20 |
21 | class Logger:
22 | def __init__(self) -> None:
23 | self._latency_dict = defaultdict(list)
24 | self._answer_dict = defaultdict(list)
25 | self._label_dict = defaultdict(list)
26 |
27 | def log(self, latency: float, answer: str, label: str, key: str) -> None:
28 | self._latency_dict[key].append(latency)
29 | self._answer_dict[key].append(answer)
30 | self._label_dict[key].append(label)
31 |
32 | def _get_mean_latency(self, key: str) -> float:
33 | latency_array = np.array(self._latency_dict[key])
34 | return latency_array.mean(), latency_array.std()
35 |
36 | def _get_accuracy(self, key: str) -> float:
37 | answer_array = np.array(self._answer_dict[key])
38 | label_array = np.array(self._label_dict[key])
39 | return (answer_array == label_array).mean()
40 |
41 | def get_results(self, key: str) -> dict:
42 | mean_latency, std_latency = self._get_mean_latency(key)
43 | accuracy = self._get_accuracy(key)
44 | return {
45 | "mean_latency": mean_latency,
46 | "std_latency": std_latency,
47 | "accuracy": accuracy,
48 | }
49 |
50 | def save_result(self, key: str, path: str):
51 | with open(f"{path}/dev_react_results.csv", "w") as f:
52 | for i in range(len(self._answer_dict[key])):
53 | f.write(f"{self._answer_dict[key][i]},{self._latency_dict[key][i]}\n")
54 |
55 |
56 | def get_logger() -> Logger:
57 | return Logger()
58 |
59 |
60 | # Custom print function to toggle logging
61 |
62 |
63 | def enable_logging(enable=True):
64 | """Toggle logging on or off based on the given argument."""
65 | global LOG_ENABLED
66 | LOG_ENABLED = enable
67 |
68 |
69 | def enable_logging_to_file(enable=True):
70 | """Toggle logging on or off based on the given argument."""
71 | global LOG_TO_FILE
72 | LOG_TO_FILE = enable
73 |
74 |
75 | def log(*args, block=False, **kwargs):
76 | """Print the given string only if logging is enabled."""
77 | if LOG_ENABLED:
78 | if block:
79 | print("=" * 80)
80 | print(*args, **kwargs)
81 | if block:
82 | print("=" * 80)
83 | if LOG_TO_FILE:
84 | with open(LOG_FILE_PATH, "a") as f:
85 | if block:
86 | print("=" * 80, file=f)
87 | print(*args, **kwargs, file=f)
88 | if block:
89 | print("=" * 80, file=f)
90 |
91 |
92 | def flush_results(save_path, results):
93 | print("Saving results")
94 | json.dump(results, open(save_path, "w"), indent=4)
95 |
--------------------------------------------------------------------------------
/src/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
2 | from langchain.llms import OpenAI
3 | from langchain_community.embeddings import HuggingFaceEmbeddings
4 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
5 |
6 | from src.utils.logger_utils import log
7 |
8 | DEFAULT_SAFE_CONTEXT_LENGTH = 512
9 | DEFAULT_SENTENCE_TRANSFORMER_BATCH_SIZE = 128
10 |
11 |
12 | def get_model(
13 | model_type,
14 | model_name,
15 | api_key,
16 | vllm_port,
17 | stream,
18 | temperature=0,
19 | azure_endpoint=None,
20 | azure_deployment=None,
21 | azure_api_version=None,
22 | ):
23 | if model_type == "openai":
24 | if api_key is None:
25 | raise ValueError("api_key must be provided for openai model")
26 | llm = ChatOpenAI(
27 | model_name=model_name, # type: ignore
28 | openai_api_key=api_key, # type: ignore
29 | streaming=stream,
30 | temperature=temperature,
31 | )
32 |
33 | elif model_type == "vllm":
34 | if vllm_port is None:
35 | raise ValueError("vllm_port must be provided for vllm model")
36 | if stream:
37 | log(
38 | "WARNING: vllm does not support streaming. "
39 | "Setting stream=False for vllm model."
40 | )
41 | llm = OpenAI(
42 | openai_api_base=f"http://localhost:{vllm_port}/v1",
43 | model_name=model_name,
44 | temperature=temperature,
45 | max_retries=1,
46 | streaming=stream,
47 | )
48 | elif model_type == "local":
49 | if vllm_port is None:
50 | raise ValueError("vllm_port must be provided for vllm model")
51 | llm = ChatOpenAI(
52 | openai_api_key=api_key, # type: ignore
53 | openai_api_base=f"http://localhost:{vllm_port}/v1",
54 | model_name=model_name,
55 | temperature=temperature,
56 | max_retries=1,
57 | streaming=stream,
58 | )
59 | elif model_type == "azure":
60 | if api_key is None:
61 | raise ValueError("api_key must be provided for azure model")
62 | if azure_api_version is None:
63 | raise ValueError("azure_api_version must be provided for azure model")
64 | if azure_endpoint is None:
65 | raise ValueError("azure_endpoint must be provided for azure model")
66 | if azure_deployment is None:
67 | raise ValueError("azure_deployment must be provided for azure model")
68 |
69 | llm = AzureChatOpenAI(
70 | api_key=api_key, # type: ignore
71 | api_version=azure_api_version,
72 | azure_endpoint=azure_endpoint,
73 | azure_deployment=azure_deployment,
74 | streaming=stream,
75 | )
76 |
77 | else:
78 | raise NotImplementedError(f"Unknown model type: {model_type}")
79 |
80 | return llm
81 |
82 |
83 | def get_embedding_model(
84 | model_type: str,
85 | model_name: str,
86 | api_key: str,
87 | azure_embedding_deployment: str,
88 | azure_endpoint: str | None,
89 | azure_api_version: str | None,
90 | local_port: int | None,
91 | context_length: int | None,
92 | ) -> OpenAIEmbeddings | AzureOpenAIEmbeddings | HuggingFaceEmbeddings:
93 | if model_name is None:
94 | raise ValueError("Embedding model's model_name must be provided")
95 |
96 | if model_type == "openai":
97 | if api_key is None:
98 | raise ValueError("api_key must be provided for openai model")
99 | return OpenAIEmbeddings(api_key=api_key, model=model_name)
100 | elif model_type == "azure":
101 | if api_key is None:
102 | raise ValueError("api_key must be provided for azure model")
103 | if azure_api_version is None:
104 | raise ValueError("azure_api_version must be provided for azure model")
105 | if azure_endpoint is None:
106 | raise ValueError("azure_endpoint must be provided for azure model")
107 | return AzureOpenAIEmbeddings(
108 | api_key=api_key,
109 | api_version=azure_api_version,
110 | azure_endpoint=azure_endpoint,
111 | azure_deployment=azure_embedding_deployment,
112 | model=model_name,
113 | )
114 | elif model_type == "local":
115 | if local_port is None:
116 | # Use SentenceTransformer for local embeddings
117 | return HuggingFaceEmbeddings(
118 | model_name=model_name,
119 | encode_kwargs={"batch_size": DEFAULT_SENTENCE_TRANSFORMER_BATCH_SIZE},
120 | )
121 | if context_length is None:
122 | print(
123 | "WARNING: context_length not provided for local model. Using default value (512).",
124 | flush=True,
125 | )
126 | context_length = DEFAULT_SAFE_CONTEXT_LENGTH
127 | return OpenAIEmbeddings(
128 | api_key=api_key,
129 | base_url=f"http://localhost:{local_port}/v1",
130 | model=model_name,
131 | embedding_ctx_length=context_length - 1,
132 | tiktoken_enabled=False,
133 | tiktoken_model_name=model_name,
134 | )
135 | else:
136 | raise NotImplementedError(f"Unknown model type: {model_type}")
137 |
--------------------------------------------------------------------------------
/src/utils/plan_utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import re
3 | from typing import Any
4 |
5 | from src.llm_compiler.constants import JOINNER_FINISH, JOINNER_REPLAN
6 | from src.llm_compiler.task_fetching_unit import Task
7 | from src.utils.data_utils import JoinStep, PlanStep, PlanStepToolName
8 |
9 | THOUGHT_PATTERN = r"Thought: ([^\n]*)"
10 | ACTION_PATTERN = r"\s*\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?"
11 |
12 |
13 | def _parse_llm_compiler_action_args(args: str) -> Any:
14 | """Parse arguments from a string."""
15 | # This will convert the string into a python object
16 | # e.g. '"Ronaldo number of kids"' -> ("Ronaldo number of kids", )
17 | # '"I can answer the question now.", [3]' -> ("I can answer the question now.", [3])
18 | if args == "":
19 | return ()
20 | try:
21 | args = ast.literal_eval(args)
22 | except:
23 | args = args
24 | if not isinstance(args, list) and not isinstance(args, tuple):
25 | args = (args,) # type: ignore
26 | return args
27 |
28 |
29 | def parse_plan(plan: str) -> dict[int, Any]:
30 | # 1. search("Ronaldo number of kids") -> 1, "search", '"Ronaldo number of kids"'
31 | # pattern = r"(\d+)\. (\w+)\(([^)]+)\)"
32 | pattern = rf"(?:{THOUGHT_PATTERN}\n)?{ACTION_PATTERN}"
33 | matches = re.findall(pattern, plan)
34 |
35 | graph_dict = {}
36 |
37 | for match in matches:
38 | # idx = 1, function = "search", args = "Ronaldo number of kids"
39 | # thought will be the preceding thought, if any, otherwise an empty string
40 | _, idx, tool_name, args, _ = match
41 | idx = int(idx)
42 |
43 | # Create a dummy task
44 | task = Task(
45 | idx=idx,
46 | name=tool_name,
47 | tool=lambda x: None,
48 | args=_parse_llm_compiler_action_args(args),
49 | dependencies=[],
50 | stringify_rule=None,
51 | thought=None,
52 | is_join=tool_name == "join",
53 | )
54 |
55 | graph_dict[idx] = task
56 | if task.is_join:
57 | break
58 |
59 | return graph_dict
60 |
61 |
62 | def get_parsed_planner_output(tasks: dict[int, Any]) -> list[PlanStep]:
63 | steps: list[PlanStep] = []
64 | try:
65 | for _, task in tasks.items():
66 | step = PlanStep(
67 | tool_name=(PlanStepToolName(task.name)),
68 | tool_args=task.args,
69 | )
70 | steps.append(step)
71 | except Exception as e:
72 | print("Tool hallucination error: ", e)
73 |
74 | return steps
75 |
76 |
77 | def get_parsed_planner_output_from_raw(raw_answer: str) -> list[PlanStep]:
78 | tasks = parse_plan(raw_answer)
79 | return get_parsed_planner_output(tasks)
80 |
81 |
82 | def get_parsed_joinner_output(raw_answer: str) -> JoinStep:
83 | thought, answer, is_replan = "", "", False # default values
84 | raw_answers = raw_answer.split("\n")
85 | for ans in raw_answers:
86 | start_of_answer = ans.find("Action:")
87 | if start_of_answer >= 0:
88 | ans = ans[start_of_answer:]
89 | if ans.startswith("Action:"):
90 | answer = ans[ans.find("(") + 1 : ans.rfind(")")]
91 | is_replan = JOINNER_REPLAN in ans
92 | elif ans.startswith("Thought:"):
93 | thought = ans.split("Thought:")[1].strip()
94 |
95 | step = JoinStep(
96 | thought=thought,
97 | action=JOINNER_REPLAN if is_replan else JOINNER_FINISH,
98 | message=answer,
99 | )
100 | return step
101 |
102 |
103 | def evaluate_plan(label_plan: list[PlanStep], predicted_plan: list[PlanStep]) -> float:
104 | """Returns the accuracy of the predicted plan based on the tool names and the right ordering."""
105 | # Assuming the lengths of the two plans are the same
106 | correct = 0
107 | for label_step, predicted_step in zip(label_plan, predicted_plan):
108 | if label_step.tool_name == predicted_step.tool_name:
109 | correct += 1
110 | return correct / len(label_plan)
111 |
112 |
113 | def evaluate_tool_recall(
114 | label_plan: list[PlanStep], predicted_plan: list[PlanStep]
115 | ) -> float:
116 | """
117 | Instead of determining the score based on the correct function call in the plan,
118 | determines the score based on the correct set of function calls in the plan.
119 | """
120 | label_set = set([step.tool_name for step in label_plan])
121 | predicted_set = set([step.tool_name for step in predicted_plan])
122 |
123 | # Check how many of the predicted steps are in the label set
124 | correct = len(label_set.intersection(predicted_set))
125 | return correct / len(label_set)
126 |
127 |
128 | def evaluate_join(label_join: JoinStep, predicted_join: JoinStep) -> float:
129 | """Returns 1.0 if the predicted join step is correct, else 0.0."""
130 | if label_join.action == predicted_join.action:
131 | return 1.0
132 | return 0.0
133 |
--------------------------------------------------------------------------------