├── .streamlit
└── config.toml
├── README.md
├── main.py
└── requirements.txt
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | primaryColor="#D2946E"
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://ask-csv.streamlit.app/)
2 |
3 | # Ask Your Data (GPT-powered)
4 |
5 | A web app that allows to interact with user-uploaded data in .csv.
6 |
7 | - Ask questions about your data, such as "what was the total sales in the US in 2022?"
8 | - Visualize data from your csv file using natural language. For example, "plot total sales by country and product category"
9 |
10 |
11 | Upload a csv file with data:
12 |
13 |
14 |
15 | Write a question to your data:
16 |
17 |
18 |
19 | Create a visualization:
20 |
21 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | import sqlite3
4 | from sqlite3 import Connection
5 | import openai
6 | import plotly.express as px
7 | import plotly.graph_objs as go
8 | import numpy as np
9 | import re
10 | from dateutil.parser import parse
11 | import traceback
12 |
13 |
14 | footer_html = """
15 |
38 | """
39 |
40 | page_bg_img = f"""
41 |
53 | """
54 |
55 |
56 | def create_connection(db_name: str) -> Connection:
57 | conn = sqlite3.connect(db_name)
58 | return conn
59 |
60 | def run_query(conn: Connection, query: str) -> pd.DataFrame:
61 | df = pd.read_sql_query(query, conn)
62 | return df
63 |
64 | def create_table(conn: Connection, df: pd.DataFrame, table_name: str):
65 | df.to_sql(table_name, conn, if_exists="replace", index=False)
66 |
67 |
68 | def generate_gpt_reponse(gpt_input, max_tokens):
69 |
70 | # load api key from secrets
71 | openai.api_key = st.secrets["openai_api_key"]
72 |
73 | completion = openai.ChatCompletion.create(
74 | model="gpt-3.5-turbo",
75 | max_tokens=max_tokens,
76 | temperature=0,
77 | messages=[
78 | {"role": "user", "content": gpt_input},
79 | ]
80 | )
81 |
82 | gpt_response = completion.choices[0].message['content'].strip()
83 | return gpt_response
84 |
85 |
86 | def extract_code(gpt_response):
87 | """function to extract code and sql query from gpt response"""
88 |
89 | if "```" in gpt_response:
90 | # extract text between ``` and ```
91 | pattern = r'```(.*?)```'
92 | code = re.search(pattern, gpt_response, re.DOTALL)
93 | extracted_code = code.group(1)
94 |
95 | # remove python from the code (weird bug)
96 | extracted_code = extracted_code.replace('python', '')
97 |
98 | return extracted_code
99 | else:
100 | return gpt_response
101 |
102 |
103 | # wide layout
104 | st.set_page_config(page_icon="🤖", page_title="Ask CSV")
105 | st.markdown(page_bg_img, unsafe_allow_html=True)
106 |
107 | st.title("ASK CSV 🤖 (GPT-powered)")
108 | st.header('Use Natural Language to Query Your Data')
109 |
110 | uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
111 |
112 | if uploaded_file is None:
113 | st.info(f"""
114 | 👆 Upload a .csv file first. Sample to try: [sample_data.csv](https://docs.google.com/spreadsheets/d/e/2PACX-1vTeB7_jzJtacH3XrFh553m9ahL0e7IIrTxMhbPtQ8Jmp9gCJKkU624Uk1uMbCEN_-9Sf7ikd1a85wIK/pub?gid=0&single=true&output=csv)
115 | """)
116 |
117 | elif uploaded_file:
118 | df = pd.read_csv(uploaded_file)
119 |
120 | # Apply the custom function and convert date columns
121 | for col in df.columns:
122 | # check if a column name contains date substring
123 | if 'date' in col.lower():
124 | df[col] = pd.to_datetime(df[col])
125 | # remove timestamp
126 | #df[col] = df[col].dt.date
127 |
128 | # reset index
129 | df = df.reset_index(drop=True)
130 |
131 | # replace space with _ in column names
132 | df.columns = df.columns.str.replace(' ', '_')
133 |
134 | cols = df.columns
135 | cols = ", ".join(cols)
136 |
137 | with st.expander("Preview of the uploaded file"):
138 | st.table(df.head())
139 |
140 | conn = create_connection(":memory:")
141 | table_name = "my_table"
142 | create_table(conn, df, table_name)
143 |
144 |
145 | selected_mode = st.selectbox("What do you want to do?", ["Ask your data", "Create a chart [beta]"])
146 |
147 | if selected_mode == 'Ask your data':
148 |
149 | user_input = st.text_area("Write a concise and clear question about your data. For example: What is the total sales in the USA in 2022?", value='What is the total sales in the USA in 2022?')
150 |
151 | if st.button("Get Response"):
152 |
153 | try:
154 | # create gpt prompt
155 | gpt_input = 'Write a sql lite query based on this question: {} The table name is my_table and the table has the following columns: {}. ' \
156 | 'Return only a sql query and nothing else'.format(user_input, cols)
157 |
158 | query = generate_gpt_reponse(gpt_input, max_tokens=200)
159 | query_clean = extract_code(query)
160 | result = run_query(conn, query_clean)
161 |
162 | with st.expander("SQL query used"):
163 | st.code(query_clean)
164 |
165 | # if result df has one row and one column
166 | if result.shape == (1, 1):
167 |
168 | # get the value of the first row of the first column
169 | val = result.iloc[0, 0]
170 |
171 | # write one liner response
172 | st.subheader('Your response: {}'.format(val))
173 |
174 | else:
175 | st.subheader("Your result:")
176 | st.table(result)
177 |
178 | except Exception as e:
179 | #st.error(f"An error occurred: {e}")
180 | st.error('Oops, there was an error :( Please try again with a different question.')
181 |
182 | elif selected_mode == 'Create a chart [beta]':
183 |
184 | user_input = st.text_area(
185 | "Briefly explain what you want to plot from your data. For example: Plot total sales by country and product category", value='Plot total sales by country and product category')
186 |
187 | if st.button("Create a visualization [beta]"):
188 | try:
189 | # create gpt prompt
190 | gpt_input = 'Write code in Python using Plotly to address the following request: {} ' \
191 | 'Use df that has the following columns: {}. Do not use animation_group argument and return only code with no import statements, use transparent background, the data has been already loaded in a df variable'.format(user_input, cols)
192 |
193 | gpt_response = generate_gpt_reponse(gpt_input, max_tokens=1500)
194 |
195 | extracted_code = extract_code(gpt_response)
196 |
197 | extracted_code = extracted_code.replace('fig.show()', 'st.plotly_chart(fig)')
198 |
199 | with st.expander("Code used"):
200 | st.code(extracted_code)
201 |
202 | # execute code
203 | exec(extracted_code)
204 |
205 | except Exception as e:
206 | #st.error(f"An error occurred: {e}")
207 | #st.write(traceback.print_exc())
208 | st.error('Oops, there was an error :( Please try again with a different question.')
209 |
210 | # footer
211 | st.markdown(footer_html, unsafe_allow_html=True)
212 |
213 |
214 |
215 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | pandas
3 | openai
4 | plotly
5 | numpy
--------------------------------------------------------------------------------