├── .gitignore ├── .gitmodules ├── README.md ├── data_crew.py ├── nyc_salaries_sampled.db ├── prompts.py ├── spaceship_titanic.db └── tools └── sql_tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Bronco"] 2 | path = Bronco 3 | url = git@github.com:GKjohns/Bronco.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DataAnalysisCrew 2 | A data analysis AI built with crewAI 3 | -------------------------------------------------------------------------------- /data_crew.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | sys.path.append('..') 5 | from crewai import Agent, Task, Crew, Process 6 | from langchain.tools import StructuredTool 7 | from langchain_community.tools import DuckDuckGoSearchRun 8 | from langchain_community.retrievers import ArxivRetriever 9 | from langchain_experimental.tools import PythonREPLTool 10 | from langchain_community.agent_toolkits import SQLDatabaseToolkit 11 | from langchain.sql_database import SQLDatabase 12 | from langchain_community.agent_toolkits import create_sql_agent 13 | from langchain_openai import ChatOpenAI 14 | from pydantic.v1 import BaseModel, Field 15 | 16 | from Bronco import bronco 17 | from prompts import CrewGenPrompts 18 | 19 | from tools.sql_tool import build_sql_tool 20 | 21 | def review_config(config, keep_file=False) -> str: 22 | 23 | # write the config to a file 24 | with open('crew_config.py', 'w') as f: 25 | f.write(str(config)) 26 | 27 | # open the file in vim 28 | os.system('vim crew_config.py') 29 | 30 | # read the reviewed file back into memory 31 | with open('crew_config.py', 'r') as f: 32 | reviewed_config = f.read() 33 | 34 | # delete the file 35 | if not keep_file: 36 | os.system('rm crew_config.py') 37 | 38 | return eval(reviewed_config) 39 | 40 | def extract_python_code(text): 41 | # Regular expression pattern to find the first code block marked as Python code 42 | pattern = r'```python(.*?)```' 43 | match = re.search(pattern, text, re.DOTALL) 44 | 45 | # If a match is not found, return the whole text string 46 | if not match: 47 | return text 48 | 49 | code = match.group(1).strip() # Use .strip() to remove leading/trailing whitespace 50 | try: 51 | # Here, eval is not safe to use directly without proper context and validation 52 | return eval(code) 53 | except Exception as e: 54 | return f"Error evaluating code: {e}" 55 | 56 | 57 | 58 | def generate_crew_config(objective, tool_names): 59 | crew_generator = bronco.LLMFunction( 60 | prompt_template=CrewGenPrompts.generate_crew_config_prompt, 61 | model_name=bronco.GPT_4, 62 | parser=extract_python_code 63 | ) 64 | 65 | return crew_generator.generate({ 66 | 'objective': objective, 67 | 'tool_names': tool_names 68 | }) 69 | 70 | def generate_agent_config(name, objective, agent_tasks, tool_names): 71 | """ 72 | Generates an agent configuration for a Senior Research Analyst role with specific tasks and tools. 73 | 74 | Parameters: 75 | - name (str): The name of the agent. 76 | - agent_tasks (list): A list of tasks that the agent is responsible for. 77 | - tool_names (list): A list of tools that the agent has access to. 78 | 79 | Returns: 80 | - dict: The generated agent configuration. 81 | """ 82 | 83 | 84 | agent_config_generator = bronco.LLMFunction( 85 | prompt_template=CrewGenPrompts.gen_agent_config_prompt, 86 | model_name=bronco.GPT_4, 87 | parser=extract_python_code, 88 | ) 89 | 90 | # Generating agent configuration using the extracted values 91 | agent_config = agent_config_generator.generate({ 92 | 'name': name, 93 | 'objective': objective, 94 | 'agent_tasks': agent_tasks, 95 | 'tool_names': tool_names 96 | }) 97 | 98 | agent_config.update({'name': name}) 99 | 100 | print(agent_config['name']) 101 | 102 | return agent_config 103 | 104 | def generate_task_config(task_description, objective, agent_dict): 105 | """ 106 | Generates a task configuration for creating a report on houseplant trends in the US in 2023. 107 | 108 | Parameters: 109 | - task_description (str): A brief description of the task. 110 | - objective (str): A detailed objective of what the report should cover. 111 | - agent_dict (dict): A dictionary containing the agent data. 112 | 113 | Returns: 114 | - dict: The generated task configuration. 115 | """ 116 | 117 | # Extracting agent and tool_names from the agent_dict parameter 118 | agent_name = agent_dict['name'] 119 | agent_role = agent_dict['role'] 120 | tool_names = agent_dict.get('tool_names', []) 121 | 122 | task_config_generator = bronco.LLMFunction( 123 | prompt_template=CrewGenPrompts.gen_task_config_prompt, 124 | model_name=bronco.GPT_4, 125 | parser=extract_python_code, 126 | success_func=lambda x: 'description' in x and 'agent' in x 127 | ) 128 | 129 | # Generating task configuration using the extracted values 130 | task_config = task_config_generator.generate({ 131 | 'task_description': task_description, 132 | 'agent_role': agent_role, 133 | 'objective': objective, 134 | 'tool_names': tool_names 135 | }) 136 | 137 | task_config.update({'name': agent_name}) 138 | 139 | return task_config 140 | 141 | def create_full_config(objective, tools, review_intermediate=True, keep_final_config=False): 142 | ''' 143 | Create a full config for a crew based on an objective and a list of tools. 144 | ''' 145 | tool_names = [tool.name for tool in tools] 146 | 147 | print('Generating crew config...') 148 | crew_config = generate_crew_config(objective, tool_names) 149 | 150 | if review_intermediate: 151 | crew_config = review_config(crew_config) 152 | 153 | # Create a config for each agent in the config 154 | agents = [] 155 | for agent_name in crew_config['agents']: 156 | print(f'Generating agent config for {agent_name}...') 157 | agent_tasks = [task['task'] for task in crew_config['tasks'] if task['agent'] == agent_name] 158 | agent_config = generate_agent_config( 159 | name=agent_name, 160 | objective=objective, 161 | agent_tasks=agent_tasks, 162 | tool_names=tool_names 163 | ) 164 | agents.append(agent_config) 165 | 166 | if review_intermediate: 167 | agents = review_config(agents) 168 | 169 | # Create a config for each task in the config 170 | tasks = [] 171 | for task in crew_config['tasks']: 172 | print(f'Generating task config for {task["task"]}...') 173 | task_agent = [agent for agent in agents if agent['name'] == task['agent']][0] 174 | task_description = task['task'] 175 | 176 | task_config = generate_task_config( 177 | task_description=task_description, 178 | objective=objective, 179 | agent_dict=task_agent 180 | ) 181 | 182 | # string agent names need to be replaced with pointers to the agent objects 183 | # occurs during crew initialization, to ensure that we have a serializable config 184 | task_config.update({'agent': task_agent['name']}) 185 | 186 | tasks.append(task_config) 187 | 188 | if review_intermediate: 189 | tasks = review_config(tasks) 190 | # create the full config 191 | crew_config = { 192 | 'agents': agents, 193 | 'tasks': tasks 194 | } 195 | 196 | # Allow the user to review the fully formed config 197 | review_config(crew_config, keep_file=keep_final_config) 198 | 199 | return crew_config 200 | 201 | 202 | def initialize_from_config(config, verbose=2): 203 | ''' 204 | Initialize a Crew object from a configuration dictionary. 205 | ''' 206 | 207 | # agent tools need to be a pointer to the object, not a string 208 | for agent in config['agents']: 209 | agent['tools'] = [tool for tool in tools if tool.name in agent['tools']] 210 | agent_objects = [Agent(**agent) for agent in config['agents']] 211 | 212 | # the task agent needs to pe a pointer to the object, not a string 213 | agent_string_to_object = {} 214 | for agent_str, agent_obj in zip(config['agents'], agent_objects): 215 | agent_string_to_object[agent_str['name']] = agent_obj 216 | for task in config['tasks']: 217 | task['agent'] = agent_string_to_object[task['agent']] 218 | task_objects = [Task(**task) for task in config['tasks']] 219 | 220 | crew = Crew( 221 | agents=agent_objects, 222 | tasks=task_objects, 223 | verbose=verbose 224 | ) 225 | 226 | return crew 227 | 228 | def initialize_crew_from_saved_config(config_file, verbose=2): 229 | with open(config_file, 'r') as f: 230 | config = f.read() 231 | 232 | return initialize_from_config(eval(config), verbose=verbose) 233 | 234 | 235 | 236 | if __name__ == '__main__': 237 | objective = ( 238 | 'Create a report on the impact of the following factors on survival rates' 239 | ' (probability of being successfully transported) on the spaceship titanic.' 240 | '\n- vip status' 241 | '\n- shopping and dining spending' 242 | '\n- cabin class' 243 | ) 244 | 245 | sql_agent_tool = build_sql_tool( 246 | db_uri='sqlite:///./spaceship_titanic.db', 247 | name='query_sql_db_tool', 248 | description='Runs a sql query against the spaceship titanic database and returns the results.' 249 | ) 250 | tools = [sql_agent_tool, PythonREPLTool()] 251 | 252 | print('Objective: ', objective) 253 | print('Tools: ', [tool.name for tool in tools]) 254 | 255 | crew_config = create_full_config( 256 | objective=objective, 257 | tools=tools, 258 | review_intermediate=False, 259 | keep_final_config=True 260 | ) 261 | 262 | crew = initialize_from_config(crew_config) 263 | 264 | crew.kickoff() -------------------------------------------------------------------------------- /nyc_salaries_sampled.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GKjohns/DataAnalysisCrew/8503f15522b327dcf78ea05de2c7d81c30b0b24c/nyc_salaries_sampled.db -------------------------------------------------------------------------------- /prompts.py: -------------------------------------------------------------------------------- 1 | class DataQaPrompts: 2 | 3 | question_to_query_prompt = ''' 4 | # Context 5 | Your job is to write a sql query that answers the following question: 6 | {question} 7 | 8 | Below is a list of columns and sample values. Your query should only use the data contained in the table. The table name is `{table_name}`. 9 | 10 | # Columns and sample values 11 | {table_sample} 12 | 13 | If the question is not a question or is answerable with the given columns, respond to the best of your ability. 14 | Do not use columns that aren't in the table. 15 | Ensure that the query runs and returns the correct output. 16 | 17 | # Your query: 18 | ''' 19 | 20 | results_to_answer_prompt = ''' 21 | # Task 22 | Based on the results of a SQL query, provide a brief summary of the key findings and explicitly answer the following question: 23 | {question} 24 | 25 | The query results from the table `{table_name}` are as follows: 26 | 27 | # Query Results Table 28 | {query_results_table} 29 | 30 | In 2-5 sentences, summarize the main insights from the query results and give a clear and direct answer to the original question. 31 | 32 | # Summary and Answer: 33 | ''' 34 | 35 | 36 | 37 | class CrewGenPrompts: 38 | generate_crew_config_prompt = ''' 39 | Here's an example of configs for a simple agent and a simple task: 40 | ```python 41 | i{{ 42 | 'agents': [ 43 | 'content_writer', # Writes engaging descriptions for plants and company info. 44 | 'web_designer', # Creates the layout and design for the landing page. 45 | 'web_developer', # Builds the website using HTML, CSS, and potentially JavaScript. 46 | 'seo_specialist' # Optimizes content for search engines to increase visibility. 47 | ], 48 | 'tasks': [ 49 | 'tasks': [ 50 | {{'task': 'write_plant_descriptions', 'agent': 'content_writer'}}, 51 | {{'task': 'design_page_layout', 'agent': 'web_designer'}}, 52 | {{'task': 'select_images', 'agent': 'web_designer'}}, 53 | {{'task': 'write_about_us', 'agent': 'content_writer'}}, 54 | {{'task': 'build_webpage', 'agent': 'web_developer'}}, 55 | {{'task': 'implement_seo_practices', 'agent': 'seo_specialist'}}, 56 | {{'task': 'setup_contact_form', 'agent': 'web_developer'}}, 57 | {{'task': 'launch_page_review', 'agent': 'web_designer'}} 58 | ] 59 | ] 60 | }} 61 | ``` 62 | 63 | # Task 64 | Create a list of agents and tasks that would complete the following obective: {objective} 65 | The output should be a config in the form of a dictionary: 66 | ``` 67 | {{'agents': ['agent1', 'agent1', ...], 'tasks': [...]}} 68 | ``` 69 | 70 | Ensure that the agents and tasks are relevant to the objective and that the agents have the necessary skills to complete the tasks. 71 | Remember that each task must be delegated to an agent. Do not create a task that cannot be completed by any of the agents. 72 | Tasks and agent capabilities should be within the abilities that can be completed by a python coder with access to the internet and a powerful LLM AI. 73 | 74 | # Available tools 75 | Here's a list of tools the agents can use to complete their tasks 76 | Note that if a step involves analyzing data, there needs to be a task that involves acquiring the data: 77 | {tool_names} 78 | 79 | # Your crew config 80 | ''' 81 | 82 | gen_agent_config_prompt = ''' 83 | 84 | # Instructions 85 | The overall objective of the larger program is: {objective} 86 | Create a config for an agent with the name {name}, in the format below. 87 | The agent will have to complete the following tasks: {agent_tasks} 88 | The agent may make use of any of the following tools: {tool_names} 89 | 90 | ENSURE THAT YOUR CONFIG IS IN THE FORM OF A DICTIONARY WITH THE KEYS BELOW. 91 | Here's an example of a config for a simple agent: 92 | ```python 93 | {{ 94 | 'role': 'Unicorn Hunter', 95 | 'goal': 'Discover and capture mythical unicorns for study and conservation', 96 | 'backstory': \'''You are part of an ancient society dedicated to the preservation and study of unicorns. With a deep understanding of mythical creatures and their habitats, you embark on expeditions into enchanted forests. Your skills in tracking, magical lore, and non-lethal capture techniques are unparalleled. You work to ensure the survival of unicorns and the balance of their ecosystems, often collaborating with wizards and other mythical beings.\''', 97 | 'verbose': True, 98 | 'allow_delegation': False, 99 | 'tools': ['enchanted_net', 'potion_brewing_kit', 'ancient_tome_of_lore'] 100 | }} 101 | ``` 102 | 103 | Do not duplicate tasks. If a task is already assigned to another agent, do not assign it to this agent. 104 | 105 | # Your agent config for {name} 106 | ''' 107 | 108 | gen_task_config_prompt = ''' 109 | # Instructions 110 | The overall objective of the larger program is: {objective} 111 | Create a config for a task with the following description: {task_description} 112 | The task should be delegated to the following agent: {agent_role} 113 | The agent will have access to the following tools: {tool_names} 114 | 115 | ENSURE THAT YOUR CONFIG IS IN THE FORM OF A PYTHON DICTIONARY with the keys 'description' and 'agent' only. 116 | Here's an example of a config for a simple task: 117 | ```python 118 | {{ 119 | 'description': \'''Conduct a comprehensive analysis of the latest advancements in AI in 2024. 120 | Identify key trends, breakthrough technologies, and potential industry impacts. 121 | Your final answer MUST be a full analysis report\''', 122 | 'agent': 'researcher' 123 | }} 124 | ``` 125 | 126 | # Your task config for {task_description} 127 | ''' 128 | 129 | code_fixer_prompt = ''' 130 | # Task 131 | Refine the formatting and fix any error in the given code snippet. 132 | Your primary goal is to ensure that the code will run successfully without any errors. 133 | Ensure brackets, quotes, and parentheses are balanced and properly nested. 134 | If the code is already well-formatted, return it unchanged. 135 | ONLY OUTPUT THE FIXED OR UNCHANGED CODE. 136 | 137 | # Language/File Type 138 | {language_or_file_type} 139 | 140 | # Input code snippet 141 | ``` 142 | {code_snippet} 143 | ``` 144 | 145 | # Output code snippet 146 | ''' 147 | 148 | 149 | if __name__ == '__main__': 150 | 151 | task_description = 'build_webpage' 152 | objective = ( 153 | 'Create a report on trends in houseplant in the US in 2023. ' 154 | 'The report should include a comprehensive analysis of the market, ' 155 | 'including the most popular plants, the most popular planters, ' 156 | 'and the most popular plant care products. ' 157 | 'The report should also include a section on the most popular houseplant influencers ' 158 | 'and a section on the most popular houseplant hashtags.' 159 | ) 160 | tool_names = ['bing_search', 'code_pad'] 161 | 162 | 163 | task_config_generator = bronco.LLMFunction( 164 | prompt_template=DataQaPrompts.gen_task_config_prompt 165 | ) 166 | 167 | task_config = task_config_generator.generate({ 168 | 'task_description': task_description, 169 | 'agent': 'web_developer', 170 | 'objective': objective, 171 | 'tool_names': tool_names 172 | }) 173 | 174 | print(task_config) -------------------------------------------------------------------------------- /spaceship_titanic.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GKjohns/DataAnalysisCrew/8503f15522b327dcf78ea05de2c7d81c30b0b24c/spaceship_titanic.db -------------------------------------------------------------------------------- /tools/sql_tool.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from uuid import UUID 3 | from pydantic.v1 import BaseModel, Field 4 | from langchain_core.pydantic_v1 import Field 5 | 6 | # from langchain_community.callbacks import Callback 7 | from langchain.sql_database import SQLDatabase 8 | from langchain.tools import StructuredTool 9 | from langchain_openai import ChatOpenAI 10 | from langchain_core.callbacks import CallbackManagerForToolRun 11 | from langchain_community.agent_toolkits import create_sql_agent 12 | from langchain_community.tools import BaseTool 13 | from langchain.sql_database import SQLDatabase 14 | from langchain.tools import StructuredTool 15 | from langchain_openai import ChatOpenAI 16 | from langchain_core.callbacks import CallbackManagerForToolRun 17 | from langchain_community.agent_toolkits import create_sql_agent 18 | from langchain_community.tools import BaseTool 19 | from langchain_core.tools import BaseTool 20 | from langchain_community.tools.sql_database.tool import ( 21 | InfoSQLDatabaseTool, 22 | ListSQLDatabaseTool, 23 | QuerySQLCheckerTool, 24 | BaseSQLDatabaseTool 25 | ) 26 | from langchain_community.utilities.sql_database import SQLDatabase 27 | from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit 28 | 29 | 30 | 31 | class QuerySQLLimitedDataBaseTool(BaseSQLDatabaseTool, BaseTool): 32 | """Tool for querying a SQL database with a limit on the output.""" 33 | 34 | name: str = "sql_db_query" 35 | description: str = """ 36 | Input to this tool is a detailed and correct SQL query, output is a result from the database. 37 | If the query is not correct, an error message will be returned. 38 | If an error is returned, rewrite the query, check the query, and try again. 39 | """ 40 | 41 | def _run( 42 | self, 43 | query: str, 44 | run_manager: Optional[CallbackManagerForToolRun] = None, 45 | ) -> str: 46 | """Execute the query, return the results or an error message.""" 47 | 48 | result = self.db.run_no_throw(query) 49 | 50 | results_list = eval(result) 51 | if len(results_list) > 100: 52 | results_list = results_list[:100] 53 | 54 | return str(results_list) 55 | 56 | class SQLDatabaseToolkitLimited(SQLDatabaseToolkit): 57 | 58 | def get_tools(self) -> List[BaseTool]: 59 | """Get the tools in the toolkit.""" 60 | list_sql_database_tool = ListSQLDatabaseTool(db=self.db) 61 | info_sql_database_tool_description = ( 62 | "Input to this tool is a comma-separated list of tables, output is the " 63 | "schema and sample rows for those tables. " 64 | "Be sure that the tables actually exist by calling " 65 | f"{list_sql_database_tool.name} first! " 66 | "Example Input: table1, table2, table3" 67 | ) 68 | info_sql_database_tool = InfoSQLDatabaseTool( 69 | db=self.db, description=info_sql_database_tool_description 70 | ) 71 | query_sql_database_tool_description = ( 72 | "Input to this tool is a detailed and correct SQL query, output is a " 73 | "result from the database. If the query is not correct, an error message " 74 | "will be returned. If an error is returned, rewrite the query, check the " 75 | "query, and try again. If you encounter an issue with Unknown column " 76 | f"'xxxx' in 'field list', use {info_sql_database_tool.name} " 77 | "to query the correct table fields." 78 | ) 79 | 80 | # Note that we're using the limited version of the query tool 81 | query_sql_database_tool = QuerySQLLimitedDataBaseTool( 82 | db=self.db, description=query_sql_database_tool_description 83 | ) 84 | query_sql_checker_tool_description = ( 85 | "Use this tool to double check if your query is correct before executing " 86 | "it. Always use this tool before executing a query with " 87 | f"{query_sql_database_tool.name}!" 88 | ) 89 | query_sql_checker_tool = QuerySQLCheckerTool( 90 | db=self.db, llm=self.llm, description=query_sql_checker_tool_description 91 | ) 92 | return [ 93 | query_sql_database_tool, 94 | info_sql_database_tool, 95 | list_sql_database_tool, 96 | query_sql_checker_tool, 97 | ] 98 | 99 | 100 | 101 | def build_sql_tool(db_uri, description, name='query_sql_db_tool', llm=None) -> StructuredTool: 102 | '''Builds a tool that can run sql queries against a database.''' 103 | 104 | 105 | toolkit = SQLDatabaseToolkitLimited( 106 | llm=ChatOpenAI(model='gpt-4', temperature=0), 107 | db=SQLDatabase.from_uri(db_uri) 108 | ) 109 | 110 | sql_agent = create_sql_agent( 111 | llm=llm or ChatOpenAI(model='gpt-4', temperature=0), 112 | toolkit=toolkit, 113 | agent_type="openai-tools", 114 | verbose=True 115 | ) 116 | 117 | class SqlAgentInput(BaseModel): 118 | sql_query: str = Field() 119 | 120 | def sql_agent_run_wrapper(*args, **kwargs): 121 | '''Runs a sql query against the spaceship titanic database and returns the results.''' 122 | result = sql_agent.invoke(*args, **kwargs) 123 | 124 | if isinstance(result, dict): 125 | return result['output'] 126 | return str(result) 127 | 128 | sql_agent_tool = StructuredTool.from_function( 129 | func=sql_agent_run_wrapper, 130 | name=name, 131 | description=description, 132 | verbose=True, 133 | return_direct=True, 134 | args_schema=SqlAgentInput 135 | ) 136 | 137 | return sql_agent_tool 138 | 139 | 140 | 141 | 142 | if __name__ == '__main__': 143 | 144 | db_uri = 'sqlite:///./spaceship_titanic.db' 145 | # Example usage 146 | tool = build_sql_tool( 147 | db_uri=db_uri, 148 | name='query_sql_db_tool', 149 | description='Runs a sql query against the spaceship titanic database and returns the results.' 150 | ) 151 | 152 | result = tool.invoke('Is there a difference in the percentage of passengers who spent more than 1000 at the food court between survivors and non-survivors?') 153 | print(result) --------------------------------------------------------------------------------