├── .gitignore ├── Chinook.db ├── requirements.txt ├── evaluate.py ├── app.py ├── original_app.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /Chinook.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwchase17/sql-qa/HEAD/Chinook.db -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain 2 | langchain_experimental 3 | openai 4 | streamlit 5 | tiktoken 6 | faiss-cpu 7 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from langchain.smith import RunEvalConfig, run_on_dataset 2 | from langsmith import Client 3 | 4 | from utils import get_agent 5 | 6 | if __name__ == "__main__": 7 | 8 | client = Client() 9 | eval_config = RunEvalConfig( 10 | evaluators=[ 11 | "qa" 12 | ], 13 | ) 14 | chain_results = run_on_dataset( 15 | client, 16 | dataset_name="misspelled-examples", 17 | concurrency_level=1, 18 | llm_or_chain_factory=get_agent, 19 | evaluation=eval_config, 20 | ) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from utils import get_agent 4 | 5 | st.set_page_config(page_title='🦜🔗 Ask the SQL DB App') 6 | st.title('🦜🔗 Ask the SQL DB App') 7 | st.info(""" 8 | Most 'question answering' applications run over unstructured text data. 9 | But a lot of the data in the world is tabular data! 10 | This is an attempt to create an application using [LangChain](https://github.com/langchain-ai/langchain) to let you ask questions of data in tabular format. 11 | The special property about this application is that it is **robust to spelling mistakes**: you can spell an artist or song wrong but you should still get the results you are looking for. 12 | For this demo application, we will use the Chinook dataset in a SQL database. 13 | Please explore it [here](https://github.com/lerocha/chinook-database) to get a sense for what questions you can ask. 14 | Please leave feedback on how well the question is answered, and we will use that improve the application! 15 | """) 16 | 17 | agent = get_agent() 18 | 19 | agent.verbose = True 20 | agent.return_intermediate_steps = False 21 | 22 | from langsmith import Client 23 | 24 | client = Client() 25 | def send_feedback(run_id, score): 26 | client.create_feedback(run_id, "user_score", score=score) 27 | 28 | query_text = st.text_input('Enter your question:', placeholder = 'How many artists are there?') 29 | 30 | print(query_text) 31 | 32 | result = None 33 | with st.form('myform', clear_on_submit=True): 34 | submitted = st.form_submit_button('Submit') 35 | if submitted: 36 | with st.spinner('Calculating...'): 37 | response = agent(query_text, include_run_info=True) 38 | result = response["output"] 39 | run_id = response["__run"].run_id 40 | if result is not None: 41 | st.info(result) 42 | col_blank, col_text, col1, col2 = st.columns([10, 2,1,1]) 43 | with col_text: 44 | st.text("Feedback:") 45 | with col1: 46 | st.button("👍", on_click=send_feedback, args=(run_id, 1)) 47 | with col2: 48 | st.button("👎", on_click=send_feedback, args=(run_id, 0)) -------------------------------------------------------------------------------- /original_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import langchain 3 | from langchain.utilities import SQLDatabase 4 | from langchain_experimental.sql import SQLDatabaseChain 5 | from langchain.chat_models import ChatOpenAI 6 | from langsmith import Client 7 | from langchain.smith import RunEvalConfig, run_on_dataset 8 | from pydantic import BaseModel, Field 9 | 10 | db = SQLDatabase.from_uri("sqlite:///Chinook.db") 11 | llm = ChatOpenAI(temperature=0) 12 | db_chain = SQLDatabaseChain.from_llm(llm, db, return_intermediate_steps=True) 13 | 14 | from langsmith import Client 15 | client = Client() 16 | def send_feedback(run_id, score): 17 | client.create_feedback(run_id, "user_score", score=score) 18 | 19 | st.set_page_config(page_title='🦜🔗 Ask the SQL DB App') 20 | st.title('🦜🔗 Ask the SQL DB App') 21 | st.info("Most 'question answering' applications run over unstructured text data. But a lot of the data in the world is tabular data! This is an attempt to create an application using [LangChain](https://github.com/langchain-ai/langchain) to let you ask questions of data in tabular format. For this demo application, we will use the Chinook dataset in a SQL database. Please explore the schema [here](https://www.sqlitetutorial.net/wp-content/uploads/2015/11/sqlite-sample-database-color.jpg) to get a sense for what questions you can ask. Please leave feedback on well the question is answered, and we will use that improve the application!") 22 | 23 | query_text = st.text_input('Enter your question:', placeholder = 'Ask something like "How many artists are there?" or "Which artist has the most albums"') 24 | # Form input and query 25 | result = None 26 | with st.form('myform', clear_on_submit=True): 27 | submitted = st.form_submit_button('Submit') 28 | if submitted: 29 | with st.spinner('Calculating...'): 30 | inputs = {"query": query_text} 31 | response = db_chain(inputs, include_run_info=True) 32 | result = response["result"] 33 | sql_command = response["intermediate_steps"][1] 34 | sql_result = response["intermediate_steps"][3] 35 | run_id = response["__run"].run_id 36 | if result is not None: 37 | st.info(result) 38 | st.code(sql_command) 39 | st.code(sql_result) 40 | col_blank, col_text, col1, col2 = st.columns([10, 2,1,1]) 41 | with col_text: 42 | st.text("Feedback:") 43 | with col1: 44 | st.button("👍", on_click=send_feedback, args=(run_id, 1)) 45 | with col2: 46 | st.button("👎", on_click=send_feedback, args=(run_id, 0)) 47 | 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import re 3 | 4 | from langchain import OpenAI 5 | from langchain.agents import AgentExecutor, OpenAIFunctionsAgent 6 | from langchain.agents.agent_toolkits import ( 7 | SQLDatabaseToolkit, 8 | create_retriever_tool, 9 | create_sql_agent, 10 | ) 11 | from langchain.agents.agent_types import AgentType 12 | from langchain.chat_models import ChatOpenAI 13 | from langchain.embeddings import OpenAIEmbeddings 14 | from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder 15 | from langchain.sql_database import SQLDatabase 16 | from langchain.tools import Tool 17 | from langchain.utilities import SQLDatabase 18 | from langchain.vectorstores import FAISS 19 | from langchain_experimental.sql import SQLDatabaseChain 20 | from pydantic import BaseModel, Field 21 | 22 | 23 | def run_query_save_results(db, query): 24 | res = db.run(query) 25 | res = [el for sub in ast.literal_eval(res) for el in sub if el] 26 | res = [re.sub(r'\b\d+\b', '', string).strip() for string in res] 27 | 28 | return res 29 | 30 | def run_query_save_results_names(db, query): 31 | res = db.run(query) 32 | res = ast.literal_eval(res) 33 | res = [' '.join(i) for i in res] 34 | 35 | return res 36 | 37 | def get_retriever(texts): 38 | embeddings = OpenAIEmbeddings() 39 | vector_db = FAISS.from_texts(texts, embeddings) 40 | 41 | return vector_db.as_retriever() 42 | 43 | def get_agent(): 44 | 45 | db = SQLDatabase.from_uri("sqlite:///Chinook.db") 46 | llm = ChatOpenAI(temperature=0, model_name='gpt-4') 47 | 48 | artists = run_query_save_results(db, "SELECT Name FROM Artist") 49 | customers = run_query_save_results(db, "SELECT Company, Address, City, State, Country FROM Customer") 50 | employees = run_query_save_results(db, "SELECT Address, City, State, Country FROM Employee") 51 | albums = run_query_save_results(db, "SELECT Title FROM Album") 52 | 53 | customer_names = run_query_save_results_names(db, "SELECT FirstName, LastName FROM Customer") 54 | employee_names = run_query_save_results_names(db, "SELECT FirstName, LastName FROM Employee") 55 | 56 | texts = ( 57 | artists + 58 | customers + 59 | customer_names + 60 | employee_names + 61 | employees + 62 | albums 63 | ) 64 | 65 | retriever = get_retriever(texts) 66 | 67 | retriever_tool = create_retriever_tool( 68 | retriever, 69 | name='name_search', 70 | description='use to learn how a piece of data is actually written, can be from names, surnames addresses etc' 71 | ) 72 | 73 | sql_agent = create_sql_agent( 74 | llm=llm, 75 | toolkit=SQLDatabaseToolkit(db=db, llm=llm), 76 | verbose=True, 77 | agent_type=AgentType.OPENAI_FUNCTIONS 78 | ) 79 | 80 | sql_tool = Tool( 81 | func=sql_agent.run, 82 | name="db_agent", 83 | description="use to get information from the databases, ask exactly what you want in natural language" 84 | ) 85 | 86 | # db_chain = SQLDatabaseChain.from_llm( 87 | # OpenAI(temperature=0, verbose=True), 88 | # db 89 | # ) 90 | 91 | # sql_tool = Tool( 92 | # func=db_chain.run, 93 | # name="db_agent", 94 | # description="use to get information from the databases, ask exactly what you want in natural language" 95 | # ) 96 | 97 | TEMPLATE = """You are working with an SQL database. 98 | 99 | You have a tool called `name_search` through which you can lookup the name of any entity that is present in the database. This could be a person name, an address, a music track name or others. 100 | You should always use this `name_search` tool to search for the correct way that something is written before you use the `db_agent` tool. 101 | You should use the `name_search` tool ONLY ONCE and you should also use the `db_agent` tool ONLY ONCE. 102 | 103 | If the user questions contains a term that is not spelled correctly, you should assume that the user meant the correct spelling and answer the question for the correctly spelled term. 104 | 105 | As soon as you have an answer to the question, you should return and not invoke more functions. 106 | """ 107 | 108 | class PythonInputs(BaseModel): 109 | query: str = Field(description="code snippet to run") 110 | 111 | template = TEMPLATE.format() 112 | 113 | prompt = ChatPromptTemplate.from_messages([ 114 | ("system", template), 115 | MessagesPlaceholder(variable_name="agent_scratchpad"), 116 | ("human", "{input}") 117 | ]) 118 | 119 | tools = [ 120 | sql_tool, 121 | retriever_tool 122 | ] 123 | 124 | agent = OpenAIFunctionsAgent( 125 | llm=llm, 126 | prompt=prompt, 127 | tools=tools 128 | ) 129 | 130 | agent_executor = AgentExecutor( 131 | agent=agent, 132 | tools=tools, 133 | max_iterations=2, 134 | early_stopping_method="generate" 135 | ) 136 | return agent_executor --------------------------------------------------------------------------------