├── .gitignore ├── 1_🛖_Home.py ├── Langchain_sf_con_sample.py ├── README.md ├── app_secrets.py ├── images └── ERD.png ├── pages └── 2_🤝_tbl_definitions.py ├── prompts └── tpch_prompt.yaml ├── requirements.txt ├── sql_assistant.py └── sql_execution.py /.gitignore: -------------------------------------------------------------------------------- 1 | /venv 2 | /__pycache__ -------------------------------------------------------------------------------- /1_🛖_Home.py: -------------------------------------------------------------------------------- 1 | from app_secrets import OPENAI_API_KEY 2 | import os 3 | import streamlit as st 4 | from sql_execution import execute_sf_query 5 | from langchain import OpenAI 6 | from langchain.prompts import load_prompt 7 | from pathlib import Path 8 | from PIL import Image 9 | 10 | def write_to_training_file(file_path,prompt,sql): 11 | try: 12 | with (open(file_path,'w')) as file: 13 | file.write("\n prompt : {}".format(prompt)) 14 | file.write("\n sql : {}".format(sql)) 15 | file.write("\n lable : 1 \n\n") 16 | file.close() 17 | return "success" 18 | except: 19 | print("problem in opening file") 20 | return "problem in openeing file" 21 | 22 | #setup env variable 23 | os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY 24 | #project root directory 25 | current_dir = Path(__file__) 26 | root_dir = [p for p in current_dir.parents if p.parts[-1]=='ai_sql_assistant'][0] 27 | #frontend 28 | st.set_page_config( 29 | page_title="Query Assistant", 30 | page_icon="🌄" 31 | ) 32 | st.sidebar.success("Select a page above") 33 | 34 | tab_titles=[ 35 | "Results", 36 | "Query", 37 | "ER Diagram" 38 | ] 39 | 40 | st.title("Your Project Assistant") 41 | prompt = st.text_input("enter your query") 42 | tabs = st.tabs(tab_titles) 43 | with tabs[2]: 44 | image = Image.open("{}/images/ERD.png".format(root_dir)) 45 | st.image(image,caption="Entity Relationship") 46 | 47 | prompt_template = load_prompt(f"{root_dir}/prompts/tpch_prompt.yaml") 48 | final_prompt = prompt_template.format(input=prompt) 49 | 50 | llm = OpenAI(temperature=0.9) 51 | 52 | if prompt: 53 | query_text = llm(prompt=final_prompt) 54 | output = execute_sf_query(query_text) 55 | with tabs[0]: 56 | st.write(output) 57 | with tabs[1]: 58 | st.write(query_text) 59 | add_to_training_data = st.button("Add to training data") 60 | if add_to_training_data: 61 | file_path="{}/trainings/gpt_trainings.txt".format(root_dir) 62 | write_to_file_status = write_to_training_file(file_path=file_path,prompt=prompt,sql=query_text) 63 | if write_to_file_status == "success": 64 | st.write("Scenario added to trainings file") 65 | else: 66 | st.write(write_to_file_status) -------------------------------------------------------------------------------- /Langchain_sf_con_sample.py: -------------------------------------------------------------------------------- 1 | 2 | from langchain.document_loaders import SnowflakeLoader 3 | 4 | QUERY = "select * from analytics.raw.customer limit 10" 5 | snowflake_loader = SnowflakeLoader( 6 | query=QUERY, 7 | user="SNOWFLAKETEST014", 8 | password="Snowfl@ketest011", 9 | account="qramxco-el59371", 10 | warehouse="DBT_WH", 11 | role="ACCOUNTADMIN", 12 | database="analytics", 13 | schema="raw", 14 | ) 15 | snowflake_documents = snowflake_loader.load() 16 | i=1 17 | for document in snowflake_documents: 18 | print("row number = {} =======================".format(i)) 19 | print(document.page_content) 20 | i=i+1 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawanrawat0926/ai_sql_assistant/6dfa85e93c9622a80b9027a614b6f251e6a7ae5a/README.md -------------------------------------------------------------------------------- /app_secrets.py: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY="YOUR OPENAI API KEY" 2 | SF_USER='SF USER NAME' 3 | SF_PASSWORD='SF PASSWORD FOR ABOVE USER' 4 | SF_ACCOUNT='SF ACCOUNT IDENTIFIER' 5 | SF_WAREHOUSE='SF WAREHOUSE' 6 | SF_DATABASE = 'SF DATABASE' 7 | SF_SCHEMA = 'SF SCHEMA' 8 | SF_ROLE='ACCOUNTADMIN' 9 | -------------------------------------------------------------------------------- /images/ERD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawanrawat0926/ai_sql_assistant/6dfa85e93c9622a80b9027a614b6f251e6a7ae5a/images/ERD.png -------------------------------------------------------------------------------- /pages/2_🤝_tbl_definitions.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | tab_title=[ 4 | "REGION", 5 | "NATION", 6 | "CUSTOMER", 7 | "ORDERS" 8 | ] 9 | 10 | tbl_tabs = st.tabs(tab_title) 11 | with tbl_tabs[0]: 12 | st.code(''' 13 | REGION ( 14 | R_REGIONKEY NUMBER(38,0), 15 | R_NAME VARCHAR(25), 16 | R_COMMENT VARCHAR(152) 17 | ) 18 | ''',language="python") 19 | with tbl_tabs[1]: 20 | st.code(''' 21 | NATION ( 22 | N_NATIONKEY NUMBER(38,0), 23 | N_NAME VARCHAR(25), 24 | N_REGIONKEY NUMBER(38,0), 25 | N_COMMENT VARCHAR(152) 26 | ) 27 | ''',language="python") 28 | with tbl_tabs[2]: 29 | st.code(''' 30 | CUSTOMER ( 31 | C_CUSTKEY NUMBER(38,0), 32 | C_NAME VARCHAR(25), 33 | C_ADDRESS VARCHAR(40), 34 | C_NATIONKEY NUMBER(38,0), 35 | C_PHONE VARCHAR(15), 36 | C_ACCTBAL NUMBER(12,2), 37 | C_MKTSEGMENT VARCHAR(10), 38 | C_COMMENT VARCHAR(117) 39 | ) 40 | ''',language="python") 41 | with tbl_tabs[3]: 42 | st.code(''' 43 | ORDERS ( 44 | O_ORDERKEY NUMBER(38,0), 45 | O_CUSTKEY NUMBER(38,0), 46 | O_ORDERSTATUS VARCHAR(1), 47 | O_TOTALPRICE NUMBER(12,2), 48 | O_ORDERDATE DATE, 49 | O_ORDERPRIORITY VARCHAR(15), 50 | O_CLERK VARCHAR(15), 51 | O_SHIPPRIORITY NUMBER(38,0), 52 | O_COMMENT VARCHAR(79) 53 | ) 54 | ''',language="python") -------------------------------------------------------------------------------- /prompts/tpch_prompt.yaml: -------------------------------------------------------------------------------- 1 | _type: prompt 2 | input_variables: 3 | ["input"] 4 | template: 5 | '''Given below are the table structures in analytics database raw schema in snowflake cloud database 6 | CUSTOMER ( 7 | C_CUSTKEY NUMBER(38,0), 8 | C_NAME VARCHAR(25), 9 | C_ADDRESS VARCHAR(40), 10 | C_NATIONKEY NUMBER(38,0), 11 | C_PHONE VARCHAR(15), 12 | C_ACCTBAL NUMBER(12,2), 13 | C_MKTSEGMENT VARCHAR(10), 14 | C_COMMENT VARCHAR(117) 15 | ); 16 | LINEITEM ( 17 | L_ORDERKEY NUMBER(38,0), 18 | L_PARTKEY NUMBER(38,0), 19 | L_SUPPKEY NUMBER(38,0), 20 | L_LINENUMBER NUMBER(38,0), 21 | L_QUANTITY NUMBER(12,2), 22 | L_EXTENDEDPRICE NUMBER(12,2), 23 | L_DISCOUNT NUMBER(12,2), 24 | L_TAX NUMBER(12,2), 25 | L_RETURNFLAG VARCHAR(1), 26 | L_LINESTATUS VARCHAR(1), 27 | L_SHIPDATE DATE, 28 | L_COMMITDATE DATE, 29 | L_RECEIPTDATE DATE, 30 | L_SHIPINSTRUCT VARCHAR(25), 31 | L_SHIPMODE VARCHAR(10), 32 | L_COMMENT VARCHAR(44) 33 | ); 34 | NATION ( 35 | N_NATIONKEY NUMBER(38,0), 36 | N_NAME VARCHAR(25), 37 | N_REGIONKEY NUMBER(38,0), 38 | N_COMMENT VARCHAR(152) 39 | ); 40 | PART ( 41 | P_PARTKEY NUMBER(38,0), 42 | P_NAME VARCHAR(55), 43 | P_MFGR VARCHAR(25), 44 | P_BRAND VARCHAR(10), 45 | P_TYPE VARCHAR(25), 46 | P_SIZE NUMBER(38,0), 47 | P_CONTAINER VARCHAR(10), 48 | P_RETAILPRICE NUMBER(12,2), 49 | P_COMMENT VARCHAR(23) 50 | ); 51 | PARTSUPP ( 52 | PS_PARTKEY NUMBER(38,0), 53 | PS_SUPPKEY NUMBER(38,0), 54 | PS_AVAILQTY NUMBER(38,0), 55 | PS_SUPPLYCOST NUMBER(12,2), 56 | PS_COMMENT VARCHAR(199) 57 | ); 58 | REGION ( 59 | R_REGIONKEY NUMBER(38,0), 60 | R_NAME VARCHAR(25), 61 | R_COMMENT VARCHAR(152) 62 | ); 63 | SUPPLIER ( 64 | S_SUPPKEY NUMBER(38,0), 65 | S_NAME VARCHAR(25), 66 | S_ADDRESS VARCHAR(40), 67 | S_NATIONKEY NUMBER(38,0), 68 | S_PHONE VARCHAR(15), 69 | S_ACCTBAL NUMBER(12,2), 70 | S_COMMENT VARCHAR(101) 71 | ); 72 | ORDERS ( 73 | O_ORDERKEY NUMBER(38,0), 74 | O_CUSTKEY NUMBER(38,0), 75 | O_ORDERSTATUS VARCHAR(1), 76 | O_TOTALPRICE NUMBER(12,2), 77 | O_ORDERDATE DATE, 78 | O_ORDERPRIORITY VARCHAR(15), 79 | O_CLERK VARCHAR(15), 80 | O_SHIPPRIORITY NUMBER(38,0), 81 | O_COMMENT VARCHAR(79) 82 | ); 83 | take user questions and response back with sql query. 84 | example : 85 | user question : give me the number of orders placed in last 10 days 86 | your generated sql query : select o_orderdate , count(*) from analytics.raw.orders where o_orderdate between current_date()-10 and current_date() group by o_orderdate ; 87 | example : 88 | user question : tell me top 3 nations having the maximum orders 89 | your generated sql query : select n.n_name , count(*) as order_count from analytics.raw.orders o 90 | inner join analytics.raw.customer c on o.o_custkey = c.c_custkey 91 | inner join analytics.raw.nation n on c.c_nationkey = n.n_nationkey 92 | group by n.n_name order by order_count desc limit 3 93 | ; 94 | example : 95 | user_question : give me the name and address of suppliers for which the available quatity is minimum 96 | your generated sql query :select distinct s.s_name ,s.s_address 97 | from analytics.raw.part p 98 | inner join analytics.raw.partsupp ps on p.p_partkey = ps.ps_partkey 99 | inner join analytics.raw.supplier s on ps.ps_suppkey = s.s_suppkey 100 | where p.p_partkey in (select ps2.ps_partkey from analytics.raw.partsupp ps2 order by ps_availqty asc limit 1 ); 101 | user question : {input} 102 | your generated sql query : ''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawanrawat0926/ai_sql_assistant/6dfa85e93c9622a80b9027a614b6f251e6a7ae5a/requirements.txt -------------------------------------------------------------------------------- /sql_assistant.py: -------------------------------------------------------------------------------- 1 | from apikey import OPENAI_API_KEY 2 | import os 3 | import streamlit as st 4 | from sql_execution import execute_sf_query 5 | from langchain import OpenAI 6 | from langchain.prompts import load_prompt 7 | from pathlib import Path 8 | 9 | def main(): 10 | #setup env variable 11 | os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY 12 | #project root directory 13 | current_dir = Path(__file__) 14 | root_dir = [p for p in current_dir.parents if p.parts[-1]=='langchain_demo'][0] 15 | #frontend 16 | st.title("Your Project Assistant") 17 | prompt = st.text_input("enter your query") 18 | 19 | prompt_template = load_prompt(f"{root_dir}/prompts/tpch_prompt.yaml") 20 | final_prompt = prompt_template.format(input=prompt) 21 | 22 | llm = OpenAI(temperature=0.9) 23 | 24 | if prompt: 25 | response = llm(prompt=final_prompt) 26 | with st.expander(label="SQL Query",expanded=False): 27 | st.write(response) 28 | output = execute_sf_query(response) 29 | st.write(output) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /sql_execution.py: -------------------------------------------------------------------------------- 1 | import snowflake.connector 2 | import pandas as pd 3 | from app_secrets import * 4 | 5 | def execute_sf_query(sql): 6 | # Snowflake connection parameters 7 | connection_params = { 8 | 'user': SF_USER, 9 | 'password': SF_PASSWORD, 10 | 'account': SF_ACCOUNT, 11 | 'warehouse': SF_WAREHOUSE, 12 | 'database': SF_DATABASE, 13 | 'schema': SF_SCHEMA, 14 | 'role':SF_ROLE 15 | } 16 | 17 | query=sql 18 | 19 | try: 20 | # Establish a connection to Snowflake 21 | conn = snowflake.connector.connect(**connection_params) 22 | 23 | # Create a cursor object 24 | cur = conn.cursor() 25 | 26 | # Execute the query 27 | try: 28 | cur.execute(query) 29 | except snowflake.connector.errors.ProgrammingError as pe: 30 | print("Query Compilation Error:", pe) 31 | return("Query compilation error") 32 | 33 | # Fetch all results 34 | query_results = cur.fetchall() 35 | 36 | # Get column names from the cursor description 37 | column_names = [col[0] for col in cur.description] 38 | 39 | # Create a Pandas DataFrame 40 | data_frame = pd.DataFrame(query_results, columns=column_names) 41 | 42 | # Print the DataFrame 43 | #print(data_frame) 44 | return data_frame 45 | 46 | except snowflake.connector.errors.DatabaseError as de: 47 | print("Snowflake Database Error:", de) 48 | 49 | except Exception as e: 50 | print("An error occurred:", e) 51 | 52 | finally: 53 | # Close the cursor and connection 54 | try: 55 | cur.close() 56 | except: 57 | pass 58 | 59 | try: 60 | conn.close() 61 | except: 62 | pass 63 | 64 | 65 | if __name__ == "__main__": 66 | # Snowflake query 67 | query = ''' 68 | select n.n_name , count(*) as order_count from analytics.raw.orders o 69 | inner join analytics.raw.customer c on o.o_custkey = c.c_custkey 70 | inner join analytics.raw.nation n on c.c_nationkey = n.n_nationkey 71 | group by n.n_name order by order_count desc limit 3 72 | ''' 73 | execute_sf_query(query) --------------------------------------------------------------------------------