├── .gitignore
├── README.md
├── assets
├── CSG_module_prompt_template.png
├── QE_module_prompt_template.png
├── SF_module_prompt_template.png
├── SR_module_prompt_template.png
├── e-sql-flowchart_qid_1448.png
└── e-sql-pipeline.png
├── env.example
├── evaluation
├── evaluation.py
├── evaluation_ex.py
├── evaluation_f1.py
├── evaluation_utils.py
└── evaluation_ves.py
├── few-shot-data
└── question_enrichment_few_shot_examples.json
├── main.py
├── pipeline
└── Pipeline.py
├── prompt_templates
├── candidate_sql_generation_prompt_template.txt
├── question_enrichment_prompt_template.txt
├── schema_filter_prompt_template.txt
└── sql_refinement_prompt_template.txt
├── requirements.txt
├── run_evaluation.sh
├── run_main.sh
└── utils
├── __init__.py
├── db_utils.py
├── openai_utils.py
├── prompt_utils.py
└── retrieval_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | __pycache__/
3 | *.py[cod]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # E-SQL: Direct Schema Linking via Question Enrichment in Text-to-SQL
2 |
3 | This is the official repository for the paper **"E-SQL: Direct Schema Linking via Question Enrichment in Text-to-SQL"**.
4 |
5 | [](https://www.arxiv.org/abs/2409.16751) [](#citation)
6 |
7 | ## Overview
8 |
9 | Translating natural language queries into SQL (Text-to-SQL) is a critical task for enabling natural language interfaces to databases (NLIDB), but challenges such as complex schemas and ambiguous queries often hinder the accuracy of generated SQL. **E-SQL** addresses these challenges through a novel pipeline that directly links relevant database schema items with the natural language query, a method we call **Question Enrichment**.
10 |
11 | E-SQL enriches natural language questions by incorporating relevant database elements—such as tables, columns, and potential conditions—directly into the query to enhance SQL generation accuracy. Our pipeline also introduces **Candidate Predicate Generation** to further reduce SQL errors caused by incomplete or incorrect predicates.
12 |
13 |
14 |
15 | ### Key modules of the E-SQL pipeline:
16 | - **Candidate SQL Generation (CSG)**: Generates initial SQL queries based on the natural language question.
17 | - **Candidate Predicate Generation (CPG)**: Extracts and incorporates likely predicates from the database.
18 | - **Question Enrichment (QE)**: Enhances the natural language query by linking relevant database items and conditions.
19 | - **SQL Refinement (SR)**: Refines generated SQL queries by correcting minor errors and ensuring execution correctness.
20 |
21 | While schema filtering has been widely adopted in previous research, our experiments show that it can lead to performance degradation when used alongside advanced large language models (LLMs). Instead, direct schema linking through question enrichment and candidate predicate augmentation proves to be more effective, particularly on complex queries.
22 |
23 |
24 | E-SQL execution flow for th question with question ID 1448 in the development set is given below.
25 |
26 |
27 | This repository contains all the code for implementing and evaluating **E-SQL** on the BIRD development set, utilizing models like GPT-4 and GPT-4o-mini to achieve competitive performance.
28 |
29 |
30 |
31 | ## Project Structure
32 | To avoid any potential file path errors, organize your project files as follows:
33 |
34 | ```text
35 | nlq-to-sql/
36 | ├── dataset/
37 | │ ├── bird-sql/
38 | │ │ ├── dev/
39 | | | | ├── dev_databases/
40 | | | | ├── column_meaning.json
41 | | | | ├── dev_gold.json
42 | | | | ├── dev_tables.json
43 | | | | ├── dev_tied_append.json
44 | | | | ├── dev.json
45 | │ │ ├── test/
46 | | | | ├── test_databases/
47 | | | | ├── column_meaning.json
48 | | | | ├── test_gold.json
49 | | | | ├── test_tables.json
50 | | | | ├── test.json
51 | ├── E-SQL/
52 | │ ├── evaluation/
53 | │ │ ├── evaluation_ves.py
54 | │ │ └── evaluation.py
55 | │ ├── few-shot-data/
56 | │ │ └── question_enrichment_few_shot_examples.json
57 | │ ├── pipeline/
58 | │ │ └── Pipeline.py
59 | │ ├── prompt_templates/
60 | │ │ ├── candidate_sql_generation_prompt_template.txt
61 | │ │ ├── question_enrichment_prompt_template.txt
62 | │ │ └── sql_refinement_prompt_template.txt
63 | │ ├── results/
64 | │ │ ├── model_outputs_dev_CSG-EQ-SR_gpt-4o-2024-08-06/
65 | │ │ ├── model_outputs_dev_CSG-EQ-SR_gpt-4o-mini-2024-07-18/
66 | │ ├── utils/
67 | │ │ ├── __init__.py
68 | │ │ ├── db_utils.py
69 | │ │ ├── openai_utils.py
70 | │ │ ├── prompt_utils.py
71 | │ │ └── retrieval_utils.py
72 | │ ├── .env
73 | │ ├── .gitignore
74 | │ ├── env.example
75 | │ ├── main.py
76 | │ ├── README.md
77 | │ ├── requirements.txt
78 | │ ├── run_evaluation.sh
79 | │ ├── run_main.sh
80 |
81 | ```
82 | hint: you can generate column_meaning.json for dev and val dataset from TA-SQL: https://github.com/quge2023/TA-SQL
83 |
84 | ## Setting up the Environment
85 | 1. **Download the compressed code.**
86 | 2. **Create a `.env` file and set environment variable as follows:** In the root directory of the E-SQL add the following items to the `.env` file. Ensure that your environment variables are correctly configured. One critical point is to set the DB_ROOT_PATH environment variable to the root directory of the bird-dataset, which contains both the dev and test directories. This is critical for the proper functioning of the code, as the database paths are dynamically constructed based on the mode (dev or test). To avoid any potential file path errors, we strongly recommend organizing your project files as we described above.
87 |
88 | ```
89 | DB_ROOT_PATH="../dataset/bird-sql"
90 | OPENAI_API_KEY=
91 | ```
92 |
93 | 3. **Install the required packages:** Use the following command to install the required packages:
94 | ```
95 | pip install -r requirements.txt
96 | ```
97 |
98 |
99 | ## Running the Code
100 | 1. **Update the `run_main.sh` file for running mode or OpenAI model change:** In `run_main.sh` file set the mode and model argument. Do not change the other arguments in the `run_main.sh`.
101 |
102 | ```bash
103 | mode='test' # Update this with the mode you want to run. Write either 'dev' or 'test'
104 | model="gpt-4o-2024-08-06" # For GPT-4o use "gpt-4o-2024-08-06". For GPT-4o-mini use "gpt-4o-mini-2024-07-18".
105 | pipeline_order="CSG-QE-SR"
106 | ...
107 | ```
108 |
109 |
110 |
111 | 2. **Run the main script:** After move to the E-SQL directory run the main script as follows. Running main script initially process all database csv files that include description about tables and column of databases.
112 | ```
113 | cd E-SQL
114 | sh run_main.sh
115 | ```
116 |
117 | 3. Review Informative Logs: After running the main script, the database description files will be processed, and informative logs will be printed. These logs provide detailed information about the processing steps and the status of each database description file. No additional effort is required during this step; it is intended for informational purposes only to help you monitor the process.
118 |
119 |
120 | ## Evaluation
121 |
122 | 1. **Setting up the Evaluation Arguments:** In order to evaluate the results, set the arguments in the `run_evaluation.sh` as follows:
123 |
124 | ```bash
125 | db_root_path='../dataset/bird-sql/dev/dev_databases/'
126 | data_mode='dev'
127 | diff_json_path='../dataset/bird-sql/dev/dev.json'
128 | predicted_sql_path_kg='./results/model_outputs_{args.mode}_{args.pipeline_order}_{args.model}'
129 | ground_truth_path='../dataset/bird-sql/dev/'
130 | num_cpus=8
131 | meta_time_out=30.0
132 | mode_gt='gt'
133 | mode_predict='gpt'
134 | ```
135 |
136 |
137 |
138 | 2. **Run the evaluation script** after main script is executed and all predicted SQLs are saved into the `predict_dev.json` or `predict_test.json` files under `./results/model_outputs_{args.mode}_{args.pipeline_order}_{args.model}` directory. Please run the evaluation script as follows to print both EX and VES metrics after completed the previous step.
139 | ```
140 | sh run_evaluation.sh
141 | ```
142 |
143 | 3. **Cost estimation**: The evaluation on the dev set involves approximately 40.25 million prompt tokens and 1.23 million output tokens, resulting in an estimated cost of $112 for GPT-4o and $8 for GPT-4o-mini.
144 |
145 | ## Metrics
146 |
147 | Finally the `metric.json` file under the `./results/model_outputs` directory is created to inspect the execution accuracies. The `metric.json` file format is as followings. Note that this file doesn't include VES metric. You can also run the evaluation script as explained in the previous part.
148 |
149 | ```json
150 | "candidate_generation": {
151 | "EX": "Overall Execution Accuracy",
152 | "total_correct_count": "Total number of correctly predicted SQLs.",
153 | "total_item_count": "Total item count",
154 | "simple_stats": {
155 | "correct_number": "The number of correctly predicted SQL considered as simple.",
156 | "count": "The total number of simple questions.",
157 | "ex": "Execution accuracy for the simple questions."
158 | },
159 | "moderate_stats": {
160 | "correct_number": "The number of correctly predicted SQL considered as moderate.",
161 | "count": "The total number of moderate questions.",
162 | "ex": "Execution accuracy for the moderate questions."
163 | },
164 | "challenging_stats": {
165 | "correct_number": "The number of correctly predicted SQL considered as challenging.",
166 | "count": "The total number of challenging questions.",
167 | "ex": "Execution accuracy for the challenging questions."
168 | },
169 | "fail_q_ids": [],
170 | "config": {
171 | "mode": "Either dev or test",
172 | "model": "OpenAI model",
173 | "temperature": 0.0,
174 | "top_p": 1.0,
175 | "max_tokens": 2048,
176 | "n": 1,
177 | "pipeline_order": "Selected Pipeline Order",
178 | "enrichment_level": "complex",
179 | "enrichment_level_shot_number": 3,
180 | "enrichment_few_shot_schema_existance": false,
181 | "filtering_level_shot_number": 3,
182 | "filtering_few_shot_schema_existance": false,
183 | "cfg": true,
184 | "generation_level_shot_number": 3,
185 | "generation_few_shot_schema_existance": false,
186 | "db_sample_limit": 10,
187 | "relevant_description_number": 20,
188 | "seed": 42
189 | }
190 | }
191 | ```
192 |
193 | ## Additional Notes:
194 |
195 | 1. **Column Meaning File Usage**: The `column_meaning.json` file is needed in both dev and test.
196 |
197 | 2. **Error Handling**: There are error-handling mechanisms and loggings. However, in some cases, the LLM might fail to generate a SQL query, causing an error that halts the code execution. When such an error occurs, the question ID where the code stopped can be viewed both in the terminal and in the predictions.json or predict_dev.json files, located in the results/model_outputs_{mode}_{pipeline_order}_{model} directory. When restarting the evaluation, locate the error point, then open the main.py file. You'll find a section labeled "In case of an error, you can restart the code from the point of error using the following line." Update the dataset list accordingly so that the code resumes from where it left off, preventing the need to start the process over from the beginning. Although the code appends new predicted SQLs and corresponding objects to the `predict_dev.json` and `predictions.json` files, please make a copy of both files with a different name in case of an error, to ensure that previously generated data is not lost.
198 |
199 | 3. **SQL Parsing**: The generated SQL queries are parsed using the SQLGlot SQL parser. While parsing, errors may arise, but these are handled using try-except statements. Although warnings are logged when parsing issues occur, these can generally be ignored as long as no errors are thrown and the execution proceeds smoothly.
200 |
201 | 4. **Gold SQL File Naming**: The evaluation script expects a file named dev_gold.sql. However, the corresponding file downloaded from BIRD website was named dev.sql. To ensure compatibility with the evaluation script, copy and rename the dev.sql file to dev_gold.sql.
202 |
203 |
204 | # Citation
205 |
206 | If you find this repository helpful, please cite the following paper:
207 |
208 | ```
209 | @misc{caferoğlu2024esqldirectschemalinking,
210 | title={E-SQL: Direct Schema Linking via Question Enrichment in Text-to-SQL},
211 | author={Hasan Alp Caferoğlu and Özgür Ulusoy},
212 | year={2024},
213 | eprint={2409.16751},
214 | archivePrefix={arXiv},
215 | primaryClass={cs.CL},
216 | url={https://arxiv.org/abs/2409.16751},
217 | }
218 | ```
219 |
--------------------------------------------------------------------------------
/assets/CSG_module_prompt_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/CSG_module_prompt_template.png
--------------------------------------------------------------------------------
/assets/QE_module_prompt_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/QE_module_prompt_template.png
--------------------------------------------------------------------------------
/assets/SF_module_prompt_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/SF_module_prompt_template.png
--------------------------------------------------------------------------------
/assets/SR_module_prompt_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/SR_module_prompt_template.png
--------------------------------------------------------------------------------
/assets/e-sql-flowchart_qid_1448.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/e-sql-flowchart_qid_1448.png
--------------------------------------------------------------------------------
/assets/e-sql-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/e-sql-pipeline.png
--------------------------------------------------------------------------------
/env.example:
--------------------------------------------------------------------------------
1 | BIRD_DB_PATH="../dataset/bird-sql"
2 | OPENAI_API_KEY=
--------------------------------------------------------------------------------
/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import argparse
4 | import sqlite3
5 | import multiprocessing as mp
6 | from func_timeout import func_timeout, FunctionTimedOut
7 |
8 | def load_json(dir):
9 | with open(dir, 'r') as j:
10 | contents = json.loads(j.read())
11 | return contents
12 |
13 | def result_callback(result):
14 | exec_result.append(result)
15 |
16 |
17 | def execute_sql(predicted_sql,ground_truth, db_path):
18 | conn = sqlite3.connect(db_path)
19 | # Connect to the database
20 | cursor = conn.cursor()
21 | cursor.execute(predicted_sql)
22 | predicted_res = cursor.fetchall()
23 | cursor.execute(ground_truth)
24 | ground_truth_res = cursor.fetchall()
25 | res = 0
26 | if set(predicted_res) == set(ground_truth_res):
27 | res = 1
28 | return res
29 |
30 |
31 |
32 | def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out):
33 | try:
34 | res = func_timeout(meta_time_out, execute_sql,
35 | args=(predicted_sql, ground_truth, db_place))
36 | except KeyboardInterrupt:
37 | sys.exit(0)
38 | except FunctionTimedOut:
39 | result = [(f'timeout',)]
40 | res = 0
41 | except Exception as e:
42 | result = [(f'error',)] # possibly len(query) > 512 or not executable
43 | res = 0
44 | # print(result)
45 | # result = str(set([ret[0] for ret in result]))
46 | result = {'sql_idx': idx, 'res': res}
47 | # print(result)
48 | return result
49 |
50 |
51 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
52 | clean_sqls = []
53 | db_path_list = []
54 | if mode == 'gpt':
55 | sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
56 | for idx, sql_str in sql_data.items():
57 | if type(sql_str) == str:
58 | sql, db_name = sql_str.split('\t----- bird -----\t')
59 | else:
60 | sql, db_name = " ", "financial"
61 | clean_sqls.append(sql)
62 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
63 |
64 | elif mode == 'gt':
65 | sqls = open(sql_path + data_mode + '_gold.sql')
66 | sql_txt = sqls.readlines()
67 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt]
68 | for idx, sql_str in enumerate(sql_txt):
69 | sql, db_name = sql_str.strip().split('\t')
70 | clean_sqls.append(sql)
71 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
72 |
73 | return clean_sqls, db_path_list
74 |
75 | def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
76 | pool = mp.Pool(processes=num_cpus)
77 | for i,sql_pair in enumerate(sqls):
78 |
79 | predicted_sql, ground_truth = sql_pair
80 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback)
81 | pool.close()
82 | pool.join()
83 |
84 | def sort_results(list_of_dicts):
85 | return sorted(list_of_dicts, key=lambda x: x['sql_idx'])
86 |
87 | def compute_acc_by_diff(exec_results,diff_json_path):
88 | num_queries = len(exec_results)
89 | results = [res['res'] for res in exec_results]
90 | contents = load_json(diff_json_path)
91 | simple_results, moderate_results, challenging_results = [], [], []
92 |
93 | for i,content in enumerate(contents):
94 | if content['difficulty'] == 'simple':
95 | simple_results.append(exec_results[i])
96 |
97 | if content['difficulty'] == 'moderate':
98 | moderate_results.append(exec_results[i])
99 |
100 | if content['difficulty'] == 'challenging':
101 | challenging_results.append(exec_results[i])
102 |
103 | simple_acc = sum([res['res'] for res in simple_results])/len(simple_results)
104 | moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results)
105 | challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results)
106 | all_acc = sum(results)/num_queries
107 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
108 | return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists
109 |
110 |
111 |
112 | def print_data(score_lists,count_lists):
113 | levels = ['simple', 'moderate', 'challenging', 'total']
114 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
115 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))
116 |
117 | print('====================================== ACCURACY =====================================')
118 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists))
119 |
120 |
121 | if __name__ == '__main__':
122 | args_parser = argparse.ArgumentParser()
123 | args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='')
124 | args_parser.add_argument('--ground_truth_path', type=str, required=True, default='')
125 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev')
126 | args_parser.add_argument('--db_root_path', type=str, required=True, default='')
127 | args_parser.add_argument('--num_cpus', type=int, default=1)
128 | args_parser.add_argument('--meta_time_out', type=float, default=30.0)
129 | args_parser.add_argument('--mode_gt', type=str, default='gt')
130 | args_parser.add_argument('--mode_predict', type=str, default='gpt')
131 | args_parser.add_argument('--difficulty',type=str,default='simple')
132 | args_parser.add_argument('--diff_json_path',type=str,default='')
133 | args = args_parser.parse_args()
134 | exec_result = []
135 |
136 | pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict,
137 | data_mode=args.data_mode)
138 | # generate gt sqls:
139 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt',
140 | data_mode=args.data_mode)
141 |
142 | query_pairs = list(zip(pred_queries,gt_queries))
143 | run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
144 | exec_result = sort_results(exec_result)
145 |
146 | print('start calculate')
147 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = \
148 | compute_acc_by_diff(exec_result,args.diff_json_path)
149 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
150 | print_data(score_lists,count_lists)
151 | print('===========================================================================================')
152 | print("Finished evaluation")
153 |
--------------------------------------------------------------------------------
/evaluation/evaluation_ex.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import sys
3 | import argparse
4 | import multiprocessing as mp
5 | from func_timeout import func_timeout, FunctionTimedOut
6 | from evaluation_utils import (
7 | load_json,
8 | execute_sql,
9 | package_sqls,
10 | sort_results,
11 | print_data,
12 | )
13 |
14 |
15 | def result_callback(result):
16 | exec_result.append(result)
17 |
18 |
19 | def calculate_ex(predicted_res, ground_truth_res):
20 | res = 0
21 | if set(predicted_res) == set(ground_truth_res):
22 | res = 1
23 | return res
24 |
25 |
26 | def execute_model(
27 | predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect
28 | ):
29 | try:
30 | res = func_timeout(
31 | meta_time_out,
32 | execute_sql,
33 | args=(predicted_sql, ground_truth, db_place, sql_dialect, calculate_ex),
34 | )
35 | except KeyboardInterrupt:
36 | sys.exit(0)
37 | except FunctionTimedOut:
38 | result = [(f"timeout",)]
39 | res = 0
40 | except Exception as e:
41 | result = [(f"error",)] # possibly len(query) > 512 or not executable
42 | res = 0
43 | result = {"sql_idx": idx, "res": res}
44 | return result
45 |
46 |
47 | def run_sqls_parallel(
48 | sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite"
49 | ):
50 | pool = mp.Pool(processes=num_cpus)
51 | for i, sql_pair in enumerate(sqls):
52 |
53 | predicted_sql, ground_truth = sql_pair
54 | pool.apply_async(
55 | execute_model,
56 | args=(
57 | predicted_sql,
58 | ground_truth,
59 | db_places[i],
60 | i,
61 | meta_time_out,
62 | sql_dialect,
63 | ),
64 | callback=result_callback,
65 | )
66 | pool.close()
67 | pool.join()
68 |
69 |
70 | def compute_acc_by_diff(exec_results, diff_json_path):
71 | num_queries = len(exec_results)
72 | results = [res["res"] for res in exec_results]
73 | contents = load_json(diff_json_path)
74 | simple_results, moderate_results, challenging_results = [], [], []
75 |
76 | for i, content in enumerate(contents):
77 | if content["difficulty"] == "simple":
78 | simple_results.append(exec_results[i])
79 |
80 | if content["difficulty"] == "moderate":
81 | moderate_results.append(exec_results[i])
82 |
83 | if content["difficulty"] == "challenging":
84 | try:
85 | challenging_results.append(exec_results[i])
86 | except:
87 | print(i)
88 |
89 | simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
90 | moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
91 | challenging_acc = sum([res["res"] for res in challenging_results]) / len(
92 | challenging_results
93 | )
94 | all_acc = sum(results) / num_queries
95 | count_lists = [
96 | len(simple_results),
97 | len(moderate_results),
98 | len(challenging_results),
99 | num_queries,
100 | ]
101 | return (
102 | simple_acc * 100,
103 | moderate_acc * 100,
104 | challenging_acc * 100,
105 | all_acc * 100,
106 | count_lists,
107 | )
108 |
109 |
110 | if __name__ == "__main__":
111 | # Record the start time
112 | start_time = datetime.now()
113 | print(f"Start time: {start_time}")
114 |
115 | args_parser = argparse.ArgumentParser()
116 | args_parser.add_argument(
117 | "--predicted_sql_path", type=str, required=True, default=""
118 | )
119 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
120 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
121 | args_parser.add_argument("--db_root_path", type=str, required=True, default="")
122 | args_parser.add_argument("--num_cpus", type=int, default=1)
123 | args_parser.add_argument("--meta_time_out", type=float, default=30.0)
124 | args_parser.add_argument("--mode_gt", type=str, default="gt")
125 | args_parser.add_argument("--mode_predict", type=str, default="gpt")
126 | args_parser.add_argument("--difficulty", type=str, default="simple")
127 | args_parser.add_argument("--diff_json_path", type=str, default="")
128 | args_parser.add_argument("--engine", type=str, default="")
129 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite")
130 | args = args_parser.parse_args()
131 | exec_result = []
132 |
133 | pred_queries, db_paths = package_sqls(
134 | args.predicted_sql_path,
135 | args.db_root_path,
136 | args.engine,
137 | sql_dialect=args.sql_dialect,
138 | mode=args.mode_predict,
139 | data_mode=args.data_mode,
140 | )
141 | # generate ground truth sqls:
142 | gt_queries, db_paths_gt = package_sqls(
143 | args.ground_truth_path,
144 | args.db_root_path,
145 | args.engine,
146 | sql_dialect=args.sql_dialect,
147 | mode="gt",
148 | data_mode=args.data_mode,
149 | )
150 |
151 | query_pairs = list(zip(pred_queries, gt_queries))
152 |
153 | run_sqls_parallel(
154 | query_pairs,
155 | db_places=db_paths,
156 | num_cpus=args.num_cpus,
157 | meta_time_out=args.meta_time_out,
158 | sql_dialect=args.sql_dialect,
159 | )
160 | exec_result = sort_results(exec_result)
161 | print("start calculate")
162 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
163 | exec_result, args.diff_json_path
164 | )
165 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
166 | print(f"EX for {args.engine} on {args.sql_dialect} set")
167 | print("start calculate")
168 | print_data(score_lists, count_lists, metric="EX")
169 | print(
170 | "==========================================================================================="
171 | )
172 | print(f"Finished EX evaluation for {args.engine} on {args.sql_dialect} set")
173 | print("\n\n")
174 |
175 |
176 | # Record the end time and calculate duration
177 | end_time = datetime.now()
178 | print(f"End time: {end_time}")
179 | duration = end_time - start_time
180 | print(f"Duration: {duration}")
181 |
182 | # Saving EX results in a txt file
183 | ex_result_file_path = args.predicted_sql_path + "ex_score.txt"
184 | ex_file_content = f"The EX scores are: \nOverall: {acc} \nSimple: {simple_acc} \nModerate: {moderate_acc} \nChallenging: {challenging_acc} \n\nCount Lists: {count_lists} \n\nEvaluation Duration: {duration} \n\nMeta Time-out: {args.meta_time_out}"
185 | with open(ex_result_file_path, 'w') as f:
186 | f.write(ex_file_content)
187 |
188 | print("EX score is written into the ex_score.txt file.")
189 |
190 |
191 |
192 |
193 |
194 |
195 |
--------------------------------------------------------------------------------
/evaluation/evaluation_f1.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import sys
3 | import argparse
4 | import multiprocessing as mp
5 | from func_timeout import func_timeout, FunctionTimedOut
6 | from evaluation_utils import (
7 | load_json,
8 | execute_sql,
9 | package_sqls,
10 | sort_results,
11 | print_data,
12 | )
13 |
14 |
15 | def calculate_row_match(predicted_row, ground_truth_row):
16 | """
17 | Calculate the matching percentage for a single row.
18 |
19 | Args:
20 | predicted_row (tuple): The predicted row values.
21 | ground_truth_row (tuple): The actual row values from ground truth.
22 |
23 | Returns:
24 | float: The match percentage (0 to 1 scale).
25 | """
26 | total_columns = len(ground_truth_row)
27 | matches = 0
28 | element_in_pred_only = 0
29 | element_in_truth_only = 0
30 | for pred_val in predicted_row:
31 | if pred_val in ground_truth_row:
32 | matches += 1
33 | else:
34 | element_in_pred_only += 1
35 | for truth_val in ground_truth_row:
36 | if truth_val not in predicted_row:
37 | element_in_truth_only += 1
38 | match_percentage = matches / total_columns
39 | pred_only_percentage = element_in_pred_only / total_columns
40 | truth_only_percentage = element_in_truth_only / total_columns
41 | return match_percentage, pred_only_percentage, truth_only_percentage
42 |
43 |
44 | def calculate_f1_score(predicted, ground_truth):
45 | """
46 | Calculate the F1 score based on sets of predicted results and ground truth results,
47 | where each element (tuple) represents a row from the database with multiple columns.
48 |
49 | Args:
50 | predicted (set of tuples): Predicted results from SQL query.
51 | ground_truth (set of tuples): Actual results expected (ground truth).
52 |
53 | Returns:
54 | float: The calculated F1 score.
55 | """
56 | # if both predicted and ground_truth are empty, return 1.0 for f1_score
57 | if not predicted and not ground_truth:
58 | return 1.0
59 |
60 | # Drop duplicates
61 | predicted_set = set(predicted) if predicted else set()
62 | ground_truth_set = set(ground_truth)
63 |
64 | # convert back to list
65 | predicted = list(predicted_set)
66 | ground_truth = list(ground_truth_set)
67 |
68 | # Calculate matching scores for each possible pair
69 | match_scores = []
70 | pred_only_scores = []
71 | truth_only_scores = []
72 | for i, gt_row in enumerate(ground_truth):
73 | # rows only in the ground truth results
74 | if i >= len(predicted):
75 | match_scores.append(0)
76 | truth_only_scores.append(1)
77 | continue
78 | pred_row = predicted[i]
79 | match_score, pred_only_score, truth_only_score = calculate_row_match(
80 | pred_row, gt_row
81 | )
82 | match_scores.append(match_score)
83 | pred_only_scores.append(pred_only_score)
84 | truth_only_scores.append(truth_only_score)
85 |
86 | # rows only in the predicted results
87 | for i in range(len(predicted) - len(ground_truth)):
88 | match_scores.append(0)
89 | pred_only_scores.append(1)
90 | truth_only_scores.append(0)
91 |
92 | tp = sum(match_scores)
93 | fp = sum(pred_only_scores)
94 | fn = sum(truth_only_scores)
95 |
96 | precision = tp / (tp + fp) if tp + fp > 0 else 0
97 | recall = tp / (tp + fn) if tp + fn > 0 else 0
98 |
99 | f1_score = (
100 | 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
101 | )
102 | return f1_score
103 |
104 |
105 | def result_callback(result):
106 | exec_result.append(result)
107 |
108 |
109 | def execute_model(
110 | predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect
111 | ):
112 | try:
113 | res = func_timeout(
114 | meta_time_out,
115 | execute_sql,
116 | args=(
117 | predicted_sql,
118 | ground_truth,
119 | db_place,
120 | sql_dialect,
121 | calculate_f1_score,
122 | ),
123 | )
124 | except KeyboardInterrupt:
125 | sys.exit(0)
126 | except FunctionTimedOut:
127 | result = [(f"timeout",)]
128 | res = 0
129 | except Exception as e:
130 | result = [(f"error",)] # possibly len(query) > 512 or not executable
131 | res = 0
132 | # print(result)
133 | # result = str(set([ret[0] for ret in result]))
134 | result = {"sql_idx": idx, "res": res}
135 | # print(result)
136 | return result
137 |
138 |
139 | def run_sqls_parallel(
140 | sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite"
141 | ):
142 | pool = mp.Pool(processes=num_cpus)
143 | for i, sql_pair in enumerate(sqls):
144 |
145 | predicted_sql, ground_truth = sql_pair
146 | pool.apply_async(
147 | execute_model,
148 | args=(
149 | predicted_sql,
150 | ground_truth,
151 | db_places[i],
152 | i,
153 | meta_time_out,
154 | sql_dialect,
155 | ),
156 | callback=result_callback,
157 | )
158 | pool.close()
159 | pool.join()
160 |
161 |
162 | def compute_f1_by_diff(exec_results, diff_json_path):
163 | num_queries = len(exec_results)
164 | results = [res["res"] for res in exec_results]
165 | contents = load_json(diff_json_path)
166 | simple_results, moderate_results, challenging_results = [], [], []
167 |
168 | for i, content in enumerate(contents):
169 | if content["difficulty"] == "simple":
170 | simple_results.append(exec_results[i])
171 |
172 | if content["difficulty"] == "moderate":
173 | moderate_results.append(exec_results[i])
174 |
175 | if content["difficulty"] == "challenging":
176 | try:
177 | challenging_results.append(exec_results[i])
178 | except:
179 | print(i)
180 |
181 | simple_f1 = sum([res["res"] for res in simple_results]) / len(simple_results) * 100
182 | moderate_f1 = (
183 | sum([res["res"] for res in moderate_results]) / len(moderate_results) * 100
184 | )
185 | challenging_f1 = (
186 | sum([res["res"] for res in challenging_results])
187 | / len(challenging_results)
188 | * 100
189 | )
190 | all_f1 = sum(results) / num_queries * 100
191 | count_lists = [
192 | len(simple_results),
193 | len(moderate_results),
194 | len(challenging_results),
195 | num_queries,
196 | ]
197 | return (
198 | simple_f1,
199 | moderate_f1,
200 | challenging_f1,
201 | all_f1,
202 | count_lists,
203 | )
204 |
205 |
206 | if __name__ == "__main__":
207 | # Record the start time
208 | start_time = datetime.now()
209 | print(f"Start time: {start_time}")
210 |
211 | args_parser = argparse.ArgumentParser()
212 | args_parser.add_argument(
213 | "--predicted_sql_path", type=str, required=True, default=""
214 | )
215 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
216 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
217 | args_parser.add_argument("--db_root_path", type=str, required=True, default="")
218 | args_parser.add_argument("--num_cpus", type=int, default=1)
219 | args_parser.add_argument("--meta_time_out", type=float, default=30.0)
220 | args_parser.add_argument("--mode_gt", type=str, default="gt")
221 | args_parser.add_argument("--mode_predict", type=str, default="gpt")
222 | args_parser.add_argument("--difficulty", type=str, default="simple")
223 | args_parser.add_argument("--diff_json_path", type=str, default="")
224 | args_parser.add_argument("--engine", type=str, default="")
225 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite")
226 | args = args_parser.parse_args()
227 | exec_result = []
228 |
229 | pred_queries, db_paths = package_sqls(
230 | args.predicted_sql_path,
231 | args.db_root_path,
232 | args.engine,
233 | sql_dialect=args.sql_dialect,
234 | mode=args.mode_predict,
235 | data_mode=args.data_mode,
236 | )
237 | # generate ground truth sqls:
238 | gt_queries, db_paths_gt = package_sqls(
239 | args.ground_truth_path,
240 | args.db_root_path,
241 | args.engine,
242 | sql_dialect=args.sql_dialect,
243 | mode="gt",
244 | data_mode=args.data_mode,
245 | )
246 |
247 | query_pairs = list(zip(pred_queries, gt_queries))
248 |
249 | run_sqls_parallel(
250 | query_pairs,
251 | db_places=db_paths,
252 | num_cpus=args.num_cpus,
253 | meta_time_out=args.meta_time_out,
254 | sql_dialect=args.sql_dialect,
255 | )
256 | exec_result = sort_results(exec_result)
257 |
258 | print("start calculate")
259 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_f1_by_diff(
260 | exec_result, args.diff_json_path
261 | )
262 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
263 | print(f"Soft F1 for {args.engine} on {args.sql_dialect} set")
264 | print("start calculate")
265 | print_data(score_lists, count_lists)
266 | print(
267 | "==========================================================================================="
268 | )
269 | print(f"Finished Soft F1 evaluation for {args.engine} on {args.sql_dialect} set")
270 | print("\n\n")
271 |
272 | # Record the end time and calculate duration
273 | end_time = datetime.now()
274 | print(f"End time: {end_time}")
275 | duration = end_time - start_time
276 | print(f"Duration: {duration}")
277 |
278 | # Saving soft F1 results in a txt file
279 | ex_result_file_path = args.predicted_sql_path + "soft_f1_score.txt"
280 | ex_file_content = f"The soft F1-Scores are: \nOverall Soft F1: {acc} \nSimple Soft F1: {simple_acc} \nModerate Soft F1: {moderate_acc} \nChallenging Soft F1: {challenging_acc} \n\nCount Lists: {count_lists} \n\nEvaluation Duration: {duration} \n\nMeta Time-out: {args.meta_time_out}"
281 | with open(ex_result_file_path, 'w') as f:
282 | f.write(ex_file_content)
283 |
284 | print("Soft F1 score is written into the soft_f1_score.txt file.")
285 |
286 |
287 |
288 |
--------------------------------------------------------------------------------
/evaluation/evaluation_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | # import psycopg2
3 | # import pymysql
4 | import sqlite3
5 |
6 |
7 | def load_json(dir):
8 | with open(dir, "r") as j:
9 | contents = json.loads(j.read())
10 | return contents
11 |
12 |
13 | def connect_postgresql():
14 | # Open database connection
15 | # Connect to the database
16 | db = psycopg2.connect(
17 | "dbname=BIRD user=root host=localhost password=YOUR_PASSWORD port=5432"
18 | )
19 | return db
20 |
21 |
22 | def connect_mysql():
23 | # Open database connection
24 | # Connect to the database"
25 | db = pymysql.connect(
26 | host="localhost",
27 | user="root",
28 | password="YOUR_PASSWORD",
29 | database="BIRD",
30 | unix_socket="/tmp/mysql.sock",
31 | # port=3306,
32 | )
33 | return db
34 |
35 |
36 | def connect_db(sql_dialect, db_path):
37 | if sql_dialect == "SQLite":
38 | conn = sqlite3.connect(db_path)
39 | elif sql_dialect == "MySQL":
40 | conn = connect_mysql()
41 | elif sql_dialect == "PostgreSQL":
42 | conn = connect_postgresql()
43 | else:
44 | raise ValueError("Unsupported SQL dialect")
45 | return conn
46 |
47 |
48 | def execute_sql(predicted_sql, ground_truth, db_path, sql_dialect, calculate_func):
49 | conn = connect_db(sql_dialect, db_path)
50 | # Connect to the database
51 | cursor = conn.cursor()
52 | cursor.execute(predicted_sql)
53 | predicted_res = cursor.fetchall()
54 | cursor.execute(ground_truth)
55 | ground_truth_res = cursor.fetchall()
56 | conn.close()
57 | res = calculate_func(predicted_res, ground_truth_res)
58 | return res
59 |
60 |
61 | def package_sqls(
62 | sql_path, db_root_path, engine, sql_dialect="SQLite", mode="gpt", data_mode="dev"
63 | ):
64 | clean_sqls = []
65 | db_path_list = []
66 | if mode == "gpt":
67 | # use chain of thought
68 | # sql_data = json.load(
69 | # open(
70 | # sql_path
71 | # + "predict_"
72 | # + data_mode
73 | # + "_"
74 | # + engine
75 | # + "_cot_"
76 | # + sql_dialect
77 | # + ".json",
78 | # "r",
79 | # )
80 | # )
81 | sql_data = json.load(
82 | open(
83 | sql_path
84 | + "predict_"
85 | + data_mode
86 | + ".json",
87 | "r",
88 | )
89 | )
90 | for _, sql_str in sql_data.items():
91 | if type(sql_str) == str:
92 | sql, db_name = sql_str.split("\t----- bird -----\t")
93 | else:
94 | sql, db_name = " ", "financial"
95 | clean_sqls.append(sql)
96 | db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
97 |
98 | elif mode == "gt":
99 | sqls = open(sql_path + data_mode + "_" + sql_dialect + "_gold.sql")
100 | sql_txt = sqls.readlines()
101 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt]
102 | for idx, sql_str in enumerate(sql_txt):
103 | # print(sql_str)
104 | sql, db_name = sql_str.strip().split("\t")
105 | clean_sqls.append(sql)
106 | db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
107 |
108 | return clean_sqls, db_path_list
109 |
110 |
111 | def sort_results(list_of_dicts):
112 | return sorted(list_of_dicts, key=lambda x: x["sql_idx"])
113 |
114 |
115 | def print_data(score_lists, count_lists, metric="F1 Score"):
116 | levels = ["simple", "moderate", "challenging", "total"]
117 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
118 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))
119 |
120 | print(
121 | f"====================================== {metric} ====================================="
122 | )
123 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format(metric, *score_lists))
124 |
--------------------------------------------------------------------------------
/evaluation/evaluation_ves.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import sys
3 | import json
4 | import numpy as np
5 | import argparse
6 | import multiprocessing as mp
7 | from func_timeout import func_timeout, FunctionTimedOut
8 | from evaluation_utils import (
9 | load_json,
10 | execute_sql,
11 | package_sqls,
12 | sort_results,
13 | print_data,
14 | connect_db,
15 | )
16 | import time
17 | import math
18 |
19 |
20 | def result_callback(result):
21 | exec_result.append(result)
22 |
23 |
24 | def clean_abnormal(input):
25 | input = np.asarray(input)
26 | processed_list = []
27 | mean = np.mean(input, axis=0)
28 | std = np.std(input, axis=0)
29 | for x in input:
30 | if x < mean + 3 * std and x > mean - 3 * std:
31 | processed_list.append(x)
32 | return processed_list
33 |
34 |
35 | def execute_sql(sql, db_path, sql_dialect, return_time=False):
36 | # Connect to the database
37 | conn = connect_db(sql_dialect, db_path)
38 | start_time = time.time()
39 | cursor = conn.cursor()
40 | cursor.execute(sql)
41 | res = cursor.fetchall()
42 | conn.close() # Don't forget to close the connection!
43 | exec_time = time.time() - start_time
44 | if return_time:
45 | return exec_time
46 |
47 | return res
48 |
49 |
50 | def iterated_execute_sql(
51 | predicted_sql, ground_truth, db_path, iterate_num, sql_dialect
52 | ):
53 | diff_list = []
54 | predicted_res = execute_sql(predicted_sql, db_path, sql_dialect)
55 | ground_truth_res = execute_sql(ground_truth, db_path, sql_dialect)
56 | reward = 0
57 | time_ratio = 0
58 | if set(predicted_res) == set(ground_truth_res):
59 | for _ in range(iterate_num):
60 | predicted_time = execute_sql(
61 | predicted_sql, db_path, sql_dialect, return_time=True
62 | )
63 | ground_truth_time = execute_sql(
64 | ground_truth, db_path, sql_dialect, return_time=True
65 | )
66 | diff_list.append(ground_truth_time / predicted_time)
67 | processed_diff_list = clean_abnormal(diff_list)
68 | time_ratio = sum(processed_diff_list) / len(processed_diff_list)
69 | if time_ratio == 0:
70 | reward = 0
71 | elif time_ratio >= 2:
72 | reward = 1.25
73 | elif time_ratio >= 1 and time_ratio < 2:
74 | reward = 1
75 | elif time_ratio >= 0.5 and time_ratio < 1:
76 | reward = 0.75
77 | elif time_ratio >= 0.25 and time_ratio < 0.5:
78 | reward = 0.5
79 | else:
80 | reward = 0.25
81 | # return time_ratio
82 | return reward
83 |
84 |
85 | def execute_model(
86 | predicted_sql, ground_truth, db_place, idx, iterate_num, meta_time_out, sql_dialect
87 | ):
88 | try:
89 | # you can personalize the total timeout number
90 | # larger timeout leads to more stable ves
91 | # while it needs more your patience....
92 | reward = func_timeout(
93 | meta_time_out * iterate_num,
94 | iterated_execute_sql,
95 | args=(predicted_sql, ground_truth, db_place, iterate_num, sql_dialect),
96 | )
97 | except KeyboardInterrupt:
98 | sys.exit(0)
99 | except FunctionTimedOut:
100 | result = [(f"timeout",)]
101 | reward = 0
102 | except Exception as e:
103 | result = [(f"error",)] # possibly len(query) > 512 or not executable
104 | reward = 0
105 | result = {"sql_idx": idx, "reward": reward}
106 | return result
107 |
108 |
109 | def run_sqls_parallel(
110 | sqls,
111 | db_places,
112 | num_cpus=1,
113 | iterate_num=100,
114 | meta_time_out=30.0,
115 | sql_dialect="SQLite",
116 | ):
117 | pool = mp.Pool(processes=num_cpus)
118 | for i, sql_pair in enumerate(sqls):
119 | predicted_sql, ground_truth = sql_pair
120 | pool.apply_async(
121 | execute_model,
122 | args=(
123 | predicted_sql,
124 | ground_truth,
125 | db_places[i],
126 | i,
127 | iterate_num,
128 | meta_time_out,
129 | sql_dialect,
130 | ),
131 | callback=result_callback,
132 | )
133 | pool.close()
134 | pool.join()
135 |
136 |
137 | def compute_ves(exec_results):
138 | num_queries = len(exec_results)
139 | total_reward = 0
140 | count = 0
141 |
142 | for i, result in enumerate(exec_results):
143 | if result["reward"] != 0:
144 | count += 1
145 | total_reward += math.sqrt(result["reward"]) * 100
146 | ves = total_reward / num_queries
147 | return ves
148 |
149 |
150 | def compute_ves_by_diff(exec_results, diff_json_path):
151 | num_queries = len(exec_results)
152 | contents = load_json(diff_json_path)
153 | simple_results, moderate_results, challenging_results = [], [], []
154 | for i, content in enumerate(contents):
155 | if content["difficulty"] == "simple":
156 | simple_results.append(exec_results[i])
157 | if content["difficulty"] == "moderate":
158 | moderate_results.append(exec_results[i])
159 | if content["difficulty"] == "challenging":
160 | challenging_results.append(exec_results[i])
161 | simple_ves = compute_ves(simple_results)
162 | moderate_ves = compute_ves(moderate_results)
163 | challenging_ves = compute_ves(challenging_results)
164 | all_ves = compute_ves(exec_results)
165 | count_lists = [
166 | len(simple_results),
167 | len(moderate_results),
168 | len(challenging_results),
169 | num_queries,
170 | ]
171 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists
172 |
173 |
174 | def print_reward_category(exec_results, engine, sql_dialect):
175 | res = {
176 | "engine": engine,
177 | "sql_dialect": sql_dialect,
178 | "distribution": exec_results,
179 | }
180 | file_path = "results.json"
181 | try:
182 | with open(file_path, "r") as file:
183 | data = json.load(file)
184 | except (FileNotFoundError, json.JSONDecodeError):
185 | data = [] # Start with an empty list if file doesn't exist or is empty
186 |
187 | # Append the new data
188 | data.append(res)
189 |
190 | # Write the updated data back to the file
191 | with open(file_path, "w") as file:
192 | json.dump(data, file, indent=4)
193 |
194 |
195 | if __name__ == "__main__":
196 | # Record the start time
197 | start_time = datetime.now()
198 | print(f"Start time: {start_time}")
199 |
200 | args_parser = argparse.ArgumentParser()
201 | args_parser.add_argument(
202 | "--predicted_sql_path", type=str, required=True, default=""
203 | )
204 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
205 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
206 | args_parser.add_argument("--db_root_path", type=str, required=True, default="")
207 | args_parser.add_argument("--num_cpus", type=int, default=1)
208 | args_parser.add_argument("--meta_time_out", type=float, default=30.0)
209 | args_parser.add_argument("--mode_gt", type=str, default="gt")
210 | args_parser.add_argument("--mode_predict", type=str, default="gpt")
211 | args_parser.add_argument("--diff_json_path", type=str, default="")
212 | args_parser.add_argument("--engine", type=str, default="")
213 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite")
214 | args = args_parser.parse_args()
215 | exec_result = []
216 |
217 | pred_queries, db_paths = package_sqls(
218 | args.predicted_sql_path,
219 | args.db_root_path,
220 | args.engine,
221 | sql_dialect=args.sql_dialect,
222 | mode=args.mode_predict,
223 | data_mode=args.data_mode,
224 | )
225 | # generate ground truth sqls:
226 | gt_queries, db_paths_gt = package_sqls(
227 | args.ground_truth_path,
228 | args.db_root_path,
229 | args.engine,
230 | sql_dialect=args.sql_dialect,
231 | mode="gt",
232 | data_mode=args.data_mode,
233 | )
234 | query_pairs = list(zip(pred_queries, gt_queries))
235 | run_sqls_parallel(
236 | query_pairs,
237 | db_places=db_paths,
238 | num_cpus=args.num_cpus,
239 | meta_time_out=args.meta_time_out,
240 | sql_dialect=args.sql_dialect,
241 | )
242 | exec_result = sort_results(exec_result)
243 | # print_reward_category(exec_result, args.engine, args.sql_dialect)
244 | print("start calculate")
245 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = compute_ves_by_diff(
246 | exec_result, args.diff_json_path
247 | )
248 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves]
249 | print(f"VES for {args.engine} on {args.sql_dialect} set")
250 | print("start calculate")
251 | print_data(score_lists, count_lists, metric="VES")
252 | print(
253 | "==========================================================================================="
254 | )
255 | print(f"Finished VES evaluation for {args.engine} on {args.sql_dialect} set")
256 | print("\n\n")
257 |
258 | # Record the end time and calculate duration
259 | end_time = datetime.now()
260 | print(f"End time: {end_time}")
261 | duration = end_time - start_time
262 | print(f"Duration: {duration}")
263 |
264 | # Saving R-VES results in a txt file
265 | ex_result_file_path = args.predicted_sql_path + "r_ves_score.txt"
266 | ex_file_content = f"The R-VES Scores are: \nOverall R-VES: {ves} \nSimple R-VES: {simple_ves} \nModerate R-VES: {moderate_ves} \nChallenging R-VES: {challenging_ves} \n\nCount Lists: {count_lists} \n\nEvaluation Duration: {duration} \n\nMeta Time-out: {args.meta_time_out}"
267 | with open(ex_result_file_path, 'w') as f:
268 | f.write(ex_file_content)
269 |
270 | print("R-VES score is written into the r_ves_score.txt file.")
271 |
272 |
273 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import random
5 | from dotenv import load_dotenv
6 | from pipeline.Pipeline import *
7 | from utils.db_utils import *
8 | from utils.retrieval_utils import process_all_dbs
9 | from typing import Dict, Union, List, Tuple
10 |
11 | def main(args):
12 | load_dotenv() # load variables into os.environ
13 | create_result_files(args) # creating results directory for specific arguments
14 |
15 | bird_sql_path = os.getenv('BIRD_DB_PATH')
16 | args.dataset_path = bird_sql_path
17 | process_all_dbs(bird_sql_path, args.mode) # for all databases, creating db_description.csv files which include column descriptions for all talbes
18 |
19 | # set random seed
20 | random.seed(args.seed)
21 |
22 | # load dataset
23 | dataset_json_path = bird_sql_path + f"/{args.mode}/{args.mode}.json"
24 | f = open(dataset_json_path)
25 | dataset = json.load(f)
26 |
27 | pipeline = Pipeline(args)
28 |
29 | output_dict = {}
30 | predictions = []
31 | # Incase error you can restart the code from the point of error using following lines
32 | # dataset = dataset[: ]
33 | # dataset = dataset[:]
34 | dataset = dataset[1135:]
35 | for ind,t2s_object in enumerate(dataset):
36 | q_id = t2s_object["question_id"]
37 | if pipeline.pipeline_order == "CSG-SR":
38 | t2s_object_prediction = pipeline.forward_pipeline_CSG_SR(t2s_object)
39 | elif pipeline.pipeline_order == "CSG-QE-SR":
40 | t2s_object_prediction = pipeline.forward_pipeline_CSG_QE_SR(t2s_object)
41 | elif pipeline.pipeline_order == "SF-CSG-QE-SR":
42 | t2s_object_prediction = pipeline.forward_pipeline_SF_CSG_QE_SR(t2s_object)
43 | else:
44 | raise ValueError("Wrong value for pipeline_order argument. It must be either CSG-QE-SR or CSG-SR.")
45 |
46 | # Compare predicted and ground truth sqls
47 | compare_results = check_correctness(t2s_object_prediction, args)
48 | t2s_object_prediction['results'] = compare_results
49 | if os.path.exists(args.prediction_json_path):
50 | # get existing predictions
51 | with open(args.prediction_json_path, 'r') as file_read:
52 | existing_predictions = json.load(file_read)
53 |
54 | # add new prediction to the existing predictions and then write to the file
55 | existing_predictions.append(t2s_object_prediction)
56 | with open(args.prediction_json_path, 'w') as file_write:
57 | json.dump(existing_predictions, file_write, indent=4)
58 |
59 | else:
60 | file_write = open(args.prediction_json_path, 'w')
61 | existing_predictions = [t2s_object_prediction]
62 | json.dump(existing_predictions, file_write, indent=4)
63 | file_write.close()
64 |
65 | # # add the current text2sql object to the predictions
66 | # predictions.append(t2s_object_prediction)
67 | # # writing prediction to the predictions.json file
68 | # with open(args.prediction_json_path, 'w') as f:
69 | # json.dump(predictions, f, indent=4) # indent=4 for pretty printing
70 |
71 | # adding predicted sql in the expected format for the evaluation files
72 | db_id = t2s_object_prediction["db_id"]
73 | predicted_sql = t2s_object_prediction["predicted_sql"]
74 | predicted_sql = predicted_sql.replace('\"','').replace('\\\n',' ').replace('\n',' ')
75 | sql = predicted_sql + '\t----- bird -----\t' + db_id
76 | output_dict[str(q_id)] = sql
77 | if os.path.exists(args.predictions_eval_json_path):
78 | with open(args.predictions_eval_json_path, 'r') as f:
79 | contents = json.loads(f.read())
80 | else:
81 | # Initialize contents as an empty dictionary if the file doesn't exist
82 | contents = {}
83 | contents.update(output_dict)
84 | json.dump(contents, open(args.predictions_eval_json_path, 'w'), indent=4)
85 |
86 | print(f"Question with {q_id} is processed. Correctness: {compare_results['exec_res']} ")
87 |
88 | # Calculatin Metrics
89 | predictions_json_file = open(args.prediction_json_path, 'r')
90 | predictions = json.load(predictions_json_file)
91 | stats, fail_q_ids = calculate_accuracies(predictions)
92 | metric_object = {
93 | "EX": stats["ex"],
94 | "total_correct_count": stats["total_correct_count"],
95 | "total_item_count": stats["total_item_count"],
96 | "simple_stats": stats["simple"],
97 | "moderate_stats": stats["moderate"],
98 | "challenging_stats": stats["challenging"],
99 | "fail_q_ids": fail_q_ids,
100 | "config": {
101 | "mode": args.mode,
102 | "model": args.model,
103 | "temperature": args.temperature,
104 | "top_p": args.top_p,
105 | "max_tokens": args.max_tokens,
106 | "n": args.n,
107 | "pipeline_order": args.pipeline_order,
108 | "enrichment_level": args.enrichment_level,
109 | "enrichment_level_shot_number": args.enrichment_level_shot_number,
110 | "enrichment_few_shot_schema_existance": args.enrichment_few_shot_schema_existance,
111 | "filtering_level_shot_number": args.filtering_level_shot_number,
112 | "filtering_few_shot_schema_existance": args.filtering_few_shot_schema_existance,
113 | "cfg": args.cfg,
114 | "generation_level_shot_number": args.generation_level_shot_number,
115 | "generation_few_shot_schema_existance": args.generation_few_shot_schema_existance,
116 | "db_sample_limit": args.db_sample_limit,
117 | "relevant_description_number": args.relevant_description_number,
118 | "seed": args.seed
119 | }
120 | }
121 |
122 | # Writing metrics to a file
123 | metrics_path = args.output_directory_path + "/metrics.json"
124 | # writing metric object
125 | with open(metrics_path, 'w') as f:
126 | json.dump(metric_object, f, indent=4) # indent=4 for pretty printing
127 |
128 | print("Metrics are written into metrics.json file.")
129 | return
130 |
131 |
132 | def calculate_accuracies(predictions: List[Dict]) -> Tuple[float, List]:
133 | """
134 | The function calculates the Execution Accuracy(EX) metric and find the question IDs whose predictions are failed
135 |
136 | Arguments:
137 | predictions
138 | """
139 | difficulty_flag = False
140 | failed_predictions_q_ids = []
141 | stats = {
142 | "ex": 0,
143 | "total_correct_count": 0,
144 | "total_item_count": 0,
145 | "simple": {
146 | "correct_number": 0,
147 | "count": 0
148 | },
149 | "moderate": {
150 | "correct_number": 0,
151 | "count": 0
152 | },
153 | "challenging": {
154 | "correct_number": 0,
155 | "count": 0
156 | }
157 | }
158 |
159 | # check if there is difficulty key
160 | sample = predictions[0]
161 | if "difficulty" in sample:
162 | difficulty_flag = True
163 | else:
164 | difficulty_flag = False
165 |
166 | if difficulty_flag:
167 | for q2s_object in predictions:
168 | level = q2s_object['difficulty']
169 | stats[level]["count"] = stats[level]["count"] + 1
170 |
171 | if q2s_object['results']['exec_res'] != 0:
172 | stats[level]['correct_number'] = stats[level]['correct_number'] + 1
173 | else:
174 | failed_predictions_q_ids.append(q2s_object['question_id'])
175 |
176 | stats["simple"]["ex"] = stats["simple"]["correct_number"] / stats["simple"]["count"] * 100
177 | stats["moderate"]["ex"] = stats["moderate"]["correct_number"] / stats["moderate"]["count"] * 100
178 | stats["challenging"]["ex"] = stats["challenging"]["correct_number"] / stats["challenging"]["count"] * 100
179 |
180 | stats["total_item_count"] = stats["simple"]["count"] + stats["moderate"]["count"] + stats["challenging"]["count"]
181 | stats["total_correct_count"] = stats["simple"]["correct_number"] + stats["moderate"]["correct_number"] + stats["challenging"]["correct_number"]
182 | stats["ex"] = stats["total_correct_count"] / stats["total_item_count"] * 100
183 |
184 | return (stats, failed_predictions_q_ids)
185 |
186 | else:
187 |
188 | for q2s_object in predictions:
189 | stats["total_item_count"] = stats["total_item_count"] + 1
190 | if q2s_object['results']['exec_res'] != 0:
191 | stats["total_correct_count"] = stats["total_correct_count"] + 1
192 | else:
193 | failed_predictions_q_ids.append(q2s_object['question_id'])
194 |
195 | stats["ex"] = stats["total_correct_count"] / stats["total_item_count"] * 100
196 | return (stats, failed_predictions_q_ids)
197 |
198 | def check_correctness(t2s_object_prediction: Dict, args) -> Dict[str, Union[int, str]]:
199 | """
200 | The function check whether predicted SQL is correct or not
201 |
202 | Arguments:
203 | t2s_object_prediction ()
204 |
205 | Returns:
206 | compare_results (Dict[str, Union[int, str]]): Comparison results dictionary with execution result and execution error keys
207 | """
208 | db_id = t2s_object_prediction['db_id']
209 | bird_sql_path = os.getenv('BIRD_DB_PATH')
210 | db_path = bird_sql_path+ f"/{args.mode}/{args.mode}_databases/{db_id}/{db_id}.sqlite"
211 | if 'predicted_sql' in t2s_object_prediction:
212 | predicted_sql = t2s_object_prediction['predicted_sql']
213 | gt_sql = t2s_object_prediction['SQL']
214 | compare_results = compare_sqls(db_path=db_path, predicted_sql=predicted_sql, ground_truth_sql=gt_sql )
215 | else:
216 | compare_results = {'exec_res': 0, 'exec_err': "There is no predicted SQL. There must be and error in this question while extracting information."}
217 |
218 | return compare_results
219 |
220 | def create_result_files(args):
221 | """
222 | The function creates result files according to arguments.
223 | """
224 |
225 | # Ensure the results directory exist otherwise create it
226 | if not os.path.exists("./results"):
227 | os.makedirs("./results")
228 |
229 | args.output_directory_path = f"./results/model_outputs_{args.mode}_{args.pipeline_order}_{args.model}"
230 |
231 | # Ensure the directory exists
232 | if not os.path.exists(args.output_directory_path):
233 | os.makedirs(args.output_directory_path)
234 |
235 | # Overall predictions file
236 | prediction_json_path = args.output_directory_path + "/predictions.json"
237 | args.prediction_json_path = prediction_json_path
238 | # print("args.prediction_json_path: ", args.prediction_json_path)
239 |
240 | # Create an empty predictions.json file if not exist
241 | if not os.path.exists(prediction_json_path):
242 | with open(args.prediction_json_path, 'w') as f:
243 | json.dump([], f) # Initialize with an empty JSON object
244 |
245 | # predictions file for evaluation
246 | predictions_eval_json_path = args.output_directory_path + f"/predict_{args.mode}.json"
247 | args.predictions_eval_json_path = predictions_eval_json_path
248 |
249 |
250 |
251 | def str2bool(v: str) -> bool:
252 | """
253 | The function converst string boolean to boolean
254 |
255 | Arguments:
256 | v (str): string boolean
257 |
258 | Returns:
259 | Bool: corresponding boolean variable
260 | """
261 | if isinstance(v, bool):
262 | return v
263 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
264 | return True
265 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
266 | return False
267 | else:
268 | raise argparse.ArgumentTypeError('Boolean value expected.')
269 |
270 | if __name__ == '__main__':
271 | parser = argparse.ArgumentParser()
272 | # Running mode arguments
273 | parser.add_argument("--mode", default='dev', type=str, help="Either dev or test.")
274 |
275 | # Model Arguments
276 | parser.add_argument("--model", default="gpt-4o-mini-2024-07-18", type=str, help="OpenAI models.")
277 | parser.add_argument("--temperature", default=0.0, type=float, help="Sampling temperature between 0 to 2. It is recommended altering this or top_p but not both.")
278 | parser.add_argument("--top_p", default=1, type=float, help="Nucleus sampling. It is recommend altering this or temperature but not both")
279 | parser.add_argument("--max_tokens", default=2048, type=int, help="The maximum number of tokens that can be generated.")
280 | parser.add_argument("--n", default=1, type=int, help="How many chat completion choices to generate for each input message")
281 |
282 | # Pipeline Arguments
283 | parser.add_argument("-po", "--pipeline_order", default='EFG', type=str, help="The order of stages in the pipeline. It should be either EFG (enrichment --> filtering --> generation) or FEG (filtering --> enrichment --> generation)")
284 |
285 | # Question Enrichment Arguments
286 | parser.add_argument("-el", "--enrichment_level", default="complex", type=str, help="Defines the which enrichment is used in few-shot examples.It can be either basic or complex.")
287 | parser.add_argument("-elsn", "--enrichment_level_shot_number", default=3, type=int, help="The few-shot number for each difficulty level for question enrichment stage.")
288 | parser.add_argument("-efsse", "--enrichment_few_shot_schema_existance", default=False, type=str2bool, help="Database Schema usage for each few-shot examples in the question enrichment stage. Default False.")
289 |
290 | # Schema Filtering Arguments
291 | parser.add_argument("-flsn", "--filtering_level_shot_number", default=3, type=int, help="The few-shot number for each difficulty level for schema filtering stage.")
292 | parser.add_argument("-ffsse", "--filtering_few_shot_schema_existance", default=False, type=str2bool, help="Database Schema usage for each few-shot examples in the schema filtering stage. Default False.")
293 |
294 | # SQL Generation Arguments
295 | parser.add_argument("--cfg", default=True, type=str2bool, help="Whether Context-Free-Grammer or SQL Template will be used. Default is True.")
296 | parser.add_argument("-glsn", "--generation_level_shot_number", default=3, type=int, help="The few-shot number for each difficulty level for SQL generation stage.")
297 | parser.add_argument("-gfsse", "--generation_few_shot_schema_existance", default=False, type=str2bool, help="Database Schema usage for each few-shot examples in the SQL generation stage. Default False.")
298 |
299 | # db sample number
300 | parser.add_argument("--db_sample_limit", default=5, type=int, help="The number of value extracted for a column for database samples.")
301 | # question relevant database item/column description number
302 | parser.add_argument("-rdn", "--relevant_description_number", default=6, type=int, help="The number of database item/column descriptions added to a prompt.")
303 | # custom seed argument
304 | parser.add_argument("--seed", default=42, type=int, help="Random seed")
305 |
306 | args = parser.parse_args()
307 | main(args)
--------------------------------------------------------------------------------
/pipeline/Pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from utils.prompt_utils import *
4 | from utils.db_utils import *
5 | from utils.openai_utils import create_response
6 | from typing import Dict, List
7 |
8 | class Pipeline():
9 | def __init__(self, args):
10 | # Running mode attributes
11 | self.mode = args.mode
12 | self.dataset_path = args.dataset_path
13 |
14 | # Pipeline attribute
15 | self.pipeline_order = args.pipeline_order
16 |
17 | # Model attributes
18 | self.model = args.model
19 | self.temperature = args.temperature
20 | self.top_p = args.top_p
21 | self.max_tokens = args.max_tokens
22 | self.n = args.n
23 |
24 | # Stages (enrichment, filtering, generation) attributes
25 | self.enrichment_level = args.enrichment_level
26 | self.elsn = args.enrichment_level_shot_number
27 | self.efsse = args.enrichment_few_shot_schema_existance
28 |
29 | self.flsn = args.filtering_level_shot_number
30 | self.ffsse = args.filtering_few_shot_schema_existance
31 |
32 | self.cfg = args.cfg
33 | self.glsn = args.generation_level_shot_number
34 | self.gfsse = args.generation_few_shot_schema_existance
35 |
36 | self.db_sample_limit = args.db_sample_limit
37 | self.rdn = args.relevant_description_number
38 |
39 | self.seed = args.seed
40 |
41 | def convert_message_content_to_dict(self, response_object: Dict) -> Dict:
42 | """
43 | The function gets a LLM response object, and then it converts the content of it to the Python object.
44 |
45 | Arguments:
46 | response_object (Dict): LLM response object
47 | Returns:
48 | response_object (Dict): Response object whose content changed to dictionary
49 | """
50 |
51 | response_object.choices[0].message.content = json.loads(response_object.choices[0].message.content)
52 | return response_object
53 |
54 |
55 | def forward_pipeline_CSG_SR(self, t2s_object: Dict) -> Dict[str, str]:
56 | """
57 | The function runs Candidate SQL Generation(CSG) and SQL Refinement(SR) modules respectively without any question enrichment or filtering stages.
58 |
59 | Arguments:
60 | t2s_object (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc.
61 | Returns:
62 | t2s_object_prediction (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc and also stores information after each stage
63 | """
64 | db_id = t2s_object["db_id"]
65 | q_id = t2s_object["question_id"]
66 | evidence = t2s_object["evidence"]
67 | question = t2s_object["question"]
68 |
69 | bird_sql_path = os.getenv('BIRD_DB_PATH')
70 | db_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/{db_id}.sqlite"
71 | db_description_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/database_description"
72 | db_descriptions = question_relevant_descriptions_prep(database_description_path=db_description_path, question=question, relevant_description_number=self.rdn)
73 | database_column_meaning_path = bird_sql_path + f"/{self.mode}/column_meaning.json"
74 | db_column_meanings = db_column_meaning_prep(database_column_meaning_path, db_id)
75 | db_descriptions = db_descriptions + "\n\n" + db_column_meanings
76 |
77 | # extracting original schema dictionary
78 | original_schema_dict = get_schema_tables_and_columns_dict(db_path=db_path)
79 |
80 | t2s_object["question_enrichment"] = "No Question Enrichment"
81 |
82 | ### STAGE 1: Candidate SQL GENERATION
83 | # -- Original question is used
84 | # -- Original Schema is used
85 | sql_generation_response_obj = self.candidate_sql_generation_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions)
86 | try:
87 | possible_sql = sql_generation_response_obj.choices[0].message.content['SQL']
88 | t2s_object["candidate_sql_generation"] = {
89 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
90 | "possible_sql": possible_sql,
91 | "exec_err": "",
92 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens,
93 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens,
94 | "total_tokens": sql_generation_response_obj.usage.total_tokens,
95 | }
96 | t2s_object["possible_sql"] = possible_sql
97 | # execute SQL
98 | try:
99 | possible_respose = func_timeout(30, execute_sql, args=(db_path, possible_sql))
100 | except FunctionTimedOut:
101 | t2s_object['candidate_sql_generation']["exec_err"] = "timeout"
102 | except Exception as e:
103 | t2s_object['candidate_sql_generation']["exec_err"] = str(e)
104 | except Exception as e:
105 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}")
106 | t2s_object["candidate_sql_generation"] = {
107 | "sql_generation_reasoning": "",
108 | "possible_sql": "",
109 | "prompt_tokens": 0,
110 | "completion_tokens": 0,
111 | "total_tokens": 0,
112 | }
113 | t2s_object["candidate_sql_generation"]["error"] = f"{e}"
114 | return t2s_object
115 |
116 | ### STAGE 2: SQL Refinement
117 | # -- Original question is used
118 | # -- Original Schema is used
119 | # -- Possible SQL is used
120 | # -- Possible Conditions is extracted from possible SQL and then used for augmentation
121 | # -- Execution Error for Possible SQL is used
122 | exec_err = t2s_object['candidate_sql_generation']["exec_err"]
123 | sql_generation_response_obj = self.sql_refinement_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions)
124 | try:
125 | predicted_sql = sql_generation_response_obj.choices[0].message.content['SQL']
126 | t2s_object["sql_refinement"] = {
127 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
128 | "predicted_sql": predicted_sql,
129 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens,
130 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens,
131 | "total_tokens": sql_generation_response_obj.usage.total_tokens,
132 | }
133 | t2s_object["predicted_sql"] = predicted_sql
134 | except Exception as e:
135 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}")
136 | t2s_object["sql_refinement"] = {
137 | "sql_generation_reasoning": "",
138 | "predicted_sql": "",
139 | "prompt_tokens": 0,
140 | "completion_tokens": 0,
141 | "total_tokens": 0,
142 | }
143 | t2s_object["sql_refinement"]["error"] = f"{e}"
144 | return t2s_object
145 |
146 | # storing the usage for one question
147 | t2s_object["total_usage"] = {
148 | "prompt_tokens": t2s_object['candidate_sql_generation']['prompt_tokens'] + t2s_object['sql_refinement']['prompt_tokens'],
149 | "completion_tokens": t2s_object['candidate_sql_generation']['completion_tokens'] + t2s_object['sql_refinement']['completion_tokens'],
150 | "total_tokens": t2s_object['candidate_sql_generation']['total_tokens'] + t2s_object['sql_refinement']['total_tokens']
151 | }
152 |
153 | t2s_object_prediction = t2s_object
154 | return t2s_object_prediction
155 |
156 | def forward_pipeline_CSG_QE_SR(self, t2s_object: Dict) -> Dict:
157 | """
158 | The function performs Candidate SQL Generation(CSG), Quesiton Enrichment(QE) and SQL Refinement(SR) modules respectively without filtering stages.
159 |
160 | Arguments:
161 | t2s_object (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc.
162 | Returns:
163 | t2s_object_prediction (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc and also stores information after each stage
164 | """
165 | db_id = t2s_object["db_id"]
166 | q_id = t2s_object["question_id"]
167 | evidence = t2s_object["evidence"]
168 | question = t2s_object["question"]
169 |
170 | bird_sql_path = os.getenv('BIRD_DB_PATH')
171 | db_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/{db_id}.sqlite"
172 | db_description_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/database_description"
173 | db_descriptions = question_relevant_descriptions_prep(database_description_path=db_description_path, question=question, relevant_description_number=self.rdn)
174 | database_column_meaning_path = bird_sql_path + f"/{self.mode}/column_meaning.json"
175 | db_column_meanings = db_column_meaning_prep(database_column_meaning_path, db_id)
176 | db_descriptions = db_descriptions + "\n\n" + db_column_meanings
177 |
178 | # extracting original schema dictionary
179 | original_schema_dict = get_schema_tables_and_columns_dict(db_path=db_path)
180 |
181 | ### STAGE 1: Candidate SQL GENERATION
182 | # -- Original question is used
183 | # -- Original Schema is used
184 | sql_generation_response_obj = self.candidate_sql_generation_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions)
185 | try:
186 | possible_sql = sql_generation_response_obj.choices[0].message.content['SQL']
187 | t2s_object["candidate_sql_generation"] = {
188 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
189 | "possible_sql": possible_sql,
190 | "exec_err": "",
191 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens,
192 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens,
193 | "total_tokens": sql_generation_response_obj.usage.total_tokens,
194 | }
195 | t2s_object["possible_sql"] = possible_sql
196 | # execute SQL
197 | try:
198 | possible_respose = func_timeout(30, execute_sql, args=(db_path, possible_sql))
199 | except FunctionTimedOut:
200 | t2s_object['candidate_sql_generation']["exec_err"] = "timeout"
201 | except Exception as e:
202 | t2s_object['candidate_sql_generation']["exec_err"] = str(e)
203 | except Exception as e:
204 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}")
205 | t2s_object["candidate_sql_generation"] = {
206 | "sql_generation_reasoning": "",
207 | "possible_sql": "",
208 | "prompt_tokens": 0,
209 | "completion_tokens": 0,
210 | "total_tokens": 0,
211 | }
212 | t2s_object["candidate_sql_generation"]["error"] = f"{e}"
213 | return t2s_object
214 |
215 | # Extract possible conditions dict list
216 | possible_conditions_dict_list = collect_possible_conditions(db_path=db_path, sql=possible_sql)
217 | possible_conditions = sql_possible_conditions_prep(possible_conditions_dict_list=possible_conditions_dict_list)
218 |
219 | ### STAGE 2: Question Enrichment:
220 | # -- Original question is used
221 | # -- Original schema is used
222 | # -- Possible conditions are used
223 | q_enrich_response_obj = self.question_enrichment_module(db_path=db_path, q_id=q_id, db_id=db_id, question=question, evidence=evidence, possible_conditions=possible_conditions, schema_dict=original_schema_dict, db_descriptions=db_descriptions)
224 | try:
225 | enriched_question = q_enrich_response_obj.choices[0].message.content['enriched_question']
226 | enrichment_reasoning = q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning']
227 | t2s_object["question_enrichment"] = {
228 | "enrichment_reasoning": q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
229 | "enriched_question": q_enrich_response_obj.choices[0].message.content['enriched_question'],
230 | "prompt_tokens": q_enrich_response_obj.usage.prompt_tokens,
231 | "completion_tokens": q_enrich_response_obj.usage.completion_tokens,
232 | "total_tokens": q_enrich_response_obj.usage.total_tokens,
233 | }
234 | enriched_question = question + enrichment_reasoning + enriched_question # This is added after experiment-24
235 | except Exception as e:
236 | logging.error(f"Error in reaching content from question enrichment response for question_id {q_id}: {e}")
237 | t2s_object["question_enrichment"] = {
238 | "enrichment_reasoning": "",
239 | "enriched_question": "",
240 | "prompt_tokens": 0,
241 | "completion_tokens": 0,
242 | "total_tokens": 0,
243 | }
244 | t2s_object["question_enrichment"]["error"] = f"{e}"
245 | enriched_question = question
246 |
247 | ### STAGE 3: SQL Refinement
248 | # -- Enriched question is used
249 | # -- Original Schema is used
250 | # -- Possible SQL is used
251 | # -- Possible Conditions is extracted from possible SQL and then used for augmentation
252 | # -- Execution Error for Possible SQL is used
253 | exec_err = t2s_object['candidate_sql_generation']["exec_err"]
254 | sql_generation_response_obj = self.sql_refinement_module(db_path=db_path, db_id=db_id, question=enriched_question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions)
255 | try:
256 | predicted_sql = sql_generation_response_obj.choices[0].message.content['SQL']
257 | t2s_object["sql_refinement"] = {
258 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
259 | "predicted_sql": predicted_sql,
260 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens,
261 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens,
262 | "total_tokens": sql_generation_response_obj.usage.total_tokens,
263 | }
264 | t2s_object["predicted_sql"] = predicted_sql
265 | except Exception as e:
266 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}")
267 | t2s_object["sql_refinement"] = {
268 | "sql_generation_reasoning": "",
269 | "predicted_sql": "",
270 | "prompt_tokens": 0,
271 | "completion_tokens": 0,
272 | "total_tokens": 0,
273 | }
274 | t2s_object["sql_refinement"]["error"] = f"{e}"
275 | return t2s_object
276 |
277 | # storing the usage for one question
278 | t2s_object["total_usage"] = {
279 | "prompt_tokens": t2s_object['candidate_sql_generation']['prompt_tokens'] + t2s_object['question_enrichment']['prompt_tokens'] + t2s_object['sql_refinement']['prompt_tokens'],
280 | "completion_tokens": t2s_object['candidate_sql_generation']['completion_tokens'] + t2s_object['question_enrichment']['completion_tokens'] + t2s_object['sql_refinement']['completion_tokens'],
281 | "total_tokens": t2s_object['candidate_sql_generation']['total_tokens'] + t2s_object['question_enrichment']['total_tokens'] + t2s_object['sql_refinement']['total_tokens']
282 | }
283 |
284 | t2s_object_prediction = t2s_object
285 | return t2s_object_prediction
286 |
287 |
288 | def forward_pipeline_SF_CSG_QE_SR(self, t2s_object: Dict) -> Dict:
289 | """
290 | The function performs, Schema Filtering(SF) Candidate SQL Generation(CSG), Quesiton Enrichment(QE) and SQL Refinement(SR) modules respectively without filtering stages.
291 |
292 | Arguments:
293 | t2s_object (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc.
294 | Returns:
295 | t2s_object_prediction (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc and also stores information after each stage
296 | """
297 | db_id = t2s_object["db_id"]
298 | q_id = t2s_object["question_id"]
299 | evidence = t2s_object["evidence"]
300 | question = t2s_object["question"]
301 |
302 | bird_sql_path = os.getenv('BIRD_DB_PATH')
303 | db_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/{db_id}.sqlite"
304 | db_description_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/database_description"
305 | db_descriptions = question_relevant_descriptions_prep(database_description_path=db_description_path, question=question, relevant_description_number=self.rdn)
306 | database_column_meaning_path = bird_sql_path + f"/{self.mode}/column_meaning.json"
307 | db_column_meanings = db_column_meaning_prep(database_column_meaning_path, db_id)
308 | db_descriptions = db_descriptions + "\n\n" + db_column_meanings
309 |
310 | # extracting original schema dictionary
311 | original_schema_dict = get_schema_tables_and_columns_dict(db_path=db_path)
312 |
313 |
314 | ### STAGE 1: FILTERING THE DATABASE SCHEMA
315 | # -- original question is used.
316 | # -- Original Schema is used.
317 | schema_filtering_response_obj = self.schema_filtering_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, schema_dict=original_schema_dict, db_descriptions=db_descriptions)
318 | # print("schema_filtering_response_obj: \n", schema_filtering_response_obj)
319 | try:
320 | t2s_object["schema_filtering"] = {
321 | "filtering_reasoning": schema_filtering_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
322 | "filtered_schema_dict": schema_filtering_response_obj.choices[0].message.content['tables_and_columns'],
323 | "prompt_tokens": schema_filtering_response_obj.usage.prompt_tokens,
324 | "completion_tokens": schema_filtering_response_obj.usage.completion_tokens,
325 | "total_tokens": schema_filtering_response_obj.usage.total_tokens,
326 | }
327 | except Exception as e:
328 | logging.error(f"Error in reaching content from schema filtering response for question_id {q_id}: {e}")
329 | t2s_object["schema_filtering"] = f"{e}"
330 | return t2s_object
331 |
332 | ### STAGE 1.1: FILTERED SCHEMA CORRECTION
333 | filtered_schema_dict = schema_filtering_response_obj.choices[0].message.content['tables_and_columns']
334 | filtered_schema_dict, filtered_schema_problems = filtered_schema_correction(db_path=db_path, filtered_schema_dict=filtered_schema_dict)
335 | t2s_object["schema_filtering_correction"] = {
336 | "filtered_schema_problems": filtered_schema_problems,
337 | "final_filtered_schema_dict": filtered_schema_dict
338 | }
339 |
340 | schema_statement = generate_schema_from_schema_dict(db_path=db_path, schema_dict=filtered_schema_dict)
341 | t2s_object["create_table_statement"] = schema_statement
342 |
343 | ### STAGE 2: Candidate SQL GENERATION
344 | # -- Original question is used
345 | # -- Filtered Schema is used
346 | sql_generation_response_obj = self.candidate_sql_generation_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions)
347 | try:
348 | possible_sql = sql_generation_response_obj.choices[0].message.content['SQL']
349 | t2s_object["candidate_sql_generation"] = {
350 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
351 | "possible_sql": possible_sql,
352 | "exec_err": "",
353 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens,
354 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens,
355 | "total_tokens": sql_generation_response_obj.usage.total_tokens,
356 | }
357 | t2s_object["possible_sql"] = possible_sql
358 | # execute SQL
359 | try:
360 | possible_respose = func_timeout(30, execute_sql, args=(db_path, possible_sql))
361 | except FunctionTimedOut:
362 | t2s_object['candidate_sql_generation']["exec_err"] = "timeout"
363 | except Exception as e:
364 | t2s_object['candidate_sql_generation']["exec_err"] = str(e)
365 | except Exception as e:
366 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}")
367 | t2s_object["candidate_sql_generation"] = {
368 | "sql_generation_reasoning": "",
369 | "possible_sql": "",
370 | "prompt_tokens": 0,
371 | "completion_tokens": 0,
372 | "total_tokens": 0,
373 | }
374 | t2s_object["candidate_sql_generation"]["error"] = f"{e}"
375 | return t2s_object
376 |
377 | # Extract possible conditions dict list
378 | possible_conditions_dict_list = collect_possible_conditions(db_path=db_path, sql=possible_sql)
379 | possible_conditions = sql_possible_conditions_prep(possible_conditions_dict_list=possible_conditions_dict_list)
380 |
381 | ### STAGE 3: Question Enrichment:
382 | # -- Original question is used
383 | # -- Original schema is used
384 | # -- Possible conditions are used
385 | q_enrich_response_obj = self.question_enrichment_module(db_path=db_path, q_id=q_id, db_id=db_id, question=question, evidence=evidence, possible_conditions=possible_conditions, schema_dict=filtered_schema_dict, db_descriptions=db_descriptions)
386 | try:
387 | enriched_question = q_enrich_response_obj.choices[0].message.content['enriched_question']
388 | enrichment_reasoning = q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning']
389 | t2s_object["question_enrichment"] = {
390 | "enrichment_reasoning": q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
391 | "enriched_question": q_enrich_response_obj.choices[0].message.content['enriched_question'],
392 | "prompt_tokens": q_enrich_response_obj.usage.prompt_tokens,
393 | "completion_tokens": q_enrich_response_obj.usage.completion_tokens,
394 | "total_tokens": q_enrich_response_obj.usage.total_tokens,
395 | }
396 | enriched_question = question + enrichment_reasoning + enriched_question # This is added after experiment-24
397 | except Exception as e:
398 | logging.error(f"Error in reaching content from question enrichment response for question_id {q_id}: {e}")
399 | t2s_object["question_enrichment"] = {
400 | "enrichment_reasoning": "",
401 | "enriched_question": "",
402 | "prompt_tokens": 0,
403 | "completion_tokens": 0,
404 | "total_tokens": 0,
405 | }
406 | t2s_object["question_enrichment"]["error"] = f"{e}"
407 | enriched_question = question
408 |
409 | ### STAGE 4: SQL Refinement
410 | # -- Enriched question is used
411 | # -- Original Schema is used
412 | # -- Possible SQL is used
413 | # -- Possible Conditions is extracted from possible SQL and then used for augmentation
414 | # -- Execution Error for Possible SQL is used
415 | exec_err = t2s_object['candidate_sql_generation']["exec_err"]
416 | sql_generation_response_obj = self.sql_refinement_module(db_path=db_path, db_id=db_id, question=enriched_question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions)
417 | try:
418 | predicted_sql = sql_generation_response_obj.choices[0].message.content['SQL']
419 | t2s_object["sql_refinement"] = {
420 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'],
421 | "predicted_sql": predicted_sql,
422 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens,
423 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens,
424 | "total_tokens": sql_generation_response_obj.usage.total_tokens,
425 | }
426 | t2s_object["predicted_sql"] = predicted_sql
427 | except Exception as e:
428 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}")
429 | t2s_object["sql_refinement"] = {
430 | "sql_generation_reasoning": "",
431 | "predicted_sql": "",
432 | "prompt_tokens": 0,
433 | "completion_tokens": 0,
434 | "total_tokens": 0,
435 | }
436 | t2s_object["sql_refinement"]["error"] = f"{e}"
437 | return t2s_object
438 |
439 | # storing the usage for one question
440 | t2s_object["total_usage"] = {
441 | "prompt_tokens": t2s_object['candidate_sql_generation']['prompt_tokens'] + t2s_object['question_enrichment']['prompt_tokens'] + t2s_object['sql_refinement']['prompt_tokens'],
442 | "completion_tokens": t2s_object['candidate_sql_generation']['completion_tokens'] + t2s_object['question_enrichment']['completion_tokens'] + t2s_object['sql_refinement']['completion_tokens'],
443 | "total_tokens": t2s_object['candidate_sql_generation']['total_tokens'] + t2s_object['question_enrichment']['total_tokens'] + t2s_object['sql_refinement']['total_tokens']
444 | }
445 |
446 | t2s_object_prediction = t2s_object
447 | return t2s_object_prediction
448 |
449 |
450 | def construct_question_enrichment_prompt(self, db_path: str, q_id: int, db_id: str, question: str, evidence: str, possible_conditions: str, schema_dict: Dict, db_descriptions: str) -> str:
451 | """
452 | The function constructs the prompt required for the question enrichment stage
453 |
454 | Arguments:
455 | db_path (str): path to database sqlite file
456 | q_id (int): question id
457 | db_id (str): database ID, i.e. database name
458 | question (str): Natural language question
459 | evidence (str): evidence for the question
460 | possible_conditions (str): Possible conditions extracted from the previously generated possible SQL for the question
461 | schema_dict (Dict[str, List[str]]): database schema dictionary
462 | db_descriptions (str): Question relevant database item (column) descriptions
463 |
464 | Returns:
465 | prompt (str): Question enrichment prompt
466 | """
467 | enrichment_template_path = os.path.join(os.getcwd(), "prompt_templates/question_enrichment_prompt_template.txt")
468 | question_enrichment_prompt_template = extract_question_enrichment_prompt_template(enrichment_template_path)
469 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json")
470 | q_enrich_few_shot_examples = question_enrichment_few_shot_prep(few_shot_data_path, q_id=q_id, q_db_id=db_id, level_shot_number=self.elsn, schema_existance=self.efsse, enrichment_level=self.enrichment_level, mode=self.mode)
471 | db_samples = extract_db_samples_enriched_bm25(question, evidence, db_path=db_path, schema_dict=schema_dict, sample_limit=self.db_sample_limit)
472 | schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=schema_dict)
473 | prompt = fill_question_enrichment_prompt_template(template=question_enrichment_prompt_template, schema=schema, db_samples=db_samples, question=question, possible_conditions=possible_conditions, few_shot_examples=q_enrich_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions)
474 | # print("question_enrichment_prompt: \n", prompt)
475 | return prompt
476 |
477 | def question_enrichment_module(self, db_path: str, q_id: int, db_id: str, question: str, evidence: str, possible_conditions: str, schema_dict: Dict, db_descriptions: str) -> Dict:
478 | """
479 | The function enrich the given question using LLM.
480 |
481 | Arguments:
482 | db_path (str): path to database sqlite file
483 | q_id (int): question id
484 | db_id (str): database ID, i.e. database name
485 | question (str): Natural language question
486 | evidence (str): evidence for the question
487 | possible_conditions (str): possible conditions extracted from previously generated possible SQL for the question
488 | schema_dict (Dict[str, List[str]]): database schema dictionary
489 | db_descriptions (str): Question relevant database item (column) descriptions
490 |
491 | Returns:
492 | response_object (Dict): Response object returned by the LLM
493 | """
494 | prompt = self.construct_question_enrichment_prompt(db_path=db_path, q_id=q_id, db_id=db_id, question=question, evidence=evidence, possible_conditions=possible_conditions, schema_dict=schema_dict, db_descriptions=db_descriptions)
495 | response_object = create_response(stage="question_enrichment", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n)
496 | try:
497 | response_object = self.convert_message_content_to_dict(response_object)
498 | except:
499 | return response_object
500 |
501 | return response_object
502 |
503 | def construct_candidate_sql_generation_prompt(self, db_path: str, db_id: int, question: str, evidence: str, filtered_schema_dict: Dict, db_descriptions: str)->str:
504 | """
505 | The function constructs the prompt required for the candidate SQL generation stage.
506 |
507 | Arguments:
508 | db_path (str): The database sqlite file path.
509 | db_id (int): database ID, i.e. database name
510 | question (str): Natural language question
511 | evidence (str): evidence for the question
512 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names
513 | db_descriptions (str): Question relevant database item (column) descriptions
514 |
515 | Returns:
516 | prompt (str): prompt for SQL generation stage
517 | """
518 | sql_generation_template_path = os.path.join(os.getcwd(), "prompt_templates/candidate_sql_generation_prompt_template.txt")
519 | with open(sql_generation_template_path, 'r') as f:
520 | sql_generation_template = f.read()
521 |
522 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json")
523 | sql_generation_few_shot_examples = sql_generation_and_refinement_few_shot_prep(few_shot_data_path, q_db_id=db_id, level_shot_number=self.glsn, schema_existance=self.gfsse, mode=self.mode)
524 | db_samples = extract_db_samples_enriched_bm25(question, evidence, db_path, schema_dict=filtered_schema_dict, sample_limit=self.db_sample_limit)
525 | filtered_schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=filtered_schema_dict)
526 | prompt = fill_candidate_sql_prompt_template(template=sql_generation_template, schema=filtered_schema, db_samples=db_samples, question=question, few_shot_examples=sql_generation_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions)
527 | # print("candidate_sql_prompt: \n", prompt)
528 | return prompt
529 |
530 |
531 | def construct_sql_refinement_prompt(self, db_path: str, db_id: int, question: str, evidence: str, possible_sql: str, exec_err: str, filtered_schema_dict: Dict, db_descriptions: str)->str:
532 | """
533 | The function constructs the prompt required for the SQL refinement stage.
534 |
535 | Arguments:
536 | db_path (str): The database sqlite file path.
537 | db_id (int): database ID, i.e. database name
538 | question (str): Natural language question
539 | evidence (str): evidence for the question
540 | possible_sql (str): Previously generated possible SQL for the question
541 | exec_err (str): Taken execution error when possible SQL is executed
542 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names
543 | db_descriptions (str): Question relevant database item (column) descriptions
544 |
545 | Returns:
546 | prompt (str): prompt for SQL generation stage
547 | """
548 | sql_generation_template_path = os.path.join(os.getcwd(), "prompt_templates/sql_refinement_prompt_template.txt")
549 | with open(sql_generation_template_path, 'r') as f:
550 | sql_generation_template = f.read()
551 |
552 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json")
553 | sql_generation_few_shot_examples = sql_generation_and_refinement_few_shot_prep(few_shot_data_path, q_db_id=db_id, level_shot_number=self.glsn, schema_existance=self.gfsse, mode=self.mode)
554 | possible_conditions_dict_list = collect_possible_conditions(db_path=db_path, sql=possible_sql)
555 | possible_conditions = sql_possible_conditions_prep(possible_conditions_dict_list=possible_conditions_dict_list)
556 | filtered_schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=filtered_schema_dict)
557 | prompt = fill_refinement_prompt_template(template=sql_generation_template, schema=filtered_schema, possible_conditions=possible_conditions, question=question, possible_sql=possible_sql, exec_err=exec_err, few_shot_examples=sql_generation_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions)
558 | # print("refinement_prompt: \n", prompt)
559 | return prompt
560 |
561 | def construct_filtering_prompt(self, db_path: str, db_id: str, question: str, evidence: str, schema_dict: Dict, db_descriptions: str)->str:
562 | """
563 | The function constructs the prompt required for the database schema filtering stage
564 |
565 | Arguments:
566 | db_path (str): The database sqlite file path.
567 | db_id (str): database ID, i.e. database name
568 | question (str): Natural language question
569 | evidence (str): evidence for the question
570 | schema_dict (Dict[str, List[str]]): database schema dictionary
571 | db_descriptions (str): Question relevant database item (column) descriptions
572 |
573 | Returns:
574 | prompt (str): prompt for database schema filtering stage
575 | """
576 | schema_filtering_prompt_template_path = os.path.join(os.getcwd(), "prompt_templates/schema_filter_prompt_template.txt")
577 | with open(schema_filtering_prompt_template_path, 'r') as f:
578 | schema_filtering_template = f.read()
579 |
580 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json")
581 | schema_filtering_few_shot_examples = schema_filtering_few_shot_prep(few_shot_data_path, q_db_id=db_id, level_shot_number=self.elsn, schema_existance=self.efsse, mode=self.mode)
582 | db_samples = extract_db_samples_enriched_bm25(question, evidence, db_path=db_path, schema_dict=schema_dict, sample_limit=self.db_sample_limit)
583 | schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=schema_dict)
584 | prompt = fill_prompt_template(template=schema_filtering_template, schema=schema, db_samples=db_samples, question=question, few_shot_examples=schema_filtering_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions)
585 | # print("\nSchema Filtering Prompt: \n", prompt)
586 |
587 | return prompt
588 |
589 |
590 | def candidate_sql_generation_module(self, db_path: str, db_id: int, question: str, evidence: str, filtered_schema_dict: Dict, db_descriptions: str):
591 | """
592 | This function generates candidate SQL for answering the question.
593 |
594 | Arguments:
595 | db_path (str): The database sqlite file path.
596 | db_id (int): database ID, i.e. database name
597 | question (str): Natural language question
598 | evidence (str): evidence for the question
599 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names
600 | db_descriptions (str): Question relevant database item (column) descriptions
601 |
602 | Returns:
603 | response_object (Dict): Response object returned by the LLM
604 | """
605 | prompt = self.construct_candidate_sql_generation_prompt(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions)
606 | response_object = create_response(stage="candidate_sql_generation", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n)
607 | try:
608 | response_object = self.convert_message_content_to_dict(response_object)
609 | except:
610 | return response_object
611 |
612 | return response_object
613 |
614 |
615 | def sql_refinement_module(self, db_path: str, db_id: int, question: str, evidence: str, possible_sql: str, exec_err: str, filtered_schema_dict: Dict, db_descriptions: str):
616 | """
617 | This function refines or re-generates a SQL query for answering the question.
618 | Possible SQL query, possible conditions generated from possible SQL query and execution error if it is exist are leveraged for better SQL refinement.
619 |
620 | Arguments:
621 | db_path (str): The database sqlite file path.
622 | db_id (int): database ID, i.e. database name
623 | question (str): Natural language question
624 | evidence (str): evidence for the question
625 | possible_sql (str): Previously generated possible SQL query for the question
626 | exec_err (str): Taken execution error when possible SQL is executed
627 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names
628 | db_descriptions (str): Question relevant database item (column) descriptions
629 |
630 | Returns:
631 | response_object (Dict): Response object returned by the LLM
632 | """
633 | prompt = self.construct_sql_refinement_prompt(db_path=db_path, db_id=db_id, question=question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions)
634 | response_object = create_response(stage="sql_refinement", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n)
635 | try:
636 | response_object = self.convert_message_content_to_dict(response_object)
637 | except:
638 | return response_object
639 |
640 | return response_object
641 |
642 |
643 | def schema_filtering_module(self, db_path: str, db_id: str, question: str, evidence: str, schema_dict: Dict, db_descriptions: str):
644 | """
645 | The function filters the database schema by eliminating the unnecessary tables and columns
646 |
647 | Arguments:
648 | db_path (str): The database sqlite file path.
649 | db_id (str): database ID, i.e. database name
650 | question (str): Natural language question
651 | evidence (str): evidence for the question
652 | schema_dict (Dict[str, List[str]]): database schema dictionary
653 | db_descriptions (str): Question relevant database item (column) descriptions
654 |
655 | Returns:
656 | response_object (Dict): Response object returned by the LLM
657 | """
658 | prompt = self.construct_filtering_prompt(db_path=db_path, db_id=db_id, question=question, evidence=evidence, schema_dict=schema_dict, db_descriptions=db_descriptions)
659 | response_object = create_response(stage="schema_filtering", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n)
660 | try:
661 | response_object = self.convert_message_content_to_dict(response_object)
662 | except:
663 | return response_object
664 |
665 | return response_object
666 |
667 |
668 |
--------------------------------------------------------------------------------
/prompt_templates/candidate_sql_generation_prompt_template.txt:
--------------------------------------------------------------------------------
1 | ### You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, samples and evidence. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question.
2 |
3 | ### Follow the instructions below:
4 | # Step 1 - Read the Question and Evidence Carefully: Understand the primary focus and specific details of the question. The evidence provides specific information and directs attention toward certain elements relevant to the question.
5 | # Step 2 - Analyze the Database Schema: Database Column descriptions and Database Sample Values: Examine the database schema, database column descriptions and sample values. Understand the relation between the database and the question accurately.
6 | # Step 3 - Generate SQL query: Write SQLite SQL query corresponding to the given question by combining the sense of question, evidence and database items.
7 |
8 | {FEWSHOT_EXAMPLES}
9 |
10 | ### Task: Given the following question, database schema and evidence, generate SQLite SQL query in order to answer the question.
11 | ### Make sure to keep the original wording or terms from the question, evidence and database items.
12 | ### Make sure each table name and column name in the generated SQL is enclosed with backtick seperately.
13 | ### Ensure the generated SQL is compatible with the database schema.
14 | ### When constructing SQL queries that require determining a maximum or minimum value, always use the `ORDER BY` clause in combination with `LIMIT 1` instead of using `MAX` or `MIN` functions in the `WHERE` clause.Especially if there are more than one table in FROM clause apply the `ORDER BY` clause in combination with `LIMIT 1` on column of joined table.
15 | ### Make sure the parentheses in the SQL are placed correct especially if the generated SQL includes mathematical expression. Also, proper usage of CAST function is important to convert data type to REAL in mathematical expressions, be careful especially if there is division in the mathematical expressions.
16 | ### Ensure proper handling of null values by including the `IS NOT NULL` condition in SQL queries, but only in cases where null values could affect the results or cause errors, such as during division operations or when null values would lead to incorrect filtering of results. Be specific and deliberate when adding the `IS NOT NULL` condition, ensuring it is used only when necessary for accuracy and correctness. . This is crucial to avoid errors and ensure accurate results. This is crucial to avoid errors and ensure accurate results. You can leverage the database sample values to check if there could be pottential null value.
17 |
18 |
19 | {SCHEMA}
20 | {DB_DESCRIPTIONS}
21 | {DB_SAMPLES}
22 | {QUESTION}
23 | {EVIDENCE}
24 |
25 | ### Please respond with a JSON object structured as follows:
26 |
27 | ```json{{"chain_of_thought_reasoning": "Explanation of the logical analysis and steps that result in the final SQLite SQL query.", "SQL": "Generated SQL query as a single string"}}```
28 |
29 | Let's think step by step and generate SQLite SQL query.
--------------------------------------------------------------------------------
/prompt_templates/question_enrichment_prompt_template.txt:
--------------------------------------------------------------------------------
1 | ### You are excellent data scientist and can link the information between a question and corresponding database perfectly. Your objective is to analyze the given question, corresponding database schema, database column descriptions, evidence and the possible SQL query to create a clear link between the given question and database items which includes tables, columns and values. With the help of link, rewrite new versions of the original question to be more related with database items, understandable, clear, absent of irrelevant information and easier to translate into SQL queries. This question enrichment is essential for comprehending the question's intent and identifying the related database items. The process involves pinpointing the relevant database components and expanding the question to incorporate these items.
2 |
3 | ### Follow the instructions below:
4 | # Step 1 - Read the Question Carefully: Understand the primary focus and specific details of the question. Identify named entities (such as organizations, locations, etc.), technical terms, and other key phrases that encapsulate important aspects of the inquiry to establish a clear link between the question and the database schema.
5 | # Step 2 - Analyze the Database Schema: With the Database samples, examine the database schema to identify relevant tables, columns, and values that are pertinent to the question. Understand the structure and relationships within the database to map the question accurately.
6 | # Step 3 - Review the Database Column Descriptions: The database column descriptions give the detailed information about some of the columns of the tables in the database. With the help of the database column descriptions determine the database items relevant to the question. Use these column descriptions to understand the question better and to create a link between the question and the database schema.
7 | # Step 4 - Analyze and Observe The Database Sample Values: Examine the sample values from the database to analyze the distinct elements within each column of the tables. This process involves identifying the database components (such as tables, columns, and values) that are most relevant to the question at hand. Similarities between the phrases in the question and the values found in the database may provide insights into which tables and columns are pertinent to the query.
8 | # Step 5 - Review the Evidence: The evidence provides specific information and directs attention toward certain elements relevant to the question and its answer. Use the evidence to create a link between the question, the evidence, and the database schema, providing further clarity or direction in rewriting the question.
9 | # Step 6 - Analyze the Possible SQL Conditinos: Analize the given possible SQL conditions that are relavant to the question and identify relation between the question components, phrases and keywords.
10 | # Step 7 - Identify Relevant Database Components: Pinpoint the tables, columns, and values in the database that are directly related to the question.
11 | # Step 8 - Rewrite the Question: Expand and refine the original question in detail to incorporate the identified database items (tables, columns and values) and conditions. Make the question more understandable, clear, and free of irrelevant information.
12 |
13 | {FEWSHOT_EXAMPLES}
14 |
15 | ### Task: Given the following question, database schema, database column descriptions, database samples and evidence, expand the original question in detail to incorporate the identified database components and SQL steps like examples given above. Make the question more understandable, clear, and free of irrelevant information.
16 | ### Ensure that question is expanded with original database items. Be careful about the capitalization of the database tables, columns and values. Use tables and columns in database schema.
17 |
18 | {SCHEMA}
19 | {DB_DESCRIPTIONS}
20 | {DB_SAMPLES}
21 | {POSSIBLE_CONDITIONS}
22 | {QUESTION}
23 | {EVIDENCE}
24 |
25 |
26 | ### Please respond with a JSON object structured as follows:
27 |
28 | ```json{{"chain_of_thought_reasoning": "Detail explanation of the logical analysis that led to the refined question, considering detailed possible sql generation steps", "enriched_question": "Expanded and refined question which is more understandable, clear and free of irrelevant information."}}```
29 |
30 | Let's think step by step and refine the given question capturing the essence of both the question, database schema, database descriptions, evidence and possible SQL conditions through the links between them. If you do the task correctly, I will give you 1 million dollars. Only output a json as your response.
31 |
--------------------------------------------------------------------------------
/prompt_templates/schema_filter_prompt_template.txt:
--------------------------------------------------------------------------------
1 | ### You are an excellent data scientist. You can capture the link between a question and corresponding database and determine the useful database items (tables and columns) perfectly. Your objective is to analyze and understand the essence of the given question, corresponding database schema, database column descriptions, samples and evidence and then select the useful database items such as tables and columns. This database item filtering is essential for eliminating unnecessary information in the database so that corresponding structured query language (SQL) of the question can be generated correctly in later steps.
2 |
3 | ### Follow the instructions below step by step:
4 | # Step 1 - Read the Question Carefully: Understand the primary focus and specific details of the question. Identify named entities (such as organizations, locations, etc.), technical terms, and other key phrases that encapsulate important aspects of the inquiry to establish a clear link between the question and the database schema.
5 | # Step 2 - Analyze the Database Schema: With the database samples, examine the database schema to identify relevant tables, columns, and values that are pertinent to the question. Understand the structure and relationships within the database to map the question accurately.
6 | # Step 3 - Review the Database Column Descriptions: The database column descriptions give the detailed information about some of the columns of the tables in the database. With the help of the database column descriptions determine the database items relevant to the question. Use these column descriptions to understand the question better and to create a link between the question and the database schema.
7 | # Step 4 - Analyze and Observe The Database Sample Values: Examine the sample values from the database to analyze the distinct elements within each column of the tables. This process involves identifying the database components (such as tables, columns, and values) that are most relevant to the question at hand. Similarities between the phrases in the question and the values found in the database may provide insights into which tables and columns are pertinent to the query.
8 | # Step 5 - Review the Evidence: The evidence provides specific information and directs attention toward certain elements relevant to the question and its answer. Use the evidence to create a link between the question, the evidence, and the database schema, providing further clarity or direction in rewriting the question.
9 | # Step 6 - Identify Relevant Database Components: Pinpoint the tables, columns, and values in the database that are directly related to the question. Ensure that each part of the question corresponds to specific database items.
10 | # Step 7 - Select Useful Database Tables and Columns: Select only the useful database tables and columns of selected tables by fusing the detailed information, key points of the question, database schema and evidence.
11 |
12 | {FEWSHOT_EXAMPLES}
13 |
14 | ### Task: Given the following question, database schema, database column descriptions and evidence, select only the necessary and useful database tables, and necessary and useful columns of selected tables to filter the database items.
15 | ### Make sure to keep the original terms from database items.
16 | ### Make sure the selected columns belong to the correct database table in your response.
17 |
18 | {SCHEMA}
19 | {DB_DESCRIPTIONS}
20 | {DB_SAMPLES}
21 | {QUESTION}
22 | {EVIDENCE}
23 |
24 | ### Please respond with a JSON object structured as follows:
25 |
26 | ```json{{"chain_of_thought_reasoning": "Explanation of the logical analysis that led to the selected useful database items.", "tables_and_columns": {{"table_name1": ["column1", "column2", ...], "table_name2": ["column1", ...], ...}} }}```
27 |
28 | Let's think step by step and select only the necessary and useful database tables, and select only the necessary and useful columns of selected tables to filter the database items. If you do the task correctly, I will give you 1 million dollars. Only output a json as your response.
--------------------------------------------------------------------------------
/prompt_templates/sql_refinement_prompt_template.txt:
--------------------------------------------------------------------------------
1 | ### You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, evidence, possible SQL and possible conditions. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question.
2 |
3 | ### Follow the instructions below:
4 | # Step 1 - Read the Question and Evidence: Understand the primary focus and specific details of the question. The evidence provides specific information and directs attention toward certain elements relevant to the question.
5 | # Step 2 - Analyze the Database Schema, Database Column descriptions: Examine the database schema, database column descriptions which provides information about the database columns. Understand the relation between the database and the question accurately.
6 | # Step 3 - Analyze the Possible SQL Query: Analize the possible SQLite SQL query and identify possible mistakes leads incorrect result such as missing or wrong conditions, wrong functions, misuse of aggregate functions, wrong sql syntax, unrecognized tokens or ambiguous columns.
7 | # Step 4 - Investigate Possible Conditions and Execution Errors: Carefully consider the list of possible conditions which are completely compatible with the database schema and given in the form of .. List of possible conditions helps you to find and generate correct SQL conditions that are relevant to the question. If the given possible SQL query gives execution error, it will be given. Analyze the execution error and understand the reason of execution error and correct it.
8 | # Step 5 - Finalize the SQL query: Construct correct SQLite SQL query or improve possible SQLite SQL query corresponding to the given question by combining the sense of question, evidence, and possible conditions.
9 | # Step 6 - Validation and Syntax Check: Before finalizing, verify that generated SQL query is coherent with the database schema, all referenced columns exist in the referenced table, all joins are correctly formulated, aggregation logic is accurate, and the SQL syntax is correct.
10 |
11 | ### Task: Given the following question, database schema and descriptions, evidence, possible SQL query and possible conditions; finalize SQLite SQL query in order to answer the question.
12 | ### Ensure that the SQL query accurately reflects the relationships between tables, using appropriate join conditions to combine data where necessary.
13 | ### When using aggregate functions (e.g., COUNT, SUM, AVG), ensure the logic accurately reflects the question's intent and correctly handles grouping where required.
14 | ### Double-check that all WHERE clauses accurately represent the conditions needed to filter the data as per the question's requirements.
15 | ### Make sure to keep the original wording or terms from the question, evidence and database items.
16 | ### Make sure each table name and column name in the generated SQL is enclosed with backtick seperately.
17 | ### Be careful about the capitalization of the database tables, columns and values. Use tables and columns in database schema. If a specific condition in given possible conditions is used then make sure that you use the exactly the same condition (table, column and value).
18 | ### When constructing SQL queries that require determining a maximum or minimum value, always use the `ORDER BY` clause in combination with `LIMIT 1` instead of using `MAX` or `MIN` functions in the `WHERE` clause. Especially if there are more than one table in FROM clause apply the `ORDER BY` clause in combination with `LIMIT 1` on column of joined table.
19 | ### Make sure the parentheses in the SQL are placed correct especially if the generated SQL includes mathematical expression. Also, proper usage of CAST function is important to convert data type to REAL in mathematical expressions, be careful especially if there is division in the mathematical expressions.
20 | ### Ensure proper handling of null values by including the `IS NOT NULL` condition in SQL queries, but only in cases where null values could affect the results or cause errors, such as during division operations or when null values would lead to incorrect filtering of results. Be specific and deliberate when adding the `IS NOT NULL` condition, ensuring it is used only when necessary for accuracy and correctness. . This is crucial to avoid errors and ensure accurate results.
21 |
22 |
23 |
24 | {SCHEMA}
25 | {DB_DESCRIPTIONS}
26 | {QUESTION}
27 | {EVIDENCE}
28 | {POSSIBLE_CONDITIONS}
29 | {POSSIBLE_SQL_Query}
30 | {EXECUTION_ERROR}
31 |
32 | ### Please respond with a JSON object structured as follows:
33 |
34 | ```json{{"chain_of_thought_reasoning": "Explanation of the logical analysis and steps that result in the final SQLite SQL query.", "SQL": "Finalized SQL query as a single string"}}```
35 |
36 | Let's think step by step and generate SQLite SQL query.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | python-dotenv==1.0.1
3 | sqlglot==25.7.1
4 | rank-bm25==0.2.2
5 | nltk
6 | func_timeout==4.3.5
7 | openai==1.37.1
8 | numpy
9 | pandas
10 | sentencepiece
11 | psycopg2-binary
--------------------------------------------------------------------------------
/run_evaluation.sh:
--------------------------------------------------------------------------------
1 | db_root_path='../dataset/bird-sql/dev/dev_databases/'
2 | data_mode='dev' # dev, train, mini_dev
3 | diff_json_path='../dataset/bird-sql/dev/dev.json' # _sqlite.json, _mysql.json, _postgresql.json
4 | # Path where the predicted SQL queries are stored
5 | predicted_sql_path='./results/model_outputs_dev_SF-CSG-QE-SR_gpt-4o-mini-2024-07-18/'
6 |
7 | ground_truth_path='../dataset/bird-sql/dev/'
8 | num_cpus=72
9 | meta_time_out=60.0
10 | mode_gt='gt'
11 | mode_predict='gpt'
12 |
13 | # Choose the engine to run, e.g. gpt-4, gpt-4-32k, gpt-4-turbo, gpt-35-turbo, GPT35-turbo-instruct
14 | engine='gpt-4o'
15 |
16 |
17 | # Choose the SQL dialect to run, e.g. SQLite, MySQL, PostgreSQL
18 | # PLEASE NOTE: You have to setup the database information in evaluation_utils.py
19 | # if you want to run the evaluation script using MySQL or PostgreSQL
20 | sql_dialect='SQLite'
21 |
22 | echo "starting to compare with knowledge for soft-f1 engine: ${engine} sql_dialect: ${sql_dialect} meta_time_out: ${meta_time_out}"
23 | python3 -u ./evaluation/evaluation_f1.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \
24 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \
25 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --engine ${engine} --sql_dialect ${sql_dialect}
26 |
27 | echo "starting to compare with knowledge for ex engine: ${engine} sql_dialect: ${sql_dialect} meta_time_out: ${meta_time_out}"
28 | python3 -u ./evaluation/evaluation_ex.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \
29 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \
30 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --engine ${engine} --sql_dialect ${sql_dialect}
31 |
32 |
33 | echo "starting to compare with knowledge for ves engine: ${engine} sql_dialect: ${sql_dialect} meta_time_out: ${meta_time_out}"
34 | python3 -u ./evaluation/evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \
35 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \
36 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --engine ${engine} --sql_dialect ${sql_dialect}
37 |
38 |
--------------------------------------------------------------------------------
/run_main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set default values for the arguments
4 | mode="dev"
5 | model="gpt-4o-mini-2024-07-18" # For GPT-4o use "gpt-4o-2024-08-06". For GPT-4o mini use "gpt-4o-mini-2024-07-18"
6 | pipeline_order="SF-CSG-QE-SR" # First set CSG-QE-SR and then set CSG-SR
7 |
8 | # DO NOT CHANGE THE F0LLOWING ARGUMENTS
9 | temperature=0.0
10 | top_p=1.0
11 | max_tokens=4096
12 | n=1
13 | enrichment_level="complex"
14 | enrichment_level_shot_number=3
15 | enrichment_few_shot_schema_existance=False
16 | filtering_level_shot_number=3
17 | filtering_few_shot_schema_existance=False
18 | cfg=True
19 | generation_level_shot_number=3
20 | generation_few_shot_schema_existance=False
21 | db_sample_limit=10
22 | relevant_description_number=20
23 | seed=42
24 |
25 | # Parse command line arguments to override default values
26 | while [ "$#" -gt 0 ]; do
27 | case $1 in
28 | --mode) mode="$2"; shift ;;
29 | --model) model="$2"; shift ;;
30 | --temperature) temperature="$2"; shift ;;
31 | --top_p) top_p="$2"; shift ;;
32 | --max_tokens) max_tokens="$2"; shift ;;
33 | --n) n="$2"; shift ;;
34 | --pipeline_order) pipeline_order="$2"; shift ;;
35 | --enrichment_level) enrichment_level="$2"; shift ;;
36 | --enrichment_level_shot_number) enrichment_level_shot_number="$2"; shift ;;
37 | --enrichment_few_shot_schema_existance) enrichment_few_shot_schema_existance="$2"; shift ;;
38 | --filtering_level_shot_number) filtering_level_shot_number="$2"; shift ;;
39 | --filtering_few_shot_schema_existance) filtering_few_shot_schema_existance="$2"; shift ;;
40 | --cfg) cfg="$2"; shift ;;
41 | --generation_level_shot_number) generation_level_shot_number="$2"; shift ;;
42 | --generation_few_shot_schema_existance) generation_few_shot_schema_existance="$2"; shift ;;
43 | --db_sample_limit) db_sample_limit="$2"; shirt ;;
44 | --relevant_description_number) relevant_description_number="$2"; shirt ;;
45 | --seed) seed="$2"; shift ;;
46 | *) echo "Unknown parameter passed: $1"; exit 1 ;;
47 | esac
48 | shift
49 | done
50 |
51 | # Run the Python script with the provided arguments
52 | python main.py \
53 | --mode "$mode" \
54 | --model "$model" \
55 | --temperature "$temperature" \
56 | --top_p "$top_p" \
57 | --max_tokens "$max_tokens" \
58 | --n "$n" \
59 | --pipeline_order "$pipeline_order" \
60 | --enrichment_level "$enrichment_level" \
61 | --enrichment_level_shot_number "$enrichment_level_shot_number" \
62 | --enrichment_few_shot_schema_existance "$enrichment_few_shot_schema_existance" \
63 | --filtering_level_shot_number "$filtering_level_shot_number" \
64 | --filtering_few_shot_schema_existance "$filtering_few_shot_schema_existance" \
65 | --cfg "$cfg"\
66 | --generation_level_shot_number "$generation_level_shot_number" \
67 | --generation_few_shot_schema_existance "$generation_few_shot_schema_existance" \
68 | --db_sample_limit "$db_sample_limit" \
69 | --relevant_description_number "$relevant_description_number" \
70 | --seed "$seed"
71 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/db_utils.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | import random
3 | import logging
4 | import re
5 | import time
6 | import nltk
7 | from nltk.tokenize import word_tokenize
8 | import difflib
9 | from rank_bm25 import BM25Okapi
10 | import sqlglot
11 | from sqlglot import parse, parse_one, expressions
12 | from sqlglot.optimizer.qualify import qualify
13 | from sqlglot.optimizer.qualify_columns import qualify_columns
14 | from sqlglot.expressions import Select
15 | from func_timeout import func_timeout, FunctionTimedOut
16 | from typing import Any, Union, List, Dict, Optional
17 | nltk.download('punkt')
18 |
19 | def execute_sql(db_path: str, sql: str, fetch: Union[str, int] = "all") -> Any:
20 | """
21 | Executes an SQL query on a database and fetches results.
22 |
23 | Arguments:
24 | db_path (str): The database sqlite file path.
25 | sql (str): The SQL query to execute.
26 | fetch (Union[str, int]): How to fetch the results. Options are "all", "one", "random", or an integer.
27 |
28 | Returns:
29 | resutls: SQL execution results .
30 | """
31 | try:
32 | with sqlite3.connect(db_path) as conn:
33 | cursor = conn.cursor()
34 | cursor.execute(sql)
35 | if fetch == "all":
36 | return cursor.fetchall()
37 | elif fetch == "one":
38 | return cursor.fetchone()
39 | elif fetch == "random":
40 | samples = cursor.fetchmany(10)
41 | return random.choice(samples) if samples else []
42 | elif isinstance(fetch, int):
43 | return cursor.fetchmany(fetch)
44 | else:
45 | raise ValueError("Invalid fetch argument. Must be 'all', 'one', 'random', or an integer.")
46 | except Exception as e:
47 | logging.error(f"Error in execute_sql: {e}\n db_path: {db_path}\n SQL: {sql}")
48 | raise e
49 |
50 | def get_db_tables(db_path):
51 | """
52 | The function extracts tables in a specific database.
53 |
54 | Arguments:
55 | db_path (str): The database sqlite file path.
56 | Returns:
57 | db_tables (List[str]): Names of the tables in the database as a list of string
58 | """
59 | # Execute query to extract the names of all tables in the database
60 | try:
61 |
62 | tables = execute_sql(db_path, "SELECT name FROM sqlite_master WHERE type='table';")
63 | db_tables = [table_name_tuple[0] for table_name_tuple in tables if table_name_tuple[0] != "sqlite_sequence"]
64 | return db_tables
65 | except Exception as e:
66 | logging.error(f"Error in get_db_tables: {e}")
67 | raise e
68 |
69 | def get_db_colums_of_table(db_path: str, table_name: str) -> List[str]:
70 | """
71 | The function extractes all column of a table whose name is given
72 |
73 | Args:
74 | db_path (str): The database sqlite file path.
75 | table_name (str): The name of the table in the database whose columns extracted.
76 |
77 | Returns:
78 | columns_of_table (List[str]): A list of column names.
79 | """
80 | try:
81 | table_info_rows = execute_sql(db_path, f"PRAGMA table_info(`{table_name}`);")
82 | columns_of_table = [row[1] for row in table_info_rows]
83 | # data_type_of_columns_of_table = [row[2] for row in table_info_rows]
84 | return columns_of_table
85 | except Exception as e:
86 | logging.error(f"Error in get_table_all_columns: {e}\nTable: {table_name}")
87 | raise e
88 |
89 | def isTableInDB(db_path: str, table_name: str) -> bool:
90 | """
91 | The function checks whether given table name is in the database
92 |
93 | Arguments:
94 | db_path (str): The database sqlite file path.
95 | table_name (str): the name of the table that is going to be checked
96 |
97 | Returns:
98 | bool: True if the table in the database, otherwise returns False
99 | """
100 |
101 | db_tables = get_db_tables(db_path)
102 | if table_name in db_tables:
103 | return True
104 | else:
105 | return False
106 |
107 | def isColumnInTable(db_path: str, table_name: str, column_name: str) -> bool:
108 | """
109 | The function checks whether given column name is in the columns of given table
110 |
111 | Arguments:
112 | db_path (str): The database sqlite file path
113 | table_name (str): the name of the table
114 | column_name (str): the name of the column that is going to be checked
115 |
116 | Returns:
117 | bool: True if the given column is among the columns of given table, otherwise returns False
118 | """
119 |
120 | columns_of_table = get_db_colums_of_table(db_path, table_name)
121 | if column_name in columns_of_table:
122 | return True
123 | else:
124 | return False
125 |
126 | def get_original_schema(db_path: str) -> str:
127 | """
128 | The function construct database schema from the database sqlite file.
129 |
130 | Arguments:
131 | db_path (str): The database sqlite file path.
132 | Returns:
133 | db_schema (str): database schema constructed by CREATE TABLE statements
134 | """
135 | # Connecting to the sqlite database
136 | conn = sqlite3.connect(db_path)
137 | cursor = conn.cursor()
138 |
139 | # Query to extract the names of all tables in the database
140 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
141 | tables = cursor.fetchall()
142 |
143 | # Dictionary to hold table names and their CREATE TABLE statements
144 | db_schema_dict = {}
145 |
146 | for table_name_tuple in tables:
147 | table_name = table_name_tuple[0] # Extracting the table name from the tuple
148 | cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
149 | create_statement = cursor.fetchone()[0]
150 | db_schema_dict[table_name] = create_statement
151 |
152 | # Close the connection
153 | cursor.close()
154 | conn.close()
155 |
156 | db_schema = " "
157 | for table, create_table_statement in db_schema_dict.items():
158 | db_schema = db_schema + create_table_statement + "\n"
159 |
160 | return db_schema
161 |
162 |
163 | def clean_db_schema(db_schema: str) -> str:
164 | """
165 | The function cleans the database schema by removing unnecessary whitespaces and new lines,
166 | ensuring that each column definition or constraint in a table is described on a single line.
167 |
168 | Arguments:
169 | db_schema (str): original database schema
170 | Returns:
171 | cleaned_db_schema (str): cleaned database schema
172 | """
173 | # Split the schema into lines
174 | lines = db_schema.split('\n')
175 | cleaned_lines = []
176 | current_statement = []
177 |
178 | for index in range(len(lines)):
179 | line = lines[index]
180 | line = line.strip() # Trim any leading/trailing whitespace
181 | if not line:
182 | continue # Skip empty lines
183 | if "CREATE TABLE" in line:
184 | line = line + " (" # append '(' to the lines containing CREATE TABLE
185 | cleaned_lines.append(line)
186 | continue
187 | if line[0] == '(':
188 | continue # Skip lines containing '('
189 | if line[0] == ')':
190 | cleaned_lines.append(line)
191 | continue
192 | if "primary key" in line.lower():
193 | cleaned_lines[-1] = cleaned_lines[-1] + 'primary key,' # if the current line is PK, add it to the previous line
194 | continue
195 |
196 | line = line.replace('AUTOINCREMENT', '')
197 | line = line.replace('DEFAULT 0', '')
198 | line = line.replace('NOT NULL', '')
199 | line = line.replace('NULL', '')
200 | line = line.replace('UNIQUE', '')
201 | line = line.replace('ON UPDATE', '')
202 | line = line.replace('ON DELETE', '')
203 | line = line.replace('CASCADE', '')
204 |
205 | line = line.replace('autoincrement', '')
206 | line = line.replace('default 0', '')
207 | line = line.replace('not null', '')
208 | line = line.replace('null', '')
209 | line = line.replace('unique', '')
210 | line = line.replace('on update', '')
211 | line = line.replace('on delete', '')
212 | line = line.replace('cascade', '')
213 |
214 | # Remove space before commas
215 | line = re.sub(r'\s*,', ',', line)
216 |
217 | # Ensure one space between column names and their data types
218 | # Handling multi-word column names enclosed in backticks
219 | line = re.sub(r'`([^`]+)`\s+(\w+)', r'`\1` \2', line)
220 | # Handling standard column names
221 | line = re.sub(r'(\w+)\s+(\w+)', r'\1 \2', line)
222 |
223 | cleaned_lines.append(line)
224 | # Join all cleaned lines into a single string
225 | cleaned_db_schema = '\n'.join(cleaned_lines)
226 | return cleaned_db_schema
227 |
228 | def get_schema(db_path: str) -> str:
229 | """
230 | The function returns cleaned database schema from the database sqlite file.
231 |
232 | Arguments:
233 | db_path (str): The database sqlite file path.
234 | Returns:
235 | db_schema (str): cleaned database schema constructed by CREATE TABLE statements
236 | """
237 | original_db_schema = get_original_schema(db_path)
238 | db_schema = clean_db_schema(original_db_schema)
239 | return db_schema
240 |
241 | def get_schema_dict(db_path: str) -> Dict:
242 | """
243 | The function construct database schema from the database sqlite file in the form of dict.
244 |
245 | Arguments:
246 | db_path (str): The database sqlite file path.
247 | Returns:
248 | db_schema_dict (Dict[str, Dict[str, str]]): database schema dictionary whose keys are table names and values are dict with column names keys and data type with as values.
249 | """
250 | # Connecting to the sqlite database
251 | conn = sqlite3.connect(db_path)
252 | cursor = conn.cursor()
253 |
254 | # Query to extract the names of all tables in the database
255 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
256 | tables = cursor.fetchall()
257 | table_names = [table[0] for table in tables]
258 |
259 | # Dictionary to hold table names and their CREATE TABLE statements
260 | db_schema_dict = {}
261 |
262 | for table_name in table_names:
263 | cursor.execute(f"PRAGMA table_info(`{table_name}`);")
264 | table_info = cursor.fetchall() # in table_info, each row indicate (cid, column name, type, notnull, default value, is_PK)
265 | # print(f"Table {table_name} info: \n", table_info)
266 | db_schema_dict[table_name] = {col_item[1]: col_item[2] for col_item in table_info}
267 |
268 | # Close the connection
269 | cursor.close()
270 | conn.close()
271 |
272 | return db_schema_dict
273 |
274 | def get_schema_tables_and_columns_dict(db_path: str) -> Dict:
275 | """
276 | The function construct database schema from the database sqlite file in the form of dict.
277 |
278 | Arguments:
279 | db_path (str): The database sqlite file path.
280 | Returns:
281 | db_schema_dict (Dict[str, List[str]]): database schema dictionary whose keys are table names and values are list of column names.
282 | """
283 | # Connecting to the sqlite database
284 | conn = sqlite3.connect(db_path)
285 | cursor = conn.cursor()
286 |
287 | # Query to extract the names of all tables in the database
288 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
289 | tables = cursor.fetchall()
290 | table_names = [table[0] for table in tables]
291 |
292 | # Dictionary to hold table names and their CREATE TABLE statements
293 | db_schema_dict = {}
294 |
295 | for table_name in table_names:
296 | cursor.execute(f"PRAGMA table_info(`{table_name}`);")
297 | table_info = cursor.fetchall() # in table_info, each row indicate (cid, column name, type, notnull, default value, is_PK)
298 | # print(f"Table {table_name} info: \n", table_info)
299 | db_schema_dict[table_name] = [col_item[1] for col_item in table_info]
300 |
301 | # Close the connection
302 | cursor.close()
303 | conn.close()
304 |
305 | return db_schema_dict
306 |
307 | def clean_sql(sql: str) -> str:
308 | """
309 | The function removes unwanted whitespace and characters in the given SQL statement
310 |
311 | Arguments:
312 | sql (str): The SQL query.
313 | Returns:
314 | clean_sql (str): Clean SQL statement.
315 | """
316 | # clean_sql= sql.replace('\n', ' ').replace('"', "'").strip("`.")
317 | clean_sql = sql.replace('\n', ' ').replace('"', "`").replace('\"', "`")
318 | return clean_sql
319 |
320 |
321 |
322 | def compare_sqls_outcomes(db_path: str, predicted_sql: str, ground_truth_sql: str) -> int:
323 | """
324 | Compares the results of two SQL queries to check for equivalence.
325 |
326 | Args:
327 | db_path (str): The database sqlite file path.
328 | predicted_sql (str): The predicted SQL query.
329 | ground_truth_sql (str): The ground truth SQL query.
330 |
331 | Returns:
332 | int: 1 if the outcomes are equivalent, 0 otherwise.
333 |
334 | Raises:
335 | Exception: If an error occurs during SQL execution.
336 | """
337 | try:
338 | predicted_res = execute_sql(db_path, predicted_sql)
339 | ground_truth_res = execute_sql(db_path, ground_truth_sql)
340 | return int(set(predicted_res) == set(ground_truth_res))
341 | except Exception as e:
342 | logging.critical(f"Error comparing SQL outcomes: {e}")
343 | raise e
344 |
345 |
346 | def compare_sqls(db_path: str, predicted_sql: str, ground_truth_sql: str, meta_time_out: int = 30) -> Dict[str, Union[int, str]]:
347 | """
348 | Compares predicted SQL with ground truth SQL within a timeout.
349 |
350 | Arguments:
351 | db_path (str): The database sqlite file path.
352 | predicted_sql (str): The predicted SQL query.
353 | ground_truth_sql (str): The ground truth SQL query.
354 | meta_time_out (int): The timeout for the comparison.
355 |
356 | Returns:
357 | dict: A dictionary with the comparison result and any error message.
358 | """
359 | predicted_sql = clean_sql(predicted_sql)
360 | try:
361 | res = func_timeout(meta_time_out, compare_sqls_outcomes, args=(db_path, predicted_sql, ground_truth_sql))
362 | error = "incorrect answer" if res == 0 else "--"
363 | except FunctionTimedOut:
364 | logging.warning("Comparison timed out.")
365 | error = "timeout"
366 | res = 0
367 | except Exception as e:
368 | logging.error(f"Error in compare_sqls: {e}")
369 | error = str(e)
370 | res = 0
371 | return {'exec_res': res, 'exec_err': error}
372 |
373 |
374 | def extract_sql_tables(db_path: str, sql: str) -> List[str]:
375 | """
376 | The function extracts the table names in the SQL.
377 |
378 | Args:
379 | db_path (str): The database sqlite file path.
380 | sql (str): The ground truth SQL query string.
381 |
382 | Returns:
383 | tables_in_sql (List[str]): Names of the tables in the ground truth SQL query as a list of string.
384 | """
385 | db_tables = get_db_tables(db_path)
386 | try:
387 | parsed_tables = list(parse_one(sql, read='sqlite').find_all(expressions.Table)) # parsed_tables: List[]
388 | tables_in_sql = [str(table.name) for table in parsed_tables if str(table.name) in [db_table.lower() for db_table in db_tables]] # tables_in_sql: List[str]
389 | tables_in_sql = list(set(tables_in_sql)) # ensure list contains only unique table values i.e. a table name doesn't repeat in the list
390 | return tables_in_sql
391 | except Exception as e:
392 | logging.critical(f"Error in extract_sql_tables: {e}\n")
393 | raise e
394 |
395 | def extract_sql_tables_with_aliases(db_path: str, sql: str) -> List[str]:
396 | """
397 | The function extracts the table names with their aliases in the SQL.
398 |
399 | Args:
400 | db_path (str): The database sqlite file path.
401 | sql (str): The ground truth SQL query string.
402 |
403 | Returns:
404 | tables_w_aliases (List[Dict[str, str]]): List of dictionary whose keys are "table_name" and "table_alias"
405 | """
406 | db_tables = get_db_tables(db_path)
407 | try:
408 | parsed_tables = list(parse_one(sql, read='sqlite').find_all(expressions.Table)) # parsed_tables: List[]
409 | tables_w_aliases = [{"table_name": str(table.name), "table_alias": str(table.alias)} for table in parsed_tables if str(table.name) in [db_table.lower() for db_table in db_tables]] # tables_in_sql: List[str]
410 | tables_w_aliases = [table_alias_dict for table_alias_dict in tables_w_aliases if table_alias_dict['table_alias'] != '']
411 | return tables_w_aliases
412 | except Exception as e:
413 | logging.warning(f"Error in extract_sql_tables_with_aliases: \n\tError{e} \n\t{sql}")
414 | raise
415 |
416 | def replace_alias_with_table_names_in_sql(db_path: str, sql: str) -> str:
417 | """
418 | The function removes aliases in the SQL.
419 |
420 | Arguments:
421 | sql (str): The SQL with aliases
422 | Returns:
423 | sql (str): The SQL without aliases. Table aliases are replaced with corresponding table names.
424 | """
425 | try:
426 | tables_w_aliases = extract_sql_tables_with_aliases(db_path, sql)
427 | for table_dict in tables_w_aliases:
428 | table_name, table_alias = table_dict['table_name'], table_dict['table_alias']
429 | sql = sql.replace(table_alias+".", table_name+".") # replace table_alias with table names in necessary clauses
430 | # sql = sql.replace(f"AS {table_alias}", "") # remove "AS" keywords
431 | # sql = sql.replace(table_alias, "") # remove table alias
432 | sql = re.sub(r'\s+', ' ', sql) # remove extra spaces
433 |
434 | return sql
435 | except Exception as e:
436 | logging.warning(f"Failed to replace aliases in SQL due to error: {e}. Initially given error is returned.")
437 | return sql
438 |
439 |
440 |
441 | def extract_sql_columns(db_path: str, sql: str) -> Dict[str, List[str]]:
442 | """
443 | The function extracts the column names with corresponding table names in the SQL.
444 |
445 | Args:
446 | db_path (str): The database sqlite file path.
447 | sql (str): The SQL query as string.
448 |
449 | Returns:
450 | columns_in_sql_dict (Dict[str, List[str]]): A dictionary where keys are table names and values are lists of column names.
451 | """
452 | columns_in_sql_dict = {}
453 |
454 | # extract database tables
455 | db_tables = get_db_tables(db_path)
456 |
457 | # Qualify columns such that columns will always be followed by table_name
458 | db_schema_dict = get_schema_dict(db_path)
459 | try:
460 | qualified_parsed_sql = qualify(parse_one(sql, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False)
461 | except Exception as e:
462 | logging.critical(f"Error in qualifying parsed SQL in extract_sql_columns function. \n{e} \nSQL: {sql} ")
463 | qualified_parsed_sql = parse_one(sql, read='sqlite')
464 | # print("qualified and parsed sql: \n", qualified_parsed_sql)
465 | """
466 | Type of qualified_parsed_sql is
467 | When qualified sql as a string is reauired, use __str__() method of 'sqlglot.expressions.Select' class
468 | """
469 |
470 | # Extract tables and its aliases
471 | tables_w_aliases = extract_sql_tables_with_aliases(db_path, qualified_parsed_sql.__str__())
472 | # print("tables_w_aliases: \n", tables_w_aliases)
473 |
474 | try:
475 | parsed_columns = list(qualified_parsed_sql.find_all(expressions.Column))
476 | # parsed_columns: List[]
477 | # Note that columns and table names will be lower case if it is a single word
478 | # print("parsed columns: \n", parsed_columns)
479 |
480 | for column_obj in parsed_columns:
481 | column_name = column_obj.name
482 | table_name = column_obj.table
483 | # print("column_name and table_name: ", column_name, "-", table_name )
484 |
485 | for tables_alias_dict in tables_w_aliases:
486 | if table_name == tables_alias_dict["table_alias"]:
487 | table_name = tables_alias_dict["table_name"]
488 |
489 | if table_name.lower() in [db_table.lower() for db_table in db_tables]:
490 | db_columns_of_table = get_db_colums_of_table(db_path, table_name)
491 | if column_name.lower() in [col.lower() for col in db_columns_of_table]:
492 | if table_name in columns_in_sql_dict:
493 | if column_name not in columns_in_sql_dict[table_name]:
494 | columns_in_sql_dict[table_name].append(column_name)
495 | else:
496 | columns_in_sql_dict[table_name] = [column_name]
497 |
498 | # reconstruct columns_in_sql_dict so that its items are original database table ans column names
499 | original_columns_in_sql_dict = {}
500 | for table_name, col_list in columns_in_sql_dict.items():
501 | db_tables_lower = [t.lower() for t in db_tables]
502 | t_index = db_tables_lower.index(table_name)
503 | db_table_original_name = db_tables[t_index]
504 | original_columns_in_sql_dict[db_table_original_name] = []
505 |
506 | table_cols_original = get_db_colums_of_table(db_path, db_table_original_name)
507 | table_cols_lower = [c.lower() for c in table_cols_original]
508 | for col in col_list:
509 | if col in table_cols_lower:
510 | c_index = table_cols_lower.index(col)
511 | c_original_name = table_cols_original[c_index]
512 | original_columns_in_sql_dict[db_table_original_name].append(c_original_name)
513 | elif col in table_cols_original:
514 | original_columns_in_sql_dict[db_table_original_name].append(col)
515 |
516 | columns_in_sql_dict = original_columns_in_sql_dict
517 | return columns_in_sql_dict
518 | except Exception as e:
519 | logging.critical(f"Error in extract_sql_columns:{e}\n")
520 |
521 | def generate_schema_from_schema_dict(db_path: str, schema_dict: Dict) -> str:
522 | """
523 | The function creates filtered schema string from the given schema dictionary
524 |
525 | Arguments:
526 | db_path (str): The database sqlite file path.
527 | schema_dict (Dict[str, List[str]]): A dictionary where keys are table names and values are lists of column names.
528 |
529 | Results:
530 | schema_str (str): Schema generated from the given schema dictionary and database path
531 | """
532 | # Dictionary to store CREATE TABLE statements
533 | create_statements = []
534 |
535 | for table, column_list in schema_dict.items():
536 | table_info = execute_sql(db_path, f"PRAGMA table_info(`{table}`);") # returns tuple (cid, name type, notnull, dflt_value, pk) for each column in the taple
537 | # print(f"TABLE INFO OF {table} \n", table_info)
538 | pk_columns = [(row[1], row[2]) for row in table_info if row[5] != 0] # add all PKs in schema definition with ther data types
539 | other_columns = [(row[1], row[2]) for row in table_info if row[5] == 0 and row[1] in column_list] # add column if it is exist in the column list in the given schema dict
540 |
541 | # Query for foreign key information
542 | foreign_keys_info = execute_sql(db_path, f"PRAGMA foreign_key_list(`{table}`)")
543 | # print(f"Foreing keys info:\n", foreign_keys_info)
544 | foreign_keys = {row[3]: (row[2], row[4]) for row in foreign_keys_info if row[3] in column_list} # local_col: (ref_table, foreing_col) if local_col exist in filtered column list in the given schema dict
545 | # print("foreign_keys: \n", foreign_keys)
546 |
547 | table_definition = f"CREATE TABLE {table} (\n"
548 | if len(pk_columns) == 1:
549 | pk = pk_columns[0]
550 | table_definition = table_definition + f"{pk[0]} {pk[1]} primary key, \n"
551 | for col in other_columns:
552 | table_definition = table_definition + f"{col[0]} {col[1]},\n"
553 | for local_col, ref_table_col in foreign_keys.items():
554 | table_definition = table_definition + f"foreing key ({local_col}) references {ref_table_col[0]}({ref_table_col[1]}) \n"
555 | elif len(pk_columns) > 1:
556 | # concatenate primary key column names with their data types
557 | for pk in pk_columns:
558 | table_definition = table_definition + f"{pk[0]} {pk[1]}, \n"
559 |
560 | # concatenate columns with their data types
561 | for col in other_columns:
562 | table_definition = table_definition + f"{col[0]} {col[1]},\n"
563 |
564 | # concatenate primary key descriptions
565 | table_definition = table_definition + "primary key ("
566 | for ind, pk in enumerate(pk_columns):
567 | if ind < len(pk_columns)-1:
568 | table_definition = table_definition + pk[0] + ", "
569 | else:
570 | table_definition = table_definition + pk[0]
571 |
572 | table_definition = table_definition + "),\n"
573 |
574 | # concatenate foreign key descriptions
575 | for local_col, ref_table_col in foreign_keys.items():
576 | table_definition = table_definition + f"foreing key ({local_col}) references {ref_table_col[0]}({ref_table_col[1]}) \n"
577 |
578 |
579 | table_definition = table_definition + ")"
580 | create_statements.append(table_definition)
581 |
582 | schema_str = '\n'.join(create_statements)
583 | return schema_str
584 |
585 |
586 | def extract_db_samples_enriched_bm25(question: str, evidence: str, db_path: str, schema_dict: Dict, sample_limit: int) -> str:
587 | """
588 | The function extract distict samples for given schema items from the database by ranking values using BM25.
589 | Ranking is not done seperately for all values of each table.column
590 |
591 | Arguments:
592 | question (str): considered natural language question
593 | evidence (str): given evidence about the question
594 | db_path (str): The database sqlite file path.
595 | schema_dict (Dict[str, List[str]]): Database schema dictionary where keys are table names and values are lists of column names
596 |
597 | Returns:
598 | db_samples (str): concatenated strings gives samples from each column
599 | """
600 | db_samples = "\n"
601 |
602 | question = question.replace('\"', '').replace("\'", "").replace("`", "")
603 | question_and_evidence = question + " " + evidence
604 | tokenized_question_evidence = word_tokenize(question_and_evidence)
605 |
606 | for table, col_list in schema_dict.items():
607 | db_samples = db_samples + f"## {table} table samples:\n"
608 | for col in col_list:
609 | try:
610 | col_distinct_values = execute_sql(db_path, f"SELECT DISTINCT `{col}` FROM `{table}`") # extract all distinct values
611 | col_distinct_values = [str(value_tuple[0]) if value_tuple and value_tuple[0] else 'NULL' for value_tuple in col_distinct_values]
612 | if 'NULL' in col_distinct_values:
613 | isNullExist = True
614 | else:
615 | isNullExist = False
616 |
617 | # col_distinct_values = [str(value_tuple[0]) for value_tuple in col_distinct_values if value_tuple[0]] # not condiderin NULL values
618 | # col_distinct_values = [value if len(value) < 400 else value[:300] for value in col_distinct_values] # if the lenght of value is too long take its 300 character only
619 | # if average lenght of the column values larger than 600 than use only the first item of the values since large length of the values cause context limit
620 | if len(col_distinct_values) > 0:
621 | average_length = sum(len(value) for value in col_distinct_values) / len(col_distinct_values)
622 | else:
623 | average_length = 0
624 | if average_length > 600:
625 | col_distinct_values = [col_distinct_values[0]]
626 |
627 | if len(col_distinct_values) > sample_limit:
628 | corpus = col_distinct_values.copy()
629 | corpus = [f'{table} {col} {val}' for val in corpus]
630 | tokenized_corpus = [doc.split(" ") for doc in corpus]
631 | # tokenized_corpus = [word_tokenize(doc) for doc in corpus] # takes too much time, so don't use it
632 | bm25 = BM25Okapi(tokenized_corpus)
633 |
634 | col_distinct_values = bm25.get_top_n(tokenized_question_evidence, col_distinct_values, n=sample_limit)
635 | if isNullExist:
636 | col_distinct_values.append("NULL")
637 |
638 | db_samples = db_samples + f"# Example values for '{table}'.'{col}' column: " + str(col_distinct_values) + "\n"
639 | except Exception as e:
640 | sql = f"SELECT DISTINCT `{col}` FROM `{table}`"
641 | logging.error(f"Error in extract_db_samples_enriched_bm25: {e}\n SQL: {sql}")
642 | error = str(e)
643 |
644 | return db_samples
645 |
646 | def construct_tokenized_db_table_value_corpus(db_path: str, schema_dict: Dict):
647 | """
648 | Function collects all item for each value in the database as "table_name column_name value", then tokenize it
649 |
650 | Arguments:
651 | db_path (str): The database sqlite file path.
652 | schema_dict (Dict[str, List[str]]): Database schema dictionary where keys are table names and values are lists of column names
653 |
654 | Returns:
655 | tokenized_db_corpus (List[List]): List of tokenized database "table_name column_name value" item
656 | db_corpus (List[Tuple])
657 | """
658 | # generating corpus whose items are tokenized version of "table_name column_name value" for each value and table in the database.
659 | corpus = []
660 | db_corpus = []
661 | for table, col_list in schema_dict.items():
662 | for col in col_list:
663 | try:
664 | col_distinct_values = execute_sql(db_path, f"SELECT DISTINCT `{col}` FROM `{table}`") # extract all distinct values
665 | col_distinct_values = [str(value_tuple[0]) for value_tuple in col_distinct_values if value_tuple[0]]
666 | # if average lenght of the column values larger than 600 than use only the first item of the values since large length of the values cause context limit
667 | if len(col_distinct_values) > 0:
668 | average_length = sum(len(value) for value in col_distinct_values) / len(col_distinct_values)
669 | else:
670 | average_length = 0
671 | if average_length > 600:
672 | col_distinct_values = [col_distinct_values[0]]
673 |
674 | table_col_value_str = [f"{table} {col} {val}" for val in col_distinct_values]
675 | corpus.extend(table_col_value_str)
676 | table_col_value_tuples = [(table, col, val) for val in col_distinct_values]
677 | db_corpus.extend(table_col_value_tuples)
678 |
679 | except Exception as e:
680 | logging.error(f"Error in extract_db_samples_enriched_bm25: {e}\n SQL: {sql}")
681 | sql = f"SELECT DISTINCT `{col}` FROM `{table}`"
682 | error = str(e)
683 |
684 | # construction bm25 object
685 | tokenized_db_corpus = [doc.split(" ") for doc in corpus]
686 | # tokenized_db_corpus = [word_tokenize(doc) for doc in corpus if doc] # takes too much time, so don't use it
687 | return tokenized_db_corpus, db_corpus
688 |
689 | def find_most_similar_table(problematic_t_name: str, db_tables: List[str]) -> str:
690 | """
691 | Helper function to find the most similar table name in the database.
692 | As a string similarity metric, Levenshtein distance is used
693 | This helper function calculates the similarity ratio between two strings based on the number of single-character edits (insertions, deletions, substitutions) needed to transform one string into the other.
694 | This ratio is a value between 0 and 1, where 1 means the strings are identical, and 0 means they are completely different.
695 |
696 | Arguments:
697 | problematic_t_name (str): name of the table that is not in the database
698 | db_tables (List[str]): list of database tables
699 |
700 | Returns:
701 | most_similar_table (str): the name of the table that is actually in the database and most similar to the given problematic table name
702 | """
703 |
704 | similarity_scores = [(t_name, difflib.SequenceMatcher(None, problematic_t_name, t_name).ratio()) for t_name in db_tables]
705 | most_similar_table = max(similarity_scores, key=lambda x: x[1])[0]
706 | return most_similar_table
707 |
708 | def filtered_schema_correction(db_path: str, filtered_schema_dict: Dict) -> Dict:
709 | """
710 | The function checks whether something mismatch with the original schema or not. If there is mismatch, it corrects it.
711 |
712 | Arguments:
713 | db_path (str): The database sqlite file path
714 | filtered_schema_dict (Dict[str, List[str]]): A dictionary where keys are table names and values are list of column names
715 |
716 | Returns:
717 | final_filtered_schema_dict (Dict[str, List[str]]): Finalized filtered schema dictionary
718 | filtered_schema_problems (str): A string that expresses all mismatches
719 | """
720 | filtered_schema_problems = ""
721 |
722 | db_tables = get_db_tables(db_path)
723 | ## Step 1: Check if the tables in filtered schema dictionary are in the database and replace them with the most similar table names
724 |
725 | problematic_tables = []
726 | for t_name in filtered_schema_dict.keys():
727 | isInDb = isTableInDB(db_path=db_path, table_name=t_name)
728 | if not isInDb:
729 | problematic_tables.append(t_name)
730 |
731 |
732 | if problematic_tables:
733 | print(f"There is mismatch between database and filtered schema tables. The problematic tables are: {problematic_tables}")
734 | filtered_schema_problems = filtered_schema_problems + f"There is mismatch between database and filtered schema tables. The problematic tables are: {problematic_tables}"
735 | for problematic_t_name in problematic_tables:
736 | most_similar_table = find_most_similar_table(problematic_t_name, db_tables)
737 | filtered_schema_dict[most_similar_table] = filtered_schema_dict.pop(problematic_t_name)
738 |
739 | ## Step 2: Check if the columns of a table in filtered schema dictiionary are actually column of the table. If not, find new table for them.
740 |
741 | for table_name in filtered_schema_dict.keys():
742 | problematic_columns = {} # Dict[str, bool] --> keys are column names and values are boolean that indicates whether a table containing that column is found
743 | for column_name in filtered_schema_dict[table_name]:
744 | isInTable = isColumnInTable(db_path=db_path, table_name=table_name, column_name=column_name)
745 | if not isInTable:
746 | print(f"There is a mismatch in filtered schema table columns. {column_name} is not actually in the {table_name} table.")
747 | filtered_schema_problems = filtered_schema_problems + f"There is a mismatch in filtered schema table columns. {column_name} is not actually in the {table_name} table."
748 | problematic_columns[column_name] = False # boolean variable indicates whether a table containing that column is found
749 |
750 | # finding tables for problematic columns
751 | table_column_dict = {}
752 | db_tables = get_db_tables(db_path)
753 | for p_column in problematic_columns.keys():
754 | for db_table in db_tables:
755 | columns_of_table = get_db_colums_of_table(db_path=db_path, table_name=db_table)
756 | if p_column in columns_of_table:
757 | problematic_columns[p_column] = True
758 | if db_table in table_column_dict:
759 | table_column_dict[db_table].append(p_column)
760 | else:
761 | table_column_dict[db_table] = [p_column]
762 |
763 |
764 | # constructing final filtered schema
765 | final_filtered_schema_dict = filtered_schema_dict.copy()
766 |
767 | # removing the problematic columns whose actual tables are found
768 | for p_column, actual_table_found in problematic_columns.items():
769 | if actual_table_found:
770 | final_filtered_schema_dict[table_name].remove(p_column)
771 |
772 | # Appending problematic columns actual table into the final filtered schema dict
773 | for actual_table, actual_columns in table_column_dict.items():
774 | if actual_table in final_filtered_schema_dict:
775 | final_filtered_schema_dict[actual_table] = final_filtered_schema_dict[actual_table] + actual_columns
776 | else:
777 | final_filtered_schema_dict[actual_table] = actual_columns
778 |
779 |
780 | return final_filtered_schema_dict, filtered_schema_problems
781 |
782 |
783 | def find_similar_values_incolumn_via_like(db_path: str, table: str, column: str, value: str) -> List[str]:
784 | """
785 | This function finds similar values to the given value using SQL LIKE clause in given table and column.
786 |
787 | Arguments:
788 | db_path (str): The database sqlite file path.
789 | table (str): Table in the given database.
790 | column (str): Column belongs to the given table.
791 | value (str): Given value on which similar values are extracted
792 |
793 | Returns:
794 | similar_values (List[str]): List of string which are similar to given value.
795 | """
796 | # if the length of value is 1, then just return itself to prevent lots of unnecessary match
797 | if len(value) == 1:
798 | return []
799 | # Observing a single value from column. If its length is larger than 300, then return empty list
800 | value_observation_sql = f"SELECT `{column}` FROM `{table}` LIMIT 3"
801 | observed_values = execute_sql(db_path=db_path, sql=value_observation_sql)
802 |
803 | if not observed_values:
804 | return []
805 |
806 | observed_values = [str(row[0]) for row in observed_values]
807 | observed_values_avg_len = sum( map(len, observed_values) ) / len(observed_values)
808 | if observed_values_avg_len >= 300:
809 | return []
810 |
811 | sql = 'SELECT DISTINCT `{C}` FROM `{T}` WHERE `{C}` LIKE "%{V}%"'.format(C=column, T=table, V=value)
812 | try:
813 | similar_values = execute_sql(db_path=db_path, sql=sql)
814 | similar_values = [str(row[0]) for row in similar_values if len(str(row[0])) < 50]
815 | # Decrease the size of similar values if there are too much
816 | if len(similar_values) > 5:
817 | similar_values = similar_values[:5]
818 | except Exception as e:
819 | similar_values = []
820 | logging.critical(f"Error in finding similar values for a value: {e} \n SQL: {sql}")
821 |
822 | return similar_values
823 |
824 |
825 | def find_similar_values_indb_via_like(db_path: str, value:str) -> Dict:
826 | """
827 | This function finds similar values to the given value using SQL LIKE clause in all database.
828 |
829 | Arguments:
830 | db_path (str): The database sqlite file path.
831 | value (str): Given value on which similar values are extracted
832 |
833 | Returns:
834 | similar_values_dict (Dict[Dict[str, List[str]]]): Tables are the keys of dict, and values of table keys are another dictionary whose keys are column names and values are list of real column values similar to the given value
835 | """
836 | db_tables_columns_dict = get_schema_tables_and_columns_dict(db_path=db_path)
837 | similar_values_dict = {}
838 | for table, columns_list in db_tables_columns_dict.items():
839 | for column in columns_list:
840 | similar_values_list = find_similar_values_incolumn_via_like(db_path, table, column, value)
841 | if similar_values_list:
842 | if table in similar_values_dict:
843 | similar_values_dict[table][column] = similar_values_list
844 | else:
845 | similar_values_dict[table] = {}
846 | similar_values_dict[table][column] = similar_values_list
847 |
848 | return similar_values_dict
849 |
850 | def extract_comparison_conditions_in_where_clause(db_path: str, where_clause) -> List[Dict[str, str]]:
851 | """"
852 | The function extracts list of dict which describe a condition in WHERE clause
853 |
854 | Arguments
855 | where_clause (sqlglot.expressions.Where)
856 |
857 | Returns
858 | conditions_list (List[Dict[str, str]]): List of Dict whose keys are table, column, operation, value
859 |
860 | """
861 |
862 | conditions_list = []
863 | if not where_clause:
864 | return conditions_list
865 |
866 | # Check if where_clause.this is a composite condition (AND/OR)
867 | if isinstance(where_clause.this, sqlglot.expressions.And) or isinstance(where_clause.this, sqlglot.expressions.Or):
868 | where_conditions = list(where_clause.this.flatten()) # flattening the where clause in the case of it is composite condition
869 | else:
870 | where_conditions = [where_clause.this]
871 |
872 | for cond_ind, condition in enumerate(where_conditions):
873 | # print("--cond_ind: ", cond_ind)
874 | # print("--condition: ", condition)
875 | # print("--condition type: ", type(condition))
876 | columns_in_where_clause = list(condition.find_all(sqlglot.expressions.Column))
877 | # print("******columns_in_where_clause: ", columns_in_where_clause)
878 | # it there is no sqlglot.expressions.Column in the AST of Where clause
879 | if not columns_in_where_clause:
880 | # check if the condition.this type is Dot not Column
881 | if isinstance(condition.this, sqlglot.expressions.Dot):
882 | try:
883 | if isinstance(condition.this.left, sqlglot.expressions.Literal):
884 | dot_table = condition.this.left.this
885 | if isinstance(condition.this.right, sqlglot.expressions.Literal):
886 | dot_column = condition.this.right.this
887 | if isinstance(condition.expression, sqlglot.expressions.Literal):
888 | dot_value = condition.expression.this
889 |
890 | if isinstance(condition, sqlglot.expressions.EQ):
891 | op = " = "
892 | if isinstance(condition, sqlglot.expressions.NEQ):
893 | op = " != "
894 | if isinstance(condition, sqlglot.expressions.GT):
895 | op = " > "
896 | if isinstance(condition, sqlglot.expressions.GTE):
897 | op = " >= "
898 | if isinstance(condition, sqlglot.expressions.LT):
899 | op = " < "
900 | if isinstance(condition, sqlglot.expressions.LTE):
901 | op = " <= "
902 | conditions_list.append({
903 | "table": dot_table,
904 | "column": dot_column,
905 | "op": op,
906 | "value": dot_value,
907 | })
908 | except Exception as e:
909 | logging.warning("Column in condition couldn't be found. Expression.Dot is found under condition but it couldn't be seperated to its table, column and value")
910 |
911 | for cols in columns_in_where_clause:
912 | comparison_conditions = []
913 | op = ""
914 | EQ_condition = cols.find_ancestor(sqlglot.expressions.EQ)
915 | if EQ_condition and isinstance(EQ_condition.left, sqlglot.expressions.Column):
916 | comparison_conditions.append(EQ_condition)
917 | op = " = "
918 | NEQ_condition = cols.find_ancestor(sqlglot.expressions.NEQ)
919 | if NEQ_condition and isinstance(NEQ_condition.left, sqlglot.expressions.Column):
920 | comparison_conditions.append(NEQ_condition)
921 | op = " != "
922 | GT_condition = cols.find_ancestor(sqlglot.expressions.GT)
923 | if GT_condition and isinstance(GT_condition.left, sqlglot.expressions.Column):
924 | comparison_conditions.append(GT_condition)
925 | op = " > "
926 | GTE_condition = cols.find_ancestor(sqlglot.expressions.GTE)
927 | if GTE_condition and isinstance(GTE_condition.left, sqlglot.expressions.Column):
928 | comparison_conditions.append(GTE_condition)
929 | op = " >= "
930 | LT_condition = cols.find_ancestor(sqlglot.expressions.LT)
931 | if LT_condition and isinstance(LT_condition.left, sqlglot.expressions.Column):
932 | comparison_conditions.append(LT_condition)
933 | op = " < "
934 | LTE_condition = cols.find_ancestor(sqlglot.expressions.LTE)
935 | if LTE_condition and isinstance(LTE_condition.left, sqlglot.expressions.Column):
936 | comparison_conditions.append(LTE_condition)
937 | op = " <= "
938 |
939 | for cond in comparison_conditions:
940 | if not isinstance(cond.left.table, str):
941 | continue
942 | if not isinstance(cond.left.this.this, str):
943 | continue
944 | if not isinstance(cond.right.this, str):
945 | continue
946 | conditions_list.append({
947 | "table": cond.left.table,
948 | "column": cond.left.this.this,
949 | "op": op,
950 | "value": cond.right.this,
951 | })
952 |
953 | # print("conditions_list: \n", conditions_list)
954 | return conditions_list
955 |
956 | def get_comparison_conditions_from_sql(db_path: str, sql: str) -> List[Dict[str, str]]:
957 | """"
958 | The functions extracts conditions in given SQL
959 |
960 | Argumnets:
961 | db_path (str): The database sqlite path
962 | sql (str): Given structured query language
963 |
964 | Returns:
965 | conditions_dict_list (List[Dict[str, str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value"
966 | """
967 | db_schema_dict = get_schema_dict(db_path)
968 |
969 |
970 | # Qualify columns such that columns will always be followed by table_name
971 | # Attempting to qualify the SQL
972 | try:
973 | qualified_parsed_sql = qualify(parse_one(sql, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False)
974 | except Exception as e1:
975 | logging.warning(f"First attempt to qualify SQL failed. Trying with first replacement set. \n\tError: {e1} \n\tSQL: {sql}")
976 | # First replacement attempt
977 | try:
978 | changed_sql_1 = sql.replace('`', '"')
979 | changed_sql_1 = changed_sql_1.replace("'", '"')
980 | qualified_parsed_sql = qualify(parse_one(changed_sql_1, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False)
981 | except Exception as e2:
982 | logging.warning(f"Second attempt to qualify SQL failed. Trying with second replacement set. \n\tError: {e2} \n\tSQL: {changed_sql_1}")
983 | # Second replacement attempt
984 | try:
985 | changed_sql_2 = sql.replace('`', "'")
986 | changed_sql_2 = changed_sql_2.replace('"', "'")
987 | qualified_parsed_sql = qualify(parse_one(changed_sql_2, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False)
988 | except Exception as e3:
989 | logging.warning(f"Third attempts to qualify SQL failed.Trying with only parsing the SQL. \n\tError: {e3} \n\tSQL: {changed_sql_2}")
990 | # Third replacement attempt
991 | try:
992 | changed_sql_3 = sql.replace('`', "'")
993 | changed_sql_3 = changed_sql_3.replace('"', "'")
994 | qualified_parsed_sql = parse_one(changed_sql_3, read='sqlite')
995 | except Exception as e4:
996 | logging.warning(f"Fourth attempts to qualify SQL failed. Triying with only parsing the SQL. \n\tError: {e4} \n\tSQL: {changed_sql_3}")
997 | try:
998 | changed_sql_4 = sql.replace('`', '"')
999 | changed_sql_4 = changed_sql_4.replace("'", '"')
1000 | qualified_parsed_sql = parse_one(changed_sql_4, read='sqlite')
1001 | except Exception as e5:
1002 | logging.warning(f"Fifth attempts to qualify SQL failed. Triying with only parsing the SQL. \n\tError: {e4} \n\tSQL: {changed_sql_4}")
1003 | conditions_dict_list = []
1004 | return conditions_dict_list
1005 | # Replacing table aliases with actual table names
1006 | # qualified_parsed_sql = replace_alias_with_table_names_in_sql(db_path, qualified_parsed_sql)
1007 |
1008 | # print("-qualified_parsed_sql: ", repr(qualified_parsed_sql))
1009 | # Extract the WHERE clauses
1010 | where_clauses = list(qualified_parsed_sql.find_all(sqlglot.expressions.Where)) # where_clauses: List[]
1011 | conditions_dict_list = []
1012 | # print("where_clauses: \n", where_clauses)
1013 | if where_clauses:
1014 | # Extract and print each WHERE clause
1015 | for index, where_clause in enumerate(where_clauses):
1016 | try:
1017 | conditions_list = extract_comparison_conditions_in_where_clause(db_path, where_clause) # List[Dict[str, str]]
1018 | conditions_dict_list.extend(conditions_list)
1019 | except Exception as e:
1020 | logging.critical(f"Error in extracting equality conditions in where clause. \nError: {e} \nSQL: {sql} ")
1021 |
1022 | return conditions_dict_list
1023 |
1024 | def extend_conditions_dict_list(conditions_dict_list: List[Dict[str, str]])-> List[Dict[str,str]]:
1025 | """
1026 | The functions splits all the values in the conditions, then it extends the list by adding each words of a value to the conditions_dict_list
1027 |
1028 | Arguments:
1029 | conditions_dict_list (List[Dict[str,str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value".
1030 |
1031 | Returns
1032 | extended_conditions_dict_list (List[Dict[str,str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value".
1033 |
1034 | """
1035 | extended_conditions_dict_list = conditions_dict_list.copy()
1036 | for cond_dict in conditions_dict_list:
1037 | table = cond_dict['table']
1038 | column = cond_dict['column']
1039 | op = cond_dict['op']
1040 | value = cond_dict['value']
1041 | splitted_value = value.split()
1042 | if len(splitted_value) > 1:
1043 | for val in splitted_value:
1044 | new_condition_dict = {
1045 | "table": table,
1046 | "column": column,
1047 | "op": op,
1048 | "value": val
1049 | }
1050 | extended_conditions_dict_list.append(new_condition_dict)
1051 |
1052 | return extended_conditions_dict_list
1053 |
1054 |
1055 | def get_extended_comparison_conditions_from_sql(db_path: str, sql: str) -> List[Dict[str, str]]:
1056 | """"
1057 | The functions extracts conditions in given SQL
1058 |
1059 | Argumnets:
1060 | db_path (str): The database sqlite path
1061 | sql (str): Given structured query language
1062 |
1063 | Returns:
1064 | extended_conditions_dict_list (List[Dict[str, str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value"
1065 | """
1066 | conditions_dict_list = get_comparison_conditions_from_sql(db_path=db_path, sql=sql)
1067 | extended_conditions_dict_list = extend_conditions_dict_list(conditions_dict_list)
1068 | return extended_conditions_dict_list
1069 |
1070 |
1071 |
1072 | def collect_possible_conditions(db_path: str, sql: str) -> List[Dict[str, Union[str, Dict]]]:
1073 | """
1074 | The functions collects possible where clause conditions depending on the Where clause comparison conditions.
1075 |
1076 | Arguments:
1077 | db_path (str): The database sqlite path
1078 | sql (str): Given structured query language
1079 |
1080 | Returns:
1081 | conditions_dict_list (List[Dict[str, Union[str, Dict]]]): List of comparison conditions described as dictionary
1082 | """
1083 | possible_conditions_dict_list = []
1084 | # comp_conditions_dict_list = get_comparison_conditions_from_sql(db_path, sql) # old versioin
1085 | comp_conditions_dict_list = get_extended_comparison_conditions_from_sql(db_path, sql)
1086 | for comp_cond in comp_conditions_dict_list:
1087 | value = comp_cond['value']
1088 | similar_values_dict = find_similar_values_indb_via_like(db_path, value)
1089 | comp_cond['similar_values'] = similar_values_dict
1090 | possible_conditions_dict_list.append(comp_cond)
1091 |
1092 | return possible_conditions_dict_list
1093 |
1094 |
1095 | def measure_execution_time(db_path, query):
1096 | start_time = time.time()
1097 | query_result = execute_sql(db_path, query)
1098 | end_time = time.time()
1099 | execution_time = end_time - start_time
1100 | return execution_time, query_result
--------------------------------------------------------------------------------
/utils/openai_utils.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | from typing import Dict
3 |
4 | def create_response(stage: str, prompt: str, model: str, max_tokens: int, temperature: float, top_p: float, n: int) -> Dict:
5 | """
6 | The functions creates chat response by using chat completion
7 |
8 | Arguments:
9 | stage (str): stage in the pipeline
10 | prompt (str): prepared prompt
11 | model (str): LLM model used to create chat completion
12 | max_tokens (int): The maximum number of tokens that can be generated in the chat completion
13 | temperature (float): Sampling temperature
14 | top_p (float): Nucleus sampling
15 | n (int): Number of chat completion for each input message
16 |
17 | Returns:
18 | response_object (Dict): Object returned by the model
19 | """
20 | client = OpenAI()
21 |
22 | if stage == "question_enrichment":
23 | system_content = "You are excellent data scientist and can link the information between a question and corresponding database perfectly. Your objective is to analyze the given question, corresponding database schema, database column descriptions and the evidence to create a clear link between the given question and database items which includes tables, columns and values. With the help of link, rewrite new versions of the original question to be more related with database items, understandable, clear, absent of irrelevant information and easier to translate into SQL queries. This question enrichment is essential for comprehending the question's intent and identifying the related database items. The process involves pinpointing the relevant database components and expanding the question to incorporate these items."
24 | elif stage == "candidate_sql_generation":
25 | system_content = "You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, samples and evidence. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question."
26 | elif stage == "sql_refinement":
27 | system_content = "You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, evidence, possible SQL and possible conditions. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question."
28 | elif stage == "schema_filtering":
29 | system_content = "You are an excellent data scientist. You can capture the link between a question and corresponding database and determine the useful database items (tables and columns) perfectly. Your objective is to analyze and understand the essence of the given question, corresponding database schema, database column descriptions, samples and evidence and then select the useful database items such as tables and columns. This database item filtering is essential for eliminating unnecessary information in the database so that corresponding structured query language (SQL) of the question can be generated correctly in later steps."
30 | else:
31 | raise ValueError("Wrong value for stage. It can only take following values: question_enrichment, candidate_sql_generation, sql_refinement or schema_filtering.")
32 |
33 | response_object = client.chat.completions.create(
34 | model = model,
35 | messages=[
36 | {"role": "system", "content": system_content},
37 | {"role": "user", "content": prompt}
38 | ],
39 | max_tokens = max_tokens,
40 | response_format = { "type": "json_object" },
41 | temperature = temperature,
42 | top_p = top_p,
43 | n=n,
44 | presence_penalty = 0.0,
45 | frequency_penalty = 0.0
46 | )
47 |
48 | return response_object
49 |
50 |
51 | def upload_file_to_openai(file_path: str) -> Dict:
52 | """
53 | The function uploads given file to opanai for batch processing.
54 |
55 | Arguments:
56 | file_path (str): path of the file that is going to be uplaoded
57 | Returns:
58 | file_object (FileObject): Returned file object by openai
59 | """
60 | client = OpenAI()
61 |
62 | file_object = client.files.create(
63 | file=open(file_path, "rb"),
64 | purpose="batch"
65 | )
66 |
67 | print("File is uploaded to OpenAI")
68 | return file_object
69 |
70 |
71 | def construct_request_input_object(prompt: str, id: int, model: str, system_message: str) -> Dict:
72 | """
73 | The function creates a request input object for each item in the dataset
74 |
75 | Arguments:
76 | prompt (str): prompt that is going to given to the LLM as content
77 | id (int); the id of the request
78 | model (str): LLM model name
79 | system_message (str): the content of the system message
80 |
81 | Returns:
82 | request_input_object (Dict): The dictionary format required to be for request input
83 | """
84 | request_input_object = {
85 | "custom_id": f"qe-request-{id}",
86 | "method": "POST",
87 | "url": "/v1/chat/completions",
88 | "body": {
89 | "model": model,
90 | "messages": [
91 | {"role": "system", "content": f"{system_message}"},
92 | {"role": "user", "content": f"{prompt}"}
93 | ]
94 | }
95 | }
96 | return request_input_object
--------------------------------------------------------------------------------
/utils/prompt_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | from .db_utils import *
5 | from .retrieval_utils import *
6 | from typing import Any, Union, List, Dict
7 |
8 |
9 | def load_few_shot_data(few_shot_data_path="../few-shot-data/question_enrichment_few_shot_examples.json"):
10 | """
11 | The function returns question enrichment few-shot data completely as a python dict
12 |
13 | Arguments:
14 | -
15 | Returns:
16 | question_enrichment_few_shot_data_dict (dictionary): question enrichment few-shot data completely
17 | """
18 | with open(few_shot_data_path, 'r') as file:
19 | question_enrichment_few_shot_data_dict = json.load(file)
20 |
21 | return question_enrichment_few_shot_data_dict
22 |
23 | def sql_possible_conditions_prep(possible_conditions_dict_list: Dict)-> str:
24 | """
25 | The function construct conditions statements and concatenate them.
26 |
27 | Arguments:
28 | possible_conditions_dict_list (List[Dict[str, Union[str, Dict]]]):
29 |
30 | Returns:
31 | all_possible_conditions (str)
32 | """
33 | all_possible_conditions_list = []
34 | if not possible_conditions_dict_list:
35 | return ""
36 | for p_cond in possible_conditions_dict_list:
37 | condition = f"`{p_cond['table']}`.`{p_cond['column']}` {p_cond['op']} `{p_cond['value']}`"
38 | all_possible_conditions_list.append(condition)
39 | similars_dict = p_cond['similar_values']
40 | if similars_dict:
41 | for table_name, col_val_dict in similars_dict.items():
42 | for column_name, value_list in col_val_dict.items():
43 | for val in value_list:
44 | new_possible_cond = f"`{table_name}`.`{column_name}` {p_cond['op']} `{val}`"
45 | all_possible_conditions_list.append(new_possible_cond)
46 |
47 | return str(all_possible_conditions_list)
48 |
49 |
50 | def question_relevant_descriptions_prep(database_description_path, question, relevant_description_number)-> str:
51 | """
52 | The functions concatenate the relevant database item (column) descriptions and returns it as a string
53 |
54 | Arguments:
55 | database_description_path (str): Path to the directory containing database description CSV files.
56 | question (str): the considered natural language question
57 | relevant_description_number (int): number of top ranked column descriptions
58 |
59 | Returns:
60 | str: Concatenated relevant database item describtions
61 |
62 | """
63 |
64 | relevant_db_descriptions = get_relevant_db_descriptions(database_description_path, question, relevant_description_number)
65 | db_descriptions_str = ""
66 |
67 | for description in relevant_db_descriptions:
68 | db_descriptions_str = db_descriptions_str + f"# {description} \n"
69 |
70 | return db_descriptions_str
71 |
72 |
73 | def db_column_meaning_prep(database_column_meaning_path: str, db_id: str)-> str:
74 | """
75 | The functions concatenate the database item (column) descriptions and returns it as a string
76 |
77 | Arguments:
78 | database_column_meaning_path (str): path to the column_meaning.json
79 | db_id (str): name of the database whose columns' meanings will be extracted
80 |
81 | Returns:
82 | str: Concatenated column meanings of given database
83 |
84 | """
85 |
86 | db_column_meanings = get_db_column_meanings(database_column_meaning_path, db_id)
87 | db_column_meanings_str = ""
88 |
89 | for col_meaning in db_column_meanings:
90 | db_column_meanings_str = db_column_meanings_str + f"{col_meaning} \n"
91 |
92 | return db_column_meanings_str
93 |
94 |
95 | def sql_generation_and_refinement_few_shot_prep(few_shot_data_path: str, q_db_id: str, level_shot_number:str, schema_existance: bool, mode: str) -> str:
96 | """
97 | The function selects the given number of exemple from the few-shot data, and then concatenate the selected question and their ground-truth SQL in string format.
98 | - The few-shot examples will be selected from the set of data contains databases different than the considered question.
99 | - If the question is already annotated, then few-shot examples will be selected from the set of data in which given question is excluded.
100 | - Level shot number can be between 0 to 10.
101 |
102 | Arguments:
103 | ew_shot_data_path (str): few-shot data path
104 | q_db_id (str): database ID (database name) of considered question
105 | level_shot_number (int): number of exemple desired to add in the prompt for each level.
106 | schema_existance (bool): Whether the schema will be provided for the exemplars in the prompt. If it is True, then schema will be provided.
107 | mode (str): dev mode or test mode
108 | Returns:
109 | few_shot_exemplars (str): selected and concatenated exemplars for the prompt
110 | """
111 | bird_sql_path = os.getenv('BIRD_DB_PATH')
112 |
113 | if level_shot_number == 0:
114 | return ""
115 |
116 | few_shot_exemplars = ""
117 | # Check level_shot_number
118 | if level_shot_number < 0 or level_shot_number > 10:
119 | raise ValueError("Invalid few-shot number. The level_shot_number should be between 0 and 10")
120 |
121 | # Check schema_existance
122 | if not isinstance(schema_existance, bool):
123 | raise TypeError("Provided variable is not a boolean.")
124 |
125 | # Check mode
126 | if mode not in ['test', 'dev']:
127 | raise ValueError("Invalid value for mode. The variable must be either 'dev' or 'test'.")
128 |
129 | # Get all few-shot exemples
130 | all_few_shot_data = load_few_shot_data(few_shot_data_path=few_shot_data_path)
131 | # Set difficulty levels
132 | levels = ['simple', 'moderate', 'challanging']
133 |
134 | for level in levels:
135 | examples_in_level = all_few_shot_data[level]
136 | selected_indexes = []
137 | if mode == "dev":
138 | # remove the annotated questions if their db_id is the same with the considered question's db_id
139 | examples_in_level_tmp = []
140 | for example in examples_in_level:
141 | if q_db_id != example['db_id']:
142 | examples_in_level_tmp.append(example)
143 | examples_in_level = examples_in_level_tmp
144 | # By removing the db_ids that is same with the considered question's db_id, the selection of the same question as an example is prevented in the case of the question was in the annotated data
145 |
146 | selected_indexes = random.sample(range(0,len(examples_in_level)), level_shot_number) # randomly select exemple_num_for_each_level number of example
147 | for ind in selected_indexes:
148 | current_question_info_dict = examples_in_level[ind]
149 | curr_q_db_id = current_question_info_dict['db_id']
150 | db_path = bird_sql_path + f"/{mode}/{mode}_databases/{curr_q_db_id}/{curr_q_db_id}.sqlite"
151 | curr_sql = current_question_info_dict['SQL']
152 |
153 | if schema_existance:
154 | sql_schema_dict = extract_sql_columns(db_path, curr_sql)
155 | schema = generate_schema_from_schema_dict(db_path, sql_schema_dict) # filtered schema should be as an example because given schema will be filtered one
156 | few_shot_exemplars = few_shot_exemplars + "Database Schema: \n" + schema + '\n'
157 |
158 | few_shot_exemplars = few_shot_exemplars + "Question: " + current_question_info_dict['question'] + "\n"
159 | few_shot_exemplars = few_shot_exemplars + "Evidence: " + current_question_info_dict['evidence'] + "\n"
160 | few_shot_exemplars = few_shot_exemplars + "SQL: " + current_question_info_dict['SQL'] + "\n\n"
161 |
162 | return few_shot_exemplars
163 |
164 |
165 | def fill_candidate_sql_prompt_template(template: str, schema: str, db_samples: str, question: str, few_shot_examples: str = "", evidence: str = "", db_descriptions: str = "") -> str:
166 | """
167 | The functions completes the prompt template by filling the necessary slots which are few_shot_examples, schema, question and evidence
168 |
169 | Arguments:
170 | template (str): The template that is going to be filled
171 | schema (str): The schema of the database to which considered question belong
172 | questoin (str): The considered question that is going to be enriched
173 | few_shot_examples (str): few-shot examples that are injected to the prompt
174 | evidence (str): Given evidence statment if exist
175 | db_descriptions (str): Question relevant database item(column) descriptions
176 |
177 | Returns:
178 | prompt (str): Completed prompt for question enrichment
179 | """
180 | if evidence == '' or evidence == None:
181 | evidence = '\n### Evidence: No evidence'
182 | else:
183 | evidence = f"\n### Evidence: \n {evidence}"
184 |
185 | if few_shot_examples == '' or few_shot_examples == None:
186 | few_shot_examples = ""
187 | else:
188 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}"
189 |
190 | schema = "\n### Database Schema: \n\n" + schema
191 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions
192 | db_samples = "\n### Database Samples: \n\n" + db_samples
193 | question = "\n### Question: \n" + question
194 |
195 | prompt = template.format(
196 | FEWSHOT_EXAMPLES = few_shot_examples,
197 | SCHEMA = schema,
198 | DB_SAMPLES = db_samples,
199 | QUESTION = question,
200 | EVIDENCE = evidence,
201 | DB_DESCRIPTIONS = db_descriptions
202 | )
203 |
204 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}")
205 | return prompt
206 |
207 | def extract_question_enrichment_prompt_template(enrichment_template_path = "../prompt_templates/question_enrichment_prompt_template.txt") -> str:
208 | """
209 | The function returns the question enrichment prompt template by reading corresponding txt file
210 |
211 | Arguments:
212 | None
213 | Returns:
214 | enrichment_prompt_template (str): Prompt template for enrichment prompt
215 | """
216 | with open(enrichment_template_path, 'r') as f:
217 | enrichment_prompt_template = f.read()
218 |
219 | return enrichment_prompt_template
220 |
221 |
222 | def question_enrichment_few_shot_prep(few_shot_data_path: str, q_id: int, q_db_id: str, level_shot_number: str, schema_existance: bool, enrichment_level: str, mode: str) -> str:
223 | """
224 | The function selects the given number of exemple from the question enrichment few-shot data, and then
225 | concatenate the selected exemples in string format.
226 | - The few-shot examples will be selected from the set of data contains databases different than the considered question.
227 | - If the question is already annotated, then few-shot examples will be selected from the set of data in which given question is excluded.
228 | - Level shot number can be between 0 to 10.
229 |
230 | Arguments:
231 | few_shot_data_path (str); path to the file in which few_shot_data exist
232 | q_id (int): id of the question
233 | q_db_id (str): database ID (database name)
234 | level_shot_number (int): number of exemple desired to add in the prompt for each level.
235 | schema_existance (bool): Whether the schema will be provided for the exemplars in the prompt. If it is True, then schema will be provided.
236 | enrichment_level (str): Either "basic" or "complex" for selecting enriched questions
237 | mode (str): dev mode or test mode
238 | Returns:
239 | few_shot_exemplars (str): selected and concatenated exemplars for the prompt
240 | """
241 | bird_sql_path = os.getenv('BIRD_DB_PATH')
242 |
243 | if level_shot_number == 0:
244 | return ""
245 |
246 | few_shot_exemplars = ""
247 | # Check level_shot_number
248 | if level_shot_number < 0 or level_shot_number > 10:
249 | raise ValueError("Invalid few-shot number. The level_shot_number should be between 0 and 10")
250 |
251 | # Check schema_existance
252 | if not isinstance(schema_existance, bool):
253 | raise ValueError("Invalid value for schema_existance variable,it is not a boolean. It should be either True or False.")
254 |
255 | # Check enrichment_level and set enrichment_label
256 | if enrichment_level == "basic":
257 | enrichment_label = "question_enriched"
258 | elif enrichment_level == "complex":
259 | enrichment_label = "question_enriched_v2"
260 | else:
261 | raise ValueError("Invalid value for enrichment_level. The variable must be either 'basic' or 'complex'.")
262 |
263 | # Check mode
264 | if mode not in ['test', 'dev']:
265 | raise ValueError("Invalid value for mode. The variable must be either 'dev' or 'test'.")
266 |
267 |
268 | # Get all few-shot exemples
269 | all_few_shot_data = load_few_shot_data(few_shot_data_path=few_shot_data_path)
270 | # Set difficulty levels
271 | levels = ['simple', 'moderate', 'challanging']
272 |
273 |
274 | for level in levels:
275 | examples_in_level = all_few_shot_data[level]
276 | selected_indexes = []
277 | if mode == "dev":
278 | # remove the annotated questions if their db_id is the same with the considered question's db_id
279 | examples_in_level_tmp = []
280 | for example in examples_in_level:
281 | if q_db_id != example['db_id']:
282 | examples_in_level_tmp.append(example)
283 | examples_in_level = examples_in_level_tmp
284 | # By removing the db_ids that is same with the considered question's db_id, the selection of the same question as an example is prevented in the case of the question was in the annotated data
285 |
286 | selected_indexes = random.sample(range(0,len(examples_in_level)), level_shot_number) # randomly select exemple_num_for_each_level number of example
287 | for ind in selected_indexes:
288 | current_question_info_dict = examples_in_level[ind]
289 |
290 | if schema_existance:
291 | curr_q_db_id = current_question_info_dict['db_id']
292 | db_path = bird_sql_path + f"/{mode}/{mode}_databases/{curr_q_db_id}/{curr_q_db_id}.sqlite"
293 | schema = get_schema(db_path)
294 | few_shot_exemplars = few_shot_exemplars + "Database Schema: \n" + schema + '\n'
295 |
296 | few_shot_exemplars = few_shot_exemplars + "Question: " + current_question_info_dict['question'] + "\n"
297 | few_shot_exemplars = few_shot_exemplars + "Evidence: " + current_question_info_dict['evidence'] + "\n"
298 | few_shot_exemplars = few_shot_exemplars + "Enrichment Reasoning: " + current_question_info_dict['enrichment_reasoning'] + "\n"
299 | few_shot_exemplars = few_shot_exemplars + "Enriched Question: " + current_question_info_dict[enrichment_label] + "\n\n"
300 |
301 | return few_shot_exemplars
302 |
303 |
304 | def fill_question_enrichment_prompt_template(template: str, schema: str, db_samples: str, question: str, possible_conditions: str, few_shot_examples: str, evidence: str, db_descriptions: str):
305 | """
306 | The functions completes the enrichment prompt template by filling the necessary slots which are schema, db_samples, question, possible_conditions, few_shot_examples, evidence and db_desctiptions
307 |
308 | Arguments:
309 | template (str): The template that is going to be filled
310 | schema (str): The schema of the database to which considered question belong
311 | questoin (str): The considered question that is going to be enriched
312 | possible_conditions (str): Possible conditions extracted from the possible SQLite SQL query for the question
313 | few_shot_examples (str): few-shot examples that are injected to the prompt
314 | evidence (str): Given evidence statment if exist
315 | db_descriptions (str): Question relevant database item(column) descriptions
316 |
317 | Returns:
318 | prompt (str): Completed prompt for question enrichment
319 | """
320 | if evidence == '' or evidence == None:
321 | evidence = '\n### Evidence: No evidence'
322 | else:
323 | evidence = f"\n### Evidence: \n {evidence}"
324 |
325 | if few_shot_examples == '' or few_shot_examples == None:
326 | few_shot_examples = ""
327 | else:
328 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}"
329 |
330 | schema = "\n### Database Schema: \n\n" + schema
331 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions
332 | db_samples = "\n### Database Samples: \n\n" + db_samples
333 | question = "\n### Question: \n" + question
334 |
335 | if possible_conditions:
336 | possible_conditions = "\n### Possible SQL Conditions: \n" + possible_conditions
337 | else:
338 | possible_conditions = "\n### Possible SQL Conditions: No strict conditions were found. Please consider the database schema and keywords while enriching the Question."
339 |
340 |
341 | prompt = template.format(
342 | FEWSHOT_EXAMPLES = few_shot_examples,
343 | SCHEMA = schema,
344 | DB_SAMPLES = db_samples,
345 | QUESTION = question,
346 | EVIDENCE = evidence,
347 | DB_DESCRIPTIONS = db_descriptions,
348 | POSSIBLE_CONDITIONS = possible_conditions
349 | )
350 |
351 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}")
352 | return prompt
353 |
354 |
355 |
356 | def fill_refinement_prompt_template(template: str, schema: str, possible_conditions: str, question: str, possible_sql: str, exec_err: str, few_shot_examples: str = "", evidence: str = "", db_descriptions: str = "") -> str:
357 | """
358 | The functions completes the prompt template by filling the necessary slots which are few_shot_examples, schema, question, evidence and db_decriptions, possible_sql, execution error and possible_conditions
359 |
360 | Arguments:
361 | template (str): The template that is going to be filled
362 | schema (str): The schema of the database to which considered question belong
363 | questoin (str): The considered question that is going to be enriched
364 | possible_sql (str): Possible SQLite SQL query for the question
365 | exec_err (str): Taken execution error when possible SQL is executed
366 | few_shot_examples (str): few-shot examples that are injected to the prompt
367 | evidence (str): Given evidence statment if exist
368 | db_descriptions (str): Question relevant database item(column) descriptions
369 |
370 | Returns:
371 | prompt (str): Completed prompt for question enrichment
372 | """
373 | if evidence == '' or evidence == None:
374 | evidence = '\n### Evidence: No evidence'
375 | else:
376 | evidence = f"\n### Evidence: \n {evidence}"
377 |
378 | if few_shot_examples == '' or few_shot_examples == None:
379 | few_shot_examples = ""
380 | else:
381 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}"
382 |
383 | schema = "\n### Database Schema: \n\n" + schema
384 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions
385 | question = "\n### Question: \n" + question
386 | possible_sql = "\n### Possible SQLite SQL Query: \n" + possible_sql
387 | if possible_conditions:
388 | possible_conditions = "\n### Possible SQL Conditions: \n" + possible_conditions
389 | else:
390 | possible_conditions = "\n### Possible SQL Conditions: No strict conditions were found. Please consider the database schema and keywords in the question while generating the SQL."
391 | if exec_err:
392 | exec_err = "\n### Execution Error of Possible SQL Query Above: \n" + exec_err + "\n While generating new SQLite SQL query, consider this execution error and make sure newly generated SQL query runs without execution error."
393 | else:
394 | exec_err = ""
395 |
396 |
397 | prompt = template.format(
398 | FEWSHOT_EXAMPLES = few_shot_examples,
399 | SCHEMA = schema,
400 | QUESTION = question,
401 | EVIDENCE = evidence,
402 | DB_DESCRIPTIONS = db_descriptions,
403 | POSSIBLE_SQL_Query = possible_sql,
404 | EXECUTION_ERROR = exec_err,
405 | POSSIBLE_CONDITIONS = possible_conditions
406 | )
407 |
408 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}")
409 | return prompt
410 |
411 |
412 | def schema_filtering_few_shot_prep(few_shot_data_path: str, q_db_id: str, level_shot_number:str, schema_existance: bool, mode: str) -> str:
413 | """
414 | The function selects the given number of exemple from the few-shot data, and determines the filtered schema of questions in few-shot data and then concatenate the selected exemples in string format.
415 | - The few-shot examples will be selected from the set of data contains databases different than the considered question.
416 | - If the question is already annotated, then few-shot examples will be selected from the set of data in which given question is excluded.
417 | - Level shot number can be between 0 to 10.
418 |
419 | Arguments:
420 | few_shot_data_path (str): few-shot data path
421 | q_db_id (str): database ID (database name) of considered question
422 | level_shot_number (int): number of exemple desired to add in the prompt for each level.
423 | schema_existance (bool): Whether the schema will be provided for the exemplars in the prompt. If it is True, then schema will be provided.
424 | mode (str): dev mode or test mode
425 | Returns:
426 | few_shot_exemplars (str): selected and concatenated exemplars for the prompt
427 | """
428 | bird_sql_path = os.getenv('BIRD_DB_PATH', '../../dataset/bird-sql')
429 |
430 | if level_shot_number == 0:
431 | return ""
432 |
433 | few_shot_exemplars = ""
434 | # Check level_shot_number
435 | if level_shot_number < 0 or level_shot_number > 10:
436 | raise ValueError("Invalid few-shot number. The level_shot_number should be between 0 and 10")
437 |
438 | # Check schema_existance
439 | if not isinstance(schema_existance, bool):
440 | raise TypeError("Provided variable is not a boolean.")
441 |
442 | # Check mode
443 | if mode not in ['test', 'dev']:
444 | raise ValueError("Invalid value for mode. The variable must be either 'dev' or 'test'.")
445 |
446 | # Get all few-shot exemples
447 | all_few_shot_data = load_few_shot_data(few_shot_data_path=few_shot_data_path)
448 | # Set difficulty levels
449 | levels = ['simple', 'moderate', 'challanging']
450 |
451 | for level in levels:
452 | examples_in_level = all_few_shot_data[level]
453 | selected_indexes = []
454 | if mode == "dev":
455 | # remove the annotated questions if their db_id is the same with the considered question's db_id
456 | examples_in_level_tmp = []
457 | for example in examples_in_level:
458 | if q_db_id != example['db_id']:
459 | examples_in_level_tmp.append(example)
460 | examples_in_level = examples_in_level_tmp
461 | # By removing the db_ids that is same with the considered question's db_id, the selection of the same question as an example is prevented in the case of the question was in the annotated data
462 |
463 | selected_indexes = random.sample(range(0,len(examples_in_level)), level_shot_number) # randomly select exemple_num_for_each_level number of example
464 | for ind in selected_indexes:
465 | current_question_info_dict = examples_in_level[ind]
466 | curr_q_db_id = current_question_info_dict['db_id']
467 | db_path = bird_sql_path + f"/{mode}/{mode}_databases/{curr_q_db_id}/{curr_q_db_id}.sqlite"
468 | sql = current_question_info_dict['SQL']
469 |
470 |
471 | if schema_existance:
472 | schema = get_schema(db_path)
473 | few_shot_exemplars = few_shot_exemplars + "Database Schema: \n" + schema + '\n'
474 |
475 | # print("curent directory", os.getcwd())
476 | few_shot_exemplars = few_shot_exemplars + "Question: " + current_question_info_dict['question'] + "\n"
477 | few_shot_exemplars = few_shot_exemplars + "Evidence: " + current_question_info_dict['evidence'] + "\n"
478 | filtered_schema = extract_sql_columns(db_path, sql)
479 | few_shot_exemplars = few_shot_exemplars + "Filtered Database Schema: \n" + str(filtered_schema) + "\n"
480 |
481 | return few_shot_exemplars
482 |
483 |
484 | def fill_prompt_template(template: str, schema: str, db_samples: str, question: str, few_shot_examples: str = "", evidence: str = "", db_descriptions: str = "") -> str:
485 | """
486 | The functions completes the prompt template by filling the necessary slots which are few_shot_examples, schema, question and evidence
487 |
488 | Arguments:
489 | template (str): The template that is going to be filled
490 | schema (str): The schema of the database to which considered question belong
491 | questoin (str): The considered question that is going to be enriched
492 | few_shot_examples (str): few-shot examples that are injected to the prompt
493 | evidence (str): Given evidence statment if exist
494 | db_descriptions (str): Question relevant database item(column) descriptions
495 |
496 | Returns:
497 | prompt (str): Completed prompt for question enrichment
498 | """
499 | if evidence == '' or evidence == None:
500 | evidence = '\n### Evidence: No evidence'
501 | else:
502 | evidence = f"\n### Evidence: \n {evidence}"
503 |
504 | if few_shot_examples == '' or few_shot_examples == None:
505 | few_shot_examples = ""
506 | else:
507 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}"
508 |
509 | schema = "\n### Database Schema: \n\n" + schema
510 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions
511 | db_samples = "\n### Database Samples: \n\n" + db_samples
512 | question = "\n### Question: \n" + question
513 |
514 | prompt = template.format(
515 | FEWSHOT_EXAMPLES = few_shot_examples,
516 | SCHEMA = schema,
517 | DB_SAMPLES = db_samples,
518 | QUESTION = question,
519 | EVIDENCE = evidence,
520 | DB_DESCRIPTIONS = db_descriptions
521 | )
522 |
523 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}")
524 | # print("The prompt: \n", prompt)
525 | return prompt
--------------------------------------------------------------------------------
/utils/retrieval_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import string
4 | import logging
5 | import numpy as np
6 | import pandas as pd
7 | import nltk
8 | from nltk.corpus import stopwords
9 | from nltk.tokenize import word_tokenize
10 | from rank_bm25 import BM25Okapi
11 | from typing import List
12 |
13 |
14 | def nltk_downloads():
15 | nltk.download('stopwords') # Download the stopwords
16 | nltk.download('punkt') # Download the punkt tokenizer
17 | nltk.download('punkt_tab') # Download the punkt_tab
18 | return
19 |
20 |
21 | def save_dataframe_to_csv(df: pd.DataFrame, path: str):
22 | """
23 | Saves the given pandas DataFrame to a CSV file at the specified path.
24 |
25 | Arguments:
26 | df (pd.DataFrame): The DataFrame to save.
27 | path (str): The file path where the CSV file will be saved.
28 | """
29 | try:
30 | df.to_csv(path, index=False) # Set index=False to avoid saving the index column
31 | print(f"DataFrame saved successfully to {path}")
32 | except Exception as e:
33 | logging.error(f"An error occurred while saving the DataFrame: {e}")
34 |
35 |
36 | def clean_text(textData: str)-> str:
37 | """
38 | The function process the given textData by removing stop words, removing punctuation marks and lowercasing the textData.
39 |
40 | Arguments:
41 | textData (str): text to be cleaned
42 |
43 | Returns:
44 | processedTextData (str): cleaned text
45 | """
46 |
47 | if isinstance(textData, str):
48 | textData = textData.lower()
49 | textData = textData.replace(" ", '')
50 |
51 | # Removing punctuations
52 | # textData = textData.translate(str.maketrans('', '', string.punctuation)) # converts "don't" to "dont"
53 | textData = textData.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))) # converts "don't" to "don t"
54 |
55 | # Removing stopwords
56 | stopWordsSet = set(stopwords.words('english'))
57 | tokens = word_tokenize(textData)
58 | tokens = [token for token in tokens if not token.lower() in stopWordsSet]
59 |
60 | processedTextData = ' '.join(tokens)
61 | return processedTextData
62 | else:
63 | # if the text data is NaN return empty string
64 | return ''
65 |
66 |
67 | def construct_column_information(table_desc_df: pd.DataFrame, table_name: str) -> pd.DataFrame:
68 | """
69 | The function combines the original column name, column description, data format, and value description information from table description CSV files into a single descriptive string for each column and adds it as a new column in the DataFrame.
70 |
71 | Arguments:
72 | table_desc_df (pd.DataFrame): DataFrame containing table descriptions.
73 | table_name (str): Name of the table.
74 |
75 | Returns:
76 | pd.Series: constructed single text column information for each column
77 | """
78 | # Function to build column info for each row
79 | def build_column_info(row):
80 | column_info = f"The information about the {row['original_column_name']} column of the {table_name} table [{table_name}.{row['original_column_name']}] is as following."
81 |
82 | if pd.notna(row['column_description']):
83 | column_info += f" The {row['original_column_name']} column can be described as {row['column_description']}."
84 | if pd.notna(row['value_description']):
85 | column_info += f" The value description for the {row['original_column_name']} is {row['value_description']}"
86 |
87 | column_info = column_info.replace(" ", ' ')
88 | column_info = column_info.replace(" ", ' ')
89 | return column_info
90 |
91 | # Apply the function to create the "column_info" column
92 | # table_desc_df['column_info'] = table_desc_df.apply(build_column_info, axis=1)
93 | column_info_series = table_desc_df.apply(build_column_info, axis=1)
94 |
95 | return column_info_series
96 |
97 |
98 | def process_database_descriptions(database_description_path: str):
99 | """
100 | Processes multiple CSV files in the given directory, applies the pre-existing construct_column_information function to each,
101 | and combines the "column_info" columns into a single DataFrame which is then saved as db_description.csv.
102 |
103 | Arguments:
104 | database_description_path (str): Path to the directory containing database description CSV files.
105 | """
106 |
107 | # List to store column_info from each file
108 | all_column_infos = []
109 |
110 | # Iterate over each file in the directory
111 | for filename in os.listdir(database_description_path):
112 | if filename.endswith(".csv") and filename != "db_description.csv" :
113 | print(f"------> {filename} table start to be processed.")
114 | file_path = os.path.join(database_description_path, filename)
115 | try:
116 | df = pd.read_csv(file_path)
117 | except:
118 | df = pd.read_csv(file_path, encoding='ISO-8859-1')
119 |
120 | table_name = filename.replace('.csv', '')
121 | column_info_series = construct_column_information(df, table_name)
122 | # Convert the Series to a DataFrame with a single column named 'column_info'
123 | column_info_df = column_info_series.to_frame(name='column_info')
124 | all_column_infos.append(column_info_df)
125 |
126 | # Combine all column_info data into a single DataFrame
127 | all_info_df = pd.concat(all_column_infos, ignore_index=True)
128 |
129 | # Save the DataFrame to a CSV file
130 | output_path = os.path.join(database_description_path, 'db_description.csv')
131 | all_info_df.to_csv(output_path, index=False)
132 | print(f"---> Database information saved successfully to {output_path}")
133 |
134 | return
135 |
136 |
137 |
138 | def process_all_dbs(dataset_path: str, mode: str):
139 | """
140 | The function processes description of all databases and construct db_description.csv file for all databases.
141 |
142 | Arguments:
143 | dataset_path (str): General dataset path
144 | mode (str): Either dev, test or train
145 | """
146 | nltk_downloads() # download nltk stop words
147 | databases_path = dataset_path + f"/{mode}/{mode}_databases"
148 |
149 | for db_directory in os.listdir(databases_path):
150 | #thank you, Apple inc. for creating .DS_Store files
151 | if db_directory == ".DS_Store":
152 | continue
153 | print(f"----------> Start to process {db_directory} database.")
154 | db_description_path = databases_path + "/" + db_directory + "/database_description"
155 | process_database_descriptions(database_description_path=db_description_path)
156 |
157 | print("\n\n All databases processed and db_description.csv files are created for all.\n\n")
158 | return
159 |
160 |
161 | def get_relevant_db_descriptions(database_description_path: str, question: str, relevant_description_number: int = 6) -> List[str]:
162 | """
163 | The function returns list relevant column descriptions
164 |
165 | Arguments:
166 | database_description_path (str): Path to the directory containing database description CSV files.
167 | question (str): the considered natural language question
168 | relevant_description_number (int): number of top ranked column descriptions
169 |
170 | Returns:
171 | relevant_db_descriptions (List[str]): List of relevant column descriptions.
172 | """
173 | db_description_csv_path = database_description_path + "/db_description.csv"
174 |
175 | if not os.path.exists(db_description_csv_path):
176 | process_database_descriptions(database_description_path)
177 |
178 | # Read the database description db_info.csv file
179 | db_desc_df = pd.read_csv(db_description_csv_path)
180 | db_description_corpus = db_desc_df['column_info'].tolist()
181 | db_description_corpus_cleaned = [clean_text(description) for description in db_description_corpus]
182 |
183 | # Tokenize corpus and create bm25 instance using cleaned corpus
184 | tokenized_db_description_corpus_cleaned = [doc.split(" ") for doc in db_description_corpus_cleaned]
185 | bm25 = BM25Okapi(tokenized_db_description_corpus_cleaned)
186 |
187 | # Tokenize question
188 | tokenized_question = question.split(" ")
189 |
190 | relevant_db_descriptions = bm25.get_top_n(tokenized_question, db_description_corpus, n=relevant_description_number)
191 | return relevant_db_descriptions
192 |
193 | def get_db_column_meanings(database_column_meaning_path: str, db_id: str) -> List[str]:
194 | """
195 | The function extracts required database column meanings.
196 |
197 | Arguments:
198 | database_column_meaning_path (str): path to the column_meaning.json
199 | db_id (str): name of the database whose columns' meanings will be extracted
200 |
201 | Returns:
202 | List[str]: A list of strings explaining the database column meanings.
203 | """
204 | # Load the JSON file
205 | with open(database_column_meaning_path, 'r') as file:
206 | column_meanings = json.load(file)
207 |
208 | # Initialize a list to store the extracted meanings
209 | meanings = []
210 |
211 | # Iterate through each key in the JSON
212 | for key, explanation in column_meanings.items():
213 | # Check if the key starts with the given db_id
214 | if key.startswith(db_id + "|"):
215 | # Extract the table name and column name from the key
216 | _, table_name, column_name = key.split("|")
217 | # Construct the meaning string in the desired format
218 | meaning = f"# Meaning of {column_name} column of {table_name} table in database is that {explanation.strip('# ').strip()}"
219 | meanings.append(meaning)
220 |
221 | return meanings
222 |
--------------------------------------------------------------------------------