├── .env.example ├── .gitignore ├── README.md ├── config.json ├── configs ├── agents.json ├── example.txt ├── resources.json ├── sys_prompt.yaml ├── tasks.json └── tools.json ├── main.py ├── requirements.txt ├── src ├── __init__.py ├── agents.py ├── clients.py ├── inference.py ├── prompter.py ├── rag_tools.py ├── resources.py ├── schema.py ├── tasks.py ├── tools.py ├── utils.py └── validator.py └── test_main.py /.env.example: -------------------------------------------------------------------------------- 1 | ANTHROPIC_API_KEY=xxx 2 | ANTHROPIC_MODEL=claude-3-opus-20240229 3 | 4 | GROQ_API_KEY=xxx 5 | GROQ_MODEL=llama3-70b-8192 6 | 7 | SEC_API_API_KEY=xxx 8 | 9 | OLLAMA_MODEL=interstellarninja/hermes-2-theta-llama-3-8b 10 | #OLLAMA_MODEL=interstellarninja/hermes-2-pro-llama-3-8b 11 | 12 | LMSTUDIO_MODEL=NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF 13 | #LMSTUDIO_MODEL=NousResearch/Hermes-2-Pro-Mistral-7B-GGUF 14 | 15 | ORCHESTRATOR_CLIENT=ollama 16 | AGENT_CLIENT=ollama 17 | 18 | LOCAL_MODEL_PATH=NousResearch/Hermes-2-Pro-Llama-3-8B 19 | LOAD_IN_4BIT=True 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Virtuel environment 2 | venv 3 | 4 | # Python 5 | *.pyc 6 | 7 | # Logs 8 | *.log -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A framework for orchestrating AI agents using a mermaid graph & networkx 2 | 3 | Example: 4 | ``` 5 | streamlit run main.py 6 | ``` 7 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resources": [ 3 | { 4 | "type": "text", 5 | "path": "inputs/press_release.txt", 6 | "template": "Here are your thoughts on the statement '{chunk}' from the file '{file}' (start: {start}, end: {end}):" 7 | }, 8 | { 9 | "type": "pdf", 10 | "path": "inputs/sec_filings_10k.pdf", 11 | "template": "The following excerpt is from the PDF '{file}' (start: {start}, end: {end}):\\n{chunk}" 12 | }, 13 | { 14 | "type": "web", 15 | "path": "https://blogs.nvidia.com/", 16 | "template": "The following content is scraped from the web page '{file}':\\n{chunk}" 17 | } 18 | ], 19 | "agents": [ 20 | { 21 | "name": "Researcher", 22 | "role": "Text Analysis" 23 | }, 24 | { 25 | "name": "Web Analyzer", 26 | "role": "Web Scraping" 27 | } 28 | ], 29 | "tasks": [ 30 | { 31 | "name": "Text Task", 32 | "agent": "Researcher", 33 | "resource": "Text" 34 | }, 35 | { 36 | "name": "Web Scrape Task", 37 | "agent": "Web Analyzer", 38 | "resource": "Web" 39 | } 40 | ] 41 | } 42 | -------------------------------------------------------------------------------- /configs/agents.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "Financial Analyst", 4 | "goal": "Impress all customers with your financial data and market trends analysis", 5 | "persona": "The most seasoned financial analyst with lots of expertise in stock market analysis and investment strategies that is working for a super important customer.", 6 | "tools": [ 7 | "search_10q", 8 | "search_10k", 9 | "get_stock_fundamentals", 10 | "get_financial_statements" ] 11 | }, 12 | { 13 | "role": "Research Analyst", 14 | "goal": "Being the best at gather, interpret data and amaze your customer with it", 15 | "persona": "Known as the BEST research analyst, you're skilled in sifting through news, company announcements, and market sentiments. Now you're working on a super important customer", 16 | "tools": [ 17 | "google_search_and_scrape", 18 | "get_company_news" 19 | ] 20 | }, 21 | { 22 | "role": "Investment Advisor", 23 | "goal": "Impress your customers with full analyses over stocks and completer investment recommendations", 24 | "persona": "You're the most experienced investment advisor and you combine various analytical insights to formulate strategic investment advice. You are now working for a super important customer you need to impress.", 25 | "tools": [] 26 | }, 27 | { 28 | "role": "Summarizer", 29 | "persona": "You are a skilled Data Analyst with a knack for distilling complex information into concise summaries.", 30 | "goal": "Compile a summary report based on the extracted information." 31 | } 32 | ] -------------------------------------------------------------------------------- /configs/example.txt: -------------------------------------------------------------------------------- 1 | 2 | To fulfill this query, we will need to: 3 | 1. Gather financial statements and key financial ratios for NVDA using the Financial Analyst agent 4 | 2. Gather news, company announcements, and market sentiments related to NVDA using the Research Analyst agent 5 | 3. Have the Investment Advisor agent analyze the data collected by the Financial Analyst and Research Analyst to formulate investment recommendations 6 | 4. Have the Summarizer agent compile a final summary report with the analysis and recommendations 7 | The Financial Analyst and Research Analyst can work in parallel to gather data. Once they are done, the Investment Advisor will analyze the data. Finally, the Summarizer will generate the report. 8 | 9 | 10 | [ 11 | { 12 | "role": "Financial Analyst", 13 | "goal": "Gather financial statements and key financial ratios for NVDA", 14 | "persona": "The most seasoned financial analyst with lots of expertise in stock market analysis and investment strategies that is working for a super important customer.", 15 | "tools": ["get_stock_fundamentals", "get_financial_statements", "get_key_financial_ratios"], 16 | "dependencies": [] 17 | }, 18 | { 19 | "role": "Research Analyst", 20 | "goal": "Gather news, company announcements, and market sentiments related to NVDA", 21 | "persona": "Known as the BEST research analyst, you're skilled in sifting through news, company announcements, and market sentiments. Now you're working on a super important customer", 22 | "tools": ["google_search_and_scrape", "get_company_news"], 23 | "dependencies": [] 24 | }, 25 | { 26 | "role": "Investment Advisor", 27 | "goal": "Analyze the collected data and provide investment recommendations", 28 | "persona": "You're the most experienced investment advisor and you combine various analytical insights to formulate strategic investment advice. You are now working for a super important customer you need to impress.", 29 | "tools": [], 30 | "dependencies": ["Financial Analyst", "Research Analyst"] 31 | }, 32 | { 33 | "role": "Summarizer", 34 | "persona": "You are a skilled Data Analyst with a knack for distilling complex information into concise summaries.", 35 | "goal": "Compile a summary report based on the extracted information.", 36 | "tools": ["speak_to_the_user", "get_current_stock_price"], 37 | "dependencies": ["Investment Advisor"] 38 | } 39 | ] 40 | 41 | 42 | graph TD; 43 | A[Financial Analyst] --> C[Investment Advisor]; 44 | B[Research Analyst] --> C[Investment Advisor]; 45 | C[Investment Advisor] --> D[Summarizer]; 46 | -------------------------------------------------------------------------------- /configs/resources.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "type": "text", 4 | "path": "inputs/press_release.txt", 5 | "template": "Here are your thoughts on the statement '{chunk}' from the file '{file}' (start: {start}, end: {end}):" 6 | }, 7 | { 8 | "type": "pdf", 9 | "path": "inputs/sec_filings_10k.pdf", 10 | "template": "The following excerpt is from the PDF '{file}' (start: {start}, end: {end}):\\n{chunk}" 11 | }, 12 | { 13 | "type": "web", 14 | "path": "https://blogs.nvidia.com/", 15 | "template": "The following content is scraped from the web page '{file}':\\n{chunk}" 16 | } 17 | ] -------------------------------------------------------------------------------- /configs/sys_prompt.yaml: -------------------------------------------------------------------------------- 1 | Role: | 2 | You are a helpful AI assistant that generates an agent mermaid graph and agent metadata to help fulfill user queries. The current date is: {date}. 3 | Objective: | 4 | You may use agentic frameworks for reasoning and planning to help with user query. 5 | Before generating the graph, use a to plan out which agents to dispatch and how they should interact to fulfill the query. 6 | You may dispatch agents in parallel, sequentially, or as a task graph, depending on what is optimal for the query. 7 | Agents: | 8 | Here are the sub agent personas you have available: 9 | 10 | {agents} 11 | 12 | Tools: | 13 | And here are the available tools: 14 | 15 | {tools} 16 | 17 | Schema: | 18 | When specifying each agent's metadata in the graph, follow this JSON schema: 19 | 20 | {schema} 21 | 22 | Instructions: | 23 | Return the final agent metadata JSON list inside tags. 24 | And return the final Mermaid graph code inside tags. -------------------------------------------------------------------------------- /configs/tasks.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instructions": "Analyze the provided text and identify key insights and patterns.", 4 | "expected_output": "A list of key insights and patterns found in the text.", 5 | "agent": "Researcher", 6 | "output_file": "txt_analyzed.txt", 7 | "tool_name": "text_reader" 8 | }, 9 | { 10 | "instructions": "Scrape the content from the provided URL and provide a summary.", 11 | "expected_output": "A summary of the scraped web content.", 12 | "agent": "Web Analyzer", 13 | "tool_name": "web_scraper", 14 | "output_file": "web_task_output.txt" 15 | }, 16 | { 17 | "instructions": "Analyze the provided system documentation and develop a comprehensive plan for enhancing system performance, reliability, and efficiency.", 18 | "expected_output": "A detailed plan outlining strategies and steps for optimizing the systems.", 19 | "agent": "Planner", 20 | "tool_name": "system_docs", 21 | "output_file": "system_plan.txt" 22 | }, 23 | { 24 | "instructions": "Search the provided files for information relevant to the given query.", 25 | "expected_output": "A list of relevant files with their similarity scores.", 26 | "agent": "Semantic Searcher", 27 | "tool_name": "semantic_search", 28 | "context": ["system_plan", "txt_task"] 29 | }, 30 | { 31 | "instructions": "Using the insights from the researcher and web analyzer, compile a summary report.", 32 | "expected_output": "A well-structured summary report based on the extracted information.", 33 | "agent": "Summarizer", 34 | "context": ["system_plan", "txt_task", "web_task"], 35 | "output_file": "task2_output.txt" 36 | }, 37 | { 38 | "instructions": "Analyze the sentiment of the extracted information.", 39 | "expected_output": "A sentiment analysis report based on the extracted information.", 40 | "agent": "Sentimentalizer", 41 | "context": ["summary", "txt_task"], 42 | "output_file": "sentimentalizer_output.txt", 43 | "tool_name": "sentiment_analysis" 44 | }, 45 | { 46 | "instructions": "Extract named entities from the summary report.", 47 | "expected_output": "A list of extracted named entities.", 48 | "agent": "Entity Extractor", 49 | "context": ["summary", "search_task"], 50 | "output_file": "ner_output.txt", 51 | "tool_name": "ner_extraction" 52 | }, 53 | { 54 | "instructions": "Generate a mermaid diagram based on the summary report.", 55 | "expected_output": "A mermaid graph illustrating the relationships and connections in the summary report.\n```mermaid\ngraph TD\n", 56 | "agent": "Mermaid", 57 | "context": ["summary", "txt_task", "search_task"], 58 | "output_file": "mermaid_output.txt" 59 | } 60 | ] -------------------------------------------------------------------------------- /configs/tools.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "type": "function", 4 | "function": { 5 | "name": "speak_to_the_user", 6 | "description": "speak_to_the_user(message: str) -> str - Prompts the user to provide more context or feedback through the terminal or Streamlit interface.\n\nArgs:\n prompt (str): The prompt or question to ask the user.\n\nReturns:\n str: The user's response to the prompt.", 7 | "parameters": { 8 | "type": "object", 9 | "properties": { 10 | "message": { 11 | "type": "string" 12 | } 13 | }, 14 | "required": [ 15 | "message" 16 | ] 17 | } 18 | } 19 | }, 20 | { 21 | "type": "function", 22 | "function": { 23 | "name": "code_interpreter", 24 | "description": "code_interpreter(code_markdown: str) -> dict | str - Execute the provided Python code string on the terminal using exec.\n\n The string should contain valid, executable and pure Python code in markdown syntax.\n Code should also import any required Python packages.\n\n Args:\n code_markdown (str): The Python code with markdown syntax to be executed.\n For example: ```python\n\n```\n\n Returns:\n dict | str: A dictionary containing variables declared and values returned by function calls,\n or an error message if an exception occurred.\n\n Note:\n Use this function with caution, as executing arbitrary code can pose security risks.", 25 | "parameters": { 26 | "type": "object", 27 | "properties": { 28 | "code_markdown": { 29 | "type": "string" 30 | } 31 | }, 32 | "required": [ 33 | "code_markdown" 34 | ] 35 | } 36 | } 37 | }, 38 | { 39 | "type": "function", 40 | "function": { 41 | "name": "google_search_and_scrape", 42 | "description": "google_search_and_scrape(query: str) -> dict - Performs a Google search for the given query, retrieves the top search result URLs,\nand scrapes the text content and table data from those pages in parallel.\n\nArgs:\n query (str): The search query.\nReturns:\n list: A list of dictionaries containing the URL, text content, and table data for each scraped page.", 43 | "parameters": { 44 | "type": "object", 45 | "properties": { 46 | "query": { 47 | "type": "string" 48 | } 49 | }, 50 | "required": [ 51 | "query" 52 | ] 53 | } 54 | } 55 | }, 56 | { 57 | "type": "function", 58 | "function": { 59 | "name": "get_current_stock_price", 60 | "description": "get_current_stock_price(symbol: str) -> float - Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.", 61 | "parameters": { 62 | "type": "object", 63 | "properties": { 64 | "symbol": { 65 | "type": "string" 66 | } 67 | }, 68 | "required": [ 69 | "symbol" 70 | ] 71 | } 72 | } 73 | }, 74 | { 75 | "type": "function", 76 | "function": { 77 | "name": "get_stock_fundamentals", 78 | "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n dict: A dictionary containing fundamental data.\n Keys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", 79 | "parameters": { 80 | "type": "object", 81 | "properties": { 82 | "symbol": { 83 | "type": "string" 84 | } 85 | }, 86 | "required": [ 87 | "symbol" 88 | ] 89 | } 90 | } 91 | }, 92 | { 93 | "type": "function", 94 | "function": { 95 | "name": "get_financial_statements", 96 | "description": "get_financial_statements(symbol: str) -> dict - Get financial statements for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).", 97 | "parameters": { 98 | "type": "object", 99 | "properties": { 100 | "symbol": { 101 | "type": "string" 102 | } 103 | }, 104 | "required": [ 105 | "symbol" 106 | ] 107 | } 108 | } 109 | }, 110 | { 111 | "type": "function", 112 | "function": { 113 | "name": "get_key_financial_ratios", 114 | "description": "get_key_financial_ratios(symbol: str) -> dict - Get key financial ratios for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing key financial ratios.", 115 | "parameters": { 116 | "type": "object", 117 | "properties": { 118 | "symbol": { 119 | "type": "string" 120 | } 121 | }, 122 | "required": [ 123 | "symbol" 124 | ] 125 | } 126 | } 127 | }, 128 | { 129 | "type": "function", 130 | "function": { 131 | "name": "get_analyst_recommendations", 132 | "description": "get_analyst_recommendations(symbol: str) -> pandas.core.frame.DataFrame - Get analyst recommendations for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing analyst recommendations.", 133 | "parameters": { 134 | "type": "object", 135 | "properties": { 136 | "symbol": { 137 | "type": "string" 138 | } 139 | }, 140 | "required": [ 141 | "symbol" 142 | ] 143 | } 144 | } 145 | }, 146 | { 147 | "type": "function", 148 | "function": { 149 | "name": "get_dividend_data", 150 | "description": "get_dividend_data(symbol: str) -> pandas.core.frame.DataFrame - Get dividend data for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing dividend data.", 151 | "parameters": { 152 | "type": "object", 153 | "properties": { 154 | "symbol": { 155 | "type": "string" 156 | } 157 | }, 158 | "required": [ 159 | "symbol" 160 | ] 161 | } 162 | } 163 | }, 164 | { 165 | "type": "function", 166 | "function": { 167 | "name": "get_company_news", 168 | "description": "get_company_news(symbol: str) -> pandas.core.frame.DataFrame - Get company news and press releases for a given stock symbol.\nThis function returns titles and url which need further scraping using other tools.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.", 169 | "parameters": { 170 | "type": "object", 171 | "properties": { 172 | "symbol": { 173 | "type": "string" 174 | } 175 | }, 176 | "required": [ 177 | "symbol" 178 | ] 179 | } 180 | } 181 | }, 182 | { 183 | "type": "function", 184 | "function": { 185 | "name": "get_technical_indicators", 186 | "description": "get_technical_indicators(symbol: str) -> pandas.core.frame.DataFrame - Get technical indicators for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing technical indicators.", 187 | "parameters": { 188 | "type": "object", 189 | "properties": { 190 | "symbol": { 191 | "type": "string" 192 | } 193 | }, 194 | "required": [ 195 | "symbol" 196 | ] 197 | } 198 | } 199 | }, 200 | { 201 | "type": "function", 202 | "function": { 203 | "name": "get_company_profile", 204 | "description": "get_company_profile(symbol: str) -> dict - Get company profile and overview for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing company profile and overview.", 205 | "parameters": { 206 | "type": "object", 207 | "properties": { 208 | "symbol": { 209 | "type": "string" 210 | } 211 | }, 212 | "required": [ 213 | "symbol" 214 | ] 215 | } 216 | } 217 | }, 218 | { 219 | "type": "function", 220 | "function": { 221 | "name": "search_10q", 222 | "description": "search_10q(data) - Useful to search information from the latest 10-Q form for a\n given stock.\n The input to this tool should be a pipe (|) separated text of\n length two, representing the stock ticker you are interested and what\n question you have from it.\n\t\tFor example, `AAPL|what was last quarter's revenue`.", 223 | "parameters": { 224 | "type": "object", 225 | "properties": { 226 | "data": {} 227 | }, 228 | "required": [ 229 | "data" 230 | ] 231 | } 232 | } 233 | }, 234 | { 235 | "type": "function", 236 | "function": { 237 | "name": "search_10k", 238 | "description": "search_10k(data) - Useful to search information from the latest 10-K form for a\ngiven stock.\nThe input to this tool should be a pipe (|) separated text of\nlength two, representing the stock ticker you are interested, what\nquestion you have from it.\nFor example, `AAPL|what was last year's revenue`.", 239 | "parameters": { 240 | "type": "object", 241 | "properties": { 242 | "data": {} 243 | }, 244 | "required": [ 245 | "data" 246 | ] 247 | } 248 | } 249 | }, 250 | { 251 | "type": "function", 252 | "function": { 253 | "name": "get_historical_price", 254 | "description": "get_historical_price(symbol, start_date, end_date) - Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.\n- symbol (str): Stock ticker symbol.\n- end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date\n- start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date", 255 | "parameters": { 256 | "type": "object", 257 | "properties": { 258 | "symbol": {}, 259 | "start_date": {}, 260 | "end_date": {} 261 | }, 262 | "required": [ 263 | "symbol", 264 | "start_date", 265 | "end_date" 266 | ] 267 | } 268 | } 269 | }, 270 | { 271 | "type": "function", 272 | "function": { 273 | "name": "plot_price_over_time", 274 | "description": "plot_price_over_time(symbol, start_date, end_date) - Plots the historical stock prices for a given symbol over a specified date range.\n- symbol (str): Stock ticker symbol.\n- start_date (str): Start date for the historical data in 'YYYY-MM-DD' format.\n- end_date (str): End date for the historical data in 'YYYY-MM-DD' format.", 275 | "parameters": { 276 | "type": "object", 277 | "properties": { 278 | "symbol": {}, 279 | "start_date": {}, 280 | "end_date": {} 281 | }, 282 | "required": [ 283 | "symbol", 284 | "start_date", 285 | "end_date" 286 | ] 287 | } 288 | } 289 | } 290 | ] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import argparse 5 | import networkx as nx 6 | from datetime import datetime 7 | 8 | import streamlit as st 9 | 10 | from src.rag_tools import * 11 | from src.prompter import PromptManager 12 | from src.resources import Resource 13 | from src.agents import Agent 14 | from src.clients import CLIENTS 15 | from src.tools import get_openai_tools, get_function_names 16 | from src.utils import inference_logger 17 | 18 | from matplotlib import pyplot as plt 19 | 20 | import logfire 21 | 22 | class AgentOrchestrator: 23 | def __init__(self, agents: List['Agent'], agents_config, resources: List['Resource'], verbose: bool = False, log_file: str = "orchestrator_log.json"): 24 | self.client = os.getenv('ORCHESTRATOR_CLIENT') 25 | self.agents = agents 26 | self.agents_config = agents_config 27 | self.resources = resources 28 | self.verbose = verbose 29 | self.log_file = log_file 30 | self.log_data = [] 31 | self.llama_logs = [] 32 | 33 | def run(self, query: str) -> str: 34 | #tools = get_openai_tools() 35 | tools = get_function_names() 36 | 37 | tool_descriptions = [] 38 | for tool in tools: 39 | tool_descriptions.append({ 40 | "name": tool.name, 41 | "description": tool.description 42 | }) 43 | 44 | mermaid_graph, agents_metadata = self.load_or_generate_graph(query, self.agents_config, tool_descriptions, self.resources) 45 | st.write(mermaid_graph) 46 | 47 | G = nx.DiGraph() 48 | for agent_data in agents_metadata: 49 | agent_role = agent_data["role"] 50 | G.add_node(agent_role, **agent_data) 51 | 52 | # Add edges to the graph based on the dependencies 53 | for agent_data in agents_metadata: 54 | agent_role = agent_data["role"] 55 | dependencies = agent_data.get("dependencies", []) 56 | for dependency in dependencies: 57 | G.add_edge(dependency, agent_role) 58 | 59 | # Visualize the graph using python-mermaid 60 | self.visualize_graph(G) 61 | 62 | # Create a dictionary to store the output of each agent 63 | agent_outputs = {} 64 | 65 | # Execute agents in topological order (respecting dependencies) 66 | for agent_role in nx.topological_sort(G): 67 | agent_data = G.nodes[agent_role] 68 | agent = Agent(**agent_data) 69 | 70 | st.write(f"Starting Agent: {agent.role}", unsafe_allow_html=True) 71 | if agent.verbose: 72 | st.write(f"Agent Persona: {agent.persona}", unsafe_allow_html=True) 73 | st.write(f"Agent Goal: {agent.goal}", unsafe_allow_html=True) 74 | 75 | # Prepare the input messages for the agent 76 | input_messages = [] 77 | for predecessor in G.predecessors(agent_role): 78 | if predecessor in agent_outputs: 79 | input_messages.append({"role": predecessor, "content": agent_outputs[predecessor]}) 80 | 81 | agent.input_messages = input_messages 82 | 83 | # Execute the agent 84 | output = agent.execute() 85 | 86 | if agent.verbose: 87 | st.write(f"Agent Output:\n{output}\n", unsafe_allow_html=True) 88 | 89 | agent_outputs[agent_role] = output 90 | self.llama_logs.extend(agent.interactions) 91 | 92 | # Collect the final output from all the agents 93 | final_output = "\n".join([f"Agent: {role}\nGoal: {G.nodes[role]['goal']}\nOutput:\n{output}\n" for role, output in agent_outputs.items()]) 94 | 95 | self.save_logs() 96 | self.save_llama_logs() 97 | 98 | return final_output 99 | 100 | def visualize_graph(self, G): 101 | pos = nx.spring_layout(G) 102 | nx.draw(G, pos, with_labels=False, node_size=1000, node_color='lightblue', font_size=12, font_weight='bold', arrows=True) 103 | labels = nx.get_node_attributes(G, 'role') 104 | nx.draw_networkx_labels(G, pos, labels, font_size=12) 105 | plt.axis('off') 106 | plt.tight_layout() 107 | st.pyplot(plt) 108 | plt.close() 109 | 110 | def load_or_generate_graph(self, query, agents, tools, resources): 111 | mermaid_graph_file = "mermaid_graph.txt" 112 | agent_metadata_file = "agent_metadata.json" 113 | 114 | if os.path.exists(mermaid_graph_file) and os.path.exists(agent_metadata_file): 115 | with open(mermaid_graph_file, "r") as file: 116 | mermaid_graph = file.read() 117 | with open(agent_metadata_file, "r") as file: 118 | agents_metadata = json.load(file) 119 | else: 120 | mermaid_graph = self.agent_dispatcher(query, agents, tools, resources) 121 | agents_metadata = self.extract_agents_from_mermaid(mermaid_graph) 122 | 123 | with open(mermaid_graph_file, "w") as file: 124 | file.write(mermaid_graph) 125 | with open(agent_metadata_file, "w") as file: 126 | json.dump(agents_metadata, file, indent=2) 127 | 128 | return mermaid_graph, agents_metadata 129 | 130 | def agent_dispatcher(self, query, agents, tools, resources): 131 | chat = [{"role": "user", "content": query}] 132 | prompter = PromptManager() 133 | sys_prompt = prompter.generate_prompt(tools, agents, resources, one_shot=True) 134 | 135 | #response = CLIENTS.chat_completion( 136 | # client="anthropic", 137 | # messages=[ 138 | # {"role": "system", "content": sys_prompt}, 139 | # *chat 140 | # ] 141 | #) 142 | inference_logger.info(f"Running inference with {self.client}") 143 | response = CLIENTS.chat_completion( 144 | client=self.client, 145 | messages=[ 146 | {"role": "system", "content": sys_prompt}, 147 | *chat 148 | ] 149 | ) 150 | inference_logger.info(f"Assistant Message:\n{response}") 151 | inference_logger.info(response) 152 | st.write(response) 153 | return response 154 | 155 | def extract_agents_from_mermaid(self, mermaid_graph): 156 | graph_content = re.search(r'(.*?)', mermaid_graph, re.DOTALL) 157 | 158 | if graph_content: 159 | graph_content = graph_content.group(1) 160 | 161 | metadata_content = re.search(r'(.*?)', mermaid_graph, re.DOTALL) 162 | if metadata_content: 163 | metadata_content = metadata_content.group(1) 164 | 165 | dependency_pattern = r'(\w+) --> (\w+)' 166 | 167 | agents_metadata = json.loads(metadata_content) 168 | dependencies = [] 169 | 170 | for match in re.finditer(dependency_pattern, graph_content): 171 | source = match.group(1) 172 | target = match.group(2) 173 | dependencies.append((source, target)) 174 | 175 | #for agent in agents_metadata: 176 | # agent["dependencies"] = [target for source, target in dependencies if source == agent["role"]] 177 | 178 | return agents_metadata 179 | 180 | def save_llama_logs(self): 181 | with open(("qa_interactions" + datetime.now().strftime("%Y%m%d%H%M%S") + ".json"), "w") as file: 182 | json.dump(self.llama_logs, file, indent=2) 183 | 184 | def save_logs(self): 185 | with open(self.log_file, "w") as file: 186 | json.dump(self.log_data, file, indent=2) 187 | 188 | def parse_args(): 189 | parser = argparse.ArgumentParser(description="Run the agent orchestrator with dynamic configurations.") 190 | parser.add_argument('-q', '--query', type=str, help="user query for agents to assist with", required=True) 191 | return parser.parse_args() 192 | 193 | def mainflow(): 194 | st.title("Stock Analysis with MeeseeksAI Agents") 195 | multiline_text = """ 196 | Try to ask it "What is the current price of Meta stock?" or "Show me the historical prices of Apple vs Microsoft stock over the past 6 months.". 197 | """ 198 | 199 | st.markdown(multiline_text, unsafe_allow_html=True) 200 | 201 | # Add customization options to the sidebar 202 | #st.sidebar.title('Customization') 203 | #additional_context = st.sidebar.text_input('Enter additional summarization context for the LLM here (i.e. write it in spanish):') 204 | 205 | # Get the user's question 206 | user_question = st.text_input("Ask a question about a stock or multiple stocks:") 207 | 208 | if user_question: 209 | file_path = os.path.join(os.getcwd()) 210 | with open(os.path.join(file_path, "configs/agents.json"), "r") as file: 211 | agents_data = json.load(file) 212 | agents = [Agent(**agent_data) for agent_data in agents_data] 213 | 214 | with open(os.path.join(file_path, "configs/resources.json"), "r") as file: 215 | resources_data = json.load(file) 216 | resources = [Resource(**resource_data) for resource_data in resources_data] 217 | 218 | orchestrator = AgentOrchestrator( 219 | agents=agents, 220 | agents_config=agents_data, 221 | resources=resources, 222 | verbose=True, 223 | log_file="orchestrator_log" + datetime.now().strftime("%Y%m%d%H%M%S") + ".json" 224 | ) 225 | 226 | orchestrator.run(user_question) 227 | 228 | ## Wrap the final output in a scrollable container 229 | #output_container = st.container() 230 | #with output_container: 231 | # st.write(f"Final output:\n{result}") 232 | # 233 | ## Make the output container scrollable 234 | #output_container_height = min(len(result.split('\n')) * 30, 500) # Adjust the height based on the number of lines 235 | #output_container.markdown( 236 | # f""" 237 | # 243 | # """, 244 | # unsafe_allow_html=True 245 | #) 246 | 247 | if __name__ == "__main__": 248 | mainflow() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | PyYAML 3 | requests 4 | bs4 5 | tiktoken 6 | pandas 7 | yfinance 8 | PyPDF2 9 | textblob 10 | gpt4all 11 | spacy 12 | groq 13 | python-dotenv 14 | logfire 15 | sec_api 16 | unstructured 17 | sentence_transformers 18 | faiss-cpu 19 | streamlit 20 | plotly 21 | langchain 22 | anthropic -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interstellarninja/MeeseeksAI/432fac235a8521a84dce541be9561323df80766e/src/__init__.py -------------------------------------------------------------------------------- /src/agents.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Optional 3 | import uuid 4 | from pydantic import BaseModel, Field 5 | from datetime import datetime 6 | from src.clients import CLIENTS 7 | from src import tools 8 | from src.tools import * 9 | from src.rag_tools import * 10 | from src.utils import inference_logger 11 | from src.utils import validate_and_extract_tool_calls 12 | from langchain.tools import StructuredTool, BaseTool 13 | 14 | from langchain_core.messages import ToolMessage 15 | 16 | ## TODO: add default tools such as "get_user_feedback", "get_additional_context", "code_interpreter" etc. 17 | 18 | class Agent(BaseModel): 19 | class Config: 20 | arbitrary_types_allowed = True # Allow arbitrary types 21 | exclude = {"client", "tool_objects"} 22 | 23 | id: str = Field(default_factory=lambda: str(uuid.uuid4())) 24 | role: str 25 | persona: Optional[str] = None 26 | goal: str 27 | tools: List[str] = [] 28 | dependencies: Optional[List[str]] = None 29 | user_feedback: bool = False 30 | verbose: bool = True 31 | model: str = Field(default_factory=lambda: os.getenv('AGENT_MODEL')) # agent model from environment variable 32 | max_iter: int = 5 33 | client: str = Field(default_factory=lambda: os.getenv('AGENT_CLIENT')) 34 | tool_objects: Dict[str, Any] = {} 35 | input_messages: List[Dict] = [] 36 | interactions: List[Dict] = [] 37 | 38 | def __init__(self, **data: Any): 39 | super().__init__(**data) 40 | if not self.client: 41 | raise ValueError("Invalid client specified.") 42 | self.tool_objects = self.create_tool_objects() 43 | 44 | def create_tool_objects(self) -> Dict[str, Any]: 45 | tool_objects = {} 46 | 47 | if self.user_feedback: 48 | self.tools.append('speak_to_the_user') 49 | 50 | for tool_name in self.tools: 51 | if tool_name in globals(): 52 | tool_objects[tool_name] = globals()[tool_name] 53 | else: 54 | raise ValueError(f"Tool '{tool_name}' not found.") 55 | return tool_objects 56 | 57 | def execute(self) -> str: 58 | messages = [] 59 | if self.persona and self.verbose: 60 | current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 61 | messages.append({"role": "system", "content": f"You are a {self.role} with the persona: {self.persona}. Current date and time: {current_datetime}"}) 62 | 63 | # Check if the agent has available tools 64 | if self.tool_objects: 65 | #print(self.tool_objects) 66 | # Serialize the tool objects to JSON schema 67 | tool_schemas = [] 68 | for tool_name, tool_object in self.tool_objects.items(): 69 | if isinstance(tool_object, StructuredTool): 70 | tool_schema = { 71 | "name": tool_object.name, 72 | "description": tool_object.description, 73 | "parameters": tool_object.args_schema.schema() 74 | } 75 | tool_schemas.append(tool_schema) 76 | 77 | # Append the tool schemas to the system prompt within tags 78 | system_prompt = "You are a function calling AI model." 79 | system_prompt += f"\nHere are the available tools:\n\n{json.dumps(tool_schemas, indent=2)}\n\n" 80 | system_prompt += f"You should call the tools provided to you sequentially\n" 81 | system_prompt += """ 82 | Please use XML tags to record your reasoning and planning before you call the functions as follows: 83 | 84 | {step-by-step reasoning and plan in bullet points} 85 | 86 | For each function call return a json object with function name and arguments within XML tags as follows: 87 | 88 | {"arguments": , "name": } 89 | 90 | """ 91 | messages.append({"role": "system", "content": system_prompt}) 92 | 93 | if self.input_messages: 94 | for input_message in self.input_messages: 95 | role = input_message["role"] 96 | content = input_message["content"] 97 | inference_logger.info(f"Appending input messages from previous agent: {role}") 98 | messages.append({"role": "system", "content": f"\n<{role}>\n{content}\n\n"}) 99 | 100 | messages.append({"role": "user", "content": f"Your task is to {self.goal}."}) 101 | 102 | depth = 0 103 | while depth < self.max_iter: 104 | inference_logger.info(f"Running inference with {self.client}") 105 | result = CLIENTS.chat_completion( 106 | client=self.client, 107 | messages=messages, 108 | ) 109 | inference_logger.info(f"Assistant Message:\n{result}") 110 | messages.append({"role": "assistant", "content": result}) 111 | 112 | 113 | # Process the agent's response and extract tool calls 114 | if self.tool_objects: 115 | validation, tool_calls, scratchpad, error_message = validate_and_extract_tool_calls(result) 116 | 117 | if validation and tool_calls: 118 | inference_logger.info(f"Parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") 119 | 120 | # Print parsed tool calls as JSON markdown in Streamlit 121 | if self.verbose: 122 | st.write(f"Agent Plan:\n{scratchpad}\n", unsafe_allow_html=True) 123 | st.markdown(f"**Parsed Tool Calls:**") 124 | st.json(tool_calls) 125 | 126 | # Execute the tool calls 127 | tool_message = f"Sub-agent iteration {depth} to assist with user query: {self.goal}. Summarize the tool results:\n" 128 | if tool_calls: 129 | for tool_call in tool_calls: 130 | tool_name = tool_call.get("name") 131 | tool_object = self.tool_objects.get(tool_name) 132 | if tool_object: 133 | try: 134 | tool_args = tool_call.get("arguments", {}) 135 | 136 | # Print invoking tool message in blue 137 | st.markdown(f"Invoking tool: {tool_name}", unsafe_allow_html=True) 138 | 139 | tool_result = tool_object._run(**tool_args) if isinstance(tool_object, BaseTool) else self.execute_function_call(tool_call) 140 | 141 | tool_message += f"\n{tool_result}\n\n" 142 | inference_logger.info(f"Response from tool '{tool_name}':\n{tool_result}") 143 | 144 | # Print tool response in green 145 | #st.markdown(f"Tool response: {tool_result}", unsafe_allow_html=True) 146 | except Exception as e: 147 | error_message = f"Error executing tool '{tool_name}': {str(e)}" 148 | tool_message += f"\n{error_message}\n\n" 149 | else: 150 | error_message = f"Tool '{tool_name}' not found." 151 | tool_message += f"\n{error_message}\n\n" 152 | messages.append({"role": "user", "content": tool_message}) 153 | #messages.append({"role": "tool", "content": tool_message}) 154 | #messages.append(ToolMessage(tool_message, tool_call_id=0)) 155 | else: 156 | inference_logger.info(f"No tool calls found in the agent's response.") 157 | break 158 | elif error_message: 159 | inference_logger.info(f"Error parsing tool calls: {error_message}") 160 | tool_message = f"\n{error_message}\n\n" 161 | messages.append({"role": "user", "content": tool_message}) 162 | #messages.append({"role": "tool", "content": tool_message}) 163 | #messages.append(ToolMessage(tool_message, tool_call_id=0)) 164 | 165 | depth += 1 166 | else: 167 | break 168 | 169 | # Log the final interaction 170 | self.log_interaction(messages, result) 171 | return result 172 | 173 | def execute_function_call(self, tool_call): 174 | function_name = tool_call.get("name") 175 | function_to_call = getattr(tools, function_name, None) 176 | function_args = tool_call.get("arguments", {}) 177 | 178 | if function_to_call: 179 | inference_logger.info(f"Invoking function call {function_name} ...") 180 | function_response = function_to_call(**function_args) 181 | results_dict = f'{{"name": "{function_name}", "content": {json.dumps(function_response)}}}' 182 | return results_dict 183 | else: 184 | raise ValueError(f"Function '{function_name}' not found.") 185 | 186 | def log_interaction(self, prompt, response): 187 | self.interactions.append({ 188 | "role": self.role, 189 | "messages": prompt, 190 | "response": response, 191 | "agent_messages": self.input_messages, 192 | "tools": self.tools, 193 | "timestamp": datetime.now().isoformat() 194 | }) -------------------------------------------------------------------------------- /src/clients.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI 3 | from anthropic import Anthropic 4 | from groq import Groq 5 | from dotenv import load_dotenv 6 | 7 | load_dotenv(".env", override=True) 8 | 9 | class Clients: 10 | def __init__(self): 11 | self.clients = {} 12 | 13 | def initialize_ollama(self): 14 | self.clients["ollama"] = { 15 | "client": OpenAI( 16 | base_url='http://localhost:11434/v1', 17 | api_key='ollama', 18 | ), 19 | "model": os.getenv("OLLAMA_MODEL") 20 | } 21 | 22 | def initialize_groq(self): 23 | self.clients["groq"] = { 24 | "client": Groq( 25 | api_key=os.getenv('GROQ_API_KEY') 26 | ), 27 | "model": os.getenv("GROQ_MODEL") 28 | } 29 | 30 | def initialize_anthropic(self): 31 | self.clients["anthropic"] = { 32 | "client": Anthropic( 33 | api_key=os.getenv("ANTHROPIC_API_KEY") 34 | ), 35 | "model": os.getenv("ANTHROPIC_MODEL") 36 | } 37 | 38 | def initialize_lmstudio(self): 39 | self.clients["lmstudio"] = { 40 | "client": OpenAI( 41 | base_url="http://localhost:1234/v1", 42 | #base_url="http://192.168.1.2:1234/v1", 43 | api_key="lm-studio" 44 | ), 45 | "model": os.getenv("LMSTUDIO_MODEL") 46 | } 47 | 48 | def initialize_localllama(self): 49 | from src.inference import ModelInference 50 | self.clients["localllama"] = { 51 | "client": ModelInference( 52 | model_path=os.getenv("LOCAL_MODEL_PATH"), 53 | load_in_4bit=os.getenv("LOAD_IN_4BIT", "False") 54 | ), 55 | "model": None 56 | } 57 | 58 | def chat_completion(self, client, messages): 59 | if client == "ollama": 60 | self.initialize_ollama() 61 | elif client == "groq": 62 | self.initialize_groq() 63 | elif client == "anthropic": 64 | self.initialize_anthropic() 65 | elif client == "lmstudio": 66 | self.initialize_lmstudio() 67 | elif client == "localllama": 68 | self.initialize_localllama() 69 | else: 70 | raise ValueError(f"Unsupported client: {client}") 71 | 72 | response = self.clients[client]["client"].chat.completions.create( 73 | model=self.clients[client]["model"], 74 | messages=messages, 75 | ) 76 | completion = response.choices[0].message.content 77 | 78 | return completion 79 | 80 | CLIENTS = Clients() 81 | 82 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | BitsAndBytesConfig 8 | ) 9 | 10 | from src.utils import ( 11 | inference_logger, 12 | ) 13 | 14 | class ModelInference: 15 | def __init__(self, model_path, load_in_4bit): 16 | self.bnb_config = None 17 | 18 | if load_in_4bit == "True": 19 | self.bnb_config = BitsAndBytesConfig( 20 | load_in_4bit=True, 21 | bnb_4bit_quant_type="nf4", 22 | bnb_4bit_use_double_quant=True, 23 | llm_int8_enable_fp32_cpu_offload=True, 24 | ) 25 | self.model = AutoModelForCausalLM.from_pretrained( 26 | model_path, 27 | trust_remote_code=True, 28 | return_dict=True, 29 | quantization_config=self.bnb_config, 30 | torch_dtype=torch.float32, 31 | #attn_implementation="flash_attention_2", 32 | device_map="auto", 33 | ) 34 | 35 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 36 | self.tokenizer.pad_token = self.tokenizer.eos_token 37 | self.tokenizer.padding_side = "left" 38 | 39 | if self.tokenizer.chat_template is None: 40 | print("No chat template defined, getting chat_template...") 41 | self.tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" 42 | 43 | inference_logger.info(self.model.config) 44 | inference_logger.info(self.model.generation_config) 45 | inference_logger.info(self.tokenizer.special_tokens_map) 46 | 47 | 48 | def run_inference(self, prompt): 49 | inputs = self.tokenizer.apply_chat_template( 50 | prompt, 51 | add_generation_prompt=True, 52 | return_tensors='pt' 53 | ) 54 | 55 | tokens = self.model.generate( 56 | inputs.to(self.model.device), 57 | max_new_tokens=1500, 58 | temperature=0.8, 59 | repetition_penalty=1.1, 60 | do_sample=True, 61 | eos_token_id=self.tokenizer.eos_token_id 62 | ) 63 | completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True) 64 | 65 | assistant_message = self.get_assistant_message(completion, "chatml", self.tokenizer.eos_token) 66 | return assistant_message 67 | 68 | def get_assistant_message(self, completion, chat_template, eos_token): 69 | """define and match pattern to find the assistant message""" 70 | completion = completion.strip() 71 | if chat_template == "chatml": 72 | assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL) 73 | else: 74 | raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.") 75 | 76 | assistant_match = assistant_pattern.search(completion) 77 | if assistant_match: 78 | assistant_content = assistant_match.group(1).strip() 79 | return assistant_content.replace(eos_token, "") 80 | else: 81 | assistant_content = None 82 | inference_logger.info("No match found for the assistant pattern") 83 | return assistant_content -------------------------------------------------------------------------------- /src/prompter.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pydantic import BaseModel 3 | from typing import Dict 4 | from src.schema import Agent 5 | from src.utils import ( 6 | get_fewshot_examples 7 | ) 8 | import yaml 9 | import json 10 | import os 11 | 12 | class PromptSchema(BaseModel): 13 | Role: str 14 | Objective: str 15 | Agents: str 16 | Tools: str 17 | #Resources: str 18 | #Examples: str 19 | Schema: str 20 | Instructions: str 21 | 22 | class PromptManager: 23 | def __init__(self): 24 | self.script_dir = os.path.dirname(os.path.abspath(__file__)) 25 | 26 | def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str: 27 | formatted_prompt = "" 28 | for field, value in prompt_schema.dict().items(): 29 | if field == "Examples" and variables.get("examples") is None: 30 | continue 31 | formatted_value = value.format(**variables) 32 | if field == "Instructions": 33 | formatted_prompt += f"{formatted_value}" 34 | else: 35 | formatted_value = formatted_value.replace("\n", " ") 36 | formatted_prompt += f"{formatted_value}" 37 | return formatted_prompt 38 | 39 | def read_yaml_file(self, file_path: str) -> PromptSchema: 40 | with open(file_path, 'r') as file: 41 | yaml_content = yaml.safe_load(file) 42 | 43 | prompt_schema = PromptSchema( 44 | Role=yaml_content.get('Role', ''), 45 | Objective=yaml_content.get('Objective', ''), 46 | Agents=yaml_content.get('Agents', ''), 47 | Tools=yaml_content.get('Tools', ''), 48 | #Resources=yaml_content.get('Resources', ''), 49 | #Examples=yaml_content.get('Examples', ''), 50 | Schema=yaml_content.get('Schema', ''), 51 | Instructions=yaml_content.get('Instructions', ''), 52 | ) 53 | return prompt_schema 54 | 55 | def generate_prompt(self, tools, agents, resources, one_shot=False): 56 | prompt_path = os.path.join(self.script_dir, '../configs', 'sys_prompt.yaml') 57 | prompt_schema = self.read_yaml_file(prompt_path) 58 | 59 | schema_json = json.loads(Agent.schema_json()) 60 | #schema = schema_json.get("properties", {}) 61 | 62 | variables = { 63 | "date": datetime.date.today(), 64 | "agents": agents, 65 | "tools": tools, 66 | "resources": None, 67 | #"examples": examples, 68 | "schema": schema_json 69 | } 70 | sys_prompt = self.format_yaml_prompt(prompt_schema, variables) 71 | #print(sys_prompt) 72 | 73 | prompt = [ 74 | {'role': 'system', 'content': sys_prompt} 75 | ] 76 | 77 | if one_shot: 78 | #examples = get_fewshot_examples(num_fewshot) 79 | with open(os.path.join(self.script_dir, '../configs', 'example.txt'), 'r') as file: 80 | examples = file.read() 81 | 82 | prompt.extend([ 83 | {"role": "user", "content": "Perform fundamental analysis of NVDA stock and provide portfolio recommendations"}, 84 | {"role": "assistant", "content": examples} 85 | ]) 86 | print(prompt) 87 | return sys_prompt -------------------------------------------------------------------------------- /src/rag_tools.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import PyPDF2 4 | import requests 5 | from bs4 import BeautifulSoup 6 | import spacy 7 | from textblob import TextBlob 8 | import pickle 9 | from gpt4all import Embed4All 10 | from typing import Any, Callable, Dict, List, Optional, Type 11 | from pydantic import BaseModel, Field 12 | from langchain.tools import BaseTool 13 | from langchain.callbacks.manager import CallbackManagerForToolRun 14 | 15 | from src.resources import TextChunker 16 | 17 | class WikipediaSearchInput(BaseModel): 18 | query: str = Field(description="The search query for Wikipedia") 19 | top_k: int = Field(default=3, description="The number of top search results to return") 20 | 21 | class WikipediaSearchTool(BaseTool): 22 | name = "wikipedia_search" 23 | description = "Searches Wikipedia for relevant information based on a given query" 24 | args_schema: Type[BaseModel] = WikipediaSearchInput 25 | 26 | def _run(self, query: str, top_k: int = 3, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, str]]: 27 | url = f"https://en.wikipedia.org/w/index.php?search={query}&title=Special:Search&fulltext=1" 28 | response = requests.get(url) 29 | soup = BeautifulSoup(response.text, 'html.parser') 30 | 31 | search_results = [] 32 | for result in soup.find_all('li', class_='mw-search-result'): 33 | title = result.find('a').get_text() 34 | url = 'https://en.wikipedia.org' + result.find('a')['href'] 35 | page_response = requests.get(url) 36 | page_soup = BeautifulSoup(page_response.text, 'html.parser') 37 | content = page_soup.find('div', class_='mw-parser-output').get_text() 38 | chunks = TextChunker().chunk_text(text=content, chunk_size=1000, num_chunks=10) 39 | search_results.append({'title': title, 'url': url, 'chunks': chunks}) 40 | if len(search_results) >= top_k: 41 | break 42 | 43 | return search_results 44 | 45 | async def _arun(self, query: str, top_k: int = 3, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, str]]: 46 | raise NotImplementedError("WikipediaSearchTool does not support async") 47 | 48 | class SemanticFileSearchInput(BaseModel): 49 | query: str = Field(description="The search query for semantic file search") 50 | 51 | class SemanticFileSearchTool(BaseTool): 52 | name = "semantic_file_search" 53 | description = "Performs semantic search on a set of files based on a given query" 54 | args_schema: Type[BaseModel] = SemanticFileSearchInput 55 | 56 | def __init__(self, file_paths: List[str], embed_model: str, embed_dim: int = 768, chunk_size: int = 1000, top_k: int = 3): 57 | self.embedder = Embed4All(embed_model) 58 | self.embed_dim = embed_dim 59 | self.chunk_size = chunk_size 60 | self.top_k = top_k 61 | self.chunker = TextChunker(text=None, chunk_size=chunk_size) 62 | self.file_embeddings = self.load_or_generate_file_embeddings(file_paths) 63 | 64 | def load_or_generate_file_embeddings(self, file_paths: List[str]) -> Dict[str, List[Dict[str, Any]]]: 65 | file_hash = self.get_file_hash(file_paths) 66 | pickle_file = f"file_embeddings_{file_hash}.pickle" 67 | if os.path.exists(pickle_file): 68 | self.load_embeddings(pickle_file) 69 | else: 70 | self.file_embeddings = self.generate_file_embeddings(file_paths) 71 | self.save_embeddings(pickle_file) 72 | return self.file_embeddings 73 | 74 | def get_file_hash(self, file_paths: List[str]) -> str: 75 | file_contents = "".join(sorted([os.path.basename(path) for path in file_paths])) 76 | return hashlib.sha256(file_contents.encode()).hexdigest() 77 | 78 | def generate_file_embeddings(self, file_paths: List[str]) -> Dict[str, List[Dict[str, Any]]]: 79 | file_embeddings = {} 80 | for file_path in file_paths: 81 | if file_path.endswith('.pdf'): 82 | text = self.extract_text_from_pdf(file_path) 83 | else: 84 | with open(file_path, 'r') as file: 85 | text = file.read() 86 | chunks = self.chunker.chunk_text(text=text, chunk_size=self.chunk_size) 87 | chunk_embeddings = [self.embedder.embed(chunk['text'], prefix='search_document') for chunk in chunks] 88 | file_embeddings[file_path] = [(chunk['text'], embedding) for chunk, embedding in zip(chunks, chunk_embeddings)] 89 | return file_embeddings 90 | 91 | def extract_text_from_pdf(self, file_path: str) -> str: 92 | with open(file_path, "rb") as file: 93 | pdf_reader = PyPDF2.PdfReader(file) 94 | text = "" 95 | for page in pdf_reader.pages: 96 | text += page.extract_text() + "\n" 97 | return text 98 | 99 | def save_embeddings(self, pickle_file: str): 100 | with open(pickle_file, 'wb') as f: 101 | pickle.dump(self.file_embeddings, f) 102 | 103 | def load_embeddings(self, pickle_file: str): 104 | with open(pickle_file, 'rb') as f: 105 | self.file_embeddings = pickle.load(f) 106 | 107 | def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 108 | query_embedding = self.embedder.embed(query, prefix='search_query') 109 | scores = [] 110 | for file_path, chunk_data in self.file_embeddings.items(): 111 | for chunk_text, embedding in chunk_data: 112 | chunk_score = self.cosine_similarity(query_embedding, embedding) 113 | scores.append(((file_path, chunk_text), chunk_score)) 114 | sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True) 115 | top_scores = sorted_scores[:self.top_k] 116 | result = [] 117 | for (file_path, chunk_text), score in top_scores: 118 | result.append({ 119 | 'file': file_path, 120 | 'text': chunk_text, 121 | 'score': score 122 | }) 123 | return result 124 | 125 | def cosine_similarity(self, a: List[float], b: List[float]) -> float: 126 | import numpy as np 127 | a = np.array(a) 128 | b = np.array(b) 129 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) 130 | 131 | async def _arun(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 132 | raise NotImplementedError("SemanticFileSearchTool does not support async") 133 | 134 | class TextReaderInput(BaseModel): 135 | text_file: str = Field(description="The path to the text file to read") 136 | 137 | class TextReaderTool(BaseTool): 138 | name = "text_reader" 139 | description = "Reads text from a file and chunks it into smaller pieces" 140 | args_schema: Type[BaseModel] = TextReaderInput 141 | 142 | def __init__(self, chunk_size: int, num_chunks: int): 143 | self.chunk_size = chunk_size 144 | self.num_chunks = num_chunks 145 | 146 | def _run(self, text_file: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 147 | with open(text_file, "r") as file: 148 | text = file.read() 149 | chunker = TextChunker(text, self.chunk_size, overlap=0) 150 | chunks = chunker.chunk_text() 151 | return chunks[:self.num_chunks] 152 | 153 | async def _arun(self, text_file: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 154 | raise NotImplementedError("TextReaderTool does not support async") 155 | 156 | class WebScraperInput(BaseModel): 157 | url: str = Field(description="The URL of the web page to scrape") 158 | 159 | class WebScraperTool(BaseTool): 160 | name = "web_scraper" 161 | description = "Scrapes text from a web page and chunks it into smaller pieces" 162 | args_schema: Type[BaseModel] = WebScraperInput 163 | 164 | def __init__(self, chunk_size: int, num_chunks: int): 165 | self.chunk_size = chunk_size 166 | self.num_chunks = num_chunks 167 | 168 | def _run(self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 169 | response = requests.get(url) 170 | soup = BeautifulSoup(response.text, 'html.parser') 171 | 172 | text = soup.get_text(separator='\n') 173 | cleaned_text = ' '.join(text.split()) 174 | 175 | chunker = TextChunker(cleaned_text, self.chunk_size, overlap=0) 176 | chunks = chunker.chunk_text() 177 | return chunks[:self.num_chunks] 178 | 179 | async def _arun(self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 180 | raise NotImplementedError("WebScraperTool does not support async") 181 | 182 | class NERExtractionInput(BaseModel): 183 | text: str = Field(description="The text to extract named entities from") 184 | 185 | class NERExtractionTool(BaseTool): 186 | name = "ner_extraction" 187 | description = "Extracts named entities from text using spaCy" 188 | args_schema: Type[BaseModel] = NERExtractionInput 189 | 190 | def __init__(self): 191 | self.nlp = spacy.load("en_core_web_sm") 192 | 193 | def _run(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 194 | doc = self.nlp(text) 195 | entities = [] 196 | 197 | for ent in doc.ents: 198 | entities.append({ 199 | "text": ent.text, 200 | "start": ent.start_char, 201 | "end": ent.end_char, 202 | "label": ent.label_ 203 | }) 204 | 205 | return entities 206 | 207 | async def _arun(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> List[Dict[str, Any]]: 208 | raise NotImplementedError("NERExtractionTool does not support async") 209 | 210 | class SemanticAnalysisInput(BaseModel): 211 | text: str = Field(description="The text to perform sentiment analysis on") 212 | 213 | class SemanticAnalysisTool(BaseTool): 214 | name = "semantic_analysis" 215 | description = "Performs sentiment analysis using TextBlob" 216 | args_schema: Type[BaseModel] = SemanticAnalysisInput 217 | 218 | def _run(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> Dict[str, Any]: 219 | blob = TextBlob(text) 220 | sentiment = blob.sentiment 221 | return { 222 | "polarity": sentiment.polarity, 223 | "subjectivity": sentiment.subjectivity 224 | } 225 | 226 | async def _arun(self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> Dict[str, Any]: 227 | raise NotImplementedError("SemanticAnalysisTool does not support async") -------------------------------------------------------------------------------- /src/resources.py: -------------------------------------------------------------------------------- 1 | import PyPDF2 2 | import requests 3 | from typing import Dict, Any, List, Optional 4 | 5 | from tiktoken import get_encoding 6 | 7 | from pydantic import BaseModel 8 | import PyPDF2 9 | import requests 10 | from typing import Dict, Any, List 11 | 12 | class Resource(BaseModel): 13 | type: str 14 | path: str 15 | context_template: Optional[str] = None 16 | data: Optional[str] = None 17 | chunks: Optional[List[Dict[str, Any]]] = [] 18 | 19 | def load_resource(self): 20 | if self.resource_type == 'text': 21 | return self.load_text() 22 | elif self.resource_type == 'pdf': 23 | return self.load_pdf() 24 | elif self.resource_type == 'web': 25 | return self.load_web() 26 | else: 27 | raise ValueError(f"Unsupported resource type: {self.resource_type}") 28 | 29 | def load_text(self): 30 | with open(self.resource_path, 'r') as file: 31 | return file.read() 32 | 33 | def load_pdf(self): 34 | with open(self.resource_path, 'rb') as file: 35 | pdf_reader = PyPDF2.PdfReader(file) 36 | text = "" 37 | for page in pdf_reader.pages: 38 | text += page.extract_text() + "\n" 39 | return text 40 | 41 | def load_web(self): 42 | response = requests.get(self.resource_path) 43 | return response.text 44 | 45 | def chunk_resource(self, chunk_size: int, overlap: int = 0): 46 | chunker = TextChunker(self.data, chunk_size, overlap) 47 | self.chunks = chunker.chunk_text() 48 | 49 | def contextualize_chunk(self, chunk: Dict[str, Any]) -> str: 50 | if self.context_template: 51 | return self.context_template.format( 52 | chunk=chunk['text'], 53 | file=self.resource_path, 54 | start=chunk['start'], 55 | end=chunk['end'] 56 | ) 57 | else: 58 | return chunk['text'] 59 | 60 | class TextChunker: 61 | def __init__(self, text: str = None, chunk_size: int = 1000, overlap: int = 0): 62 | self.text = text 63 | self.chunk_size = chunk_size 64 | self.overlap = overlap 65 | self.encoding = get_encoding("cl100k_base") 66 | 67 | def chunk_text(self, text: str = None, chunk_size: int = None, start_pos: int = 0) -> List[Dict[str, Any]]: 68 | if text is not None: 69 | self.text = text 70 | if chunk_size is not None: 71 | self.chunk_size = chunk_size 72 | 73 | tokens = self.encoding.encode(self.text) 74 | num_tokens = len(tokens) 75 | 76 | chunks = [] 77 | current_pos = start_pos 78 | 79 | while current_pos < num_tokens: 80 | chunk_start = max(0, current_pos - self.overlap) 81 | chunk_end = min(current_pos + self.chunk_size, num_tokens) 82 | 83 | chunk_tokens = tokens[chunk_start:chunk_end] 84 | chunk_text = self.encoding.decode(chunk_tokens) 85 | 86 | chunks.append({ 87 | "text": chunk_text, 88 | "start": chunk_start, 89 | "end": chunk_end 90 | }) 91 | 92 | current_pos += self.chunk_size - self.overlap 93 | 94 | return chunks 95 | -------------------------------------------------------------------------------- /src/schema.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Literal, Optional 2 | from pydantic import BaseModel, Field 3 | 4 | class Tool(BaseModel): 5 | name: str 6 | 7 | class Agent(BaseModel): 8 | role: str = Field(..., description="The role of the agent") 9 | goal: str = Field(..., description="The goal of the agent") 10 | persona: str = Field(..., description="The persona of the agent") 11 | tools: List[str] = Field(..., description="tool names available to the agent; use only tools provided and do not make them up") 12 | dependencies: List[str] = Field(..., description="List of agent nodes this agent depends on") 13 | 14 | class FunctionCall(BaseModel): 15 | arguments: dict 16 | """ 17 | The arguments to call the function with, as generated by the model in JSON 18 | format. Note that the model does not always generate valid JSON, and may 19 | hallucinate parameters not defined by your function schema. Validate the 20 | arguments in your code before calling your function. 21 | """ 22 | 23 | name: str 24 | """The name of the function to call.""" 25 | 26 | class FunctionDefinition(BaseModel): 27 | name: str 28 | description: Optional[str] = None 29 | parameters: Optional[Dict[str, object]] = None 30 | 31 | class FunctionSignature(BaseModel): 32 | function: FunctionDefinition 33 | type: Literal["function"] -------------------------------------------------------------------------------- /src/tasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, List, Optional 3 | from src.agents import Agent 4 | import uuid 5 | import json 6 | from datetime import datetime 7 | from pydantic import BaseModel 8 | 9 | import os 10 | import json 11 | 12 | file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs/agents.json") 13 | with open(file_path, "r") as file: 14 | agents = json.load(file) 15 | 16 | agents = [Agent(**agent) for agent in agents] 17 | 18 | class Task(BaseModel): 19 | id: str = None 20 | instructions: str 21 | expected_output: str 22 | agent: Optional[str] = None 23 | async_execution: bool = False 24 | context: Optional[List[str]] = None 25 | output_file: Optional[str] = None 26 | callback: Optional[Callable] = None 27 | human_input: bool = False 28 | tool_name: Optional[str] = None 29 | input_tasks: Optional[List["Task"]] = None 30 | output_tasks: Optional[List["Task"]] = None 31 | context_agent_role: str = None 32 | prompt_data: Optional[List] = None 33 | output: Optional[str] = None 34 | 35 | 36 | def __init__(self, **data: Any): 37 | super().__init__(**data) 38 | self.id = str(uuid.uuid4()) 39 | self.agent = self.load_agent(self.agent) 40 | self.context = self.context or [] 41 | self.prompt_data = [] # List to hold prompt data for logging 42 | self.output = None 43 | 44 | def execute(self, context: Optional[str] = None) -> str: 45 | if not self.agent: 46 | raise Exception("No agent assigned to the task.") 47 | 48 | context_tasks = [task for task in self.context if task.output] 49 | if context_tasks: 50 | self.context_agent_role = context_tasks[0].agent.role 51 | original_context = "\n".join([f"{task.agent.role}: {task.output}" for task in context_tasks]) 52 | 53 | if self.tool_name == 'semantic_search': 54 | query = "\n".join([task.output for task in context_tasks]) 55 | context = query 56 | else: 57 | context = original_context 58 | 59 | # Prepare the prompt for logging before execution 60 | prompt_details = self.prepare_prompt(context) 61 | self.prompt_data.append(prompt_details) 62 | 63 | # Execute the task with the agent 64 | result = self.agent.execute_task(self, context) 65 | self.output = result 66 | 67 | if self.output_file: 68 | with open(self.output_file, "w") as file: 69 | file.write(result) 70 | 71 | if self.callback: 72 | self.callback(self) 73 | 74 | return result 75 | 76 | def prepare_prompt(self, context): 77 | """ Prepare and return the prompt details for logging """ 78 | prompt = { 79 | "timestamp": datetime.now().isoformat(), 80 | "task_id": self.id, 81 | "instructions": self.instructions, 82 | "context": context, 83 | "expected_output": self.expected_output 84 | } 85 | return prompt 86 | 87 | def load_agent(self, agent_role: str) -> Optional[Agent]: 88 | """ Load the agent based on the given role """ 89 | return next(agent for agent in agents if agent.role == agent_role) 90 | 91 | -------------------------------------------------------------------------------- /src/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | import re 4 | import inspect 5 | import sys 6 | import requests 7 | import pandas as pd 8 | import yfinance as yf 9 | import concurrent.futures 10 | 11 | import plotly.graph_objects as go 12 | import streamlit as st 13 | 14 | from sec_api import QueryApi 15 | from typing import List 16 | from bs4 import BeautifulSoup 17 | from src.utils import embedding_search, inference_logger 18 | from langchain.tools import tool 19 | from langchain_core.utils.function_calling import convert_to_openai_tool 20 | 21 | @tool 22 | def speak_to_the_user(message: str) -> str: 23 | """ 24 | Prompts the user to provide more context or feedback through the terminal or Streamlit interface. 25 | 26 | Args: 27 | prompt (str): The prompt or question to ask the user. 28 | 29 | Returns: 30 | str: The user's response to the prompt. 31 | """ 32 | st.write(message) 33 | user_input = st.text_input("Please provide your response:") 34 | return user_input 35 | 36 | @tool 37 | def code_interpreter(code_markdown: str) -> dict | str: 38 | """ 39 | Execute the provided Python code string on the terminal using exec. 40 | 41 | The string should contain valid, executable and pure Python code in markdown syntax. 42 | Code should also import any required Python packages. 43 | 44 | Args: 45 | code_markdown (str): The Python code with markdown syntax to be executed. 46 | For example: ```python\n\n``` 47 | 48 | Returns: 49 | dict | str: A dictionary containing variables declared and values returned by function calls, 50 | or an error message if an exception occurred. 51 | 52 | Note: 53 | Use this function with caution, as executing arbitrary code can pose security risks. 54 | """ 55 | try: 56 | # Extracting code from Markdown code block 57 | code_lines = code_markdown.split('\n')[1:-1] 58 | code_without_markdown = '\n'.join(code_lines) 59 | 60 | # Create a new namespace for code execution 61 | exec_namespace = {} 62 | 63 | # Execute the code in the new namespace 64 | exec(code_without_markdown, exec_namespace) 65 | 66 | # Collect variables and function call results 67 | result_dict = {} 68 | for name, value in exec_namespace.items(): 69 | if callable(value): 70 | try: 71 | result_dict[name] = value() 72 | except TypeError: 73 | # If the function requires arguments, attempt to call it with arguments from the namespace 74 | arg_names = inspect.getfullargspec(value).args 75 | args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names} 76 | result_dict[name] = value(**args) 77 | elif not name.startswith('_'): # Exclude variables starting with '_' 78 | result_dict[name] = value 79 | 80 | return result_dict 81 | 82 | except Exception as e: 83 | error_message = f"An error occurred: {e}" 84 | inference_logger.error(error_message) 85 | return error_message 86 | 87 | @tool("Make a calculation") 88 | def calculate(operation): 89 | """Useful to perform any mathematical calculations, 90 | like sum, minus, multiplication, division, etc. 91 | The input to this tool should be a mathematical 92 | expression, a couple examples are `200*7` or `5000/2*10` 93 | """ 94 | return eval(operation) 95 | 96 | @tool 97 | def google_search_and_scrape(query: str) -> dict: 98 | """ 99 | Performs a Google search for the given query, retrieves the top search result URLs, 100 | and scrapes the text content and table data from those pages in parallel. 101 | 102 | Args: 103 | query (str): The search query. 104 | Returns: 105 | list: A list of dictionaries containing the URL, text content, and table data for each scraped page. 106 | """ 107 | num_results = 2 108 | url = 'https://www.google.com/search' 109 | params = {'q': query, 'num': num_results} 110 | headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'} 111 | 112 | inference_logger.info(f"Performing google search with query: {query}\nplease wait...") 113 | response = requests.get(url, params=params, headers=headers) 114 | soup = BeautifulSoup(response.text, 'html.parser') 115 | urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')] 116 | 117 | inference_logger.info(f"Scraping text from urls, please wait...") 118 | [inference_logger.info(url) for url in urls] 119 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: 120 | futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)] 121 | results = [] 122 | for future in concurrent.futures.as_completed(futures): 123 | url, html = future.result() 124 | soup = BeautifulSoup(html, 'html.parser') 125 | paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()] 126 | text_content = ' '.join(paragraphs) 127 | text_content = re.sub(r'\s+', ' ', text_content) 128 | table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table') for row in table.find_all('tr')] 129 | if text_content or table_data: 130 | results.append({'url': url, 'content': text_content, 'tables': table_data}) 131 | return results 132 | 133 | @tool 134 | def get_current_stock_price(symbol: str) -> float: 135 | """ 136 | Get the current stock price for a given symbol. 137 | 138 | Args: 139 | symbol (str): The stock symbol. 140 | 141 | Returns: 142 | float: The current stock price, or None if an error occurs. 143 | """ 144 | try: 145 | stock = yf.Ticker(symbol) 146 | # Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market 147 | current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice")) 148 | return current_price if current_price else None 149 | except Exception as e: 150 | print(f"Error fetching current price for {symbol}: {e}") 151 | return None 152 | 153 | @tool 154 | def get_stock_fundamentals(symbol: str) -> dict: 155 | """ 156 | Get fundamental data for a given stock symbol using yfinance API. 157 | 158 | Args: 159 | symbol (str): The stock symbol. 160 | 161 | Returns: 162 | dict: A dictionary containing fundamental data. 163 | Keys: 164 | - 'symbol': The stock symbol. 165 | - 'company_name': The long name of the company. 166 | - 'sector': The sector to which the company belongs. 167 | - 'industry': The industry to which the company belongs. 168 | - 'market_cap': The market capitalization of the company. 169 | - 'pe_ratio': The forward price-to-earnings ratio. 170 | - 'pb_ratio': The price-to-book ratio. 171 | - 'dividend_yield': The dividend yield. 172 | - 'eps': The trailing earnings per share. 173 | - 'beta': The beta value of the stock. 174 | - '52_week_high': The 52-week high price of the stock. 175 | - '52_week_low': The 52-week low price of the stock. 176 | """ 177 | try: 178 | stock = yf.Ticker(symbol) 179 | info = stock.info 180 | fundamentals = { 181 | 'symbol': symbol, 182 | 'company_name': info.get('longName', ''), 183 | 'sector': info.get('sector', ''), 184 | 'industry': info.get('industry', ''), 185 | 'market_cap': info.get('marketCap', None), 186 | 'pe_ratio': info.get('forwardPE', None), 187 | 'pb_ratio': info.get('priceToBook', None), 188 | 'dividend_yield': info.get('dividendYield', None), 189 | 'eps': info.get('trailingEps', None), 190 | 'beta': info.get('beta', None), 191 | '52_week_high': info.get('fiftyTwoWeekHigh', None), 192 | '52_week_low': info.get('fiftyTwoWeekLow', None) 193 | } 194 | return fundamentals 195 | except Exception as e: 196 | print(f"Error getting fundamentals for {symbol}: {e}") 197 | return {} 198 | 199 | @tool 200 | def get_financial_statements(symbol: str) -> dict: 201 | """ 202 | Get financial statements for a given stock symbol. 203 | 204 | Args: 205 | symbol (str): The stock symbol. 206 | 207 | Returns: 208 | dict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement). 209 | """ 210 | try: 211 | stock = yf.Ticker(symbol) 212 | financials = stock.financials 213 | return financials 214 | except Exception as e: 215 | print(f"Error fetching financial statements for {symbol}: {e}") 216 | return {} 217 | 218 | @tool 219 | def get_key_financial_ratios(symbol: str) -> dict: 220 | """ 221 | Get key financial ratios for a given stock symbol. 222 | 223 | Args: 224 | symbol (str): The stock symbol. 225 | 226 | Returns: 227 | dict: Dictionary containing key financial ratios. 228 | """ 229 | try: 230 | stock = yf.Ticker(symbol) 231 | key_ratios = stock.info 232 | return key_ratios 233 | except Exception as e: 234 | print(f"Error fetching key financial ratios for {symbol}: {e}") 235 | return {} 236 | 237 | @tool 238 | def get_analyst_recommendations(symbol: str) -> pd.DataFrame: 239 | """ 240 | Get analyst recommendations for a given stock symbol. 241 | 242 | Args: 243 | symbol (str): The stock symbol. 244 | 245 | Returns: 246 | pd.DataFrame: DataFrame containing analyst recommendations. 247 | """ 248 | try: 249 | stock = yf.Ticker(symbol) 250 | recommendations = stock.recommendations 251 | return recommendations 252 | except Exception as e: 253 | print(f"Error fetching analyst recommendations for {symbol}: {e}") 254 | return pd.DataFrame() 255 | 256 | @tool 257 | def get_dividend_data(symbol: str) -> pd.DataFrame: 258 | """ 259 | Get dividend data for a given stock symbol. 260 | 261 | Args: 262 | symbol (str): The stock symbol. 263 | 264 | Returns: 265 | pd.DataFrame: DataFrame containing dividend data. 266 | """ 267 | try: 268 | stock = yf.Ticker(symbol) 269 | dividends = stock.dividends 270 | return dividends 271 | except Exception as e: 272 | print(f"Error fetching dividend data for {symbol}: {e}") 273 | return pd.DataFrame() 274 | 275 | @tool 276 | def get_company_news(symbol: str) -> pd.DataFrame: 277 | """ 278 | Get company news and press releases for a given stock symbol. 279 | This function returns titles and url which need further scraping using other tools. 280 | 281 | Args: 282 | symbol (str): The stock symbol. 283 | 284 | Returns: 285 | pd.DataFrame: DataFrame containing company news and press releases. 286 | """ 287 | try: 288 | news = yf.Ticker(symbol).news 289 | return news 290 | except Exception as e: 291 | print(f"Error fetching company news for {symbol}: {e}") 292 | return pd.DataFrame() 293 | 294 | @tool 295 | def get_technical_indicators(symbol: str) -> pd.DataFrame: 296 | """ 297 | Get technical indicators for a given stock symbol. 298 | 299 | Args: 300 | symbol (str): The stock symbol. 301 | 302 | Returns: 303 | pd.DataFrame: DataFrame containing technical indicators. 304 | """ 305 | try: 306 | indicators = yf.Ticker(symbol).history(period="max") 307 | return indicators 308 | except Exception as e: 309 | print(f"Error fetching technical indicators for {symbol}: {e}") 310 | return pd.DataFrame() 311 | 312 | @tool 313 | def get_company_profile(symbol: str) -> dict: 314 | """ 315 | Get company profile and overview for a given stock symbol. 316 | 317 | Args: 318 | symbol (str): The stock symbol. 319 | 320 | Returns: 321 | dict: Dictionary containing company profile and overview. 322 | """ 323 | try: 324 | profile = yf.Ticker(symbol).info 325 | return profile 326 | except Exception as e: 327 | print(f"Error fetching company profile for {symbol}: {e}") 328 | return { 329 | 330 | } 331 | 332 | @tool 333 | def search_10q(data): 334 | """ 335 | Useful to search information from the latest 10-Q form for a 336 | given stock. 337 | The input to this tool should be a pipe (|) separated text of 338 | length two, representing the stock ticker you are interested and what 339 | question you have from it. 340 | For example, `AAPL|what was last quarter's revenue`. 341 | """ 342 | stock, ask = data.split("|") 343 | queryApi = QueryApi(api_key=os.environ['SEC_API_API_KEY']) 344 | query = { 345 | "query": { 346 | "query_string": { 347 | "query": f"ticker:{stock} AND formType:\"10-Q\"" 348 | } 349 | }, 350 | "from": "0", 351 | "size": "1", 352 | "sort": [{ "filedAt": { "order": "desc" }}] 353 | } 354 | 355 | fillings = queryApi.get_filings(query)['filings'] 356 | if len(fillings) == 0: 357 | return "Sorry, I couldn't find any filling for this stock, check if the ticker is correct." 358 | link = fillings[0]['linkToFilingDetails'] 359 | inference_logger.info(f"Running embedding search on {link}") 360 | answer = embedding_search(link, ask) 361 | return answer 362 | 363 | @tool 364 | def search_10k(data): 365 | """ 366 | Useful to search information from the latest 10-K form for a 367 | given stock. 368 | The input to this tool should be a pipe (|) separated text of 369 | length two, representing the stock ticker you are interested, what 370 | question you have from it. 371 | For example, `AAPL|what was last year's revenue`. 372 | """ 373 | stock, ask = data.split("|") 374 | queryApi = QueryApi(api_key=os.environ['SEC_API_API_KEY']) 375 | query = { 376 | "query": { 377 | "query_string": { 378 | "query": f"ticker:{stock} AND formType:\"10-K\"" 379 | } 380 | }, 381 | "from": "0", 382 | "size": "1", 383 | "sort": [{ "filedAt": { "order": "desc" }}] 384 | } 385 | 386 | fillings = queryApi.get_filings(query)['filings'] 387 | if len(fillings) == 0: 388 | return "Sorry, I couldn't find any filling for this stock, check if the ticker is correct." 389 | link = fillings[0]['linkToFilingDetails'] 390 | inference_logger.info(f"Running embedding search on {link}") 391 | answer = embedding_search(link, ask) 392 | return answer 393 | 394 | @tool 395 | def get_historical_price(symbol, start_date, end_date): 396 | """ 397 | Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'. 398 | - symbol (str): Stock ticker symbol. 399 | - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date 400 | - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date 401 | """ 402 | 403 | data = yf.Ticker(symbol) 404 | hist = data.history(start=start_date, end=end_date) 405 | hist = hist.reset_index() 406 | hist[symbol] = hist['Close'] 407 | return hist[['Date', symbol]] 408 | 409 | @tool 410 | def plot_price_over_time(symbol, start_date, end_date): 411 | """ 412 | Plots the historical stock prices for a given symbol over a specified date range. 413 | - symbol (str): Stock ticker symbol. 414 | - start_date (str): Start date for the historical data in 'YYYY-MM-DD' format. 415 | - end_date (str): End date for the historical data in 'YYYY-MM-DD' format. 416 | """ 417 | 418 | historical_price_df = get_historical_price(symbol, start_date, end_date) 419 | 420 | # Create a Plotly figure 421 | fig = go.Figure() 422 | 423 | # Add a trace for the stock symbol 424 | fig.add_trace(go.Scatter(x=historical_price_df['Date'], y=historical_price_df[symbol], mode='lines+markers', name=symbol)) 425 | 426 | # Update the layout to add titles and format axis labels 427 | fig.update_layout( 428 | title=f'Stock Price Over Time: {symbol}', 429 | xaxis_title='Date', 430 | yaxis_title='Stock Price (USD)', 431 | yaxis_tickprefix='$', 432 | yaxis_tickformat=',.2f', 433 | xaxis=dict( 434 | tickangle=-45, 435 | nticks=20, 436 | tickfont=dict(size=10), 437 | ), 438 | yaxis=dict( 439 | showgrid=True, # Enable y-axis grid lines 440 | gridcolor='lightgrey', # Set grid line color 441 | ), 442 | plot_bgcolor='gray', # Set plot background to white 443 | paper_bgcolor='gray', # Set overall figure background to white 444 | ) 445 | 446 | # Show the figure 447 | st.plotly_chart(fig, use_container_width=True) 448 | 449 | def get_function_names(): 450 | current_module = sys.modules[__name__] 451 | module_source = inspect.getsource(current_module) 452 | 453 | tree = ast.parse(module_source) 454 | tool_functions = [] 455 | 456 | for node in ast.walk(tree): 457 | if isinstance(node, ast.FunctionDef) and any(isinstance(d, ast.Name) and d.id == 'tool' for d in node.decorator_list): 458 | func_obj = getattr(current_module, node.name) 459 | tool_functions.append(func_obj) 460 | 461 | return tool_functions 462 | 463 | def get_openai_tools() -> List[dict]: 464 | functions = get_function_names() 465 | 466 | tools = [convert_to_openai_tool(f) for f in functions] 467 | return tools 468 | 469 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import datetime 3 | import importlib 4 | import inspect 5 | import json 6 | import logging 7 | from logging.handlers import RotatingFileHandler 8 | import os 9 | import re 10 | import sys 11 | from typing import List 12 | import xml.etree.ElementTree as ET 13 | import requests 14 | 15 | from langchain.text_splitter import CharacterTextSplitter 16 | from langchain_community.embeddings import HuggingFaceEmbeddings 17 | from langchain_community.vectorstores import FAISS 18 | from unstructured.partition.html import partition_html 19 | 20 | 21 | logging.basicConfig( 22 | format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 23 | datefmt="%Y-%m-%d:%H:%M:%S", 24 | level=logging.INFO, 25 | ) 26 | script_dir = os.path.dirname(os.path.abspath(__file__)) 27 | now = datetime.datetime.now() 28 | log_folder = os.path.join(script_dir, "inference_logs") 29 | os.makedirs(log_folder, exist_ok=True) 30 | log_file_path = os.path.join( 31 | log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log" 32 | ) 33 | # Use RotatingFileHandler from the logging.handlers module 34 | file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) 35 | file_handler.setLevel(logging.INFO) 36 | 37 | formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") 38 | file_handler.setFormatter(formatter) 39 | 40 | inference_logger = logging.getLogger("function-calling-inference") 41 | inference_logger.addHandler(file_handler) 42 | 43 | def get_tool_names(): 44 | # Get all the classes defined in the script 45 | classes = inspect.getmembers(sys.modules[__name__], inspect.isclass) 46 | 47 | # Extract the class names excluding the imported ones 48 | class_names = [cls[0] for cls in classes if cls[1].__module__ == 'src.tools'] 49 | 50 | return class_names 51 | 52 | def get_fewshot_examples(num_fewshot): 53 | """return a list of few shot examples""" 54 | example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') 55 | with open(example_path, 'r') as file: 56 | examples = json.load(file) # Use json.load with the file object, not the file path 57 | if num_fewshot > len(examples): 58 | raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") 59 | return examples[:num_fewshot] 60 | 61 | def validate_and_extract_tool_calls(assistant_content): 62 | validation_result = False 63 | tool_calls = [] 64 | error_message = None 65 | scratchpad_text = None 66 | 67 | try: 68 | # Use regular expression to find the text within tags 69 | scratchpad_pattern = r'(.*?)' 70 | scratchpad_match = re.search(scratchpad_pattern, assistant_content, re.DOTALL) 71 | if scratchpad_match: 72 | scratchpad_text = scratchpad_match.group(1).strip() 73 | 74 | # Use regular expression to find all tags and their contents 75 | tool_call_pattern = r'(.*?)' 76 | tool_call_matches = re.findall(tool_call_pattern, assistant_content, re.DOTALL) 77 | 78 | if not tool_call_matches: 79 | error_message = None 80 | else: 81 | for match in tool_call_matches: 82 | json_text = match.strip() 83 | 84 | try: 85 | json_data = json.loads(json_text) 86 | tool_calls.append(json_data) 87 | validation_result = True 88 | except json.JSONDecodeError as json_err: 89 | error_message = f"JSON parsing failed:\n"\ 90 | f"- JSON Decode Error: {json_err}\n"\ 91 | f"- Problematic JSON text: {json_text}" 92 | inference_logger.error(error_message) 93 | continue 94 | 95 | except Exception as err: 96 | error_message = f"Error during tool call extraction: {err}" 97 | inference_logger.error(error_message) 98 | 99 | return validation_result, tool_calls, scratchpad_text, error_message 100 | 101 | def extract_json_from_markdown(text): 102 | """ 103 | Extracts the JSON string from the given text using a regular expression pattern. 104 | 105 | Args: 106 | text (str): The input text containing the JSON string. 107 | 108 | Returns: 109 | dict: The JSON data loaded from the extracted string, or None if the JSON string is not found. 110 | """ 111 | json_pattern = r'```json\r?\n(.*?)\r?\n```' 112 | match = re.search(json_pattern, text, re.DOTALL) 113 | if match: 114 | json_string = match.group(1) 115 | try: 116 | data = json.loads(json_string) 117 | return data 118 | except json.JSONDecodeError as e: 119 | print(f"Error decoding JSON string: {e}") 120 | else: 121 | print("JSON string not found in the text.") 122 | return None 123 | 124 | def embedding_search(url, query): 125 | text = download_form_html(url) 126 | elements = partition_html(text=text) 127 | content = "\n".join([str(el) for el in elements]) 128 | text_splitter = CharacterTextSplitter( 129 | separator="\n", 130 | chunk_size=1000, 131 | chunk_overlap=150, 132 | length_function=len, 133 | is_separator_regex=False, 134 | ) 135 | docs = text_splitter.create_documents([content]) 136 | 137 | # Load a pre-trained sentence transformer model 138 | embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") 139 | 140 | # Create FAISS index and retriever 141 | index = FAISS.from_documents(docs, embedding_model) 142 | retriever = index.as_retriever() 143 | 144 | answers = retriever.invoke(query, top_k=4) 145 | chunks = [] 146 | for i, doc in enumerate(answers): 147 | chunk = f"\n\n{doc.page_content}\n\n" 148 | chunks.append(chunk) 149 | 150 | result = "".join(chunks) 151 | return f"\n{result}" 152 | 153 | def download_form_html(url): 154 | headers = { 155 | 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7', 156 | 'Accept-Encoding': 'gzip, deflate, br', 157 | 'Accept-Language': 'en-US,en;q=0.9,pt-BR;q=0.8,pt;q=0.7', 158 | 'Cache-Control': 'max-age=0', 159 | 'Dnt': '1', 160 | 'Sec-Ch-Ua': '"Not_A Brand";v="8", "Chromium";v="120"', 161 | 'Sec-Ch-Ua-Mobile': '?0', 162 | 'Sec-Ch-Ua-Platform': '"macOS"', 163 | 'Sec-Fetch-Dest': 'document', 164 | 'Sec-Fetch-Mode': 'navigate', 165 | 'Sec-Fetch-Site': 'none', 166 | 'Sec-Fetch-User': '?1', 167 | 'Upgrade-Insecure-Requests': '1', 168 | 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' 169 | } 170 | 171 | response = requests.get(url, headers=headers) 172 | return response.text -------------------------------------------------------------------------------- /src/validator.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | from jsonschema import validate 4 | from pydantic import ValidationError 5 | from utils import inference_logger, extract_json_from_markdown 6 | from schema import FunctionCall, FunctionSignature 7 | 8 | def validate_function_call_schema(call, signatures): 9 | try: 10 | call_data = FunctionCall(**call) 11 | except ValidationError as e: 12 | return False, str(e) 13 | 14 | for signature in signatures: 15 | try: 16 | signature_data = FunctionSignature(**signature) 17 | if signature_data.function.name == call_data.name: 18 | # Validate types in function arguments 19 | for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items(): 20 | if arg_name in call_data.arguments: 21 | call_arg_value = call_data.arguments[arg_name] 22 | if call_arg_value: 23 | try: 24 | validate_argument_type(arg_name, call_arg_value, arg_schema) 25 | except Exception as arg_validation_error: 26 | return False, str(arg_validation_error) 27 | 28 | # Check if all required arguments are present 29 | required_arguments = signature_data.function.parameters.get('required', []) 30 | result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments) 31 | if not result: 32 | return False, f"Missing required arguments: {missing_arguments}" 33 | 34 | return True, None 35 | except Exception as e: 36 | # Handle validation errors for the function signature 37 | return False, str(e) 38 | 39 | # No matching function signature found 40 | return False, f"No matching function signature found for function: {call_data.name}" 41 | 42 | def check_required_arguments(call_arguments, required_arguments): 43 | missing_arguments = [arg for arg in required_arguments if arg not in call_arguments] 44 | return not bool(missing_arguments), missing_arguments 45 | 46 | def validate_enum_value(arg_name, arg_value, enum_values): 47 | if arg_value not in enum_values: 48 | raise Exception( 49 | f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}" 50 | ) 51 | 52 | def validate_argument_type(arg_name, arg_value, arg_schema): 53 | arg_type = arg_schema.get('type', None) 54 | if arg_type: 55 | if arg_type == 'string' and 'enum' in arg_schema: 56 | enum_values = arg_schema['enum'] 57 | if None not in enum_values and enum_values != []: 58 | try: 59 | validate_enum_value(arg_name, arg_value, enum_values) 60 | except Exception as e: 61 | # Propagate the validation error message 62 | raise Exception(f"Error validating function call: {e}") 63 | 64 | python_type = get_python_type(arg_type) 65 | if not isinstance(arg_value, python_type): 66 | raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}") 67 | 68 | def get_python_type(json_type): 69 | type_mapping = { 70 | 'string': str, 71 | 'number': (int, float), 72 | 'integer': int, 73 | 'boolean': bool, 74 | 'array': list, 75 | 'object': dict, 76 | 'null': type(None), 77 | } 78 | return type_mapping[json_type] 79 | 80 | def validate_json_data(json_object, json_schema): 81 | valid = False 82 | error_message = None 83 | result_json = None 84 | 85 | try: 86 | # Attempt to load JSON using json.loads 87 | try: 88 | result_json = json.loads(json_object) 89 | except json.decoder.JSONDecodeError: 90 | # If json.loads fails, try ast.literal_eval 91 | try: 92 | result_json = ast.literal_eval(json_object) 93 | except (SyntaxError, ValueError) as e: 94 | try: 95 | result_json = extract_json_from_markdown(json_object) 96 | except Exception as e: 97 | error_message = f"JSON decoding error: {e}" 98 | inference_logger.info(f"Validation failed for JSON data: {error_message}") 99 | return valid, result_json, error_message 100 | 101 | # Return early if both json.loads and ast.literal_eval fail 102 | if result_json is None: 103 | error_message = "Failed to decode JSON data" 104 | inference_logger.info(f"Validation failed for JSON data: {error_message}") 105 | return valid, result_json, error_message 106 | 107 | # Validate each item in the list against schema if it's a list 108 | if isinstance(result_json, list): 109 | for index, item in enumerate(result_json): 110 | try: 111 | validate(instance=item, schema=json_schema) 112 | inference_logger.info(f"Item {index+1} is valid against the schema.") 113 | except ValidationError as e: 114 | error_message = f"Validation failed for item {index+1}: {e}" 115 | break 116 | else: 117 | # Default to validation without list 118 | try: 119 | validate(instance=result_json, schema=json_schema) 120 | except ValidationError as e: 121 | error_message = f"Validation failed: {e}" 122 | 123 | except Exception as e: 124 | error_message = f"Error occurred: {e}" 125 | 126 | if error_message is None: 127 | valid = True 128 | inference_logger.info("JSON data is valid against the schema.") 129 | else: 130 | inference_logger.info(f"Validation failed for JSON data: {error_message}") 131 | 132 | return valid, result_json, error_message -------------------------------------------------------------------------------- /test_main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from typing import Any, Callable, Dict, List, Optional 4 | from openai import OpenAI 5 | import requests 6 | from bs4 import BeautifulSoup 7 | from datetime import datetime 8 | from tiktoken import get_encoding 9 | import pickle 10 | import requests 11 | from bs4 import BeautifulSoup 12 | import PyPDF2 13 | import os 14 | import hashlib 15 | from src.resources import Resources 16 | import argparse 17 | from datetime import datetime 18 | 19 | 20 | from src.rag_tools import TextReaderTool, WebScraperTool, SemanticAnalysisTool, NERExtractionTool, SemanticFileSearchTool, WikipediaSearchTool 21 | from src import agents 22 | from src.agents import Agent 23 | from src import tasks 24 | from src.tasks import Task 25 | 26 | class Resources: 27 | def __init__(self, resource_type, path, template): 28 | self.resource_type = resource_type 29 | self.path = path 30 | self.template = template 31 | 32 | class Agent: 33 | def __init__(self, role, tools): 34 | self.role = role 35 | self.tools = tools 36 | self.interactions = [] # Simulated interactions log 37 | 38 | class Task: 39 | def __init__(self, instructions, agent, tool_name=None): 40 | self.instructions = instructions 41 | self.agent = agent 42 | self.tool_name = tool_name 43 | self.id = str(uuid.uuid4()) 44 | self.output = None 45 | 46 | def execute(self, context): 47 | # Placeholder for task execution logic 48 | return f"Executed {self.instructions} using {self.agent.role}" 49 | 50 | class Squad: 51 | def __init__(self, agents: List[Agent], tasks: List[Task], resources: List[Resources], verbose: bool = False, log_file: str = "squad_log.json"): 52 | self.id = str(uuid.uuid4()) 53 | self.agents = agents 54 | self.tasks = tasks 55 | self.resources = resources 56 | self.verbose = verbose 57 | self.log_file = log_file 58 | self.log_data = [] 59 | self.llama_logs = [] 60 | 61 | def run(self, inputs: Optional[Dict[str, Any]] = None) -> str: 62 | context = "" 63 | for task in self.tasks: 64 | if self.verbose: 65 | print(f"Starting Task:\n{task.instructions}") 66 | 67 | self.log_data.append({ 68 | "timestamp": datetime.now().isoformat(), 69 | "type": "input", 70 | "agent_role": task.agent.role, 71 | "task_name": task.instructions, 72 | "task_id": task.id, 73 | "content": task.instructions 74 | }) 75 | 76 | output = task.execute(context=context) 77 | task.output = output 78 | 79 | if self.verbose: 80 | print(f"Task output:\n{output}\n") 81 | 82 | self.log_data.append({ 83 | "timestamp": datetime.now().isoformat(), 84 | "type": "output", 85 | "agent_role": task.agent.role, 86 | "task_name": task.instructions, 87 | "task_id": task.id, 88 | "content": output 89 | }) 90 | 91 | context += f"Task:\n{task.instructions}\nOutput:\n{output}\n\n" 92 | 93 | self.save_logs() 94 | return context 95 | 96 | def save_logs(self): 97 | with open(self.log_file, "w") as file: 98 | json.dump(self.log_data, file, indent=2) 99 | 100 | def load_configuration(file_path): # loading a config file alt is full cli? 101 | with open(file_path, 'r') as file: 102 | return json.load(file) 103 | 104 | def initialize_resources(config): 105 | resources = [] 106 | for res in config["resources"]: 107 | resources.append(Resources(res['type'], res['path'], res['template'])) 108 | return resources 109 | 110 | def initialize_agents_and_tasks(config): 111 | agents = [Agent(**ag) for ag in config['agents']] 112 | tasks = [Task(**tk) for tk in config['tasks']] 113 | return agents, tasks 114 | 115 | def parse_args(): 116 | parser = argparse.ArgumentParser(description="Run the squad with dynamic configurations.") 117 | parser.add_argument('-c', '--config', type=str, help="Path to configuration JSON file", required=True) 118 | parser.add_argument('-v', '--verbose', action='store_true', help="Enable verbose output") 119 | return parser.parse_args() 120 | 121 | def mainflow(): 122 | args = parse_args() 123 | config = load_configuration(args.config) 124 | 125 | resources = initialize_resources(config) 126 | agents, tasks = initialize_agents_and_tasks(config) 127 | 128 | squad = Squad( 129 | agents=agents, 130 | tasks=tasks, 131 | resources=resources, 132 | verbose=args.verbose, 133 | log_file="squad_goals_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".json" 134 | ) 135 | 136 | result = squad.run() 137 | print(f"Final output:\n{result}") 138 | 139 | if __name__ == "__main__": 140 | mainflow() 141 | --------------------------------------------------------------------------------