├── .streamlit └── config.toml ├── README.md ├── main.py └── requirements.txt /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | primaryColor="#D2946E" 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://ask-csv.streamlit.app/) 2 | 3 | # Ask Your Data (GPT-powered) 4 | 5 | A web app that allows to interact with user-uploaded data in .csv. 6 | 7 | - Ask questions about your data, such as "what was the total sales in the US in 2022?" 8 | - Visualize data from your csv file using natural language. For example, "plot total sales by country and product category" 9 | 10 | 11 | Upload a csv file with data: 12 | 13 | 14 | 15 | Write a question to your data: 16 | 17 | 18 | 19 | Create a visualization: 20 | 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import sqlite3 4 | from sqlite3 import Connection 5 | import openai 6 | import plotly.express as px 7 | import plotly.graph_objs as go 8 | import numpy as np 9 | import re 10 | from dateutil.parser import parse 11 | import traceback 12 | 13 | 14 | footer_html = """ 15 | 38 | """ 39 | 40 | page_bg_img = f""" 41 | 53 | """ 54 | 55 | 56 | def create_connection(db_name: str) -> Connection: 57 | conn = sqlite3.connect(db_name) 58 | return conn 59 | 60 | def run_query(conn: Connection, query: str) -> pd.DataFrame: 61 | df = pd.read_sql_query(query, conn) 62 | return df 63 | 64 | def create_table(conn: Connection, df: pd.DataFrame, table_name: str): 65 | df.to_sql(table_name, conn, if_exists="replace", index=False) 66 | 67 | 68 | def generate_gpt_reponse(gpt_input, max_tokens): 69 | 70 | # load api key from secrets 71 | openai.api_key = st.secrets["openai_api_key"] 72 | 73 | completion = openai.ChatCompletion.create( 74 | model="gpt-3.5-turbo", 75 | max_tokens=max_tokens, 76 | temperature=0, 77 | messages=[ 78 | {"role": "user", "content": gpt_input}, 79 | ] 80 | ) 81 | 82 | gpt_response = completion.choices[0].message['content'].strip() 83 | return gpt_response 84 | 85 | 86 | def extract_code(gpt_response): 87 | """function to extract code and sql query from gpt response""" 88 | 89 | if "```" in gpt_response: 90 | # extract text between ``` and ``` 91 | pattern = r'```(.*?)```' 92 | code = re.search(pattern, gpt_response, re.DOTALL) 93 | extracted_code = code.group(1) 94 | 95 | # remove python from the code (weird bug) 96 | extracted_code = extracted_code.replace('python', '') 97 | 98 | return extracted_code 99 | else: 100 | return gpt_response 101 | 102 | 103 | # wide layout 104 | st.set_page_config(page_icon="🤖", page_title="Ask CSV") 105 | st.markdown(page_bg_img, unsafe_allow_html=True) 106 | 107 | st.title("ASK CSV 🤖 (GPT-powered)") 108 | st.header('Use Natural Language to Query Your Data') 109 | 110 | uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"]) 111 | 112 | if uploaded_file is None: 113 | st.info(f""" 114 | 👆 Upload a .csv file first. Sample to try: [sample_data.csv](https://docs.google.com/spreadsheets/d/e/2PACX-1vTeB7_jzJtacH3XrFh553m9ahL0e7IIrTxMhbPtQ8Jmp9gCJKkU624Uk1uMbCEN_-9Sf7ikd1a85wIK/pub?gid=0&single=true&output=csv) 115 | """) 116 | 117 | elif uploaded_file: 118 | df = pd.read_csv(uploaded_file) 119 | 120 | # Apply the custom function and convert date columns 121 | for col in df.columns: 122 | # check if a column name contains date substring 123 | if 'date' in col.lower(): 124 | df[col] = pd.to_datetime(df[col]) 125 | # remove timestamp 126 | #df[col] = df[col].dt.date 127 | 128 | # reset index 129 | df = df.reset_index(drop=True) 130 | 131 | # replace space with _ in column names 132 | df.columns = df.columns.str.replace(' ', '_') 133 | 134 | cols = df.columns 135 | cols = ", ".join(cols) 136 | 137 | with st.expander("Preview of the uploaded file"): 138 | st.table(df.head()) 139 | 140 | conn = create_connection(":memory:") 141 | table_name = "my_table" 142 | create_table(conn, df, table_name) 143 | 144 | 145 | selected_mode = st.selectbox("What do you want to do?", ["Ask your data", "Create a chart [beta]"]) 146 | 147 | if selected_mode == 'Ask your data': 148 | 149 | user_input = st.text_area("Write a concise and clear question about your data. For example: What is the total sales in the USA in 2022?", value='What is the total sales in the USA in 2022?') 150 | 151 | if st.button("Get Response"): 152 | 153 | try: 154 | # create gpt prompt 155 | gpt_input = 'Write a sql lite query based on this question: {} The table name is my_table and the table has the following columns: {}. ' \ 156 | 'Return only a sql query and nothing else'.format(user_input, cols) 157 | 158 | query = generate_gpt_reponse(gpt_input, max_tokens=200) 159 | query_clean = extract_code(query) 160 | result = run_query(conn, query_clean) 161 | 162 | with st.expander("SQL query used"): 163 | st.code(query_clean) 164 | 165 | # if result df has one row and one column 166 | if result.shape == (1, 1): 167 | 168 | # get the value of the first row of the first column 169 | val = result.iloc[0, 0] 170 | 171 | # write one liner response 172 | st.subheader('Your response: {}'.format(val)) 173 | 174 | else: 175 | st.subheader("Your result:") 176 | st.table(result) 177 | 178 | except Exception as e: 179 | #st.error(f"An error occurred: {e}") 180 | st.error('Oops, there was an error :( Please try again with a different question.') 181 | 182 | elif selected_mode == 'Create a chart [beta]': 183 | 184 | user_input = st.text_area( 185 | "Briefly explain what you want to plot from your data. For example: Plot total sales by country and product category", value='Plot total sales by country and product category') 186 | 187 | if st.button("Create a visualization [beta]"): 188 | try: 189 | # create gpt prompt 190 | gpt_input = 'Write code in Python using Plotly to address the following request: {} ' \ 191 | 'Use df that has the following columns: {}. Do not use animation_group argument and return only code with no import statements, use transparent background, the data has been already loaded in a df variable'.format(user_input, cols) 192 | 193 | gpt_response = generate_gpt_reponse(gpt_input, max_tokens=1500) 194 | 195 | extracted_code = extract_code(gpt_response) 196 | 197 | extracted_code = extracted_code.replace('fig.show()', 'st.plotly_chart(fig)') 198 | 199 | with st.expander("Code used"): 200 | st.code(extracted_code) 201 | 202 | # execute code 203 | exec(extracted_code) 204 | 205 | except Exception as e: 206 | #st.error(f"An error occurred: {e}") 207 | #st.write(traceback.print_exc()) 208 | st.error('Oops, there was an error :( Please try again with a different question.') 209 | 210 | # footer 211 | st.markdown(footer_html, unsafe_allow_html=True) 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | pandas 3 | openai 4 | plotly 5 | numpy --------------------------------------------------------------------------------