├── .DS_Store ├── .gitignore ├── README.md ├── __pycache__ └── helpers.cpython-312.pyc ├── assis_api_sql_db.py ├── csv_agent.py ├── data └── salaries_2023.csv ├── db └── salary.db ├── first_agent.py ├── fun_call_db_agent.py ├── fun_calling.py ├── helpers.py ├── questions_sql_agent.md ├── requirements.txt └── sql_db_agent.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pdichone/database-ai-agents/605ffdc68db5f79eed956c3812e1cb318ca25203/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # ignore the venv directory 3 | venv/ 4 | # ignore the .vscode directory 5 | .vscode/ 6 | # ignore the .idea directory 7 | .idea/ 8 | # ignore the .env file 9 | .env -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Welcome to The AI Guild 🚀 3 | 4 | **This code is a part of a module in our vibrant AI community 🚀[Join the AI Guild Community](https://bit.ly/ai-guild-join), where like-minded entrepreneurs and programmers come together to build real-world AI-based solutions.** 5 | 6 | ### What is The AI Guild? 7 | The AI Guild is a collaborative community designed for developers, tech enthusiasts, and entrepreneurs who want to **build practical AI tools** and solutions. Whether you’re just starting or looking to level up your skills, this is the place to dive deeper into AI in a supportive, hands-on environment. 8 | 9 | ### Why Join Us? 10 | - **Collaborate with Like-Minded Builders**: Work alongside a community of individuals passionate about AI, sharing ideas and solving real-world problems together. 11 | - **Access to Exclusive Resources**: Gain entry to our Code & Template Vault, a collection of ready-to-use code snippets, templates, and AI projects. 12 | - **Guided Learning Paths**: Follow structured paths, from AI Basics for Builders to advanced classes like AI Solutions Lab, designed to help you apply your knowledge. 13 | - **Weekly Live Calls & Q&A**: Get direct support, feedback, and guidance during live sessions with the community. 14 | - **Real-World AI Projects**: Work on projects that make an impact, learn from others, and showcase your work. 15 | 16 | ### Success Stories 17 | Here’s what some of our members are saying: 18 | - **"Joining The AI Guild has accelerated my learning. I’ve already built my first AI chatbot with the help of the community!"** 19 | - **"The live calls and feedback have been game-changers. I’ve implemented AI automation in my business, saving hours each week."** 20 | 21 | ### Who is This For? 22 | If you’re eager to: 23 | - Build AI tools that solve real problems 24 | - Collaborate and learn from experienced AI practitioners 25 | - Stay up-to-date with the latest in AI development 26 | - Turn your coding skills into actionable solutions 27 | 28 | Then **The AI Guild** is the perfect fit for you. 29 | 30 | ### Frequently Asked Questions 31 | - **Q: Do I need to be an expert to join?** 32 | - **A:** Not at all! The AI Guild is designed for all skill levels, from beginners to advanced developers. 33 | - **Q: Will I get personalized support?** 34 | - **A:** Yes! You’ll have access to live Q&A sessions and direct feedback on your projects. 35 | - **Q: What kind of projects can I work on?** 36 | - **A:** You can start with small projects like chatbots and automation tools, and progress to more advanced AI solutions tailored to your interests. 37 | 38 | ### How to Get Started 39 | Want to dive deeper and get the full experience? 🚀[Join the AI Guild Community](https://bit.ly/ai-guild-join) and unlock all the benefits of our growing community. 40 | 41 | We look forward to seeing what you’ll build with us! 42 | -------------------------------------------------------------------------------- /__pycache__/helpers.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pdichone/database-ai-agents/605ffdc68db5f79eed956c3812e1cb318ca25203/__pycache__/helpers.cpython-312.pyc -------------------------------------------------------------------------------- /assis_api_sql_db.py: -------------------------------------------------------------------------------- 1 | import json 2 | from langchain.schema import HumanMessage, SystemMessage 3 | import os 4 | from dotenv import load_dotenv 5 | from langchain_openai import ChatOpenAI 6 | import pandas as pd 7 | 8 | from sqlalchemy import create_engine 9 | import numpy as np 10 | from sqlalchemy import text 11 | from openai import OpenAI 12 | import streamlit as st 13 | 14 | import helpers 15 | from helpers import ( 16 | get_avg_salary_and_female_count_for_division, 17 | get_total_overtime_pay_for_department, 18 | get_total_longevity_pay_for_grade, 19 | get_employee_count_by_gender_in_department, 20 | get_employees_with_overtime_above, 21 | ) 22 | 23 | 24 | # Load environment variables from .env file 25 | load_dotenv() 26 | 27 | openai_key = os.getenv("OPENAI_API_KEY") 28 | 29 | llm_name = "gpt-3.5-turbo" 30 | model = ChatOpenAI(api_key=openai_key, model=llm_name) 31 | 32 | 33 | # for the weather function calling 34 | client = OpenAI(api_key=openai_key) 35 | 36 | 37 | # Step 1: create the assistant 38 | assistant = client.beta.assistants.create( 39 | name="Salary Assistant", 40 | description="Assistant to help with salary data", 41 | model=llm_name, 42 | tools=helpers.tools_sql, 43 | ) 44 | 45 | # create a thread 46 | thread = client.beta.threads.create() 47 | print(thread.id) 48 | 49 | message = client.beta.threads.messages.create( 50 | thread_id=thread.id, 51 | role="user", 52 | content="""What is the total overtime pay for the Alcohol Beverage Services department?""", 53 | ) 54 | 55 | messages = client.beta.threads.messages.list(thread_id=thread.id) 56 | print(messages) 57 | 58 | # Run the assistant 59 | run = client.beta.threads.runs.create( 60 | thread_id=thread.id, 61 | assistant_id=assistant.id, 62 | ) 63 | 64 | import time 65 | 66 | start_time = time.time() 67 | 68 | status = run.status 69 | 70 | while status not in ["completed", "cancelled", "expired", "failed"]: 71 | time.sleep(5) 72 | run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) 73 | print( 74 | "Elapsed time: {} minutes {} seconds".format( 75 | int((time.time() - start_time) // 60), int((time.time() - start_time) % 60) 76 | ) 77 | ) 78 | status = run.status 79 | print(f"Status: {status}") 80 | if status == "requires_action": 81 | available_functions = { 82 | "get_avg_salary_and_female_count_for_division": get_avg_salary_and_female_count_for_division, 83 | "get_total_overtime_pay_for_department": get_total_overtime_pay_for_department, 84 | "get_total_longevity_pay_for_grade": get_total_longevity_pay_for_grade, 85 | "get_employee_count_by_gender_in_department": get_employee_count_by_gender_in_department, 86 | "get_employees_with_overtime_above": get_employees_with_overtime_above, 87 | } 88 | 89 | tool_outputs = [] 90 | 91 | for tool_call in run.required_action.submit_tool_outputs.tool_calls: 92 | function_name = tool_call.function.name 93 | function_to_call = available_functions[function_name] 94 | function_args = json.loads(tool_call.function.arguments) 95 | if function_name == "get_employees_with_overtime_above": 96 | function_response = function_to_call(amount=function_args.get("amount")) 97 | elif function_name == "get_total_longevity_pay_for_grade": 98 | function_response = function_to_call(grade=function_args.get("grade")) 99 | else: 100 | function_response = function_to_call(**function_args) 101 | 102 | print(f"Function response: {function_response}") 103 | print(tool_call.id) 104 | 105 | tool_outputs.append( 106 | {"tool_call_id": tool_call.id, "output": str(function_response)} 107 | ) 108 | 109 | run = client.beta.threads.runs.submit_tool_outputs( 110 | thread_id=thread.id, 111 | run_id=run.id, 112 | tool_outputs=tool_outputs, 113 | ) 114 | 115 | messages = client.beta.threads.messages.list(thread_id=thread.id) 116 | 117 | print(messages.model_dump_json(indent=2)) 118 | -------------------------------------------------------------------------------- /csv_agent.py: -------------------------------------------------------------------------------- 1 | from langchain.schema import HumanMessage, SystemMessage 2 | import os 3 | from dotenv import load_dotenv 4 | from langchain_openai import ChatOpenAI 5 | import pandas as pd 6 | 7 | # Load environment variables from .env file 8 | load_dotenv() 9 | 10 | openai_key = os.getenv("OPENAI_API_KEY") 11 | 12 | llm_name = "gpt-3.5-turbo" 13 | model = ChatOpenAI(api_key=openai_key, model=llm_name) 14 | 15 | # read csv file 16 | df = pd.read_csv("./data/salaries_2023.csv").fillna(value=0) 17 | 18 | # print(df.head()) 19 | 20 | from langchain_experimental.agents.agent_toolkits import ( 21 | create_pandas_dataframe_agent, 22 | create_csv_agent, 23 | ) 24 | 25 | agent = create_pandas_dataframe_agent( 26 | llm=model, 27 | df=df, 28 | verbose=True, 29 | ) 30 | # res = agent.invoke("how many rows are there in the dataframe?") 31 | 32 | # print(res) 33 | 34 | # then let's add some pre and sufix prompt 35 | CSV_PROMPT_PREFIX = """ 36 | First set the pandas display options to show all the columns, 37 | get the column names, then answer the question. 38 | """ 39 | 40 | CSV_PROMPT_SUFFIX = """ 41 | - **ALWAYS** before giving the Final Answer, try another method. 42 | Then reflect on the answers of the two methods you did and ask yourself 43 | if it answers correctly the original question. 44 | If you are not sure, try another method. 45 | FORMAT 4 FIGURES OR MORE WITH COMMAS. 46 | - If the methods tried do not give the same result,reflect and 47 | try again until you have two methods that have the same result. 48 | - If you still cannot arrive to a consistent result, say that 49 | you are not sure of the answer. 50 | - If you are sure of the correct answer, create a beautiful 51 | and thorough response using Markdown. 52 | - **DO NOT MAKE UP AN ANSWER OR USE PRIOR KNOWLEDGE, 53 | ONLY USE THE RESULTS OF THE CALCULATIONS YOU HAVE DONE**. 54 | - **ALWAYS**, as part of your "Final Answer", explain how you got 55 | to the answer on a section that starts with: "\n\nExplanation:\n". 56 | In the explanation, mention the column names that you used to get 57 | to the final answer. 58 | """ 59 | QUESTION = "Which grade has the highest average base salary, and compare the average female pay vs male pay?" 60 | 61 | res = agent.invoke(CSV_PROMPT_PREFIX + QUESTION + CSV_PROMPT_SUFFIX) 62 | 63 | # print(f"Final result: {res["output"]}") 64 | 65 | import streamlit as st 66 | 67 | st.title("Database AI Agent with LangChain") 68 | 69 | st.write("### Dataset Preview") 70 | st.write(df.head()) 71 | 72 | # User input for the question 73 | st.write("### Ask a Question") 74 | question = st.text_input( 75 | "Enter your question about the dataset:", 76 | "Which grade has the highest average base salary, and compare the average female pay vs male pay?", 77 | ) 78 | 79 | # Run the agent and display the result 80 | if st.button("Run Query"): 81 | QUERY = CSV_PROMPT_PREFIX + question + CSV_PROMPT_SUFFIX 82 | res = agent.invoke(QUERY) 83 | st.write("### Final Answer") 84 | st.markdown(res["output"]) 85 | -------------------------------------------------------------------------------- /db/salary.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pdichone/database-ai-agents/605ffdc68db5f79eed956c3812e1cb318ca25203/db/salary.db -------------------------------------------------------------------------------- /first_agent.py: -------------------------------------------------------------------------------- 1 | from langchain.schema import HumanMessage, SystemMessage 2 | import os 3 | from dotenv import load_dotenv 4 | from langchain_openai import ChatOpenAI 5 | 6 | # Load environment variables from .env file 7 | load_dotenv() 8 | 9 | openai_key = os.getenv("OPENAI_API_KEY") 10 | 11 | llm_name = "gpt-3.5-turbo" 12 | model = ChatOpenAI(api_key=openai_key, model=llm_name) 13 | 14 | messages = [ 15 | SystemMessage( 16 | content="You are a helpful assistant who is extremely competent as a Computer Scientist! Your name is Rob." 17 | ), 18 | HumanMessage(content="who was the very first computer scientist?"), 19 | ] 20 | 21 | 22 | # res = model.invoke(messages) 23 | # print(res) 24 | 25 | 26 | def first_agent(messages): 27 | res = model.invoke(messages) 28 | return res 29 | 30 | 31 | def run_agent(): 32 | print("Simple AI Agent: Type 'exit' to quit") 33 | while True: 34 | user_input = input("You: ") 35 | if user_input.lower() == "exit": 36 | print("Goodbye!") 37 | break 38 | print("AI Agent is thinking...") 39 | messages = [HumanMessage(content=user_input)] 40 | response = first_agent(messages) 41 | print("AI Agent: getting the response...") 42 | print(f"AI Agent: {response.content}") 43 | 44 | 45 | if __name__ == "__main__": 46 | run_agent() 47 | -------------------------------------------------------------------------------- /fun_call_db_agent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dotenv import load_dotenv 4 | from langchain_openai import ChatOpenAI 5 | import pandas as pd 6 | 7 | from sqlalchemy import create_engine 8 | import numpy as np 9 | from sqlalchemy import text 10 | from openai import OpenAI 11 | 12 | import helpers 13 | from helpers import ( 14 | get_avg_salary_and_female_count_for_division, 15 | get_total_overtime_pay_for_department, 16 | get_total_longevity_pay_for_grade, 17 | get_employee_count_by_gender_in_department, 18 | get_employees_with_overtime_above, 19 | ) 20 | 21 | 22 | # Load environment variables from .env file 23 | load_dotenv() 24 | 25 | openai_key = os.getenv("OPENAI_API_KEY") 26 | 27 | 28 | llm_name = "gpt-3.5-turbo" 29 | model = ChatOpenAI(api_key=openai_key, model=llm_name) 30 | 31 | # for the weather function calling 32 | client = OpenAI(api_key=openai_key) 33 | 34 | from langchain.agents import create_sql_agent 35 | from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit 36 | from langchain_community.utilities import SQLDatabase 37 | 38 | # create a db from csv file 39 | 40 | # Path to your SQLite database file 41 | database_file_path = "./db/salary.db" 42 | 43 | 44 | # Create an engine to connect to the SQLite database 45 | # SQLite only requires the path to the database file 46 | engine = create_engine(f"sqlite:///{database_file_path}") 47 | file_url = "./data/salaries_2023.csv" 48 | os.makedirs(os.path.dirname(database_file_path), exist_ok=True) 49 | df = pd.read_csv(file_url).fillna(value=0) 50 | df.to_sql("salaries_2023", con=engine, if_exists="replace", index=False) 51 | 52 | 53 | def run_conversation( 54 | query="""What is the average salary and the count of female employees 55 | # in the ABS 85 Administrative Services division?""", 56 | ): 57 | 58 | messages = [ 59 | # { 60 | # "role": "user", 61 | # "content": """What is the average salary and the count of female employees 62 | # in the ABS 85 Administrative Services division?""", 63 | # }, 64 | { 65 | "role": "user", 66 | "content": query, 67 | }, 68 | # { 69 | # "role": "user", # gives error request too large 70 | # "content": """How many employees have overtime pay above 5000?""", 71 | # }, 72 | ] 73 | 74 | # Call the model with the conversation and available functions 75 | response = client.chat.completions.create( 76 | model=llm_name, 77 | messages=messages, 78 | tools=helpers.tools_sql, 79 | tool_choice="auto", # auto is default, but we'll be explicit 80 | ) 81 | response_message = response.choices[0].message 82 | # print(response_message.model_dump_json(indent=2)) 83 | # print("tool calls: ", response_message.tool_calls) 84 | 85 | tool_calls = response_message.tool_calls 86 | if tool_calls: 87 | # Step 3: call the function 88 | available_functions = { 89 | "get_avg_salary_and_female_count_for_division": get_avg_salary_and_female_count_for_division, 90 | "get_total_overtime_pay_for_department": get_total_overtime_pay_for_department, 91 | "get_total_longevity_pay_for_grade": get_total_longevity_pay_for_grade, 92 | "get_employee_count_by_gender_in_department": get_employee_count_by_gender_in_department, 93 | "get_employees_with_overtime_above": get_employees_with_overtime_above, 94 | } 95 | messages.append(response_message) # extend conversation with assistant's reply 96 | 97 | # Step 4: send the info for each function call and function response to the model 98 | for tool_call in tool_calls: 99 | function_name = tool_call.function.name 100 | function_to_call = available_functions[function_name] 101 | function_args = json.loads(tool_call.function.arguments) 102 | if function_name == "get_employees_with_overtime_above": 103 | function_response = function_to_call(amount=function_args.get("amount")) 104 | elif function_name == "get_total_longevity_pay_for_grade": 105 | function_response = function_to_call(grade=function_args.get("grade")) 106 | else: 107 | function_response = function_to_call(**function_args) 108 | messages.append( 109 | { 110 | "tool_call_id": tool_call.id, 111 | "role": "tool", 112 | "name": function_name, 113 | "content": str(function_response), 114 | } 115 | ) # extend conversation with function responses 116 | second_response = client.chat.completions.create( 117 | model=llm_name, 118 | messages=messages, 119 | ) # get a new response from the model where it can see the function response 120 | 121 | return second_response 122 | 123 | 124 | # Example calls to the functions 125 | if __name__ == "__main__": 126 | res = ( 127 | run_conversation( 128 | query="""What is the total longevity pay for employees with the grade 'M3'?""" 129 | ) 130 | .choices[0] 131 | .message.content 132 | ) 133 | 134 | print(res) 135 | # run_conversation() 136 | # Step 1: First direct call to the functions = 137 | # division_name = "ABS 85 Administrative Services" 138 | # department_name = "Alcohol Beverage Services" 139 | # grade = "M3" 140 | # overtime_amount = 5000 141 | 142 | # avg_salary_and_female_count = get_avg_salary_and_female_count_for_division( 143 | # division_name 144 | # ) 145 | # total_overtime_pay = get_total_overtime_pay_for_department(department_name) 146 | # total_longevity_pay = get_total_longevity_pay_for_grade(grade) 147 | # employee_count_by_gender = get_employee_count_by_gender_in_department( 148 | # department_name 149 | # ) 150 | # employees_with_high_overtime = get_employees_with_overtime_above(overtime_amount) 151 | 152 | # print( 153 | # f"Average Salary and Female Count for Division '{division_name}': {avg_salary_and_female_count}" 154 | # ) 155 | # print( 156 | # f"Total Overtime Pay for Department '{department_name}': {total_overtime_pay}" 157 | # ) 158 | # # print(f"Total Longevity Pay for Grade '{grade}': {total_longevity_pay}") 159 | # # print( 160 | # # f"Employee Count by Gender in Department '{department_name}': {employee_count_by_gender}" 161 | # # ) 162 | # # print( 163 | # # f"Employees with Overtime Pay Above {overtime_amount}: {employees_with_high_overtime}" 164 | # # ) 165 | -------------------------------------------------------------------------------- /fun_calling.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from langchain_openai import ChatOpenAI 4 | import json 5 | from openai import OpenAI 6 | 7 | # Load environment variables from .env file 8 | load_dotenv() 9 | 10 | openai_key = os.getenv("OPENAI_API_KEY") 11 | 12 | llm_name = "gpt-3.5-turbo" # use this cause is cheaper! 13 | model = ChatOpenAI(api_key=openai_key, model=llm_name) 14 | 15 | # for the weather function calling 16 | client = OpenAI(api_key=openai_key) 17 | 18 | 19 | # Example dummy function hard coded to return the same weather 20 | # In production, this could be your backend API or an external API 21 | def get_current_weather(location, unit="fahrenheit"): 22 | """Get the current weather in a given location""" 23 | if "tokyo" in location.lower(): 24 | return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) 25 | elif "san francisco" in location.lower(): 26 | return json.dumps( 27 | {"location": "San Francisco", "temperature": "72", "unit": unit} 28 | ) 29 | elif "paris" in location.lower(): 30 | return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) 31 | else: 32 | return json.dumps({"location": location, "temperature": "unknown"}) 33 | 34 | 35 | def run_conversation(): 36 | # Step 1: send the conversation and available functions to the model 37 | messages = [ 38 | { 39 | "role": "user", 40 | "content": "What's the weather like in San Francisco, Tokyo, and Paris?", 41 | } 42 | ] 43 | 44 | # Define the available functions 45 | tools = [ 46 | { 47 | "type": "function", 48 | "function": { 49 | "name": "get_current_weather", 50 | "description": "Get the current weather in a given location", 51 | "parameters": { 52 | "type": "object", 53 | "properties": { 54 | "location": { 55 | "type": "string", 56 | "description": "The city and state, e.g. San Francisco, CA", 57 | }, 58 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, 59 | }, 60 | "required": ["location"], 61 | }, 62 | }, 63 | } 64 | ] 65 | # Call the model with the conversation and available functions 66 | response = client.chat.completions.create( 67 | model="gpt-4o", 68 | messages=messages, 69 | tools=tools, 70 | tool_choice="auto", # auto is default, but we'll be explicit 71 | ) 72 | response_message = response.choices[0].message 73 | print(response_message.model_dump_json(indent=2)) 74 | print("tool calls: ", response_message.tool_calls) 75 | 76 | tool_calls = response_message.tool_calls 77 | # Step 2: check if the model wanted to call a function 78 | if tool_calls: 79 | # Step 3: call the function 80 | # Note: the JSON response may not always be valid; be sure to handle errors 81 | available_functions = { 82 | "get_current_weather": get_current_weather, 83 | } # only one function in this example, but you can have multiple 84 | messages.append(response_message) # extend conversation with assistant's reply 85 | # Step 4: send the info for each function call and function response to the model 86 | for tool_call in tool_calls: 87 | function_name = tool_call.function.name 88 | function_to_call = available_functions[function_name] 89 | function_args = json.loads(tool_call.function.arguments) 90 | function_response = function_to_call( 91 | location=function_args.get("location"), 92 | unit=function_args.get("unit"), 93 | ) 94 | messages.append( 95 | { 96 | "tool_call_id": tool_call.id, 97 | "role": "tool", 98 | "name": function_name, 99 | "content": function_response, 100 | } 101 | ) # extend conversation with function response 102 | second_response = client.chat.completions.create( 103 | model="gpt-4o", 104 | messages=messages, 105 | ) # get a new response from the model where it can see the function response 106 | return second_response 107 | 108 | 109 | print(run_conversation().model_dump_json(indent=2)) 110 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine, text 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | # Create an engine to connect to the SQLite database 7 | database_file_path = "./db/salary.db" 8 | engine = create_engine(f"sqlite:///{database_file_path}") 9 | 10 | 11 | tools_sql = [ 12 | { 13 | "type": "function", 14 | "function": { 15 | "name": "get_avg_salary_and_female_count_for_division", 16 | "description": """Retrieves the average salary and the count of 17 | female employees in a specific division.""", 18 | "parameters": { 19 | "type": "object", 20 | "properties": { 21 | "division_name": { 22 | "type": "string", 23 | "description": """The name of the division 24 | (e.g., 'ABS 85 Administrative Services').""", 25 | } 26 | }, 27 | "required": ["division_name"], 28 | }, 29 | }, 30 | }, 31 | { 32 | "type": "function", 33 | "function": { 34 | "name": "get_total_overtime_pay_for_department", 35 | "description": """Retrieves the total overtime pay for a 36 | specific department.""", 37 | "parameters": { 38 | "type": "object", 39 | "properties": { 40 | "department_name": { 41 | "type": "string", 42 | "description": """The name of the department 43 | (e.g., 'Alcohol Beverage Services').""", 44 | } 45 | }, 46 | "required": ["department_name"], 47 | }, 48 | }, 49 | }, 50 | { 51 | "type": "function", 52 | "function": { 53 | "name": "get_total_longevity_pay_for_grade", 54 | "description": """Retrieves the total longevity pay for a 55 | specific grade.""", 56 | "parameters": { 57 | "type": "object", 58 | "properties": { 59 | "grade": { 60 | "type": "string", 61 | "description": """The grade of the employees 62 | (e.g., 'M3', 'N25').""", 63 | } 64 | }, 65 | "required": ["grade"], 66 | }, 67 | }, 68 | }, 69 | { 70 | "type": "function", 71 | "function": { 72 | "name": "get_employee_count_by_gender_in_department", 73 | "description": """Retrieves the count of employees by gender 74 | in a specific department.""", 75 | "parameters": { 76 | "type": "object", 77 | "properties": { 78 | "department_name": { 79 | "type": "string", 80 | "description": """The name of the department 81 | (e.g., 'Alcohol Beverage Services').""", 82 | } 83 | }, 84 | "required": ["department_name"], 85 | }, 86 | }, 87 | }, 88 | { 89 | "type": "function", 90 | "function": { 91 | "name": "get_employees_with_overtime_above", 92 | "description": """Retrieves the employees with overtime pay 93 | above a specified amount.""", 94 | "parameters": { 95 | "type": "object", 96 | "properties": { 97 | "amount": { 98 | "type": "number", 99 | "description": """The minimum amount of overtime pay 100 | (e.g., 1000.0).""", 101 | } 102 | }, 103 | "required": ["amount"], 104 | }, 105 | }, 106 | }, 107 | ] 108 | 109 | 110 | def get_avg_salary_and_female_count_for_division(division_name): 111 | try: 112 | query = f""" 113 | SELECT AVG(Base_Salary) AS avg_salary, COUNT(*) AS female_count 114 | FROM salaries_2023 115 | WHERE Division = '{division_name}' AND Gender = 'F'; 116 | """ 117 | query = text(query) 118 | 119 | with engine.connect() as connection: 120 | result = pd.read_sql_query(query, connection) 121 | if not result.empty: 122 | 123 | return result.to_dict("records")[0] 124 | else: 125 | return json.dumps({"avg_salary": np.nan, "female_count": 0}) 126 | # return {"avg_salary": np.nan, "female_count": 0} 127 | except Exception as e: 128 | print(e) 129 | return json.dumps({"avg_salary": np.nan, "female_count": 0}) 130 | # return {"avg_salary": np.nan, "female_count": 0} 131 | 132 | 133 | def get_total_overtime_pay_for_department(department_name): 134 | try: 135 | query = f""" 136 | SELECT SUM(Overtime_Pay) AS total_overtime_pay 137 | FROM salaries_2023 138 | WHERE Department_Name = '{department_name}'; 139 | """ 140 | query = text(query) 141 | 142 | with engine.connect() as connection: 143 | result = pd.read_sql_query(query, connection) 144 | if not result.empty: 145 | 146 | return result.to_dict("records")[0] 147 | else: 148 | return {"total_overtime_pay": 0} 149 | except Exception as e: 150 | print(e) 151 | return {"total_overtime_pay": 0} 152 | 153 | 154 | def get_employees_with_overtime_above(amount): 155 | try: 156 | query = f""" 157 | SELECT * 158 | FROM salaries_2023 159 | WHERE Overtime_Pay > {amount}; 160 | """ 161 | query = text(query) 162 | 163 | with engine.connect() as connection: 164 | result = pd.read_sql_query(query, connection) 165 | if not result.empty: 166 | return result.to_dict("records") 167 | else: 168 | return [] 169 | except Exception as e: 170 | print(e) 171 | return [] 172 | 173 | 174 | def get_employee_count_by_gender_in_department(department_name): 175 | try: 176 | query = f""" 177 | SELECT Gender, COUNT(*) AS employee_count 178 | FROM salaries_2023 179 | WHERE Department_Name = '{department_name}' 180 | GROUP BY Gender; 181 | """ 182 | query = text(query) 183 | 184 | with engine.connect() as connection: 185 | result = pd.read_sql_query(query, connection) 186 | if not result.empty: 187 | return result.to_dict("records") 188 | else: 189 | return [] 190 | except Exception as e: 191 | print(e) 192 | return [] 193 | 194 | 195 | def get_total_longevity_pay_for_grade(grade): 196 | try: 197 | query = f""" 198 | SELECT SUM(Longevity_Pay) AS total_longevity_pay 199 | FROM salaries_2023 200 | WHERE Grade = '{grade}'; 201 | """ 202 | query = text(query) 203 | 204 | with engine.connect() as connection: 205 | result = pd.read_sql_query(query, connection) 206 | if not result.empty: 207 | return result.to_dict("records")[0] 208 | else: 209 | return {"total_longevity_pay": 0} 210 | except Exception as e: 211 | print(e) 212 | return {"total_longevity_pay": 0} 213 | -------------------------------------------------------------------------------- /questions_sql_agent.md: -------------------------------------------------------------------------------- 1 | General Questions: 2 | 3 | What is the average base salary for all employees? 4 | How many employees are there in total? 5 | Which department has the highest average base salary? 6 | Department-Specific Questions: 7 | 8 | What is the total overtime pay for the Alcohol Beverage Services department? 9 | How many employees are in the ABS 85 Administration division? 10 | What is the average base salary for employees in the ABS 85 Administrative Services division? 11 | Gender-Specific Questions: 12 | 13 | What is the average base salary for male employees? 14 | What is the average base salary for female employees? 15 | How many male employees are there compared to female employees? 16 | Grade-Specific Questions: 17 | 18 | What is the average longevity pay for employees in grade M3? 19 | Which grade has the highest average base salary? 20 | Combination Questions: 21 | 22 | What is the average base salary for female employees in the Alcohol Beverage Services department? 23 | How much overtime pay is given on average to employees in grade 21? 24 | Summary and Aggregation Questions: 25 | 26 | What is the total base salary paid by the entire organization? 27 | What is the maximum overtime pay received by any employee? -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyodbc==5.1.0 2 | tabulate==0.9.0 3 | openai==1.56.1 4 | langchain==0.1.6 5 | langchain-community==0.0.20 6 | langchain-core==0.1.23 7 | langchain-experimental==0.0.49 8 | langchain-openai==0.0.5 9 | pandas==2.2.2 10 | SQLAlchemy==2.0.30 11 | pandas==2.2.2 12 | python-dotenv==1.0.1 13 | streamlit 14 | -------------------------------------------------------------------------------- /sql_db_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from langchain_openai import ChatOpenAI 4 | import pandas as pd 5 | 6 | from sqlalchemy import create_engine 7 | 8 | # Load environment variables from .env file 9 | load_dotenv() 10 | 11 | openai_key = os.getenv("OPENAI_API_KEY") 12 | 13 | llm_name = "gpt-3.5-turbo" 14 | model = ChatOpenAI(api_key=openai_key, model=llm_name) 15 | 16 | 17 | from langchain.agents import create_sql_agent 18 | from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit 19 | from langchain_community.utilities import SQLDatabase 20 | 21 | # create a db from csv file 22 | 23 | # Path to your SQLite database file 24 | database_file_path = "./db/salary.db" 25 | 26 | 27 | # Create an engine to connect to the SQLite database 28 | # SQLite only requires the path to the database file 29 | engine = create_engine(f"sqlite:///{database_file_path}") 30 | file_url = "./data/salaries_2023.csv" 31 | os.makedirs(os.path.dirname(database_file_path), exist_ok=True) 32 | df = pd.read_csv(file_url).fillna(value=0) 33 | df.to_sql("salaries_2023", con=engine, if_exists="replace", index=False) 34 | 35 | # print(f"Database created successfully! {df}") 36 | 37 | # Part 2: Prepare the sql prompt 38 | MSSQL_AGENT_PREFIX = """ 39 | 40 | You are an agent designed to interact with a SQL database. 41 | ## Instructions: 42 | - Given an input question, create a syntactically correct {dialect} query 43 | to run, then look at the results of the query and return the answer. 44 | - Unless the user specifies a specific number of examples they wish to 45 | obtain, **ALWAYS** limit your query to at most {top_k} results. 46 | - You can order the results by a relevant column to return the most 47 | interesting examples in the database. 48 | - Never query for all the columns from a specific table, only ask for 49 | the relevant columns given the question. 50 | - You have access to tools for interacting with the database. 51 | - You MUST double check your query before executing it.If you get an error 52 | while executing a query,rewrite the query and try again. 53 | - DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) 54 | to the database. 55 | - DO NOT MAKE UP AN ANSWER OR USE PRIOR KNOWLEDGE, ONLY USE THE RESULTS 56 | OF THE CALCULATIONS YOU HAVE DONE. 57 | - Your response should be in Markdown. However, **when running a SQL Query 58 | in "Action Input", do not include the markdown backticks**. 59 | Those are only for formatting the response, not for executing the command. 60 | - ALWAYS, as part of your final answer, explain how you got to the answer 61 | on a section that starts with: "Explanation:". Include the SQL query as 62 | part of the explanation section. 63 | - If the question does not seem related to the database, just return 64 | "I don\'t know" as the answer. 65 | - Only use the below tools. Only use the information returned by the 66 | below tools to construct your query and final answer. 67 | - Do not make up table names, only use the tables returned by any of the 68 | tools below. 69 | - as part of your final answer, please include the SQL query you used in json format or code format 70 | 71 | ## Tools: 72 | 73 | """ 74 | 75 | MSSQL_AGENT_FORMAT_INSTRUCTIONS = """ 76 | 77 | ## Use the following format: 78 | 79 | Question: the input question you must answer. 80 | Thought: you should always think about what to do. 81 | Action: the action to take, should be one of [{tool_names}]. 82 | Action Input: the input to the action. 83 | Observation: the result of the action. 84 | ... (this Thought/Action/Action Input/Observation can repeat N times) 85 | Thought: I now know the final answer. 86 | Final Answer: the final answer to the original input question. 87 | 88 | Example of Final Answer: 89 | <=== Beginning of example 90 | 91 | Action: query_sql_db 92 | Action Input: 93 | SELECT TOP (10) [base_salary], [grade] 94 | FROM salaries_2023 95 | 96 | WHERE state = 'Division' 97 | 98 | Observation: 99 | [(27437.0,), (27088.0,), (26762.0,), (26521.0,), (26472.0,), (26421.0,), (26408.0,)] 100 | Thought:I now know the final answer 101 | Final Answer: There were 27437 workers making 100,000. 102 | 103 | Explanation: 104 | I queried the `xyz` table for the `salary` column where the department 105 | is 'IGM' and the date starts with '2020'. The query returned a list of tuples 106 | with the bazse salary for each day in 2020. To answer the question, 107 | I took the sum of all the salaries in the list, which is 27437. 108 | I used the following query 109 | 110 | ```sql 111 | SELECT [salary] FROM xyztable WHERE department = 'IGM' AND date LIKE '2020%'" 112 | ``` 113 | ===> End of Example 114 | 115 | """ 116 | 117 | 118 | db = SQLDatabase.from_uri(f"sqlite:///{database_file_path}") 119 | toolkit = SQLDatabaseToolkit(db=db, llm=model) 120 | 121 | QUESTION = """what is the highest average salary by department, and give me the number?" 122 | """ 123 | sql_agent = create_sql_agent( 124 | prefix=MSSQL_AGENT_PREFIX, 125 | format_instructions=MSSQL_AGENT_FORMAT_INSTRUCTIONS, 126 | llm=model, 127 | toolkit=toolkit, 128 | top_k=30, 129 | verbose=True, 130 | ) 131 | 132 | # res = sql_agent.invoke(QUESTION) 133 | 134 | # print(res) 135 | 136 | import streamlit as st 137 | 138 | st.title("SQL Query AI Agent") 139 | 140 | question = st.text_input("Enter your query:") 141 | 142 | if st.button("Run Query"): 143 | if question: 144 | res = sql_agent.invoke(question) 145 | 146 | st.markdown(res["output"]) 147 | else: 148 | st.error("Please enter a query.") 149 | --------------------------------------------------------------------------------