├── README.md ├── app.py ├── data ├── employees.csv └── purchases.csv ├── groqcloud_darkmode.png ├── requirements.txt └── verified-queries ├── employees-without-purchases.yaml ├── most-expensive-purchase.yaml ├── most-recent-purchases.yaml └── number-of-teslas.yaml /README.md: -------------------------------------------------------------------------------- 1 | # DuckDB Query Retriever 2 | 3 | This repository contains a Streamlit application that allows users to ask questions about their DuckDB data using the Groq API. The application uses pre-verified SQL queries and their descriptions stored in YAML files to find the most similar query to the user's question, execute it against the data, and return the results (if a prompt is not similar to a vetted query, no data will be returned). 4 | 5 | ## Features 6 | 7 | - **Semantic Search**: The application uses semantic search to find the most similar pre-verified SQL query to the user's question. 8 | 9 | - **Data Querying**: The application executes the selected SQL query on a DuckDB database and displays the result. 10 | 11 | - **Data Summarization**: After executing a SQL query, the application uses the Groq API to summarize the resulting data in relation to the user's original question. 12 | 13 | - **Customization**: Users can customize the SentenceTransformer model used for semantic search, the Groq model used for summarization, and the minimum similarity threshold for selecting a verified SQL query. 14 | 15 | ## Data 16 | 17 | The application queries data from CSV files located in the [data](app.py#L96) folder: 18 | 19 | - `employees.csv`: Contains employee data including their ID, full name, and email address. 20 | 21 | - `purchases.csv`: Records purchase details including purchase ID, date, associated employee ID, amount, and product name. 22 | 23 | ## Verified Queries 24 | 25 | The verified SQL queries and their descriptions are stored in YAML files located in the `verified-queries` folder. Descriptions are used to semantically map prompts to queries: 26 | 27 | - `most-recent-purchases.yaml`: Returns the 5 most recent purchases 28 | 29 | - `most-expensive-purchase.yaml`: Finds the most expensive purchases 30 | 31 | - `number-of-teslas.yaml`: Counts the number of Teslas purchased 32 | 33 | - `employees-without-purchases.yaml`: Gets employees without any recent purchases 34 | 35 | ## Functions 36 | 37 | - `get_verified_queries_and_embeddings(directory_path, embedding_model)`: Reads YAML files from the specified directory, loads the verified SQL queries and their descriptions, and generates embeddings for the descriptions using the provided SentenceTransformer model. 38 | 39 | - `get_verified_sql(embedding_model, user_question, verified_queries_dict, minimum_similarity)`: Generates an embedding for the user's question, calculates the cosine similarity between the question's embedding and the embeddings of the verified queries, and returns the SQL of the most similar query if its similarity is above the specified minimum similarity threshold. 40 | 41 | - `chat_with_groq(client, prompt, model)`: Sends a chat message to the Groq API and returns the content of the response. 42 | 43 | - `execute_duckdb_query(query)`: Executes the provided SQL query using DuckDB and returns the result as a DataFrame. 44 | 45 | - `get_summarization(client, user_question, df, model, additional_context`: Generates a prompt that includes the user's question and the DataFrame result, sends the prompt to the Groq API for summarization, and returns the summarized response. 46 | 47 | - `main()`: The main function of the application, which initializes the Groq client and the SentenceTransformer model, gets user input from the Streamlit interface, retrieves and executes the most similar verified SQL query, and displays the result and its summarization. 48 | 49 | ## Usage 50 | 51 | To use this application, you need to have Streamlit and the other required Python libraries installed. You also need to have a Groq API key, which you can obtain by signing up on the Groq website. 52 | 53 | Once you have the necessary requirements, you can run the application by executing the script with Streamlit: 54 | 55 | ```shell 56 | streamlit run app.py 57 | ``` 58 | 59 | This will start the Streamlit server and open the application in your web browser. You can then ask questions about your DuckDB data, and the application will find the most similar pre-verified SQL query, execute it, and return the results. 60 | 61 | ## Customizing with Your Own Data 62 | 63 | This application is designed to be flexible and can be easily customized to work with your own data. If you want to use your own data, follow these steps: 64 | 65 | 1. **Replace the CSV files**: The application queries data from CSV files located in the `data` folder. Replace these files with your own CSV files. 66 | 67 | 2. **Modify the verified queries**: The verified SQL queries and their descriptions are stored in YAML files located in the `verified-queries` folder. Replace these files with your own verified SQL queries and descriptions. 68 | 69 | By following these steps, you can tailor the DuckDB Query Retriever to your own data and use cases. Feel free to experiment and build off this repository to create your own powerful data querying applications. 70 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | from groq import Groq 4 | from sentence_transformers import SentenceTransformer 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | import numpy as np 7 | import duckdb 8 | import yaml 9 | import glob 10 | import sqlparse 11 | 12 | 13 | def get_verified_queries_and_embeddings(directory_path, embedding_model): 14 | """ 15 | This function loads pre-verified SQL queries and their descriptions from YAML files in a specified directory. 16 | It generates embeddings for the descriptions using a provided SentenceTransformer model and returns a dictionary 17 | mapping file names to their corresponding queries, descriptions, and embeddings. 18 | 19 | """ 20 | verified_queries_yaml_files = glob.glob(os.path.join(directory_path, '*.yaml')) 21 | verified_queries_dict = {} 22 | for file in verified_queries_yaml_files: 23 | with open(file, 'r') as stream: 24 | try: 25 | file_name = file[len(directory_path):-5] 26 | verified_queries_dict[file_name] = yaml.safe_load(stream) 27 | verified_queries_dict[file_name]['embeddings'] = embedding_model.encode(verified_queries_dict[file_name]['description']) 28 | except yaml.YAMLError as exc: 29 | continue 30 | 31 | return verified_queries_dict 32 | 33 | 34 | def get_verified_sql(embedding_model,user_question,verified_queries_dict,minimum_similarity): 35 | """ 36 | This function takes a user's question and finds the most similar pre-verified SQL query based on cosine similarity 37 | between the question's embedding and the embeddings of the verified queries. If the highest similarity is above a 38 | specified minimum similarity threshold, it formats and returns the SQL of the most similar query. Otherwise, it 39 | returns None and displays a message indicating that it couldn't find a suitable verified query. 40 | """ 41 | 42 | # Get embeddings for user question 43 | prompt_embeddings = embedding_model.encode(user_question) 44 | 45 | # Calculate embedding similarity for verified queries using cosine similarity 46 | embeddings_list = [data["embeddings"] for prompt, data in verified_queries_dict.items()] 47 | verified_queries = list(verified_queries_dict.keys()) 48 | similarities = cosine_similarity([prompt_embeddings], embeddings_list)[0] 49 | 50 | # Find the index of the highest similarity 51 | max_similarity_index = np.argmax(similarities) 52 | 53 | # Retrieve the most similar prompt using the index 54 | most_similar_prompt = verified_queries[max_similarity_index] 55 | highest_similarity = similarities[max_similarity_index] 56 | sql_query = sqlparse.format(verified_queries_dict[most_similar_prompt]['sql'], reindent=True, keyword_case='upper') 57 | 58 | if highest_similarity >= minimum_similarity / 100.0: 59 | #st.write("Found a verified query:",most_similar_prompt,'(similarity',round(highest_similarity*100,1),'\%)') 60 | st.markdown(f"Found a verified query: **{most_similar_prompt}** ({round(highest_similarity*100,1)}% similarity)") 61 | st.markdown("```sql\n" + sql_query + "\n```") 62 | return sql_query 63 | else: 64 | st.markdown(f"Unable to find a verified query to answer your question. Most similar prompt: **{most_similar_prompt}** ({round(highest_similarity*100,1)}% similarity)") 65 | return None 66 | 67 | 68 | 69 | def chat_with_groq(client,prompt,model): 70 | """ 71 | This function sends a chat message to the Groq API and returns the content of the response. 72 | It takes three parameters: the Groq client, the chat prompt, and the model to use for the chat. 73 | """ 74 | 75 | completion = client.chat.completions.create( 76 | model=model, 77 | messages=[ 78 | { 79 | "role": "user", 80 | "content": prompt 81 | } 82 | ] 83 | ) 84 | 85 | return completion.choices[0].message.content 86 | 87 | 88 | def execute_duckdb_query(query): 89 | """ 90 | This function executes a provided SQL query on a DuckDB database and returns the result as a DataFrame. 91 | It changes the current working directory to the 'data' folder where the CSV files are located, 92 | creates a connection to a DuckDB database in memory, executes the query, fetches the result as a DataFrame, 93 | and then resets the current working directory to its original location. 94 | """ 95 | 96 | original_cwd = os.getcwd() 97 | os.chdir('data') 98 | 99 | try: 100 | conn = duckdb.connect(database=':memory:', read_only=False) 101 | query_result = conn.execute(query).fetchdf().reset_index(drop=True) 102 | finally: 103 | os.chdir(original_cwd) 104 | 105 | return query_result 106 | 107 | 108 | def get_summarization(client,user_question,df,model,additional_context): 109 | """ 110 | This function generates a prompt that includes the user's question and the DataFrame result, sends the prompt to the Groq API for summarization, and returns the summarized response. 111 | If additional context is provided, it is included in the prompt. 112 | """ 113 | 114 | prompt = ''' 115 | A user asked the following question pertaining to local database tables: 116 | 117 | {user_question} 118 | 119 | To answer the question, a dataframe was returned: 120 | 121 | Dataframe: 122 | {df} 123 | 124 | In a few sentences, summarize the data in the table as it pertains to the original user question. Avoid qualifiers like "based on the data" and do not comment on the structure or metadata of the table itself 125 | '''.format(user_question = user_question, df = df) 126 | 127 | if additional_context != '': 128 | prompt += '''\n 129 | The user has provided this additional context: 130 | {additional_context} 131 | '''.format(additional_context=additional_context) 132 | 133 | return chat_with_groq(client,prompt,model) 134 | 135 | 136 | def main(): 137 | """ 138 | This is the main function that runs the application. It initializes the Groq client and the SentenceTransformer model, 139 | gets user input from the Streamlit interface, retrieves and executes the most similar verified SQL query, and displays 140 | the result and its summarization. 141 | """ 142 | 143 | # Initialize the Groq client 144 | groq_api_key = st.secrets["GROQ_API_KEY"] 145 | client = Groq( 146 | api_key=groq_api_key 147 | ) 148 | 149 | # Initialize the SentenceTransformer model 150 | embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') 151 | 152 | # Display the Groq logo 153 | spacer, col = st.columns([5, 1]) 154 | with col: 155 | st.image('groqcloud_darkmode.png') 156 | 157 | # Display the title and introduction of the application 158 | st.title("DuckDB Query Retriever") 159 | multiline_text = """ 160 | Welcome! Ask questions about employee data or purchase details, like "Show the 5 most recent purchases" or "What was the most expensive purchase?". The app matches your question to pre-verified SQL queries for accurate results. 161 | """ 162 | 163 | st.markdown(multiline_text, unsafe_allow_html=True) 164 | 165 | # Add customization options to the sidebar 166 | st.sidebar.title('Customization') 167 | additional_context = st.sidebar.text_input('Enter additional summarization context for the LLM here (i.e. write it in spanish):') 168 | model = st.sidebar.selectbox( 169 | 'Choose a model', 170 | ['llama3-8b-8192', 'mixtral-8x7b-32768', 'gemma-7b-it'] 171 | ) 172 | minimum_similarity = st.sidebar.slider('Minimum Similarity:', 1, 100, value=50) 173 | 174 | # Get the user's question 175 | user_question = st.text_input("Ask a question:",value='How many Teslas were purchased?') 176 | 177 | if user_question: 178 | 179 | # Load the verified queries and their embeddings 180 | verified_queries_dict = get_verified_queries_and_embeddings('verified-queries/', embedding_model) 181 | 182 | # Find the most similar verified SQL query to the user's question 183 | verified_sql_query = get_verified_sql(embedding_model,user_question,verified_queries_dict,minimum_similarity) 184 | 185 | # If a verified query is returned, generate the output and summarization 186 | if verified_sql_query is not None: 187 | results_df = execute_duckdb_query(verified_sql_query) 188 | st.markdown(results_df.to_html(index=False), unsafe_allow_html=True) 189 | summarization = get_summarization(client,user_question,results_df,model,additional_context) 190 | st.write('') 191 | st.write(summarization.replace('$','\\$')) 192 | 193 | if __name__ == "__main__": 194 | main() 195 | 196 | -------------------------------------------------------------------------------- /data/employees.csv: -------------------------------------------------------------------------------- 1 | employee_id,name,email 2 | 1,Richard Hendricks,richard@piedpiper.com 3 | 2,Erlich Bachman,erlich@aviato.com 4 | 3,Dinesh Chugtai,dinesh@piedpiper.com 5 | 4,Bertram Gilfoyle,gilfoyle@piedpiper.com 6 | 5,Jared Dunn,jared@piedpiper.com 7 | 6,Monica Hall,monica@raviga.com 8 | 7,Gavin Belson,gavin@hooli.com -------------------------------------------------------------------------------- /data/purchases.csv: -------------------------------------------------------------------------------- 1 | purchase_id,purchase_date,product_name,employee_id,amount 2 | 1,'2024-02-01',iPhone,1,750 3 | 2,'2024-02-02',Tesla,2,70000 4 | 3,'2024-02-03',Humane pin,3,500 5 | 4,'2024-02-04',iPhone,4,700 6 | 5,'2024-02-05',Tesla,5,75000 -------------------------------------------------------------------------------- /groqcloud_darkmode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/definitive-io/duckdb-rag/494a667f7d6ecdb8dc22942609b5335f8ef12bc3/groqcloud_darkmode.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.3 2 | aiosignal==1.3.1 3 | altair==5.2.0 4 | annotated-types==0.6.0 5 | anyio==3.6.2 6 | appnope==0.1.3 7 | argon2-cffi==21.3.0 8 | argon2-cffi-bindings==21.2.0 9 | asttokens==2.0.8 10 | async-timeout==4.0.3 11 | attrs==22.1.0 12 | backcall==0.2.0 13 | beautifulsoup4==4.11.1 14 | bleach==5.0.1 15 | blinker==1.7.0 16 | cachetools==5.3.2 17 | certifi==2024.2.2 18 | cffi==1.15.1 19 | charset-normalizer==3.3.2 20 | click==8.1.7 21 | dataclasses-json==0.6.4 22 | debugpy==1.6.3 23 | decorator==5.1.1 24 | defusedxml==0.7.1 25 | distro==1.9.0 26 | duckdb==0.9.2 27 | entrypoints==0.4 28 | executing==1.1.1 29 | fastjsonschema==2.16.2 30 | filelock==3.13.1 31 | frozenlist==1.4.1 32 | fsspec==2024.2.0 33 | gitdb==4.0.11 34 | GitPython==3.1.41 35 | groq==0.4.0 36 | h11==0.14.0 37 | httpcore==1.0.2 38 | httpx==0.26.0 39 | huggingface-hub==0.20.3 40 | idna==3.4 41 | importlib-metadata==7.0.1 42 | install==1.3.5 43 | ipykernel==6.16.1 44 | ipython==8.5.0 45 | ipython-genutils==0.2.0 46 | ipywidgets==8.0.2 47 | jedi==0.18.1 48 | Jinja2==3.1.2 49 | joblib==1.3.2 50 | jsonpatch==1.33 51 | jsonpointer==2.4 52 | jsonschema==4.16.0 53 | jupyter==1.0.0 54 | jupyter-console==6.4.4 55 | jupyter-server==1.21.0 56 | jupyter_client==7.4.3 57 | jupyter_core==4.11.2 58 | jupyterlab-pygments==0.2.2 59 | jupyterlab-widgets==3.0.3 60 | langchain==0.1.5 61 | langchain-community==0.0.19 62 | langchain-core==0.1.21 63 | langsmith==0.0.87 64 | markdown-it-py==3.0.0 65 | MarkupSafe==2.1.1 66 | marshmallow==3.20.2 67 | matplotlib-inline==0.1.6 68 | mdurl==0.1.2 69 | mistune==2.0.4 70 | mpmath==1.3.0 71 | multidict==6.0.5 72 | mypy-extensions==1.0.0 73 | nbclassic==0.4.5 74 | nbclient==0.7.0 75 | nbconvert==7.2.2 76 | nbformat==5.7.0 77 | nest-asyncio==1.5.6 78 | networkx==3.2.1 79 | nltk==3.8.1 80 | notebook==6.5.1 81 | notebook_shim==0.2.0 82 | numpy==1.23.4 83 | openai==1.12.0 84 | packaging==23.2 85 | pandas==1.5.1 86 | pandocfilters==1.5.0 87 | parso==0.8.3 88 | pexpect==4.8.0 89 | pickleshare==0.7.5 90 | pillow==10.2.0 91 | pinecone-client==3.0.2 92 | prometheus-client==0.15.0 93 | prompt-toolkit==3.0.31 94 | protobuf==4.25.2 95 | psutil==5.9.3 96 | ptyprocess==0.7.0 97 | pure-eval==0.2.2 98 | pyarrow==15.0.0 99 | pycparser==2.21 100 | pydantic==2.6.1 101 | pydantic_core==2.16.2 102 | pydeck==0.8.1b0 103 | Pygments==2.13.0 104 | pyparsing==3.0.9 105 | pyrsistent==0.18.1 106 | python-dateutil==2.8.2 107 | pytz==2022.5 108 | PyYAML==6.0 109 | pyzmq==24.0.1 110 | qtconsole==5.3.2 111 | QtPy==2.2.1 112 | regex==2023.12.25 113 | requests==2.31.0 114 | rich==13.7.0 115 | safetensors==0.4.2 116 | scikit-learn==1.4.0 117 | scipy==1.12.0 118 | Send2Trash==1.8.0 119 | sentence-transformers==2.3.1 120 | sentencepiece==0.1.99 121 | six==1.16.0 122 | smmap==5.0.1 123 | sniffio==1.3.0 124 | soupsieve==2.3.2.post1 125 | SQLAlchemy==2.0.25 126 | sqlparse==0.4.4 127 | stack-data==0.5.1 128 | streamlit==1.31.0 129 | sympy==1.12 130 | tenacity==8.2.3 131 | terminado==0.16.0 132 | threadpoolctl==3.2.0 133 | tiktoken==0.6.0 134 | tinycss2==1.2.1 135 | tokenizers==0.15.1 136 | toml==0.10.2 137 | toolz==0.12.1 138 | torch==2.2.0 139 | tornado==6.2 140 | tqdm==4.66.1 141 | traitlets==5.5.0 142 | transformers==4.37.2 143 | typing-inspect==0.9.0 144 | typing_extensions==4.9.0 145 | tzlocal==5.2 146 | urllib3==2.2.0 147 | validators==0.22.0 148 | wcwidth==0.2.5 149 | webencodings==0.5.1 150 | websocket-client==1.4.1 151 | widgetsnbextension==4.0.3 152 | yarl==1.9.4 153 | zipp==3.17.0 154 | -------------------------------------------------------------------------------- /verified-queries/employees-without-purchases.yaml: -------------------------------------------------------------------------------- 1 | description: Employees without a purchase since Feb 1, 2024 2 | sql: | 3 | SELECT employees.name as employees_without_purchases 4 | FROM employees.csv AS employees 5 | LEFT JOIN purchases.csv AS purchases ON employees.employee_id = purchases.employee_id 6 | AND purchases.purchase_date > '2024-02-01' 7 | WHERE purchases.purchase_id IS NULL 8 | -------------------------------------------------------------------------------- /verified-queries/most-expensive-purchase.yaml: -------------------------------------------------------------------------------- 1 | description: Employee with the most expensive purchase 2 | sql: | 3 | SELECT employees.name AS employee_name, 4 | MAX(amount) AS max_purchase_amount 5 | FROM purchases.csv AS purchases 6 | JOIN employees.csv AS employees ON purchases.employee_id = employees.employee_id 7 | GROUP BY employees.name 8 | ORDER BY max_purchase_amount DESC 9 | LIMIT 1 10 | -------------------------------------------------------------------------------- /verified-queries/most-recent-purchases.yaml: -------------------------------------------------------------------------------- 1 | description: Five most recent purchases 2 | sql: | 3 | SELECT purchases.product_name, 4 | purchases.amount, 5 | employees.name 6 | FROM purchases.csv AS purchases 7 | JOIN employees.csv AS employees ON purchases.employee_id = employees.employee_id 8 | ORDER BY purchases.purchase_date DESC 9 | LIMIT 5; 10 | -------------------------------------------------------------------------------- /verified-queries/number-of-teslas.yaml: -------------------------------------------------------------------------------- 1 | description: Number of Teslas purchased 2 | sql: | 3 | SELECT COUNT(*) as number_of_teslas 4 | FROM purchases.csv AS p 5 | JOIN employees.csv AS e ON e.employee_id = p.employee_id 6 | WHERE p.product_name = 'Tesla' 7 | --------------------------------------------------------------------------------