├── .env.example
├── .gitignore
├── README.md
├── config.json
├── configs
├── agents.json
├── example.txt
├── resources.json
├── sys_prompt.yaml
├── tasks.json
└── tools.json
├── main.py
├── requirements.txt
├── src
├── __init__.py
├── agents.py
├── clients.py
├── inference.py
├── prompter.py
├── rag_tools.py
├── resources.py
├── schema.py
├── tasks.py
├── tools.py
├── utils.py
└── validator.py
└── test_main.py
/.env.example:
--------------------------------------------------------------------------------
1 | ANTHROPIC_API_KEY=xxx
2 | ANTHROPIC_MODEL=claude-3-opus-20240229
3 |
4 | GROQ_API_KEY=xxx
5 | GROQ_MODEL=llama3-70b-8192
6 |
7 | SEC_API_API_KEY=xxx
8 |
9 | OLLAMA_MODEL=interstellarninja/hermes-2-theta-llama-3-8b
10 | #OLLAMA_MODEL=interstellarninja/hermes-2-pro-llama-3-8b
11 |
12 | LMSTUDIO_MODEL=NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF
13 | #LMSTUDIO_MODEL=NousResearch/Hermes-2-Pro-Mistral-7B-GGUF
14 |
15 | ORCHESTRATOR_CLIENT=ollama
16 | AGENT_CLIENT=ollama
17 |
18 | LOCAL_MODEL_PATH=NousResearch/Hermes-2-Pro-Llama-3-8B
19 | LOAD_IN_4BIT=True
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Virtuel environment
2 | venv
3 |
4 | # Python
5 | *.pyc
6 |
7 | # Logs
8 | *.log
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | A framework for orchestrating AI agents using a mermaid graph & networkx
2 |
3 | Example:
4 | ```
5 | streamlit run main.py
6 | ```
7 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resources": [
3 | {
4 | "type": "text",
5 | "path": "inputs/press_release.txt",
6 | "template": "Here are your thoughts on the statement '{chunk}' from the file '{file}' (start: {start}, end: {end}):"
7 | },
8 | {
9 | "type": "pdf",
10 | "path": "inputs/sec_filings_10k.pdf",
11 | "template": "The following excerpt is from the PDF '{file}' (start: {start}, end: {end}):\\n{chunk}"
12 | },
13 | {
14 | "type": "web",
15 | "path": "https://blogs.nvidia.com/",
16 | "template": "The following content is scraped from the web page '{file}':\\n{chunk}"
17 | }
18 | ],
19 | "agents": [
20 | {
21 | "name": "Researcher",
22 | "role": "Text Analysis"
23 | },
24 | {
25 | "name": "Web Analyzer",
26 | "role": "Web Scraping"
27 | }
28 | ],
29 | "tasks": [
30 | {
31 | "name": "Text Task",
32 | "agent": "Researcher",
33 | "resource": "Text"
34 | },
35 | {
36 | "name": "Web Scrape Task",
37 | "agent": "Web Analyzer",
38 | "resource": "Web"
39 | }
40 | ]
41 | }
42 |
--------------------------------------------------------------------------------
/configs/agents.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "role": "Financial Analyst",
4 | "goal": "Impress all customers with your financial data and market trends analysis",
5 | "persona": "The most seasoned financial analyst with lots of expertise in stock market analysis and investment strategies that is working for a super important customer.",
6 | "tools": [
7 | "search_10q",
8 | "search_10k",
9 | "get_stock_fundamentals",
10 | "get_financial_statements" ]
11 | },
12 | {
13 | "role": "Research Analyst",
14 | "goal": "Being the best at gather, interpret data and amaze your customer with it",
15 | "persona": "Known as the BEST research analyst, you're skilled in sifting through news, company announcements, and market sentiments. Now you're working on a super important customer",
16 | "tools": [
17 | "google_search_and_scrape",
18 | "get_company_news"
19 | ]
20 | },
21 | {
22 | "role": "Investment Advisor",
23 | "goal": "Impress your customers with full analyses over stocks and completer investment recommendations",
24 | "persona": "You're the most experienced investment advisor and you combine various analytical insights to formulate strategic investment advice. You are now working for a super important customer you need to impress.",
25 | "tools": []
26 | },
27 | {
28 | "role": "Summarizer",
29 | "persona": "You are a skilled Data Analyst with a knack for distilling complex information into concise summaries.",
30 | "goal": "Compile a summary report based on the extracted information."
31 | }
32 | ]
--------------------------------------------------------------------------------
/configs/example.txt:
--------------------------------------------------------------------------------
1 |
2 | To fulfill this query, we will need to:
3 | 1. Gather financial statements and key financial ratios for NVDA using the Financial Analyst agent
4 | 2. Gather news, company announcements, and market sentiments related to NVDA using the Research Analyst agent
5 | 3. Have the Investment Advisor agent analyze the data collected by the Financial Analyst and Research Analyst to formulate investment recommendations
6 | 4. Have the Summarizer agent compile a final summary report with the analysis and recommendations
7 | The Financial Analyst and Research Analyst can work in parallel to gather data. Once they are done, the Investment Advisor will analyze the data. Finally, the Summarizer will generate the report.
8 |
9 |
10 | [
11 | {
12 | "role": "Financial Analyst",
13 | "goal": "Gather financial statements and key financial ratios for NVDA",
14 | "persona": "The most seasoned financial analyst with lots of expertise in stock market analysis and investment strategies that is working for a super important customer.",
15 | "tools": ["get_stock_fundamentals", "get_financial_statements", "get_key_financial_ratios"],
16 | "dependencies": []
17 | },
18 | {
19 | "role": "Research Analyst",
20 | "goal": "Gather news, company announcements, and market sentiments related to NVDA",
21 | "persona": "Known as the BEST research analyst, you're skilled in sifting through news, company announcements, and market sentiments. Now you're working on a super important customer",
22 | "tools": ["google_search_and_scrape", "get_company_news"],
23 | "dependencies": []
24 | },
25 | {
26 | "role": "Investment Advisor",
27 | "goal": "Analyze the collected data and provide investment recommendations",
28 | "persona": "You're the most experienced investment advisor and you combine various analytical insights to formulate strategic investment advice. You are now working for a super important customer you need to impress.",
29 | "tools": [],
30 | "dependencies": ["Financial Analyst", "Research Analyst"]
31 | },
32 | {
33 | "role": "Summarizer",
34 | "persona": "You are a skilled Data Analyst with a knack for distilling complex information into concise summaries.",
35 | "goal": "Compile a summary report based on the extracted information.",
36 | "tools": ["speak_to_the_user", "get_current_stock_price"],
37 | "dependencies": ["Investment Advisor"]
38 | }
39 | ]
40 |
41 |
42 | graph TD;
43 | A[Financial Analyst] --> C[Investment Advisor];
44 | B[Research Analyst] --> C[Investment Advisor];
45 | C[Investment Advisor] --> D[Summarizer];
46 |
--------------------------------------------------------------------------------
/configs/resources.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "type": "text",
4 | "path": "inputs/press_release.txt",
5 | "template": "Here are your thoughts on the statement '{chunk}' from the file '{file}' (start: {start}, end: {end}):"
6 | },
7 | {
8 | "type": "pdf",
9 | "path": "inputs/sec_filings_10k.pdf",
10 | "template": "The following excerpt is from the PDF '{file}' (start: {start}, end: {end}):\\n{chunk}"
11 | },
12 | {
13 | "type": "web",
14 | "path": "https://blogs.nvidia.com/",
15 | "template": "The following content is scraped from the web page '{file}':\\n{chunk}"
16 | }
17 | ]
--------------------------------------------------------------------------------
/configs/sys_prompt.yaml:
--------------------------------------------------------------------------------
1 | Role: |
2 | You are a helpful AI assistant that generates an agent mermaid graph and agent metadata to help fulfill user queries. The current date is: {date}.
3 | Objective: |
4 | You may use agentic frameworks for reasoning and planning to help with user query.
5 | Before generating the graph, use a to plan out which agents to dispatch and how they should interact to fulfill the query.
6 | You may dispatch agents in parallel, sequentially, or as a task graph, depending on what is optimal for the query.
7 | Agents: |
8 | Here are the sub agent personas you have available:
9 |
10 | {agents}
11 |
12 | Tools: |
13 | And here are the available tools:
14 |
15 | {tools}
16 |
17 | Schema: |
18 | When specifying each agent's metadata in the graph, follow this JSON schema:
19 |
20 | {schema}
21 |
22 | Instructions: |
23 | Return the final agent metadata JSON list inside tags.
24 | And return the final Mermaid graph code inside tags.
--------------------------------------------------------------------------------
/configs/tasks.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "instructions": "Analyze the provided text and identify key insights and patterns.",
4 | "expected_output": "A list of key insights and patterns found in the text.",
5 | "agent": "Researcher",
6 | "output_file": "txt_analyzed.txt",
7 | "tool_name": "text_reader"
8 | },
9 | {
10 | "instructions": "Scrape the content from the provided URL and provide a summary.",
11 | "expected_output": "A summary of the scraped web content.",
12 | "agent": "Web Analyzer",
13 | "tool_name": "web_scraper",
14 | "output_file": "web_task_output.txt"
15 | },
16 | {
17 | "instructions": "Analyze the provided system documentation and develop a comprehensive plan for enhancing system performance, reliability, and efficiency.",
18 | "expected_output": "A detailed plan outlining strategies and steps for optimizing the systems.",
19 | "agent": "Planner",
20 | "tool_name": "system_docs",
21 | "output_file": "system_plan.txt"
22 | },
23 | {
24 | "instructions": "Search the provided files for information relevant to the given query.",
25 | "expected_output": "A list of relevant files with their similarity scores.",
26 | "agent": "Semantic Searcher",
27 | "tool_name": "semantic_search",
28 | "context": ["system_plan", "txt_task"]
29 | },
30 | {
31 | "instructions": "Using the insights from the researcher and web analyzer, compile a summary report.",
32 | "expected_output": "A well-structured summary report based on the extracted information.",
33 | "agent": "Summarizer",
34 | "context": ["system_plan", "txt_task", "web_task"],
35 | "output_file": "task2_output.txt"
36 | },
37 | {
38 | "instructions": "Analyze the sentiment of the extracted information.",
39 | "expected_output": "A sentiment analysis report based on the extracted information.",
40 | "agent": "Sentimentalizer",
41 | "context": ["summary", "txt_task"],
42 | "output_file": "sentimentalizer_output.txt",
43 | "tool_name": "sentiment_analysis"
44 | },
45 | {
46 | "instructions": "Extract named entities from the summary report.",
47 | "expected_output": "A list of extracted named entities.",
48 | "agent": "Entity Extractor",
49 | "context": ["summary", "search_task"],
50 | "output_file": "ner_output.txt",
51 | "tool_name": "ner_extraction"
52 | },
53 | {
54 | "instructions": "Generate a mermaid diagram based on the summary report.",
55 | "expected_output": "A mermaid graph illustrating the relationships and connections in the summary report.\n```mermaid\ngraph TD\n",
56 | "agent": "Mermaid",
57 | "context": ["summary", "txt_task", "search_task"],
58 | "output_file": "mermaid_output.txt"
59 | }
60 | ]
--------------------------------------------------------------------------------
/configs/tools.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "type": "function",
4 | "function": {
5 | "name": "speak_to_the_user",
6 | "description": "speak_to_the_user(message: str) -> str - Prompts the user to provide more context or feedback through the terminal or Streamlit interface.\n\nArgs:\n prompt (str): The prompt or question to ask the user.\n\nReturns:\n str: The user's response to the prompt.",
7 | "parameters": {
8 | "type": "object",
9 | "properties": {
10 | "message": {
11 | "type": "string"
12 | }
13 | },
14 | "required": [
15 | "message"
16 | ]
17 | }
18 | }
19 | },
20 | {
21 | "type": "function",
22 | "function": {
23 | "name": "code_interpreter",
24 | "description": "code_interpreter(code_markdown: str) -> dict | str - Execute the provided Python code string on the terminal using exec.\n\n The string should contain valid, executable and pure Python code in markdown syntax.\n Code should also import any required Python packages.\n\n Args:\n code_markdown (str): The Python code with markdown syntax to be executed.\n For example: ```python\n\n```\n\n Returns:\n dict | str: A dictionary containing variables declared and values returned by function calls,\n or an error message if an exception occurred.\n\n Note:\n Use this function with caution, as executing arbitrary code can pose security risks.",
25 | "parameters": {
26 | "type": "object",
27 | "properties": {
28 | "code_markdown": {
29 | "type": "string"
30 | }
31 | },
32 | "required": [
33 | "code_markdown"
34 | ]
35 | }
36 | }
37 | },
38 | {
39 | "type": "function",
40 | "function": {
41 | "name": "google_search_and_scrape",
42 | "description": "google_search_and_scrape(query: str) -> dict - Performs a Google search for the given query, retrieves the top search result URLs,\nand scrapes the text content and table data from those pages in parallel.\n\nArgs:\n query (str): The search query.\nReturns:\n list: A list of dictionaries containing the URL, text content, and table data for each scraped page.",
43 | "parameters": {
44 | "type": "object",
45 | "properties": {
46 | "query": {
47 | "type": "string"
48 | }
49 | },
50 | "required": [
51 | "query"
52 | ]
53 | }
54 | }
55 | },
56 | {
57 | "type": "function",
58 | "function": {
59 | "name": "get_current_stock_price",
60 | "description": "get_current_stock_price(symbol: str) -> float - Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.",
61 | "parameters": {
62 | "type": "object",
63 | "properties": {
64 | "symbol": {
65 | "type": "string"
66 | }
67 | },
68 | "required": [
69 | "symbol"
70 | ]
71 | }
72 | }
73 | },
74 | {
75 | "type": "function",
76 | "function": {
77 | "name": "get_stock_fundamentals",
78 | "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n dict: A dictionary containing fundamental data.\n Keys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.",
79 | "parameters": {
80 | "type": "object",
81 | "properties": {
82 | "symbol": {
83 | "type": "string"
84 | }
85 | },
86 | "required": [
87 | "symbol"
88 | ]
89 | }
90 | }
91 | },
92 | {
93 | "type": "function",
94 | "function": {
95 | "name": "get_financial_statements",
96 | "description": "get_financial_statements(symbol: str) -> dict - Get financial statements for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).",
97 | "parameters": {
98 | "type": "object",
99 | "properties": {
100 | "symbol": {
101 | "type": "string"
102 | }
103 | },
104 | "required": [
105 | "symbol"
106 | ]
107 | }
108 | }
109 | },
110 | {
111 | "type": "function",
112 | "function": {
113 | "name": "get_key_financial_ratios",
114 | "description": "get_key_financial_ratios(symbol: str) -> dict - Get key financial ratios for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing key financial ratios.",
115 | "parameters": {
116 | "type": "object",
117 | "properties": {
118 | "symbol": {
119 | "type": "string"
120 | }
121 | },
122 | "required": [
123 | "symbol"
124 | ]
125 | }
126 | }
127 | },
128 | {
129 | "type": "function",
130 | "function": {
131 | "name": "get_analyst_recommendations",
132 | "description": "get_analyst_recommendations(symbol: str) -> pandas.core.frame.DataFrame - Get analyst recommendations for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing analyst recommendations.",
133 | "parameters": {
134 | "type": "object",
135 | "properties": {
136 | "symbol": {
137 | "type": "string"
138 | }
139 | },
140 | "required": [
141 | "symbol"
142 | ]
143 | }
144 | }
145 | },
146 | {
147 | "type": "function",
148 | "function": {
149 | "name": "get_dividend_data",
150 | "description": "get_dividend_data(symbol: str) -> pandas.core.frame.DataFrame - Get dividend data for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing dividend data.",
151 | "parameters": {
152 | "type": "object",
153 | "properties": {
154 | "symbol": {
155 | "type": "string"
156 | }
157 | },
158 | "required": [
159 | "symbol"
160 | ]
161 | }
162 | }
163 | },
164 | {
165 | "type": "function",
166 | "function": {
167 | "name": "get_company_news",
168 | "description": "get_company_news(symbol: str) -> pandas.core.frame.DataFrame - Get company news and press releases for a given stock symbol.\nThis function returns titles and url which need further scraping using other tools.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.",
169 | "parameters": {
170 | "type": "object",
171 | "properties": {
172 | "symbol": {
173 | "type": "string"
174 | }
175 | },
176 | "required": [
177 | "symbol"
178 | ]
179 | }
180 | }
181 | },
182 | {
183 | "type": "function",
184 | "function": {
185 | "name": "get_technical_indicators",
186 | "description": "get_technical_indicators(symbol: str) -> pandas.core.frame.DataFrame - Get technical indicators for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing technical indicators.",
187 | "parameters": {
188 | "type": "object",
189 | "properties": {
190 | "symbol": {
191 | "type": "string"
192 | }
193 | },
194 | "required": [
195 | "symbol"
196 | ]
197 | }
198 | }
199 | },
200 | {
201 | "type": "function",
202 | "function": {
203 | "name": "get_company_profile",
204 | "description": "get_company_profile(symbol: str) -> dict - Get company profile and overview for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing company profile and overview.",
205 | "parameters": {
206 | "type": "object",
207 | "properties": {
208 | "symbol": {
209 | "type": "string"
210 | }
211 | },
212 | "required": [
213 | "symbol"
214 | ]
215 | }
216 | }
217 | },
218 | {
219 | "type": "function",
220 | "function": {
221 | "name": "search_10q",
222 | "description": "search_10q(data) - Useful to search information from the latest 10-Q form for a\n given stock.\n The input to this tool should be a pipe (|) separated text of\n length two, representing the stock ticker you are interested and what\n question you have from it.\n\t\tFor example, `AAPL|what was last quarter's revenue`.",
223 | "parameters": {
224 | "type": "object",
225 | "properties": {
226 | "data": {}
227 | },
228 | "required": [
229 | "data"
230 | ]
231 | }
232 | }
233 | },
234 | {
235 | "type": "function",
236 | "function": {
237 | "name": "search_10k",
238 | "description": "search_10k(data) - Useful to search information from the latest 10-K form for a\ngiven stock.\nThe input to this tool should be a pipe (|) separated text of\nlength two, representing the stock ticker you are interested, what\nquestion you have from it.\nFor example, `AAPL|what was last year's revenue`.",
239 | "parameters": {
240 | "type": "object",
241 | "properties": {
242 | "data": {}
243 | },
244 | "required": [
245 | "data"
246 | ]
247 | }
248 | }
249 | },
250 | {
251 | "type": "function",
252 | "function": {
253 | "name": "get_historical_price",
254 | "description": "get_historical_price(symbol, start_date, end_date) - Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.\n- symbol (str): Stock ticker symbol.\n- end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date\n- start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date",
255 | "parameters": {
256 | "type": "object",
257 | "properties": {
258 | "symbol": {},
259 | "start_date": {},
260 | "end_date": {}
261 | },
262 | "required": [
263 | "symbol",
264 | "start_date",
265 | "end_date"
266 | ]
267 | }
268 | }
269 | },
270 | {
271 | "type": "function",
272 | "function": {
273 | "name": "plot_price_over_time",
274 | "description": "plot_price_over_time(symbol, start_date, end_date) - Plots the historical stock prices for a given symbol over a specified date range.\n- symbol (str): Stock ticker symbol.\n- start_date (str): Start date for the historical data in 'YYYY-MM-DD' format.\n- end_date (str): End date for the historical data in 'YYYY-MM-DD' format.",
275 | "parameters": {
276 | "type": "object",
277 | "properties": {
278 | "symbol": {},
279 | "start_date": {},
280 | "end_date": {}
281 | },
282 | "required": [
283 | "symbol",
284 | "start_date",
285 | "end_date"
286 | ]
287 | }
288 | }
289 | }
290 | ]
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import argparse
5 | import networkx as nx
6 | from datetime import datetime
7 |
8 | import streamlit as st
9 |
10 | from src.rag_tools import *
11 | from src.prompter import PromptManager
12 | from src.resources import Resource
13 | from src.agents import Agent
14 | from src.clients import CLIENTS
15 | from src.tools import get_openai_tools, get_function_names
16 | from src.utils import inference_logger
17 |
18 | from matplotlib import pyplot as plt
19 |
20 | import logfire
21 |
22 | class AgentOrchestrator:
23 | def __init__(self, agents: List['Agent'], agents_config, resources: List['Resource'], verbose: bool = False, log_file: str = "orchestrator_log.json"):
24 | self.client = os.getenv('ORCHESTRATOR_CLIENT')
25 | self.agents = agents
26 | self.agents_config = agents_config
27 | self.resources = resources
28 | self.verbose = verbose
29 | self.log_file = log_file
30 | self.log_data = []
31 | self.llama_logs = []
32 |
33 | def run(self, query: str) -> str:
34 | #tools = get_openai_tools()
35 | tools = get_function_names()
36 |
37 | tool_descriptions = []
38 | for tool in tools:
39 | tool_descriptions.append({
40 | "name": tool.name,
41 | "description": tool.description
42 | })
43 |
44 | mermaid_graph, agents_metadata = self.load_or_generate_graph(query, self.agents_config, tool_descriptions, self.resources)
45 | st.write(mermaid_graph)
46 |
47 | G = nx.DiGraph()
48 | for agent_data in agents_metadata:
49 | agent_role = agent_data["role"]
50 | G.add_node(agent_role, **agent_data)
51 |
52 | # Add edges to the graph based on the dependencies
53 | for agent_data in agents_metadata:
54 | agent_role = agent_data["role"]
55 | dependencies = agent_data.get("dependencies", [])
56 | for dependency in dependencies:
57 | G.add_edge(dependency, agent_role)
58 |
59 | # Visualize the graph using python-mermaid
60 | self.visualize_graph(G)
61 |
62 | # Create a dictionary to store the output of each agent
63 | agent_outputs = {}
64 |
65 | # Execute agents in topological order (respecting dependencies)
66 | for agent_role in nx.topological_sort(G):
67 | agent_data = G.nodes[agent_role]
68 | agent = Agent(**agent_data)
69 |
70 | st.write(f"Starting Agent: {agent.role}", unsafe_allow_html=True)
71 | if agent.verbose:
72 | st.write(f"Agent Persona: {agent.persona}", unsafe_allow_html=True)
73 | st.write(f"Agent Goal: {agent.goal}", unsafe_allow_html=True)
74 |
75 | # Prepare the input messages for the agent
76 | input_messages = []
77 | for predecessor in G.predecessors(agent_role):
78 | if predecessor in agent_outputs:
79 | input_messages.append({"role": predecessor, "content": agent_outputs[predecessor]})
80 |
81 | agent.input_messages = input_messages
82 |
83 | # Execute the agent
84 | output = agent.execute()
85 |
86 | if agent.verbose:
87 | st.write(f"Agent Output:\n{output}\n", unsafe_allow_html=True)
88 |
89 | agent_outputs[agent_role] = output
90 | self.llama_logs.extend(agent.interactions)
91 |
92 | # Collect the final output from all the agents
93 | final_output = "\n".join([f"Agent: {role}\nGoal: {G.nodes[role]['goal']}\nOutput:\n{output}\n" for role, output in agent_outputs.items()])
94 |
95 | self.save_logs()
96 | self.save_llama_logs()
97 |
98 | return final_output
99 |
100 | def visualize_graph(self, G):
101 | pos = nx.spring_layout(G)
102 | nx.draw(G, pos, with_labels=False, node_size=1000, node_color='lightblue', font_size=12, font_weight='bold', arrows=True)
103 | labels = nx.get_node_attributes(G, 'role')
104 | nx.draw_networkx_labels(G, pos, labels, font_size=12)
105 | plt.axis('off')
106 | plt.tight_layout()
107 | st.pyplot(plt)
108 | plt.close()
109 |
110 | def load_or_generate_graph(self, query, agents, tools, resources):
111 | mermaid_graph_file = "mermaid_graph.txt"
112 | agent_metadata_file = "agent_metadata.json"
113 |
114 | if os.path.exists(mermaid_graph_file) and os.path.exists(agent_metadata_file):
115 | with open(mermaid_graph_file, "r") as file:
116 | mermaid_graph = file.read()
117 | with open(agent_metadata_file, "r") as file:
118 | agents_metadata = json.load(file)
119 | else:
120 | mermaid_graph = self.agent_dispatcher(query, agents, tools, resources)
121 | agents_metadata = self.extract_agents_from_mermaid(mermaid_graph)
122 |
123 | with open(mermaid_graph_file, "w") as file:
124 | file.write(mermaid_graph)
125 | with open(agent_metadata_file, "w") as file:
126 | json.dump(agents_metadata, file, indent=2)
127 |
128 | return mermaid_graph, agents_metadata
129 |
130 | def agent_dispatcher(self, query, agents, tools, resources):
131 | chat = [{"role": "user", "content": query}]
132 | prompter = PromptManager()
133 | sys_prompt = prompter.generate_prompt(tools, agents, resources, one_shot=True)
134 |
135 | #response = CLIENTS.chat_completion(
136 | # client="anthropic",
137 | # messages=[
138 | # {"role": "system", "content": sys_prompt},
139 | # *chat
140 | # ]
141 | #)
142 | inference_logger.info(f"Running inference with {self.client}")
143 | response = CLIENTS.chat_completion(
144 | client=self.client,
145 | messages=[
146 | {"role": "system", "content": sys_prompt},
147 | *chat
148 | ]
149 | )
150 | inference_logger.info(f"Assistant Message:\n{response}")
151 | inference_logger.info(response)
152 | st.write(response)
153 | return response
154 |
155 | def extract_agents_from_mermaid(self, mermaid_graph):
156 | graph_content = re.search(r'(.*?)', mermaid_graph, re.DOTALL)
157 |
158 | if graph_content:
159 | graph_content = graph_content.group(1)
160 |
161 | metadata_content = re.search(r'(.*?)', mermaid_graph, re.DOTALL)
162 | if metadata_content:
163 | metadata_content = metadata_content.group(1)
164 |
165 | dependency_pattern = r'(\w+) --> (\w+)'
166 |
167 | agents_metadata = json.loads(metadata_content)
168 | dependencies = []
169 |
170 | for match in re.finditer(dependency_pattern, graph_content):
171 | source = match.group(1)
172 | target = match.group(2)
173 | dependencies.append((source, target))
174 |
175 | #for agent in agents_metadata:
176 | # agent["dependencies"] = [target for source, target in dependencies if source == agent["role"]]
177 |
178 | return agents_metadata
179 |
180 | def save_llama_logs(self):
181 | with open(("qa_interactions" + datetime.now().strftime("%Y%m%d%H%M%S") + ".json"), "w") as file:
182 | json.dump(self.llama_logs, file, indent=2)
183 |
184 | def save_logs(self):
185 | with open(self.log_file, "w") as file:
186 | json.dump(self.log_data, file, indent=2)
187 |
188 | def parse_args():
189 | parser = argparse.ArgumentParser(description="Run the agent orchestrator with dynamic configurations.")
190 | parser.add_argument('-q', '--query', type=str, help="user query for agents to assist with", required=True)
191 | return parser.parse_args()
192 |
193 | def mainflow():
194 | st.title("Stock Analysis with MeeseeksAI Agents")
195 | multiline_text = """
196 | Try to ask it "What is the current price of Meta stock?" or "Show me the historical prices of Apple vs Microsoft stock over the past 6 months.".
197 | """
198 |
199 | st.markdown(multiline_text, unsafe_allow_html=True)
200 |
201 | # Add customization options to the sidebar
202 | #st.sidebar.title('Customization')
203 | #additional_context = st.sidebar.text_input('Enter additional summarization context for the LLM here (i.e. write it in spanish):')
204 |
205 | # Get the user's question
206 | user_question = st.text_input("Ask a question about a stock or multiple stocks:")
207 |
208 | if user_question:
209 | file_path = os.path.join(os.getcwd())
210 | with open(os.path.join(file_path, "configs/agents.json"), "r") as file:
211 | agents_data = json.load(file)
212 | agents = [Agent(**agent_data) for agent_data in agents_data]
213 |
214 | with open(os.path.join(file_path, "configs/resources.json"), "r") as file:
215 | resources_data = json.load(file)
216 | resources = [Resource(**resource_data) for resource_data in resources_data]
217 |
218 | orchestrator = AgentOrchestrator(
219 | agents=agents,
220 | agents_config=agents_data,
221 | resources=resources,
222 | verbose=True,
223 | log_file="orchestrator_log" + datetime.now().strftime("%Y%m%d%H%M%S") + ".json"
224 | )
225 |
226 | orchestrator.run(user_question)
227 |
228 | ## Wrap the final output in a scrollable container
229 | #output_container = st.container()
230 | #with output_container:
231 | # st.write(f"Final output:\n{result}")
232 | #
233 | ## Make the output container scrollable
234 | #output_container_height = min(len(result.split('\n')) * 30, 500) # Adjust the height based on the number of lines
235 | #output_container.markdown(
236 | # f"""
237 | #
243 | # """,
244 | # unsafe_allow_html=True
245 | #)
246 |
247 | if __name__ == "__main__":
248 | mainflow()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai
2 | PyYAML
3 | requests
4 | bs4
5 | tiktoken
6 | pandas
7 | yfinance
8 | PyPDF2
9 | textblob
10 | gpt4all
11 | spacy
12 | groq
13 | python-dotenv
14 | logfire
15 | sec_api
16 | unstructured
17 | sentence_transformers
18 | faiss-cpu
19 | streamlit
20 | plotly
21 | langchain
22 | anthropic
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interstellarninja/MeeseeksAI/432fac235a8521a84dce541be9561323df80766e/src/__init__.py
--------------------------------------------------------------------------------
/src/agents.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any, Dict, List, Optional
3 | import uuid
4 | from pydantic import BaseModel, Field
5 | from datetime import datetime
6 | from src.clients import CLIENTS
7 | from src import tools
8 | from src.tools import *
9 | from src.rag_tools import *
10 | from src.utils import inference_logger
11 | from src.utils import validate_and_extract_tool_calls
12 | from langchain.tools import StructuredTool, BaseTool
13 |
14 | from langchain_core.messages import ToolMessage
15 |
16 | ## TODO: add default tools such as "get_user_feedback", "get_additional_context", "code_interpreter" etc.
17 |
18 | class Agent(BaseModel):
19 | class Config:
20 | arbitrary_types_allowed = True # Allow arbitrary types
21 | exclude = {"client", "tool_objects"}
22 |
23 | id: str = Field(default_factory=lambda: str(uuid.uuid4()))
24 | role: str
25 | persona: Optional[str] = None
26 | goal: str
27 | tools: List[str] = []
28 | dependencies: Optional[List[str]] = None
29 | user_feedback: bool = False
30 | verbose: bool = True
31 | model: str = Field(default_factory=lambda: os.getenv('AGENT_MODEL')) # agent model from environment variable
32 | max_iter: int = 5
33 | client: str = Field(default_factory=lambda: os.getenv('AGENT_CLIENT'))
34 | tool_objects: Dict[str, Any] = {}
35 | input_messages: List[Dict] = []
36 | interactions: List[Dict] = []
37 |
38 | def __init__(self, **data: Any):
39 | super().__init__(**data)
40 | if not self.client:
41 | raise ValueError("Invalid client specified.")
42 | self.tool_objects = self.create_tool_objects()
43 |
44 | def create_tool_objects(self) -> Dict[str, Any]:
45 | tool_objects = {}
46 |
47 | if self.user_feedback:
48 | self.tools.append('speak_to_the_user')
49 |
50 | for tool_name in self.tools:
51 | if tool_name in globals():
52 | tool_objects[tool_name] = globals()[tool_name]
53 | else:
54 | raise ValueError(f"Tool '{tool_name}' not found.")
55 | return tool_objects
56 |
57 | def execute(self) -> str:
58 | messages = []
59 | if self.persona and self.verbose:
60 | current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
61 | messages.append({"role": "system", "content": f"You are a {self.role} with the persona: {self.persona}. Current date and time: {current_datetime}"})
62 |
63 | # Check if the agent has available tools
64 | if self.tool_objects:
65 | #print(self.tool_objects)
66 | # Serialize the tool objects to JSON schema
67 | tool_schemas = []
68 | for tool_name, tool_object in self.tool_objects.items():
69 | if isinstance(tool_object, StructuredTool):
70 | tool_schema = {
71 | "name": tool_object.name,
72 | "description": tool_object.description,
73 | "parameters": tool_object.args_schema.schema()
74 | }
75 | tool_schemas.append(tool_schema)
76 |
77 | # Append the tool schemas to the system prompt within tags
78 | system_prompt = "You are a function calling AI model."
79 | system_prompt += f"\nHere are the available tools:\n\n{json.dumps(tool_schemas, indent=2)}\n\n"
80 | system_prompt += f"You should call the tools provided to you sequentially\n"
81 | system_prompt += """
82 | Please use XML tags to record your reasoning and planning before you call the functions as follows:
83 |
84 | {step-by-step reasoning and plan in bullet points}
85 |
86 | For each function call return a json object with function name and arguments within XML tags as follows:
87 |
88 | {"arguments": , "name": }
89 |
90 | """
91 | messages.append({"role": "system", "content": system_prompt})
92 |
93 | if self.input_messages:
94 | for input_message in self.input_messages:
95 | role = input_message["role"]
96 | content = input_message["content"]
97 | inference_logger.info(f"Appending input messages from previous agent: {role}")
98 | messages.append({"role": "system", "content": f"\n<{role}>\n{content}\n{role}>\n"})
99 |
100 | messages.append({"role": "user", "content": f"Your task is to {self.goal}."})
101 |
102 | depth = 0
103 | while depth < self.max_iter:
104 | inference_logger.info(f"Running inference with {self.client}")
105 | result = CLIENTS.chat_completion(
106 | client=self.client,
107 | messages=messages,
108 | )
109 | inference_logger.info(f"Assistant Message:\n{result}")
110 | messages.append({"role": "assistant", "content": result})
111 |
112 |
113 | # Process the agent's response and extract tool calls
114 | if self.tool_objects:
115 | validation, tool_calls, scratchpad, error_message = validate_and_extract_tool_calls(result)
116 |
117 | if validation and tool_calls:
118 | inference_logger.info(f"Parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
119 |
120 | # Print parsed tool calls as JSON markdown in Streamlit
121 | if self.verbose:
122 | st.write(f"Agent Plan:\n{scratchpad}\n", unsafe_allow_html=True)
123 | st.markdown(f"**Parsed Tool Calls:**")
124 | st.json(tool_calls)
125 |
126 | # Execute the tool calls
127 | tool_message = f"Sub-agent iteration {depth} to assist with user query: {self.goal}. Summarize the tool results:\n"
128 | if tool_calls:
129 | for tool_call in tool_calls:
130 | tool_name = tool_call.get("name")
131 | tool_object = self.tool_objects.get(tool_name)
132 | if tool_object:
133 | try:
134 | tool_args = tool_call.get("arguments", {})
135 |
136 | # Print invoking tool message in blue
137 | st.markdown(f"Invoking tool: {tool_name}", unsafe_allow_html=True)
138 |
139 | tool_result = tool_object._run(**tool_args) if isinstance(tool_object, BaseTool) else self.execute_function_call(tool_call)
140 |
141 | tool_message += f"\n{tool_result}\n\n"
142 | inference_logger.info(f"Response from tool '{tool_name}':\n{tool_result}")
143 |
144 | # Print tool response in green
145 | #st.markdown(f"Tool response: {tool_result}", unsafe_allow_html=True)
146 | except Exception as e:
147 | error_message = f"Error executing tool '{tool_name}': {str(e)}"
148 | tool_message += f"\n{error_message}\n\n"
149 | else:
150 | error_message = f"Tool '{tool_name}' not found."
151 | tool_message += f"\n{error_message}\n\n"
152 | messages.append({"role": "user", "content": tool_message})
153 | #messages.append({"role": "tool", "content": tool_message})
154 | #messages.append(ToolMessage(tool_message, tool_call_id=0))
155 | else:
156 | inference_logger.info(f"No tool calls found in the agent's response.")
157 | break
158 | elif error_message:
159 | inference_logger.info(f"Error parsing tool calls: {error_message}")
160 | tool_message = f"\n{error_message}\n\n"
161 | messages.append({"role": "user", "content": tool_message})
162 | #messages.append({"role": "tool", "content": tool_message})
163 | #messages.append(ToolMessage(tool_message, tool_call_id=0))
164 |
165 | depth += 1
166 | else:
167 | break
168 |
169 | # Log the final interaction
170 | self.log_interaction(messages, result)
171 | return result
172 |
173 | def execute_function_call(self, tool_call):
174 | function_name = tool_call.get("name")
175 | function_to_call = getattr(tools, function_name, None)
176 | function_args = tool_call.get("arguments", {})
177 |
178 | if function_to_call:
179 | inference_logger.info(f"Invoking function call {function_name} ...")
180 | function_response = function_to_call(**function_args)
181 | results_dict = f'{{"name": "{function_name}", "content": {json.dumps(function_response)}}}'
182 | return results_dict
183 | else:
184 | raise ValueError(f"Function '{function_name}' not found.")
185 |
186 | def log_interaction(self, prompt, response):
187 | self.interactions.append({
188 | "role": self.role,
189 | "messages": prompt,
190 | "response": response,
191 | "agent_messages": self.input_messages,
192 | "tools": self.tools,
193 | "timestamp": datetime.now().isoformat()
194 | })
--------------------------------------------------------------------------------
/src/clients.py:
--------------------------------------------------------------------------------
1 | import os
2 | from openai import OpenAI
3 | from anthropic import Anthropic
4 | from groq import Groq
5 | from dotenv import load_dotenv
6 |
7 | load_dotenv(".env", override=True)
8 |
9 | class Clients:
10 | def __init__(self):
11 | self.clients = {}
12 |
13 | def initialize_ollama(self):
14 | self.clients["ollama"] = {
15 | "client": OpenAI(
16 | base_url='http://localhost:11434/v1',
17 | api_key='ollama',
18 | ),
19 | "model": os.getenv("OLLAMA_MODEL")
20 | }
21 |
22 | def initialize_groq(self):
23 | self.clients["groq"] = {
24 | "client": Groq(
25 | api_key=os.getenv('GROQ_API_KEY')
26 | ),
27 | "model": os.getenv("GROQ_MODEL")
28 | }
29 |
30 | def initialize_anthropic(self):
31 | self.clients["anthropic"] = {
32 | "client": Anthropic(
33 | api_key=os.getenv("ANTHROPIC_API_KEY")
34 | ),
35 | "model": os.getenv("ANTHROPIC_MODEL")
36 | }
37 |
38 | def initialize_lmstudio(self):
39 | self.clients["lmstudio"] = {
40 | "client": OpenAI(
41 | base_url="http://localhost:1234/v1",
42 | #base_url="http://192.168.1.2:1234/v1",
43 | api_key="lm-studio"
44 | ),
45 | "model": os.getenv("LMSTUDIO_MODEL")
46 | }
47 |
48 | def initialize_localllama(self):
49 | from src.inference import ModelInference
50 | self.clients["localllama"] = {
51 | "client": ModelInference(
52 | model_path=os.getenv("LOCAL_MODEL_PATH"),
53 | load_in_4bit=os.getenv("LOAD_IN_4BIT", "False")
54 | ),
55 | "model": None
56 | }
57 |
58 | def chat_completion(self, client, messages):
59 | if client == "ollama":
60 | self.initialize_ollama()
61 | elif client == "groq":
62 | self.initialize_groq()
63 | elif client == "anthropic":
64 | self.initialize_anthropic()
65 | elif client == "lmstudio":
66 | self.initialize_lmstudio()
67 | elif client == "localllama":
68 | self.initialize_localllama()
69 | else:
70 | raise ValueError(f"Unsupported client: {client}")
71 |
72 | response = self.clients[client]["client"].chat.completions.create(
73 | model=self.clients[client]["model"],
74 | messages=messages,
75 | )
76 | completion = response.choices[0].message.content
77 |
78 | return completion
79 |
80 | CLIENTS = Clients()
81 |
82 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 |
4 | from transformers import (
5 | AutoModelForCausalLM,
6 | AutoTokenizer,
7 | BitsAndBytesConfig
8 | )
9 |
10 | from src.utils import (
11 | inference_logger,
12 | )
13 |
14 | class ModelInference:
15 | def __init__(self, model_path, load_in_4bit):
16 | self.bnb_config = None
17 |
18 | if load_in_4bit == "True":
19 | self.bnb_config = BitsAndBytesConfig(
20 | load_in_4bit=True,
21 | bnb_4bit_quant_type="nf4",
22 | bnb_4bit_use_double_quant=True,
23 | llm_int8_enable_fp32_cpu_offload=True,
24 | )
25 | self.model = AutoModelForCausalLM.from_pretrained(
26 | model_path,
27 | trust_remote_code=True,
28 | return_dict=True,
29 | quantization_config=self.bnb_config,
30 | torch_dtype=torch.float32,
31 | #attn_implementation="flash_attention_2",
32 | device_map="auto",
33 | )
34 |
35 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
36 | self.tokenizer.pad_token = self.tokenizer.eos_token
37 | self.tokenizer.padding_side = "left"
38 |
39 | if self.tokenizer.chat_template is None:
40 | print("No chat template defined, getting chat_template...")
41 | self.tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
42 |
43 | inference_logger.info(self.model.config)
44 | inference_logger.info(self.model.generation_config)
45 | inference_logger.info(self.tokenizer.special_tokens_map)
46 |
47 |
48 | def run_inference(self, prompt):
49 | inputs = self.tokenizer.apply_chat_template(
50 | prompt,
51 | add_generation_prompt=True,
52 | return_tensors='pt'
53 | )
54 |
55 | tokens = self.model.generate(
56 | inputs.to(self.model.device),
57 | max_new_tokens=1500,
58 | temperature=0.8,
59 | repetition_penalty=1.1,
60 | do_sample=True,
61 | eos_token_id=self.tokenizer.eos_token_id
62 | )
63 | completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
64 |
65 | assistant_message = self.get_assistant_message(completion, "chatml", self.tokenizer.eos_token)
66 | return assistant_message
67 |
68 | def get_assistant_message(self, completion, chat_template, eos_token):
69 | """define and match pattern to find the assistant message"""
70 | completion = completion.strip()
71 | if chat_template == "chatml":
72 | assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
73 | else:
74 | raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
75 |
76 | assistant_match = assistant_pattern.search(completion)
77 | if assistant_match:
78 | assistant_content = assistant_match.group(1).strip()
79 | return assistant_content.replace(eos_token, "")
80 | else:
81 | assistant_content = None
82 | inference_logger.info("No match found for the assistant pattern")
83 | return assistant_content
--------------------------------------------------------------------------------
/src/prompter.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from pydantic import BaseModel
3 | from typing import Dict
4 | from src.schema import Agent
5 | from src.utils import (
6 | get_fewshot_examples
7 | )
8 | import yaml
9 | import json
10 | import os
11 |
12 | class PromptSchema(BaseModel):
13 | Role: str
14 | Objective: str
15 | Agents: str
16 | Tools: str
17 | #Resources: str
18 | #Examples: str
19 | Schema: str
20 | Instructions: str
21 |
22 | class PromptManager:
23 | def __init__(self):
24 | self.script_dir = os.path.dirname(os.path.abspath(__file__))
25 |
26 | def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
27 | formatted_prompt = ""
28 | for field, value in prompt_schema.dict().items():
29 | if field == "Examples" and variables.get("examples") is None:
30 | continue
31 | formatted_value = value.format(**variables)
32 | if field == "Instructions":
33 | formatted_prompt += f"{formatted_value}"
34 | else:
35 | formatted_value = formatted_value.replace("\n", " ")
36 | formatted_prompt += f"{formatted_value}"
37 | return formatted_prompt
38 |
39 | def read_yaml_file(self, file_path: str) -> PromptSchema:
40 | with open(file_path, 'r') as file:
41 | yaml_content = yaml.safe_load(file)
42 |
43 | prompt_schema = PromptSchema(
44 | Role=yaml_content.get('Role', ''),
45 | Objective=yaml_content.get('Objective', ''),
46 | Agents=yaml_content.get('Agents', ''),
47 | Tools=yaml_content.get('Tools', ''),
48 | #Resources=yaml_content.get('Resources', ''),
49 | #Examples=yaml_content.get('Examples', ''),
50 | Schema=yaml_content.get('Schema', ''),
51 | Instructions=yaml_content.get('Instructions', ''),
52 | )
53 | return prompt_schema
54 |
55 | def generate_prompt(self, tools, agents, resources, one_shot=False):
56 | prompt_path = os.path.join(self.script_dir, '../configs', 'sys_prompt.yaml')
57 | prompt_schema = self.read_yaml_file(prompt_path)
58 |
59 | schema_json = json.loads(Agent.schema_json())
60 | #schema = schema_json.get("properties", {})
61 |
62 | variables = {
63 | "date": datetime.date.today(),
64 | "agents": agents,
65 | "tools": tools,
66 | "resources": None,
67 | #"examples": examples,
68 | "schema": schema_json
69 | }
70 | sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
71 | #print(sys_prompt)
72 |
73 | prompt = [
74 | {'role': 'system', 'content': sys_prompt}
75 | ]
76 |
77 | if one_shot:
78 | #examples = get_fewshot_examples(num_fewshot)
79 | with open(os.path.join(self.script_dir, '../configs', 'example.txt'), 'r') as file:
80 | examples = file.read()
81 |
82 | prompt.extend([
83 | {"role": "user", "content": "Perform fundamental analysis of NVDA stock and provide portfolio recommendations"},
84 | {"role": "assistant", "content": examples}
85 | ])
86 | print(prompt)
87 | return sys_prompt
--------------------------------------------------------------------------------
/src/rag_tools.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import PyPDF2
4 | import requests
5 | from bs4 import BeautifulSoup
6 | import spacy
7 | from textblob import TextBlob
8 | import pickle
9 | from gpt4all import Embed4All
10 | from typing import Any, Callable, Dict, List, Optional, Type
11 | from pydantic import BaseModel, Field
12 | from langchain.tools import BaseTool
13 | from langchain.callbacks.manager import CallbackManagerForToolRun
14 |
15 | from src.resources import TextChunker
16 |
17 | class WikipediaSearchInput(BaseModel):
18 | query: str = Field(description="The search query for Wikipedia")
19 | top_k: int = Field(default=3, description="The number of top search results to return")
20 |
21 | class WikipediaSearchTool(BaseTool):
22 | name = "wikipedia_search"
23 | description = "Searches Wikipedia for relevant information based on a given query"
24 | args_schema: Type[BaseModel] = WikipediaSearchInput
25 |
26 | def _run(self, query: str, top_k: int = 3, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, str]]:
27 | url = f"https://en.wikipedia.org/w/index.php?search={query}&title=Special:Search&fulltext=1"
28 | response = requests.get(url)
29 | soup = BeautifulSoup(response.text, 'html.parser')
30 |
31 | search_results = []
32 | for result in soup.find_all('li', class_='mw-search-result'):
33 | title = result.find('a').get_text()
34 | url = 'https://en.wikipedia.org' + result.find('a')['href']
35 | page_response = requests.get(url)
36 | page_soup = BeautifulSoup(page_response.text, 'html.parser')
37 | content = page_soup.find('div', class_='mw-parser-output').get_text()
38 | chunks = TextChunker().chunk_text(text=content, chunk_size=1000, num_chunks=10)
39 | search_results.append({'title': title, 'url': url, 'chunks': chunks})
40 | if len(search_results) >= top_k:
41 | break
42 |
43 | return search_results
44 |
45 | async def _arun(self, query: str, top_k: int = 3, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, str]]:
46 | raise NotImplementedError("WikipediaSearchTool does not support async")
47 |
48 | class SemanticFileSearchInput(BaseModel):
49 | query: str = Field(description="The search query for semantic file search")
50 |
51 | class SemanticFileSearchTool(BaseTool):
52 | name = "semantic_file_search"
53 | description = "Performs semantic search on a set of files based on a given query"
54 | args_schema: Type[BaseModel] = SemanticFileSearchInput
55 |
56 | def __init__(self, file_paths: List[str], embed_model: str, embed_dim: int = 768, chunk_size: int = 1000, top_k: int = 3):
57 | self.embedder = Embed4All(embed_model)
58 | self.embed_dim = embed_dim
59 | self.chunk_size = chunk_size
60 | self.top_k = top_k
61 | self.chunker = TextChunker(text=None, chunk_size=chunk_size)
62 | self.file_embeddings = self.load_or_generate_file_embeddings(file_paths)
63 |
64 | def load_or_generate_file_embeddings(self, file_paths: List[str]) -> Dict[str, List[Dict[str, Any]]]:
65 | file_hash = self.get_file_hash(file_paths)
66 | pickle_file = f"file_embeddings_{file_hash}.pickle"
67 | if os.path.exists(pickle_file):
68 | self.load_embeddings(pickle_file)
69 | else:
70 | self.file_embeddings = self.generate_file_embeddings(file_paths)
71 | self.save_embeddings(pickle_file)
72 | return self.file_embeddings
73 |
74 | def get_file_hash(self, file_paths: List[str]) -> str:
75 | file_contents = "".join(sorted([os.path.basename(path) for path in file_paths]))
76 | return hashlib.sha256(file_contents.encode()).hexdigest()
77 |
78 | def generate_file_embeddings(self, file_paths: List[str]) -> Dict[str, List[Dict[str, Any]]]:
79 | file_embeddings = {}
80 | for file_path in file_paths:
81 | if file_path.endswith('.pdf'):
82 | text = self.extract_text_from_pdf(file_path)
83 | else:
84 | with open(file_path, 'r') as file:
85 | text = file.read()
86 | chunks = self.chunker.chunk_text(text=text, chunk_size=self.chunk_size)
87 | chunk_embeddings = [self.embedder.embed(chunk['text'], prefix='search_document') for chunk in chunks]
88 | file_embeddings[file_path] = [(chunk['text'], embedding) for chunk, embedding in zip(chunks, chunk_embeddings)]
89 | return file_embeddings
90 |
91 | def extract_text_from_pdf(self, file_path: str) -> str:
92 | with open(file_path, "rb") as file:
93 | pdf_reader = PyPDF2.PdfReader(file)
94 | text = ""
95 | for page in pdf_reader.pages:
96 | text += page.extract_text() + "\n"
97 | return text
98 |
99 | def save_embeddings(self, pickle_file: str):
100 | with open(pickle_file, 'wb') as f:
101 | pickle.dump(self.file_embeddings, f)
102 |
103 | def load_embeddings(self, pickle_file: str):
104 | with open(pickle_file, 'rb') as f:
105 | self.file_embeddings = pickle.load(f)
106 |
107 | def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
108 | query_embedding = self.embedder.embed(query, prefix='search_query')
109 | scores = []
110 | for file_path, chunk_data in self.file_embeddings.items():
111 | for chunk_text, embedding in chunk_data:
112 | chunk_score = self.cosine_similarity(query_embedding, embedding)
113 | scores.append(((file_path, chunk_text), chunk_score))
114 | sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
115 | top_scores = sorted_scores[:self.top_k]
116 | result = []
117 | for (file_path, chunk_text), score in top_scores:
118 | result.append({
119 | 'file': file_path,
120 | 'text': chunk_text,
121 | 'score': score
122 | })
123 | return result
124 |
125 | def cosine_similarity(self, a: List[float], b: List[float]) -> float:
126 | import numpy as np
127 | a = np.array(a)
128 | b = np.array(b)
129 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
130 |
131 | async def _arun(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
132 | raise NotImplementedError("SemanticFileSearchTool does not support async")
133 |
134 | class TextReaderInput(BaseModel):
135 | text_file: str = Field(description="The path to the text file to read")
136 |
137 | class TextReaderTool(BaseTool):
138 | name = "text_reader"
139 | description = "Reads text from a file and chunks it into smaller pieces"
140 | args_schema: Type[BaseModel] = TextReaderInput
141 |
142 | def __init__(self, chunk_size: int, num_chunks: int):
143 | self.chunk_size = chunk_size
144 | self.num_chunks = num_chunks
145 |
146 | def _run(self, text_file: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
147 | with open(text_file, "r") as file:
148 | text = file.read()
149 | chunker = TextChunker(text, self.chunk_size, overlap=0)
150 | chunks = chunker.chunk_text()
151 | return chunks[:self.num_chunks]
152 |
153 | async def _arun(self, text_file: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
154 | raise NotImplementedError("TextReaderTool does not support async")
155 |
156 | class WebScraperInput(BaseModel):
157 | url: str = Field(description="The URL of the web page to scrape")
158 |
159 | class WebScraperTool(BaseTool):
160 | name = "web_scraper"
161 | description = "Scrapes text from a web page and chunks it into smaller pieces"
162 | args_schema: Type[BaseModel] = WebScraperInput
163 |
164 | def __init__(self, chunk_size: int, num_chunks: int):
165 | self.chunk_size = chunk_size
166 | self.num_chunks = num_chunks
167 |
168 | def _run(self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
169 | response = requests.get(url)
170 | soup = BeautifulSoup(response.text, 'html.parser')
171 |
172 | text = soup.get_text(separator='\n')
173 | cleaned_text = ' '.join(text.split())
174 |
175 | chunker = TextChunker(cleaned_text, self.chunk_size, overlap=0)
176 | chunks = chunker.chunk_text()
177 | return chunks[:self.num_chunks]
178 |
179 | async def _arun(self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
180 | raise NotImplementedError("WebScraperTool does not support async")
181 |
182 | class NERExtractionInput(BaseModel):
183 | text: str = Field(description="The text to extract named entities from")
184 |
185 | class NERExtractionTool(BaseTool):
186 | name = "ner_extraction"
187 | description = "Extracts named entities from text using spaCy"
188 | args_schema: Type[BaseModel] = NERExtractionInput
189 |
190 | def __init__(self):
191 | self.nlp = spacy.load("en_core_web_sm")
192 |
193 | def _run(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
194 | doc = self.nlp(text)
195 | entities = []
196 |
197 | for ent in doc.ents:
198 | entities.append({
199 | "text": ent.text,
200 | "start": ent.start_char,
201 | "end": ent.end_char,
202 | "label": ent.label_
203 | })
204 |
205 | return entities
206 |
207 | async def _arun(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]:
208 | raise NotImplementedError("NERExtractionTool does not support async")
209 |
210 | class SemanticAnalysisInput(BaseModel):
211 | text: str = Field(description="The text to perform sentiment analysis on")
212 |
213 | class SemanticAnalysisTool(BaseTool):
214 | name = "semantic_analysis"
215 | description = "Performs sentiment analysis using TextBlob"
216 | args_schema: Type[BaseModel] = SemanticAnalysisInput
217 |
218 | def _run(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> Dict[str, Any]:
219 | blob = TextBlob(text)
220 | sentiment = blob.sentiment
221 | return {
222 | "polarity": sentiment.polarity,
223 | "subjectivity": sentiment.subjectivity
224 | }
225 |
226 | async def _arun(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> Dict[str, Any]:
227 | raise NotImplementedError("SemanticAnalysisTool does not support async")
--------------------------------------------------------------------------------
/src/resources.py:
--------------------------------------------------------------------------------
1 | import PyPDF2
2 | import requests
3 | from typing import Dict, Any, List, Optional
4 |
5 | from tiktoken import get_encoding
6 |
7 | from pydantic import BaseModel
8 | import PyPDF2
9 | import requests
10 | from typing import Dict, Any, List
11 |
12 | class Resource(BaseModel):
13 | type: str
14 | path: str
15 | context_template: Optional[str] = None
16 | data: Optional[str] = None
17 | chunks: Optional[List[Dict[str, Any]]] = []
18 |
19 | def load_resource(self):
20 | if self.resource_type == 'text':
21 | return self.load_text()
22 | elif self.resource_type == 'pdf':
23 | return self.load_pdf()
24 | elif self.resource_type == 'web':
25 | return self.load_web()
26 | else:
27 | raise ValueError(f"Unsupported resource type: {self.resource_type}")
28 |
29 | def load_text(self):
30 | with open(self.resource_path, 'r') as file:
31 | return file.read()
32 |
33 | def load_pdf(self):
34 | with open(self.resource_path, 'rb') as file:
35 | pdf_reader = PyPDF2.PdfReader(file)
36 | text = ""
37 | for page in pdf_reader.pages:
38 | text += page.extract_text() + "\n"
39 | return text
40 |
41 | def load_web(self):
42 | response = requests.get(self.resource_path)
43 | return response.text
44 |
45 | def chunk_resource(self, chunk_size: int, overlap: int = 0):
46 | chunker = TextChunker(self.data, chunk_size, overlap)
47 | self.chunks = chunker.chunk_text()
48 |
49 | def contextualize_chunk(self, chunk: Dict[str, Any]) -> str:
50 | if self.context_template:
51 | return self.context_template.format(
52 | chunk=chunk['text'],
53 | file=self.resource_path,
54 | start=chunk['start'],
55 | end=chunk['end']
56 | )
57 | else:
58 | return chunk['text']
59 |
60 | class TextChunker:
61 | def __init__(self, text: str = None, chunk_size: int = 1000, overlap: int = 0):
62 | self.text = text
63 | self.chunk_size = chunk_size
64 | self.overlap = overlap
65 | self.encoding = get_encoding("cl100k_base")
66 |
67 | def chunk_text(self, text: str = None, chunk_size: int = None, start_pos: int = 0) -> List[Dict[str, Any]]:
68 | if text is not None:
69 | self.text = text
70 | if chunk_size is not None:
71 | self.chunk_size = chunk_size
72 |
73 | tokens = self.encoding.encode(self.text)
74 | num_tokens = len(tokens)
75 |
76 | chunks = []
77 | current_pos = start_pos
78 |
79 | while current_pos < num_tokens:
80 | chunk_start = max(0, current_pos - self.overlap)
81 | chunk_end = min(current_pos + self.chunk_size, num_tokens)
82 |
83 | chunk_tokens = tokens[chunk_start:chunk_end]
84 | chunk_text = self.encoding.decode(chunk_tokens)
85 |
86 | chunks.append({
87 | "text": chunk_text,
88 | "start": chunk_start,
89 | "end": chunk_end
90 | })
91 |
92 | current_pos += self.chunk_size - self.overlap
93 |
94 | return chunks
95 |
--------------------------------------------------------------------------------
/src/schema.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Literal, Optional
2 | from pydantic import BaseModel, Field
3 |
4 | class Tool(BaseModel):
5 | name: str
6 |
7 | class Agent(BaseModel):
8 | role: str = Field(..., description="The role of the agent")
9 | goal: str = Field(..., description="The goal of the agent")
10 | persona: str = Field(..., description="The persona of the agent")
11 | tools: List[str] = Field(..., description="tool names available to the agent; use only tools provided and do not make them up")
12 | dependencies: List[str] = Field(..., description="List of agent nodes this agent depends on")
13 |
14 | class FunctionCall(BaseModel):
15 | arguments: dict
16 | """
17 | The arguments to call the function with, as generated by the model in JSON
18 | format. Note that the model does not always generate valid JSON, and may
19 | hallucinate parameters not defined by your function schema. Validate the
20 | arguments in your code before calling your function.
21 | """
22 |
23 | name: str
24 | """The name of the function to call."""
25 |
26 | class FunctionDefinition(BaseModel):
27 | name: str
28 | description: Optional[str] = None
29 | parameters: Optional[Dict[str, object]] = None
30 |
31 | class FunctionSignature(BaseModel):
32 | function: FunctionDefinition
33 | type: Literal["function"]
--------------------------------------------------------------------------------
/src/tasks.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable, List, Optional
3 | from src.agents import Agent
4 | import uuid
5 | import json
6 | from datetime import datetime
7 | from pydantic import BaseModel
8 |
9 | import os
10 | import json
11 |
12 | file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs/agents.json")
13 | with open(file_path, "r") as file:
14 | agents = json.load(file)
15 |
16 | agents = [Agent(**agent) for agent in agents]
17 |
18 | class Task(BaseModel):
19 | id: str = None
20 | instructions: str
21 | expected_output: str
22 | agent: Optional[str] = None
23 | async_execution: bool = False
24 | context: Optional[List[str]] = None
25 | output_file: Optional[str] = None
26 | callback: Optional[Callable] = None
27 | human_input: bool = False
28 | tool_name: Optional[str] = None
29 | input_tasks: Optional[List["Task"]] = None
30 | output_tasks: Optional[List["Task"]] = None
31 | context_agent_role: str = None
32 | prompt_data: Optional[List] = None
33 | output: Optional[str] = None
34 |
35 |
36 | def __init__(self, **data: Any):
37 | super().__init__(**data)
38 | self.id = str(uuid.uuid4())
39 | self.agent = self.load_agent(self.agent)
40 | self.context = self.context or []
41 | self.prompt_data = [] # List to hold prompt data for logging
42 | self.output = None
43 |
44 | def execute(self, context: Optional[str] = None) -> str:
45 | if not self.agent:
46 | raise Exception("No agent assigned to the task.")
47 |
48 | context_tasks = [task for task in self.context if task.output]
49 | if context_tasks:
50 | self.context_agent_role = context_tasks[0].agent.role
51 | original_context = "\n".join([f"{task.agent.role}: {task.output}" for task in context_tasks])
52 |
53 | if self.tool_name == 'semantic_search':
54 | query = "\n".join([task.output for task in context_tasks])
55 | context = query
56 | else:
57 | context = original_context
58 |
59 | # Prepare the prompt for logging before execution
60 | prompt_details = self.prepare_prompt(context)
61 | self.prompt_data.append(prompt_details)
62 |
63 | # Execute the task with the agent
64 | result = self.agent.execute_task(self, context)
65 | self.output = result
66 |
67 | if self.output_file:
68 | with open(self.output_file, "w") as file:
69 | file.write(result)
70 |
71 | if self.callback:
72 | self.callback(self)
73 |
74 | return result
75 |
76 | def prepare_prompt(self, context):
77 | """ Prepare and return the prompt details for logging """
78 | prompt = {
79 | "timestamp": datetime.now().isoformat(),
80 | "task_id": self.id,
81 | "instructions": self.instructions,
82 | "context": context,
83 | "expected_output": self.expected_output
84 | }
85 | return prompt
86 |
87 | def load_agent(self, agent_role: str) -> Optional[Agent]:
88 | """ Load the agent based on the given role """
89 | return next(agent for agent in agents if agent.role == agent_role)
90 |
91 |
--------------------------------------------------------------------------------
/src/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 | import re
4 | import inspect
5 | import sys
6 | import requests
7 | import pandas as pd
8 | import yfinance as yf
9 | import concurrent.futures
10 |
11 | import plotly.graph_objects as go
12 | import streamlit as st
13 |
14 | from sec_api import QueryApi
15 | from typing import List
16 | from bs4 import BeautifulSoup
17 | from src.utils import embedding_search, inference_logger
18 | from langchain.tools import tool
19 | from langchain_core.utils.function_calling import convert_to_openai_tool
20 |
21 | @tool
22 | def speak_to_the_user(message: str) -> str:
23 | """
24 | Prompts the user to provide more context or feedback through the terminal or Streamlit interface.
25 |
26 | Args:
27 | prompt (str): The prompt or question to ask the user.
28 |
29 | Returns:
30 | str: The user's response to the prompt.
31 | """
32 | st.write(message)
33 | user_input = st.text_input("Please provide your response:")
34 | return user_input
35 |
36 | @tool
37 | def code_interpreter(code_markdown: str) -> dict | str:
38 | """
39 | Execute the provided Python code string on the terminal using exec.
40 |
41 | The string should contain valid, executable and pure Python code in markdown syntax.
42 | Code should also import any required Python packages.
43 |
44 | Args:
45 | code_markdown (str): The Python code with markdown syntax to be executed.
46 | For example: ```python\n\n```
47 |
48 | Returns:
49 | dict | str: A dictionary containing variables declared and values returned by function calls,
50 | or an error message if an exception occurred.
51 |
52 | Note:
53 | Use this function with caution, as executing arbitrary code can pose security risks.
54 | """
55 | try:
56 | # Extracting code from Markdown code block
57 | code_lines = code_markdown.split('\n')[1:-1]
58 | code_without_markdown = '\n'.join(code_lines)
59 |
60 | # Create a new namespace for code execution
61 | exec_namespace = {}
62 |
63 | # Execute the code in the new namespace
64 | exec(code_without_markdown, exec_namespace)
65 |
66 | # Collect variables and function call results
67 | result_dict = {}
68 | for name, value in exec_namespace.items():
69 | if callable(value):
70 | try:
71 | result_dict[name] = value()
72 | except TypeError:
73 | # If the function requires arguments, attempt to call it with arguments from the namespace
74 | arg_names = inspect.getfullargspec(value).args
75 | args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names}
76 | result_dict[name] = value(**args)
77 | elif not name.startswith('_'): # Exclude variables starting with '_'
78 | result_dict[name] = value
79 |
80 | return result_dict
81 |
82 | except Exception as e:
83 | error_message = f"An error occurred: {e}"
84 | inference_logger.error(error_message)
85 | return error_message
86 |
87 | @tool("Make a calculation")
88 | def calculate(operation):
89 | """Useful to perform any mathematical calculations,
90 | like sum, minus, multiplication, division, etc.
91 | The input to this tool should be a mathematical
92 | expression, a couple examples are `200*7` or `5000/2*10`
93 | """
94 | return eval(operation)
95 |
96 | @tool
97 | def google_search_and_scrape(query: str) -> dict:
98 | """
99 | Performs a Google search for the given query, retrieves the top search result URLs,
100 | and scrapes the text content and table data from those pages in parallel.
101 |
102 | Args:
103 | query (str): The search query.
104 | Returns:
105 | list: A list of dictionaries containing the URL, text content, and table data for each scraped page.
106 | """
107 | num_results = 2
108 | url = 'https://www.google.com/search'
109 | params = {'q': query, 'num': num_results}
110 | headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
111 |
112 | inference_logger.info(f"Performing google search with query: {query}\nplease wait...")
113 | response = requests.get(url, params=params, headers=headers)
114 | soup = BeautifulSoup(response.text, 'html.parser')
115 | urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
116 |
117 | inference_logger.info(f"Scraping text from urls, please wait...")
118 | [inference_logger.info(url) for url in urls]
119 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
120 | futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)]
121 | results = []
122 | for future in concurrent.futures.as_completed(futures):
123 | url, html = future.result()
124 | soup = BeautifulSoup(html, 'html.parser')
125 | paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()]
126 | text_content = ' '.join(paragraphs)
127 | text_content = re.sub(r'\s+', ' ', text_content)
128 | table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table') for row in table.find_all('tr')]
129 | if text_content or table_data:
130 | results.append({'url': url, 'content': text_content, 'tables': table_data})
131 | return results
132 |
133 | @tool
134 | def get_current_stock_price(symbol: str) -> float:
135 | """
136 | Get the current stock price for a given symbol.
137 |
138 | Args:
139 | symbol (str): The stock symbol.
140 |
141 | Returns:
142 | float: The current stock price, or None if an error occurs.
143 | """
144 | try:
145 | stock = yf.Ticker(symbol)
146 | # Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market
147 | current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice"))
148 | return current_price if current_price else None
149 | except Exception as e:
150 | print(f"Error fetching current price for {symbol}: {e}")
151 | return None
152 |
153 | @tool
154 | def get_stock_fundamentals(symbol: str) -> dict:
155 | """
156 | Get fundamental data for a given stock symbol using yfinance API.
157 |
158 | Args:
159 | symbol (str): The stock symbol.
160 |
161 | Returns:
162 | dict: A dictionary containing fundamental data.
163 | Keys:
164 | - 'symbol': The stock symbol.
165 | - 'company_name': The long name of the company.
166 | - 'sector': The sector to which the company belongs.
167 | - 'industry': The industry to which the company belongs.
168 | - 'market_cap': The market capitalization of the company.
169 | - 'pe_ratio': The forward price-to-earnings ratio.
170 | - 'pb_ratio': The price-to-book ratio.
171 | - 'dividend_yield': The dividend yield.
172 | - 'eps': The trailing earnings per share.
173 | - 'beta': The beta value of the stock.
174 | - '52_week_high': The 52-week high price of the stock.
175 | - '52_week_low': The 52-week low price of the stock.
176 | """
177 | try:
178 | stock = yf.Ticker(symbol)
179 | info = stock.info
180 | fundamentals = {
181 | 'symbol': symbol,
182 | 'company_name': info.get('longName', ''),
183 | 'sector': info.get('sector', ''),
184 | 'industry': info.get('industry', ''),
185 | 'market_cap': info.get('marketCap', None),
186 | 'pe_ratio': info.get('forwardPE', None),
187 | 'pb_ratio': info.get('priceToBook', None),
188 | 'dividend_yield': info.get('dividendYield', None),
189 | 'eps': info.get('trailingEps', None),
190 | 'beta': info.get('beta', None),
191 | '52_week_high': info.get('fiftyTwoWeekHigh', None),
192 | '52_week_low': info.get('fiftyTwoWeekLow', None)
193 | }
194 | return fundamentals
195 | except Exception as e:
196 | print(f"Error getting fundamentals for {symbol}: {e}")
197 | return {}
198 |
199 | @tool
200 | def get_financial_statements(symbol: str) -> dict:
201 | """
202 | Get financial statements for a given stock symbol.
203 |
204 | Args:
205 | symbol (str): The stock symbol.
206 |
207 | Returns:
208 | dict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).
209 | """
210 | try:
211 | stock = yf.Ticker(symbol)
212 | financials = stock.financials
213 | return financials
214 | except Exception as e:
215 | print(f"Error fetching financial statements for {symbol}: {e}")
216 | return {}
217 |
218 | @tool
219 | def get_key_financial_ratios(symbol: str) -> dict:
220 | """
221 | Get key financial ratios for a given stock symbol.
222 |
223 | Args:
224 | symbol (str): The stock symbol.
225 |
226 | Returns:
227 | dict: Dictionary containing key financial ratios.
228 | """
229 | try:
230 | stock = yf.Ticker(symbol)
231 | key_ratios = stock.info
232 | return key_ratios
233 | except Exception as e:
234 | print(f"Error fetching key financial ratios for {symbol}: {e}")
235 | return {}
236 |
237 | @tool
238 | def get_analyst_recommendations(symbol: str) -> pd.DataFrame:
239 | """
240 | Get analyst recommendations for a given stock symbol.
241 |
242 | Args:
243 | symbol (str): The stock symbol.
244 |
245 | Returns:
246 | pd.DataFrame: DataFrame containing analyst recommendations.
247 | """
248 | try:
249 | stock = yf.Ticker(symbol)
250 | recommendations = stock.recommendations
251 | return recommendations
252 | except Exception as e:
253 | print(f"Error fetching analyst recommendations for {symbol}: {e}")
254 | return pd.DataFrame()
255 |
256 | @tool
257 | def get_dividend_data(symbol: str) -> pd.DataFrame:
258 | """
259 | Get dividend data for a given stock symbol.
260 |
261 | Args:
262 | symbol (str): The stock symbol.
263 |
264 | Returns:
265 | pd.DataFrame: DataFrame containing dividend data.
266 | """
267 | try:
268 | stock = yf.Ticker(symbol)
269 | dividends = stock.dividends
270 | return dividends
271 | except Exception as e:
272 | print(f"Error fetching dividend data for {symbol}: {e}")
273 | return pd.DataFrame()
274 |
275 | @tool
276 | def get_company_news(symbol: str) -> pd.DataFrame:
277 | """
278 | Get company news and press releases for a given stock symbol.
279 | This function returns titles and url which need further scraping using other tools.
280 |
281 | Args:
282 | symbol (str): The stock symbol.
283 |
284 | Returns:
285 | pd.DataFrame: DataFrame containing company news and press releases.
286 | """
287 | try:
288 | news = yf.Ticker(symbol).news
289 | return news
290 | except Exception as e:
291 | print(f"Error fetching company news for {symbol}: {e}")
292 | return pd.DataFrame()
293 |
294 | @tool
295 | def get_technical_indicators(symbol: str) -> pd.DataFrame:
296 | """
297 | Get technical indicators for a given stock symbol.
298 |
299 | Args:
300 | symbol (str): The stock symbol.
301 |
302 | Returns:
303 | pd.DataFrame: DataFrame containing technical indicators.
304 | """
305 | try:
306 | indicators = yf.Ticker(symbol).history(period="max")
307 | return indicators
308 | except Exception as e:
309 | print(f"Error fetching technical indicators for {symbol}: {e}")
310 | return pd.DataFrame()
311 |
312 | @tool
313 | def get_company_profile(symbol: str) -> dict:
314 | """
315 | Get company profile and overview for a given stock symbol.
316 |
317 | Args:
318 | symbol (str): The stock symbol.
319 |
320 | Returns:
321 | dict: Dictionary containing company profile and overview.
322 | """
323 | try:
324 | profile = yf.Ticker(symbol).info
325 | return profile
326 | except Exception as e:
327 | print(f"Error fetching company profile for {symbol}: {e}")
328 | return {
329 |
330 | }
331 |
332 | @tool
333 | def search_10q(data):
334 | """
335 | Useful to search information from the latest 10-Q form for a
336 | given stock.
337 | The input to this tool should be a pipe (|) separated text of
338 | length two, representing the stock ticker you are interested and what
339 | question you have from it.
340 | For example, `AAPL|what was last quarter's revenue`.
341 | """
342 | stock, ask = data.split("|")
343 | queryApi = QueryApi(api_key=os.environ['SEC_API_API_KEY'])
344 | query = {
345 | "query": {
346 | "query_string": {
347 | "query": f"ticker:{stock} AND formType:\"10-Q\""
348 | }
349 | },
350 | "from": "0",
351 | "size": "1",
352 | "sort": [{ "filedAt": { "order": "desc" }}]
353 | }
354 |
355 | fillings = queryApi.get_filings(query)['filings']
356 | if len(fillings) == 0:
357 | return "Sorry, I couldn't find any filling for this stock, check if the ticker is correct."
358 | link = fillings[0]['linkToFilingDetails']
359 | inference_logger.info(f"Running embedding search on {link}")
360 | answer = embedding_search(link, ask)
361 | return answer
362 |
363 | @tool
364 | def search_10k(data):
365 | """
366 | Useful to search information from the latest 10-K form for a
367 | given stock.
368 | The input to this tool should be a pipe (|) separated text of
369 | length two, representing the stock ticker you are interested, what
370 | question you have from it.
371 | For example, `AAPL|what was last year's revenue`.
372 | """
373 | stock, ask = data.split("|")
374 | queryApi = QueryApi(api_key=os.environ['SEC_API_API_KEY'])
375 | query = {
376 | "query": {
377 | "query_string": {
378 | "query": f"ticker:{stock} AND formType:\"10-K\""
379 | }
380 | },
381 | "from": "0",
382 | "size": "1",
383 | "sort": [{ "filedAt": { "order": "desc" }}]
384 | }
385 |
386 | fillings = queryApi.get_filings(query)['filings']
387 | if len(fillings) == 0:
388 | return "Sorry, I couldn't find any filling for this stock, check if the ticker is correct."
389 | link = fillings[0]['linkToFilingDetails']
390 | inference_logger.info(f"Running embedding search on {link}")
391 | answer = embedding_search(link, ask)
392 | return answer
393 |
394 | @tool
395 | def get_historical_price(symbol, start_date, end_date):
396 | """
397 | Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.
398 | - symbol (str): Stock ticker symbol.
399 | - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date
400 | - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date
401 | """
402 |
403 | data = yf.Ticker(symbol)
404 | hist = data.history(start=start_date, end=end_date)
405 | hist = hist.reset_index()
406 | hist[symbol] = hist['Close']
407 | return hist[['Date', symbol]]
408 |
409 | @tool
410 | def plot_price_over_time(symbol, start_date, end_date):
411 | """
412 | Plots the historical stock prices for a given symbol over a specified date range.
413 | - symbol (str): Stock ticker symbol.
414 | - start_date (str): Start date for the historical data in 'YYYY-MM-DD' format.
415 | - end_date (str): End date for the historical data in 'YYYY-MM-DD' format.
416 | """
417 |
418 | historical_price_df = get_historical_price(symbol, start_date, end_date)
419 |
420 | # Create a Plotly figure
421 | fig = go.Figure()
422 |
423 | # Add a trace for the stock symbol
424 | fig.add_trace(go.Scatter(x=historical_price_df['Date'], y=historical_price_df[symbol], mode='lines+markers', name=symbol))
425 |
426 | # Update the layout to add titles and format axis labels
427 | fig.update_layout(
428 | title=f'Stock Price Over Time: {symbol}',
429 | xaxis_title='Date',
430 | yaxis_title='Stock Price (USD)',
431 | yaxis_tickprefix='$',
432 | yaxis_tickformat=',.2f',
433 | xaxis=dict(
434 | tickangle=-45,
435 | nticks=20,
436 | tickfont=dict(size=10),
437 | ),
438 | yaxis=dict(
439 | showgrid=True, # Enable y-axis grid lines
440 | gridcolor='lightgrey', # Set grid line color
441 | ),
442 | plot_bgcolor='gray', # Set plot background to white
443 | paper_bgcolor='gray', # Set overall figure background to white
444 | )
445 |
446 | # Show the figure
447 | st.plotly_chart(fig, use_container_width=True)
448 |
449 | def get_function_names():
450 | current_module = sys.modules[__name__]
451 | module_source = inspect.getsource(current_module)
452 |
453 | tree = ast.parse(module_source)
454 | tool_functions = []
455 |
456 | for node in ast.walk(tree):
457 | if isinstance(node, ast.FunctionDef) and any(isinstance(d, ast.Name) and d.id == 'tool' for d in node.decorator_list):
458 | func_obj = getattr(current_module, node.name)
459 | tool_functions.append(func_obj)
460 |
461 | return tool_functions
462 |
463 | def get_openai_tools() -> List[dict]:
464 | functions = get_function_names()
465 |
466 | tools = [convert_to_openai_tool(f) for f in functions]
467 | return tools
468 |
469 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import datetime
3 | import importlib
4 | import inspect
5 | import json
6 | import logging
7 | from logging.handlers import RotatingFileHandler
8 | import os
9 | import re
10 | import sys
11 | from typing import List
12 | import xml.etree.ElementTree as ET
13 | import requests
14 |
15 | from langchain.text_splitter import CharacterTextSplitter
16 | from langchain_community.embeddings import HuggingFaceEmbeddings
17 | from langchain_community.vectorstores import FAISS
18 | from unstructured.partition.html import partition_html
19 |
20 |
21 | logging.basicConfig(
22 | format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
23 | datefmt="%Y-%m-%d:%H:%M:%S",
24 | level=logging.INFO,
25 | )
26 | script_dir = os.path.dirname(os.path.abspath(__file__))
27 | now = datetime.datetime.now()
28 | log_folder = os.path.join(script_dir, "inference_logs")
29 | os.makedirs(log_folder, exist_ok=True)
30 | log_file_path = os.path.join(
31 | log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
32 | )
33 | # Use RotatingFileHandler from the logging.handlers module
34 | file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
35 | file_handler.setLevel(logging.INFO)
36 |
37 | formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
38 | file_handler.setFormatter(formatter)
39 |
40 | inference_logger = logging.getLogger("function-calling-inference")
41 | inference_logger.addHandler(file_handler)
42 |
43 | def get_tool_names():
44 | # Get all the classes defined in the script
45 | classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
46 |
47 | # Extract the class names excluding the imported ones
48 | class_names = [cls[0] for cls in classes if cls[1].__module__ == 'src.tools']
49 |
50 | return class_names
51 |
52 | def get_fewshot_examples(num_fewshot):
53 | """return a list of few shot examples"""
54 | example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
55 | with open(example_path, 'r') as file:
56 | examples = json.load(file) # Use json.load with the file object, not the file path
57 | if num_fewshot > len(examples):
58 | raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
59 | return examples[:num_fewshot]
60 |
61 | def validate_and_extract_tool_calls(assistant_content):
62 | validation_result = False
63 | tool_calls = []
64 | error_message = None
65 | scratchpad_text = None
66 |
67 | try:
68 | # Use regular expression to find the text within tags
69 | scratchpad_pattern = r'(.*?)'
70 | scratchpad_match = re.search(scratchpad_pattern, assistant_content, re.DOTALL)
71 | if scratchpad_match:
72 | scratchpad_text = scratchpad_match.group(1).strip()
73 |
74 | # Use regular expression to find all tags and their contents
75 | tool_call_pattern = r'(.*?)'
76 | tool_call_matches = re.findall(tool_call_pattern, assistant_content, re.DOTALL)
77 |
78 | if not tool_call_matches:
79 | error_message = None
80 | else:
81 | for match in tool_call_matches:
82 | json_text = match.strip()
83 |
84 | try:
85 | json_data = json.loads(json_text)
86 | tool_calls.append(json_data)
87 | validation_result = True
88 | except json.JSONDecodeError as json_err:
89 | error_message = f"JSON parsing failed:\n"\
90 | f"- JSON Decode Error: {json_err}\n"\
91 | f"- Problematic JSON text: {json_text}"
92 | inference_logger.error(error_message)
93 | continue
94 |
95 | except Exception as err:
96 | error_message = f"Error during tool call extraction: {err}"
97 | inference_logger.error(error_message)
98 |
99 | return validation_result, tool_calls, scratchpad_text, error_message
100 |
101 | def extract_json_from_markdown(text):
102 | """
103 | Extracts the JSON string from the given text using a regular expression pattern.
104 |
105 | Args:
106 | text (str): The input text containing the JSON string.
107 |
108 | Returns:
109 | dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
110 | """
111 | json_pattern = r'```json\r?\n(.*?)\r?\n```'
112 | match = re.search(json_pattern, text, re.DOTALL)
113 | if match:
114 | json_string = match.group(1)
115 | try:
116 | data = json.loads(json_string)
117 | return data
118 | except json.JSONDecodeError as e:
119 | print(f"Error decoding JSON string: {e}")
120 | else:
121 | print("JSON string not found in the text.")
122 | return None
123 |
124 | def embedding_search(url, query):
125 | text = download_form_html(url)
126 | elements = partition_html(text=text)
127 | content = "\n".join([str(el) for el in elements])
128 | text_splitter = CharacterTextSplitter(
129 | separator="\n",
130 | chunk_size=1000,
131 | chunk_overlap=150,
132 | length_function=len,
133 | is_separator_regex=False,
134 | )
135 | docs = text_splitter.create_documents([content])
136 |
137 | # Load a pre-trained sentence transformer model
138 | embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
139 |
140 | # Create FAISS index and retriever
141 | index = FAISS.from_documents(docs, embedding_model)
142 | retriever = index.as_retriever()
143 |
144 | answers = retriever.invoke(query, top_k=4)
145 | chunks = []
146 | for i, doc in enumerate(answers):
147 | chunk = f"\n\n{doc.page_content}\n\n"
148 | chunks.append(chunk)
149 |
150 | result = "".join(chunks)
151 | return f"\n{result}"
152 |
153 | def download_form_html(url):
154 | headers = {
155 | 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
156 | 'Accept-Encoding': 'gzip, deflate, br',
157 | 'Accept-Language': 'en-US,en;q=0.9,pt-BR;q=0.8,pt;q=0.7',
158 | 'Cache-Control': 'max-age=0',
159 | 'Dnt': '1',
160 | 'Sec-Ch-Ua': '"Not_A Brand";v="8", "Chromium";v="120"',
161 | 'Sec-Ch-Ua-Mobile': '?0',
162 | 'Sec-Ch-Ua-Platform': '"macOS"',
163 | 'Sec-Fetch-Dest': 'document',
164 | 'Sec-Fetch-Mode': 'navigate',
165 | 'Sec-Fetch-Site': 'none',
166 | 'Sec-Fetch-User': '?1',
167 | 'Upgrade-Insecure-Requests': '1',
168 | 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
169 | }
170 |
171 | response = requests.get(url, headers=headers)
172 | return response.text
--------------------------------------------------------------------------------
/src/validator.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import json
3 | from jsonschema import validate
4 | from pydantic import ValidationError
5 | from utils import inference_logger, extract_json_from_markdown
6 | from schema import FunctionCall, FunctionSignature
7 |
8 | def validate_function_call_schema(call, signatures):
9 | try:
10 | call_data = FunctionCall(**call)
11 | except ValidationError as e:
12 | return False, str(e)
13 |
14 | for signature in signatures:
15 | try:
16 | signature_data = FunctionSignature(**signature)
17 | if signature_data.function.name == call_data.name:
18 | # Validate types in function arguments
19 | for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
20 | if arg_name in call_data.arguments:
21 | call_arg_value = call_data.arguments[arg_name]
22 | if call_arg_value:
23 | try:
24 | validate_argument_type(arg_name, call_arg_value, arg_schema)
25 | except Exception as arg_validation_error:
26 | return False, str(arg_validation_error)
27 |
28 | # Check if all required arguments are present
29 | required_arguments = signature_data.function.parameters.get('required', [])
30 | result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
31 | if not result:
32 | return False, f"Missing required arguments: {missing_arguments}"
33 |
34 | return True, None
35 | except Exception as e:
36 | # Handle validation errors for the function signature
37 | return False, str(e)
38 |
39 | # No matching function signature found
40 | return False, f"No matching function signature found for function: {call_data.name}"
41 |
42 | def check_required_arguments(call_arguments, required_arguments):
43 | missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
44 | return not bool(missing_arguments), missing_arguments
45 |
46 | def validate_enum_value(arg_name, arg_value, enum_values):
47 | if arg_value not in enum_values:
48 | raise Exception(
49 | f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
50 | )
51 |
52 | def validate_argument_type(arg_name, arg_value, arg_schema):
53 | arg_type = arg_schema.get('type', None)
54 | if arg_type:
55 | if arg_type == 'string' and 'enum' in arg_schema:
56 | enum_values = arg_schema['enum']
57 | if None not in enum_values and enum_values != []:
58 | try:
59 | validate_enum_value(arg_name, arg_value, enum_values)
60 | except Exception as e:
61 | # Propagate the validation error message
62 | raise Exception(f"Error validating function call: {e}")
63 |
64 | python_type = get_python_type(arg_type)
65 | if not isinstance(arg_value, python_type):
66 | raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
67 |
68 | def get_python_type(json_type):
69 | type_mapping = {
70 | 'string': str,
71 | 'number': (int, float),
72 | 'integer': int,
73 | 'boolean': bool,
74 | 'array': list,
75 | 'object': dict,
76 | 'null': type(None),
77 | }
78 | return type_mapping[json_type]
79 |
80 | def validate_json_data(json_object, json_schema):
81 | valid = False
82 | error_message = None
83 | result_json = None
84 |
85 | try:
86 | # Attempt to load JSON using json.loads
87 | try:
88 | result_json = json.loads(json_object)
89 | except json.decoder.JSONDecodeError:
90 | # If json.loads fails, try ast.literal_eval
91 | try:
92 | result_json = ast.literal_eval(json_object)
93 | except (SyntaxError, ValueError) as e:
94 | try:
95 | result_json = extract_json_from_markdown(json_object)
96 | except Exception as e:
97 | error_message = f"JSON decoding error: {e}"
98 | inference_logger.info(f"Validation failed for JSON data: {error_message}")
99 | return valid, result_json, error_message
100 |
101 | # Return early if both json.loads and ast.literal_eval fail
102 | if result_json is None:
103 | error_message = "Failed to decode JSON data"
104 | inference_logger.info(f"Validation failed for JSON data: {error_message}")
105 | return valid, result_json, error_message
106 |
107 | # Validate each item in the list against schema if it's a list
108 | if isinstance(result_json, list):
109 | for index, item in enumerate(result_json):
110 | try:
111 | validate(instance=item, schema=json_schema)
112 | inference_logger.info(f"Item {index+1} is valid against the schema.")
113 | except ValidationError as e:
114 | error_message = f"Validation failed for item {index+1}: {e}"
115 | break
116 | else:
117 | # Default to validation without list
118 | try:
119 | validate(instance=result_json, schema=json_schema)
120 | except ValidationError as e:
121 | error_message = f"Validation failed: {e}"
122 |
123 | except Exception as e:
124 | error_message = f"Error occurred: {e}"
125 |
126 | if error_message is None:
127 | valid = True
128 | inference_logger.info("JSON data is valid against the schema.")
129 | else:
130 | inference_logger.info(f"Validation failed for JSON data: {error_message}")
131 |
132 | return valid, result_json, error_message
--------------------------------------------------------------------------------
/test_main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | from typing import Any, Callable, Dict, List, Optional
4 | from openai import OpenAI
5 | import requests
6 | from bs4 import BeautifulSoup
7 | from datetime import datetime
8 | from tiktoken import get_encoding
9 | import pickle
10 | import requests
11 | from bs4 import BeautifulSoup
12 | import PyPDF2
13 | import os
14 | import hashlib
15 | from src.resources import Resources
16 | import argparse
17 | from datetime import datetime
18 |
19 |
20 | from src.rag_tools import TextReaderTool, WebScraperTool, SemanticAnalysisTool, NERExtractionTool, SemanticFileSearchTool, WikipediaSearchTool
21 | from src import agents
22 | from src.agents import Agent
23 | from src import tasks
24 | from src.tasks import Task
25 |
26 | class Resources:
27 | def __init__(self, resource_type, path, template):
28 | self.resource_type = resource_type
29 | self.path = path
30 | self.template = template
31 |
32 | class Agent:
33 | def __init__(self, role, tools):
34 | self.role = role
35 | self.tools = tools
36 | self.interactions = [] # Simulated interactions log
37 |
38 | class Task:
39 | def __init__(self, instructions, agent, tool_name=None):
40 | self.instructions = instructions
41 | self.agent = agent
42 | self.tool_name = tool_name
43 | self.id = str(uuid.uuid4())
44 | self.output = None
45 |
46 | def execute(self, context):
47 | # Placeholder for task execution logic
48 | return f"Executed {self.instructions} using {self.agent.role}"
49 |
50 | class Squad:
51 | def __init__(self, agents: List[Agent], tasks: List[Task], resources: List[Resources], verbose: bool = False, log_file: str = "squad_log.json"):
52 | self.id = str(uuid.uuid4())
53 | self.agents = agents
54 | self.tasks = tasks
55 | self.resources = resources
56 | self.verbose = verbose
57 | self.log_file = log_file
58 | self.log_data = []
59 | self.llama_logs = []
60 |
61 | def run(self, inputs: Optional[Dict[str, Any]] = None) -> str:
62 | context = ""
63 | for task in self.tasks:
64 | if self.verbose:
65 | print(f"Starting Task:\n{task.instructions}")
66 |
67 | self.log_data.append({
68 | "timestamp": datetime.now().isoformat(),
69 | "type": "input",
70 | "agent_role": task.agent.role,
71 | "task_name": task.instructions,
72 | "task_id": task.id,
73 | "content": task.instructions
74 | })
75 |
76 | output = task.execute(context=context)
77 | task.output = output
78 |
79 | if self.verbose:
80 | print(f"Task output:\n{output}\n")
81 |
82 | self.log_data.append({
83 | "timestamp": datetime.now().isoformat(),
84 | "type": "output",
85 | "agent_role": task.agent.role,
86 | "task_name": task.instructions,
87 | "task_id": task.id,
88 | "content": output
89 | })
90 |
91 | context += f"Task:\n{task.instructions}\nOutput:\n{output}\n\n"
92 |
93 | self.save_logs()
94 | return context
95 |
96 | def save_logs(self):
97 | with open(self.log_file, "w") as file:
98 | json.dump(self.log_data, file, indent=2)
99 |
100 | def load_configuration(file_path): # loading a config file alt is full cli?
101 | with open(file_path, 'r') as file:
102 | return json.load(file)
103 |
104 | def initialize_resources(config):
105 | resources = []
106 | for res in config["resources"]:
107 | resources.append(Resources(res['type'], res['path'], res['template']))
108 | return resources
109 |
110 | def initialize_agents_and_tasks(config):
111 | agents = [Agent(**ag) for ag in config['agents']]
112 | tasks = [Task(**tk) for tk in config['tasks']]
113 | return agents, tasks
114 |
115 | def parse_args():
116 | parser = argparse.ArgumentParser(description="Run the squad with dynamic configurations.")
117 | parser.add_argument('-c', '--config', type=str, help="Path to configuration JSON file", required=True)
118 | parser.add_argument('-v', '--verbose', action='store_true', help="Enable verbose output")
119 | return parser.parse_args()
120 |
121 | def mainflow():
122 | args = parse_args()
123 | config = load_configuration(args.config)
124 |
125 | resources = initialize_resources(config)
126 | agents, tasks = initialize_agents_and_tasks(config)
127 |
128 | squad = Squad(
129 | agents=agents,
130 | tasks=tasks,
131 | resources=resources,
132 | verbose=args.verbose,
133 | log_file="squad_goals_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".json"
134 | )
135 |
136 | result = squad.run()
137 | print(f"Final output:\n{result}")
138 |
139 | if __name__ == "__main__":
140 | mainflow()
141 |
--------------------------------------------------------------------------------