').append(senderElement);
21 | senderLine.addClass('d-flex justify-content-between align-items-center');
22 |
23 | let copyButton;
24 | if (sender.toLowerCase() === 'autoanalyst' && (isQuery || data)) {
25 | // Create the copy to clipboard button and append it to the sender line
26 | copyButton = $('
');
27 | copyButton.addClass('btn btn-small btn-outline-secondary ml-2'); // Add Bootstrap classes
28 | copyButton.html('❐');
29 | copyButton.attr('data-clipboard-text', isQuery ? message : JSON.stringify(data));
30 | senderLine.append(copyButton);
31 | }
32 |
33 | let messageContent;
34 |
35 | if (isQuery) {
36 | messageContent = $('').append($('').addClass('sql').text(message));
37 | messageContent.addClass('query-content'); // Add the query-content class for query messages
38 | hljs.highlightBlock(messageContent[0]);
39 | } else {
40 | messageContent = typeof message === 'string' ? $('').text(message) : message;
41 | }
42 |
43 |
44 | messageItem.append(senderLine);
45 | messageItem.append(messageContent);
46 |
47 | $('#chatWindow').append(messageItem);
48 | scrollToBottom();
49 |
50 | if (copyButton) {
51 | new ClipboardJS(copyButton[0]);
52 | }
53 | }
54 |
55 |
56 | function sendMessage() {
57 | let message = $("#messageInput").val();
58 | if (message.trim() === '') return;
59 |
60 | addToChat('You', message); // Add message to chat before making the call
61 | $("#messageInput").val(''); // Clear the message input
62 |
63 | // Show loading animation
64 | let loadingMessage = $('
');
65 | loadingMessage.text('AutoAnalyst is thinking...');
66 | $('#chatWindow').append(loadingMessage);
67 |
68 | $.ajax({
69 | url: '/analyze',
70 | type: 'POST',
71 | data: JSON.stringify({ question: message }),
72 | contentType: 'application/json',
73 | dataType: 'json',
74 | headers: { "X-CSRFToken": csrf_token },
75 | success: function (response) {
76 | let result = response;
77 | console.log(result);
78 | loadingMessage.remove(); // Remove loading animation
79 | if (result.analysis_type === "data") {
80 | // Get column names dynamically from result_data
81 | let colNames = Object.keys(result.result_data[0]);
82 |
83 | // Create colModel dynamically
84 | let colModel = colNames.map(function (name) {
85 | return {
86 | name: name,
87 | index: name,
88 | width: 100
89 | };
90 | });
91 |
92 | // Create new div element to host the jqGrid
93 | let gridContainer = $('');
94 | let gridElement = $('
');
95 | let pagerElement = $('');
96 | gridContainer.append(gridElement);
97 | gridContainer.append(pagerElement);
98 |
99 | gridElement.jqGrid({
100 | datatype: "local",
101 | data: result.result_data,
102 | colNames: colNames,
103 | colModel: colModel,
104 | rowNum: 10,
105 | rowList: [10, 20, 30],
106 | pager: pagerElement,
107 | viewrecords: true,
108 | caption: "Analysis Results"
109 | });
110 |
111 | // Add gridContainer as a new message to the chat
112 | addToChat('AutoAnalyst', gridContainer, false, result.result_data);
113 | }
114 | else if (result.analysis_type === "plot") {
115 | let plotContainer = $('
');
116 | let plotElement = $('
');
117 | plotContainer.append(plotElement);
118 |
119 | // Add plotContainer as a new message to the chat
120 | addToChat('AutoAnalyst', plotContainer);
121 | console.log(result.result_plot);
122 |
123 | if (result.result_plot) { // Check if result.result_plot is defined
124 | setTimeout(function () {
125 | let plotData = JSON.parse(result.result_plot);
126 | Plotly.newPlot(plotElement[0], plotData.data, plotData.layout);
127 | }, 0);
128 | }
129 | }
130 |
131 | else if (result.analysis_type === "query") {
132 | addToChat('AutoAnalyst', result.query, true);
133 | }
134 |
135 | $('#error').hide(); // hide the error message on success
136 | },
137 | error: function (xhr, status, error) {
138 | loadingMessage.remove(); // Remove loading animation
139 | let errorMessage = null;
140 | if (xhr.status === 400 || xhr.status === 500) {
141 | let response = JSON.parse(xhr.responseText);
142 | errorMessage = response.error;
143 | } else {
144 | errorMessage = "An unexpected error occurred.";
145 | }
146 | $('#error').html(errorMessage);
147 | $('#error').show();
148 | }
149 | });
150 | }
151 |
152 | $('#messageForm').submit(function (e) {
153 | e.preventDefault();
154 | sendMessage();
155 | });
156 |
157 | function scrollToBottom() {
158 | let chatWindow = $('#chatWindow');
159 | chatWindow.scrollTop(chatWindow.prop("scrollHeight"));
160 | }
161 |
--------------------------------------------------------------------------------
/auto_analyst/llms/openai.py:
--------------------------------------------------------------------------------
1 | from auto_analyst.llms.base import BaseLLM
2 | import openai
3 | import enum
4 | import re
5 | import logging
6 | from typing import (
7 | Optional,
8 | List,
9 | Dict,
10 | )
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class Model(enum.Enum):
16 | """Enum for OpenAI LLM Models"""
17 |
18 | GPT_4 = "gpt-4"
19 | GPT_4_0314 = "gpt-4-0314"
20 | GPT_4_32K = "gpt-4-32k"
21 | GPT_4_32K_0314 = "gpt-4-32k-0314"
22 | GPT_3_5_TURBO = "gpt-3.5-turbo"
23 | GPT_3_5_TURBO_0301 = "gpt-3.5-turbo-0301"
24 |
25 |
26 | class OpenAILLM(BaseLLM):
27 | """Class for OpenAI LLM
28 | Attributes:
29 | api_key (str): OpenAI API Key
30 | model (Model): OpenAI LLM Model
31 | temperature (float): Temperature for generating reply
32 | frequency_penalty (float): Frequency penalty for generating reply
33 | presence_penalty (float): Presence penalty for generating reply"""
34 |
35 | def __init__(
36 | self,
37 | api_key: str,
38 | model: Model,
39 | temperature: float = 0.2,
40 | frequency_penalty: float = 0,
41 | presence_penalty: float = 0,
42 | ):
43 | """Initialize OpenAI LLM
44 | Args:
45 | api_key (str): OpenAI API Key
46 | model (Model): OpenAI LLM Model
47 | temperature (float): Temperature for generating reply
48 | frequency_penalty (float): Frequency penalty for generating reply
49 | presence_penalty (float): Presence penalty for generating reply
50 | """
51 | self.api_key = api_key
52 | self.model = model
53 | self.temperature = temperature
54 | self.frequency_penalty = frequency_penalty
55 | self.presence_penalty = presence_penalty
56 | openai.api_key = self.api_key
57 |
58 | def get_reply(
59 | self,
60 | prompt: Optional[str] = None,
61 | system_prompt: Optional[str] = None,
62 | messages: List[Dict[str, str]] = [],
63 | **kwargs,
64 | ) -> str:
65 | """Get reply from OpenAI LLM
66 | Args:
67 | prompt (Optional[str]): Prompt to be used for generating reply
68 | system_prompt (Optional[str]): System prompt to be used for generating reply
69 | messages (List[Dict[str, str]]): List of messages to be used for generating reply
70 | Returns:
71 | str: Reply from OpenAI LLM"""
72 | if not prompt and not system_prompt and not messages:
73 | raise ValueError(
74 | "Please provide either messages or prompt and system_prompt"
75 | )
76 | elif not messages:
77 | messages = [
78 | {"role": "system", "content": system_prompt}, # type: ignore[dict-item]
79 | {"role": "user", "content": prompt}, # type: ignore[dict-item]
80 | ]
81 |
82 | logger.info(f"Messages: {messages}")
83 | try:
84 | response = openai.ChatCompletion.create(
85 | model=self.model.value,
86 | messages=messages,
87 | temperature=self.temperature,
88 | frequency_penalty=self.frequency_penalty,
89 | )
90 | logger.info(f"Response: {response}")
91 | except openai.error.APIConnectionError as e:
92 | # Handle connection error here
93 | logger.error(f"Failed to connect to OpenAI API: {e}")
94 | raise Exception(f"Failed to connect to OpenAI API: {e}")
95 | except openai.error.APIError as e:
96 | logger.error(f"OpenAI API Error: {e}")
97 | raise Exception(f"OpenAI API Error: {e}")
98 | except openai.error.RateLimitError as e:
99 | logger.error(f"OpenAI API Rate Limit Error: {e}")
100 | raise Exception(f"OpenAI API Rate Limit Error: {e}")
101 | except openai.error.AuthenticationError as e:
102 | logger.error(
103 | f"OpenAI API Authentication Error:{e}\nCheck your OpenAI API key in config.json"
104 | )
105 | raise Exception(
106 | f"OpenAI API Authentication Error:{e}\nCheck your OpenAI API key in config.json"
107 | )
108 | except openai.error.InvalidRequestError as e:
109 | logger.error(f"OpenAI API Invalid Request Error: {e}")
110 | raise Exception(f"OpenAI API Invalid Request Error: {e}")
111 | return response["choices"][0]["message"]["content"].strip()
112 |
113 | async def get_reply_async(
114 | self, prompt=None, system_prompt=None, messages: list = []
115 | ):
116 | if not prompt and not system_prompt and not messages:
117 | raise ValueError(
118 | "Please provide either messages or prompt and system_prompt"
119 | )
120 | elif not messages:
121 | messages = [
122 | {"role": "system", "content": system_prompt},
123 | {"role": "user", "content": prompt},
124 | ]
125 | logger.info(f"Messages: {messages}")
126 | response = await openai.ChatCompletion.create(
127 | model=self.model.value,
128 | messages=messages,
129 | temperature=self.temperature,
130 | frequency_penalty=self.frequency_penalty,
131 | )
132 | logger.info(f"Response: {response}")
133 |
134 | return response["choices"][0]["message"]["content"].strip()
135 |
136 | def get_code(
137 | self,
138 | prompt: Optional[str] = None,
139 | system_prompt: Optional[str] = None,
140 | messages: List[Dict[str, str]] = [],
141 | **kwargs,
142 | ) -> str:
143 | """Get code from OpenAI LLM reply
144 | Args:
145 | prompt (Optional[str]): Prompt to be used for generating reply
146 | system_prompt (Optional[str]): System prompt to be used for generating reply
147 | messages (List[Dict[str, str]]): List of messages to be used for generating reply
148 | Returns:
149 | str: Code from OpenAI LLM reply"""
150 | reply = self.get_reply(prompt, system_prompt, messages)
151 | pattern = r"```.*?\n(.*?)```"
152 | matches = re.findall(pattern, reply, re.DOTALL)
153 |
154 | if matches:
155 | code = matches[0].strip()
156 | return code
157 | else:
158 | return reply
159 |
160 | async def get_code_async(
161 | self, prompt=None, system_prompt=None, messages: list = []
162 | ):
163 | reply = await self.get_reply_async(prompt, system_prompt, messages)
164 | pattern = r"```.*?\n(.*?)```"
165 | matches = re.findall(pattern, reply, re.DOTALL)
166 |
167 | if matches:
168 | code = matches[0].strip()
169 | return code
170 | else:
171 | return reply
172 |
--------------------------------------------------------------------------------
/auto_analyst/prompts/general.py:
--------------------------------------------------------------------------------
1 | import jinja2
2 | from typing import (
3 | Dict,
4 | List,
5 | Optional,
6 | )
7 | from ..data_catalog.base import Table
8 |
9 | environment = jinja2.Environment()
10 |
11 | analysis_type_system_prompt = "You are a helpful assistant that determines whether a question is asking for a SQL query, tabular data or a plot."
12 | query_system_prompt = "You are a helpful assistant that only writes SQL SELECT queries. Reply only with SQL queries, wrap your query in triple backquotes."
13 | transformed_data_system_prompt = "You are a helpful assistant that assists in defining the data needed to answer a question."
14 | plotly_system_prompt = "You are a helpful assistant that only writes Python code using plotly library. Reply only with Python code, wrap your query in triple backquotes."
15 | yes_no_system_prompt = "You are a helpful assistant that answers yes or no questions. Reply only with yes or no."
16 |
17 | type_examples = [
18 | {"question": "How many sales were made in August?", "type": "data"},
19 | {"question": "Relationship between customer age and time spent", "type": "plot"},
20 | {"question": "Query to get 1000 random customers who live in USA", "type": "query"},
21 | {"question": "What is the average amount per transaction?", "type": "data"},
22 | {
23 | "question": "1, 7, 14, 28 retention for customer who signed up in August",
24 | "type": "plot",
25 | },
26 | {"question": "Plot the timeseries of ad impressions", "type": "plot"},
27 | {"question": "Query to get last quarters' sales", "type": "query"},
28 | {
29 | "question": "Top 10 customers by number of transactions",
30 | "type": "data",
31 | },
32 | {"question": "Histogram of customer age", "type": "plot"},
33 | ]
34 |
35 | type_messages = [
36 | {"role": "system", "content": analysis_type_system_prompt},
37 | ] + [
38 | elem
39 | for example in type_examples
40 | for elem in [
41 | {"role": "user", "content": example["question"]},
42 | {"role": "assistant", "content": example["type"]},
43 | ]
44 | ]
45 |
46 |
47 | def render_type_messages(question) -> List[Dict[str, str]]:
48 | """Render type messages
49 | Args:
50 | question (str): Question to be answered
51 | Returns:
52 | List[Dict[str, str]]: List of messages to be displayed to the user"""
53 | return type_messages + [{"role": "user", "content": question}]
54 |
55 |
56 | query_template = environment.from_string(
57 | """
58 | Given the following tables
59 | {% for tbl in source_data %}
60 | {{ tbl.name }}: {{ tbl.description }}
61 | {% endfor %}
62 |
63 | With schema:
64 | {% for table, schema_df in table_schema.items() %}
65 | {{ table }}: {{ schema_df.to_string(index=False) }}
66 | {% endfor %}
67 |
68 | {% if analysis_type != 'plot' %}
69 | Write a SQL query to answer the following question:
70 | {{ question }}{% else %}
71 | Write a SQL query to get the following data:
72 | {{ transformed_data }}
73 | {% endif %}"""
74 | )
75 |
76 |
77 | def render_query_prompt(
78 | question: str,
79 | source_data: List[Table],
80 | table_schema: Dict,
81 | analysis_type: str,
82 | transformed_data: str = "",
83 | ) -> str:
84 | """Render prompt to write a SQL query
85 | Args:
86 | question (str): Question to be answered
87 | source_data (List[Table]): List of source tables
88 | table_schema (Dict): Dictionary of table schemas
89 | analysis_type (str): Type of analysis
90 | transformed_data (str, optional): Transformed data. Defaults to "".
91 | Returns:
92 | str: Prompt to write a SQL query"""
93 | return query_template.render(
94 | question=question,
95 | source_data=source_data,
96 | table_schema=table_schema,
97 | analysis_type=analysis_type,
98 | transformed_data=transformed_data,
99 | )
100 |
101 |
102 | update_query_template = environment.from_string(
103 | """
104 | Instructions:
105 | {{ prompt }}
106 |
107 | Query:
108 | {{ query }}
109 |
110 | Failed with following error:
111 | {{ error }}
112 |
113 | Please update the query to answer the question.
114 | """
115 | )
116 |
117 |
118 | def render_update_query_prompt(
119 | prompt: Optional[str],
120 | query: Optional[str],
121 | error: Optional[str],
122 | ) -> str:
123 | """Render prompt to update a SQL query
124 | Args:
125 | prompt (Optional[str]): Prompt to update the query
126 | query (Optional[str]): SQL query
127 | error (Optional[str]): Error message
128 | Returns:
129 | str: Prompt to update a SQL query"""
130 | return update_query_template.render(
131 | prompt=prompt,
132 | query=query,
133 | error=error,
134 | )
135 |
136 |
137 | transformed_data_template = environment.from_string(
138 | """
139 | Given the following tables
140 | {% for tbl in source_data %}
141 | {{ tbl.name }}: {{ tbl.description }}
142 | {% endfor %}
143 |
144 | With schema:
145 | {% for table, schema_df in table_schema.items() %}
146 | {{ table }}: {{ schema_df.to_string(index=False) }}
147 | {% endfor %}
148 |
149 | Define table 'result_data' needed following question:
150 | {{ question }}
151 |
152 | Answer in following format:
153 | Name: result_data
154 | Description:
155 | Schema
156 |
157 | Column Name | Type | Description
158 | """
159 | )
160 |
161 |
162 | def render_transformed_data_prompt(
163 | question: str,
164 | source_data: List[Table],
165 | table_schema: Dict,
166 | ) -> str:
167 | """Render prompt to define transformed data
168 | Args:
169 | question (str): Question to be answered
170 | source_data (List[Table]): List of source tables
171 | table_schema (Dict): Dictionary of table schemas
172 | Returns:
173 | str: Prompt to define transformed data"""
174 | return transformed_data_template.render(
175 | question=question,
176 | source_data=source_data,
177 | table_schema=table_schema,
178 | )
179 |
180 |
181 | plotly_code_template = environment.from_string(
182 | """
183 | For dataframe with following schema:
184 | {{ transformed_data }}
185 |
186 | Write plotly code to store the following plot in `fig` variable, don't call fig.show():
187 | {{ question }}"""
188 | )
189 |
190 |
191 | def render_plotly_code_prompt(
192 | question: str,
193 | transformed_data: str,
194 | ) -> str:
195 | """Render prompt to write plotly code
196 | Args:
197 | question (str): Question to be answered
198 | transformed_data (str): Transformed data
199 | Returns:
200 | str: Prompt to write plotly code"""
201 | return plotly_code_template.render(
202 | question=question,
203 | transformed_data=transformed_data,
204 | )
205 |
206 |
207 | plotly_code_check_template = environment.from_string(
208 | """
209 | Does the following Python code store a plotly plot in `fig` variable?
210 | {{ code }}
211 |
212 | Answer only with yes or no. If you are unsure, answer no."""
213 | )
214 |
215 |
216 | def render_plotly_code_check_prompt(
217 | code: str,
218 | ) -> str:
219 | """Render prompt to check plotly code
220 | Args:
221 | code (str): Python code
222 | Returns:
223 | str: Prompt to check plotly code"""
224 | return plotly_code_check_template.render(
225 | code=code,
226 | )
227 |
--------------------------------------------------------------------------------
/auto_analyst/auto_analyst.py:
--------------------------------------------------------------------------------
1 | from auto_analyst.databases.base import BaseDatabase
2 | from auto_analyst.data_catalog.base import BaseDataCatalog, Table
3 | from auto_analyst.analysis import Analysis, AnalysisType
4 | import pandas as pd
5 | from auto_analyst.llms.base import BaseLLM
6 | from auto_analyst.prompts import (
7 | render_type_messages,
8 | render_query_prompt,
9 | query_system_prompt,
10 | transformed_data_system_prompt,
11 | yes_no_system_prompt,
12 | plotly_system_prompt,
13 | render_update_query_prompt,
14 | render_transformed_data_prompt,
15 | render_plotly_code_prompt,
16 | render_plotly_code_check_prompt,
17 | )
18 | import logging
19 | from plotly.graph_objs import Figure
20 |
21 | from typing import (
22 | Dict,
23 | List,
24 | Tuple,
25 | )
26 |
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class AutoAnalyst:
32 | def __init__(
33 | self,
34 | database: BaseDatabase,
35 | datacatalog: BaseDataCatalog,
36 | driver_llm: BaseLLM,
37 | query_retry_count: int = 0,
38 | ) -> None:
39 | self.database = database
40 | self.datacatalog = datacatalog
41 | self.driver_llm = driver_llm
42 | # self.analysis: Optional[Analysis] = None
43 | # self.query: Optional[str] = None
44 | # self.query_prompt: Optional[str] = None
45 | self.query_retry_count = query_retry_count
46 | logger.info(
47 | f"Initalized AutoAnalyst with retry count: {self.query_retry_count}"
48 | )
49 |
50 | def _generate_data_query(
51 | self,
52 | question: str,
53 | source_data: List[Table],
54 | table_schema: Dict,
55 | analysis_type: str,
56 | transformed_data: str = "",
57 | ) -> Tuple[str, str]:
58 | """Generate query to answer the question"""
59 | query_prompt = render_query_prompt(
60 | question=question,
61 | source_data=source_data,
62 | table_schema=table_schema,
63 | analysis_type=analysis_type,
64 | transformed_data=transformed_data,
65 | )
66 | logger.info(f"Query prompt: {query_prompt}")
67 | query = self.driver_llm.get_code(
68 | prompt=query_prompt,
69 | system_prompt=query_system_prompt,
70 | )
71 |
72 | return query_prompt, query
73 |
74 | def _generate_plotly_code(
75 | self, question: str, transformed_data: str, retry_count: int = 1
76 | ) -> str:
77 | """Generate plotly code to plot the data"""
78 | plotly_code_prompt = render_plotly_code_prompt(
79 | question=question,
80 | transformed_data=transformed_data,
81 | )
82 |
83 | plotly_code = self.driver_llm.get_code(
84 | prompt=plotly_code_prompt,
85 | system_prompt=plotly_system_prompt,
86 | )
87 |
88 | plotly_code_check_prompt = render_plotly_code_check_prompt(
89 | code=plotly_code,
90 | )
91 |
92 | plotly_code_check = self.driver_llm.get_reply(
93 | prompt=plotly_code_check_prompt,
94 | system_prompt=yes_no_system_prompt,
95 | )
96 |
97 | if plotly_code_check.strip().lower() == "no":
98 | if retry_count > 0:
99 | plotly_code = self._generate_plotly_code(
100 | question=question,
101 | transformed_data=transformed_data,
102 | retry_count=retry_count - 1,
103 | )
104 | else:
105 | raise ValueError("Plotly code generation failed")
106 | return plotly_code
107 |
108 | def _update_query(self, query: str, query_prompt: str, error: str) -> str:
109 | """Update query to answer the question"""
110 |
111 | update_query_prompt = render_update_query_prompt(
112 | prompt=query_prompt,
113 | query=query,
114 | error=error,
115 | )
116 | logger.info(f"Update query prompt: {update_query_prompt}")
117 | query = self.driver_llm.get_code(
118 | prompt=update_query_prompt,
119 | system_prompt=query_system_prompt,
120 | )
121 | logger.info(f"Updated query: {query}")
122 | return query # Type: ignore
123 |
124 | def _run_query(
125 | self, query: str, query_prompt: str, retry_count: int = 0
126 | ) -> pd.DataFrame:
127 | """Run query and return the result"""
128 | try:
129 | return self.database.run_query(query)
130 | except Exception as e:
131 | if retry_count > 0:
132 | query = self._update_query(
133 | query=query, query_prompt=query_prompt, error=str(e)
134 | )
135 | return self._run_query(
136 | query=query, query_prompt=query_prompt, retry_count=retry_count - 1
137 | )
138 | else:
139 | raise e
140 |
141 | def _run_plotly_code(self, plotly_code: str, result_data: pd.DataFrame) -> Figure:
142 | namespace = {"result_data": result_data}
143 | plotly_code = plotly_code.replace("fig.show()", "")
144 | logger.info(f"Plotly code to be executed: {plotly_code}")
145 | exec(plotly_code, namespace)
146 | return namespace["fig"] # type: ignore
147 |
148 | def analyze(self, question: str) -> Analysis:
149 | """Analyze the question and return the analysis"""
150 |
151 | analysis = Analysis(question)
152 | logger.info(f"Analyzing question: {question}")
153 |
154 | # Determine whether the question can be answered using query, aggregate data or a plot
155 | analysis_type = self.driver_llm.get_reply(
156 | messages=render_type_messages(question)
157 | ) # type: ignore
158 | logger.info(f"Analysis type: {analysis_type}")
159 |
160 | analysis.analysis_type = AnalysisType(analysis_type)
161 |
162 | # Determine source data
163 | source_tables = self.datacatalog.get_source_tables(question)
164 | if len(source_tables) == 0:
165 | raise ValueError("No source tables found")
166 |
167 | analysis.metadata = {"source_data": [tbl.to_str() for tbl in source_tables]} # type: ignore
168 | logger.info(f"Source tables: {[tbl.to_str() for tbl in source_tables]}") # type: ignore
169 |
170 | table_schema = self.datacatalog.get_table_schemas([tbl.name for tbl in source_tables]) # type: ignore
171 | analysis.metadata = {"table_schema": {k: v.to_dict(orient="records") for k, v in table_schema.items()}} # type: ignore
172 | logger.info(f"Table schema: {table_schema}")
173 |
174 | if analysis_type in ["query", "data"]:
175 | # Generate query
176 | query_prompt, query = self._generate_data_query(
177 | question=question,
178 | source_data=source_tables, # type: ignore
179 | table_schema=table_schema,
180 | analysis_type=analysis_type,
181 | )
182 | analysis.query = query
183 |
184 | if analysis_type == "data":
185 | # Run query
186 | result_data = self._run_query(
187 | query=query,
188 | query_prompt=query_prompt,
189 | retry_count=self.query_retry_count,
190 | )
191 | analysis.result_data = result_data
192 |
193 | elif analysis_type == "plot":
194 | transformed_data = self.driver_llm.get_reply(
195 | prompt=render_transformed_data_prompt(
196 | question=question,
197 | source_data=source_tables, # type: ignore
198 | table_schema=table_schema,
199 | ),
200 | system_prompt=transformed_data_system_prompt,
201 | )
202 |
203 | query_prompt, query = self._generate_data_query(
204 | question=question,
205 | source_data=source_tables, # type: ignore
206 | table_schema=table_schema,
207 | analysis_type=analysis_type,
208 | transformed_data=transformed_data,
209 | )
210 |
211 | result_data = self._run_query(
212 | query=query,
213 | query_prompt=query_prompt,
214 | retry_count=self.query_retry_count,
215 | )
216 | analysis.result_data = result_data
217 |
218 | # Generate plotting code
219 | plotly_code = self._generate_plotly_code(
220 | question=question,
221 | transformed_data=transformed_data,
222 | )
223 |
224 | fig = self._run_plotly_code(
225 | plotly_code=plotly_code, result_data=result_data
226 | )
227 | analysis.result_plot = fig
228 | return analysis
229 |
--------------------------------------------------------------------------------