├── .gitignore ├── README.md ├── assets ├── CSG_module_prompt_template.png ├── QE_module_prompt_template.png ├── SF_module_prompt_template.png ├── SR_module_prompt_template.png ├── e-sql-flowchart_qid_1448.png └── e-sql-pipeline.png ├── env.example ├── evaluation ├── evaluation.py ├── evaluation_ex.py ├── evaluation_f1.py ├── evaluation_utils.py └── evaluation_ves.py ├── few-shot-data └── question_enrichment_few_shot_examples.json ├── main.py ├── pipeline └── Pipeline.py ├── prompt_templates ├── candidate_sql_generation_prompt_template.txt ├── question_enrichment_prompt_template.txt ├── schema_filter_prompt_template.txt └── sql_refinement_prompt_template.txt ├── requirements.txt ├── run_evaluation.sh ├── run_main.sh └── utils ├── __init__.py ├── db_utils.py ├── openai_utils.py ├── prompt_utils.py └── retrieval_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | __pycache__/ 3 | *.py[cod] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E-SQL: Direct Schema Linking via Question Enrichment in Text-to-SQL 2 | 3 | This is the official repository for the paper **"E-SQL: Direct Schema Linking via Question Enrichment in Text-to-SQL"**. 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://www.arxiv.org/abs/2409.16751) [![Bibtex](https://img.shields.io/badge/Cite-BibTeX-orange)](#citation) 6 | 7 | ## Overview 8 | 9 | Translating natural language queries into SQL (Text-to-SQL) is a critical task for enabling natural language interfaces to databases (NLIDB), but challenges such as complex schemas and ambiguous queries often hinder the accuracy of generated SQL. **E-SQL** addresses these challenges through a novel pipeline that directly links relevant database schema items with the natural language query, a method we call **Question Enrichment**. 10 | 11 | E-SQL enriches natural language questions by incorporating relevant database elements—such as tables, columns, and potential conditions—directly into the query to enhance SQL generation accuracy. Our pipeline also introduces **Candidate Predicate Generation** to further reduce SQL errors caused by incomplete or incorrect predicates. 12 | 13 |
14 | 15 | ### Key modules of the E-SQL pipeline: 16 | - **Candidate SQL Generation (CSG)**: Generates initial SQL queries based on the natural language question. 17 | - **Candidate Predicate Generation (CPG)**: Extracts and incorporates likely predicates from the database. 18 | - **Question Enrichment (QE)**: Enhances the natural language query by linking relevant database items and conditions. 19 | - **SQL Refinement (SR)**: Refines generated SQL queries by correcting minor errors and ensuring execution correctness. 20 | 21 | While schema filtering has been widely adopted in previous research, our experiments show that it can lead to performance degradation when used alongside advanced large language models (LLMs). Instead, direct schema linking through question enrichment and candidate predicate augmentation proves to be more effective, particularly on complex queries. 22 | 23 | 24 | E-SQL execution flow for th question with question ID 1448 in the development set is given below. 25 |
26 | 27 | This repository contains all the code for implementing and evaluating **E-SQL** on the BIRD development set, utilizing models like GPT-4 and GPT-4o-mini to achieve competitive performance. 28 | 29 | 30 | 31 | ## Project Structure 32 | To avoid any potential file path errors, organize your project files as follows: 33 | 34 | ```text 35 | nlq-to-sql/ 36 | ├── dataset/ 37 | │ ├── bird-sql/ 38 | │ │ ├── dev/ 39 | | | | ├── dev_databases/ 40 | | | | ├── column_meaning.json 41 | | | | ├── dev_gold.json 42 | | | | ├── dev_tables.json 43 | | | | ├── dev_tied_append.json 44 | | | | ├── dev.json 45 | │ │ ├── test/ 46 | | | | ├── test_databases/ 47 | | | | ├── column_meaning.json 48 | | | | ├── test_gold.json 49 | | | | ├── test_tables.json 50 | | | | ├── test.json 51 | ├── E-SQL/ 52 | │ ├── evaluation/ 53 | │ │ ├── evaluation_ves.py 54 | │ │ └── evaluation.py 55 | │ ├── few-shot-data/ 56 | │ │ └── question_enrichment_few_shot_examples.json 57 | │ ├── pipeline/ 58 | │ │ └── Pipeline.py 59 | │ ├── prompt_templates/ 60 | │ │ ├── candidate_sql_generation_prompt_template.txt 61 | │ │ ├── question_enrichment_prompt_template.txt 62 | │ │ └── sql_refinement_prompt_template.txt 63 | │ ├── results/ 64 | │ │ ├── model_outputs_dev_CSG-EQ-SR_gpt-4o-2024-08-06/ 65 | │ │ ├── model_outputs_dev_CSG-EQ-SR_gpt-4o-mini-2024-07-18/ 66 | │ ├── utils/ 67 | │ │ ├── __init__.py 68 | │ │ ├── db_utils.py 69 | │ │ ├── openai_utils.py 70 | │ │ ├── prompt_utils.py 71 | │ │ └── retrieval_utils.py 72 | │ ├── .env 73 | │ ├── .gitignore 74 | │ ├── env.example 75 | │ ├── main.py 76 | │ ├── README.md 77 | │ ├── requirements.txt 78 | │ ├── run_evaluation.sh 79 | │ ├── run_main.sh 80 | 81 | ``` 82 | hint: you can generate column_meaning.json for dev and val dataset from TA-SQL: https://github.com/quge2023/TA-SQL 83 | 84 | ## Setting up the Environment 85 | 1. **Download the compressed code.** 86 | 2. **Create a `.env` file and set environment variable as follows:** In the root directory of the E-SQL add the following items to the `.env` file. Ensure that your environment variables are correctly configured. One critical point is to set the DB_ROOT_PATH environment variable to the root directory of the bird-dataset, which contains both the dev and test directories. This is critical for the proper functioning of the code, as the database paths are dynamically constructed based on the mode (dev or test). To avoid any potential file path errors, we strongly recommend organizing your project files as we described above. 87 | 88 | ``` 89 | DB_ROOT_PATH="../dataset/bird-sql" 90 | OPENAI_API_KEY= 91 | ``` 92 | 93 | 3. **Install the required packages:** Use the following command to install the required packages: 94 | ``` 95 | pip install -r requirements.txt 96 | ``` 97 | 98 | 99 | ## Running the Code 100 | 1. **Update the `run_main.sh` file for running mode or OpenAI model change:** In `run_main.sh` file set the mode and model argument. Do not change the other arguments in the `run_main.sh`. 101 | 102 | ```bash 103 | mode='test' # Update this with the mode you want to run. Write either 'dev' or 'test' 104 | model="gpt-4o-2024-08-06" # For GPT-4o use "gpt-4o-2024-08-06". For GPT-4o-mini use "gpt-4o-mini-2024-07-18". 105 | pipeline_order="CSG-QE-SR" 106 | ... 107 | ``` 108 | 109 | 110 | 111 | 2. **Run the main script:** After move to the E-SQL directory run the main script as follows. Running main script initially process all database csv files that include description about tables and column of databases. 112 | ``` 113 | cd E-SQL 114 | sh run_main.sh 115 | ``` 116 | 117 | 3. Review Informative Logs: After running the main script, the database description files will be processed, and informative logs will be printed. These logs provide detailed information about the processing steps and the status of each database description file. No additional effort is required during this step; it is intended for informational purposes only to help you monitor the process. 118 | 119 | 120 | ## Evaluation 121 | 122 | 1. **Setting up the Evaluation Arguments:** In order to evaluate the results, set the arguments in the `run_evaluation.sh` as follows: 123 | 124 | ```bash 125 | db_root_path='../dataset/bird-sql/dev/dev_databases/' 126 | data_mode='dev' 127 | diff_json_path='../dataset/bird-sql/dev/dev.json' 128 | predicted_sql_path_kg='./results/model_outputs_{args.mode}_{args.pipeline_order}_{args.model}' 129 | ground_truth_path='../dataset/bird-sql/dev/' 130 | num_cpus=8 131 | meta_time_out=30.0 132 | mode_gt='gt' 133 | mode_predict='gpt' 134 | ``` 135 | 136 | 137 | 138 | 2. **Run the evaluation script** after main script is executed and all predicted SQLs are saved into the `predict_dev.json` or `predict_test.json` files under `./results/model_outputs_{args.mode}_{args.pipeline_order}_{args.model}` directory. Please run the evaluation script as follows to print both EX and VES metrics after completed the previous step. 139 | ``` 140 | sh run_evaluation.sh 141 | ``` 142 | 143 | 3. **Cost estimation**: The evaluation on the dev set involves approximately 40.25 million prompt tokens and 1.23 million output tokens, resulting in an estimated cost of $112 for GPT-4o and $8 for GPT-4o-mini. 144 | 145 | ## Metrics 146 | 147 | Finally the `metric.json` file under the `./results/model_outputs` directory is created to inspect the execution accuracies. The `metric.json` file format is as followings. Note that this file doesn't include VES metric. You can also run the evaluation script as explained in the previous part. 148 | 149 | ```json 150 | "candidate_generation": { 151 | "EX": "Overall Execution Accuracy", 152 | "total_correct_count": "Total number of correctly predicted SQLs.", 153 | "total_item_count": "Total item count", 154 | "simple_stats": { 155 | "correct_number": "The number of correctly predicted SQL considered as simple.", 156 | "count": "The total number of simple questions.", 157 | "ex": "Execution accuracy for the simple questions." 158 | }, 159 | "moderate_stats": { 160 | "correct_number": "The number of correctly predicted SQL considered as moderate.", 161 | "count": "The total number of moderate questions.", 162 | "ex": "Execution accuracy for the moderate questions." 163 | }, 164 | "challenging_stats": { 165 | "correct_number": "The number of correctly predicted SQL considered as challenging.", 166 | "count": "The total number of challenging questions.", 167 | "ex": "Execution accuracy for the challenging questions." 168 | }, 169 | "fail_q_ids": [], 170 | "config": { 171 | "mode": "Either dev or test", 172 | "model": "OpenAI model", 173 | "temperature": 0.0, 174 | "top_p": 1.0, 175 | "max_tokens": 2048, 176 | "n": 1, 177 | "pipeline_order": "Selected Pipeline Order", 178 | "enrichment_level": "complex", 179 | "enrichment_level_shot_number": 3, 180 | "enrichment_few_shot_schema_existance": false, 181 | "filtering_level_shot_number": 3, 182 | "filtering_few_shot_schema_existance": false, 183 | "cfg": true, 184 | "generation_level_shot_number": 3, 185 | "generation_few_shot_schema_existance": false, 186 | "db_sample_limit": 10, 187 | "relevant_description_number": 20, 188 | "seed": 42 189 | } 190 | } 191 | ``` 192 | 193 | ## Additional Notes: 194 | 195 | 1. **Column Meaning File Usage**: The `column_meaning.json` file is needed in both dev and test. 196 | 197 | 2. **Error Handling**: There are error-handling mechanisms and loggings. However, in some cases, the LLM might fail to generate a SQL query, causing an error that halts the code execution. When such an error occurs, the question ID where the code stopped can be viewed both in the terminal and in the predictions.json or predict_dev.json files, located in the results/model_outputs_{mode}_{pipeline_order}_{model} directory. When restarting the evaluation, locate the error point, then open the main.py file. You'll find a section labeled "In case of an error, you can restart the code from the point of error using the following line." Update the dataset list accordingly so that the code resumes from where it left off, preventing the need to start the process over from the beginning. Although the code appends new predicted SQLs and corresponding objects to the `predict_dev.json` and `predictions.json` files, please make a copy of both files with a different name in case of an error, to ensure that previously generated data is not lost. 198 | 199 | 3. **SQL Parsing**: The generated SQL queries are parsed using the SQLGlot SQL parser. While parsing, errors may arise, but these are handled using try-except statements. Although warnings are logged when parsing issues occur, these can generally be ignored as long as no errors are thrown and the execution proceeds smoothly. 200 | 201 | 4. **Gold SQL File Naming**: The evaluation script expects a file named dev_gold.sql. However, the corresponding file downloaded from BIRD website was named dev.sql. To ensure compatibility with the evaluation script, copy and rename the dev.sql file to dev_gold.sql. 202 | 203 | 204 | # Citation 205 | 206 | If you find this repository helpful, please cite the following paper: 207 | 208 | ``` 209 | @misc{caferoğlu2024esqldirectschemalinking, 210 | title={E-SQL: Direct Schema Linking via Question Enrichment in Text-to-SQL}, 211 | author={Hasan Alp Caferoğlu and Özgür Ulusoy}, 212 | year={2024}, 213 | eprint={2409.16751}, 214 | archivePrefix={arXiv}, 215 | primaryClass={cs.CL}, 216 | url={https://arxiv.org/abs/2409.16751}, 217 | } 218 | ``` 219 | -------------------------------------------------------------------------------- /assets/CSG_module_prompt_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/CSG_module_prompt_template.png -------------------------------------------------------------------------------- /assets/QE_module_prompt_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/QE_module_prompt_template.png -------------------------------------------------------------------------------- /assets/SF_module_prompt_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/SF_module_prompt_template.png -------------------------------------------------------------------------------- /assets/SR_module_prompt_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/SR_module_prompt_template.png -------------------------------------------------------------------------------- /assets/e-sql-flowchart_qid_1448.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/e-sql-flowchart_qid_1448.png -------------------------------------------------------------------------------- /assets/e-sql-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/assets/e-sql-pipeline.png -------------------------------------------------------------------------------- /env.example: -------------------------------------------------------------------------------- 1 | BIRD_DB_PATH="../dataset/bird-sql" 2 | OPENAI_API_KEY= -------------------------------------------------------------------------------- /evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import argparse 4 | import sqlite3 5 | import multiprocessing as mp 6 | from func_timeout import func_timeout, FunctionTimedOut 7 | 8 | def load_json(dir): 9 | with open(dir, 'r') as j: 10 | contents = json.loads(j.read()) 11 | return contents 12 | 13 | def result_callback(result): 14 | exec_result.append(result) 15 | 16 | 17 | def execute_sql(predicted_sql,ground_truth, db_path): 18 | conn = sqlite3.connect(db_path) 19 | # Connect to the database 20 | cursor = conn.cursor() 21 | cursor.execute(predicted_sql) 22 | predicted_res = cursor.fetchall() 23 | cursor.execute(ground_truth) 24 | ground_truth_res = cursor.fetchall() 25 | res = 0 26 | if set(predicted_res) == set(ground_truth_res): 27 | res = 1 28 | return res 29 | 30 | 31 | 32 | def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out): 33 | try: 34 | res = func_timeout(meta_time_out, execute_sql, 35 | args=(predicted_sql, ground_truth, db_place)) 36 | except KeyboardInterrupt: 37 | sys.exit(0) 38 | except FunctionTimedOut: 39 | result = [(f'timeout',)] 40 | res = 0 41 | except Exception as e: 42 | result = [(f'error',)] # possibly len(query) > 512 or not executable 43 | res = 0 44 | # print(result) 45 | # result = str(set([ret[0] for ret in result])) 46 | result = {'sql_idx': idx, 'res': res} 47 | # print(result) 48 | return result 49 | 50 | 51 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): 52 | clean_sqls = [] 53 | db_path_list = [] 54 | if mode == 'gpt': 55 | sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) 56 | for idx, sql_str in sql_data.items(): 57 | if type(sql_str) == str: 58 | sql, db_name = sql_str.split('\t----- bird -----\t') 59 | else: 60 | sql, db_name = " ", "financial" 61 | clean_sqls.append(sql) 62 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 63 | 64 | elif mode == 'gt': 65 | sqls = open(sql_path + data_mode + '_gold.sql') 66 | sql_txt = sqls.readlines() 67 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt] 68 | for idx, sql_str in enumerate(sql_txt): 69 | sql, db_name = sql_str.strip().split('\t') 70 | clean_sqls.append(sql) 71 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 72 | 73 | return clean_sqls, db_path_list 74 | 75 | def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): 76 | pool = mp.Pool(processes=num_cpus) 77 | for i,sql_pair in enumerate(sqls): 78 | 79 | predicted_sql, ground_truth = sql_pair 80 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback) 81 | pool.close() 82 | pool.join() 83 | 84 | def sort_results(list_of_dicts): 85 | return sorted(list_of_dicts, key=lambda x: x['sql_idx']) 86 | 87 | def compute_acc_by_diff(exec_results,diff_json_path): 88 | num_queries = len(exec_results) 89 | results = [res['res'] for res in exec_results] 90 | contents = load_json(diff_json_path) 91 | simple_results, moderate_results, challenging_results = [], [], [] 92 | 93 | for i,content in enumerate(contents): 94 | if content['difficulty'] == 'simple': 95 | simple_results.append(exec_results[i]) 96 | 97 | if content['difficulty'] == 'moderate': 98 | moderate_results.append(exec_results[i]) 99 | 100 | if content['difficulty'] == 'challenging': 101 | challenging_results.append(exec_results[i]) 102 | 103 | simple_acc = sum([res['res'] for res in simple_results])/len(simple_results) 104 | moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results) 105 | challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results) 106 | all_acc = sum(results)/num_queries 107 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] 108 | return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists 109 | 110 | 111 | 112 | def print_data(score_lists,count_lists): 113 | levels = ['simple', 'moderate', 'challenging', 'total'] 114 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 115 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) 116 | 117 | print('====================================== ACCURACY =====================================') 118 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) 119 | 120 | 121 | if __name__ == '__main__': 122 | args_parser = argparse.ArgumentParser() 123 | args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') 124 | args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') 125 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev') 126 | args_parser.add_argument('--db_root_path', type=str, required=True, default='') 127 | args_parser.add_argument('--num_cpus', type=int, default=1) 128 | args_parser.add_argument('--meta_time_out', type=float, default=30.0) 129 | args_parser.add_argument('--mode_gt', type=str, default='gt') 130 | args_parser.add_argument('--mode_predict', type=str, default='gpt') 131 | args_parser.add_argument('--difficulty',type=str,default='simple') 132 | args_parser.add_argument('--diff_json_path',type=str,default='') 133 | args = args_parser.parse_args() 134 | exec_result = [] 135 | 136 | pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, 137 | data_mode=args.data_mode) 138 | # generate gt sqls: 139 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', 140 | data_mode=args.data_mode) 141 | 142 | query_pairs = list(zip(pred_queries,gt_queries)) 143 | run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) 144 | exec_result = sort_results(exec_result) 145 | 146 | print('start calculate') 147 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ 148 | compute_acc_by_diff(exec_result,args.diff_json_path) 149 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 150 | print_data(score_lists,count_lists) 151 | print('===========================================================================================') 152 | print("Finished evaluation") 153 | -------------------------------------------------------------------------------- /evaluation/evaluation_ex.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sys 3 | import argparse 4 | import multiprocessing as mp 5 | from func_timeout import func_timeout, FunctionTimedOut 6 | from evaluation_utils import ( 7 | load_json, 8 | execute_sql, 9 | package_sqls, 10 | sort_results, 11 | print_data, 12 | ) 13 | 14 | 15 | def result_callback(result): 16 | exec_result.append(result) 17 | 18 | 19 | def calculate_ex(predicted_res, ground_truth_res): 20 | res = 0 21 | if set(predicted_res) == set(ground_truth_res): 22 | res = 1 23 | return res 24 | 25 | 26 | def execute_model( 27 | predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect 28 | ): 29 | try: 30 | res = func_timeout( 31 | meta_time_out, 32 | execute_sql, 33 | args=(predicted_sql, ground_truth, db_place, sql_dialect, calculate_ex), 34 | ) 35 | except KeyboardInterrupt: 36 | sys.exit(0) 37 | except FunctionTimedOut: 38 | result = [(f"timeout",)] 39 | res = 0 40 | except Exception as e: 41 | result = [(f"error",)] # possibly len(query) > 512 or not executable 42 | res = 0 43 | result = {"sql_idx": idx, "res": res} 44 | return result 45 | 46 | 47 | def run_sqls_parallel( 48 | sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite" 49 | ): 50 | pool = mp.Pool(processes=num_cpus) 51 | for i, sql_pair in enumerate(sqls): 52 | 53 | predicted_sql, ground_truth = sql_pair 54 | pool.apply_async( 55 | execute_model, 56 | args=( 57 | predicted_sql, 58 | ground_truth, 59 | db_places[i], 60 | i, 61 | meta_time_out, 62 | sql_dialect, 63 | ), 64 | callback=result_callback, 65 | ) 66 | pool.close() 67 | pool.join() 68 | 69 | 70 | def compute_acc_by_diff(exec_results, diff_json_path): 71 | num_queries = len(exec_results) 72 | results = [res["res"] for res in exec_results] 73 | contents = load_json(diff_json_path) 74 | simple_results, moderate_results, challenging_results = [], [], [] 75 | 76 | for i, content in enumerate(contents): 77 | if content["difficulty"] == "simple": 78 | simple_results.append(exec_results[i]) 79 | 80 | if content["difficulty"] == "moderate": 81 | moderate_results.append(exec_results[i]) 82 | 83 | if content["difficulty"] == "challenging": 84 | try: 85 | challenging_results.append(exec_results[i]) 86 | except: 87 | print(i) 88 | 89 | simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) 90 | moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) 91 | challenging_acc = sum([res["res"] for res in challenging_results]) / len( 92 | challenging_results 93 | ) 94 | all_acc = sum(results) / num_queries 95 | count_lists = [ 96 | len(simple_results), 97 | len(moderate_results), 98 | len(challenging_results), 99 | num_queries, 100 | ] 101 | return ( 102 | simple_acc * 100, 103 | moderate_acc * 100, 104 | challenging_acc * 100, 105 | all_acc * 100, 106 | count_lists, 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | # Record the start time 112 | start_time = datetime.now() 113 | print(f"Start time: {start_time}") 114 | 115 | args_parser = argparse.ArgumentParser() 116 | args_parser.add_argument( 117 | "--predicted_sql_path", type=str, required=True, default="" 118 | ) 119 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 120 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 121 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 122 | args_parser.add_argument("--num_cpus", type=int, default=1) 123 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 124 | args_parser.add_argument("--mode_gt", type=str, default="gt") 125 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 126 | args_parser.add_argument("--difficulty", type=str, default="simple") 127 | args_parser.add_argument("--diff_json_path", type=str, default="") 128 | args_parser.add_argument("--engine", type=str, default="") 129 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 130 | args = args_parser.parse_args() 131 | exec_result = [] 132 | 133 | pred_queries, db_paths = package_sqls( 134 | args.predicted_sql_path, 135 | args.db_root_path, 136 | args.engine, 137 | sql_dialect=args.sql_dialect, 138 | mode=args.mode_predict, 139 | data_mode=args.data_mode, 140 | ) 141 | # generate ground truth sqls: 142 | gt_queries, db_paths_gt = package_sqls( 143 | args.ground_truth_path, 144 | args.db_root_path, 145 | args.engine, 146 | sql_dialect=args.sql_dialect, 147 | mode="gt", 148 | data_mode=args.data_mode, 149 | ) 150 | 151 | query_pairs = list(zip(pred_queries, gt_queries)) 152 | 153 | run_sqls_parallel( 154 | query_pairs, 155 | db_places=db_paths, 156 | num_cpus=args.num_cpus, 157 | meta_time_out=args.meta_time_out, 158 | sql_dialect=args.sql_dialect, 159 | ) 160 | exec_result = sort_results(exec_result) 161 | print("start calculate") 162 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( 163 | exec_result, args.diff_json_path 164 | ) 165 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 166 | print(f"EX for {args.engine} on {args.sql_dialect} set") 167 | print("start calculate") 168 | print_data(score_lists, count_lists, metric="EX") 169 | print( 170 | "===========================================================================================" 171 | ) 172 | print(f"Finished EX evaluation for {args.engine} on {args.sql_dialect} set") 173 | print("\n\n") 174 | 175 | 176 | # Record the end time and calculate duration 177 | end_time = datetime.now() 178 | print(f"End time: {end_time}") 179 | duration = end_time - start_time 180 | print(f"Duration: {duration}") 181 | 182 | # Saving EX results in a txt file 183 | ex_result_file_path = args.predicted_sql_path + "ex_score.txt" 184 | ex_file_content = f"The EX scores are: \nOverall: {acc} \nSimple: {simple_acc} \nModerate: {moderate_acc} \nChallenging: {challenging_acc} \n\nCount Lists: {count_lists} \n\nEvaluation Duration: {duration} \n\nMeta Time-out: {args.meta_time_out}" 185 | with open(ex_result_file_path, 'w') as f: 186 | f.write(ex_file_content) 187 | 188 | print("EX score is written into the ex_score.txt file.") 189 | 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /evaluation/evaluation_f1.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sys 3 | import argparse 4 | import multiprocessing as mp 5 | from func_timeout import func_timeout, FunctionTimedOut 6 | from evaluation_utils import ( 7 | load_json, 8 | execute_sql, 9 | package_sqls, 10 | sort_results, 11 | print_data, 12 | ) 13 | 14 | 15 | def calculate_row_match(predicted_row, ground_truth_row): 16 | """ 17 | Calculate the matching percentage for a single row. 18 | 19 | Args: 20 | predicted_row (tuple): The predicted row values. 21 | ground_truth_row (tuple): The actual row values from ground truth. 22 | 23 | Returns: 24 | float: The match percentage (0 to 1 scale). 25 | """ 26 | total_columns = len(ground_truth_row) 27 | matches = 0 28 | element_in_pred_only = 0 29 | element_in_truth_only = 0 30 | for pred_val in predicted_row: 31 | if pred_val in ground_truth_row: 32 | matches += 1 33 | else: 34 | element_in_pred_only += 1 35 | for truth_val in ground_truth_row: 36 | if truth_val not in predicted_row: 37 | element_in_truth_only += 1 38 | match_percentage = matches / total_columns 39 | pred_only_percentage = element_in_pred_only / total_columns 40 | truth_only_percentage = element_in_truth_only / total_columns 41 | return match_percentage, pred_only_percentage, truth_only_percentage 42 | 43 | 44 | def calculate_f1_score(predicted, ground_truth): 45 | """ 46 | Calculate the F1 score based on sets of predicted results and ground truth results, 47 | where each element (tuple) represents a row from the database with multiple columns. 48 | 49 | Args: 50 | predicted (set of tuples): Predicted results from SQL query. 51 | ground_truth (set of tuples): Actual results expected (ground truth). 52 | 53 | Returns: 54 | float: The calculated F1 score. 55 | """ 56 | # if both predicted and ground_truth are empty, return 1.0 for f1_score 57 | if not predicted and not ground_truth: 58 | return 1.0 59 | 60 | # Drop duplicates 61 | predicted_set = set(predicted) if predicted else set() 62 | ground_truth_set = set(ground_truth) 63 | 64 | # convert back to list 65 | predicted = list(predicted_set) 66 | ground_truth = list(ground_truth_set) 67 | 68 | # Calculate matching scores for each possible pair 69 | match_scores = [] 70 | pred_only_scores = [] 71 | truth_only_scores = [] 72 | for i, gt_row in enumerate(ground_truth): 73 | # rows only in the ground truth results 74 | if i >= len(predicted): 75 | match_scores.append(0) 76 | truth_only_scores.append(1) 77 | continue 78 | pred_row = predicted[i] 79 | match_score, pred_only_score, truth_only_score = calculate_row_match( 80 | pred_row, gt_row 81 | ) 82 | match_scores.append(match_score) 83 | pred_only_scores.append(pred_only_score) 84 | truth_only_scores.append(truth_only_score) 85 | 86 | # rows only in the predicted results 87 | for i in range(len(predicted) - len(ground_truth)): 88 | match_scores.append(0) 89 | pred_only_scores.append(1) 90 | truth_only_scores.append(0) 91 | 92 | tp = sum(match_scores) 93 | fp = sum(pred_only_scores) 94 | fn = sum(truth_only_scores) 95 | 96 | precision = tp / (tp + fp) if tp + fp > 0 else 0 97 | recall = tp / (tp + fn) if tp + fn > 0 else 0 98 | 99 | f1_score = ( 100 | 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 101 | ) 102 | return f1_score 103 | 104 | 105 | def result_callback(result): 106 | exec_result.append(result) 107 | 108 | 109 | def execute_model( 110 | predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect 111 | ): 112 | try: 113 | res = func_timeout( 114 | meta_time_out, 115 | execute_sql, 116 | args=( 117 | predicted_sql, 118 | ground_truth, 119 | db_place, 120 | sql_dialect, 121 | calculate_f1_score, 122 | ), 123 | ) 124 | except KeyboardInterrupt: 125 | sys.exit(0) 126 | except FunctionTimedOut: 127 | result = [(f"timeout",)] 128 | res = 0 129 | except Exception as e: 130 | result = [(f"error",)] # possibly len(query) > 512 or not executable 131 | res = 0 132 | # print(result) 133 | # result = str(set([ret[0] for ret in result])) 134 | result = {"sql_idx": idx, "res": res} 135 | # print(result) 136 | return result 137 | 138 | 139 | def run_sqls_parallel( 140 | sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite" 141 | ): 142 | pool = mp.Pool(processes=num_cpus) 143 | for i, sql_pair in enumerate(sqls): 144 | 145 | predicted_sql, ground_truth = sql_pair 146 | pool.apply_async( 147 | execute_model, 148 | args=( 149 | predicted_sql, 150 | ground_truth, 151 | db_places[i], 152 | i, 153 | meta_time_out, 154 | sql_dialect, 155 | ), 156 | callback=result_callback, 157 | ) 158 | pool.close() 159 | pool.join() 160 | 161 | 162 | def compute_f1_by_diff(exec_results, diff_json_path): 163 | num_queries = len(exec_results) 164 | results = [res["res"] for res in exec_results] 165 | contents = load_json(diff_json_path) 166 | simple_results, moderate_results, challenging_results = [], [], [] 167 | 168 | for i, content in enumerate(contents): 169 | if content["difficulty"] == "simple": 170 | simple_results.append(exec_results[i]) 171 | 172 | if content["difficulty"] == "moderate": 173 | moderate_results.append(exec_results[i]) 174 | 175 | if content["difficulty"] == "challenging": 176 | try: 177 | challenging_results.append(exec_results[i]) 178 | except: 179 | print(i) 180 | 181 | simple_f1 = sum([res["res"] for res in simple_results]) / len(simple_results) * 100 182 | moderate_f1 = ( 183 | sum([res["res"] for res in moderate_results]) / len(moderate_results) * 100 184 | ) 185 | challenging_f1 = ( 186 | sum([res["res"] for res in challenging_results]) 187 | / len(challenging_results) 188 | * 100 189 | ) 190 | all_f1 = sum(results) / num_queries * 100 191 | count_lists = [ 192 | len(simple_results), 193 | len(moderate_results), 194 | len(challenging_results), 195 | num_queries, 196 | ] 197 | return ( 198 | simple_f1, 199 | moderate_f1, 200 | challenging_f1, 201 | all_f1, 202 | count_lists, 203 | ) 204 | 205 | 206 | if __name__ == "__main__": 207 | # Record the start time 208 | start_time = datetime.now() 209 | print(f"Start time: {start_time}") 210 | 211 | args_parser = argparse.ArgumentParser() 212 | args_parser.add_argument( 213 | "--predicted_sql_path", type=str, required=True, default="" 214 | ) 215 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 216 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 217 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 218 | args_parser.add_argument("--num_cpus", type=int, default=1) 219 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 220 | args_parser.add_argument("--mode_gt", type=str, default="gt") 221 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 222 | args_parser.add_argument("--difficulty", type=str, default="simple") 223 | args_parser.add_argument("--diff_json_path", type=str, default="") 224 | args_parser.add_argument("--engine", type=str, default="") 225 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 226 | args = args_parser.parse_args() 227 | exec_result = [] 228 | 229 | pred_queries, db_paths = package_sqls( 230 | args.predicted_sql_path, 231 | args.db_root_path, 232 | args.engine, 233 | sql_dialect=args.sql_dialect, 234 | mode=args.mode_predict, 235 | data_mode=args.data_mode, 236 | ) 237 | # generate ground truth sqls: 238 | gt_queries, db_paths_gt = package_sqls( 239 | args.ground_truth_path, 240 | args.db_root_path, 241 | args.engine, 242 | sql_dialect=args.sql_dialect, 243 | mode="gt", 244 | data_mode=args.data_mode, 245 | ) 246 | 247 | query_pairs = list(zip(pred_queries, gt_queries)) 248 | 249 | run_sqls_parallel( 250 | query_pairs, 251 | db_places=db_paths, 252 | num_cpus=args.num_cpus, 253 | meta_time_out=args.meta_time_out, 254 | sql_dialect=args.sql_dialect, 255 | ) 256 | exec_result = sort_results(exec_result) 257 | 258 | print("start calculate") 259 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_f1_by_diff( 260 | exec_result, args.diff_json_path 261 | ) 262 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 263 | print(f"Soft F1 for {args.engine} on {args.sql_dialect} set") 264 | print("start calculate") 265 | print_data(score_lists, count_lists) 266 | print( 267 | "===========================================================================================" 268 | ) 269 | print(f"Finished Soft F1 evaluation for {args.engine} on {args.sql_dialect} set") 270 | print("\n\n") 271 | 272 | # Record the end time and calculate duration 273 | end_time = datetime.now() 274 | print(f"End time: {end_time}") 275 | duration = end_time - start_time 276 | print(f"Duration: {duration}") 277 | 278 | # Saving soft F1 results in a txt file 279 | ex_result_file_path = args.predicted_sql_path + "soft_f1_score.txt" 280 | ex_file_content = f"The soft F1-Scores are: \nOverall Soft F1: {acc} \nSimple Soft F1: {simple_acc} \nModerate Soft F1: {moderate_acc} \nChallenging Soft F1: {challenging_acc} \n\nCount Lists: {count_lists} \n\nEvaluation Duration: {duration} \n\nMeta Time-out: {args.meta_time_out}" 281 | with open(ex_result_file_path, 'w') as f: 282 | f.write(ex_file_content) 283 | 284 | print("Soft F1 score is written into the soft_f1_score.txt file.") 285 | 286 | 287 | 288 | -------------------------------------------------------------------------------- /evaluation/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | # import psycopg2 3 | # import pymysql 4 | import sqlite3 5 | 6 | 7 | def load_json(dir): 8 | with open(dir, "r") as j: 9 | contents = json.loads(j.read()) 10 | return contents 11 | 12 | 13 | def connect_postgresql(): 14 | # Open database connection 15 | # Connect to the database 16 | db = psycopg2.connect( 17 | "dbname=BIRD user=root host=localhost password=YOUR_PASSWORD port=5432" 18 | ) 19 | return db 20 | 21 | 22 | def connect_mysql(): 23 | # Open database connection 24 | # Connect to the database" 25 | db = pymysql.connect( 26 | host="localhost", 27 | user="root", 28 | password="YOUR_PASSWORD", 29 | database="BIRD", 30 | unix_socket="/tmp/mysql.sock", 31 | # port=3306, 32 | ) 33 | return db 34 | 35 | 36 | def connect_db(sql_dialect, db_path): 37 | if sql_dialect == "SQLite": 38 | conn = sqlite3.connect(db_path) 39 | elif sql_dialect == "MySQL": 40 | conn = connect_mysql() 41 | elif sql_dialect == "PostgreSQL": 42 | conn = connect_postgresql() 43 | else: 44 | raise ValueError("Unsupported SQL dialect") 45 | return conn 46 | 47 | 48 | def execute_sql(predicted_sql, ground_truth, db_path, sql_dialect, calculate_func): 49 | conn = connect_db(sql_dialect, db_path) 50 | # Connect to the database 51 | cursor = conn.cursor() 52 | cursor.execute(predicted_sql) 53 | predicted_res = cursor.fetchall() 54 | cursor.execute(ground_truth) 55 | ground_truth_res = cursor.fetchall() 56 | conn.close() 57 | res = calculate_func(predicted_res, ground_truth_res) 58 | return res 59 | 60 | 61 | def package_sqls( 62 | sql_path, db_root_path, engine, sql_dialect="SQLite", mode="gpt", data_mode="dev" 63 | ): 64 | clean_sqls = [] 65 | db_path_list = [] 66 | if mode == "gpt": 67 | # use chain of thought 68 | # sql_data = json.load( 69 | # open( 70 | # sql_path 71 | # + "predict_" 72 | # + data_mode 73 | # + "_" 74 | # + engine 75 | # + "_cot_" 76 | # + sql_dialect 77 | # + ".json", 78 | # "r", 79 | # ) 80 | # ) 81 | sql_data = json.load( 82 | open( 83 | sql_path 84 | + "predict_" 85 | + data_mode 86 | + ".json", 87 | "r", 88 | ) 89 | ) 90 | for _, sql_str in sql_data.items(): 91 | if type(sql_str) == str: 92 | sql, db_name = sql_str.split("\t----- bird -----\t") 93 | else: 94 | sql, db_name = " ", "financial" 95 | clean_sqls.append(sql) 96 | db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") 97 | 98 | elif mode == "gt": 99 | sqls = open(sql_path + data_mode + "_" + sql_dialect + "_gold.sql") 100 | sql_txt = sqls.readlines() 101 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt] 102 | for idx, sql_str in enumerate(sql_txt): 103 | # print(sql_str) 104 | sql, db_name = sql_str.strip().split("\t") 105 | clean_sqls.append(sql) 106 | db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") 107 | 108 | return clean_sqls, db_path_list 109 | 110 | 111 | def sort_results(list_of_dicts): 112 | return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) 113 | 114 | 115 | def print_data(score_lists, count_lists, metric="F1 Score"): 116 | levels = ["simple", "moderate", "challenging", "total"] 117 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 118 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) 119 | 120 | print( 121 | f"====================================== {metric} =====================================" 122 | ) 123 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format(metric, *score_lists)) 124 | -------------------------------------------------------------------------------- /evaluation/evaluation_ves.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sys 3 | import json 4 | import numpy as np 5 | import argparse 6 | import multiprocessing as mp 7 | from func_timeout import func_timeout, FunctionTimedOut 8 | from evaluation_utils import ( 9 | load_json, 10 | execute_sql, 11 | package_sqls, 12 | sort_results, 13 | print_data, 14 | connect_db, 15 | ) 16 | import time 17 | import math 18 | 19 | 20 | def result_callback(result): 21 | exec_result.append(result) 22 | 23 | 24 | def clean_abnormal(input): 25 | input = np.asarray(input) 26 | processed_list = [] 27 | mean = np.mean(input, axis=0) 28 | std = np.std(input, axis=0) 29 | for x in input: 30 | if x < mean + 3 * std and x > mean - 3 * std: 31 | processed_list.append(x) 32 | return processed_list 33 | 34 | 35 | def execute_sql(sql, db_path, sql_dialect, return_time=False): 36 | # Connect to the database 37 | conn = connect_db(sql_dialect, db_path) 38 | start_time = time.time() 39 | cursor = conn.cursor() 40 | cursor.execute(sql) 41 | res = cursor.fetchall() 42 | conn.close() # Don't forget to close the connection! 43 | exec_time = time.time() - start_time 44 | if return_time: 45 | return exec_time 46 | 47 | return res 48 | 49 | 50 | def iterated_execute_sql( 51 | predicted_sql, ground_truth, db_path, iterate_num, sql_dialect 52 | ): 53 | diff_list = [] 54 | predicted_res = execute_sql(predicted_sql, db_path, sql_dialect) 55 | ground_truth_res = execute_sql(ground_truth, db_path, sql_dialect) 56 | reward = 0 57 | time_ratio = 0 58 | if set(predicted_res) == set(ground_truth_res): 59 | for _ in range(iterate_num): 60 | predicted_time = execute_sql( 61 | predicted_sql, db_path, sql_dialect, return_time=True 62 | ) 63 | ground_truth_time = execute_sql( 64 | ground_truth, db_path, sql_dialect, return_time=True 65 | ) 66 | diff_list.append(ground_truth_time / predicted_time) 67 | processed_diff_list = clean_abnormal(diff_list) 68 | time_ratio = sum(processed_diff_list) / len(processed_diff_list) 69 | if time_ratio == 0: 70 | reward = 0 71 | elif time_ratio >= 2: 72 | reward = 1.25 73 | elif time_ratio >= 1 and time_ratio < 2: 74 | reward = 1 75 | elif time_ratio >= 0.5 and time_ratio < 1: 76 | reward = 0.75 77 | elif time_ratio >= 0.25 and time_ratio < 0.5: 78 | reward = 0.5 79 | else: 80 | reward = 0.25 81 | # return time_ratio 82 | return reward 83 | 84 | 85 | def execute_model( 86 | predicted_sql, ground_truth, db_place, idx, iterate_num, meta_time_out, sql_dialect 87 | ): 88 | try: 89 | # you can personalize the total timeout number 90 | # larger timeout leads to more stable ves 91 | # while it needs more your patience.... 92 | reward = func_timeout( 93 | meta_time_out * iterate_num, 94 | iterated_execute_sql, 95 | args=(predicted_sql, ground_truth, db_place, iterate_num, sql_dialect), 96 | ) 97 | except KeyboardInterrupt: 98 | sys.exit(0) 99 | except FunctionTimedOut: 100 | result = [(f"timeout",)] 101 | reward = 0 102 | except Exception as e: 103 | result = [(f"error",)] # possibly len(query) > 512 or not executable 104 | reward = 0 105 | result = {"sql_idx": idx, "reward": reward} 106 | return result 107 | 108 | 109 | def run_sqls_parallel( 110 | sqls, 111 | db_places, 112 | num_cpus=1, 113 | iterate_num=100, 114 | meta_time_out=30.0, 115 | sql_dialect="SQLite", 116 | ): 117 | pool = mp.Pool(processes=num_cpus) 118 | for i, sql_pair in enumerate(sqls): 119 | predicted_sql, ground_truth = sql_pair 120 | pool.apply_async( 121 | execute_model, 122 | args=( 123 | predicted_sql, 124 | ground_truth, 125 | db_places[i], 126 | i, 127 | iterate_num, 128 | meta_time_out, 129 | sql_dialect, 130 | ), 131 | callback=result_callback, 132 | ) 133 | pool.close() 134 | pool.join() 135 | 136 | 137 | def compute_ves(exec_results): 138 | num_queries = len(exec_results) 139 | total_reward = 0 140 | count = 0 141 | 142 | for i, result in enumerate(exec_results): 143 | if result["reward"] != 0: 144 | count += 1 145 | total_reward += math.sqrt(result["reward"]) * 100 146 | ves = total_reward / num_queries 147 | return ves 148 | 149 | 150 | def compute_ves_by_diff(exec_results, diff_json_path): 151 | num_queries = len(exec_results) 152 | contents = load_json(diff_json_path) 153 | simple_results, moderate_results, challenging_results = [], [], [] 154 | for i, content in enumerate(contents): 155 | if content["difficulty"] == "simple": 156 | simple_results.append(exec_results[i]) 157 | if content["difficulty"] == "moderate": 158 | moderate_results.append(exec_results[i]) 159 | if content["difficulty"] == "challenging": 160 | challenging_results.append(exec_results[i]) 161 | simple_ves = compute_ves(simple_results) 162 | moderate_ves = compute_ves(moderate_results) 163 | challenging_ves = compute_ves(challenging_results) 164 | all_ves = compute_ves(exec_results) 165 | count_lists = [ 166 | len(simple_results), 167 | len(moderate_results), 168 | len(challenging_results), 169 | num_queries, 170 | ] 171 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists 172 | 173 | 174 | def print_reward_category(exec_results, engine, sql_dialect): 175 | res = { 176 | "engine": engine, 177 | "sql_dialect": sql_dialect, 178 | "distribution": exec_results, 179 | } 180 | file_path = "results.json" 181 | try: 182 | with open(file_path, "r") as file: 183 | data = json.load(file) 184 | except (FileNotFoundError, json.JSONDecodeError): 185 | data = [] # Start with an empty list if file doesn't exist or is empty 186 | 187 | # Append the new data 188 | data.append(res) 189 | 190 | # Write the updated data back to the file 191 | with open(file_path, "w") as file: 192 | json.dump(data, file, indent=4) 193 | 194 | 195 | if __name__ == "__main__": 196 | # Record the start time 197 | start_time = datetime.now() 198 | print(f"Start time: {start_time}") 199 | 200 | args_parser = argparse.ArgumentParser() 201 | args_parser.add_argument( 202 | "--predicted_sql_path", type=str, required=True, default="" 203 | ) 204 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 205 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 206 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 207 | args_parser.add_argument("--num_cpus", type=int, default=1) 208 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 209 | args_parser.add_argument("--mode_gt", type=str, default="gt") 210 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 211 | args_parser.add_argument("--diff_json_path", type=str, default="") 212 | args_parser.add_argument("--engine", type=str, default="") 213 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 214 | args = args_parser.parse_args() 215 | exec_result = [] 216 | 217 | pred_queries, db_paths = package_sqls( 218 | args.predicted_sql_path, 219 | args.db_root_path, 220 | args.engine, 221 | sql_dialect=args.sql_dialect, 222 | mode=args.mode_predict, 223 | data_mode=args.data_mode, 224 | ) 225 | # generate ground truth sqls: 226 | gt_queries, db_paths_gt = package_sqls( 227 | args.ground_truth_path, 228 | args.db_root_path, 229 | args.engine, 230 | sql_dialect=args.sql_dialect, 231 | mode="gt", 232 | data_mode=args.data_mode, 233 | ) 234 | query_pairs = list(zip(pred_queries, gt_queries)) 235 | run_sqls_parallel( 236 | query_pairs, 237 | db_places=db_paths, 238 | num_cpus=args.num_cpus, 239 | meta_time_out=args.meta_time_out, 240 | sql_dialect=args.sql_dialect, 241 | ) 242 | exec_result = sort_results(exec_result) 243 | # print_reward_category(exec_result, args.engine, args.sql_dialect) 244 | print("start calculate") 245 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = compute_ves_by_diff( 246 | exec_result, args.diff_json_path 247 | ) 248 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves] 249 | print(f"VES for {args.engine} on {args.sql_dialect} set") 250 | print("start calculate") 251 | print_data(score_lists, count_lists, metric="VES") 252 | print( 253 | "===========================================================================================" 254 | ) 255 | print(f"Finished VES evaluation for {args.engine} on {args.sql_dialect} set") 256 | print("\n\n") 257 | 258 | # Record the end time and calculate duration 259 | end_time = datetime.now() 260 | print(f"End time: {end_time}") 261 | duration = end_time - start_time 262 | print(f"Duration: {duration}") 263 | 264 | # Saving R-VES results in a txt file 265 | ex_result_file_path = args.predicted_sql_path + "r_ves_score.txt" 266 | ex_file_content = f"The R-VES Scores are: \nOverall R-VES: {ves} \nSimple R-VES: {simple_ves} \nModerate R-VES: {moderate_ves} \nChallenging R-VES: {challenging_ves} \n\nCount Lists: {count_lists} \n\nEvaluation Duration: {duration} \n\nMeta Time-out: {args.meta_time_out}" 267 | with open(ex_result_file_path, 'w') as f: 268 | f.write(ex_file_content) 269 | 270 | print("R-VES score is written into the r_ves_score.txt file.") 271 | 272 | 273 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import random 5 | from dotenv import load_dotenv 6 | from pipeline.Pipeline import * 7 | from utils.db_utils import * 8 | from utils.retrieval_utils import process_all_dbs 9 | from typing import Dict, Union, List, Tuple 10 | 11 | def main(args): 12 | load_dotenv() # load variables into os.environ 13 | create_result_files(args) # creating results directory for specific arguments 14 | 15 | bird_sql_path = os.getenv('BIRD_DB_PATH') 16 | args.dataset_path = bird_sql_path 17 | process_all_dbs(bird_sql_path, args.mode) # for all databases, creating db_description.csv files which include column descriptions for all talbes 18 | 19 | # set random seed 20 | random.seed(args.seed) 21 | 22 | # load dataset 23 | dataset_json_path = bird_sql_path + f"/{args.mode}/{args.mode}.json" 24 | f = open(dataset_json_path) 25 | dataset = json.load(f) 26 | 27 | pipeline = Pipeline(args) 28 | 29 | output_dict = {} 30 | predictions = [] 31 | # Incase error you can restart the code from the point of error using following lines 32 | # dataset = dataset[: ] 33 | # dataset = dataset[:] 34 | dataset = dataset[1135:] 35 | for ind,t2s_object in enumerate(dataset): 36 | q_id = t2s_object["question_id"] 37 | if pipeline.pipeline_order == "CSG-SR": 38 | t2s_object_prediction = pipeline.forward_pipeline_CSG_SR(t2s_object) 39 | elif pipeline.pipeline_order == "CSG-QE-SR": 40 | t2s_object_prediction = pipeline.forward_pipeline_CSG_QE_SR(t2s_object) 41 | elif pipeline.pipeline_order == "SF-CSG-QE-SR": 42 | t2s_object_prediction = pipeline.forward_pipeline_SF_CSG_QE_SR(t2s_object) 43 | else: 44 | raise ValueError("Wrong value for pipeline_order argument. It must be either CSG-QE-SR or CSG-SR.") 45 | 46 | # Compare predicted and ground truth sqls 47 | compare_results = check_correctness(t2s_object_prediction, args) 48 | t2s_object_prediction['results'] = compare_results 49 | if os.path.exists(args.prediction_json_path): 50 | # get existing predictions 51 | with open(args.prediction_json_path, 'r') as file_read: 52 | existing_predictions = json.load(file_read) 53 | 54 | # add new prediction to the existing predictions and then write to the file 55 | existing_predictions.append(t2s_object_prediction) 56 | with open(args.prediction_json_path, 'w') as file_write: 57 | json.dump(existing_predictions, file_write, indent=4) 58 | 59 | else: 60 | file_write = open(args.prediction_json_path, 'w') 61 | existing_predictions = [t2s_object_prediction] 62 | json.dump(existing_predictions, file_write, indent=4) 63 | file_write.close() 64 | 65 | # # add the current text2sql object to the predictions 66 | # predictions.append(t2s_object_prediction) 67 | # # writing prediction to the predictions.json file 68 | # with open(args.prediction_json_path, 'w') as f: 69 | # json.dump(predictions, f, indent=4) # indent=4 for pretty printing 70 | 71 | # adding predicted sql in the expected format for the evaluation files 72 | db_id = t2s_object_prediction["db_id"] 73 | predicted_sql = t2s_object_prediction["predicted_sql"] 74 | predicted_sql = predicted_sql.replace('\"','').replace('\\\n',' ').replace('\n',' ') 75 | sql = predicted_sql + '\t----- bird -----\t' + db_id 76 | output_dict[str(q_id)] = sql 77 | if os.path.exists(args.predictions_eval_json_path): 78 | with open(args.predictions_eval_json_path, 'r') as f: 79 | contents = json.loads(f.read()) 80 | else: 81 | # Initialize contents as an empty dictionary if the file doesn't exist 82 | contents = {} 83 | contents.update(output_dict) 84 | json.dump(contents, open(args.predictions_eval_json_path, 'w'), indent=4) 85 | 86 | print(f"Question with {q_id} is processed. Correctness: {compare_results['exec_res']} ") 87 | 88 | # Calculatin Metrics 89 | predictions_json_file = open(args.prediction_json_path, 'r') 90 | predictions = json.load(predictions_json_file) 91 | stats, fail_q_ids = calculate_accuracies(predictions) 92 | metric_object = { 93 | "EX": stats["ex"], 94 | "total_correct_count": stats["total_correct_count"], 95 | "total_item_count": stats["total_item_count"], 96 | "simple_stats": stats["simple"], 97 | "moderate_stats": stats["moderate"], 98 | "challenging_stats": stats["challenging"], 99 | "fail_q_ids": fail_q_ids, 100 | "config": { 101 | "mode": args.mode, 102 | "model": args.model, 103 | "temperature": args.temperature, 104 | "top_p": args.top_p, 105 | "max_tokens": args.max_tokens, 106 | "n": args.n, 107 | "pipeline_order": args.pipeline_order, 108 | "enrichment_level": args.enrichment_level, 109 | "enrichment_level_shot_number": args.enrichment_level_shot_number, 110 | "enrichment_few_shot_schema_existance": args.enrichment_few_shot_schema_existance, 111 | "filtering_level_shot_number": args.filtering_level_shot_number, 112 | "filtering_few_shot_schema_existance": args.filtering_few_shot_schema_existance, 113 | "cfg": args.cfg, 114 | "generation_level_shot_number": args.generation_level_shot_number, 115 | "generation_few_shot_schema_existance": args.generation_few_shot_schema_existance, 116 | "db_sample_limit": args.db_sample_limit, 117 | "relevant_description_number": args.relevant_description_number, 118 | "seed": args.seed 119 | } 120 | } 121 | 122 | # Writing metrics to a file 123 | metrics_path = args.output_directory_path + "/metrics.json" 124 | # writing metric object 125 | with open(metrics_path, 'w') as f: 126 | json.dump(metric_object, f, indent=4) # indent=4 for pretty printing 127 | 128 | print("Metrics are written into metrics.json file.") 129 | return 130 | 131 | 132 | def calculate_accuracies(predictions: List[Dict]) -> Tuple[float, List]: 133 | """ 134 | The function calculates the Execution Accuracy(EX) metric and find the question IDs whose predictions are failed 135 | 136 | Arguments: 137 | predictions 138 | """ 139 | difficulty_flag = False 140 | failed_predictions_q_ids = [] 141 | stats = { 142 | "ex": 0, 143 | "total_correct_count": 0, 144 | "total_item_count": 0, 145 | "simple": { 146 | "correct_number": 0, 147 | "count": 0 148 | }, 149 | "moderate": { 150 | "correct_number": 0, 151 | "count": 0 152 | }, 153 | "challenging": { 154 | "correct_number": 0, 155 | "count": 0 156 | } 157 | } 158 | 159 | # check if there is difficulty key 160 | sample = predictions[0] 161 | if "difficulty" in sample: 162 | difficulty_flag = True 163 | else: 164 | difficulty_flag = False 165 | 166 | if difficulty_flag: 167 | for q2s_object in predictions: 168 | level = q2s_object['difficulty'] 169 | stats[level]["count"] = stats[level]["count"] + 1 170 | 171 | if q2s_object['results']['exec_res'] != 0: 172 | stats[level]['correct_number'] = stats[level]['correct_number'] + 1 173 | else: 174 | failed_predictions_q_ids.append(q2s_object['question_id']) 175 | 176 | stats["simple"]["ex"] = stats["simple"]["correct_number"] / stats["simple"]["count"] * 100 177 | stats["moderate"]["ex"] = stats["moderate"]["correct_number"] / stats["moderate"]["count"] * 100 178 | stats["challenging"]["ex"] = stats["challenging"]["correct_number"] / stats["challenging"]["count"] * 100 179 | 180 | stats["total_item_count"] = stats["simple"]["count"] + stats["moderate"]["count"] + stats["challenging"]["count"] 181 | stats["total_correct_count"] = stats["simple"]["correct_number"] + stats["moderate"]["correct_number"] + stats["challenging"]["correct_number"] 182 | stats["ex"] = stats["total_correct_count"] / stats["total_item_count"] * 100 183 | 184 | return (stats, failed_predictions_q_ids) 185 | 186 | else: 187 | 188 | for q2s_object in predictions: 189 | stats["total_item_count"] = stats["total_item_count"] + 1 190 | if q2s_object['results']['exec_res'] != 0: 191 | stats["total_correct_count"] = stats["total_correct_count"] + 1 192 | else: 193 | failed_predictions_q_ids.append(q2s_object['question_id']) 194 | 195 | stats["ex"] = stats["total_correct_count"] / stats["total_item_count"] * 100 196 | return (stats, failed_predictions_q_ids) 197 | 198 | def check_correctness(t2s_object_prediction: Dict, args) -> Dict[str, Union[int, str]]: 199 | """ 200 | The function check whether predicted SQL is correct or not 201 | 202 | Arguments: 203 | t2s_object_prediction () 204 | 205 | Returns: 206 | compare_results (Dict[str, Union[int, str]]): Comparison results dictionary with execution result and execution error keys 207 | """ 208 | db_id = t2s_object_prediction['db_id'] 209 | bird_sql_path = os.getenv('BIRD_DB_PATH') 210 | db_path = bird_sql_path+ f"/{args.mode}/{args.mode}_databases/{db_id}/{db_id}.sqlite" 211 | if 'predicted_sql' in t2s_object_prediction: 212 | predicted_sql = t2s_object_prediction['predicted_sql'] 213 | gt_sql = t2s_object_prediction['SQL'] 214 | compare_results = compare_sqls(db_path=db_path, predicted_sql=predicted_sql, ground_truth_sql=gt_sql ) 215 | else: 216 | compare_results = {'exec_res': 0, 'exec_err': "There is no predicted SQL. There must be and error in this question while extracting information."} 217 | 218 | return compare_results 219 | 220 | def create_result_files(args): 221 | """ 222 | The function creates result files according to arguments. 223 | """ 224 | 225 | # Ensure the results directory exist otherwise create it 226 | if not os.path.exists("./results"): 227 | os.makedirs("./results") 228 | 229 | args.output_directory_path = f"./results/model_outputs_{args.mode}_{args.pipeline_order}_{args.model}" 230 | 231 | # Ensure the directory exists 232 | if not os.path.exists(args.output_directory_path): 233 | os.makedirs(args.output_directory_path) 234 | 235 | # Overall predictions file 236 | prediction_json_path = args.output_directory_path + "/predictions.json" 237 | args.prediction_json_path = prediction_json_path 238 | # print("args.prediction_json_path: ", args.prediction_json_path) 239 | 240 | # Create an empty predictions.json file if not exist 241 | if not os.path.exists(prediction_json_path): 242 | with open(args.prediction_json_path, 'w') as f: 243 | json.dump([], f) # Initialize with an empty JSON object 244 | 245 | # predictions file for evaluation 246 | predictions_eval_json_path = args.output_directory_path + f"/predict_{args.mode}.json" 247 | args.predictions_eval_json_path = predictions_eval_json_path 248 | 249 | 250 | 251 | def str2bool(v: str) -> bool: 252 | """ 253 | The function converst string boolean to boolean 254 | 255 | Arguments: 256 | v (str): string boolean 257 | 258 | Returns: 259 | Bool: corresponding boolean variable 260 | """ 261 | if isinstance(v, bool): 262 | return v 263 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 264 | return True 265 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 266 | return False 267 | else: 268 | raise argparse.ArgumentTypeError('Boolean value expected.') 269 | 270 | if __name__ == '__main__': 271 | parser = argparse.ArgumentParser() 272 | # Running mode arguments 273 | parser.add_argument("--mode", default='dev', type=str, help="Either dev or test.") 274 | 275 | # Model Arguments 276 | parser.add_argument("--model", default="gpt-4o-mini-2024-07-18", type=str, help="OpenAI models.") 277 | parser.add_argument("--temperature", default=0.0, type=float, help="Sampling temperature between 0 to 2. It is recommended altering this or top_p but not both.") 278 | parser.add_argument("--top_p", default=1, type=float, help="Nucleus sampling. It is recommend altering this or temperature but not both") 279 | parser.add_argument("--max_tokens", default=2048, type=int, help="The maximum number of tokens that can be generated.") 280 | parser.add_argument("--n", default=1, type=int, help="How many chat completion choices to generate for each input message") 281 | 282 | # Pipeline Arguments 283 | parser.add_argument("-po", "--pipeline_order", default='EFG', type=str, help="The order of stages in the pipeline. It should be either EFG (enrichment --> filtering --> generation) or FEG (filtering --> enrichment --> generation)") 284 | 285 | # Question Enrichment Arguments 286 | parser.add_argument("-el", "--enrichment_level", default="complex", type=str, help="Defines the which enrichment is used in few-shot examples.It can be either basic or complex.") 287 | parser.add_argument("-elsn", "--enrichment_level_shot_number", default=3, type=int, help="The few-shot number for each difficulty level for question enrichment stage.") 288 | parser.add_argument("-efsse", "--enrichment_few_shot_schema_existance", default=False, type=str2bool, help="Database Schema usage for each few-shot examples in the question enrichment stage. Default False.") 289 | 290 | # Schema Filtering Arguments 291 | parser.add_argument("-flsn", "--filtering_level_shot_number", default=3, type=int, help="The few-shot number for each difficulty level for schema filtering stage.") 292 | parser.add_argument("-ffsse", "--filtering_few_shot_schema_existance", default=False, type=str2bool, help="Database Schema usage for each few-shot examples in the schema filtering stage. Default False.") 293 | 294 | # SQL Generation Arguments 295 | parser.add_argument("--cfg", default=True, type=str2bool, help="Whether Context-Free-Grammer or SQL Template will be used. Default is True.") 296 | parser.add_argument("-glsn", "--generation_level_shot_number", default=3, type=int, help="The few-shot number for each difficulty level for SQL generation stage.") 297 | parser.add_argument("-gfsse", "--generation_few_shot_schema_existance", default=False, type=str2bool, help="Database Schema usage for each few-shot examples in the SQL generation stage. Default False.") 298 | 299 | # db sample number 300 | parser.add_argument("--db_sample_limit", default=5, type=int, help="The number of value extracted for a column for database samples.") 301 | # question relevant database item/column description number 302 | parser.add_argument("-rdn", "--relevant_description_number", default=6, type=int, help="The number of database item/column descriptions added to a prompt.") 303 | # custom seed argument 304 | parser.add_argument("--seed", default=42, type=int, help="Random seed") 305 | 306 | args = parser.parse_args() 307 | main(args) -------------------------------------------------------------------------------- /pipeline/Pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from utils.prompt_utils import * 4 | from utils.db_utils import * 5 | from utils.openai_utils import create_response 6 | from typing import Dict, List 7 | 8 | class Pipeline(): 9 | def __init__(self, args): 10 | # Running mode attributes 11 | self.mode = args.mode 12 | self.dataset_path = args.dataset_path 13 | 14 | # Pipeline attribute 15 | self.pipeline_order = args.pipeline_order 16 | 17 | # Model attributes 18 | self.model = args.model 19 | self.temperature = args.temperature 20 | self.top_p = args.top_p 21 | self.max_tokens = args.max_tokens 22 | self.n = args.n 23 | 24 | # Stages (enrichment, filtering, generation) attributes 25 | self.enrichment_level = args.enrichment_level 26 | self.elsn = args.enrichment_level_shot_number 27 | self.efsse = args.enrichment_few_shot_schema_existance 28 | 29 | self.flsn = args.filtering_level_shot_number 30 | self.ffsse = args.filtering_few_shot_schema_existance 31 | 32 | self.cfg = args.cfg 33 | self.glsn = args.generation_level_shot_number 34 | self.gfsse = args.generation_few_shot_schema_existance 35 | 36 | self.db_sample_limit = args.db_sample_limit 37 | self.rdn = args.relevant_description_number 38 | 39 | self.seed = args.seed 40 | 41 | def convert_message_content_to_dict(self, response_object: Dict) -> Dict: 42 | """ 43 | The function gets a LLM response object, and then it converts the content of it to the Python object. 44 | 45 | Arguments: 46 | response_object (Dict): LLM response object 47 | Returns: 48 | response_object (Dict): Response object whose content changed to dictionary 49 | """ 50 | 51 | response_object.choices[0].message.content = json.loads(response_object.choices[0].message.content) 52 | return response_object 53 | 54 | 55 | def forward_pipeline_CSG_SR(self, t2s_object: Dict) -> Dict[str, str]: 56 | """ 57 | The function runs Candidate SQL Generation(CSG) and SQL Refinement(SR) modules respectively without any question enrichment or filtering stages. 58 | 59 | Arguments: 60 | t2s_object (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc. 61 | Returns: 62 | t2s_object_prediction (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc and also stores information after each stage 63 | """ 64 | db_id = t2s_object["db_id"] 65 | q_id = t2s_object["question_id"] 66 | evidence = t2s_object["evidence"] 67 | question = t2s_object["question"] 68 | 69 | bird_sql_path = os.getenv('BIRD_DB_PATH') 70 | db_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/{db_id}.sqlite" 71 | db_description_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/database_description" 72 | db_descriptions = question_relevant_descriptions_prep(database_description_path=db_description_path, question=question, relevant_description_number=self.rdn) 73 | database_column_meaning_path = bird_sql_path + f"/{self.mode}/column_meaning.json" 74 | db_column_meanings = db_column_meaning_prep(database_column_meaning_path, db_id) 75 | db_descriptions = db_descriptions + "\n\n" + db_column_meanings 76 | 77 | # extracting original schema dictionary 78 | original_schema_dict = get_schema_tables_and_columns_dict(db_path=db_path) 79 | 80 | t2s_object["question_enrichment"] = "No Question Enrichment" 81 | 82 | ### STAGE 1: Candidate SQL GENERATION 83 | # -- Original question is used 84 | # -- Original Schema is used 85 | sql_generation_response_obj = self.candidate_sql_generation_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions) 86 | try: 87 | possible_sql = sql_generation_response_obj.choices[0].message.content['SQL'] 88 | t2s_object["candidate_sql_generation"] = { 89 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 90 | "possible_sql": possible_sql, 91 | "exec_err": "", 92 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens, 93 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens, 94 | "total_tokens": sql_generation_response_obj.usage.total_tokens, 95 | } 96 | t2s_object["possible_sql"] = possible_sql 97 | # execute SQL 98 | try: 99 | possible_respose = func_timeout(30, execute_sql, args=(db_path, possible_sql)) 100 | except FunctionTimedOut: 101 | t2s_object['candidate_sql_generation']["exec_err"] = "timeout" 102 | except Exception as e: 103 | t2s_object['candidate_sql_generation']["exec_err"] = str(e) 104 | except Exception as e: 105 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}") 106 | t2s_object["candidate_sql_generation"] = { 107 | "sql_generation_reasoning": "", 108 | "possible_sql": "", 109 | "prompt_tokens": 0, 110 | "completion_tokens": 0, 111 | "total_tokens": 0, 112 | } 113 | t2s_object["candidate_sql_generation"]["error"] = f"{e}" 114 | return t2s_object 115 | 116 | ### STAGE 2: SQL Refinement 117 | # -- Original question is used 118 | # -- Original Schema is used 119 | # -- Possible SQL is used 120 | # -- Possible Conditions is extracted from possible SQL and then used for augmentation 121 | # -- Execution Error for Possible SQL is used 122 | exec_err = t2s_object['candidate_sql_generation']["exec_err"] 123 | sql_generation_response_obj = self.sql_refinement_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions) 124 | try: 125 | predicted_sql = sql_generation_response_obj.choices[0].message.content['SQL'] 126 | t2s_object["sql_refinement"] = { 127 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 128 | "predicted_sql": predicted_sql, 129 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens, 130 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens, 131 | "total_tokens": sql_generation_response_obj.usage.total_tokens, 132 | } 133 | t2s_object["predicted_sql"] = predicted_sql 134 | except Exception as e: 135 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}") 136 | t2s_object["sql_refinement"] = { 137 | "sql_generation_reasoning": "", 138 | "predicted_sql": "", 139 | "prompt_tokens": 0, 140 | "completion_tokens": 0, 141 | "total_tokens": 0, 142 | } 143 | t2s_object["sql_refinement"]["error"] = f"{e}" 144 | return t2s_object 145 | 146 | # storing the usage for one question 147 | t2s_object["total_usage"] = { 148 | "prompt_tokens": t2s_object['candidate_sql_generation']['prompt_tokens'] + t2s_object['sql_refinement']['prompt_tokens'], 149 | "completion_tokens": t2s_object['candidate_sql_generation']['completion_tokens'] + t2s_object['sql_refinement']['completion_tokens'], 150 | "total_tokens": t2s_object['candidate_sql_generation']['total_tokens'] + t2s_object['sql_refinement']['total_tokens'] 151 | } 152 | 153 | t2s_object_prediction = t2s_object 154 | return t2s_object_prediction 155 | 156 | def forward_pipeline_CSG_QE_SR(self, t2s_object: Dict) -> Dict: 157 | """ 158 | The function performs Candidate SQL Generation(CSG), Quesiton Enrichment(QE) and SQL Refinement(SR) modules respectively without filtering stages. 159 | 160 | Arguments: 161 | t2s_object (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc. 162 | Returns: 163 | t2s_object_prediction (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc and also stores information after each stage 164 | """ 165 | db_id = t2s_object["db_id"] 166 | q_id = t2s_object["question_id"] 167 | evidence = t2s_object["evidence"] 168 | question = t2s_object["question"] 169 | 170 | bird_sql_path = os.getenv('BIRD_DB_PATH') 171 | db_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/{db_id}.sqlite" 172 | db_description_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/database_description" 173 | db_descriptions = question_relevant_descriptions_prep(database_description_path=db_description_path, question=question, relevant_description_number=self.rdn) 174 | database_column_meaning_path = bird_sql_path + f"/{self.mode}/column_meaning.json" 175 | db_column_meanings = db_column_meaning_prep(database_column_meaning_path, db_id) 176 | db_descriptions = db_descriptions + "\n\n" + db_column_meanings 177 | 178 | # extracting original schema dictionary 179 | original_schema_dict = get_schema_tables_and_columns_dict(db_path=db_path) 180 | 181 | ### STAGE 1: Candidate SQL GENERATION 182 | # -- Original question is used 183 | # -- Original Schema is used 184 | sql_generation_response_obj = self.candidate_sql_generation_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions) 185 | try: 186 | possible_sql = sql_generation_response_obj.choices[0].message.content['SQL'] 187 | t2s_object["candidate_sql_generation"] = { 188 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 189 | "possible_sql": possible_sql, 190 | "exec_err": "", 191 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens, 192 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens, 193 | "total_tokens": sql_generation_response_obj.usage.total_tokens, 194 | } 195 | t2s_object["possible_sql"] = possible_sql 196 | # execute SQL 197 | try: 198 | possible_respose = func_timeout(30, execute_sql, args=(db_path, possible_sql)) 199 | except FunctionTimedOut: 200 | t2s_object['candidate_sql_generation']["exec_err"] = "timeout" 201 | except Exception as e: 202 | t2s_object['candidate_sql_generation']["exec_err"] = str(e) 203 | except Exception as e: 204 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}") 205 | t2s_object["candidate_sql_generation"] = { 206 | "sql_generation_reasoning": "", 207 | "possible_sql": "", 208 | "prompt_tokens": 0, 209 | "completion_tokens": 0, 210 | "total_tokens": 0, 211 | } 212 | t2s_object["candidate_sql_generation"]["error"] = f"{e}" 213 | return t2s_object 214 | 215 | # Extract possible conditions dict list 216 | possible_conditions_dict_list = collect_possible_conditions(db_path=db_path, sql=possible_sql) 217 | possible_conditions = sql_possible_conditions_prep(possible_conditions_dict_list=possible_conditions_dict_list) 218 | 219 | ### STAGE 2: Question Enrichment: 220 | # -- Original question is used 221 | # -- Original schema is used 222 | # -- Possible conditions are used 223 | q_enrich_response_obj = self.question_enrichment_module(db_path=db_path, q_id=q_id, db_id=db_id, question=question, evidence=evidence, possible_conditions=possible_conditions, schema_dict=original_schema_dict, db_descriptions=db_descriptions) 224 | try: 225 | enriched_question = q_enrich_response_obj.choices[0].message.content['enriched_question'] 226 | enrichment_reasoning = q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning'] 227 | t2s_object["question_enrichment"] = { 228 | "enrichment_reasoning": q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 229 | "enriched_question": q_enrich_response_obj.choices[0].message.content['enriched_question'], 230 | "prompt_tokens": q_enrich_response_obj.usage.prompt_tokens, 231 | "completion_tokens": q_enrich_response_obj.usage.completion_tokens, 232 | "total_tokens": q_enrich_response_obj.usage.total_tokens, 233 | } 234 | enriched_question = question + enrichment_reasoning + enriched_question # This is added after experiment-24 235 | except Exception as e: 236 | logging.error(f"Error in reaching content from question enrichment response for question_id {q_id}: {e}") 237 | t2s_object["question_enrichment"] = { 238 | "enrichment_reasoning": "", 239 | "enriched_question": "", 240 | "prompt_tokens": 0, 241 | "completion_tokens": 0, 242 | "total_tokens": 0, 243 | } 244 | t2s_object["question_enrichment"]["error"] = f"{e}" 245 | enriched_question = question 246 | 247 | ### STAGE 3: SQL Refinement 248 | # -- Enriched question is used 249 | # -- Original Schema is used 250 | # -- Possible SQL is used 251 | # -- Possible Conditions is extracted from possible SQL and then used for augmentation 252 | # -- Execution Error for Possible SQL is used 253 | exec_err = t2s_object['candidate_sql_generation']["exec_err"] 254 | sql_generation_response_obj = self.sql_refinement_module(db_path=db_path, db_id=db_id, question=enriched_question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=original_schema_dict, db_descriptions=db_descriptions) 255 | try: 256 | predicted_sql = sql_generation_response_obj.choices[0].message.content['SQL'] 257 | t2s_object["sql_refinement"] = { 258 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 259 | "predicted_sql": predicted_sql, 260 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens, 261 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens, 262 | "total_tokens": sql_generation_response_obj.usage.total_tokens, 263 | } 264 | t2s_object["predicted_sql"] = predicted_sql 265 | except Exception as e: 266 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}") 267 | t2s_object["sql_refinement"] = { 268 | "sql_generation_reasoning": "", 269 | "predicted_sql": "", 270 | "prompt_tokens": 0, 271 | "completion_tokens": 0, 272 | "total_tokens": 0, 273 | } 274 | t2s_object["sql_refinement"]["error"] = f"{e}" 275 | return t2s_object 276 | 277 | # storing the usage for one question 278 | t2s_object["total_usage"] = { 279 | "prompt_tokens": t2s_object['candidate_sql_generation']['prompt_tokens'] + t2s_object['question_enrichment']['prompt_tokens'] + t2s_object['sql_refinement']['prompt_tokens'], 280 | "completion_tokens": t2s_object['candidate_sql_generation']['completion_tokens'] + t2s_object['question_enrichment']['completion_tokens'] + t2s_object['sql_refinement']['completion_tokens'], 281 | "total_tokens": t2s_object['candidate_sql_generation']['total_tokens'] + t2s_object['question_enrichment']['total_tokens'] + t2s_object['sql_refinement']['total_tokens'] 282 | } 283 | 284 | t2s_object_prediction = t2s_object 285 | return t2s_object_prediction 286 | 287 | 288 | def forward_pipeline_SF_CSG_QE_SR(self, t2s_object: Dict) -> Dict: 289 | """ 290 | The function performs, Schema Filtering(SF) Candidate SQL Generation(CSG), Quesiton Enrichment(QE) and SQL Refinement(SR) modules respectively without filtering stages. 291 | 292 | Arguments: 293 | t2s_object (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc. 294 | Returns: 295 | t2s_object_prediction (Dict): Python dictionary that stores information about a question like q_id, db_id, question, evidence etc and also stores information after each stage 296 | """ 297 | db_id = t2s_object["db_id"] 298 | q_id = t2s_object["question_id"] 299 | evidence = t2s_object["evidence"] 300 | question = t2s_object["question"] 301 | 302 | bird_sql_path = os.getenv('BIRD_DB_PATH') 303 | db_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/{db_id}.sqlite" 304 | db_description_path = bird_sql_path + f"/{self.mode}/{self.mode}_databases/{db_id}/database_description" 305 | db_descriptions = question_relevant_descriptions_prep(database_description_path=db_description_path, question=question, relevant_description_number=self.rdn) 306 | database_column_meaning_path = bird_sql_path + f"/{self.mode}/column_meaning.json" 307 | db_column_meanings = db_column_meaning_prep(database_column_meaning_path, db_id) 308 | db_descriptions = db_descriptions + "\n\n" + db_column_meanings 309 | 310 | # extracting original schema dictionary 311 | original_schema_dict = get_schema_tables_and_columns_dict(db_path=db_path) 312 | 313 | 314 | ### STAGE 1: FILTERING THE DATABASE SCHEMA 315 | # -- original question is used. 316 | # -- Original Schema is used. 317 | schema_filtering_response_obj = self.schema_filtering_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, schema_dict=original_schema_dict, db_descriptions=db_descriptions) 318 | # print("schema_filtering_response_obj: \n", schema_filtering_response_obj) 319 | try: 320 | t2s_object["schema_filtering"] = { 321 | "filtering_reasoning": schema_filtering_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 322 | "filtered_schema_dict": schema_filtering_response_obj.choices[0].message.content['tables_and_columns'], 323 | "prompt_tokens": schema_filtering_response_obj.usage.prompt_tokens, 324 | "completion_tokens": schema_filtering_response_obj.usage.completion_tokens, 325 | "total_tokens": schema_filtering_response_obj.usage.total_tokens, 326 | } 327 | except Exception as e: 328 | logging.error(f"Error in reaching content from schema filtering response for question_id {q_id}: {e}") 329 | t2s_object["schema_filtering"] = f"{e}" 330 | return t2s_object 331 | 332 | ### STAGE 1.1: FILTERED SCHEMA CORRECTION 333 | filtered_schema_dict = schema_filtering_response_obj.choices[0].message.content['tables_and_columns'] 334 | filtered_schema_dict, filtered_schema_problems = filtered_schema_correction(db_path=db_path, filtered_schema_dict=filtered_schema_dict) 335 | t2s_object["schema_filtering_correction"] = { 336 | "filtered_schema_problems": filtered_schema_problems, 337 | "final_filtered_schema_dict": filtered_schema_dict 338 | } 339 | 340 | schema_statement = generate_schema_from_schema_dict(db_path=db_path, schema_dict=filtered_schema_dict) 341 | t2s_object["create_table_statement"] = schema_statement 342 | 343 | ### STAGE 2: Candidate SQL GENERATION 344 | # -- Original question is used 345 | # -- Filtered Schema is used 346 | sql_generation_response_obj = self.candidate_sql_generation_module(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions) 347 | try: 348 | possible_sql = sql_generation_response_obj.choices[0].message.content['SQL'] 349 | t2s_object["candidate_sql_generation"] = { 350 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 351 | "possible_sql": possible_sql, 352 | "exec_err": "", 353 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens, 354 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens, 355 | "total_tokens": sql_generation_response_obj.usage.total_tokens, 356 | } 357 | t2s_object["possible_sql"] = possible_sql 358 | # execute SQL 359 | try: 360 | possible_respose = func_timeout(30, execute_sql, args=(db_path, possible_sql)) 361 | except FunctionTimedOut: 362 | t2s_object['candidate_sql_generation']["exec_err"] = "timeout" 363 | except Exception as e: 364 | t2s_object['candidate_sql_generation']["exec_err"] = str(e) 365 | except Exception as e: 366 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}") 367 | t2s_object["candidate_sql_generation"] = { 368 | "sql_generation_reasoning": "", 369 | "possible_sql": "", 370 | "prompt_tokens": 0, 371 | "completion_tokens": 0, 372 | "total_tokens": 0, 373 | } 374 | t2s_object["candidate_sql_generation"]["error"] = f"{e}" 375 | return t2s_object 376 | 377 | # Extract possible conditions dict list 378 | possible_conditions_dict_list = collect_possible_conditions(db_path=db_path, sql=possible_sql) 379 | possible_conditions = sql_possible_conditions_prep(possible_conditions_dict_list=possible_conditions_dict_list) 380 | 381 | ### STAGE 3: Question Enrichment: 382 | # -- Original question is used 383 | # -- Original schema is used 384 | # -- Possible conditions are used 385 | q_enrich_response_obj = self.question_enrichment_module(db_path=db_path, q_id=q_id, db_id=db_id, question=question, evidence=evidence, possible_conditions=possible_conditions, schema_dict=filtered_schema_dict, db_descriptions=db_descriptions) 386 | try: 387 | enriched_question = q_enrich_response_obj.choices[0].message.content['enriched_question'] 388 | enrichment_reasoning = q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning'] 389 | t2s_object["question_enrichment"] = { 390 | "enrichment_reasoning": q_enrich_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 391 | "enriched_question": q_enrich_response_obj.choices[0].message.content['enriched_question'], 392 | "prompt_tokens": q_enrich_response_obj.usage.prompt_tokens, 393 | "completion_tokens": q_enrich_response_obj.usage.completion_tokens, 394 | "total_tokens": q_enrich_response_obj.usage.total_tokens, 395 | } 396 | enriched_question = question + enrichment_reasoning + enriched_question # This is added after experiment-24 397 | except Exception as e: 398 | logging.error(f"Error in reaching content from question enrichment response for question_id {q_id}: {e}") 399 | t2s_object["question_enrichment"] = { 400 | "enrichment_reasoning": "", 401 | "enriched_question": "", 402 | "prompt_tokens": 0, 403 | "completion_tokens": 0, 404 | "total_tokens": 0, 405 | } 406 | t2s_object["question_enrichment"]["error"] = f"{e}" 407 | enriched_question = question 408 | 409 | ### STAGE 4: SQL Refinement 410 | # -- Enriched question is used 411 | # -- Original Schema is used 412 | # -- Possible SQL is used 413 | # -- Possible Conditions is extracted from possible SQL and then used for augmentation 414 | # -- Execution Error for Possible SQL is used 415 | exec_err = t2s_object['candidate_sql_generation']["exec_err"] 416 | sql_generation_response_obj = self.sql_refinement_module(db_path=db_path, db_id=db_id, question=enriched_question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions) 417 | try: 418 | predicted_sql = sql_generation_response_obj.choices[0].message.content['SQL'] 419 | t2s_object["sql_refinement"] = { 420 | "sql_generation_reasoning": sql_generation_response_obj.choices[0].message.content['chain_of_thought_reasoning'], 421 | "predicted_sql": predicted_sql, 422 | "prompt_tokens": sql_generation_response_obj.usage.prompt_tokens, 423 | "completion_tokens": sql_generation_response_obj.usage.completion_tokens, 424 | "total_tokens": sql_generation_response_obj.usage.total_tokens, 425 | } 426 | t2s_object["predicted_sql"] = predicted_sql 427 | except Exception as e: 428 | logging.error(f"Error in reaching content from sql generation response for question_id {q_id}: {e}") 429 | t2s_object["sql_refinement"] = { 430 | "sql_generation_reasoning": "", 431 | "predicted_sql": "", 432 | "prompt_tokens": 0, 433 | "completion_tokens": 0, 434 | "total_tokens": 0, 435 | } 436 | t2s_object["sql_refinement"]["error"] = f"{e}" 437 | return t2s_object 438 | 439 | # storing the usage for one question 440 | t2s_object["total_usage"] = { 441 | "prompt_tokens": t2s_object['candidate_sql_generation']['prompt_tokens'] + t2s_object['question_enrichment']['prompt_tokens'] + t2s_object['sql_refinement']['prompt_tokens'], 442 | "completion_tokens": t2s_object['candidate_sql_generation']['completion_tokens'] + t2s_object['question_enrichment']['completion_tokens'] + t2s_object['sql_refinement']['completion_tokens'], 443 | "total_tokens": t2s_object['candidate_sql_generation']['total_tokens'] + t2s_object['question_enrichment']['total_tokens'] + t2s_object['sql_refinement']['total_tokens'] 444 | } 445 | 446 | t2s_object_prediction = t2s_object 447 | return t2s_object_prediction 448 | 449 | 450 | def construct_question_enrichment_prompt(self, db_path: str, q_id: int, db_id: str, question: str, evidence: str, possible_conditions: str, schema_dict: Dict, db_descriptions: str) -> str: 451 | """ 452 | The function constructs the prompt required for the question enrichment stage 453 | 454 | Arguments: 455 | db_path (str): path to database sqlite file 456 | q_id (int): question id 457 | db_id (str): database ID, i.e. database name 458 | question (str): Natural language question 459 | evidence (str): evidence for the question 460 | possible_conditions (str): Possible conditions extracted from the previously generated possible SQL for the question 461 | schema_dict (Dict[str, List[str]]): database schema dictionary 462 | db_descriptions (str): Question relevant database item (column) descriptions 463 | 464 | Returns: 465 | prompt (str): Question enrichment prompt 466 | """ 467 | enrichment_template_path = os.path.join(os.getcwd(), "prompt_templates/question_enrichment_prompt_template.txt") 468 | question_enrichment_prompt_template = extract_question_enrichment_prompt_template(enrichment_template_path) 469 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json") 470 | q_enrich_few_shot_examples = question_enrichment_few_shot_prep(few_shot_data_path, q_id=q_id, q_db_id=db_id, level_shot_number=self.elsn, schema_existance=self.efsse, enrichment_level=self.enrichment_level, mode=self.mode) 471 | db_samples = extract_db_samples_enriched_bm25(question, evidence, db_path=db_path, schema_dict=schema_dict, sample_limit=self.db_sample_limit) 472 | schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=schema_dict) 473 | prompt = fill_question_enrichment_prompt_template(template=question_enrichment_prompt_template, schema=schema, db_samples=db_samples, question=question, possible_conditions=possible_conditions, few_shot_examples=q_enrich_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions) 474 | # print("question_enrichment_prompt: \n", prompt) 475 | return prompt 476 | 477 | def question_enrichment_module(self, db_path: str, q_id: int, db_id: str, question: str, evidence: str, possible_conditions: str, schema_dict: Dict, db_descriptions: str) -> Dict: 478 | """ 479 | The function enrich the given question using LLM. 480 | 481 | Arguments: 482 | db_path (str): path to database sqlite file 483 | q_id (int): question id 484 | db_id (str): database ID, i.e. database name 485 | question (str): Natural language question 486 | evidence (str): evidence for the question 487 | possible_conditions (str): possible conditions extracted from previously generated possible SQL for the question 488 | schema_dict (Dict[str, List[str]]): database schema dictionary 489 | db_descriptions (str): Question relevant database item (column) descriptions 490 | 491 | Returns: 492 | response_object (Dict): Response object returned by the LLM 493 | """ 494 | prompt = self.construct_question_enrichment_prompt(db_path=db_path, q_id=q_id, db_id=db_id, question=question, evidence=evidence, possible_conditions=possible_conditions, schema_dict=schema_dict, db_descriptions=db_descriptions) 495 | response_object = create_response(stage="question_enrichment", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n) 496 | try: 497 | response_object = self.convert_message_content_to_dict(response_object) 498 | except: 499 | return response_object 500 | 501 | return response_object 502 | 503 | def construct_candidate_sql_generation_prompt(self, db_path: str, db_id: int, question: str, evidence: str, filtered_schema_dict: Dict, db_descriptions: str)->str: 504 | """ 505 | The function constructs the prompt required for the candidate SQL generation stage. 506 | 507 | Arguments: 508 | db_path (str): The database sqlite file path. 509 | db_id (int): database ID, i.e. database name 510 | question (str): Natural language question 511 | evidence (str): evidence for the question 512 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names 513 | db_descriptions (str): Question relevant database item (column) descriptions 514 | 515 | Returns: 516 | prompt (str): prompt for SQL generation stage 517 | """ 518 | sql_generation_template_path = os.path.join(os.getcwd(), "prompt_templates/candidate_sql_generation_prompt_template.txt") 519 | with open(sql_generation_template_path, 'r') as f: 520 | sql_generation_template = f.read() 521 | 522 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json") 523 | sql_generation_few_shot_examples = sql_generation_and_refinement_few_shot_prep(few_shot_data_path, q_db_id=db_id, level_shot_number=self.glsn, schema_existance=self.gfsse, mode=self.mode) 524 | db_samples = extract_db_samples_enriched_bm25(question, evidence, db_path, schema_dict=filtered_schema_dict, sample_limit=self.db_sample_limit) 525 | filtered_schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=filtered_schema_dict) 526 | prompt = fill_candidate_sql_prompt_template(template=sql_generation_template, schema=filtered_schema, db_samples=db_samples, question=question, few_shot_examples=sql_generation_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions) 527 | # print("candidate_sql_prompt: \n", prompt) 528 | return prompt 529 | 530 | 531 | def construct_sql_refinement_prompt(self, db_path: str, db_id: int, question: str, evidence: str, possible_sql: str, exec_err: str, filtered_schema_dict: Dict, db_descriptions: str)->str: 532 | """ 533 | The function constructs the prompt required for the SQL refinement stage. 534 | 535 | Arguments: 536 | db_path (str): The database sqlite file path. 537 | db_id (int): database ID, i.e. database name 538 | question (str): Natural language question 539 | evidence (str): evidence for the question 540 | possible_sql (str): Previously generated possible SQL for the question 541 | exec_err (str): Taken execution error when possible SQL is executed 542 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names 543 | db_descriptions (str): Question relevant database item (column) descriptions 544 | 545 | Returns: 546 | prompt (str): prompt for SQL generation stage 547 | """ 548 | sql_generation_template_path = os.path.join(os.getcwd(), "prompt_templates/sql_refinement_prompt_template.txt") 549 | with open(sql_generation_template_path, 'r') as f: 550 | sql_generation_template = f.read() 551 | 552 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json") 553 | sql_generation_few_shot_examples = sql_generation_and_refinement_few_shot_prep(few_shot_data_path, q_db_id=db_id, level_shot_number=self.glsn, schema_existance=self.gfsse, mode=self.mode) 554 | possible_conditions_dict_list = collect_possible_conditions(db_path=db_path, sql=possible_sql) 555 | possible_conditions = sql_possible_conditions_prep(possible_conditions_dict_list=possible_conditions_dict_list) 556 | filtered_schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=filtered_schema_dict) 557 | prompt = fill_refinement_prompt_template(template=sql_generation_template, schema=filtered_schema, possible_conditions=possible_conditions, question=question, possible_sql=possible_sql, exec_err=exec_err, few_shot_examples=sql_generation_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions) 558 | # print("refinement_prompt: \n", prompt) 559 | return prompt 560 | 561 | def construct_filtering_prompt(self, db_path: str, db_id: str, question: str, evidence: str, schema_dict: Dict, db_descriptions: str)->str: 562 | """ 563 | The function constructs the prompt required for the database schema filtering stage 564 | 565 | Arguments: 566 | db_path (str): The database sqlite file path. 567 | db_id (str): database ID, i.e. database name 568 | question (str): Natural language question 569 | evidence (str): evidence for the question 570 | schema_dict (Dict[str, List[str]]): database schema dictionary 571 | db_descriptions (str): Question relevant database item (column) descriptions 572 | 573 | Returns: 574 | prompt (str): prompt for database schema filtering stage 575 | """ 576 | schema_filtering_prompt_template_path = os.path.join(os.getcwd(), "prompt_templates/schema_filter_prompt_template.txt") 577 | with open(schema_filtering_prompt_template_path, 'r') as f: 578 | schema_filtering_template = f.read() 579 | 580 | few_shot_data_path = os.path.join(os.getcwd(), "few-shot-data/question_enrichment_few_shot_examples.json") 581 | schema_filtering_few_shot_examples = schema_filtering_few_shot_prep(few_shot_data_path, q_db_id=db_id, level_shot_number=self.elsn, schema_existance=self.efsse, mode=self.mode) 582 | db_samples = extract_db_samples_enriched_bm25(question, evidence, db_path=db_path, schema_dict=schema_dict, sample_limit=self.db_sample_limit) 583 | schema = generate_schema_from_schema_dict(db_path=db_path, schema_dict=schema_dict) 584 | prompt = fill_prompt_template(template=schema_filtering_template, schema=schema, db_samples=db_samples, question=question, few_shot_examples=schema_filtering_few_shot_examples, evidence=evidence, db_descriptions=db_descriptions) 585 | # print("\nSchema Filtering Prompt: \n", prompt) 586 | 587 | return prompt 588 | 589 | 590 | def candidate_sql_generation_module(self, db_path: str, db_id: int, question: str, evidence: str, filtered_schema_dict: Dict, db_descriptions: str): 591 | """ 592 | This function generates candidate SQL for answering the question. 593 | 594 | Arguments: 595 | db_path (str): The database sqlite file path. 596 | db_id (int): database ID, i.e. database name 597 | question (str): Natural language question 598 | evidence (str): evidence for the question 599 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names 600 | db_descriptions (str): Question relevant database item (column) descriptions 601 | 602 | Returns: 603 | response_object (Dict): Response object returned by the LLM 604 | """ 605 | prompt = self.construct_candidate_sql_generation_prompt(db_path=db_path, db_id=db_id, question=question, evidence=evidence, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions) 606 | response_object = create_response(stage="candidate_sql_generation", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n) 607 | try: 608 | response_object = self.convert_message_content_to_dict(response_object) 609 | except: 610 | return response_object 611 | 612 | return response_object 613 | 614 | 615 | def sql_refinement_module(self, db_path: str, db_id: int, question: str, evidence: str, possible_sql: str, exec_err: str, filtered_schema_dict: Dict, db_descriptions: str): 616 | """ 617 | This function refines or re-generates a SQL query for answering the question. 618 | Possible SQL query, possible conditions generated from possible SQL query and execution error if it is exist are leveraged for better SQL refinement. 619 | 620 | Arguments: 621 | db_path (str): The database sqlite file path. 622 | db_id (int): database ID, i.e. database name 623 | question (str): Natural language question 624 | evidence (str): evidence for the question 625 | possible_sql (str): Previously generated possible SQL query for the question 626 | exec_err (str): Taken execution error when possible SQL is executed 627 | filtered_schema_dict (Dict[str, List[str]]): filtered database as dictionary where keys are the tables and the values are the list of column names 628 | db_descriptions (str): Question relevant database item (column) descriptions 629 | 630 | Returns: 631 | response_object (Dict): Response object returned by the LLM 632 | """ 633 | prompt = self.construct_sql_refinement_prompt(db_path=db_path, db_id=db_id, question=question, evidence=evidence, possible_sql=possible_sql, exec_err=exec_err, filtered_schema_dict=filtered_schema_dict, db_descriptions=db_descriptions) 634 | response_object = create_response(stage="sql_refinement", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n) 635 | try: 636 | response_object = self.convert_message_content_to_dict(response_object) 637 | except: 638 | return response_object 639 | 640 | return response_object 641 | 642 | 643 | def schema_filtering_module(self, db_path: str, db_id: str, question: str, evidence: str, schema_dict: Dict, db_descriptions: str): 644 | """ 645 | The function filters the database schema by eliminating the unnecessary tables and columns 646 | 647 | Arguments: 648 | db_path (str): The database sqlite file path. 649 | db_id (str): database ID, i.e. database name 650 | question (str): Natural language question 651 | evidence (str): evidence for the question 652 | schema_dict (Dict[str, List[str]]): database schema dictionary 653 | db_descriptions (str): Question relevant database item (column) descriptions 654 | 655 | Returns: 656 | response_object (Dict): Response object returned by the LLM 657 | """ 658 | prompt = self.construct_filtering_prompt(db_path=db_path, db_id=db_id, question=question, evidence=evidence, schema_dict=schema_dict, db_descriptions=db_descriptions) 659 | response_object = create_response(stage="schema_filtering", prompt=prompt, model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, n=self.n) 660 | try: 661 | response_object = self.convert_message_content_to_dict(response_object) 662 | except: 663 | return response_object 664 | 665 | return response_object 666 | 667 | 668 | -------------------------------------------------------------------------------- /prompt_templates/candidate_sql_generation_prompt_template.txt: -------------------------------------------------------------------------------- 1 | ### You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, samples and evidence. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question. 2 | 3 | ### Follow the instructions below: 4 | # Step 1 - Read the Question and Evidence Carefully: Understand the primary focus and specific details of the question. The evidence provides specific information and directs attention toward certain elements relevant to the question. 5 | # Step 2 - Analyze the Database Schema: Database Column descriptions and Database Sample Values: Examine the database schema, database column descriptions and sample values. Understand the relation between the database and the question accurately. 6 | # Step 3 - Generate SQL query: Write SQLite SQL query corresponding to the given question by combining the sense of question, evidence and database items. 7 | 8 | {FEWSHOT_EXAMPLES} 9 | 10 | ### Task: Given the following question, database schema and evidence, generate SQLite SQL query in order to answer the question. 11 | ### Make sure to keep the original wording or terms from the question, evidence and database items. 12 | ### Make sure each table name and column name in the generated SQL is enclosed with backtick seperately. 13 | ### Ensure the generated SQL is compatible with the database schema. 14 | ### When constructing SQL queries that require determining a maximum or minimum value, always use the `ORDER BY` clause in combination with `LIMIT 1` instead of using `MAX` or `MIN` functions in the `WHERE` clause.Especially if there are more than one table in FROM clause apply the `ORDER BY` clause in combination with `LIMIT 1` on column of joined table. 15 | ### Make sure the parentheses in the SQL are placed correct especially if the generated SQL includes mathematical expression. Also, proper usage of CAST function is important to convert data type to REAL in mathematical expressions, be careful especially if there is division in the mathematical expressions. 16 | ### Ensure proper handling of null values by including the `IS NOT NULL` condition in SQL queries, but only in cases where null values could affect the results or cause errors, such as during division operations or when null values would lead to incorrect filtering of results. Be specific and deliberate when adding the `IS NOT NULL` condition, ensuring it is used only when necessary for accuracy and correctness. . This is crucial to avoid errors and ensure accurate results. This is crucial to avoid errors and ensure accurate results. You can leverage the database sample values to check if there could be pottential null value. 17 | 18 | 19 | {SCHEMA} 20 | {DB_DESCRIPTIONS} 21 | {DB_SAMPLES} 22 | {QUESTION} 23 | {EVIDENCE} 24 | 25 | ### Please respond with a JSON object structured as follows: 26 | 27 | ```json{{"chain_of_thought_reasoning": "Explanation of the logical analysis and steps that result in the final SQLite SQL query.", "SQL": "Generated SQL query as a single string"}}``` 28 | 29 | Let's think step by step and generate SQLite SQL query. -------------------------------------------------------------------------------- /prompt_templates/question_enrichment_prompt_template.txt: -------------------------------------------------------------------------------- 1 | ### You are excellent data scientist and can link the information between a question and corresponding database perfectly. Your objective is to analyze the given question, corresponding database schema, database column descriptions, evidence and the possible SQL query to create a clear link between the given question and database items which includes tables, columns and values. With the help of link, rewrite new versions of the original question to be more related with database items, understandable, clear, absent of irrelevant information and easier to translate into SQL queries. This question enrichment is essential for comprehending the question's intent and identifying the related database items. The process involves pinpointing the relevant database components and expanding the question to incorporate these items. 2 | 3 | ### Follow the instructions below: 4 | # Step 1 - Read the Question Carefully: Understand the primary focus and specific details of the question. Identify named entities (such as organizations, locations, etc.), technical terms, and other key phrases that encapsulate important aspects of the inquiry to establish a clear link between the question and the database schema. 5 | # Step 2 - Analyze the Database Schema: With the Database samples, examine the database schema to identify relevant tables, columns, and values that are pertinent to the question. Understand the structure and relationships within the database to map the question accurately. 6 | # Step 3 - Review the Database Column Descriptions: The database column descriptions give the detailed information about some of the columns of the tables in the database. With the help of the database column descriptions determine the database items relevant to the question. Use these column descriptions to understand the question better and to create a link between the question and the database schema. 7 | # Step 4 - Analyze and Observe The Database Sample Values: Examine the sample values from the database to analyze the distinct elements within each column of the tables. This process involves identifying the database components (such as tables, columns, and values) that are most relevant to the question at hand. Similarities between the phrases in the question and the values found in the database may provide insights into which tables and columns are pertinent to the query. 8 | # Step 5 - Review the Evidence: The evidence provides specific information and directs attention toward certain elements relevant to the question and its answer. Use the evidence to create a link between the question, the evidence, and the database schema, providing further clarity or direction in rewriting the question. 9 | # Step 6 - Analyze the Possible SQL Conditinos: Analize the given possible SQL conditions that are relavant to the question and identify relation between the question components, phrases and keywords. 10 | # Step 7 - Identify Relevant Database Components: Pinpoint the tables, columns, and values in the database that are directly related to the question. 11 | # Step 8 - Rewrite the Question: Expand and refine the original question in detail to incorporate the identified database items (tables, columns and values) and conditions. Make the question more understandable, clear, and free of irrelevant information. 12 | 13 | {FEWSHOT_EXAMPLES} 14 | 15 | ### Task: Given the following question, database schema, database column descriptions, database samples and evidence, expand the original question in detail to incorporate the identified database components and SQL steps like examples given above. Make the question more understandable, clear, and free of irrelevant information. 16 | ### Ensure that question is expanded with original database items. Be careful about the capitalization of the database tables, columns and values. Use tables and columns in database schema. 17 | 18 | {SCHEMA} 19 | {DB_DESCRIPTIONS} 20 | {DB_SAMPLES} 21 | {POSSIBLE_CONDITIONS} 22 | {QUESTION} 23 | {EVIDENCE} 24 | 25 | 26 | ### Please respond with a JSON object structured as follows: 27 | 28 | ```json{{"chain_of_thought_reasoning": "Detail explanation of the logical analysis that led to the refined question, considering detailed possible sql generation steps", "enriched_question": "Expanded and refined question which is more understandable, clear and free of irrelevant information."}}``` 29 | 30 | Let's think step by step and refine the given question capturing the essence of both the question, database schema, database descriptions, evidence and possible SQL conditions through the links between them. If you do the task correctly, I will give you 1 million dollars. Only output a json as your response. 31 | -------------------------------------------------------------------------------- /prompt_templates/schema_filter_prompt_template.txt: -------------------------------------------------------------------------------- 1 | ### You are an excellent data scientist. You can capture the link between a question and corresponding database and determine the useful database items (tables and columns) perfectly. Your objective is to analyze and understand the essence of the given question, corresponding database schema, database column descriptions, samples and evidence and then select the useful database items such as tables and columns. This database item filtering is essential for eliminating unnecessary information in the database so that corresponding structured query language (SQL) of the question can be generated correctly in later steps. 2 | 3 | ### Follow the instructions below step by step: 4 | # Step 1 - Read the Question Carefully: Understand the primary focus and specific details of the question. Identify named entities (such as organizations, locations, etc.), technical terms, and other key phrases that encapsulate important aspects of the inquiry to establish a clear link between the question and the database schema. 5 | # Step 2 - Analyze the Database Schema: With the database samples, examine the database schema to identify relevant tables, columns, and values that are pertinent to the question. Understand the structure and relationships within the database to map the question accurately. 6 | # Step 3 - Review the Database Column Descriptions: The database column descriptions give the detailed information about some of the columns of the tables in the database. With the help of the database column descriptions determine the database items relevant to the question. Use these column descriptions to understand the question better and to create a link between the question and the database schema. 7 | # Step 4 - Analyze and Observe The Database Sample Values: Examine the sample values from the database to analyze the distinct elements within each column of the tables. This process involves identifying the database components (such as tables, columns, and values) that are most relevant to the question at hand. Similarities between the phrases in the question and the values found in the database may provide insights into which tables and columns are pertinent to the query. 8 | # Step 5 - Review the Evidence: The evidence provides specific information and directs attention toward certain elements relevant to the question and its answer. Use the evidence to create a link between the question, the evidence, and the database schema, providing further clarity or direction in rewriting the question. 9 | # Step 6 - Identify Relevant Database Components: Pinpoint the tables, columns, and values in the database that are directly related to the question. Ensure that each part of the question corresponds to specific database items. 10 | # Step 7 - Select Useful Database Tables and Columns: Select only the useful database tables and columns of selected tables by fusing the detailed information, key points of the question, database schema and evidence. 11 | 12 | {FEWSHOT_EXAMPLES} 13 | 14 | ### Task: Given the following question, database schema, database column descriptions and evidence, select only the necessary and useful database tables, and necessary and useful columns of selected tables to filter the database items. 15 | ### Make sure to keep the original terms from database items. 16 | ### Make sure the selected columns belong to the correct database table in your response. 17 | 18 | {SCHEMA} 19 | {DB_DESCRIPTIONS} 20 | {DB_SAMPLES} 21 | {QUESTION} 22 | {EVIDENCE} 23 | 24 | ### Please respond with a JSON object structured as follows: 25 | 26 | ```json{{"chain_of_thought_reasoning": "Explanation of the logical analysis that led to the selected useful database items.", "tables_and_columns": {{"table_name1": ["column1", "column2", ...], "table_name2": ["column1", ...], ...}} }}``` 27 | 28 | Let's think step by step and select only the necessary and useful database tables, and select only the necessary and useful columns of selected tables to filter the database items. If you do the task correctly, I will give you 1 million dollars. Only output a json as your response. -------------------------------------------------------------------------------- /prompt_templates/sql_refinement_prompt_template.txt: -------------------------------------------------------------------------------- 1 | ### You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, evidence, possible SQL and possible conditions. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question. 2 | 3 | ### Follow the instructions below: 4 | # Step 1 - Read the Question and Evidence: Understand the primary focus and specific details of the question. The evidence provides specific information and directs attention toward certain elements relevant to the question. 5 | # Step 2 - Analyze the Database Schema, Database Column descriptions: Examine the database schema, database column descriptions which provides information about the database columns. Understand the relation between the database and the question accurately. 6 | # Step 3 - Analyze the Possible SQL Query: Analize the possible SQLite SQL query and identify possible mistakes leads incorrect result such as missing or wrong conditions, wrong functions, misuse of aggregate functions, wrong sql syntax, unrecognized tokens or ambiguous columns. 7 | # Step 4 - Investigate Possible Conditions and Execution Errors: Carefully consider the list of possible conditions which are completely compatible with the database schema and given in the form of .. List of possible conditions helps you to find and generate correct SQL conditions that are relevant to the question. If the given possible SQL query gives execution error, it will be given. Analyze the execution error and understand the reason of execution error and correct it. 8 | # Step 5 - Finalize the SQL query: Construct correct SQLite SQL query or improve possible SQLite SQL query corresponding to the given question by combining the sense of question, evidence, and possible conditions. 9 | # Step 6 - Validation and Syntax Check: Before finalizing, verify that generated SQL query is coherent with the database schema, all referenced columns exist in the referenced table, all joins are correctly formulated, aggregation logic is accurate, and the SQL syntax is correct. 10 | 11 | ### Task: Given the following question, database schema and descriptions, evidence, possible SQL query and possible conditions; finalize SQLite SQL query in order to answer the question. 12 | ### Ensure that the SQL query accurately reflects the relationships between tables, using appropriate join conditions to combine data where necessary. 13 | ### When using aggregate functions (e.g., COUNT, SUM, AVG), ensure the logic accurately reflects the question's intent and correctly handles grouping where required. 14 | ### Double-check that all WHERE clauses accurately represent the conditions needed to filter the data as per the question's requirements. 15 | ### Make sure to keep the original wording or terms from the question, evidence and database items. 16 | ### Make sure each table name and column name in the generated SQL is enclosed with backtick seperately. 17 | ### Be careful about the capitalization of the database tables, columns and values. Use tables and columns in database schema. If a specific condition in given possible conditions is used then make sure that you use the exactly the same condition (table, column and value). 18 | ### When constructing SQL queries that require determining a maximum or minimum value, always use the `ORDER BY` clause in combination with `LIMIT 1` instead of using `MAX` or `MIN` functions in the `WHERE` clause. Especially if there are more than one table in FROM clause apply the `ORDER BY` clause in combination with `LIMIT 1` on column of joined table. 19 | ### Make sure the parentheses in the SQL are placed correct especially if the generated SQL includes mathematical expression. Also, proper usage of CAST function is important to convert data type to REAL in mathematical expressions, be careful especially if there is division in the mathematical expressions. 20 | ### Ensure proper handling of null values by including the `IS NOT NULL` condition in SQL queries, but only in cases where null values could affect the results or cause errors, such as during division operations or when null values would lead to incorrect filtering of results. Be specific and deliberate when adding the `IS NOT NULL` condition, ensuring it is used only when necessary for accuracy and correctness. . This is crucial to avoid errors and ensure accurate results. 21 | 22 | 23 | 24 | {SCHEMA} 25 | {DB_DESCRIPTIONS} 26 | {QUESTION} 27 | {EVIDENCE} 28 | {POSSIBLE_CONDITIONS} 29 | {POSSIBLE_SQL_Query} 30 | {EXECUTION_ERROR} 31 | 32 | ### Please respond with a JSON object structured as follows: 33 | 34 | ```json{{"chain_of_thought_reasoning": "Explanation of the logical analysis and steps that result in the final SQLite SQL query.", "SQL": "Finalized SQL query as a single string"}}``` 35 | 36 | Let's think step by step and generate SQLite SQL query. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | python-dotenv==1.0.1 3 | sqlglot==25.7.1 4 | rank-bm25==0.2.2 5 | nltk 6 | func_timeout==4.3.5 7 | openai==1.37.1 8 | numpy 9 | pandas 10 | sentencepiece 11 | psycopg2-binary -------------------------------------------------------------------------------- /run_evaluation.sh: -------------------------------------------------------------------------------- 1 | db_root_path='../dataset/bird-sql/dev/dev_databases/' 2 | data_mode='dev' # dev, train, mini_dev 3 | diff_json_path='../dataset/bird-sql/dev/dev.json' # _sqlite.json, _mysql.json, _postgresql.json 4 | # Path where the predicted SQL queries are stored 5 | predicted_sql_path='./results/model_outputs_dev_SF-CSG-QE-SR_gpt-4o-mini-2024-07-18/' 6 | 7 | ground_truth_path='../dataset/bird-sql/dev/' 8 | num_cpus=72 9 | meta_time_out=60.0 10 | mode_gt='gt' 11 | mode_predict='gpt' 12 | 13 | # Choose the engine to run, e.g. gpt-4, gpt-4-32k, gpt-4-turbo, gpt-35-turbo, GPT35-turbo-instruct 14 | engine='gpt-4o' 15 | 16 | 17 | # Choose the SQL dialect to run, e.g. SQLite, MySQL, PostgreSQL 18 | # PLEASE NOTE: You have to setup the database information in evaluation_utils.py 19 | # if you want to run the evaluation script using MySQL or PostgreSQL 20 | sql_dialect='SQLite' 21 | 22 | echo "starting to compare with knowledge for soft-f1 engine: ${engine} sql_dialect: ${sql_dialect} meta_time_out: ${meta_time_out}" 23 | python3 -u ./evaluation/evaluation_f1.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \ 24 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \ 25 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --engine ${engine} --sql_dialect ${sql_dialect} 26 | 27 | echo "starting to compare with knowledge for ex engine: ${engine} sql_dialect: ${sql_dialect} meta_time_out: ${meta_time_out}" 28 | python3 -u ./evaluation/evaluation_ex.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \ 29 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \ 30 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --engine ${engine} --sql_dialect ${sql_dialect} 31 | 32 | 33 | echo "starting to compare with knowledge for ves engine: ${engine} sql_dialect: ${sql_dialect} meta_time_out: ${meta_time_out}" 34 | python3 -u ./evaluation/evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \ 35 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \ 36 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --engine ${engine} --sql_dialect ${sql_dialect} 37 | 38 | -------------------------------------------------------------------------------- /run_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set default values for the arguments 4 | mode="dev" 5 | model="gpt-4o-mini-2024-07-18" # For GPT-4o use "gpt-4o-2024-08-06". For GPT-4o mini use "gpt-4o-mini-2024-07-18" 6 | pipeline_order="SF-CSG-QE-SR" # First set CSG-QE-SR and then set CSG-SR 7 | 8 | # DO NOT CHANGE THE F0LLOWING ARGUMENTS 9 | temperature=0.0 10 | top_p=1.0 11 | max_tokens=4096 12 | n=1 13 | enrichment_level="complex" 14 | enrichment_level_shot_number=3 15 | enrichment_few_shot_schema_existance=False 16 | filtering_level_shot_number=3 17 | filtering_few_shot_schema_existance=False 18 | cfg=True 19 | generation_level_shot_number=3 20 | generation_few_shot_schema_existance=False 21 | db_sample_limit=10 22 | relevant_description_number=20 23 | seed=42 24 | 25 | # Parse command line arguments to override default values 26 | while [ "$#" -gt 0 ]; do 27 | case $1 in 28 | --mode) mode="$2"; shift ;; 29 | --model) model="$2"; shift ;; 30 | --temperature) temperature="$2"; shift ;; 31 | --top_p) top_p="$2"; shift ;; 32 | --max_tokens) max_tokens="$2"; shift ;; 33 | --n) n="$2"; shift ;; 34 | --pipeline_order) pipeline_order="$2"; shift ;; 35 | --enrichment_level) enrichment_level="$2"; shift ;; 36 | --enrichment_level_shot_number) enrichment_level_shot_number="$2"; shift ;; 37 | --enrichment_few_shot_schema_existance) enrichment_few_shot_schema_existance="$2"; shift ;; 38 | --filtering_level_shot_number) filtering_level_shot_number="$2"; shift ;; 39 | --filtering_few_shot_schema_existance) filtering_few_shot_schema_existance="$2"; shift ;; 40 | --cfg) cfg="$2"; shift ;; 41 | --generation_level_shot_number) generation_level_shot_number="$2"; shift ;; 42 | --generation_few_shot_schema_existance) generation_few_shot_schema_existance="$2"; shift ;; 43 | --db_sample_limit) db_sample_limit="$2"; shirt ;; 44 | --relevant_description_number) relevant_description_number="$2"; shirt ;; 45 | --seed) seed="$2"; shift ;; 46 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 47 | esac 48 | shift 49 | done 50 | 51 | # Run the Python script with the provided arguments 52 | python main.py \ 53 | --mode "$mode" \ 54 | --model "$model" \ 55 | --temperature "$temperature" \ 56 | --top_p "$top_p" \ 57 | --max_tokens "$max_tokens" \ 58 | --n "$n" \ 59 | --pipeline_order "$pipeline_order" \ 60 | --enrichment_level "$enrichment_level" \ 61 | --enrichment_level_shot_number "$enrichment_level_shot_number" \ 62 | --enrichment_few_shot_schema_existance "$enrichment_few_shot_schema_existance" \ 63 | --filtering_level_shot_number "$filtering_level_shot_number" \ 64 | --filtering_few_shot_schema_existance "$filtering_few_shot_schema_existance" \ 65 | --cfg "$cfg"\ 66 | --generation_level_shot_number "$generation_level_shot_number" \ 67 | --generation_few_shot_schema_existance "$generation_few_shot_schema_existance" \ 68 | --db_sample_limit "$db_sample_limit" \ 69 | --relevant_description_number "$relevant_description_number" \ 70 | --seed "$seed" 71 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasanAlpCaferoglu/E-SQL/72ec6a38945b6f8ce85e373e159fe86f0a862d2f/utils/__init__.py -------------------------------------------------------------------------------- /utils/db_utils.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import random 3 | import logging 4 | import re 5 | import time 6 | import nltk 7 | from nltk.tokenize import word_tokenize 8 | import difflib 9 | from rank_bm25 import BM25Okapi 10 | import sqlglot 11 | from sqlglot import parse, parse_one, expressions 12 | from sqlglot.optimizer.qualify import qualify 13 | from sqlglot.optimizer.qualify_columns import qualify_columns 14 | from sqlglot.expressions import Select 15 | from func_timeout import func_timeout, FunctionTimedOut 16 | from typing import Any, Union, List, Dict, Optional 17 | nltk.download('punkt') 18 | 19 | def execute_sql(db_path: str, sql: str, fetch: Union[str, int] = "all") -> Any: 20 | """ 21 | Executes an SQL query on a database and fetches results. 22 | 23 | Arguments: 24 | db_path (str): The database sqlite file path. 25 | sql (str): The SQL query to execute. 26 | fetch (Union[str, int]): How to fetch the results. Options are "all", "one", "random", or an integer. 27 | 28 | Returns: 29 | resutls: SQL execution results . 30 | """ 31 | try: 32 | with sqlite3.connect(db_path) as conn: 33 | cursor = conn.cursor() 34 | cursor.execute(sql) 35 | if fetch == "all": 36 | return cursor.fetchall() 37 | elif fetch == "one": 38 | return cursor.fetchone() 39 | elif fetch == "random": 40 | samples = cursor.fetchmany(10) 41 | return random.choice(samples) if samples else [] 42 | elif isinstance(fetch, int): 43 | return cursor.fetchmany(fetch) 44 | else: 45 | raise ValueError("Invalid fetch argument. Must be 'all', 'one', 'random', or an integer.") 46 | except Exception as e: 47 | logging.error(f"Error in execute_sql: {e}\n db_path: {db_path}\n SQL: {sql}") 48 | raise e 49 | 50 | def get_db_tables(db_path): 51 | """ 52 | The function extracts tables in a specific database. 53 | 54 | Arguments: 55 | db_path (str): The database sqlite file path. 56 | Returns: 57 | db_tables (List[str]): Names of the tables in the database as a list of string 58 | """ 59 | # Execute query to extract the names of all tables in the database 60 | try: 61 | 62 | tables = execute_sql(db_path, "SELECT name FROM sqlite_master WHERE type='table';") 63 | db_tables = [table_name_tuple[0] for table_name_tuple in tables if table_name_tuple[0] != "sqlite_sequence"] 64 | return db_tables 65 | except Exception as e: 66 | logging.error(f"Error in get_db_tables: {e}") 67 | raise e 68 | 69 | def get_db_colums_of_table(db_path: str, table_name: str) -> List[str]: 70 | """ 71 | The function extractes all column of a table whose name is given 72 | 73 | Args: 74 | db_path (str): The database sqlite file path. 75 | table_name (str): The name of the table in the database whose columns extracted. 76 | 77 | Returns: 78 | columns_of_table (List[str]): A list of column names. 79 | """ 80 | try: 81 | table_info_rows = execute_sql(db_path, f"PRAGMA table_info(`{table_name}`);") 82 | columns_of_table = [row[1] for row in table_info_rows] 83 | # data_type_of_columns_of_table = [row[2] for row in table_info_rows] 84 | return columns_of_table 85 | except Exception as e: 86 | logging.error(f"Error in get_table_all_columns: {e}\nTable: {table_name}") 87 | raise e 88 | 89 | def isTableInDB(db_path: str, table_name: str) -> bool: 90 | """ 91 | The function checks whether given table name is in the database 92 | 93 | Arguments: 94 | db_path (str): The database sqlite file path. 95 | table_name (str): the name of the table that is going to be checked 96 | 97 | Returns: 98 | bool: True if the table in the database, otherwise returns False 99 | """ 100 | 101 | db_tables = get_db_tables(db_path) 102 | if table_name in db_tables: 103 | return True 104 | else: 105 | return False 106 | 107 | def isColumnInTable(db_path: str, table_name: str, column_name: str) -> bool: 108 | """ 109 | The function checks whether given column name is in the columns of given table 110 | 111 | Arguments: 112 | db_path (str): The database sqlite file path 113 | table_name (str): the name of the table 114 | column_name (str): the name of the column that is going to be checked 115 | 116 | Returns: 117 | bool: True if the given column is among the columns of given table, otherwise returns False 118 | """ 119 | 120 | columns_of_table = get_db_colums_of_table(db_path, table_name) 121 | if column_name in columns_of_table: 122 | return True 123 | else: 124 | return False 125 | 126 | def get_original_schema(db_path: str) -> str: 127 | """ 128 | The function construct database schema from the database sqlite file. 129 | 130 | Arguments: 131 | db_path (str): The database sqlite file path. 132 | Returns: 133 | db_schema (str): database schema constructed by CREATE TABLE statements 134 | """ 135 | # Connecting to the sqlite database 136 | conn = sqlite3.connect(db_path) 137 | cursor = conn.cursor() 138 | 139 | # Query to extract the names of all tables in the database 140 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 141 | tables = cursor.fetchall() 142 | 143 | # Dictionary to hold table names and their CREATE TABLE statements 144 | db_schema_dict = {} 145 | 146 | for table_name_tuple in tables: 147 | table_name = table_name_tuple[0] # Extracting the table name from the tuple 148 | cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';") 149 | create_statement = cursor.fetchone()[0] 150 | db_schema_dict[table_name] = create_statement 151 | 152 | # Close the connection 153 | cursor.close() 154 | conn.close() 155 | 156 | db_schema = " " 157 | for table, create_table_statement in db_schema_dict.items(): 158 | db_schema = db_schema + create_table_statement + "\n" 159 | 160 | return db_schema 161 | 162 | 163 | def clean_db_schema(db_schema: str) -> str: 164 | """ 165 | The function cleans the database schema by removing unnecessary whitespaces and new lines, 166 | ensuring that each column definition or constraint in a table is described on a single line. 167 | 168 | Arguments: 169 | db_schema (str): original database schema 170 | Returns: 171 | cleaned_db_schema (str): cleaned database schema 172 | """ 173 | # Split the schema into lines 174 | lines = db_schema.split('\n') 175 | cleaned_lines = [] 176 | current_statement = [] 177 | 178 | for index in range(len(lines)): 179 | line = lines[index] 180 | line = line.strip() # Trim any leading/trailing whitespace 181 | if not line: 182 | continue # Skip empty lines 183 | if "CREATE TABLE" in line: 184 | line = line + " (" # append '(' to the lines containing CREATE TABLE 185 | cleaned_lines.append(line) 186 | continue 187 | if line[0] == '(': 188 | continue # Skip lines containing '(' 189 | if line[0] == ')': 190 | cleaned_lines.append(line) 191 | continue 192 | if "primary key" in line.lower(): 193 | cleaned_lines[-1] = cleaned_lines[-1] + 'primary key,' # if the current line is PK, add it to the previous line 194 | continue 195 | 196 | line = line.replace('AUTOINCREMENT', '') 197 | line = line.replace('DEFAULT 0', '') 198 | line = line.replace('NOT NULL', '') 199 | line = line.replace('NULL', '') 200 | line = line.replace('UNIQUE', '') 201 | line = line.replace('ON UPDATE', '') 202 | line = line.replace('ON DELETE', '') 203 | line = line.replace('CASCADE', '') 204 | 205 | line = line.replace('autoincrement', '') 206 | line = line.replace('default 0', '') 207 | line = line.replace('not null', '') 208 | line = line.replace('null', '') 209 | line = line.replace('unique', '') 210 | line = line.replace('on update', '') 211 | line = line.replace('on delete', '') 212 | line = line.replace('cascade', '') 213 | 214 | # Remove space before commas 215 | line = re.sub(r'\s*,', ',', line) 216 | 217 | # Ensure one space between column names and their data types 218 | # Handling multi-word column names enclosed in backticks 219 | line = re.sub(r'`([^`]+)`\s+(\w+)', r'`\1` \2', line) 220 | # Handling standard column names 221 | line = re.sub(r'(\w+)\s+(\w+)', r'\1 \2', line) 222 | 223 | cleaned_lines.append(line) 224 | # Join all cleaned lines into a single string 225 | cleaned_db_schema = '\n'.join(cleaned_lines) 226 | return cleaned_db_schema 227 | 228 | def get_schema(db_path: str) -> str: 229 | """ 230 | The function returns cleaned database schema from the database sqlite file. 231 | 232 | Arguments: 233 | db_path (str): The database sqlite file path. 234 | Returns: 235 | db_schema (str): cleaned database schema constructed by CREATE TABLE statements 236 | """ 237 | original_db_schema = get_original_schema(db_path) 238 | db_schema = clean_db_schema(original_db_schema) 239 | return db_schema 240 | 241 | def get_schema_dict(db_path: str) -> Dict: 242 | """ 243 | The function construct database schema from the database sqlite file in the form of dict. 244 | 245 | Arguments: 246 | db_path (str): The database sqlite file path. 247 | Returns: 248 | db_schema_dict (Dict[str, Dict[str, str]]): database schema dictionary whose keys are table names and values are dict with column names keys and data type with as values. 249 | """ 250 | # Connecting to the sqlite database 251 | conn = sqlite3.connect(db_path) 252 | cursor = conn.cursor() 253 | 254 | # Query to extract the names of all tables in the database 255 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 256 | tables = cursor.fetchall() 257 | table_names = [table[0] for table in tables] 258 | 259 | # Dictionary to hold table names and their CREATE TABLE statements 260 | db_schema_dict = {} 261 | 262 | for table_name in table_names: 263 | cursor.execute(f"PRAGMA table_info(`{table_name}`);") 264 | table_info = cursor.fetchall() # in table_info, each row indicate (cid, column name, type, notnull, default value, is_PK) 265 | # print(f"Table {table_name} info: \n", table_info) 266 | db_schema_dict[table_name] = {col_item[1]: col_item[2] for col_item in table_info} 267 | 268 | # Close the connection 269 | cursor.close() 270 | conn.close() 271 | 272 | return db_schema_dict 273 | 274 | def get_schema_tables_and_columns_dict(db_path: str) -> Dict: 275 | """ 276 | The function construct database schema from the database sqlite file in the form of dict. 277 | 278 | Arguments: 279 | db_path (str): The database sqlite file path. 280 | Returns: 281 | db_schema_dict (Dict[str, List[str]]): database schema dictionary whose keys are table names and values are list of column names. 282 | """ 283 | # Connecting to the sqlite database 284 | conn = sqlite3.connect(db_path) 285 | cursor = conn.cursor() 286 | 287 | # Query to extract the names of all tables in the database 288 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 289 | tables = cursor.fetchall() 290 | table_names = [table[0] for table in tables] 291 | 292 | # Dictionary to hold table names and their CREATE TABLE statements 293 | db_schema_dict = {} 294 | 295 | for table_name in table_names: 296 | cursor.execute(f"PRAGMA table_info(`{table_name}`);") 297 | table_info = cursor.fetchall() # in table_info, each row indicate (cid, column name, type, notnull, default value, is_PK) 298 | # print(f"Table {table_name} info: \n", table_info) 299 | db_schema_dict[table_name] = [col_item[1] for col_item in table_info] 300 | 301 | # Close the connection 302 | cursor.close() 303 | conn.close() 304 | 305 | return db_schema_dict 306 | 307 | def clean_sql(sql: str) -> str: 308 | """ 309 | The function removes unwanted whitespace and characters in the given SQL statement 310 | 311 | Arguments: 312 | sql (str): The SQL query. 313 | Returns: 314 | clean_sql (str): Clean SQL statement. 315 | """ 316 | # clean_sql= sql.replace('\n', ' ').replace('"', "'").strip("`.") 317 | clean_sql = sql.replace('\n', ' ').replace('"', "`").replace('\"', "`") 318 | return clean_sql 319 | 320 | 321 | 322 | def compare_sqls_outcomes(db_path: str, predicted_sql: str, ground_truth_sql: str) -> int: 323 | """ 324 | Compares the results of two SQL queries to check for equivalence. 325 | 326 | Args: 327 | db_path (str): The database sqlite file path. 328 | predicted_sql (str): The predicted SQL query. 329 | ground_truth_sql (str): The ground truth SQL query. 330 | 331 | Returns: 332 | int: 1 if the outcomes are equivalent, 0 otherwise. 333 | 334 | Raises: 335 | Exception: If an error occurs during SQL execution. 336 | """ 337 | try: 338 | predicted_res = execute_sql(db_path, predicted_sql) 339 | ground_truth_res = execute_sql(db_path, ground_truth_sql) 340 | return int(set(predicted_res) == set(ground_truth_res)) 341 | except Exception as e: 342 | logging.critical(f"Error comparing SQL outcomes: {e}") 343 | raise e 344 | 345 | 346 | def compare_sqls(db_path: str, predicted_sql: str, ground_truth_sql: str, meta_time_out: int = 30) -> Dict[str, Union[int, str]]: 347 | """ 348 | Compares predicted SQL with ground truth SQL within a timeout. 349 | 350 | Arguments: 351 | db_path (str): The database sqlite file path. 352 | predicted_sql (str): The predicted SQL query. 353 | ground_truth_sql (str): The ground truth SQL query. 354 | meta_time_out (int): The timeout for the comparison. 355 | 356 | Returns: 357 | dict: A dictionary with the comparison result and any error message. 358 | """ 359 | predicted_sql = clean_sql(predicted_sql) 360 | try: 361 | res = func_timeout(meta_time_out, compare_sqls_outcomes, args=(db_path, predicted_sql, ground_truth_sql)) 362 | error = "incorrect answer" if res == 0 else "--" 363 | except FunctionTimedOut: 364 | logging.warning("Comparison timed out.") 365 | error = "timeout" 366 | res = 0 367 | except Exception as e: 368 | logging.error(f"Error in compare_sqls: {e}") 369 | error = str(e) 370 | res = 0 371 | return {'exec_res': res, 'exec_err': error} 372 | 373 | 374 | def extract_sql_tables(db_path: str, sql: str) -> List[str]: 375 | """ 376 | The function extracts the table names in the SQL. 377 | 378 | Args: 379 | db_path (str): The database sqlite file path. 380 | sql (str): The ground truth SQL query string. 381 | 382 | Returns: 383 | tables_in_sql (List[str]): Names of the tables in the ground truth SQL query as a list of string. 384 | """ 385 | db_tables = get_db_tables(db_path) 386 | try: 387 | parsed_tables = list(parse_one(sql, read='sqlite').find_all(expressions.Table)) # parsed_tables: List[] 388 | tables_in_sql = [str(table.name) for table in parsed_tables if str(table.name) in [db_table.lower() for db_table in db_tables]] # tables_in_sql: List[str] 389 | tables_in_sql = list(set(tables_in_sql)) # ensure list contains only unique table values i.e. a table name doesn't repeat in the list 390 | return tables_in_sql 391 | except Exception as e: 392 | logging.critical(f"Error in extract_sql_tables: {e}\n") 393 | raise e 394 | 395 | def extract_sql_tables_with_aliases(db_path: str, sql: str) -> List[str]: 396 | """ 397 | The function extracts the table names with their aliases in the SQL. 398 | 399 | Args: 400 | db_path (str): The database sqlite file path. 401 | sql (str): The ground truth SQL query string. 402 | 403 | Returns: 404 | tables_w_aliases (List[Dict[str, str]]): List of dictionary whose keys are "table_name" and "table_alias" 405 | """ 406 | db_tables = get_db_tables(db_path) 407 | try: 408 | parsed_tables = list(parse_one(sql, read='sqlite').find_all(expressions.Table)) # parsed_tables: List[] 409 | tables_w_aliases = [{"table_name": str(table.name), "table_alias": str(table.alias)} for table in parsed_tables if str(table.name) in [db_table.lower() for db_table in db_tables]] # tables_in_sql: List[str] 410 | tables_w_aliases = [table_alias_dict for table_alias_dict in tables_w_aliases if table_alias_dict['table_alias'] != ''] 411 | return tables_w_aliases 412 | except Exception as e: 413 | logging.warning(f"Error in extract_sql_tables_with_aliases: \n\tError{e} \n\t{sql}") 414 | raise 415 | 416 | def replace_alias_with_table_names_in_sql(db_path: str, sql: str) -> str: 417 | """ 418 | The function removes aliases in the SQL. 419 | 420 | Arguments: 421 | sql (str): The SQL with aliases 422 | Returns: 423 | sql (str): The SQL without aliases. Table aliases are replaced with corresponding table names. 424 | """ 425 | try: 426 | tables_w_aliases = extract_sql_tables_with_aliases(db_path, sql) 427 | for table_dict in tables_w_aliases: 428 | table_name, table_alias = table_dict['table_name'], table_dict['table_alias'] 429 | sql = sql.replace(table_alias+".", table_name+".") # replace table_alias with table names in necessary clauses 430 | # sql = sql.replace(f"AS {table_alias}", "") # remove "AS" keywords 431 | # sql = sql.replace(table_alias, "") # remove table alias 432 | sql = re.sub(r'\s+', ' ', sql) # remove extra spaces 433 | 434 | return sql 435 | except Exception as e: 436 | logging.warning(f"Failed to replace aliases in SQL due to error: {e}. Initially given error is returned.") 437 | return sql 438 | 439 | 440 | 441 | def extract_sql_columns(db_path: str, sql: str) -> Dict[str, List[str]]: 442 | """ 443 | The function extracts the column names with corresponding table names in the SQL. 444 | 445 | Args: 446 | db_path (str): The database sqlite file path. 447 | sql (str): The SQL query as string. 448 | 449 | Returns: 450 | columns_in_sql_dict (Dict[str, List[str]]): A dictionary where keys are table names and values are lists of column names. 451 | """ 452 | columns_in_sql_dict = {} 453 | 454 | # extract database tables 455 | db_tables = get_db_tables(db_path) 456 | 457 | # Qualify columns such that columns will always be followed by table_name 458 | db_schema_dict = get_schema_dict(db_path) 459 | try: 460 | qualified_parsed_sql = qualify(parse_one(sql, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False) 461 | except Exception as e: 462 | logging.critical(f"Error in qualifying parsed SQL in extract_sql_columns function. \n{e} \nSQL: {sql} ") 463 | qualified_parsed_sql = parse_one(sql, read='sqlite') 464 | # print("qualified and parsed sql: \n", qualified_parsed_sql) 465 | """ 466 | Type of qualified_parsed_sql is 467 | When qualified sql as a string is reauired, use __str__() method of 'sqlglot.expressions.Select' class 468 | """ 469 | 470 | # Extract tables and its aliases 471 | tables_w_aliases = extract_sql_tables_with_aliases(db_path, qualified_parsed_sql.__str__()) 472 | # print("tables_w_aliases: \n", tables_w_aliases) 473 | 474 | try: 475 | parsed_columns = list(qualified_parsed_sql.find_all(expressions.Column)) 476 | # parsed_columns: List[] 477 | # Note that columns and table names will be lower case if it is a single word 478 | # print("parsed columns: \n", parsed_columns) 479 | 480 | for column_obj in parsed_columns: 481 | column_name = column_obj.name 482 | table_name = column_obj.table 483 | # print("column_name and table_name: ", column_name, "-", table_name ) 484 | 485 | for tables_alias_dict in tables_w_aliases: 486 | if table_name == tables_alias_dict["table_alias"]: 487 | table_name = tables_alias_dict["table_name"] 488 | 489 | if table_name.lower() in [db_table.lower() for db_table in db_tables]: 490 | db_columns_of_table = get_db_colums_of_table(db_path, table_name) 491 | if column_name.lower() in [col.lower() for col in db_columns_of_table]: 492 | if table_name in columns_in_sql_dict: 493 | if column_name not in columns_in_sql_dict[table_name]: 494 | columns_in_sql_dict[table_name].append(column_name) 495 | else: 496 | columns_in_sql_dict[table_name] = [column_name] 497 | 498 | # reconstruct columns_in_sql_dict so that its items are original database table ans column names 499 | original_columns_in_sql_dict = {} 500 | for table_name, col_list in columns_in_sql_dict.items(): 501 | db_tables_lower = [t.lower() for t in db_tables] 502 | t_index = db_tables_lower.index(table_name) 503 | db_table_original_name = db_tables[t_index] 504 | original_columns_in_sql_dict[db_table_original_name] = [] 505 | 506 | table_cols_original = get_db_colums_of_table(db_path, db_table_original_name) 507 | table_cols_lower = [c.lower() for c in table_cols_original] 508 | for col in col_list: 509 | if col in table_cols_lower: 510 | c_index = table_cols_lower.index(col) 511 | c_original_name = table_cols_original[c_index] 512 | original_columns_in_sql_dict[db_table_original_name].append(c_original_name) 513 | elif col in table_cols_original: 514 | original_columns_in_sql_dict[db_table_original_name].append(col) 515 | 516 | columns_in_sql_dict = original_columns_in_sql_dict 517 | return columns_in_sql_dict 518 | except Exception as e: 519 | logging.critical(f"Error in extract_sql_columns:{e}\n") 520 | 521 | def generate_schema_from_schema_dict(db_path: str, schema_dict: Dict) -> str: 522 | """ 523 | The function creates filtered schema string from the given schema dictionary 524 | 525 | Arguments: 526 | db_path (str): The database sqlite file path. 527 | schema_dict (Dict[str, List[str]]): A dictionary where keys are table names and values are lists of column names. 528 | 529 | Results: 530 | schema_str (str): Schema generated from the given schema dictionary and database path 531 | """ 532 | # Dictionary to store CREATE TABLE statements 533 | create_statements = [] 534 | 535 | for table, column_list in schema_dict.items(): 536 | table_info = execute_sql(db_path, f"PRAGMA table_info(`{table}`);") # returns tuple (cid, name type, notnull, dflt_value, pk) for each column in the taple 537 | # print(f"TABLE INFO OF {table} \n", table_info) 538 | pk_columns = [(row[1], row[2]) for row in table_info if row[5] != 0] # add all PKs in schema definition with ther data types 539 | other_columns = [(row[1], row[2]) for row in table_info if row[5] == 0 and row[1] in column_list] # add column if it is exist in the column list in the given schema dict 540 | 541 | # Query for foreign key information 542 | foreign_keys_info = execute_sql(db_path, f"PRAGMA foreign_key_list(`{table}`)") 543 | # print(f"Foreing keys info:\n", foreign_keys_info) 544 | foreign_keys = {row[3]: (row[2], row[4]) for row in foreign_keys_info if row[3] in column_list} # local_col: (ref_table, foreing_col) if local_col exist in filtered column list in the given schema dict 545 | # print("foreign_keys: \n", foreign_keys) 546 | 547 | table_definition = f"CREATE TABLE {table} (\n" 548 | if len(pk_columns) == 1: 549 | pk = pk_columns[0] 550 | table_definition = table_definition + f"{pk[0]} {pk[1]} primary key, \n" 551 | for col in other_columns: 552 | table_definition = table_definition + f"{col[0]} {col[1]},\n" 553 | for local_col, ref_table_col in foreign_keys.items(): 554 | table_definition = table_definition + f"foreing key ({local_col}) references {ref_table_col[0]}({ref_table_col[1]}) \n" 555 | elif len(pk_columns) > 1: 556 | # concatenate primary key column names with their data types 557 | for pk in pk_columns: 558 | table_definition = table_definition + f"{pk[0]} {pk[1]}, \n" 559 | 560 | # concatenate columns with their data types 561 | for col in other_columns: 562 | table_definition = table_definition + f"{col[0]} {col[1]},\n" 563 | 564 | # concatenate primary key descriptions 565 | table_definition = table_definition + "primary key (" 566 | for ind, pk in enumerate(pk_columns): 567 | if ind < len(pk_columns)-1: 568 | table_definition = table_definition + pk[0] + ", " 569 | else: 570 | table_definition = table_definition + pk[0] 571 | 572 | table_definition = table_definition + "),\n" 573 | 574 | # concatenate foreign key descriptions 575 | for local_col, ref_table_col in foreign_keys.items(): 576 | table_definition = table_definition + f"foreing key ({local_col}) references {ref_table_col[0]}({ref_table_col[1]}) \n" 577 | 578 | 579 | table_definition = table_definition + ")" 580 | create_statements.append(table_definition) 581 | 582 | schema_str = '\n'.join(create_statements) 583 | return schema_str 584 | 585 | 586 | def extract_db_samples_enriched_bm25(question: str, evidence: str, db_path: str, schema_dict: Dict, sample_limit: int) -> str: 587 | """ 588 | The function extract distict samples for given schema items from the database by ranking values using BM25. 589 | Ranking is not done seperately for all values of each table.column 590 | 591 | Arguments: 592 | question (str): considered natural language question 593 | evidence (str): given evidence about the question 594 | db_path (str): The database sqlite file path. 595 | schema_dict (Dict[str, List[str]]): Database schema dictionary where keys are table names and values are lists of column names 596 | 597 | Returns: 598 | db_samples (str): concatenated strings gives samples from each column 599 | """ 600 | db_samples = "\n" 601 | 602 | question = question.replace('\"', '').replace("\'", "").replace("`", "") 603 | question_and_evidence = question + " " + evidence 604 | tokenized_question_evidence = word_tokenize(question_and_evidence) 605 | 606 | for table, col_list in schema_dict.items(): 607 | db_samples = db_samples + f"## {table} table samples:\n" 608 | for col in col_list: 609 | try: 610 | col_distinct_values = execute_sql(db_path, f"SELECT DISTINCT `{col}` FROM `{table}`") # extract all distinct values 611 | col_distinct_values = [str(value_tuple[0]) if value_tuple and value_tuple[0] else 'NULL' for value_tuple in col_distinct_values] 612 | if 'NULL' in col_distinct_values: 613 | isNullExist = True 614 | else: 615 | isNullExist = False 616 | 617 | # col_distinct_values = [str(value_tuple[0]) for value_tuple in col_distinct_values if value_tuple[0]] # not condiderin NULL values 618 | # col_distinct_values = [value if len(value) < 400 else value[:300] for value in col_distinct_values] # if the lenght of value is too long take its 300 character only 619 | # if average lenght of the column values larger than 600 than use only the first item of the values since large length of the values cause context limit 620 | if len(col_distinct_values) > 0: 621 | average_length = sum(len(value) for value in col_distinct_values) / len(col_distinct_values) 622 | else: 623 | average_length = 0 624 | if average_length > 600: 625 | col_distinct_values = [col_distinct_values[0]] 626 | 627 | if len(col_distinct_values) > sample_limit: 628 | corpus = col_distinct_values.copy() 629 | corpus = [f'{table} {col} {val}' for val in corpus] 630 | tokenized_corpus = [doc.split(" ") for doc in corpus] 631 | # tokenized_corpus = [word_tokenize(doc) for doc in corpus] # takes too much time, so don't use it 632 | bm25 = BM25Okapi(tokenized_corpus) 633 | 634 | col_distinct_values = bm25.get_top_n(tokenized_question_evidence, col_distinct_values, n=sample_limit) 635 | if isNullExist: 636 | col_distinct_values.append("NULL") 637 | 638 | db_samples = db_samples + f"# Example values for '{table}'.'{col}' column: " + str(col_distinct_values) + "\n" 639 | except Exception as e: 640 | sql = f"SELECT DISTINCT `{col}` FROM `{table}`" 641 | logging.error(f"Error in extract_db_samples_enriched_bm25: {e}\n SQL: {sql}") 642 | error = str(e) 643 | 644 | return db_samples 645 | 646 | def construct_tokenized_db_table_value_corpus(db_path: str, schema_dict: Dict): 647 | """ 648 | Function collects all item for each value in the database as "table_name column_name value", then tokenize it 649 | 650 | Arguments: 651 | db_path (str): The database sqlite file path. 652 | schema_dict (Dict[str, List[str]]): Database schema dictionary where keys are table names and values are lists of column names 653 | 654 | Returns: 655 | tokenized_db_corpus (List[List]): List of tokenized database "table_name column_name value" item 656 | db_corpus (List[Tuple]) 657 | """ 658 | # generating corpus whose items are tokenized version of "table_name column_name value" for each value and table in the database. 659 | corpus = [] 660 | db_corpus = [] 661 | for table, col_list in schema_dict.items(): 662 | for col in col_list: 663 | try: 664 | col_distinct_values = execute_sql(db_path, f"SELECT DISTINCT `{col}` FROM `{table}`") # extract all distinct values 665 | col_distinct_values = [str(value_tuple[0]) for value_tuple in col_distinct_values if value_tuple[0]] 666 | # if average lenght of the column values larger than 600 than use only the first item of the values since large length of the values cause context limit 667 | if len(col_distinct_values) > 0: 668 | average_length = sum(len(value) for value in col_distinct_values) / len(col_distinct_values) 669 | else: 670 | average_length = 0 671 | if average_length > 600: 672 | col_distinct_values = [col_distinct_values[0]] 673 | 674 | table_col_value_str = [f"{table} {col} {val}" for val in col_distinct_values] 675 | corpus.extend(table_col_value_str) 676 | table_col_value_tuples = [(table, col, val) for val in col_distinct_values] 677 | db_corpus.extend(table_col_value_tuples) 678 | 679 | except Exception as e: 680 | logging.error(f"Error in extract_db_samples_enriched_bm25: {e}\n SQL: {sql}") 681 | sql = f"SELECT DISTINCT `{col}` FROM `{table}`" 682 | error = str(e) 683 | 684 | # construction bm25 object 685 | tokenized_db_corpus = [doc.split(" ") for doc in corpus] 686 | # tokenized_db_corpus = [word_tokenize(doc) for doc in corpus if doc] # takes too much time, so don't use it 687 | return tokenized_db_corpus, db_corpus 688 | 689 | def find_most_similar_table(problematic_t_name: str, db_tables: List[str]) -> str: 690 | """ 691 | Helper function to find the most similar table name in the database. 692 | As a string similarity metric, Levenshtein distance is used 693 | This helper function calculates the similarity ratio between two strings based on the number of single-character edits (insertions, deletions, substitutions) needed to transform one string into the other. 694 | This ratio is a value between 0 and 1, where 1 means the strings are identical, and 0 means they are completely different. 695 | 696 | Arguments: 697 | problematic_t_name (str): name of the table that is not in the database 698 | db_tables (List[str]): list of database tables 699 | 700 | Returns: 701 | most_similar_table (str): the name of the table that is actually in the database and most similar to the given problematic table name 702 | """ 703 | 704 | similarity_scores = [(t_name, difflib.SequenceMatcher(None, problematic_t_name, t_name).ratio()) for t_name in db_tables] 705 | most_similar_table = max(similarity_scores, key=lambda x: x[1])[0] 706 | return most_similar_table 707 | 708 | def filtered_schema_correction(db_path: str, filtered_schema_dict: Dict) -> Dict: 709 | """ 710 | The function checks whether something mismatch with the original schema or not. If there is mismatch, it corrects it. 711 | 712 | Arguments: 713 | db_path (str): The database sqlite file path 714 | filtered_schema_dict (Dict[str, List[str]]): A dictionary where keys are table names and values are list of column names 715 | 716 | Returns: 717 | final_filtered_schema_dict (Dict[str, List[str]]): Finalized filtered schema dictionary 718 | filtered_schema_problems (str): A string that expresses all mismatches 719 | """ 720 | filtered_schema_problems = "" 721 | 722 | db_tables = get_db_tables(db_path) 723 | ## Step 1: Check if the tables in filtered schema dictionary are in the database and replace them with the most similar table names 724 | 725 | problematic_tables = [] 726 | for t_name in filtered_schema_dict.keys(): 727 | isInDb = isTableInDB(db_path=db_path, table_name=t_name) 728 | if not isInDb: 729 | problematic_tables.append(t_name) 730 | 731 | 732 | if problematic_tables: 733 | print(f"There is mismatch between database and filtered schema tables. The problematic tables are: {problematic_tables}") 734 | filtered_schema_problems = filtered_schema_problems + f"There is mismatch between database and filtered schema tables. The problematic tables are: {problematic_tables}" 735 | for problematic_t_name in problematic_tables: 736 | most_similar_table = find_most_similar_table(problematic_t_name, db_tables) 737 | filtered_schema_dict[most_similar_table] = filtered_schema_dict.pop(problematic_t_name) 738 | 739 | ## Step 2: Check if the columns of a table in filtered schema dictiionary are actually column of the table. If not, find new table for them. 740 | 741 | for table_name in filtered_schema_dict.keys(): 742 | problematic_columns = {} # Dict[str, bool] --> keys are column names and values are boolean that indicates whether a table containing that column is found 743 | for column_name in filtered_schema_dict[table_name]: 744 | isInTable = isColumnInTable(db_path=db_path, table_name=table_name, column_name=column_name) 745 | if not isInTable: 746 | print(f"There is a mismatch in filtered schema table columns. {column_name} is not actually in the {table_name} table.") 747 | filtered_schema_problems = filtered_schema_problems + f"There is a mismatch in filtered schema table columns. {column_name} is not actually in the {table_name} table." 748 | problematic_columns[column_name] = False # boolean variable indicates whether a table containing that column is found 749 | 750 | # finding tables for problematic columns 751 | table_column_dict = {} 752 | db_tables = get_db_tables(db_path) 753 | for p_column in problematic_columns.keys(): 754 | for db_table in db_tables: 755 | columns_of_table = get_db_colums_of_table(db_path=db_path, table_name=db_table) 756 | if p_column in columns_of_table: 757 | problematic_columns[p_column] = True 758 | if db_table in table_column_dict: 759 | table_column_dict[db_table].append(p_column) 760 | else: 761 | table_column_dict[db_table] = [p_column] 762 | 763 | 764 | # constructing final filtered schema 765 | final_filtered_schema_dict = filtered_schema_dict.copy() 766 | 767 | # removing the problematic columns whose actual tables are found 768 | for p_column, actual_table_found in problematic_columns.items(): 769 | if actual_table_found: 770 | final_filtered_schema_dict[table_name].remove(p_column) 771 | 772 | # Appending problematic columns actual table into the final filtered schema dict 773 | for actual_table, actual_columns in table_column_dict.items(): 774 | if actual_table in final_filtered_schema_dict: 775 | final_filtered_schema_dict[actual_table] = final_filtered_schema_dict[actual_table] + actual_columns 776 | else: 777 | final_filtered_schema_dict[actual_table] = actual_columns 778 | 779 | 780 | return final_filtered_schema_dict, filtered_schema_problems 781 | 782 | 783 | def find_similar_values_incolumn_via_like(db_path: str, table: str, column: str, value: str) -> List[str]: 784 | """ 785 | This function finds similar values to the given value using SQL LIKE clause in given table and column. 786 | 787 | Arguments: 788 | db_path (str): The database sqlite file path. 789 | table (str): Table in the given database. 790 | column (str): Column belongs to the given table. 791 | value (str): Given value on which similar values are extracted 792 | 793 | Returns: 794 | similar_values (List[str]): List of string which are similar to given value. 795 | """ 796 | # if the length of value is 1, then just return itself to prevent lots of unnecessary match 797 | if len(value) == 1: 798 | return [] 799 | # Observing a single value from column. If its length is larger than 300, then return empty list 800 | value_observation_sql = f"SELECT `{column}` FROM `{table}` LIMIT 3" 801 | observed_values = execute_sql(db_path=db_path, sql=value_observation_sql) 802 | 803 | if not observed_values: 804 | return [] 805 | 806 | observed_values = [str(row[0]) for row in observed_values] 807 | observed_values_avg_len = sum( map(len, observed_values) ) / len(observed_values) 808 | if observed_values_avg_len >= 300: 809 | return [] 810 | 811 | sql = 'SELECT DISTINCT `{C}` FROM `{T}` WHERE `{C}` LIKE "%{V}%"'.format(C=column, T=table, V=value) 812 | try: 813 | similar_values = execute_sql(db_path=db_path, sql=sql) 814 | similar_values = [str(row[0]) for row in similar_values if len(str(row[0])) < 50] 815 | # Decrease the size of similar values if there are too much 816 | if len(similar_values) > 5: 817 | similar_values = similar_values[:5] 818 | except Exception as e: 819 | similar_values = [] 820 | logging.critical(f"Error in finding similar values for a value: {e} \n SQL: {sql}") 821 | 822 | return similar_values 823 | 824 | 825 | def find_similar_values_indb_via_like(db_path: str, value:str) -> Dict: 826 | """ 827 | This function finds similar values to the given value using SQL LIKE clause in all database. 828 | 829 | Arguments: 830 | db_path (str): The database sqlite file path. 831 | value (str): Given value on which similar values are extracted 832 | 833 | Returns: 834 | similar_values_dict (Dict[Dict[str, List[str]]]): Tables are the keys of dict, and values of table keys are another dictionary whose keys are column names and values are list of real column values similar to the given value 835 | """ 836 | db_tables_columns_dict = get_schema_tables_and_columns_dict(db_path=db_path) 837 | similar_values_dict = {} 838 | for table, columns_list in db_tables_columns_dict.items(): 839 | for column in columns_list: 840 | similar_values_list = find_similar_values_incolumn_via_like(db_path, table, column, value) 841 | if similar_values_list: 842 | if table in similar_values_dict: 843 | similar_values_dict[table][column] = similar_values_list 844 | else: 845 | similar_values_dict[table] = {} 846 | similar_values_dict[table][column] = similar_values_list 847 | 848 | return similar_values_dict 849 | 850 | def extract_comparison_conditions_in_where_clause(db_path: str, where_clause) -> List[Dict[str, str]]: 851 | """" 852 | The function extracts list of dict which describe a condition in WHERE clause 853 | 854 | Arguments 855 | where_clause (sqlglot.expressions.Where) 856 | 857 | Returns 858 | conditions_list (List[Dict[str, str]]): List of Dict whose keys are table, column, operation, value 859 | 860 | """ 861 | 862 | conditions_list = [] 863 | if not where_clause: 864 | return conditions_list 865 | 866 | # Check if where_clause.this is a composite condition (AND/OR) 867 | if isinstance(where_clause.this, sqlglot.expressions.And) or isinstance(where_clause.this, sqlglot.expressions.Or): 868 | where_conditions = list(where_clause.this.flatten()) # flattening the where clause in the case of it is composite condition 869 | else: 870 | where_conditions = [where_clause.this] 871 | 872 | for cond_ind, condition in enumerate(where_conditions): 873 | # print("--cond_ind: ", cond_ind) 874 | # print("--condition: ", condition) 875 | # print("--condition type: ", type(condition)) 876 | columns_in_where_clause = list(condition.find_all(sqlglot.expressions.Column)) 877 | # print("******columns_in_where_clause: ", columns_in_where_clause) 878 | # it there is no sqlglot.expressions.Column in the AST of Where clause 879 | if not columns_in_where_clause: 880 | # check if the condition.this type is Dot not Column 881 | if isinstance(condition.this, sqlglot.expressions.Dot): 882 | try: 883 | if isinstance(condition.this.left, sqlglot.expressions.Literal): 884 | dot_table = condition.this.left.this 885 | if isinstance(condition.this.right, sqlglot.expressions.Literal): 886 | dot_column = condition.this.right.this 887 | if isinstance(condition.expression, sqlglot.expressions.Literal): 888 | dot_value = condition.expression.this 889 | 890 | if isinstance(condition, sqlglot.expressions.EQ): 891 | op = " = " 892 | if isinstance(condition, sqlglot.expressions.NEQ): 893 | op = " != " 894 | if isinstance(condition, sqlglot.expressions.GT): 895 | op = " > " 896 | if isinstance(condition, sqlglot.expressions.GTE): 897 | op = " >= " 898 | if isinstance(condition, sqlglot.expressions.LT): 899 | op = " < " 900 | if isinstance(condition, sqlglot.expressions.LTE): 901 | op = " <= " 902 | conditions_list.append({ 903 | "table": dot_table, 904 | "column": dot_column, 905 | "op": op, 906 | "value": dot_value, 907 | }) 908 | except Exception as e: 909 | logging.warning("Column in condition couldn't be found. Expression.Dot is found under condition but it couldn't be seperated to its table, column and value") 910 | 911 | for cols in columns_in_where_clause: 912 | comparison_conditions = [] 913 | op = "" 914 | EQ_condition = cols.find_ancestor(sqlglot.expressions.EQ) 915 | if EQ_condition and isinstance(EQ_condition.left, sqlglot.expressions.Column): 916 | comparison_conditions.append(EQ_condition) 917 | op = " = " 918 | NEQ_condition = cols.find_ancestor(sqlglot.expressions.NEQ) 919 | if NEQ_condition and isinstance(NEQ_condition.left, sqlglot.expressions.Column): 920 | comparison_conditions.append(NEQ_condition) 921 | op = " != " 922 | GT_condition = cols.find_ancestor(sqlglot.expressions.GT) 923 | if GT_condition and isinstance(GT_condition.left, sqlglot.expressions.Column): 924 | comparison_conditions.append(GT_condition) 925 | op = " > " 926 | GTE_condition = cols.find_ancestor(sqlglot.expressions.GTE) 927 | if GTE_condition and isinstance(GTE_condition.left, sqlglot.expressions.Column): 928 | comparison_conditions.append(GTE_condition) 929 | op = " >= " 930 | LT_condition = cols.find_ancestor(sqlglot.expressions.LT) 931 | if LT_condition and isinstance(LT_condition.left, sqlglot.expressions.Column): 932 | comparison_conditions.append(LT_condition) 933 | op = " < " 934 | LTE_condition = cols.find_ancestor(sqlglot.expressions.LTE) 935 | if LTE_condition and isinstance(LTE_condition.left, sqlglot.expressions.Column): 936 | comparison_conditions.append(LTE_condition) 937 | op = " <= " 938 | 939 | for cond in comparison_conditions: 940 | if not isinstance(cond.left.table, str): 941 | continue 942 | if not isinstance(cond.left.this.this, str): 943 | continue 944 | if not isinstance(cond.right.this, str): 945 | continue 946 | conditions_list.append({ 947 | "table": cond.left.table, 948 | "column": cond.left.this.this, 949 | "op": op, 950 | "value": cond.right.this, 951 | }) 952 | 953 | # print("conditions_list: \n", conditions_list) 954 | return conditions_list 955 | 956 | def get_comparison_conditions_from_sql(db_path: str, sql: str) -> List[Dict[str, str]]: 957 | """" 958 | The functions extracts conditions in given SQL 959 | 960 | Argumnets: 961 | db_path (str): The database sqlite path 962 | sql (str): Given structured query language 963 | 964 | Returns: 965 | conditions_dict_list (List[Dict[str, str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value" 966 | """ 967 | db_schema_dict = get_schema_dict(db_path) 968 | 969 | 970 | # Qualify columns such that columns will always be followed by table_name 971 | # Attempting to qualify the SQL 972 | try: 973 | qualified_parsed_sql = qualify(parse_one(sql, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False) 974 | except Exception as e1: 975 | logging.warning(f"First attempt to qualify SQL failed. Trying with first replacement set. \n\tError: {e1} \n\tSQL: {sql}") 976 | # First replacement attempt 977 | try: 978 | changed_sql_1 = sql.replace('`', '"') 979 | changed_sql_1 = changed_sql_1.replace("'", '"') 980 | qualified_parsed_sql = qualify(parse_one(changed_sql_1, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False) 981 | except Exception as e2: 982 | logging.warning(f"Second attempt to qualify SQL failed. Trying with second replacement set. \n\tError: {e2} \n\tSQL: {changed_sql_1}") 983 | # Second replacement attempt 984 | try: 985 | changed_sql_2 = sql.replace('`', "'") 986 | changed_sql_2 = changed_sql_2.replace('"', "'") 987 | qualified_parsed_sql = qualify(parse_one(changed_sql_2, read='sqlite'), schema=db_schema_dict, qualify_columns=True, validate_qualify_columns=False) 988 | except Exception as e3: 989 | logging.warning(f"Third attempts to qualify SQL failed.Trying with only parsing the SQL. \n\tError: {e3} \n\tSQL: {changed_sql_2}") 990 | # Third replacement attempt 991 | try: 992 | changed_sql_3 = sql.replace('`', "'") 993 | changed_sql_3 = changed_sql_3.replace('"', "'") 994 | qualified_parsed_sql = parse_one(changed_sql_3, read='sqlite') 995 | except Exception as e4: 996 | logging.warning(f"Fourth attempts to qualify SQL failed. Triying with only parsing the SQL. \n\tError: {e4} \n\tSQL: {changed_sql_3}") 997 | try: 998 | changed_sql_4 = sql.replace('`', '"') 999 | changed_sql_4 = changed_sql_4.replace("'", '"') 1000 | qualified_parsed_sql = parse_one(changed_sql_4, read='sqlite') 1001 | except Exception as e5: 1002 | logging.warning(f"Fifth attempts to qualify SQL failed. Triying with only parsing the SQL. \n\tError: {e4} \n\tSQL: {changed_sql_4}") 1003 | conditions_dict_list = [] 1004 | return conditions_dict_list 1005 | # Replacing table aliases with actual table names 1006 | # qualified_parsed_sql = replace_alias_with_table_names_in_sql(db_path, qualified_parsed_sql) 1007 | 1008 | # print("-qualified_parsed_sql: ", repr(qualified_parsed_sql)) 1009 | # Extract the WHERE clauses 1010 | where_clauses = list(qualified_parsed_sql.find_all(sqlglot.expressions.Where)) # where_clauses: List[] 1011 | conditions_dict_list = [] 1012 | # print("where_clauses: \n", where_clauses) 1013 | if where_clauses: 1014 | # Extract and print each WHERE clause 1015 | for index, where_clause in enumerate(where_clauses): 1016 | try: 1017 | conditions_list = extract_comparison_conditions_in_where_clause(db_path, where_clause) # List[Dict[str, str]] 1018 | conditions_dict_list.extend(conditions_list) 1019 | except Exception as e: 1020 | logging.critical(f"Error in extracting equality conditions in where clause. \nError: {e} \nSQL: {sql} ") 1021 | 1022 | return conditions_dict_list 1023 | 1024 | def extend_conditions_dict_list(conditions_dict_list: List[Dict[str, str]])-> List[Dict[str,str]]: 1025 | """ 1026 | The functions splits all the values in the conditions, then it extends the list by adding each words of a value to the conditions_dict_list 1027 | 1028 | Arguments: 1029 | conditions_dict_list (List[Dict[str,str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value". 1030 | 1031 | Returns 1032 | extended_conditions_dict_list (List[Dict[str,str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value". 1033 | 1034 | """ 1035 | extended_conditions_dict_list = conditions_dict_list.copy() 1036 | for cond_dict in conditions_dict_list: 1037 | table = cond_dict['table'] 1038 | column = cond_dict['column'] 1039 | op = cond_dict['op'] 1040 | value = cond_dict['value'] 1041 | splitted_value = value.split() 1042 | if len(splitted_value) > 1: 1043 | for val in splitted_value: 1044 | new_condition_dict = { 1045 | "table": table, 1046 | "column": column, 1047 | "op": op, 1048 | "value": val 1049 | } 1050 | extended_conditions_dict_list.append(new_condition_dict) 1051 | 1052 | return extended_conditions_dict_list 1053 | 1054 | 1055 | def get_extended_comparison_conditions_from_sql(db_path: str, sql: str) -> List[Dict[str, str]]: 1056 | """" 1057 | The functions extracts conditions in given SQL 1058 | 1059 | Argumnets: 1060 | db_path (str): The database sqlite path 1061 | sql (str): Given structured query language 1062 | 1063 | Returns: 1064 | extended_conditions_dict_list (List[Dict[str, str]]): List of comparison conditions described as dictionary. Dictionary keys are "table", "column", "op" and "value" 1065 | """ 1066 | conditions_dict_list = get_comparison_conditions_from_sql(db_path=db_path, sql=sql) 1067 | extended_conditions_dict_list = extend_conditions_dict_list(conditions_dict_list) 1068 | return extended_conditions_dict_list 1069 | 1070 | 1071 | 1072 | def collect_possible_conditions(db_path: str, sql: str) -> List[Dict[str, Union[str, Dict]]]: 1073 | """ 1074 | The functions collects possible where clause conditions depending on the Where clause comparison conditions. 1075 | 1076 | Arguments: 1077 | db_path (str): The database sqlite path 1078 | sql (str): Given structured query language 1079 | 1080 | Returns: 1081 | conditions_dict_list (List[Dict[str, Union[str, Dict]]]): List of comparison conditions described as dictionary 1082 | """ 1083 | possible_conditions_dict_list = [] 1084 | # comp_conditions_dict_list = get_comparison_conditions_from_sql(db_path, sql) # old versioin 1085 | comp_conditions_dict_list = get_extended_comparison_conditions_from_sql(db_path, sql) 1086 | for comp_cond in comp_conditions_dict_list: 1087 | value = comp_cond['value'] 1088 | similar_values_dict = find_similar_values_indb_via_like(db_path, value) 1089 | comp_cond['similar_values'] = similar_values_dict 1090 | possible_conditions_dict_list.append(comp_cond) 1091 | 1092 | return possible_conditions_dict_list 1093 | 1094 | 1095 | def measure_execution_time(db_path, query): 1096 | start_time = time.time() 1097 | query_result = execute_sql(db_path, query) 1098 | end_time = time.time() 1099 | execution_time = end_time - start_time 1100 | return execution_time, query_result -------------------------------------------------------------------------------- /utils/openai_utils.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from typing import Dict 3 | 4 | def create_response(stage: str, prompt: str, model: str, max_tokens: int, temperature: float, top_p: float, n: int) -> Dict: 5 | """ 6 | The functions creates chat response by using chat completion 7 | 8 | Arguments: 9 | stage (str): stage in the pipeline 10 | prompt (str): prepared prompt 11 | model (str): LLM model used to create chat completion 12 | max_tokens (int): The maximum number of tokens that can be generated in the chat completion 13 | temperature (float): Sampling temperature 14 | top_p (float): Nucleus sampling 15 | n (int): Number of chat completion for each input message 16 | 17 | Returns: 18 | response_object (Dict): Object returned by the model 19 | """ 20 | client = OpenAI() 21 | 22 | if stage == "question_enrichment": 23 | system_content = "You are excellent data scientist and can link the information between a question and corresponding database perfectly. Your objective is to analyze the given question, corresponding database schema, database column descriptions and the evidence to create a clear link between the given question and database items which includes tables, columns and values. With the help of link, rewrite new versions of the original question to be more related with database items, understandable, clear, absent of irrelevant information and easier to translate into SQL queries. This question enrichment is essential for comprehending the question's intent and identifying the related database items. The process involves pinpointing the relevant database components and expanding the question to incorporate these items." 24 | elif stage == "candidate_sql_generation": 25 | system_content = "You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, samples and evidence. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question." 26 | elif stage == "sql_refinement": 27 | system_content = "You are an excellent data scientist. You can capture the link between the question and corresponding database and perfectly generate valid SQLite SQL query to answer the question. Your objective is to generate SQLite SQL query by analyzing and understanding the essence of the given question, database schema, database column descriptions, evidence, possible SQL and possible conditions. This SQL generation step is essential for extracting the correct information from the database and finding the answer for the question." 28 | elif stage == "schema_filtering": 29 | system_content = "You are an excellent data scientist. You can capture the link between a question and corresponding database and determine the useful database items (tables and columns) perfectly. Your objective is to analyze and understand the essence of the given question, corresponding database schema, database column descriptions, samples and evidence and then select the useful database items such as tables and columns. This database item filtering is essential for eliminating unnecessary information in the database so that corresponding structured query language (SQL) of the question can be generated correctly in later steps." 30 | else: 31 | raise ValueError("Wrong value for stage. It can only take following values: question_enrichment, candidate_sql_generation, sql_refinement or schema_filtering.") 32 | 33 | response_object = client.chat.completions.create( 34 | model = model, 35 | messages=[ 36 | {"role": "system", "content": system_content}, 37 | {"role": "user", "content": prompt} 38 | ], 39 | max_tokens = max_tokens, 40 | response_format = { "type": "json_object" }, 41 | temperature = temperature, 42 | top_p = top_p, 43 | n=n, 44 | presence_penalty = 0.0, 45 | frequency_penalty = 0.0 46 | ) 47 | 48 | return response_object 49 | 50 | 51 | def upload_file_to_openai(file_path: str) -> Dict: 52 | """ 53 | The function uploads given file to opanai for batch processing. 54 | 55 | Arguments: 56 | file_path (str): path of the file that is going to be uplaoded 57 | Returns: 58 | file_object (FileObject): Returned file object by openai 59 | """ 60 | client = OpenAI() 61 | 62 | file_object = client.files.create( 63 | file=open(file_path, "rb"), 64 | purpose="batch" 65 | ) 66 | 67 | print("File is uploaded to OpenAI") 68 | return file_object 69 | 70 | 71 | def construct_request_input_object(prompt: str, id: int, model: str, system_message: str) -> Dict: 72 | """ 73 | The function creates a request input object for each item in the dataset 74 | 75 | Arguments: 76 | prompt (str): prompt that is going to given to the LLM as content 77 | id (int); the id of the request 78 | model (str): LLM model name 79 | system_message (str): the content of the system message 80 | 81 | Returns: 82 | request_input_object (Dict): The dictionary format required to be for request input 83 | """ 84 | request_input_object = { 85 | "custom_id": f"qe-request-{id}", 86 | "method": "POST", 87 | "url": "/v1/chat/completions", 88 | "body": { 89 | "model": model, 90 | "messages": [ 91 | {"role": "system", "content": f"{system_message}"}, 92 | {"role": "user", "content": f"{prompt}"} 93 | ] 94 | } 95 | } 96 | return request_input_object -------------------------------------------------------------------------------- /utils/prompt_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | from .db_utils import * 5 | from .retrieval_utils import * 6 | from typing import Any, Union, List, Dict 7 | 8 | 9 | def load_few_shot_data(few_shot_data_path="../few-shot-data/question_enrichment_few_shot_examples.json"): 10 | """ 11 | The function returns question enrichment few-shot data completely as a python dict 12 | 13 | Arguments: 14 | - 15 | Returns: 16 | question_enrichment_few_shot_data_dict (dictionary): question enrichment few-shot data completely 17 | """ 18 | with open(few_shot_data_path, 'r') as file: 19 | question_enrichment_few_shot_data_dict = json.load(file) 20 | 21 | return question_enrichment_few_shot_data_dict 22 | 23 | def sql_possible_conditions_prep(possible_conditions_dict_list: Dict)-> str: 24 | """ 25 | The function construct conditions statements and concatenate them. 26 | 27 | Arguments: 28 | possible_conditions_dict_list (List[Dict[str, Union[str, Dict]]]): 29 | 30 | Returns: 31 | all_possible_conditions (str) 32 | """ 33 | all_possible_conditions_list = [] 34 | if not possible_conditions_dict_list: 35 | return "" 36 | for p_cond in possible_conditions_dict_list: 37 | condition = f"`{p_cond['table']}`.`{p_cond['column']}` {p_cond['op']} `{p_cond['value']}`" 38 | all_possible_conditions_list.append(condition) 39 | similars_dict = p_cond['similar_values'] 40 | if similars_dict: 41 | for table_name, col_val_dict in similars_dict.items(): 42 | for column_name, value_list in col_val_dict.items(): 43 | for val in value_list: 44 | new_possible_cond = f"`{table_name}`.`{column_name}` {p_cond['op']} `{val}`" 45 | all_possible_conditions_list.append(new_possible_cond) 46 | 47 | return str(all_possible_conditions_list) 48 | 49 | 50 | def question_relevant_descriptions_prep(database_description_path, question, relevant_description_number)-> str: 51 | """ 52 | The functions concatenate the relevant database item (column) descriptions and returns it as a string 53 | 54 | Arguments: 55 | database_description_path (str): Path to the directory containing database description CSV files. 56 | question (str): the considered natural language question 57 | relevant_description_number (int): number of top ranked column descriptions 58 | 59 | Returns: 60 | str: Concatenated relevant database item describtions 61 | 62 | """ 63 | 64 | relevant_db_descriptions = get_relevant_db_descriptions(database_description_path, question, relevant_description_number) 65 | db_descriptions_str = "" 66 | 67 | for description in relevant_db_descriptions: 68 | db_descriptions_str = db_descriptions_str + f"# {description} \n" 69 | 70 | return db_descriptions_str 71 | 72 | 73 | def db_column_meaning_prep(database_column_meaning_path: str, db_id: str)-> str: 74 | """ 75 | The functions concatenate the database item (column) descriptions and returns it as a string 76 | 77 | Arguments: 78 | database_column_meaning_path (str): path to the column_meaning.json 79 | db_id (str): name of the database whose columns' meanings will be extracted 80 | 81 | Returns: 82 | str: Concatenated column meanings of given database 83 | 84 | """ 85 | 86 | db_column_meanings = get_db_column_meanings(database_column_meaning_path, db_id) 87 | db_column_meanings_str = "" 88 | 89 | for col_meaning in db_column_meanings: 90 | db_column_meanings_str = db_column_meanings_str + f"{col_meaning} \n" 91 | 92 | return db_column_meanings_str 93 | 94 | 95 | def sql_generation_and_refinement_few_shot_prep(few_shot_data_path: str, q_db_id: str, level_shot_number:str, schema_existance: bool, mode: str) -> str: 96 | """ 97 | The function selects the given number of exemple from the few-shot data, and then concatenate the selected question and their ground-truth SQL in string format. 98 | - The few-shot examples will be selected from the set of data contains databases different than the considered question. 99 | - If the question is already annotated, then few-shot examples will be selected from the set of data in which given question is excluded. 100 | - Level shot number can be between 0 to 10. 101 | 102 | Arguments: 103 | ew_shot_data_path (str): few-shot data path 104 | q_db_id (str): database ID (database name) of considered question 105 | level_shot_number (int): number of exemple desired to add in the prompt for each level. 106 | schema_existance (bool): Whether the schema will be provided for the exemplars in the prompt. If it is True, then schema will be provided. 107 | mode (str): dev mode or test mode 108 | Returns: 109 | few_shot_exemplars (str): selected and concatenated exemplars for the prompt 110 | """ 111 | bird_sql_path = os.getenv('BIRD_DB_PATH') 112 | 113 | if level_shot_number == 0: 114 | return "" 115 | 116 | few_shot_exemplars = "" 117 | # Check level_shot_number 118 | if level_shot_number < 0 or level_shot_number > 10: 119 | raise ValueError("Invalid few-shot number. The level_shot_number should be between 0 and 10") 120 | 121 | # Check schema_existance 122 | if not isinstance(schema_existance, bool): 123 | raise TypeError("Provided variable is not a boolean.") 124 | 125 | # Check mode 126 | if mode not in ['test', 'dev']: 127 | raise ValueError("Invalid value for mode. The variable must be either 'dev' or 'test'.") 128 | 129 | # Get all few-shot exemples 130 | all_few_shot_data = load_few_shot_data(few_shot_data_path=few_shot_data_path) 131 | # Set difficulty levels 132 | levels = ['simple', 'moderate', 'challanging'] 133 | 134 | for level in levels: 135 | examples_in_level = all_few_shot_data[level] 136 | selected_indexes = [] 137 | if mode == "dev": 138 | # remove the annotated questions if their db_id is the same with the considered question's db_id 139 | examples_in_level_tmp = [] 140 | for example in examples_in_level: 141 | if q_db_id != example['db_id']: 142 | examples_in_level_tmp.append(example) 143 | examples_in_level = examples_in_level_tmp 144 | # By removing the db_ids that is same with the considered question's db_id, the selection of the same question as an example is prevented in the case of the question was in the annotated data 145 | 146 | selected_indexes = random.sample(range(0,len(examples_in_level)), level_shot_number) # randomly select exemple_num_for_each_level number of example 147 | for ind in selected_indexes: 148 | current_question_info_dict = examples_in_level[ind] 149 | curr_q_db_id = current_question_info_dict['db_id'] 150 | db_path = bird_sql_path + f"/{mode}/{mode}_databases/{curr_q_db_id}/{curr_q_db_id}.sqlite" 151 | curr_sql = current_question_info_dict['SQL'] 152 | 153 | if schema_existance: 154 | sql_schema_dict = extract_sql_columns(db_path, curr_sql) 155 | schema = generate_schema_from_schema_dict(db_path, sql_schema_dict) # filtered schema should be as an example because given schema will be filtered one 156 | few_shot_exemplars = few_shot_exemplars + "Database Schema: \n" + schema + '\n' 157 | 158 | few_shot_exemplars = few_shot_exemplars + "Question: " + current_question_info_dict['question'] + "\n" 159 | few_shot_exemplars = few_shot_exemplars + "Evidence: " + current_question_info_dict['evidence'] + "\n" 160 | few_shot_exemplars = few_shot_exemplars + "SQL: " + current_question_info_dict['SQL'] + "\n\n" 161 | 162 | return few_shot_exemplars 163 | 164 | 165 | def fill_candidate_sql_prompt_template(template: str, schema: str, db_samples: str, question: str, few_shot_examples: str = "", evidence: str = "", db_descriptions: str = "") -> str: 166 | """ 167 | The functions completes the prompt template by filling the necessary slots which are few_shot_examples, schema, question and evidence 168 | 169 | Arguments: 170 | template (str): The template that is going to be filled 171 | schema (str): The schema of the database to which considered question belong 172 | questoin (str): The considered question that is going to be enriched 173 | few_shot_examples (str): few-shot examples that are injected to the prompt 174 | evidence (str): Given evidence statment if exist 175 | db_descriptions (str): Question relevant database item(column) descriptions 176 | 177 | Returns: 178 | prompt (str): Completed prompt for question enrichment 179 | """ 180 | if evidence == '' or evidence == None: 181 | evidence = '\n### Evidence: No evidence' 182 | else: 183 | evidence = f"\n### Evidence: \n {evidence}" 184 | 185 | if few_shot_examples == '' or few_shot_examples == None: 186 | few_shot_examples = "" 187 | else: 188 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}" 189 | 190 | schema = "\n### Database Schema: \n\n" + schema 191 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions 192 | db_samples = "\n### Database Samples: \n\n" + db_samples 193 | question = "\n### Question: \n" + question 194 | 195 | prompt = template.format( 196 | FEWSHOT_EXAMPLES = few_shot_examples, 197 | SCHEMA = schema, 198 | DB_SAMPLES = db_samples, 199 | QUESTION = question, 200 | EVIDENCE = evidence, 201 | DB_DESCRIPTIONS = db_descriptions 202 | ) 203 | 204 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}") 205 | return prompt 206 | 207 | def extract_question_enrichment_prompt_template(enrichment_template_path = "../prompt_templates/question_enrichment_prompt_template.txt") -> str: 208 | """ 209 | The function returns the question enrichment prompt template by reading corresponding txt file 210 | 211 | Arguments: 212 | None 213 | Returns: 214 | enrichment_prompt_template (str): Prompt template for enrichment prompt 215 | """ 216 | with open(enrichment_template_path, 'r') as f: 217 | enrichment_prompt_template = f.read() 218 | 219 | return enrichment_prompt_template 220 | 221 | 222 | def question_enrichment_few_shot_prep(few_shot_data_path: str, q_id: int, q_db_id: str, level_shot_number: str, schema_existance: bool, enrichment_level: str, mode: str) -> str: 223 | """ 224 | The function selects the given number of exemple from the question enrichment few-shot data, and then 225 | concatenate the selected exemples in string format. 226 | - The few-shot examples will be selected from the set of data contains databases different than the considered question. 227 | - If the question is already annotated, then few-shot examples will be selected from the set of data in which given question is excluded. 228 | - Level shot number can be between 0 to 10. 229 | 230 | Arguments: 231 | few_shot_data_path (str); path to the file in which few_shot_data exist 232 | q_id (int): id of the question 233 | q_db_id (str): database ID (database name) 234 | level_shot_number (int): number of exemple desired to add in the prompt for each level. 235 | schema_existance (bool): Whether the schema will be provided for the exemplars in the prompt. If it is True, then schema will be provided. 236 | enrichment_level (str): Either "basic" or "complex" for selecting enriched questions 237 | mode (str): dev mode or test mode 238 | Returns: 239 | few_shot_exemplars (str): selected and concatenated exemplars for the prompt 240 | """ 241 | bird_sql_path = os.getenv('BIRD_DB_PATH') 242 | 243 | if level_shot_number == 0: 244 | return "" 245 | 246 | few_shot_exemplars = "" 247 | # Check level_shot_number 248 | if level_shot_number < 0 or level_shot_number > 10: 249 | raise ValueError("Invalid few-shot number. The level_shot_number should be between 0 and 10") 250 | 251 | # Check schema_existance 252 | if not isinstance(schema_existance, bool): 253 | raise ValueError("Invalid value for schema_existance variable,it is not a boolean. It should be either True or False.") 254 | 255 | # Check enrichment_level and set enrichment_label 256 | if enrichment_level == "basic": 257 | enrichment_label = "question_enriched" 258 | elif enrichment_level == "complex": 259 | enrichment_label = "question_enriched_v2" 260 | else: 261 | raise ValueError("Invalid value for enrichment_level. The variable must be either 'basic' or 'complex'.") 262 | 263 | # Check mode 264 | if mode not in ['test', 'dev']: 265 | raise ValueError("Invalid value for mode. The variable must be either 'dev' or 'test'.") 266 | 267 | 268 | # Get all few-shot exemples 269 | all_few_shot_data = load_few_shot_data(few_shot_data_path=few_shot_data_path) 270 | # Set difficulty levels 271 | levels = ['simple', 'moderate', 'challanging'] 272 | 273 | 274 | for level in levels: 275 | examples_in_level = all_few_shot_data[level] 276 | selected_indexes = [] 277 | if mode == "dev": 278 | # remove the annotated questions if their db_id is the same with the considered question's db_id 279 | examples_in_level_tmp = [] 280 | for example in examples_in_level: 281 | if q_db_id != example['db_id']: 282 | examples_in_level_tmp.append(example) 283 | examples_in_level = examples_in_level_tmp 284 | # By removing the db_ids that is same with the considered question's db_id, the selection of the same question as an example is prevented in the case of the question was in the annotated data 285 | 286 | selected_indexes = random.sample(range(0,len(examples_in_level)), level_shot_number) # randomly select exemple_num_for_each_level number of example 287 | for ind in selected_indexes: 288 | current_question_info_dict = examples_in_level[ind] 289 | 290 | if schema_existance: 291 | curr_q_db_id = current_question_info_dict['db_id'] 292 | db_path = bird_sql_path + f"/{mode}/{mode}_databases/{curr_q_db_id}/{curr_q_db_id}.sqlite" 293 | schema = get_schema(db_path) 294 | few_shot_exemplars = few_shot_exemplars + "Database Schema: \n" + schema + '\n' 295 | 296 | few_shot_exemplars = few_shot_exemplars + "Question: " + current_question_info_dict['question'] + "\n" 297 | few_shot_exemplars = few_shot_exemplars + "Evidence: " + current_question_info_dict['evidence'] + "\n" 298 | few_shot_exemplars = few_shot_exemplars + "Enrichment Reasoning: " + current_question_info_dict['enrichment_reasoning'] + "\n" 299 | few_shot_exemplars = few_shot_exemplars + "Enriched Question: " + current_question_info_dict[enrichment_label] + "\n\n" 300 | 301 | return few_shot_exemplars 302 | 303 | 304 | def fill_question_enrichment_prompt_template(template: str, schema: str, db_samples: str, question: str, possible_conditions: str, few_shot_examples: str, evidence: str, db_descriptions: str): 305 | """ 306 | The functions completes the enrichment prompt template by filling the necessary slots which are schema, db_samples, question, possible_conditions, few_shot_examples, evidence and db_desctiptions 307 | 308 | Arguments: 309 | template (str): The template that is going to be filled 310 | schema (str): The schema of the database to which considered question belong 311 | questoin (str): The considered question that is going to be enriched 312 | possible_conditions (str): Possible conditions extracted from the possible SQLite SQL query for the question 313 | few_shot_examples (str): few-shot examples that are injected to the prompt 314 | evidence (str): Given evidence statment if exist 315 | db_descriptions (str): Question relevant database item(column) descriptions 316 | 317 | Returns: 318 | prompt (str): Completed prompt for question enrichment 319 | """ 320 | if evidence == '' or evidence == None: 321 | evidence = '\n### Evidence: No evidence' 322 | else: 323 | evidence = f"\n### Evidence: \n {evidence}" 324 | 325 | if few_shot_examples == '' or few_shot_examples == None: 326 | few_shot_examples = "" 327 | else: 328 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}" 329 | 330 | schema = "\n### Database Schema: \n\n" + schema 331 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions 332 | db_samples = "\n### Database Samples: \n\n" + db_samples 333 | question = "\n### Question: \n" + question 334 | 335 | if possible_conditions: 336 | possible_conditions = "\n### Possible SQL Conditions: \n" + possible_conditions 337 | else: 338 | possible_conditions = "\n### Possible SQL Conditions: No strict conditions were found. Please consider the database schema and keywords while enriching the Question." 339 | 340 | 341 | prompt = template.format( 342 | FEWSHOT_EXAMPLES = few_shot_examples, 343 | SCHEMA = schema, 344 | DB_SAMPLES = db_samples, 345 | QUESTION = question, 346 | EVIDENCE = evidence, 347 | DB_DESCRIPTIONS = db_descriptions, 348 | POSSIBLE_CONDITIONS = possible_conditions 349 | ) 350 | 351 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}") 352 | return prompt 353 | 354 | 355 | 356 | def fill_refinement_prompt_template(template: str, schema: str, possible_conditions: str, question: str, possible_sql: str, exec_err: str, few_shot_examples: str = "", evidence: str = "", db_descriptions: str = "") -> str: 357 | """ 358 | The functions completes the prompt template by filling the necessary slots which are few_shot_examples, schema, question, evidence and db_decriptions, possible_sql, execution error and possible_conditions 359 | 360 | Arguments: 361 | template (str): The template that is going to be filled 362 | schema (str): The schema of the database to which considered question belong 363 | questoin (str): The considered question that is going to be enriched 364 | possible_sql (str): Possible SQLite SQL query for the question 365 | exec_err (str): Taken execution error when possible SQL is executed 366 | few_shot_examples (str): few-shot examples that are injected to the prompt 367 | evidence (str): Given evidence statment if exist 368 | db_descriptions (str): Question relevant database item(column) descriptions 369 | 370 | Returns: 371 | prompt (str): Completed prompt for question enrichment 372 | """ 373 | if evidence == '' or evidence == None: 374 | evidence = '\n### Evidence: No evidence' 375 | else: 376 | evidence = f"\n### Evidence: \n {evidence}" 377 | 378 | if few_shot_examples == '' or few_shot_examples == None: 379 | few_shot_examples = "" 380 | else: 381 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}" 382 | 383 | schema = "\n### Database Schema: \n\n" + schema 384 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions 385 | question = "\n### Question: \n" + question 386 | possible_sql = "\n### Possible SQLite SQL Query: \n" + possible_sql 387 | if possible_conditions: 388 | possible_conditions = "\n### Possible SQL Conditions: \n" + possible_conditions 389 | else: 390 | possible_conditions = "\n### Possible SQL Conditions: No strict conditions were found. Please consider the database schema and keywords in the question while generating the SQL." 391 | if exec_err: 392 | exec_err = "\n### Execution Error of Possible SQL Query Above: \n" + exec_err + "\n While generating new SQLite SQL query, consider this execution error and make sure newly generated SQL query runs without execution error." 393 | else: 394 | exec_err = "" 395 | 396 | 397 | prompt = template.format( 398 | FEWSHOT_EXAMPLES = few_shot_examples, 399 | SCHEMA = schema, 400 | QUESTION = question, 401 | EVIDENCE = evidence, 402 | DB_DESCRIPTIONS = db_descriptions, 403 | POSSIBLE_SQL_Query = possible_sql, 404 | EXECUTION_ERROR = exec_err, 405 | POSSIBLE_CONDITIONS = possible_conditions 406 | ) 407 | 408 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}") 409 | return prompt 410 | 411 | 412 | def schema_filtering_few_shot_prep(few_shot_data_path: str, q_db_id: str, level_shot_number:str, schema_existance: bool, mode: str) -> str: 413 | """ 414 | The function selects the given number of exemple from the few-shot data, and determines the filtered schema of questions in few-shot data and then concatenate the selected exemples in string format. 415 | - The few-shot examples will be selected from the set of data contains databases different than the considered question. 416 | - If the question is already annotated, then few-shot examples will be selected from the set of data in which given question is excluded. 417 | - Level shot number can be between 0 to 10. 418 | 419 | Arguments: 420 | few_shot_data_path (str): few-shot data path 421 | q_db_id (str): database ID (database name) of considered question 422 | level_shot_number (int): number of exemple desired to add in the prompt for each level. 423 | schema_existance (bool): Whether the schema will be provided for the exemplars in the prompt. If it is True, then schema will be provided. 424 | mode (str): dev mode or test mode 425 | Returns: 426 | few_shot_exemplars (str): selected and concatenated exemplars for the prompt 427 | """ 428 | bird_sql_path = os.getenv('BIRD_DB_PATH', '../../dataset/bird-sql') 429 | 430 | if level_shot_number == 0: 431 | return "" 432 | 433 | few_shot_exemplars = "" 434 | # Check level_shot_number 435 | if level_shot_number < 0 or level_shot_number > 10: 436 | raise ValueError("Invalid few-shot number. The level_shot_number should be between 0 and 10") 437 | 438 | # Check schema_existance 439 | if not isinstance(schema_existance, bool): 440 | raise TypeError("Provided variable is not a boolean.") 441 | 442 | # Check mode 443 | if mode not in ['test', 'dev']: 444 | raise ValueError("Invalid value for mode. The variable must be either 'dev' or 'test'.") 445 | 446 | # Get all few-shot exemples 447 | all_few_shot_data = load_few_shot_data(few_shot_data_path=few_shot_data_path) 448 | # Set difficulty levels 449 | levels = ['simple', 'moderate', 'challanging'] 450 | 451 | for level in levels: 452 | examples_in_level = all_few_shot_data[level] 453 | selected_indexes = [] 454 | if mode == "dev": 455 | # remove the annotated questions if their db_id is the same with the considered question's db_id 456 | examples_in_level_tmp = [] 457 | for example in examples_in_level: 458 | if q_db_id != example['db_id']: 459 | examples_in_level_tmp.append(example) 460 | examples_in_level = examples_in_level_tmp 461 | # By removing the db_ids that is same with the considered question's db_id, the selection of the same question as an example is prevented in the case of the question was in the annotated data 462 | 463 | selected_indexes = random.sample(range(0,len(examples_in_level)), level_shot_number) # randomly select exemple_num_for_each_level number of example 464 | for ind in selected_indexes: 465 | current_question_info_dict = examples_in_level[ind] 466 | curr_q_db_id = current_question_info_dict['db_id'] 467 | db_path = bird_sql_path + f"/{mode}/{mode}_databases/{curr_q_db_id}/{curr_q_db_id}.sqlite" 468 | sql = current_question_info_dict['SQL'] 469 | 470 | 471 | if schema_existance: 472 | schema = get_schema(db_path) 473 | few_shot_exemplars = few_shot_exemplars + "Database Schema: \n" + schema + '\n' 474 | 475 | # print("curent directory", os.getcwd()) 476 | few_shot_exemplars = few_shot_exemplars + "Question: " + current_question_info_dict['question'] + "\n" 477 | few_shot_exemplars = few_shot_exemplars + "Evidence: " + current_question_info_dict['evidence'] + "\n" 478 | filtered_schema = extract_sql_columns(db_path, sql) 479 | few_shot_exemplars = few_shot_exemplars + "Filtered Database Schema: \n" + str(filtered_schema) + "\n" 480 | 481 | return few_shot_exemplars 482 | 483 | 484 | def fill_prompt_template(template: str, schema: str, db_samples: str, question: str, few_shot_examples: str = "", evidence: str = "", db_descriptions: str = "") -> str: 485 | """ 486 | The functions completes the prompt template by filling the necessary slots which are few_shot_examples, schema, question and evidence 487 | 488 | Arguments: 489 | template (str): The template that is going to be filled 490 | schema (str): The schema of the database to which considered question belong 491 | questoin (str): The considered question that is going to be enriched 492 | few_shot_examples (str): few-shot examples that are injected to the prompt 493 | evidence (str): Given evidence statment if exist 494 | db_descriptions (str): Question relevant database item(column) descriptions 495 | 496 | Returns: 497 | prompt (str): Completed prompt for question enrichment 498 | """ 499 | if evidence == '' or evidence == None: 500 | evidence = '\n### Evidence: No evidence' 501 | else: 502 | evidence = f"\n### Evidence: \n {evidence}" 503 | 504 | if few_shot_examples == '' or few_shot_examples == None: 505 | few_shot_examples = "" 506 | else: 507 | few_shot_examples = f"\n### Examples: \n {few_shot_examples}" 508 | 509 | schema = "\n### Database Schema: \n\n" + schema 510 | db_descriptions = "\n### Database Column Descriptions: \n\n" + db_descriptions 511 | db_samples = "\n### Database Samples: \n\n" + db_samples 512 | question = "\n### Question: \n" + question 513 | 514 | prompt = template.format( 515 | FEWSHOT_EXAMPLES = few_shot_examples, 516 | SCHEMA = schema, 517 | DB_SAMPLES = db_samples, 518 | QUESTION = question, 519 | EVIDENCE = evidence, 520 | DB_DESCRIPTIONS = db_descriptions 521 | ) 522 | 523 | prompt = prompt.replace("```json{", "{").replace("}```", "}").replace("{{", "{").replace("}}", "}") 524 | # print("The prompt: \n", prompt) 525 | return prompt -------------------------------------------------------------------------------- /utils/retrieval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import string 4 | import logging 5 | import numpy as np 6 | import pandas as pd 7 | import nltk 8 | from nltk.corpus import stopwords 9 | from nltk.tokenize import word_tokenize 10 | from rank_bm25 import BM25Okapi 11 | from typing import List 12 | 13 | 14 | def nltk_downloads(): 15 | nltk.download('stopwords') # Download the stopwords 16 | nltk.download('punkt') # Download the punkt tokenizer 17 | nltk.download('punkt_tab') # Download the punkt_tab 18 | return 19 | 20 | 21 | def save_dataframe_to_csv(df: pd.DataFrame, path: str): 22 | """ 23 | Saves the given pandas DataFrame to a CSV file at the specified path. 24 | 25 | Arguments: 26 | df (pd.DataFrame): The DataFrame to save. 27 | path (str): The file path where the CSV file will be saved. 28 | """ 29 | try: 30 | df.to_csv(path, index=False) # Set index=False to avoid saving the index column 31 | print(f"DataFrame saved successfully to {path}") 32 | except Exception as e: 33 | logging.error(f"An error occurred while saving the DataFrame: {e}") 34 | 35 | 36 | def clean_text(textData: str)-> str: 37 | """ 38 | The function process the given textData by removing stop words, removing punctuation marks and lowercasing the textData. 39 | 40 | Arguments: 41 | textData (str): text to be cleaned 42 | 43 | Returns: 44 | processedTextData (str): cleaned text 45 | """ 46 | 47 | if isinstance(textData, str): 48 | textData = textData.lower() 49 | textData = textData.replace("       ", '') 50 | 51 | # Removing punctuations 52 | # textData = textData.translate(str.maketrans('', '', string.punctuation)) # converts "don't" to "dont" 53 | textData = textData.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))) # converts "don't" to "don t" 54 | 55 | # Removing stopwords 56 | stopWordsSet = set(stopwords.words('english')) 57 | tokens = word_tokenize(textData) 58 | tokens = [token for token in tokens if not token.lower() in stopWordsSet] 59 | 60 | processedTextData = ' '.join(tokens) 61 | return processedTextData 62 | else: 63 | # if the text data is NaN return empty string 64 | return '' 65 | 66 | 67 | def construct_column_information(table_desc_df: pd.DataFrame, table_name: str) -> pd.DataFrame: 68 | """ 69 | The function combines the original column name, column description, data format, and value description information from table description CSV files into a single descriptive string for each column and adds it as a new column in the DataFrame. 70 | 71 | Arguments: 72 | table_desc_df (pd.DataFrame): DataFrame containing table descriptions. 73 | table_name (str): Name of the table. 74 | 75 | Returns: 76 | pd.Series: constructed single text column information for each column 77 | """ 78 | # Function to build column info for each row 79 | def build_column_info(row): 80 | column_info = f"The information about the {row['original_column_name']} column of the {table_name} table [{table_name}.{row['original_column_name']}] is as following." 81 | 82 | if pd.notna(row['column_description']): 83 | column_info += f" The {row['original_column_name']} column can be described as {row['column_description']}." 84 | if pd.notna(row['value_description']): 85 | column_info += f" The value description for the {row['original_column_name']} is {row['value_description']}" 86 | 87 | column_info = column_info.replace(" ", ' ') 88 | column_info = column_info.replace("       ", ' ') 89 | return column_info 90 | 91 | # Apply the function to create the "column_info" column 92 | # table_desc_df['column_info'] = table_desc_df.apply(build_column_info, axis=1) 93 | column_info_series = table_desc_df.apply(build_column_info, axis=1) 94 | 95 | return column_info_series 96 | 97 | 98 | def process_database_descriptions(database_description_path: str): 99 | """ 100 | Processes multiple CSV files in the given directory, applies the pre-existing construct_column_information function to each, 101 | and combines the "column_info" columns into a single DataFrame which is then saved as db_description.csv. 102 | 103 | Arguments: 104 | database_description_path (str): Path to the directory containing database description CSV files. 105 | """ 106 | 107 | # List to store column_info from each file 108 | all_column_infos = [] 109 | 110 | # Iterate over each file in the directory 111 | for filename in os.listdir(database_description_path): 112 | if filename.endswith(".csv") and filename != "db_description.csv" : 113 | print(f"------> {filename} table start to be processed.") 114 | file_path = os.path.join(database_description_path, filename) 115 | try: 116 | df = pd.read_csv(file_path) 117 | except: 118 | df = pd.read_csv(file_path, encoding='ISO-8859-1') 119 | 120 | table_name = filename.replace('.csv', '') 121 | column_info_series = construct_column_information(df, table_name) 122 | # Convert the Series to a DataFrame with a single column named 'column_info' 123 | column_info_df = column_info_series.to_frame(name='column_info') 124 | all_column_infos.append(column_info_df) 125 | 126 | # Combine all column_info data into a single DataFrame 127 | all_info_df = pd.concat(all_column_infos, ignore_index=True) 128 | 129 | # Save the DataFrame to a CSV file 130 | output_path = os.path.join(database_description_path, 'db_description.csv') 131 | all_info_df.to_csv(output_path, index=False) 132 | print(f"---> Database information saved successfully to {output_path}") 133 | 134 | return 135 | 136 | 137 | 138 | def process_all_dbs(dataset_path: str, mode: str): 139 | """ 140 | The function processes description of all databases and construct db_description.csv file for all databases. 141 | 142 | Arguments: 143 | dataset_path (str): General dataset path 144 | mode (str): Either dev, test or train 145 | """ 146 | nltk_downloads() # download nltk stop words 147 | databases_path = dataset_path + f"/{mode}/{mode}_databases" 148 | 149 | for db_directory in os.listdir(databases_path): 150 | #thank you, Apple inc. for creating .DS_Store files 151 | if db_directory == ".DS_Store": 152 | continue 153 | print(f"----------> Start to process {db_directory} database.") 154 | db_description_path = databases_path + "/" + db_directory + "/database_description" 155 | process_database_descriptions(database_description_path=db_description_path) 156 | 157 | print("\n\n All databases processed and db_description.csv files are created for all.\n\n") 158 | return 159 | 160 | 161 | def get_relevant_db_descriptions(database_description_path: str, question: str, relevant_description_number: int = 6) -> List[str]: 162 | """ 163 | The function returns list relevant column descriptions 164 | 165 | Arguments: 166 | database_description_path (str): Path to the directory containing database description CSV files. 167 | question (str): the considered natural language question 168 | relevant_description_number (int): number of top ranked column descriptions 169 | 170 | Returns: 171 | relevant_db_descriptions (List[str]): List of relevant column descriptions. 172 | """ 173 | db_description_csv_path = database_description_path + "/db_description.csv" 174 | 175 | if not os.path.exists(db_description_csv_path): 176 | process_database_descriptions(database_description_path) 177 | 178 | # Read the database description db_info.csv file 179 | db_desc_df = pd.read_csv(db_description_csv_path) 180 | db_description_corpus = db_desc_df['column_info'].tolist() 181 | db_description_corpus_cleaned = [clean_text(description) for description in db_description_corpus] 182 | 183 | # Tokenize corpus and create bm25 instance using cleaned corpus 184 | tokenized_db_description_corpus_cleaned = [doc.split(" ") for doc in db_description_corpus_cleaned] 185 | bm25 = BM25Okapi(tokenized_db_description_corpus_cleaned) 186 | 187 | # Tokenize question 188 | tokenized_question = question.split(" ") 189 | 190 | relevant_db_descriptions = bm25.get_top_n(tokenized_question, db_description_corpus, n=relevant_description_number) 191 | return relevant_db_descriptions 192 | 193 | def get_db_column_meanings(database_column_meaning_path: str, db_id: str) -> List[str]: 194 | """ 195 | The function extracts required database column meanings. 196 | 197 | Arguments: 198 | database_column_meaning_path (str): path to the column_meaning.json 199 | db_id (str): name of the database whose columns' meanings will be extracted 200 | 201 | Returns: 202 | List[str]: A list of strings explaining the database column meanings. 203 | """ 204 | # Load the JSON file 205 | with open(database_column_meaning_path, 'r') as file: 206 | column_meanings = json.load(file) 207 | 208 | # Initialize a list to store the extracted meanings 209 | meanings = [] 210 | 211 | # Iterate through each key in the JSON 212 | for key, explanation in column_meanings.items(): 213 | # Check if the key starts with the given db_id 214 | if key.startswith(db_id + "|"): 215 | # Extract the table name and column name from the key 216 | _, table_name, column_name = key.split("|") 217 | # Construct the meaning string in the desired format 218 | meaning = f"# Meaning of {column_name} column of {table_name} table in database is that {explanation.strip('# ').strip()}" 219 | meanings.append(meaning) 220 | 221 | return meanings 222 | --------------------------------------------------------------------------------