├── README.md ├── assets └── Workflow.png ├── evaluation ├── evaluation_bird_ex.py ├── evaluation_bird_ves.py ├── evaluation_spider.py ├── exec_eval.py ├── parse.py └── process_sql.py ├── evaluation_bird_ex_ves.sh ├── main_scripts ├── MAG.py ├── __init__.py ├── __pycache__ │ ├── MAG.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── bridge_content_encoder.cpython-39.pyc │ ├── chat_manager.cpython-39.pyc │ ├── const.cpython-39.pyc │ ├── llm.cpython-39.pyc │ └── utils.cpython-39.pyc ├── bridge_content_encoder.py ├── chat_manager.py ├── const.py ├── llm.py └── utils.py ├── output ├── gpt4_mag.json └── log │ └── log │ ├── api_trace.json │ └── txt ├── requirements.txt ├── run.py └── run.sh /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 📖 Introduction 3 | 4 | Thank you for your interest in our work: MAG-SQL: Multi-Agent Generative Approach with Soft Schema Linking and Iterative Sub-SQL Refinement for Text-to-SQL 5 | 6 | You can find our paper [[here]](https://arxiv.org/abs/2408.07930). In this paper, we propose MAG-SQL, a multi-agent generative approach with soft schema linking and iterative Sub-SQL refinement. In our framework, an entity-based method with tables' summary is used to select the columns in database, and a novel targets-conditions decomposition method is introduced to decompose those complex questions. Additionally, we build a iterative generating module which includes a Sub-SQL Generator and Sub-SQL Refiner, introducing external oversight for each step of generation. The workflow is shown below: 7 | 8 | 9 | 10 | ## 🌍 Environment 11 | 12 | 1. Config your local environment. 13 | 14 | ```bash 15 | conda create -n magsql python=3.9 -y 16 | conda activate magsql 17 | cd MAG-SQL 18 | pip install -r requirements.txt 19 | python -c "import nltk; nltk.download('punkt')" 20 | ``` 21 | 22 | Note: we use `openai==0.28.1`, which use `openai.ChatCompletions.create` to call api. 23 | 24 | 1. Edit openai config at **main_scripts/llm.py**, and set related environment variables of Azure OpenAI API. 25 | 26 | 27 | ## 📜 Data Preparation 28 | The main dataset we used is BIRD(https://bird-bench.github.io/). 29 | For evaluation, the data should be in **data/bird/dev** 30 | 31 | 32 | ## 🚀 Run 33 | 34 | You can enter the following commands at the command line to evaluate and change the corresponding file paths 35 | ```bash 36 | python run.py --dataset_name bird --dataset_mode dev --input_file ./data/bird/dev/dev.json --db_path ./data/bird/dev/dev_databases/ --tables_json_path ./data/bird/dev/dev_tables.json --output_file ./output/gpt4_mag.json --log_file ./output/log/log/txt --start_pos 0 37 | ``` 38 | 39 | ## 📑 Evaluation 40 | 41 | We provide the script for Linux systems. 42 | **evaluation_bird_ex_ves.sh** 43 | 44 | ## 🎈 Useful Resources 45 | 46 | Our project is built based on MAC-SQL: [[Repo]](https://github.com/wbbeyourself/MAC-SQL) 47 | 48 | ## 📰 Citation 49 | If you find our work is helpful, please cite as: 50 | ```text 51 | @article{xie2024mag, 52 | title={MAG-SQL: Multi-Agent Generative Approach with Soft Schema Linking and Iterative Sub-SQL Refinement for Text-to-SQL}, 53 | author={Xie, Wenxuan and Wu, Gaochen and Zhou, Bowen}, 54 | journal={arXiv preprint arXiv:2408.07930}, 55 | year={2024} 56 | } 57 | ``` -------------------------------------------------------------------------------- /assets/Workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/assets/Workflow.png -------------------------------------------------------------------------------- /evaluation/evaluation_bird_ex.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import json 5 | import argparse 6 | import sqlite3 7 | import multiprocessing as mp 8 | from func_timeout import func_timeout, FunctionTimedOut 9 | 10 | def replace_multiple_spaces(text): 11 | # 定义正则表达式,匹配多个空字符 12 | pattern = r'\s+' 13 | # 将多个空字符替换成一个空格 14 | new_text = re.sub(pattern, ' ', text) 15 | return new_text 16 | 17 | def load_json(dir): 18 | with open(dir, 'r', encoding='utf8') as j: 19 | contents = json.loads(j.read()) 20 | return contents 21 | 22 | def save_json_file(path, data): 23 | with open(path, 'w', encoding='utf-8') as f: 24 | json.dump(data, f, indent=2, ensure_ascii=False) 25 | print(f"save json file to {path}") 26 | 27 | def result_callback(result): 28 | exec_result.append(result) 29 | 30 | 31 | def execute_sql(predicted_sql,ground_truth, db_path): 32 | conn = sqlite3.connect(db_path) 33 | # Connect to the database 34 | cursor = conn.cursor() 35 | cursor.execute(predicted_sql) 36 | predicted_res = cursor.fetchall() 37 | cursor.execute(ground_truth) 38 | ground_truth_res = cursor.fetchall() 39 | res = 0 40 | # todo: this should permute column order! 41 | if set(predicted_res) == set(ground_truth_res): 42 | res = 1 43 | return res 44 | 45 | 46 | 47 | def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out): 48 | try: 49 | res = func_timeout(meta_time_out, execute_sql, 50 | args=(predicted_sql, ground_truth, db_place)) 51 | except KeyboardInterrupt: 52 | sys.exit(0) 53 | except FunctionTimedOut: 54 | result = [(f'timeout',)] 55 | res = 0 56 | except Exception as e: 57 | result = [(f'error',)] # possibly len(query) > 512 or not executable 58 | res = 0 59 | # print(result) 60 | # result = str(set([ret[0] for ret in result])) 61 | result = {'sql_idx': idx, 'res': res} 62 | # print(result) 63 | return result 64 | 65 | 66 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): 67 | clean_sqls = [] 68 | db_path_list = [] 69 | if mode == 'gpt': 70 | sql_data = json.load(open(sql_path, 'r', encoding='utf8')) 71 | for idx, sql_str in sql_data: # .items() 72 | if type(sql_str) == str: 73 | sql, db_name = sql_str.split('\t----- bird -----\t') 74 | else: 75 | sql, db_name = " ", "financial" 76 | clean_sqls.append(sql) 77 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 78 | 79 | elif mode == 'gt': 80 | sqls = open(sql_path, encoding='utf8') 81 | sql_txt = sqls.readlines() 82 | for idx, sql_str in enumerate(sql_txt): 83 | sql, db_name = sql_str.strip().split('\t') 84 | clean_sqls.append(sql) 85 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 86 | 87 | return clean_sqls, db_path_list 88 | 89 | def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): 90 | pool = mp.Pool(processes=num_cpus) 91 | for i, sql_pair in enumerate(sqls): 92 | 93 | predicted_sql, ground_truth = sql_pair 94 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback) 95 | pool.close() 96 | pool.join() 97 | 98 | def sort_results(list_of_dicts): 99 | return sorted(list_of_dicts, key=lambda x: x['sql_idx']) 100 | 101 | def compute_acc_by_diff(exec_results, diff_json_path): 102 | num_queries = len(exec_results) 103 | results = [res['res'] for res in exec_results] 104 | contents = load_json(diff_json_path) 105 | simple_results, moderate_results, challenging_results = [], [], [] 106 | 107 | for i, content in enumerate(contents): 108 | difficulty = content.get('difficulty', 'simple') 109 | if difficulty == 'simple': 110 | try: 111 | simple_results.append(exec_results[i]) 112 | except Exception as e: 113 | print(e) 114 | import pdb 115 | pdb.set_trace() 116 | 117 | if difficulty == 'moderate': 118 | moderate_results.append(exec_results[i]) 119 | 120 | if difficulty == 'challenging': 121 | challenging_results.append(exec_results[i]) 122 | 123 | simple_acc = sum([res['res'] for res in simple_results])/len(simple_results) 124 | 125 | if len(moderate_results) == 0: 126 | moderate_acc = 0 127 | else: 128 | moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results) 129 | 130 | if len(challenging_results) == 0: 131 | challenging_acc = 0 132 | else: 133 | challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results) 134 | 135 | all_acc = sum(results)/num_queries 136 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] 137 | return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists 138 | 139 | 140 | 141 | def print_data(score_lists,count_lists): 142 | levels = ['simple', 'moderate', 'challenging', 'total'] 143 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 144 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) 145 | 146 | print('====================================== ACCURACY =====================================') 147 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) 148 | 149 | 150 | if __name__ == '__main__': 151 | args_parser = argparse.ArgumentParser() 152 | args_parser.add_argument('--predicted_sql_json_path', type=str, required=True) 153 | args_parser.add_argument('--ground_truth_sql_path', type=str, required=True) 154 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev', choices=['train', 'dev', 'test']) 155 | args_parser.add_argument('--db_root_path', type=str, required=True) 156 | args_parser.add_argument('--num_cpus', type=int, default=1) 157 | args_parser.add_argument('--meta_time_out', type=float, default=30.0) 158 | args_parser.add_argument('--mode_predict', type=str, default='gpt') 159 | args_parser.add_argument('--difficulty',type=str, default='simple') 160 | args_parser.add_argument('--diff_json_path',type=str,default='./data/bird/dev/dev.json') 161 | args = args_parser.parse_args() 162 | exec_result = [] 163 | 164 | pred_queries, db_paths = package_sqls(args.predicted_sql_json_path, args.db_root_path, 165 | mode=args.mode_predict, data_mode=args.data_mode) 166 | if len(pred_queries) == 0: 167 | raise ValueError(f'Empty data in {args.predicted_sql_json_path}') 168 | # generate gt sqls: 169 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_sql_path, args.db_root_path, mode='gt', 170 | data_mode=args.data_mode) 171 | print(len(pred_queries),len(gt_queries)) 172 | assert len(pred_queries) == len(gt_queries), "len(pred_queries) != len(gt_queries)" 173 | query_pairs = list(zip(pred_queries, gt_queries)) 174 | run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) 175 | exec_result = sort_results(exec_result) 176 | 177 | # save ex results 178 | out_dir = os.path.dirname(args.predicted_sql_json_path) 179 | if not os.path.exists(out_dir): 180 | os.makedirs(out_dir, exist_ok=True) 181 | result_json_path = os.path.join(out_dir, f'eval_result_{args.data_mode}.json') 182 | 183 | # relocate idx of exec_result 184 | raw_json_data = load_json(args.diff_json_path) 185 | pred_sqls = [replace_multiple_spaces(s) for s in pred_queries] 186 | result_json_lst = [] 187 | for i, item in enumerate(raw_json_data): 188 | item['pred'] = pred_sqls[i] 189 | item['gold'] = replace_multiple_spaces(item.get('SQL', '')) 190 | if 'SQL' in item: 191 | del item['SQL'] 192 | item['res'] = exec_result[i]['res'] 193 | result_json_lst.append(item) 194 | save_json_file(result_json_path, result_json_lst) 195 | 196 | print('start calculate') 197 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ 198 | compute_acc_by_diff(exec_result, args.diff_json_path) 199 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 200 | print_data(score_lists,count_lists) 201 | print('===========================================================================================') 202 | print("Finished evaluation") 203 | -------------------------------------------------------------------------------- /evaluation/evaluation_bird_ves.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import sys 4 | import json 5 | import numpy as np 6 | import argparse 7 | import sqlite3 8 | import multiprocessing as mp 9 | from func_timeout import func_timeout, FunctionTimedOut 10 | import time 11 | import math 12 | 13 | 14 | def result_callback(result): 15 | exec_result.append(result) 16 | 17 | 18 | def clean_abnormal(input): 19 | input = np.asarray(input) 20 | processed_list = [] 21 | mean = np.mean(input, axis=0) 22 | std = np.std(input, axis=0) 23 | for x in input: 24 | if x < mean + 3 * std and x > mean - 3 * std: 25 | processed_list.append(x) 26 | return processed_list 27 | 28 | 29 | def execute_sql(sql, db_path): 30 | # Connect to the database 31 | conn = sqlite3.connect(db_path) 32 | # Create a cursor object 33 | cursor = conn.cursor() 34 | start_time = time.time() 35 | cursor.execute(sql) 36 | exec_time = time.time() - start_time 37 | return exec_time 38 | 39 | 40 | def iterated_execute_sql(predicted_sql, ground_truth, db_path, iterate_num): 41 | conn = sqlite3.connect(db_path) 42 | diff_list = [] 43 | cursor = conn.cursor() 44 | cursor.execute(predicted_sql) 45 | predicted_res = cursor.fetchall() 46 | cursor.execute(ground_truth) 47 | ground_truth_res = cursor.fetchall() 48 | time_ratio = 0 49 | if set(predicted_res) == set(ground_truth_res): 50 | for i in range(iterate_num): 51 | predicted_time = execute_sql(predicted_sql, db_path) 52 | ground_truth_time = execute_sql(ground_truth, db_path) 53 | diff_list.append(ground_truth_time / predicted_time) 54 | processed_diff_list = clean_abnormal(diff_list) 55 | time_ratio = sum(processed_diff_list) / len(processed_diff_list) 56 | return time_ratio 57 | 58 | 59 | def execute_model(predicted_sql, ground_truth, db_place, idx, iterate_num, meta_time_out): 60 | try: 61 | # you can personalize the total timeout number 62 | # larger timeout leads to more stable ves 63 | # while it needs more your patience.... 64 | print(idx, file=sys.stdout, flush=True) 65 | time_ratio = func_timeout(meta_time_out * iterate_num, iterated_execute_sql, 66 | args=(predicted_sql, ground_truth, db_place, iterate_num)) 67 | # print([idx, math.sqrt(time_ratio)]) 68 | except KeyboardInterrupt: 69 | sys.exit(0) 70 | except FunctionTimedOut: 71 | result = [(f'timeout',)] 72 | time_ratio = 0 73 | except Exception as e: 74 | result = [(f'error',)] # possibly len(query) > 512 or not executable 75 | time_ratio = 0 76 | result = {'sql_idx': idx, 'time_ratio': time_ratio} 77 | return result 78 | 79 | 80 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): 81 | clean_sqls = [] 82 | db_path_list = [] 83 | if mode == 'gpt': 84 | sql_data = json.load(open(sql_path, 'r', encoding='utf8')) 85 | for idx, sql_str in sql_data: 86 | if type(sql_str) == str: 87 | sql, db_name = sql_str.split('\t----- bird -----\t') 88 | else: 89 | sql, db_name = " ", "financial" 90 | clean_sqls.append(sql) 91 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 92 | 93 | elif mode == 'gt': 94 | sqls = open(sql_path, encoding='utf8') 95 | sql_txt = sqls.readlines() 96 | for idx, sql_str in enumerate(sql_txt): 97 | sql, db_name = sql_str.strip().split('\t') 98 | clean_sqls.append(sql) 99 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 100 | 101 | return clean_sqls, db_path_list 102 | 103 | 104 | def run_sqls_parallel(sqls, db_places, num_cpus=1, iterate_num=100, meta_time_out=30.0): 105 | pool = mp.Pool(processes=num_cpus) 106 | for i, sql_pair in enumerate(sqls): 107 | predicted_sql, ground_truth = sql_pair 108 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, iterate_num, meta_time_out), 109 | callback=result_callback) 110 | pool.close() 111 | pool.join() 112 | 113 | 114 | def sort_results(list_of_dicts): 115 | return sorted(list_of_dicts, key=lambda x: x['sql_idx']) 116 | 117 | 118 | def compute_ves(exec_results): 119 | num_queries = len(exec_results) 120 | if num_queries == 0: 121 | return 0 122 | total_ratio = 0 123 | count = 0 124 | 125 | for i, result in enumerate(exec_results): 126 | if result['time_ratio'] != 0: 127 | count += 1 128 | total_ratio += math.sqrt(result['time_ratio']) * 100 129 | ves = (total_ratio / num_queries) 130 | return ves 131 | 132 | 133 | def load_json(dir): 134 | with open(dir, 'r', encoding='utf8') as j: 135 | contents = json.loads(j.read()) 136 | return contents 137 | 138 | 139 | def compute_ves_by_diff(exec_results, diff_json_path): 140 | num_queries = len(exec_results) 141 | contents = load_json(diff_json_path) 142 | simple_results, moderate_results, challenging_results = [], [], [] 143 | for i, content in enumerate(contents): 144 | if content['difficulty'] == 'simple': 145 | simple_results.append(exec_results[i]) 146 | if content['difficulty'] == 'moderate': 147 | moderate_results.append(exec_results[i]) 148 | if content['difficulty'] == 'challenging': 149 | challenging_results.append(exec_results[i]) 150 | simple_ves = compute_ves(simple_results) 151 | moderate_ves = compute_ves(moderate_results) 152 | challenging_ves = compute_ves(challenging_results) 153 | all_ves = compute_ves(exec_results) 154 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] 155 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists 156 | 157 | 158 | def print_data(score_lists, count_lists): 159 | levels = ['simple', 'moderate', 'challenging', 'total'] 160 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 161 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) 162 | 163 | print('========================================= VES ========================================') 164 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('ves', *score_lists)) 165 | 166 | 167 | if __name__ == '__main__': 168 | args_parser = argparse.ArgumentParser() 169 | args_parser.add_argument('--predicted_sql_json_path', type=str, required=True, default='') 170 | args_parser.add_argument('--ground_truth_sql_path', type=str, required=True, default='') 171 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev') 172 | args_parser.add_argument('--db_root_path', type=str, required=True, default='') 173 | args_parser.add_argument('--num_cpus', type=int, default=1) 174 | args_parser.add_argument('--meta_time_out', type=float, default=30.0) 175 | args_parser.add_argument('--mode_gt', type=str, default='gt') 176 | args_parser.add_argument('--mode_predict', type=str, default='gpt') 177 | args_parser.add_argument('--diff_json_path', type=str, required=True, default='') 178 | args = args_parser.parse_args() 179 | exec_result = [] 180 | 181 | pred_queries, db_paths = package_sqls(args.predicted_sql_json_path, args.db_root_path, 182 | mode=args.mode_predict, data_mode=args.data_mode) 183 | if len(pred_queries) == 0: 184 | raise ValueError(f'Empty data in {args.predicted_sql_json_path}') 185 | # generate gt sqls: 186 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_sql_path, args.db_root_path, mode='gt', 187 | data_mode=args.data_mode) 188 | 189 | assert len(pred_queries) == len(gt_queries), "len(pred_queries) != len(gt_queries)" 190 | query_pairs = list(zip(pred_queries, gt_queries)) 191 | run_sqls_parallel(query_pairs, iterate_num=100, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) 192 | exec_result = sort_results(exec_result) 193 | print('start calculate') 194 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = \ 195 | compute_ves_by_diff(exec_result, args.diff_json_path) 196 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves] 197 | print_data(score_lists, count_lists) 198 | print('===========================================================================================') 199 | print("Finished evaluation") 200 | 201 | 202 | -------------------------------------------------------------------------------- /evaluation/evaluation_spider.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # val: number(float)/string(str)/sql(dict) 3 | # col_unit: (agg_id, col_id, isDistinct(bool)) 4 | # val_unit: (unit_op, col_unit1, col_unit2) 5 | # table_unit: (table_type, col_unit/sql) 6 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 7 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 8 | # sql { 9 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 10 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 11 | # 'where': condition 12 | # 'groupBy': [col_unit1, col_unit2, ...] 13 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 14 | # 'having': condition 15 | # 'limit': None/limit value 16 | # 'intersect': None/sql 17 | # 'except': None/sql 18 | # 'union': None/sql 19 | # } 20 | ################################ 21 | 22 | import os 23 | import json 24 | import sqlite3 25 | import argparse 26 | 27 | from process_sql import get_schema, Schema, get_sql 28 | from exec_eval import eval_exec_match 29 | 30 | # Flag to disable value evaluation 31 | DISABLE_VALUE = True 32 | # Flag to disable distinct in select evaluation 33 | DISABLE_DISTINCT = True 34 | 35 | 36 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 37 | JOIN_KEYWORDS = ('join', 'on', 'as') 38 | 39 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 40 | UNIT_OPS = ('none', '-', '+', "*", '/') 41 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 42 | TABLE_TYPE = { 43 | 'sql': "sql", 44 | 'table_unit': "table_unit", 45 | } 46 | 47 | COND_OPS = ('and', 'or') 48 | SQL_OPS = ('intersect', 'union', 'except') 49 | ORDER_OPS = ('desc', 'asc') 50 | 51 | 52 | HARDNESS = { 53 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), 54 | "component2": ('except', 'union', 'intersect') 55 | } 56 | 57 | 58 | def condition_has_or(conds): 59 | return 'or' in conds[1::2] 60 | 61 | 62 | def condition_has_like(conds): 63 | return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] 64 | 65 | 66 | def condition_has_sql(conds): 67 | for cond_unit in conds[::2]: 68 | val1, val2 = cond_unit[3], cond_unit[4] 69 | if val1 is not None and type(val1) is dict: 70 | return True 71 | if val2 is not None and type(val2) is dict: 72 | return True 73 | return False 74 | 75 | 76 | def val_has_op(val_unit): 77 | return val_unit[0] != UNIT_OPS.index('none') 78 | 79 | 80 | def has_agg(unit): 81 | return unit[0] != AGG_OPS.index('none') 82 | 83 | 84 | def accuracy(count, total): 85 | if count == total: 86 | return 1 87 | return 0 88 | 89 | 90 | def recall(count, total): 91 | if count == total: 92 | return 1 93 | return 0 94 | 95 | 96 | def F1(acc, rec): 97 | if (acc + rec) == 0: 98 | return 0 99 | return (2. * acc * rec) / (acc + rec) 100 | 101 | 102 | def get_scores(count, pred_total, label_total): 103 | if pred_total != label_total: 104 | return 0,0,0 105 | elif count == pred_total: 106 | return 1,1,1 107 | return 0,0,0 108 | 109 | 110 | def eval_sel(pred, label): 111 | pred_sel = pred['select'][1] 112 | label_sel = label['select'][1] 113 | label_wo_agg = [unit[1] for unit in label_sel] 114 | pred_total = len(pred_sel) 115 | label_total = len(label_sel) 116 | cnt = 0 117 | cnt_wo_agg = 0 118 | 119 | for unit in pred_sel: 120 | if unit in label_sel: 121 | cnt += 1 122 | label_sel.remove(unit) 123 | if unit[1] in label_wo_agg: 124 | cnt_wo_agg += 1 125 | label_wo_agg.remove(unit[1]) 126 | 127 | return label_total, pred_total, cnt, cnt_wo_agg 128 | 129 | 130 | def eval_where(pred, label): 131 | pred_conds = [unit for unit in pred['where'][::2]] 132 | label_conds = [unit for unit in label['where'][::2]] 133 | label_wo_agg = [unit[2] for unit in label_conds] 134 | pred_total = len(pred_conds) 135 | label_total = len(label_conds) 136 | cnt = 0 137 | cnt_wo_agg = 0 138 | 139 | for unit in pred_conds: 140 | if unit in label_conds: 141 | cnt += 1 142 | label_conds.remove(unit) 143 | if unit[2] in label_wo_agg: 144 | cnt_wo_agg += 1 145 | label_wo_agg.remove(unit[2]) 146 | 147 | return label_total, pred_total, cnt, cnt_wo_agg 148 | 149 | 150 | def eval_group(pred, label): 151 | pred_cols = [unit[1] for unit in pred['groupBy']] 152 | label_cols = [unit[1] for unit in label['groupBy']] 153 | pred_total = len(pred_cols) 154 | label_total = len(label_cols) 155 | cnt = 0 156 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] 157 | label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] 158 | for col in pred_cols: 159 | if col in label_cols: 160 | cnt += 1 161 | label_cols.remove(col) 162 | return label_total, pred_total, cnt 163 | 164 | 165 | def eval_having(pred, label): 166 | pred_total = label_total = cnt = 0 167 | if len(pred['groupBy']) > 0: 168 | pred_total = 1 169 | if len(label['groupBy']) > 0: 170 | label_total = 1 171 | 172 | pred_cols = [unit[1] for unit in pred['groupBy']] 173 | label_cols = [unit[1] for unit in label['groupBy']] 174 | if pred_total == label_total == 1 \ 175 | and pred_cols == label_cols \ 176 | and pred['having'] == label['having']: 177 | cnt = 1 178 | 179 | return label_total, pred_total, cnt 180 | 181 | 182 | def eval_order(pred, label): 183 | pred_total = label_total = cnt = 0 184 | if len(pred['orderBy']) > 0: 185 | pred_total = 1 186 | if len(label['orderBy']) > 0: 187 | label_total = 1 188 | if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ 189 | ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): 190 | cnt = 1 191 | return label_total, pred_total, cnt 192 | 193 | 194 | def eval_and_or(pred, label): 195 | pred_ao = pred['where'][1::2] 196 | label_ao = label['where'][1::2] 197 | pred_ao = set(pred_ao) 198 | label_ao = set(label_ao) 199 | 200 | if pred_ao == label_ao: 201 | return 1,1,1 202 | return len(pred_ao),len(label_ao),0 203 | 204 | 205 | def get_nestedSQL(sql): 206 | nested = [] 207 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: 208 | if type(cond_unit[3]) is dict: 209 | nested.append(cond_unit[3]) 210 | if type(cond_unit[4]) is dict: 211 | nested.append(cond_unit[4]) 212 | if sql['intersect'] is not None: 213 | nested.append(sql['intersect']) 214 | if sql['except'] is not None: 215 | nested.append(sql['except']) 216 | if sql['union'] is not None: 217 | nested.append(sql['union']) 218 | return nested 219 | 220 | 221 | def eval_nested(pred, label): 222 | label_total = 0 223 | pred_total = 0 224 | cnt = 0 225 | if pred is not None: 226 | pred_total += 1 227 | if label is not None: 228 | label_total += 1 229 | if pred is not None and label is not None: 230 | cnt += Evaluator().eval_exact_match(pred, label) 231 | return label_total, pred_total, cnt 232 | 233 | 234 | def eval_IUEN(pred, label): 235 | lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) 236 | lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) 237 | lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) 238 | label_total = lt1 + lt2 + lt3 239 | pred_total = pt1 + pt2 + pt3 240 | cnt = cnt1 + cnt2 + cnt3 241 | return label_total, pred_total, cnt 242 | 243 | 244 | def get_keywords(sql): 245 | res = set() 246 | if len(sql['where']) > 0: 247 | res.add('where') 248 | if len(sql['groupBy']) > 0: 249 | res.add('group') 250 | if len(sql['having']) > 0: 251 | res.add('having') 252 | if len(sql['orderBy']) > 0: 253 | res.add(sql['orderBy'][0]) 254 | res.add('order') 255 | if sql['limit'] is not None: 256 | res.add('limit') 257 | if sql['except'] is not None: 258 | res.add('except') 259 | if sql['union'] is not None: 260 | res.add('union') 261 | if sql['intersect'] is not None: 262 | res.add('intersect') 263 | 264 | # or keyword 265 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 266 | if len([token for token in ao if token == 'or']) > 0: 267 | res.add('or') 268 | 269 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 270 | # not keyword 271 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: 272 | res.add('not') 273 | 274 | # in keyword 275 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: 276 | res.add('in') 277 | 278 | # like keyword 279 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: 280 | res.add('like') 281 | 282 | return res 283 | 284 | 285 | def eval_keywords(pred, label): 286 | pred_keywords = get_keywords(pred) 287 | label_keywords = get_keywords(label) 288 | pred_total = len(pred_keywords) 289 | label_total = len(label_keywords) 290 | cnt = 0 291 | 292 | for k in pred_keywords: 293 | if k in label_keywords: 294 | cnt += 1 295 | return label_total, pred_total, cnt 296 | 297 | 298 | def count_agg(units): 299 | return len([unit for unit in units if has_agg(unit)]) 300 | 301 | 302 | def count_component1(sql): 303 | count = 0 304 | if len(sql['where']) > 0: 305 | count += 1 306 | if len(sql['groupBy']) > 0: 307 | count += 1 308 | if len(sql['orderBy']) > 0: 309 | count += 1 310 | if sql['limit'] is not None: 311 | count += 1 312 | if len(sql['from']['table_units']) > 0: # JOIN 313 | count += len(sql['from']['table_units']) - 1 314 | 315 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 316 | count += len([token for token in ao if token == 'or']) 317 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 318 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) 319 | 320 | return count 321 | 322 | 323 | def count_component2(sql): 324 | nested = get_nestedSQL(sql) 325 | return len(nested) 326 | 327 | 328 | def count_others(sql): 329 | count = 0 330 | # number of aggregation 331 | agg_count = count_agg(sql['select'][1]) 332 | agg_count += count_agg(sql['where'][::2]) 333 | agg_count += count_agg(sql['groupBy']) 334 | if len(sql['orderBy']) > 0: 335 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + 336 | [unit[2] for unit in sql['orderBy'][1] if unit[2]]) 337 | agg_count += count_agg(sql['having']) 338 | if agg_count > 1: 339 | count += 1 340 | 341 | # number of select columns 342 | if len(sql['select'][1]) > 1: 343 | count += 1 344 | 345 | # number of where conditions 346 | if len(sql['where']) > 1: 347 | count += 1 348 | 349 | # number of group by clauses 350 | if len(sql['groupBy']) > 1: 351 | count += 1 352 | 353 | return count 354 | 355 | 356 | class Evaluator: 357 | """A simple evaluator""" 358 | def __init__(self): 359 | self.partial_scores = None 360 | 361 | def eval_hardness(self, sql): 362 | count_comp1_ = count_component1(sql) 363 | count_comp2_ = count_component2(sql) 364 | count_others_ = count_others(sql) 365 | 366 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: 367 | return "easy" 368 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ 369 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): 370 | return "medium" 371 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ 372 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ 373 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): 374 | return "hard" 375 | else: 376 | return "extra" 377 | 378 | def eval_exact_match(self, pred, label): 379 | partial_scores = self.eval_partial_match(pred, label) 380 | self.partial_scores = partial_scores 381 | 382 | for key, score in partial_scores.items(): 383 | if score['f1'] != 1: 384 | return 0 385 | 386 | if len(label['from']['table_units']) > 0: 387 | label_tables = sorted(label['from']['table_units']) 388 | pred_tables = sorted(pred['from']['table_units']) 389 | return label_tables == pred_tables 390 | return 1 391 | 392 | def eval_partial_match(self, pred, label): 393 | res = {} 394 | 395 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) 396 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 397 | res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 398 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 399 | res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 400 | 401 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) 402 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 403 | res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 404 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 405 | res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 406 | 407 | label_total, pred_total, cnt = eval_group(pred, label) 408 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 409 | res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 410 | 411 | label_total, pred_total, cnt = eval_having(pred, label) 412 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 413 | res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 414 | 415 | label_total, pred_total, cnt = eval_order(pred, label) 416 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 417 | res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 418 | 419 | label_total, pred_total, cnt = eval_and_or(pred, label) 420 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 421 | res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 422 | 423 | label_total, pred_total, cnt = eval_IUEN(pred, label) 424 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 425 | res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 426 | 427 | label_total, pred_total, cnt = eval_keywords(pred, label) 428 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 429 | res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 430 | 431 | return res 432 | 433 | 434 | def isValidSQL(sql, db): 435 | conn = sqlite3.connect(db) 436 | cursor = conn.cursor() 437 | try: 438 | cursor.execute(sql) 439 | except: 440 | return False 441 | return True 442 | 443 | 444 | 445 | def print_formated_s(row_name, l, element_format): 446 | template = "{:20} " + ' '.join([element_format] * len(l)) 447 | print(template.format(row_name, *l)) 448 | 449 | 450 | def print_scores(scores, etype, include_turn_acc=True): 451 | turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4'] 452 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 453 | if include_turn_acc: 454 | levels.append('joint_all') 455 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 456 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 457 | 458 | print_formated_s("", levels, '{:20}') 459 | counts = [scores[level]['count'] for level in levels] 460 | print_formated_s("count", counts, '{:<20d}') 461 | 462 | if etype in ["all", "exec"]: 463 | print ('===================== EXECUTION ACCURACY =====================') 464 | exec_scores = [scores[level]['exec'] for level in levels] 465 | print_formated_s("execution", exec_scores, '{:<20.3f}') 466 | 467 | if etype in ["all", "match"]: 468 | print ('\n====================== EXACT MATCHING ACCURACY =====================') 469 | exact_scores = [scores[level]['exact'] for level in levels] 470 | print_formated_s("exact match", exact_scores, '{:<20.3f}') 471 | print ('\n---------------------PARTIAL MATCHING ACCURACY----------------------') 472 | for type_ in partial_types: 473 | this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] 474 | print_formated_s(type_, this_scores, '{:<20.3f}') 475 | 476 | print ('---------------------- PARTIAL MATCHING RECALL ----------------------') 477 | for type_ in partial_types: 478 | this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] 479 | print_formated_s(type_, this_scores, '{:<20.3f}') 480 | 481 | print ('---------------------- PARTIAL MATCHING F1 --------------------------') 482 | for type_ in partial_types: 483 | this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] 484 | print_formated_s(type_, this_scores, '{:<20.3f}') 485 | 486 | if include_turn_acc: 487 | print() 488 | print() 489 | print_formated_s("", turns, '{:20}') 490 | counts = [scores[turn]['count'] for turn in turns] 491 | print_formated_s("count", counts, "{:<20d}") 492 | 493 | if etype in ["all", "exec"]: 494 | print ('===================== TURN EXECUTION ACCURACY =====================') 495 | exec_scores = [scores[turn]['exec'] for turn in turns] 496 | print_formated_s("execution", exec_scores, '{:<20.3f}') 497 | 498 | if etype in ["all", "match"]: 499 | print ('\n====================== TURN EXACT MATCHING ACCURACY =====================') 500 | exact_scores = [scores[turn]['exact'] for turn in turns] 501 | print_formated_s("exact match", exact_scores, '{:<20.3f}') 502 | 503 | 504 | def evaluate(gold, predict, db_dir, etype, kmaps, plug_value, keep_distinct, progress_bar_for_each_datapoint): 505 | 506 | with open(gold) as f: 507 | glist = [] 508 | gseq_one = [] 509 | for l in f.readlines(): 510 | if len(l.strip()) == 0: 511 | glist.append(gseq_one) 512 | gseq_one = [] 513 | else: 514 | lstrip = l.strip().split('\t') 515 | gseq_one.append(lstrip) 516 | 517 | # include the last session 518 | # this was previously ignored in the SParC evaluation script 519 | # which might lead to slight differences in scores 520 | if len(gseq_one) != 0: 521 | glist.append(gseq_one) 522 | 523 | # spider formatting indicates that there is only one "single turn" 524 | # do not report "turn accuracy" for SPIDER 525 | include_turn_acc = len(glist) > 1 526 | 527 | with open(predict) as f: 528 | plist = [] 529 | pseq_one = [] 530 | for l in f.readlines(): 531 | if len(l.strip()) == 0: 532 | plist.append(pseq_one) 533 | pseq_one = [] 534 | else: 535 | pseq_one.append(l.strip().split('\t')) 536 | 537 | if len(pseq_one) != 0: 538 | plist.append(pseq_one) 539 | 540 | assert len(plist) == len(glist), "number of sessions must equal" 541 | 542 | evaluator = Evaluator() 543 | turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4'] 544 | levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] 545 | 546 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 547 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 548 | entries = [] 549 | scores = {} 550 | 551 | for turn in turns: 552 | scores[turn] = {'count': 0, 'exact': 0.} 553 | scores[turn]['exec'] = 0 554 | 555 | for level in levels: 556 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} 557 | scores[level]['exec'] = 0 558 | for type_ in partial_types: 559 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} 560 | 561 | gold_pred_map_lst = [] 562 | 563 | for i, (p, g) in enumerate(zip(plist, glist)): 564 | if (i + 1) % 10 == 0: 565 | print('Evaluating %dth prediction' % (i + 1)) 566 | scores['joint_all']['count'] += 1 567 | turn_scores = {"exec": [], "exact": []} 568 | 569 | print(f"len(p): {len(p)}; len(g): {len(g)}") 570 | for idx, pg in enumerate(zip(p, g)): 571 | gold_pred_map = { 572 | 'idx': idx, 573 | 'db_id': '', 574 | 'question': '', 575 | 'gold': '', 576 | 'pred': '', 577 | 'exec_result': 0 578 | } 579 | p, g = pg 580 | p_str = p[0] 581 | p_str = p_str.replace("value", "1") 582 | g_str, db = g 583 | 584 | gold_pred_map['pred'] = p_str 585 | gold_pred_map['gold'] = g_str 586 | gold_pred_map['db_id'] = db 587 | 588 | db_name = db 589 | db = os.path.join(db_dir, db, db + ".sqlite") 590 | schema = Schema(get_schema(db)) 591 | g_sql = get_sql(schema, g_str) 592 | hardness = evaluator.eval_hardness(g_sql) 593 | if idx > 3: 594 | idx = "> 4" 595 | else: 596 | idx += 1 597 | turn_id = "turn " + str(idx) 598 | scores[turn_id]['count'] += 1 599 | scores[hardness]['count'] += 1 600 | scores['all']['count'] += 1 601 | 602 | try: 603 | p_sql = get_sql(schema, p_str) 604 | except: 605 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 606 | p_sql = { 607 | "except": None, 608 | "from": { 609 | "conds": [], 610 | "table_units": [] 611 | }, 612 | "groupBy": [], 613 | "having": [], 614 | "intersect": None, 615 | "limit": None, 616 | "orderBy": [], 617 | "select": [ 618 | False, 619 | [] 620 | ], 621 | "union": None, 622 | "where": [] 623 | } 624 | 625 | if etype in ["all", "exec"]: 626 | exec_score = eval_exec_match(db=db, p_str=p_str, g_str=g_str, plug_value=plug_value, 627 | keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint) 628 | if exec_score: 629 | scores[hardness]['exec'] += 1 630 | scores[turn_id]['exec'] += 1 631 | scores['all']['exec'] += 1 632 | turn_scores['exec'].append(1) 633 | 634 | gold_pred_map['exec_result'] = 1 635 | else: 636 | turn_scores['exec'].append(0) 637 | gold_pred_map_lst.append(gold_pred_map) 638 | 639 | if etype in ["all", "match"]: 640 | # rebuild sql for value evaluation 641 | kmap = kmaps[db_name] 642 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) 643 | g_sql = rebuild_sql_val(g_sql) 644 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 645 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) 646 | p_sql = rebuild_sql_val(p_sql) 647 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 648 | exact_score = evaluator.eval_exact_match(p_sql, g_sql) 649 | partial_scores = evaluator.partial_scores 650 | if exact_score == 0: 651 | turn_scores['exact'].append(0) 652 | print("{} pred: {}".format(hardness, p_str)) 653 | print("{} gold: {}".format(hardness, g_str)) 654 | print("") 655 | else: 656 | turn_scores['exact'].append(1) 657 | scores[turn_id]['exact'] += exact_score 658 | scores[hardness]['exact'] += exact_score 659 | scores['all']['exact'] += exact_score 660 | for type_ in partial_types: 661 | if partial_scores[type_]['pred_total'] > 0: 662 | scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] 663 | scores[hardness]['partial'][type_]['acc_count'] += 1 664 | if partial_scores[type_]['label_total'] > 0: 665 | scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] 666 | scores[hardness]['partial'][type_]['rec_count'] += 1 667 | scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] 668 | if partial_scores[type_]['pred_total'] > 0: 669 | scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] 670 | scores['all']['partial'][type_]['acc_count'] += 1 671 | if partial_scores[type_]['label_total'] > 0: 672 | scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] 673 | scores['all']['partial'][type_]['rec_count'] += 1 674 | scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] 675 | 676 | entries.append({ 677 | 'predictSQL': p_str, 678 | 'goldSQL': g_str, 679 | 'hardness': hardness, 680 | 'exact': exact_score, 681 | 'partial': partial_scores 682 | }) 683 | 684 | if all(v == 1 for v in turn_scores["exec"]): 685 | scores['joint_all']['exec'] += 1 686 | 687 | if all(v == 1 for v in turn_scores["exact"]): 688 | scores['joint_all']['exact'] += 1 689 | 690 | # export evaluation result 691 | out_dir = os.path.dirname(predict) 692 | out_evaluation_json_path = os.path.join(out_dir, "evaluation.json") 693 | if not os.path.exists(out_dir): 694 | os.makedirs(out_dir, exist_ok=True) 695 | 696 | total_cnt = len(gold_pred_map_lst) 697 | err_cnt = len([m for m in gold_pred_map_lst if m['exec_result']== 0]) 698 | print(f"total_cnt: {total_cnt}, err_cnt: {err_cnt}") 699 | 700 | # save json file 701 | with open(out_evaluation_json_path, 'w') as f: 702 | json.dump(gold_pred_map_lst, f, indent=2) 703 | print("save evaluation result to {}".format(out_evaluation_json_path)) 704 | 705 | for turn in turns: 706 | if scores[turn]['count'] == 0: 707 | continue 708 | if etype in ["all", "exec"]: 709 | scores[turn]['exec'] /= scores[turn]['count'] 710 | 711 | if etype in ["all", "match"]: 712 | scores[turn]['exact'] /= scores[turn]['count'] 713 | 714 | for level in levels: 715 | if scores[level]['count'] == 0: 716 | continue 717 | if etype in ["all", "exec"]: 718 | scores[level]['exec'] /= scores[level]['count'] 719 | 720 | if etype in ["all", "match"]: 721 | scores[level]['exact'] /= scores[level]['count'] 722 | for type_ in partial_types: 723 | if scores[level]['partial'][type_]['acc_count'] == 0: 724 | scores[level]['partial'][type_]['acc'] = 0 725 | else: 726 | scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ 727 | scores[level]['partial'][type_]['acc_count'] * 1.0 728 | if scores[level]['partial'][type_]['rec_count'] == 0: 729 | scores[level]['partial'][type_]['rec'] = 0 730 | else: 731 | scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ 732 | scores[level]['partial'][type_]['rec_count'] * 1.0 733 | if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: 734 | scores[level]['partial'][type_]['f1'] = 1 735 | else: 736 | scores[level]['partial'][type_]['f1'] = \ 737 | 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( 738 | scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) 739 | 740 | print_scores(scores, etype, include_turn_acc=include_turn_acc) 741 | 742 | 743 | # Rebuild SQL functions for value evaluation 744 | def rebuild_cond_unit_val(cond_unit): 745 | if cond_unit is None or not DISABLE_VALUE: 746 | return cond_unit 747 | 748 | not_op, op_id, val_unit, val1, val2 = cond_unit 749 | if type(val1) is not dict: 750 | val1 = None 751 | else: 752 | val1 = rebuild_sql_val(val1) 753 | if type(val2) is not dict: 754 | val2 = None 755 | else: 756 | val2 = rebuild_sql_val(val2) 757 | return not_op, op_id, val_unit, val1, val2 758 | 759 | 760 | def rebuild_condition_val(condition): 761 | if condition is None or not DISABLE_VALUE: 762 | return condition 763 | 764 | res = [] 765 | for idx, it in enumerate(condition): 766 | if idx % 2 == 0: 767 | res.append(rebuild_cond_unit_val(it)) 768 | else: 769 | res.append(it) 770 | return res 771 | 772 | 773 | def rebuild_sql_val(sql): 774 | if sql is None or not DISABLE_VALUE: 775 | return sql 776 | 777 | sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) 778 | sql['having'] = rebuild_condition_val(sql['having']) 779 | sql['where'] = rebuild_condition_val(sql['where']) 780 | sql['intersect'] = rebuild_sql_val(sql['intersect']) 781 | sql['except'] = rebuild_sql_val(sql['except']) 782 | sql['union'] = rebuild_sql_val(sql['union']) 783 | 784 | return sql 785 | 786 | 787 | # Rebuild SQL functions for foreign key evaluation 788 | def build_valid_col_units(table_units, schema): 789 | col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] 790 | prefixs = [col_id[:-2] for col_id in col_ids] 791 | valid_col_units= [] 792 | for value in schema.idMap.values(): 793 | if '.' in value and value[:value.index('.')] in prefixs: 794 | valid_col_units.append(value) 795 | return valid_col_units 796 | 797 | 798 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap): 799 | if col_unit is None: 800 | return col_unit 801 | 802 | agg_id, col_id, distinct = col_unit 803 | if col_id in kmap and col_id in valid_col_units: 804 | col_id = kmap[col_id] 805 | if DISABLE_DISTINCT: 806 | distinct = None 807 | return agg_id, col_id, distinct 808 | 809 | 810 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap): 811 | if val_unit is None: 812 | return val_unit 813 | 814 | unit_op, col_unit1, col_unit2 = val_unit 815 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) 816 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) 817 | return unit_op, col_unit1, col_unit2 818 | 819 | 820 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap): 821 | if table_unit is None: 822 | return table_unit 823 | 824 | table_type, col_unit_or_sql = table_unit 825 | if isinstance(col_unit_or_sql, tuple): 826 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) 827 | return table_type, col_unit_or_sql 828 | 829 | 830 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): 831 | if cond_unit is None: 832 | return cond_unit 833 | 834 | not_op, op_id, val_unit, val1, val2 = cond_unit 835 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) 836 | return not_op, op_id, val_unit, val1, val2 837 | 838 | 839 | def rebuild_condition_col(valid_col_units, condition, kmap): 840 | for idx in range(len(condition)): 841 | if idx % 2 == 0: 842 | condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) 843 | return condition 844 | 845 | 846 | def rebuild_select_col(valid_col_units, sel, kmap): 847 | if sel is None: 848 | return sel 849 | distinct, _list = sel 850 | new_list = [] 851 | for it in _list: 852 | agg_id, val_unit = it 853 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) 854 | if DISABLE_DISTINCT: 855 | distinct = None 856 | return distinct, new_list 857 | 858 | 859 | def rebuild_from_col(valid_col_units, from_, kmap): 860 | if from_ is None: 861 | return from_ 862 | 863 | from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] 864 | from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) 865 | return from_ 866 | 867 | 868 | def rebuild_group_by_col(valid_col_units, group_by, kmap): 869 | if group_by is None: 870 | return group_by 871 | 872 | return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] 873 | 874 | 875 | def rebuild_order_by_col(valid_col_units, order_by, kmap): 876 | if order_by is None or len(order_by) == 0: 877 | return order_by 878 | 879 | direction, val_units = order_by 880 | new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] 881 | return direction, new_val_units 882 | 883 | 884 | def rebuild_sql_col(valid_col_units, sql, kmap): 885 | if sql is None: 886 | return sql 887 | 888 | sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) 889 | sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) 890 | sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) 891 | sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) 892 | sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) 893 | sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) 894 | sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) 895 | sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) 896 | sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) 897 | 898 | return sql 899 | 900 | 901 | def build_foreign_key_map(entry): 902 | cols_orig = entry["column_names_original"] 903 | tables_orig = entry["table_names_original"] 904 | 905 | # rebuild cols corresponding to idmap in Schema 906 | cols = [] 907 | for col_orig in cols_orig: 908 | if col_orig[0] >= 0: 909 | t = tables_orig[col_orig[0]] 910 | c = col_orig[1] 911 | cols.append("__" + t.lower() + "." + c.lower() + "__") 912 | else: 913 | cols.append("__all__") 914 | 915 | def keyset_in_list(k1, k2, k_list): 916 | for k_set in k_list: 917 | if k1 in k_set or k2 in k_set: 918 | return k_set 919 | new_k_set = set() 920 | k_list.append(new_k_set) 921 | return new_k_set 922 | 923 | foreign_key_list = [] 924 | foreign_keys = entry["foreign_keys"] 925 | for fkey in foreign_keys: 926 | key1, key2 = fkey 927 | key_set = keyset_in_list(key1, key2, foreign_key_list) 928 | key_set.add(key1) 929 | key_set.add(key2) 930 | 931 | foreign_key_map = {} 932 | for key_set in foreign_key_list: 933 | sorted_list = sorted(list(key_set)) 934 | midx = sorted_list[0] 935 | for idx in sorted_list: 936 | foreign_key_map[cols[idx]] = cols[midx] 937 | 938 | return foreign_key_map 939 | 940 | 941 | def build_foreign_key_map_from_json(table): 942 | with open(table) as f: 943 | data = json.load(f) 944 | tables = {} 945 | for entry in data: 946 | tables[entry['db_id']] = build_foreign_key_map(entry) 947 | return tables 948 | 949 | 950 | if __name__ == "__main__": 951 | parser = argparse.ArgumentParser() 952 | parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold queries") 953 | parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted queries") 954 | parser.add_argument('--db', dest='db', type=str, help="the directory that contains all the databases and test suites") 955 | parser.add_argument('--table', dest='table', type=str, help="the tables.json schema file") 956 | parser.add_argument('--etype', dest='etype', type=str, default='exec', 957 | help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", 958 | choices=('all', 'exec', 'match')) 959 | parser.add_argument('--plug_value', default=False, action='store_true', 960 | help='whether to plug in the gold value into the predicted query; suitable if your model does not predict values.') 961 | parser.add_argument('--keep_distinct', default=False, action='store_true', 962 | help='whether to keep distinct keyword during evaluation. default is false.') 963 | parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true', 964 | help='whether to print progress bar of running test inputs for each datapoint') 965 | args = parser.parse_args() 966 | 967 | # only evaluting exact match needs this argument 968 | kmaps = None 969 | if args.etype in ['all', 'match']: 970 | assert args.table is not None, 'table argument must be non-None if exact set match is evaluated' 971 | kmaps = build_foreign_key_map_from_json(args.table) 972 | 973 | evaluate(args.gold, args.pred, args.db, args.etype, kmaps, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint) 974 | -------------------------------------------------------------------------------- /evaluation/exec_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import asyncio 4 | import sqlite3 5 | import threading 6 | from typing import Tuple, Any, List, Set 7 | from itertools import product 8 | from collections import defaultdict 9 | import tqdm 10 | import random 11 | from parse import get_all_preds_for_execution, remove_distinct 12 | import time 13 | import pickle as pkl 14 | import subprocess 15 | from itertools import chain 16 | 17 | 18 | 19 | threadLock = threading.Lock() 20 | TIMEOUT = 60 21 | EXEC_TMP_DIR = 'tmp/' 22 | 23 | def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: 24 | assert len(element) == len(perm) 25 | return tuple([element[i] for i in perm]) 26 | 27 | 28 | def unorder_row(row: Tuple) -> Tuple: 29 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) 30 | 31 | 32 | # unorder each row in the table 33 | # [result_1 and result_2 has the same bag of unordered row] 34 | # is a necessary condition of 35 | # [result_1 and result_2 are equivalent in denotation] 36 | def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 37 | s1 = [unorder_row(row) for row in result1] 38 | s2 = [unorder_row(row) for row in result2] 39 | if order_matters: 40 | return s1 == s2 41 | else: 42 | return set(s1) == set(s2) 43 | 44 | 45 | # return whether two bag of relations are equivalent 46 | def multiset_eq(l1: List, l2: List) -> bool: 47 | if len(l1) != len(l2): 48 | return False 49 | d = defaultdict(int) 50 | for e in l1: 51 | d[e] = d[e] + 1 52 | for e in l2: 53 | d[e] = d[e] - 1 54 | if d[e] < 0: 55 | return False 56 | return True 57 | 58 | 59 | def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): 60 | num_cols = len(result2[0]) 61 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] 62 | if num_cols <= 3: 63 | return product(*perm_constraints) 64 | 65 | # we sample 20 rows and constrain the space of permutations 66 | for _ in range(20): 67 | random_tab2_row = random.choice(result2) 68 | 69 | for tab1_col in range(num_cols): 70 | for tab2_col in set(perm_constraints[tab1_col]): 71 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: 72 | perm_constraints[tab1_col].remove(tab2_col) 73 | return product(*perm_constraints) 74 | 75 | 76 | # check whether two denotations are correct 77 | def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 78 | if len(result1) == 0 and len(result2) == 0: 79 | return True 80 | 81 | # if length is not the same, then they are definitely different bag of rows 82 | if len(result1) != len(result2): 83 | return False 84 | 85 | num_cols = len(result1[0]) 86 | 87 | # if the results do not have the same number of columns, they are different 88 | if len(result2[0]) != num_cols: 89 | return False 90 | 91 | # unorder each row and compare whether the denotation is the same 92 | # this can already find most pair of denotations that are different 93 | if not quick_rej(result1, result2, order_matters): 94 | return False 95 | 96 | # the rest of the problem is in fact more complicated than one might think 97 | # we want to find a permutation of column order and a permutation of row order, 98 | # s.t. result_1 is the same as result_2 99 | # we return true if we can find such column & row permutations 100 | # and false if we cannot 101 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] 102 | 103 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 104 | # we decrease the size of the column permutation space by the function get_constraint_permutation 105 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent 106 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2): 107 | if len(perm) != len(set(perm)): 108 | continue 109 | if num_cols == 1: 110 | result2_perm = result2 111 | else: 112 | result2_perm = [permute_tuple(element, perm) for element in result2] 113 | if order_matters: 114 | if result1 == result2_perm: 115 | return True 116 | else: 117 | # in fact the first condition must hold if the second condition holds 118 | # but the first is way more efficient implementation-wise 119 | # and we use it to quickly reject impossible candidates 120 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): 121 | return True 122 | return False 123 | 124 | 125 | def replace_cur_year(query: str) -> str: 126 | return re.sub( 127 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE 128 | ) 129 | 130 | 131 | # get the database cursor for a sqlite database path 132 | def get_cursor_from_path(sqlite_path: str): 133 | try: 134 | if not os.path.exists(sqlite_path): 135 | print("Openning a new connection %s" % sqlite_path) 136 | connection = sqlite3.connect(sqlite_path) 137 | except Exception as e: 138 | print(sqlite_path) 139 | raise e 140 | connection.text_factory = lambda b: b.decode(errors="ignore") 141 | cursor = connection.cursor() 142 | return cursor 143 | 144 | 145 | async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: 146 | query = replace_cur_year(query) 147 | cursor = get_cursor_from_path(sqlite_path) 148 | try: 149 | cursor.execute(query) 150 | result = cursor.fetchall() 151 | cursor.close() 152 | cursor.connection.close() 153 | return "result", result 154 | except Exception as e: 155 | cursor.close() 156 | cursor.connection.close() 157 | return "exception", e 158 | 159 | async def exec_on_db( 160 | sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT 161 | ) -> Tuple[str, Any]: 162 | try: 163 | return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) 164 | except asyncio.TimeoutError: 165 | return ('exception', TimeoutError) 166 | except Exception as e: 167 | return ("exception", e) 168 | 169 | 170 | # postprocess the model predictions to avoid execution errors 171 | # e.g. removing spaces between ">" and "=" 172 | def postprocess(query: str) -> str: 173 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 174 | return query 175 | 176 | 177 | # approximate whether p_str and g_str are semantically equivalent 178 | # db is the database path 179 | # we are going to evaluate whether they are equivalent in all the databases 180 | # that are in the same directory as db 181 | # 0 if denotationally equivalent 182 | # 1 otherwise 183 | # the meaning of each auxillary argument can be seen in the parser definition in evaluation.py 184 | def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int: 185 | # post-process the prediction. 186 | # e.g. removing spaces between ">" and "=" 187 | p_str, g_str = postprocess(p_str), postprocess(g_str) 188 | if not keep_distinct: 189 | p_str = remove_distinct(p_str) 190 | g_str = remove_distinct(g_str) 191 | 192 | # we decide whether two denotations are equivalent based on "bag semantics" 193 | # https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf 194 | # if there is order by in query, then we assume order of the rows matter 195 | # order by might also be used to find the max/min instead of sorting, 196 | # but in that case the result mostly only contains one row and hence order_matters does not make a difference 197 | order_matters = 'order by' in g_str.lower() 198 | 199 | # find all databases in the same directory 200 | db_dir = os.path.dirname(db) 201 | db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename] 202 | 203 | preds = [p_str] 204 | # if plug in value (i.e. we do not consider value prediction correctness) 205 | # enumerate all ways to plug in values in the gold query to the model predictions 206 | # otherwise, we only evaluate the predicted query with its own value prediction 207 | if plug_value: 208 | _, preds = get_all_preds_for_execution(g_str, p_str) 209 | # we did not add this line in our EMNLP work 210 | # this reduces "false negatives" when value is substituted 211 | preds = chain([p_str], preds) 212 | 213 | for pred in preds: 214 | 215 | pred_passes = 1 216 | # compare the gold and predicted denotations on each database in the directory 217 | # wrap with progress bar if required 218 | if progress_bar_for_each_datapoint: 219 | ranger = tqdm.tqdm(db_paths) 220 | else: 221 | ranger = db_paths 222 | 223 | for db_path in ranger: 224 | g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) 225 | p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) 226 | 227 | # we should expect the gold to be succesfully executed on the database 228 | assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path) 229 | 230 | # wrong if execution fails 231 | if p_flag == 'exception': 232 | pred_passes = 0 233 | 234 | # if denotations are not equivalent, the prediction must be wrong 235 | elif not result_eq(g_denotation, p_denotation, order_matters=order_matters): 236 | pred_passes = 0 237 | if pred_passes == 0: 238 | break 239 | 240 | # the model prediction has the same denotation as the gold for all databases 241 | if pred_passes == 1: 242 | return 1 243 | 244 | # none of the predictions passed 245 | return 0 246 | -------------------------------------------------------------------------------- /evaluation/parse.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sqlparse 3 | from typing import List, Tuple, Set, Iterator, Dict, Any, Union 4 | from sqlparse.sql import Comparison, Identifier 5 | from sqlparse.tokens import Whitespace 6 | import itertools 7 | from collections import namedtuple 8 | 9 | Token = namedtuple('Token', ['ttype', 'value']) 10 | VALUE_NUM_SYMBOL = 'VALUERARE' 11 | QUOTE_CHARS = {'`', '\'', '"'} 12 | 13 | 14 | def tokenize(query: str) -> List[Token]: 15 | tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()]) 16 | return tokens 17 | 18 | 19 | def join_tokens(tokens: List[Token]) -> str: 20 | return ''.join([x.value for x in tokens]).strip().replace(' ', ' ') 21 | 22 | 23 | def round_trip_test(query: str) -> None: 24 | tokens = tokenize(query) 25 | reconstructed = ''.join([token.value for token in tokens]) 26 | assert query == reconstructed, "Round trip test fails for string %s" % query 27 | 28 | 29 | def postprocess(query: str) -> str: 30 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 31 | return query 32 | 33 | 34 | # strip_query, reformat_query and replace values 35 | # were implemented by Yu Tao for processing CoSQL 36 | def strip_query(query: str) -> Tuple[List[str], List[str]]: 37 | query_keywords, all_values = [], [] 38 | 39 | # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE} 40 | 41 | # Tao's implementation is commented out here. 42 | """ 43 | str_1 = re.findall("\"[^\"]*\"", query) 44 | str_2 = re.findall("\'[^\']*\'", query) 45 | values = str_1 + str_2 46 | """ 47 | 48 | toks = sqlparse.parse(query)[0].flatten() 49 | values = [t.value for t in toks if t.ttype == sqlparse.tokens.Literal.String.Single or t.ttype == sqlparse.tokens.Literal.String.Symbol] 50 | 51 | 52 | for val in values: 53 | all_values.append(val) 54 | query = query.replace(val.strip(), VALUE_NUM_SYMBOL) 55 | 56 | query_tokenized = query.split() 57 | float_nums = re.findall("[-+]?\d*\.\d+", query) 58 | all_values += [qt for qt in query_tokenized if qt in float_nums] 59 | query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized] 60 | 61 | query = " ".join(query_tokenized) 62 | int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)] 63 | 64 | all_values += [qt for qt in query_tokenized if qt in int_nums] 65 | query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized] 66 | # print int_nums, query, query_tokenized 67 | 68 | for tok in query_tokenized: 69 | if "." in tok: 70 | table = re.findall("[Tt]\d+\.", tok) 71 | if len(table) > 0: 72 | to = tok.replace(".", " . ").split() 73 | to = [t.lower() for t in to if len(t) > 0] 74 | query_keywords.extend(to) 75 | else: 76 | query_keywords.append(tok.lower()) 77 | 78 | elif len(tok) > 0: 79 | query_keywords.append(tok.lower()) 80 | return query_keywords, all_values 81 | 82 | 83 | def reformat_query(query: str) -> str: 84 | query = query.strip().replace(";", "").replace("\t", "") 85 | query = ' '.join([t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace]) 86 | t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"] 87 | for ts in t_stars: 88 | query = query.replace(ts, "*") 89 | return query 90 | 91 | 92 | def replace_values(sql: str) -> Tuple[List[str], Set[str]]: 93 | sql = sqlparse.format(sql, reindent=False, keyword_case='upper') 94 | # sql = re.sub(r"(<=|>=|!=|=|<|>|,)", r" \1 ", sql) 95 | sql = re.sub(r"(T\d+\.)\s", r"\1", sql) 96 | query_toks_no_value, values = strip_query(sql) 97 | return query_toks_no_value, set(values) 98 | 99 | 100 | # extract the non-value tokens and the set of values 101 | # from a sql query 102 | def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]: 103 | reformated = reformat_query(query=sql) 104 | query_value_replaced, values = replace_values(reformated) 105 | return query_value_replaced, values 106 | 107 | 108 | # plug in the values into query with value slots 109 | def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str: 110 | q_length = len(query_value_replaced) 111 | query_w_values = query_value_replaced[:] 112 | value_idx = [idx for idx in range(q_length) if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower()] 113 | assert len(value_idx) == len(values_in_order) 114 | 115 | for idx, value in zip(value_idx, values_in_order): 116 | query_w_values[idx] = value 117 | return ' '.join(query_w_values) 118 | 119 | 120 | # a generator generating all possible ways of 121 | # filling values into predicted query 122 | def plugin_all_permutations(query_value_replaced: List[str], values: Set[str]) -> Iterator[str]: 123 | num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 124 | for values in itertools.product(*[list(values) for _ in range(num_slots)]): 125 | yield plugin(query_value_replaced, list(values)) 126 | 127 | 128 | # given the gold query and the model prediction 129 | # extract values from the gold, extract predicted sql with value slots 130 | # return 1) number of possible ways to plug in gold values and 2) an iterator of predictions with value plugged in 131 | def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]: 132 | _, gold_values = extract_query_values(gold) 133 | pred_query_value_replaced, _ = extract_query_values(pred) 134 | num_slots = len([v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 135 | num_alternatives = len(gold_values) ** num_slots 136 | return num_alternatives, plugin_all_permutations(pred_query_value_replaced, gold_values) 137 | 138 | 139 | def remove_distinct(s): 140 | toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())] 141 | return ''.join([t for t in toks if t.lower() != 'distinct']) 142 | 143 | 144 | def extract_all_comparison_from_node(node: Token) -> List[Comparison]: 145 | comparison_list = [] 146 | if hasattr(node, 'tokens'): 147 | for t in node.tokens: 148 | comparison_list.extend(extract_all_comparison_from_node(t)) 149 | if type(node) == Comparison: 150 | comparison_list.append(node) 151 | return comparison_list 152 | 153 | 154 | def extract_all_comparison(query: str) -> List[Comparison]: 155 | tree = sqlparse.parse(query)[0] 156 | comparison_list = extract_all_comparison_from_node(tree) 157 | return comparison_list 158 | 159 | 160 | def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]: 161 | tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace] 162 | return tokens 163 | 164 | 165 | def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]: 166 | tokens = extract_toks_from_comparison(comparison_node) 167 | left, op, right = tokens 168 | 169 | returned_dict = { 170 | 'left': left, 171 | 'op': op.value, 172 | 'right': right 173 | } 174 | 175 | if type(left) != Identifier: 176 | return returned_dict 177 | 178 | table = None 179 | if len(left.tokens) == 3 and re.match('^[tT][0-9]$', left.tokens[0].value) is None: 180 | table = left.tokens[0].value.lower() 181 | col = left.tokens[-1].value 182 | 183 | if type(right) == Identifier: 184 | if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token: 185 | right_val = right.tokens[0].value 186 | else: 187 | return returned_dict 188 | elif type(right) == sqlparse.sql.Token: 189 | right_val = right.value 190 | else: 191 | return returned_dict 192 | 193 | returned_dict['table_col'], returned_dict['val'] = (table, col.upper()), process_str_value(right_val) 194 | 195 | return returned_dict 196 | 197 | 198 | def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]: 199 | comparison_list = extract_all_comparison(query) 200 | return [extract_info_from_comparison(c) for c in comparison_list] 201 | 202 | 203 | def extract_typed_value_in_comparison_from_query(query: str) -> List[Tuple[Tuple[Union[str, None], str], str]]: 204 | cmps = extract_all_comparison_from_query(query) 205 | typed_values = [(cmp['table_col'], cmp['val']) for cmp in cmps if 'table_col' in cmp] 206 | for table, col, val1, val2 in re.findall('(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)', query, re.IGNORECASE): 207 | if table == '': 208 | table = None 209 | else: 210 | table = table.lower() 211 | col = col.upper() 212 | for v in [val1, val2]: 213 | typed_values.append(((table, col), v)) 214 | return typed_values 215 | 216 | 217 | def process_str_value(v: str) -> str: 218 | if len(v) > 0 and v[0] in QUOTE_CHARS: 219 | v = v[1:] 220 | if len(v) > 0 and v[-1] in QUOTE_CHARS: 221 | v = v[:-1] 222 | for c in QUOTE_CHARS: 223 | v = v.replace(c + c, c) 224 | return v 225 | -------------------------------------------------------------------------------- /evaluation/process_sql.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Assumptions: 3 | # 1. sql is correct 4 | # 2. only table name has alias 5 | # 3. only one intersect/union/except 6 | # 7 | # val: number(float)/string(str)/sql(dict) 8 | # col_unit: (agg_id, col_id, isDistinct(bool)) 9 | # val_unit: (unit_op, col_unit1, col_unit2) 10 | # table_unit: (table_type, col_unit/sql) 11 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 13 | # sql { 14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 16 | # 'where': condition 17 | # 'groupBy': [col_unit1, col_unit2, ...] 18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 19 | # 'having': condition 20 | # 'limit': None/limit value 21 | # 'intersect': None/sql 22 | # 'except': None/sql 23 | # 'union': None/sql 24 | # } 25 | ################################ 26 | 27 | import json 28 | import sqlite3 29 | from nltk import word_tokenize 30 | 31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 32 | JOIN_KEYWORDS = ('join', 'on', 'as') 33 | 34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 35 | UNIT_OPS = ('none', '-', '+', "*", '/') 36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 37 | TABLE_TYPE = { 38 | 'sql': "sql", 39 | 'table_unit': "table_unit", 40 | } 41 | 42 | COND_OPS = ('and', 'or') 43 | SQL_OPS = ('intersect', 'union', 'except') 44 | ORDER_OPS = ('desc', 'asc') 45 | 46 | 47 | 48 | class Schema: 49 | """ 50 | Simple schema which maps table&column to a unique identifier 51 | """ 52 | def __init__(self, schema): 53 | self._schema = schema 54 | self._idMap = self._map(self._schema) 55 | 56 | @property 57 | def schema(self): 58 | return self._schema 59 | 60 | @property 61 | def idMap(self): 62 | return self._idMap 63 | 64 | def _map(self, schema): 65 | idMap = {'*': "__all__"} 66 | id = 1 67 | for key, vals in schema.items(): 68 | for val in vals: 69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 70 | id += 1 71 | 72 | for key in schema: 73 | idMap[key.lower()] = "__" + key.lower() + "__" 74 | id += 1 75 | 76 | return idMap 77 | 78 | 79 | def get_schema(db): 80 | """ 81 | Get database's schema, which is a dict with table name as key 82 | and list of column names as value 83 | :param db: database path 84 | :return: schema dict 85 | """ 86 | 87 | schema = {} 88 | conn = sqlite3.connect(db) 89 | cursor = conn.cursor() 90 | 91 | # fetch table names 92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 93 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 94 | 95 | # fetch table info 96 | for table in tables: 97 | cursor.execute("PRAGMA table_info({})".format(table)) 98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 99 | 100 | return schema 101 | 102 | 103 | def get_schema_from_json(fpath): 104 | with open(fpath) as f: 105 | data = json.load(f) 106 | 107 | schema = {} 108 | for entry in data: 109 | table = str(entry['table'].lower()) 110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 111 | schema[table] = cols 112 | 113 | return schema 114 | 115 | 116 | def tokenize(string): 117 | string = str(string) 118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 121 | 122 | # keep string value as token 123 | vals = {} 124 | for i in range(len(quote_idxs)-1, -1, -2): 125 | qidx1 = quote_idxs[i-1] 126 | qidx2 = quote_idxs[i] 127 | val = string[qidx1: qidx2+1] 128 | key = "__val_{}_{}__".format(qidx1, qidx2) 129 | string = string[:qidx1] + key + string[qidx2+1:] 130 | vals[key] = val 131 | 132 | toks = [word.lower() for word in word_tokenize(string)] 133 | # replace with string value token 134 | for i in range(len(toks)): 135 | if toks[i] in vals: 136 | toks[i] = vals[toks[i]] 137 | 138 | # find if there exists !=, >=, <= 139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 140 | eq_idxs.reverse() 141 | prefix = ('!', '>', '<') 142 | for eq_idx in eq_idxs: 143 | pre_tok = toks[eq_idx-1] 144 | if pre_tok in prefix: 145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 146 | 147 | return toks 148 | 149 | 150 | def scan_alias(toks): 151 | """Scan the index of 'as' and build the map for all alias""" 152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 153 | alias = {} 154 | for idx in as_idxs: 155 | alias[toks[idx+1]] = toks[idx-1] 156 | return alias 157 | 158 | 159 | def get_tables_with_alias(schema, toks): 160 | tables = scan_alias(toks) 161 | for key in schema: 162 | assert key not in tables, "Alias {} has the same name in table".format(key) 163 | tables[key] = key 164 | return tables 165 | 166 | 167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 168 | """ 169 | :returns next idx, column id 170 | """ 171 | tok = toks[start_idx] 172 | if tok == "*": 173 | return start_idx + 1, schema.idMap[tok] 174 | 175 | if '.' in tok: # if token is a composite 176 | alias, col = tok.split('.') 177 | key = tables_with_alias[alias] + "." + col 178 | return start_idx+1, schema.idMap[key] 179 | 180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 181 | 182 | for alias in default_tables: 183 | table = tables_with_alias[alias] 184 | if tok in schema.schema[table]: 185 | key = table + "." + tok 186 | return start_idx+1, schema.idMap[key] 187 | 188 | assert False, "Error col: {}".format(tok) 189 | 190 | 191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 192 | """ 193 | :returns next idx, (agg_op id, col_id) 194 | """ 195 | idx = start_idx 196 | len_ = len(toks) 197 | isBlock = False 198 | isDistinct = False 199 | if toks[idx] == '(': 200 | isBlock = True 201 | idx += 1 202 | 203 | if toks[idx] in AGG_OPS: 204 | agg_id = AGG_OPS.index(toks[idx]) 205 | idx += 1 206 | assert idx < len_ and toks[idx] == '(' 207 | idx += 1 208 | if toks[idx] == "distinct": 209 | idx += 1 210 | isDistinct = True 211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 212 | assert idx < len_ and toks[idx] == ')' 213 | idx += 1 214 | return idx, (agg_id, col_id, isDistinct) 215 | 216 | if toks[idx] == "distinct": 217 | idx += 1 218 | isDistinct = True 219 | agg_id = AGG_OPS.index("none") 220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 221 | 222 | if isBlock: 223 | assert toks[idx] == ')' 224 | idx += 1 # skip ')' 225 | 226 | return idx, (agg_id, col_id, isDistinct) 227 | 228 | 229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 230 | idx = start_idx 231 | len_ = len(toks) 232 | isBlock = False 233 | if toks[idx] == '(': 234 | isBlock = True 235 | idx += 1 236 | 237 | col_unit1 = None 238 | col_unit2 = None 239 | unit_op = UNIT_OPS.index('none') 240 | 241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 242 | if idx < len_ and toks[idx] in UNIT_OPS: 243 | unit_op = UNIT_OPS.index(toks[idx]) 244 | idx += 1 245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 246 | 247 | if isBlock: 248 | assert toks[idx] == ')' 249 | idx += 1 # skip ')' 250 | 251 | return idx, (unit_op, col_unit1, col_unit2) 252 | 253 | 254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 255 | """ 256 | :returns next idx, table id, table name 257 | """ 258 | idx = start_idx 259 | len_ = len(toks) 260 | key = tables_with_alias[toks[idx]] 261 | 262 | if idx + 1 < len_ and toks[idx+1] == "as": 263 | idx += 3 264 | else: 265 | idx += 1 266 | 267 | return idx, schema.idMap[key], key 268 | 269 | 270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 271 | idx = start_idx 272 | len_ = len(toks) 273 | 274 | isBlock = False 275 | if toks[idx] == '(': 276 | isBlock = True 277 | idx += 1 278 | 279 | if toks[idx] == 'select': 280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 281 | elif "\"" in toks[idx]: # token is a string value 282 | val = toks[idx] 283 | idx += 1 284 | else: 285 | try: 286 | val = float(toks[idx]) 287 | idx += 1 288 | except: 289 | end_idx = idx 290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: 292 | end_idx += 1 293 | 294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 295 | idx = end_idx 296 | 297 | if isBlock: 298 | assert toks[idx] == ')' 299 | idx += 1 300 | 301 | return idx, val 302 | 303 | 304 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 305 | idx = start_idx 306 | len_ = len(toks) 307 | conds = [] 308 | 309 | while idx < len_: 310 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 311 | not_op = False 312 | if toks[idx] == 'not': 313 | not_op = True 314 | idx += 1 315 | 316 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 317 | op_id = WHERE_OPS.index(toks[idx]) 318 | idx += 1 319 | val1 = val2 = None 320 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 321 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 322 | assert toks[idx] == 'and' 323 | idx += 1 324 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 325 | else: # normal case: single value 326 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 327 | val2 = None 328 | 329 | conds.append((not_op, op_id, val_unit, val1, val2)) 330 | 331 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): 332 | break 333 | 334 | if idx < len_ and toks[idx] in COND_OPS: 335 | conds.append(toks[idx]) 336 | idx += 1 # skip and/or 337 | 338 | return idx, conds 339 | 340 | 341 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 342 | idx = start_idx 343 | len_ = len(toks) 344 | 345 | assert toks[idx] == 'select', "'select' not found" 346 | idx += 1 347 | isDistinct = False 348 | if idx < len_ and toks[idx] == 'distinct': 349 | idx += 1 350 | isDistinct = True 351 | val_units = [] 352 | 353 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 354 | agg_id = AGG_OPS.index("none") 355 | if toks[idx] in AGG_OPS: 356 | agg_id = AGG_OPS.index(toks[idx]) 357 | idx += 1 358 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 359 | val_units.append((agg_id, val_unit)) 360 | if idx < len_ and toks[idx] == ',': 361 | idx += 1 # skip ',' 362 | 363 | return idx, (isDistinct, val_units) 364 | 365 | 366 | def parse_from(toks, start_idx, tables_with_alias, schema): 367 | """ 368 | Assume in the from clause, all table units are combined with join 369 | """ 370 | assert 'from' in toks[start_idx:], "'from' not found" 371 | 372 | len_ = len(toks) 373 | idx = toks.index('from', start_idx) + 1 374 | default_tables = [] 375 | table_units = [] 376 | conds = [] 377 | 378 | while idx < len_: 379 | isBlock = False 380 | if toks[idx] == '(': 381 | isBlock = True 382 | idx += 1 383 | 384 | if toks[idx] == 'select': 385 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 386 | table_units.append((TABLE_TYPE['sql'], sql)) 387 | else: 388 | if idx < len_ and toks[idx] == 'join': 389 | idx += 1 # skip join 390 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 391 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 392 | default_tables.append(table_name) 393 | if idx < len_ and toks[idx] == "on": 394 | idx += 1 # skip on 395 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 396 | if len(conds) > 0: 397 | conds.append('and') 398 | conds.extend(this_conds) 399 | 400 | if isBlock: 401 | assert toks[idx] == ')' 402 | idx += 1 403 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 404 | break 405 | 406 | return idx, table_units, conds, default_tables 407 | 408 | 409 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 410 | idx = start_idx 411 | len_ = len(toks) 412 | 413 | if idx >= len_ or toks[idx] != 'where': 414 | return idx, [] 415 | 416 | idx += 1 417 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 418 | return idx, conds 419 | 420 | 421 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 422 | idx = start_idx 423 | len_ = len(toks) 424 | col_units = [] 425 | 426 | if idx >= len_ or toks[idx] != 'group': 427 | return idx, col_units 428 | 429 | idx += 1 430 | assert toks[idx] == 'by' 431 | idx += 1 432 | 433 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 434 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 435 | col_units.append(col_unit) 436 | if idx < len_ and toks[idx] == ',': 437 | idx += 1 # skip ',' 438 | else: 439 | break 440 | 441 | return idx, col_units 442 | 443 | 444 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 445 | idx = start_idx 446 | len_ = len(toks) 447 | val_units = [] 448 | order_type = 'asc' # default type is 'asc' 449 | 450 | if idx >= len_ or toks[idx] != 'order': 451 | return idx, val_units 452 | 453 | idx += 1 454 | assert toks[idx] == 'by' 455 | idx += 1 456 | 457 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 458 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 459 | val_units.append(val_unit) 460 | if idx < len_ and toks[idx] in ORDER_OPS: 461 | order_type = toks[idx] 462 | idx += 1 463 | if idx < len_ and toks[idx] == ',': 464 | idx += 1 # skip ',' 465 | else: 466 | break 467 | 468 | return idx, (order_type, val_units) 469 | 470 | 471 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 472 | idx = start_idx 473 | len_ = len(toks) 474 | 475 | if idx >= len_ or toks[idx] != 'having': 476 | return idx, [] 477 | 478 | idx += 1 479 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 480 | return idx, conds 481 | 482 | 483 | def parse_limit(toks, start_idx): 484 | idx = start_idx 485 | len_ = len(toks) 486 | 487 | if idx < len_ and toks[idx] == 'limit': 488 | idx += 2 489 | # make limit value can work, cannot assume put 1 as a fake limit number 490 | if type(toks[idx-1]) != int: 491 | return idx, 1 492 | 493 | return idx, int(toks[idx-1]) 494 | 495 | return idx, None 496 | 497 | 498 | def parse_sql(toks, start_idx, tables_with_alias, schema): 499 | isBlock = False # indicate whether this is a block of sql/sub-sql 500 | len_ = len(toks) 501 | idx = start_idx 502 | 503 | sql = {} 504 | if toks[idx] == '(': 505 | isBlock = True 506 | idx += 1 507 | 508 | # parse from clause in order to get default tables 509 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 510 | sql['from'] = {'table_units': table_units, 'conds': conds} 511 | # select clause 512 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) 513 | idx = from_end_idx 514 | sql['select'] = select_col_units 515 | # where clause 516 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 517 | sql['where'] = where_conds 518 | # group by clause 519 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 520 | sql['groupBy'] = group_col_units 521 | # having clause 522 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 523 | sql['having'] = having_conds 524 | # order by clause 525 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 526 | sql['orderBy'] = order_col_units 527 | # limit clause 528 | idx, limit_val = parse_limit(toks, idx) 529 | sql['limit'] = limit_val 530 | 531 | idx = skip_semicolon(toks, idx) 532 | if isBlock: 533 | assert toks[idx] == ')' 534 | idx += 1 # skip ')' 535 | idx = skip_semicolon(toks, idx) 536 | 537 | # intersect/union/except clause 538 | for op in SQL_OPS: # initialize IUE 539 | sql[op] = None 540 | if idx < len_ and toks[idx] in SQL_OPS: 541 | sql_op = toks[idx] 542 | idx += 1 543 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 544 | sql[sql_op] = IUE_sql 545 | return idx, sql 546 | 547 | 548 | def load_data(fpath): 549 | with open(fpath) as f: 550 | data = json.load(f) 551 | return data 552 | 553 | 554 | def get_sql(schema, query): 555 | toks = tokenize(query) 556 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 557 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 558 | 559 | return sql 560 | 561 | 562 | def skip_semicolon(toks, start_idx): 563 | idx = start_idx 564 | while idx < len(toks) and toks[idx] == ";": 565 | idx += 1 566 | return idx 567 | -------------------------------------------------------------------------------- /evaluation_bird_ex_ves.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | db_root_path="./data/bird/dev/dev_databases/" 4 | data_mode="dev" 5 | diff_json_path="./data/bird/dev/dev.json" 6 | predicted_sql_json_path="./outputs/bird/predict_dev.json" 7 | ground_truth_sql_path="./data/bird/dev/dev_gold.sql" 8 | num_cpus=12 9 | meta_time_out=30.0 10 | time_out=60 11 | mode_gt="gt" 12 | mode_predict="gpt" 13 | 14 | # evaluate EX 15 | echo "Evaluate BIRD EX begin!" 16 | python ./evaluation/evaluation_bird_ex.py --db_root_path $db_root_path \ 17 | --predicted_sql_json_path $predicted_sql_json_path \ 18 | --data_mode $data_mode \ 19 | --ground_truth_sql_path $ground_truth_sql_path \ 20 | --num_cpus $num_cpus \ 21 | --mode_predict $mode_predict \ 22 | --diff_json_path $diff_json_path \ 23 | --meta_time_out $meta_time_out 24 | echo "Evaluate EX done!" 25 | 26 | # evaluate VES 27 | echo "Evaluate BIRD VES begin!" 28 | python ./evaluation/evaluation_bird_ves.py \ 29 | --db_root_path $db_root_path \ 30 | --predicted_sql_json_path $predicted_sql_json_path \ 31 | --data_mode $data_mode \ 32 | --ground_truth_sql_path $ground_truth_sql_path \ 33 | --num_cpus $num_cpus --meta_time_out $time_out \ 34 | --mode_gt $mode_gt --mode_predict $mode_predict \ 35 | --diff_json_path $diff_json_path 36 | echo "Evaluate VES done!" -------------------------------------------------------------------------------- /main_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__init__.py -------------------------------------------------------------------------------- /main_scripts/__pycache__/MAG.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/MAG.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/__pycache__/bridge_content_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/bridge_content_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/__pycache__/chat_manager.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/chat_manager.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/__pycache__/const.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/const.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/__pycache__/llm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/llm.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /main_scripts/bridge_content_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Encode DB content. 8 | """ 9 | 10 | import difflib 11 | from typing import List, Optional, Tuple 12 | from rapidfuzz import fuzz 13 | import sqlite3 14 | import functools 15 | 16 | # fmt: off 17 | _stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again', 18 | 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but', 19 | "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is', 20 | "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being', 21 | 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be', 22 | 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor', 23 | 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because', 24 | 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have', 25 | 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours', 26 | 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his', 27 | "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what', 28 | 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan', 29 | 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here', 30 | 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself', 31 | 'then', 'did', 'just', "aren't"} 32 | # fmt: on 33 | 34 | _commonwords = {"no", "yes", "many"} 35 | 36 | 37 | def is_number(s: str) -> bool: 38 | try: 39 | float(s.replace(",", "")) 40 | return True 41 | except: 42 | return False 43 | 44 | 45 | def is_stopword(s: str) -> bool: 46 | return s.strip() in _stopwords 47 | 48 | 49 | def is_commonword(s: str) -> bool: 50 | return s.strip() in _commonwords 51 | 52 | 53 | def is_common_db_term(s: str) -> bool: 54 | return s.strip() in ["id"] 55 | 56 | 57 | class Match(object): 58 | def __init__(self, start: int, size: int) -> None: 59 | self.start = start 60 | self.size = size 61 | 62 | 63 | def is_span_separator(c: str) -> bool: 64 | return c in "'\"()`,.?! " 65 | 66 | 67 | def split(s: str) -> List[str]: 68 | return [c.lower() for c in s.strip()] 69 | 70 | 71 | def prefix_match(s1: str, s2: str) -> bool: 72 | i, j = 0, 0 73 | for i in range(len(s1)): 74 | if not is_span_separator(s1[i]): 75 | break 76 | for j in range(len(s2)): 77 | if not is_span_separator(s2[j]): 78 | break 79 | if i < len(s1) and j < len(s2): 80 | return s1[i] == s2[j] 81 | elif i >= len(s1) and j >= len(s2): 82 | return True 83 | else: 84 | return False 85 | 86 | 87 | def get_effective_match_source(s: str, start: int, end: int) -> Match: 88 | _start = -1 89 | 90 | for i in range(start, start - 2, -1): 91 | if i < 0: 92 | _start = i + 1 93 | break 94 | if is_span_separator(s[i]): 95 | _start = i 96 | break 97 | 98 | if _start < 0: 99 | return None 100 | 101 | _end = -1 102 | for i in range(end - 1, end + 3): 103 | if i >= len(s): 104 | _end = i - 1 105 | break 106 | if is_span_separator(s[i]): 107 | _end = i 108 | break 109 | 110 | if _end < 0: 111 | return None 112 | 113 | while _start < len(s) and is_span_separator(s[_start]): 114 | _start += 1 115 | while _end >= 0 and is_span_separator(s[_end]): 116 | _end -= 1 117 | 118 | return Match(_start, _end - _start + 1) 119 | 120 | 121 | def get_matched_entries( 122 | s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85 123 | ) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]: 124 | if not field_values: 125 | return None 126 | 127 | if isinstance(s, str): 128 | n_grams = split(s) 129 | else: 130 | n_grams = s 131 | 132 | matched = dict() 133 | for field_value in field_values: 134 | if not isinstance(field_value, str): 135 | continue 136 | fv_tokens = split(field_value) 137 | sm = difflib.SequenceMatcher(None, n_grams, fv_tokens) 138 | match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens)) 139 | if match.size > 0: 140 | source_match = get_effective_match_source( 141 | n_grams, match.a, match.a + match.size 142 | ) 143 | if source_match: # and source_match.size > 1 144 | match_str = field_value[match.b : match.b + match.size] 145 | source_match_str = s[ 146 | source_match.start : source_match.start + source_match.size 147 | ] 148 | c_match_str = match_str.lower().strip() 149 | c_source_match_str = source_match_str.lower().strip() 150 | c_field_value = field_value.lower().strip() 151 | if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str) 152 | if ( 153 | is_stopword(c_match_str) 154 | or is_stopword(c_source_match_str) 155 | or is_stopword(c_field_value) 156 | ): 157 | continue 158 | if c_source_match_str.endswith(c_match_str + "'s"): 159 | match_score = 1.0 160 | else: 161 | if prefix_match(c_field_value, c_source_match_str): 162 | match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100 163 | else: 164 | match_score = 0 165 | if ( 166 | is_commonword(c_match_str) 167 | or is_commonword(c_source_match_str) 168 | or is_commonword(c_field_value) 169 | ) and match_score < 1: 170 | continue 171 | s_match_score = match_score 172 | if match_score >= m_theta and s_match_score >= s_theta: 173 | if field_value.isupper() and match_score * s_match_score < 1: 174 | continue 175 | matched[match_str] = ( 176 | field_value, 177 | source_match_str, 178 | match_score, 179 | s_match_score, 180 | match.size, 181 | ) 182 | 183 | if not matched: 184 | return None 185 | else: 186 | return sorted( 187 | matched.items(), 188 | key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]), 189 | reverse=True, 190 | ) 191 | 192 | 193 | @functools.lru_cache(maxsize=1000, typed=False) 194 | def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list: 195 | fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name) 196 | try: 197 | conn = sqlite3.connect(db_path) 198 | conn.text_factory = bytes 199 | c = conn.cursor() 200 | c.execute(fetch_sql) 201 | picklist = set() 202 | for x in c.fetchall(): 203 | if isinstance(x[0], str): 204 | picklist.add(x[0].encode("utf-8")) 205 | elif isinstance(x[0], bytes): 206 | try: 207 | picklist.add(x[0].decode("utf-8")) 208 | except UnicodeDecodeError: 209 | picklist.add(x[0].decode("latin-1")) 210 | else: 211 | picklist.add(x[0]) 212 | picklist = list(picklist) 213 | except Exception as e: 214 | picklist = [] 215 | finally: 216 | conn.close() 217 | return picklist 218 | 219 | 220 | def get_database_matches( 221 | question: str, 222 | table_name: str, 223 | column_name: str, 224 | db_path: str, 225 | top_k_matches: int = 2, 226 | match_threshold: float = 0.85, 227 | ) -> List[str]: 228 | picklist = get_column_picklist( 229 | table_name=table_name, column_name=column_name, db_path=db_path 230 | ) 231 | # only maintain data in ``str'' type 232 | picklist = [ele.strip() for ele in picklist if isinstance(ele, str)] 233 | # picklist is unordered, we sort it to ensure the reproduction stability 234 | picklist = sorted(picklist) 235 | 236 | matches = [] 237 | if picklist and isinstance(picklist[0], str): 238 | matched_entries = get_matched_entries( 239 | s=question, 240 | field_values=picklist, 241 | m_theta=match_threshold, 242 | s_theta=match_threshold, 243 | ) 244 | 245 | if matched_entries: 246 | num_values_inserted = 0 247 | for _match_str, ( 248 | field_value, 249 | _s_match_str, 250 | match_score, 251 | s_match_score, 252 | _match_size, 253 | ) in matched_entries: 254 | if "name" in column_name and match_score * s_match_score < 1: 255 | continue 256 | if table_name != "sqlite_sequence": # Spider database artifact 257 | matches.append(field_value.strip()) 258 | num_values_inserted += 1 259 | if num_values_inserted >= top_k_matches: 260 | break 261 | return matches -------------------------------------------------------------------------------- /main_scripts/chat_manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from main_scripts.MAG import Soft_Schema_linker, Decomposer, Generator, Refiner 3 | from main_scripts.const import MAX_ROUND, SYSTEM_NAME, SCHEMALINKER_NAME, DECOMPOSER_NAME, REFINER_NAME 4 | 5 | INIT_LOG__PATH_FUNC = None 6 | LLM_API_FUC = None 7 | try: 8 | from main_scripts import llm 9 | LLM_API_FUC = llm.safe_call_llm 10 | INIT_LOG__PATH_FUNC = llm.init_log_path 11 | print(f"Use func from main_scripts.llm in chat_manager.py") 12 | except: 13 | print("Please prepare the code related to the large language model.") 14 | import time 15 | from pprint import pprint 16 | 17 | 18 | class ChatManager(object): 19 | def __init__(self, data_path: str, tables_json_path: str, log_path: str, model_name: str, dataset_name:str, dataset_path:str, lazy: bool=False, without_selector: bool=False): 20 | self.data_path = data_path # root path to database dir, including all databases 21 | self.tables_json_path = tables_json_path # path to table description json file 22 | self.log_path = log_path # path to record important printed content during running 23 | self.model_name = model_name # name of base LLM called by agent 24 | self.dataset_name = dataset_name 25 | self.ping_network() 26 | 27 | self.chat_group = [ 28 | Soft_Schema_linker(data_path=self.data_path, tables_json_path=self.tables_json_path, model_name=self.model_name, dataset_name=dataset_name, dataset_path=dataset_path, lazy=lazy, without_selector=without_selector), 29 | Decomposer(), 30 | Generator(dataset_name=dataset_name), 31 | Refiner(data_path=self.data_path, dataset_name=dataset_name) 32 | ] 33 | INIT_LOG__PATH_FUNC(log_path) 34 | 35 | def ping_network(self): 36 | # check network status 37 | print("Checking network status...", flush=True) 38 | try: 39 | _ = LLM_API_FUC("Hello world!") 40 | print("Network is available", flush=True) 41 | except Exception as e: 42 | raise Exception(f"Network is not available: {e}") 43 | 44 | def _chat_single_round(self, message: dict): 45 | # we use `dict` type so value can be changed in the function 46 | for agent in self.chat_group: # check each agent in the group 47 | if message['send_to'] == agent.name: 48 | agent.talk(message) 49 | 50 | def start(self, user_message: dict): 51 | # we use `dict` type so value can be changed in the function 52 | start_time = time.time() 53 | if user_message['send_to'] == SYSTEM_NAME: # in the first round, pass message to prune 54 | #user_message['send_to'] = SELECTOR_NAME 55 | user_message['send_to'] = SCHEMALINKER_NAME 56 | for _ in range(MAX_ROUND): # start chat in group 57 | self._chat_single_round(user_message) 58 | if user_message['send_to'] == SYSTEM_NAME: # should terminate chat 59 | break 60 | end_time = time.time() 61 | exec_time = end_time - start_time 62 | print(f"\033[0;34mExecute {exec_time} seconds\033[0m", flush=True) 63 | -------------------------------------------------------------------------------- /main_scripts/const.py: -------------------------------------------------------------------------------- 1 | MAX_ROUND = 8 # MAX_ROUND + 2 is the max try times of one task 2 | 3 | 4 | SCHEMALINKER_NAME = 'Soft_Schema_Linker' 5 | DECOMPOSER_NAME = 'Decomposer' 6 | GENERATOR_NAME = 'Generator' 7 | REFINER_NAME = 'Refiner' 8 | SYSTEM_NAME = 'System' 9 | 10 | 11 | schema_linker_template = """ 12 | As an experienced and professional database administrator, your task is to analyze a user question and a database schema to provide relevant information. The database schema consists of table descriptions, each containing multiple column descriptions. Your goal is to extract the entities from question and identify the relevant tables and columns based on these entities and the evidence provided. 13 | 14 | [Instruction]: 15 | 1. Extract the mentioned entities from the user question. Make sure all of the entities are extracted. 16 | 2. For each entity, keep at least 3 related columns. 17 | 4. Your output should include entity extraction, analysis and related database schema. 18 | 5. The related database schema should be in JSON format. 19 | 6. Each column's information in provided 【Schema】 is in this format: (column, description. Value examples) 20 | 21 | [Requirements]: 22 | 1. Sort the related columns in each list corresponding to each entity in descending order of relevance. 23 | 2. The chosen columns should be in this format: . 24 | 3. Make sure each chosen list is not empty. The value [] will be punished. 25 | 4.【Matched values】 may contain redundant or useless information in addition to the correct matching values, so you need to select the useful information in conjunction with the specific column names and descriptions. 26 | 5. An entity may not have a corresponding evidence, which requires you to find the relevant columns yourself through your understanding of the database schema. 27 | 6. There are columns with similar names but different meanings, you can find the most accurate columns in the right table based on the 【Summary of each table】. 28 | 29 | Here is a typical example: 30 | 31 | ========== 32 | 【DB_ID】 banking_system 33 | 【Schema】 34 | # Table: account 35 | [ 36 | (account_id , the id of the account. Value examples: [11382, 11362, 2, 1, 2367].), 37 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].), 38 | (frequency , frequency of the acount. Value examples: ['POPLATEK MESICNE', 'POPLATEK TYDNE', 'POPLATEK PO OBRATU'].), 39 | (date , the creation date of the account. Value examples: ['1997-12-29', '1997-12-28'].) 40 | ] 41 | # Table: client 42 | [ 43 | (client_id , the unique number. Value examples: [13998, 13971, 2, 1, 2839].), 44 | (gender , gender. Value examples: ['M', 'F']. And F:female . M:male ), 45 | (birth_date , birth date. Value examples: ['1987-09-27', '1986-08-13'].), 46 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].), 47 | (first_name , first_name.), 48 | (last_name , last_name.) 49 | ] 50 | # Table: loan 51 | [ 52 | (loan_id , the id number identifying the loan data. Value examples: [4959, 4960, 4961].), 53 | (account_id , the id number identifying the account. Value examples: [10, 80, 55, 43].), 54 | (date , the date when the loan is approved. Value examples: ['1998-07-12', '1998-04-19'].), 55 | (amount , the id number identifying the loan data. Value examples: [1567, 7877, 9988].), 56 | (duration , the id number identifying the loan data. Value examples: [60, 48, 24, 12, 36].), 57 | (payments , the id number identifying the loan data. Value examples: [3456, 8972, 9845].), 58 | (status , the id number identifying the loan data. Value examples: ['C', 'A', 'D', 'B'].) 59 | ] 60 | # Table: district 61 | [ 62 | (district_id , location of branch. Value examples: [77, 76].), 63 | (A2 , area in square kilometers. Value examples: [50.5, 48.9].), 64 | (A4 , number of inhabitants. Value examples: [95907, 95616].), 65 | (A5 , number of households. Value examples: [35678, 34892].), 66 | (A6 , literacy rate. Value examples: [95.6, 92.3, 89.7].), 67 | (A7 , number of entrepreneurs. Value examples: [1234, 1456].), 68 | (A8 , number of cities. Value examples: [5, 4].), 69 | (A9 , number of schools. Value examples: [15, 12, 10].), 70 | (A10 , number of hospitals. Value examples: [8, 6, 4].), 71 | (A11 , average salary. Value examples: [12541.5, 11277].), 72 | (A12 , poverty rate. Value examples: [12.4, 9.8].), 73 | (A13 , unemployment rate. Value examples: [8.2, 7.9].), 74 | (A15 , number of crimes. Value examples: [256, 189].) 75 | ] 76 | 【Primary keys】 77 | account.`account_id` | client.`client_id` | loan.`loan_id` | district.`district_id` 78 | 【Foreign keys】 79 | client.`district_id` = district.`district_id` 80 | 【Question】 81 | What is the gender of the youngest client who opened account in the lowest average salary branch and when did this client open the account? Please list their full name. 82 | 【Evidence】 83 | Later birthdate refers to younger age; A11 refers to average salary; Full name refers to first_name, last_name 84 | 【Matched values】 85 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you in selecting columns (You need to ignore matches that are not relevant to the question): 86 | No matched values. 87 | 【Summary of each table】 88 | # account: Specific information for each account 89 | # client: Basic information about each client 90 | # loan: Detailed records of each loan 91 | # district: Various data recorded in each district 92 | 【Answer】 93 | The entities extracted from the 【Question】 are: 94 | 1. gender; 95 | 2. youngest client; 96 | 3. account; 97 | 4. lowest average salary branch; 98 | 5. when did this client open the account; 99 | 6. full name 100 | 101 | Extract the related evidence in 【Evidence】 for entities if 【Evidence】 is not None: 102 | 1. gender --no related evidence **the columns about gender are related** 103 | 2. youngest client --Later birthdate refers to younger age; **the columns about birth date of client and the id of client are related** 104 | 3. account --no related evidence **the columns about account or account ids are related** 105 | 4. lowest average salary branch --A11 refers to average salary **the columns about average salary and branch are related** 106 | 5. when did this client open the account --no related evidence **the columns about time or date of the account are related** 107 | 6. full name --Full name refers to first_name, last_name **the columns about first_name, last_name of the client are related** 108 | 109 | Therefore, we can select the related database schema based on these entities with 【Evidence】: 110 | ```json 111 | {{ 112 | "gender": ["client.gender","client.client_id","loan.status"], 113 | "youngest client": ["client.birth_date","client.client_id","account.date","loan.date"], 114 | "account": ["account.account_id","loan.account_id","account.date"], 115 | "lowest average salary branch": ["district.A11","district.district_id","district.A13"], 116 | "when did this client open the account": ["account.date","loan.date","client.birth_date"], 117 | "full name": ["client.first_name","client.last_name","client.client_id"] 118 | }} 119 | ``` 120 | Question Solved. 121 | 122 | ========== 123 | 124 | Here is a new example, please start answering: 125 | 126 | 【DB_ID】 {db_id} 127 | 【Schema】 128 | {desc_str} 129 | 【Primary keys】 130 | {pk_str} 131 | 【Foreign keys】 132 | {fk_str} 133 | 【Question】 134 | {query} 135 | 【Evidence】 136 | {evidence} 137 | 【Matched values】 138 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you in selecting columns (You need to ignore matches that are not relevant to the question): 139 | {matched_list} 140 | 【Summary of each table】 141 | {summary_str} 142 | 【Answer】 143 | """ 144 | 145 | summarizer_template = """ 146 | [instruction] 147 | Given the database schema, you need to summarise the data stored in each table in one sentence, based on the name of the table and the columns in the table. 148 | 149 | [Requirements] 150 | - Your output should be in json format 151 | 152 | Here is an example: 153 | ========== 154 | 155 | 【DB_ID】 banking_system 156 | 【Schema】 157 | # Table: account 158 | [ 159 | (account_id , the id of the account. Value examples: [11382, 11362, 2, 1, 2367].), 160 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].), 161 | (frequency , frequency of the acount. Value examples: ['POPLATEK MESICNE', 'POPLATEK TYDNE', 'POPLATEK PO OBRATU'].), 162 | (date , the creation date of the account. Value examples: ['1997-12-29', '1997-12-28'].) 163 | ] 164 | # Table: client 165 | [ 166 | (client_id , the unique number. Value examples: [13998, 13971, 2, 1, 2839].), 167 | (gender , gender. Value examples: ['M', 'F']. And F:female . M:male ), 168 | (birth_date , birth date. Value examples: ['1987-09-27', '1986-08-13'].), 169 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].), 170 | (first_name , first_name.), 171 | (last_name , last_name.) 172 | ] 173 | # Table: loan 174 | [ 175 | (loan_id , the id number identifying the loan data. Value examples: [4959, 4960, 4961].), 176 | (account_id , the id number identifying the account. Value examples: [10, 80, 55, 43].), 177 | (date , the date when the loan is approved. Value examples: ['1998-07-12', '1998-04-19'].), 178 | (amount , the id number identifying the loan data. Value examples: [1567, 7877, 9988].), 179 | (duration , the id number identifying the loan data. Value examples: [60, 48, 24, 12, 36].), 180 | (payments , the id number identifying the loan data. Value examples: [3456, 8972, 9845].), 181 | (status , the id number identifying the loan data. Value examples: ['C', 'A', 'D', 'B'].) 182 | ] 183 | # Table: district 184 | [ 185 | (district_id , location of branch. Value examples: [77, 76].), 186 | (A2 , area in square kilometers. Value examples: [50.5, 48.9].), 187 | (A4 , number of inhabitants. Value examples: [95907, 95616].), 188 | (A5 , number of households. Value examples: [35678, 34892].), 189 | (A6 , literacy rate. Value examples: [95.6, 92.3, 89.7].), 190 | (A7 , number of entrepreneurs. Value examples: [1234, 1456].), 191 | (A8 , number of cities. Value examples: [5, 4].), 192 | (A9 , number of schools. Value examples: [15, 12, 10].), 193 | (A10 , number of hospitals. Value examples: [8, 6, 4].), 194 | (A11 , average salary. Value examples: [12541.5, 11277].), 195 | (A12 , poverty rate. Value examples: [12.4, 9.8].), 196 | (A13 , unemployment rate. Value examples: [8.2, 7.9].), 197 | (A15 , number of crimes. Value examples: [256, 189].) 198 | ] 199 | 200 | 【Summary】 201 | ```json 202 | {{ 203 | "account":"Specific information for each account", 204 | "client":"Basic information about each client", 205 | "loan":"Detailed records of each loan", 206 | "district":"Various data recorded in each district", 207 | 208 | }} 209 | ``` 210 | ========== 211 | 212 | Here is a new case: 213 | 【DB_ID】 {db_id} 214 | 【Schema】 215 | {desc_str} 216 | 217 | 【Summary】 218 | """ 219 | 220 | 221 | subq_pattern = r"Sub question\s*\d+\s*:" 222 | 223 | 224 | pure_decomposer_template = """ 225 | [Instruction] 226 | Given a 【query】, you need to understanding the intent of Query, and then deceompose it into Targets and Conditions. Then you need to combine Targets and Conditions into Subquerys step by step. 227 | For the case where Conditions is NULL, consider Targets as the final Subquery directly. 228 | For the case where Conditions are not NULL, combine Targets and the first Condition to get the first Subquery, then combine this Subquery and the next Condition into a new Subquery until all Conditions are used (which means the content of the last Subquery and the original Query is the same). 229 | 230 | [Requirements] 231 | -Try not to overlap Targets and Conditions. 232 | -Make sure the decomposed Target and Condition can cover all of the information in Query. 233 | -Don't change any information (specific value) in Query! 234 | -Mark each Subquery with ## in front of it. 235 | 236 | Here are some examples: 237 | ========== 238 | 239 | 【Query】 240 | Show the stadium name and the number of concerts in each stadium. Please also list the year the stadium was built. 241 | 【Evidence】 242 | NULL 243 | 244 | 【Decomposition】 245 | Targets: List the stadium name, the year the stadium built and the number of concerts in each stadium 246 | Conditions: NULL 247 | 248 | Subqueries: 249 | 1. Since Conditions is NULL, the final Subquery is the Targets. 250 | ##Subquery: List the stadium name, the year the stadium built and the number of concerts in each stadium 251 | ========== 252 | 253 | 【Query】 254 | What is the qualification rate for the H-11 products produced in 2023/11/2? 255 | 【Evidence】 256 | qualification rate = `Numqualified(H-11)` / `production(H-11)` 257 | 258 | 【Decomposition】 259 | Targets: List the qualification rate for the H-11 Products 260 | Conditions: 261 | 1. produced in 2023/11/2 --Condition_1 262 | 263 | Subqueries: 264 | 1. Combine Targets and Conditon_1 to get the first Subquery. 265 | ##Subquery: List the qualification rate for the H-11 Products produced in 2023/11/2 266 | 267 | ========== 268 | 269 | 【Query】 270 | List the race of institutions in Alabama with number of students greater than the 90% of average number of students of all institutions? 271 | 【Evidence】 272 | Alabama refers to state = 'Alabama'; number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count 273 | 274 | 【Decomposition】 275 | Targets: List the race of institutions 276 | Conditions: 277 | 1. in Alabama --Condition_1 278 | 2. number of students greater than the 90% of average number of students of all institutions --Condition_2 279 | 280 | Subqueries: 281 | 1. Combine Targets and Condition_1 to get the first Subquery. 282 | ##Subquery: List the race of institution in Alabama 283 | 2. Conbine the first Subquery and Conditon_2 to get the seconed Subquery. 284 | ##Subquery: List the race of institutions in Alabama with number of number of students greater than the 90% of average number of students of all institutions 285 | 286 | ========== 287 | 288 | 【Query】 289 | Which president got a higher approval rating, Joseph Biden or Donald Trump? 290 | 【Evidence】 291 | NULL 292 | 293 | 【Decomposition】 294 | Targets: List the name of the president who got a higher approval rating between Joseph Biden or Donald Trump 295 | Conditions: 296 | NULL 297 | 298 | Subqueries: 299 | 1. Since Conditions is NULL, the final Subquery is the Targets. 300 | ##Subquery: List the name of the president who got a higher approval rating between Joseph Biden or Donald Trump 301 | 302 | ========== 303 | 304 | 【Query】 305 | For movie id 1269, how many users, who was a paying subscriber and was eligible for trial when he rated the movie, gave the movie a rating score of less than or equal to 2? 306 | 【Evidence】 307 | NULL 308 | 309 | 【Decomposition】 310 | Targets: List the number of users 311 | Conditions: 312 | 1. was a paying subscriber --Condition_1 313 | 2. was eligible for trial --Condition_2 314 | 3. for movie id 1269, gave the movie a rating score of less than or equal to 2 --Condition3 315 | 316 | Subquerys: 317 | 1. Combine Targets and Conditon_1 to get the first Subquery. 318 | ##Subquery: List the number of users who was a paying subscriber 319 | 2. Combine the first Subquery and Condition_2 to get the second Subquery. 320 | ##Subquery: List the number of users who was a paying subscriber and was eligible for trial 321 | 3. Combine the second Subquery and Condition_3 to get the third Subquery. 322 | ##Subquery: List the number of users who was a paying subscriber and was eligible for trial and gave the movie whose id is 1269 a rating score of less than or equal to 2 323 | 324 | ========== 325 | 326 | Here is a new query need to be decomposed: 327 | 328 | 【Query】 329 | {query} 330 | 【Evidence】 331 | {evidence} 332 | 333 | 【Decomposition】 334 | """ 335 | 336 | soft_schema_initial_generator_template = """ 337 | Given a 【Database schema】 description, a knowledge 【Evidence】and a 【Question】, you need to use valid SQLite and understand the database and knowledge so that you can generate a good SQL for the 【Question】. 338 | When generating SQL, we should always consider constraints: 339 | 【Constraints】 340 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value 341 | - In `FROM ` or `JOIN
`, do not include unnecessary table 342 | - If use max or min func, `JOIN
` FIRST, THEN use `SELECT MAX()` or `SELECT MIN()` 343 | - If [Value examples] of has 'None' or None, use `JOIN
` or `WHERE is NOT NULL` is better 344 | - If use `ORDER BY ASC|DESC`, add `GROUP BY ` before to select distinct values 345 | - If include more than one table, use `JOIN
` 346 | - If use `JOIN
`, the connected columns should be in the 【Foreign keys】 347 | - If evidence gives a formula for calculating a value, try to use that formula 348 | - If use `ORDER BY ASC LIMIT `, please use `ORDER BY ASC NULLS LAST LIMIT ` to make sure the null values will not be selected 349 | 350 | ========== 351 | 352 | 【Database schema】 353 | # stadium: [Stadium_ID (INTEGER), Location (TEXT), Name (TEXT)] 354 | # concert: [concert_ID (INTEGER), concert_Name (TEXT), Stadium_ID (INTEGER)] 355 | 【Primary keys】 356 | stadium.`Stadium_ID` | concert.`concert_ID` 357 | 【Foreign keys】 358 | concert.`Stadium_ID` = stadium.`Stadium_ID` 359 | 【Detailed descriptions of tables and columns】 360 | stadium.`Stadium_ID`: The column 'Stadium_ID' in Table has column descriptions of "stadium id". Value examples: [1, 2, 3, 4, 5, 6]. 361 | stadium.`Name`: The column 'Name' in Table has column descriptions of "name of stadium". Value examples: ["Stark's Park", 'Somerset Park', 'Recreation Park', 'Hampden Park', 'Glebe Park', 'Gayfield Park']. 362 | stadium.`Location`: The column 'Location' in Table has column descriptions of "location of stadium". Value examples: ['Stirling Albion', 'Raith Rovers', "Queen's Park", 'Peterhead', 'East Fife', 'Brechin City']. 363 | concert.`concert_ID`: The column 'concert_ID' in Table has column descriptions of "concert id". Value examples: [1, 2, 3, 4, 5, 6]. 364 | concert.`concert_Name`: The column 'concert_Name' in Table has column descriptions of "concert name". Value examples: ['Week 1', 'Week 2', 'Super bootcamp', 'Home Visits', 'Auditions']. 365 | concert.`Stadium_ID`: The column 'Stadium_ID' in Table has column descriptions of "stadium id". Value examples: [2, 9, 7, 10, 1]. 366 | 【Evidence】 367 | NULL 368 | 【Question】 369 | Show the stadium name and the number of concerts in each stadium. 370 | 【Matched values】 371 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 372 | No matched values. 373 | 374 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step: 375 | 【Question】: Show the stadium name and the number of concerts in each stadium. 376 | Targets for `SELECT`: the stadium name and the number of concerts 377 | hints from 【Evidence】:NULL 378 | For the entities in the 【Question】, get corresponding columns from 【Database schema】 with hints: the stadium name refers to stadium.`Name`, the number of concerts refers to COUNT(concert.`concert_ID`), each stadium refers to stadium.`Stadium_ID` 379 | Connection of tables: include tables and , get connected keys from 【Foreign keys】: concert.`Stadium_ID` = stadium.`Stadium_ID` 380 | Final SQL: 381 | ```sql 382 | SELECT T1.`Name`, COUNT(T2.concert_ID) AS num_concerts 383 | FROM stadium AS T1 384 | JOIN concert AS T2 385 | ON T1.`Stadium_ID` = T2.`Stadium_ID` 386 | GROUP BY T1.`Stadium_ID` 387 | ``` 388 | 389 | Question Solved. 390 | 391 | ========== 392 | 393 | 【Database schema】 394 | # country: [origin (INTEGER), country(TEXT)] 395 | # price: [ID (INTEGER), price (REAL)] 396 | # data: [ID (INTEGER), mpg (REAL), cylinders (INTEGER), displacement (TEXT), horsepower (REAL), weight (REAL), acceleration (REAL), model (TEXT), car_name (TEXT)] 397 | # production: [ID (INTEGER), model_year (INTEGER), country (INTEGER)] 398 | 【Primary keys】 399 | country.`origin` | price.`ID` | data.`ID` | production.`ID` 400 | 【Foreign keys】 401 | data.`ID` = price.`ID` 402 | production.`ID` = price.`ID` 403 | production.`ID` = data.`ID` 404 | production.`country` = country.`origin` 405 | 【Detailed descriptions of tables and columns】 406 | country.`origin`: The column 'origin' in Table has column descriptions of "the unique identifier for the origin country". Value examples: [1, 2, 3]. 407 | country.`country`: The column 'country' in Table has column descriptions of "the origin country of the car". Value examples: ['USA', 'Japan', 'Europe']. 408 | data.`horsepower`: The column 'horsepower' in Table has column descriptions of "horse power associated with the car". 409 | data.`acceleration`: The column 'acceleration' in Table has column descriptions of "acceleration of the car in miles per squared hour". 410 | production.`country`: The column 'country' in Table has column descriptions of "country id to which the car belongs". Value examples: [1, 2, 3]. 411 | 【Evidence】 412 | the fastest refers to max(horsepower); name of the car refers to car_name 413 | 【Question】 414 | What is the fastest car made by Japan? 415 | 【Matched values】 416 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 417 | country.`country` = 'Japan' 418 | 419 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step: 420 | 【Question】: What is the fastest car made by Japan? 421 | Targets for `SELECT`: the name of the fastest car 422 | hints from 【Evidence】: the fastest refers to max(horsepower); name of the car refers to car_name 423 | For the entities in the 【Question】, get corresponding columns from 【Database schema】 with hints: car refers to data.`car_name`, the fastest refers to MAX(data.`horsepower`), Japan refers to country.`country` 424 | Connection of tables: includes tables and , since and are not connected directly, use as a bridge, get connected keys from 【Foreign keys】: production.`ID` = data.`ID`, production.`country` = country.`origin` 425 | Final SQL: 426 | ```sql 427 | SELECT T1.`car_name` 428 | FROM data AS T1 429 | INNER JOIN production AS T2 430 | ON T1.`ID` = T2.`ID` 431 | INNER JOIN country AS T3 432 | ON T3.`origin` = T2.`country` 433 | WHERE T3.`country` = 'Japan' 434 | ORDER BY T1.`horsepower` DESC LIMIT 1 435 | ``` 436 | 437 | Question Solved. 438 | 439 | ========== 440 | 441 | 【Database schema】 442 | # institution_details: [unitid (INTEGER), chronname (TEXT), city (TEXT), state (TEXT), site (TEXT), student_count(TEXT)] 443 | # institution_grads: [unitid (INTEGER), gender (TEXT), race (TEXT), cohhort (TEXT)] 444 | 【Primary keys】 445 | institution_details.`unitid` | institution_grads.`unitid` 446 | 【Foreign keys】 447 | institution_grads.`unitid` = institution_details.`unitid` 448 | 【Detailed descriptions of tables and columns】 449 | institution_details.`unitid`: The column 'unitid' in Table has column descriptions of "Education Unit ID number". 450 | institution_details.`chronname`: The column 'chronname' in Table has column descriptions of "Institution name". 451 | institution_details.`student_count`: The column 'student_count' in Table has column descriptions of "Total number of undergraduates in 2010". 452 | 【Evidence】 453 | number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count 454 | 【Question】 455 | List the chronname of institutions with number of students greater than the 90% of average number of students of all institutions? 456 | 【Matched values】 457 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 458 | No matched values. 459 | 460 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step: 461 | 【Question】: List the chroname of institutions with number of students greater than the 90% of average number of students of all institutions? 462 | Targets for `SELECT`: the chroname of institutions 463 | hints from 【Evidence】: number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count 464 | There is a formula in 【Evidence】, so we can turn it into SQL format: institution_details.`student_count` > ( SELECT AVG(`student_count`) * 0.9 FROM institution_details ) 465 | For the entities in the 【Question】, get corresponding columns from 【Database schema】 with hints: the chronname of institutions refers to institution_details.`chronname`, number of students refers to institution_details.`student_count` 466 | Connection of tables: include only one table , no connection 467 | Final SQL: 468 | ```sql 469 | SELECT DISTINCT `chroname` 470 | FROM institution_details 471 | WHERE `student_count` > ( 472 | SELECT AVG(`student_count`) * 0.9 FROM institution_details 473 | ) 474 | ``` 475 | 476 | Question Solved. 477 | 478 | ========== 479 | 480 | 【Database schema】 481 | {desc_str} 482 | 【Primary keys】 483 | {pk_str} 484 | 【Foreign keys】 485 | {fk_str} 486 | 【Detailed descriptions of tables and columns】 487 | {detailed_str} 488 | 【Evidence】 489 | {evidence} 490 | 【Question】 491 | {query} 492 | 【Matched values】 493 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 494 | {matched_list} 495 | 496 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step: 497 | """ 498 | 499 | soft_schema_continuous_generator_template = """ 500 | Given a 【Database schema】 description, a knowledge 【Evidence】, a 【Question】, a 【Subquesion】 extracted from 【Question】 and a 【Sub-SQL】 for the 【Subquestion】, you need to use valid SQLite and understand the database and knowledge, and then generate a complete SQL for the 【Question】 based on 【Sub-SQL】. 501 | When generating SQL, we should always consider constraints: 502 | 【Constraints】 503 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value 504 | - In `FROM
` or `JOIN
`, do not include unnecessary table 505 | - If use max or min func, `JOIN
` FIRST, THEN use `SELECT MAX()` or `SELECT MIN()` 506 | - If [Value examples] of has 'None' or None, use `JOIN
` or `WHERE is NOT NULL` is better 507 | - If use `ORDER BY ASC|DESC`, add `GROUP BY ` before to select distinct values 508 | - If include more than one table, use `JOIN
` 509 | - If use `JOIN
`, the connected columns should be in the 【Foreign keys】 510 | - If evidence gives a formula for calculating a value, try to use that formula 511 | - If use `ORDER BY ASC LIMIT `, please use `ORDER BY ASC NULLS LAST LIMIT ` to make sure the null values will not be selected 512 | 513 | ========== 514 | 515 | 【Database schema】 516 | # country: [origin (INTEGER), country(TEXT)] 517 | # price: [ID (INTEGER), price (REAL)] 518 | # data: [ID (INTEGER), mpg (REAL), cylinders (INTEGER), displacement (TEXT), horsepower (REAL), weight (REAL), acceleration (REAL), model (TEXT), car_name (TEXT)] 519 | # production: [ID (INTEGER), model_year (INTEGER), country (INTEGER)] 520 | 【Primary keys】 521 | country.`origin` | price.`ID` | data.`ID` | production.`ID` 522 | 【Foreign keys】 523 | data.`ID` = price.`ID` 524 | production.`ID` = price.`ID` 525 | production.`ID` = data.`ID` 526 | production.`country` = country.`origin` 527 | 【Detailed descriptions of tables and columns】 528 | country.`origin`: The column 'origin' in Table has column descriptions of "the unique identifier for the origin country". Value examples: [1, 2, 3]. 529 | country.`country`: The column 'country' in Table has column descriptions of "the origin country of the car". Value examples: ['USA', 'Japan', 'Europe']. 530 | data.`horsepower`: The column 'horsepower' in Table has column descriptions of "horse power associated with the car". 531 | data.`acceleration`: The column 'acceleration' in Table has column descriptions of "acceleration of the car in miles per squared hour". 532 | production.`country`: The column 'country' in Table has column descriptions of "country id to which the car belongs". Value examples: [1, 2, 3]. 533 | 【Evidence】 534 | the fastest refers to max(horsepower); made by Japan refers to country = 'Japan' 535 | 【Question】 536 | What is the price of the fastest car made by Japan? 537 | 【Subquestion】 538 | What is the price of the fastest car? 539 | 【Sub-SQL】 540 | ``` 541 | SELECT T1.`price` FROM price AS T1 INNER JOIN data AS T2 ON T2.`ID` = T1.`ID` ORDER BY T2.`horsepower` DESC LIMIT 1 542 | ``` 543 | 【Matched values】 544 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 545 | No matched values. 546 | 547 | 【Subquestion】 is decomposed from 【Question】, now we provide Sub-SQL corresponding to Subquestion, you just need to complete the remaining conditions based on Sub-SQL to generate SQL for 【Question】. Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, and then generate SQL after thinking step by step: 548 | hints from 【Evidence】: the fastest refers to max(horsepower); made by Japan refers to country = 'Japan'; 549 | Find necessary columns from 【Database schema】 in addition to Sub-SQL: Japan refers to country.`country` 550 | Connection of tables: the connection of and is needed, since and are not connected directly, use as a bridge, get connected keys from 【Foreign keys】: production.`ID` = data.`ID`, production.`country` = country.`origin` 551 | Final SQL: 552 | ```sql 553 | SELECT T1.`price` 554 | FROM price AS T1 555 | INNER JOIN data AS T2 556 | on T2.`ID` = T1.`ID` 557 | INNER JOIN production AS T3 558 | ON T3.`ID` = T2.`ID` 559 | INNER JOIN country AS T4 560 | ON T4.`origin` = T3.`country` 561 | WHERE T4.`country` = 'Japan' 562 | ORDER BY T2.`horsepower` DESC LIMIT 1 563 | ``` 564 | 565 | Question solved. 566 | 567 | ========== 568 | 569 | 【Database schema】 570 | # institution_details: [unitid (INTEGER), chronname (TEXT), city (TEXT), state (TEXT), site (TEXT), student_count(TEXT)] 571 | # institution_grads: [unitid (INTEGER), gender (TEXT), race (TEXT), cohhort (TEXT)] 572 | 【Primary keys】 573 | institution_details.`unitid` | institution_grads.`unitid` 574 | 【Foreign keys】 575 | institution_grads.`unitid` = institution_details.`unitid` 576 | 【Detailed descriptions of tables and columns】 577 | institution_details.`unitid`: The column 'unitid' in Table has column descriptions of "Education Unit ID number". 578 | institution_details.`state`: The column 'state' in Table has column descriptions of "Institution state". 579 | institution_details.`student_count`: The column 'student_count' in Table has column descriptions of "Total number of undergraduates in 2010". 580 | institution_grads.`race`: The column 'race' in Table has column descriptions of "race/ethnicity of students". Value examples: ['X', 'Ai', 'A', 'B', 'H']. 581 | 【Evidence】 582 | Alabama refers to state = 'Alabama'; number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count 583 | 【Question】 584 | List the race of institutions in Alabama with number of students greater than the 90% of average number of students of all institutions? 585 | 【Subquestion】 586 | List the race of institutions in 'Alabama'. 587 | 【Sub-SQL】 588 | ``` 589 | SELECT DISTINCT T1.`race` FROM institution_grads AS T1 INNER JOIN institution_details AS T2 ON T1.`unitid` = T2.`unitid` WHERE T2.`state` = 'Alabama' 590 | ``` 591 | 【Matched values】 592 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 593 | institution_details.`city` = 'Alabama'; institution_details.`state` = 'Alabama' 594 | 595 | 【Subquestion】 is decomposed from 【Question】, now we provide Sub-SQL corresponding to Subquestion, you just need to complete the remaining conditions based on Sub-SQL to generate SQL for 【Question】. Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, and then generate SQL after thinking step by step: 596 | hints from 【Evidence】: Alabama refers to state = 'Alabama'; number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count 597 | There is a formula in 【Evidence】, so we can turn it into SQL format: institution_details.`student_count` > ( SELECT AVG(`student_count`) * 0.9 FROM institution_details ) 598 | Find necessary columns from 【Database schema】 in addition to Sub-SQL: number of students refers to institution_details.`student_count` 599 | Connection of tables: the table is already in Sub-SQL, no more connection is needed 600 | Final SQL: 601 | ```sql 602 | SELECT DISTINCT T1.`race` 603 | FROM institution_grads AS T1 604 | INNER JOIN institution_details AS T2 605 | ON T1.`unitid` = T2.`unitid` 606 | WHERE T2.`student_count` > ( 607 | SELECT AVG(`student_count`) * 0.9 FROM institution_details 608 | ) 609 | AND T2.`state` = 'Alabama' 610 | ``` 611 | 612 | Question solved. 613 | 614 | ========== 615 | 616 | 【Database schema】 617 | {desc_str} 618 | 【Primary keys】 619 | {pk_str} 620 | 【Foreign keys】 621 | {fk_str} 622 | 【Detailed descriptions of tables and columns】 623 | {detailed_str} 624 | 【Evidence】 625 | {evidence} 626 | 【Question】 627 | {query} 628 | 【Subquestion】 629 | {subquery} 630 | 【Sub-SQL】 631 | {subsql} 632 | 【Matched values】 633 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question): 634 | {matched_list} 635 | 636 | 【Subquestion】 is decomposed from 【Question】, now we provide Sub-SQL corresponding to Subquestion, you just need to complete the remaining conditions based on Sub-SQL to generate SQL for 【Question】. Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, and then generate SQL after thinking step by step: 637 | """ 638 | 639 | 640 | 641 | baseline_template = """ 642 | Given a 【Database schema】 description, a knowledge 【Evidence】 and the 【Question】, you need to use valid SQLite and understand the database and knowledge, and then generate SQL. 643 | You can write answer in script blocks, and indicate script type in it, like this: 644 | ```sql 645 | SELECT column_a 646 | FROM table_b 647 | ``` 648 | 649 | 【Database schema】 650 | {desc_str} 651 | 【Question】 652 | {query} 653 | 【Evidence】 654 | {evidence} 655 | 【Answer】 656 | """ 657 | 658 | 659 | refiner_template = """ 660 | 【Instruction】 661 | When executing SQL below, some errors occurred, please fix up SQL based on query and database info. 662 | Solve the task step by step if you need to. Using SQL format in the code block, and indicate script type in the code block. 663 | When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. 664 | 【Constraints】 665 | - The SQL should start with 'SELECT' 666 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value 667 | - In `FROM
` or `JOIN
`, do not include unnecessary table 668 | - If use `JOIN
`, the connected columns should be in the Foreign keys of 【Database schema】 669 | 【Response format】 670 | Your response should be in this format: 671 | Analysis: 672 | **[Your analysis]** 673 | Correct SQL: 674 | ```sql 675 | [the fixed SQL] 676 | ``` 677 | 【Attention】 678 | Only SQL statements are allowed in [the fixed SQL], do not add any comments. 679 | 680 | 【Evidence】 681 | {evidence} 682 | 【Query】 683 | -- {query} 684 | 【Database info】 685 | {desc_str} 686 | 【Primary keys】 687 | {pk_str} 688 | 【Foreign keys】 689 | {fk_str} 690 | 【Detailed descriptions of tables and columns]】 691 | {detailed_str} 692 | 【old SQL】 693 | ```sql 694 | {sql} 695 | ``` 696 | 【SQLite error】 697 | {sqlite_error} 698 | 【Exception class】 699 | {exception_class} 700 | 701 | Now please fixup old SQL and generate new SQL again. 702 | 【correct SQL】 703 | """ 704 | 705 | nested_refiner_template = """ 706 | 【Instruction】 707 | When executing SQL below, some errors occurred, please fix up SQL based on query and database info. 708 | Solve the task step by step if you need to. Using SQL format in the code block, and indicate script type in the code block. 709 | When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. 710 | 【Possible error】 711 | When using nested SQL, using the MIN or MAX function in the sub-SQL may result in null data in the end. Because when multiple tables are joined, they are connected by foreign keys and contain only some data in common, the maximum or minimum value of a column in a table may not be in the joined table. 712 | 【Constraints】 713 | - The SQL should start with 'SELECT' 714 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value 715 | - In `FROM
` or `JOIN
`, do not include unnecessary table 716 | - If use `JOIN
`, the connected columns should be in the Foreign keys of 【Database schema】 717 | - If use `ORDER BY ASC LIMIT `, please use `ORDER BY ASC NULLS LAST LIMIT ` to make sure the null values will not be selected 718 | - There are columns with similar names but different meanings, you can find the most accurate columns in the right table based on the 【Summary of each table】 719 | 【Response format】 720 | Your response should be in this format: 721 | Analysis: 722 | **[Your analysis]** 723 | Correct SQL: 724 | ```sql 725 | [the fixed SQL] 726 | ``` 727 | 【Attention】 728 | Only SQL statements are allowed in [the fixed SQL], do not add any comments. 729 | 【Typical examples】 730 | For the error called "no data selected", here are some typical example. 731 | 1. Wrong SQL: 732 | ``` 733 | SELECT T1.`date` FROM recordtime AS T1 INNER JOIN information AS T2 ON T1.`id` = T2.`id` WHERE T2.`viewers` = ( SELECT MAX(`viewers`) FROM information) 734 | ``` 735 | Analysis: **Using the MAX function in a nested SQL statement will result in a mismatch because the maximum or minimum value of the column may not be in the joined table, so use `ORDER BY`!** 736 | Correct SQL: 737 | ```sql 738 | SELECT T1.`date` 739 | FROM recordtime AS T1 740 | INNER JOIN information AS T2 741 | ON T1.`id` = T2.`id` 742 | ORDER BY T2.`viewers` DESC 743 | LIMIT 1 744 | ``` 745 | 2. Wrong SQL: 746 | ``` 747 | SELECT T1.`idea` FROM work AS T1 INNER JOIN student AS T2 ON T1.`ID` = T2.`ID` WHERE T2.`final_score` = ( SELECT MIN(`final_score`) FROM student) 748 | ``` 749 | Analysis: **Using the MIN function in a nested SQL statement will result in a mismatch because the maximum or minimum value of the column may not be in the joined table, so use `ORDER BY`!** 750 | Correct SQL: 751 | ```sql 752 | SELECT T1.`idea` 753 | FROM work AS T1 754 | INNER JOIN student AS T2 755 | ON T1.`ID` = T2.`ID` 756 | ORDER BY T2.`final_score` ASC 757 | LIMIT 1 758 | ``` 759 | 760 | Here is a new case for you. 761 | 【Evidence】 762 | {evidence} 763 | 【Query】 764 | -- {query} 765 | 【Database info】 766 | {desc_str} 767 | 【Primary keys】 768 | {pk_str} 769 | 【Foreign keys】 770 | {fk_str} 771 | 【Detailed descriptions of tables and columns]】 772 | {detailed_str} 773 | 【Matched values】 774 | {matched_content} 775 | 【old SQL】 776 | ```sql 777 | {sql} 778 | ``` 779 | 【SQLite error】 780 | {sqlite_error} 781 | 【Exception class】 782 | {exception_class} 783 | 784 | Now please fixup old SQL and generate new SQL again. 785 | 【correct SQL】 786 | """ 787 | -------------------------------------------------------------------------------- /main_scripts/llm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import time 4 | 5 | import openai 6 | import os 7 | 8 | 9 | # Your api settings 10 | API_Key = "API Key" 11 | Base_url = "The base url of your API" 12 | 13 | #MODEL_NAME = 'gpt-4' 14 | MODEL_NAME = 'gpt-3.5-turbo-1106' 15 | #MODEL_NAME = 'gpt-4-0125-preview' 16 | MAX_TRY = 5 17 | 18 | # 用来传递外面的字典进来 19 | world_dict = {} 20 | 21 | log_path = None 22 | api_trace_json_path = None 23 | total_prompt_tokens = 0 24 | total_response_tokens = 0 25 | 26 | 27 | def init_log_path(my_log_path): 28 | global total_prompt_tokens 29 | global total_response_tokens 30 | global log_path 31 | global api_trace_json_path 32 | log_path = my_log_path 33 | total_prompt_tokens = 0 34 | total_response_tokens = 0 35 | dir_name = os.path.dirname(log_path) 36 | os.makedirs(dir_name, exist_ok=True) 37 | 38 | # 另外一个记录api调用的文件 39 | api_trace_json_path = os.path.join(dir_name, 'api_trace.json') 40 | 41 | 42 | def api_func(prompt:str): 43 | global MODEL_NAME 44 | print(f"\nUse OpenAI model: {MODEL_NAME}\n") 45 | 46 | client = openai.OpenAI(api_key=API_Key,base_url=Base_url) 47 | 48 | response = client.chat.completions.create( 49 | model=MODEL_NAME, 50 | messages=[{"role": "user", "content": prompt}], 51 | n = 1, 52 | temperature=0.1, 53 | max_tokens=1000 54 | ) 55 | 56 | text = response.choices[0].message.content.strip() 57 | prompt_token = response.usage.prompt_tokens 58 | response_token = response.usage.completion_tokens 59 | 60 | 61 | return text, prompt_token, response_token 62 | 63 | 64 | def safe_call_llm(input_prompt, **kwargs) -> str: 65 | """ 66 | Function description: Given the prompt, return the content generated by LLM (internal automatic error retry 5 times, 5 errors throw an exception) 67 | """ 68 | global MODEL_NAME 69 | global log_path 70 | global api_trace_json_path 71 | global total_prompt_tokens 72 | global total_response_tokens 73 | global world_dict 74 | 75 | for i in range(5): 76 | try: 77 | if log_path is None: 78 | # print(input_prompt) 79 | sys_response, prompt_token, response_token = api_func(input_prompt) 80 | print(f"\nsys_response: \n{sys_response}") 81 | print(f'\n prompt_token,response_token: {prompt_token} {response_token}\n') 82 | else: 83 | # check log_path and api_trace_json_path is not None 84 | if (log_path is None) or (api_trace_json_path is None): 85 | raise FileExistsError('log_path or api_trace_json_path is None, init_log_path first!') 86 | with open(log_path, 'a+', encoding='utf8') as log_fp, open(api_trace_json_path, 'a+', encoding='utf8') as trace_json_fp: 87 | print('\n' + f'*'*20 +'\n', file=log_fp) 88 | print(input_prompt, file=log_fp) 89 | print('\n' + f'='*20 +'\n', file=log_fp) 90 | sys_response, prompt_token, response_token = api_func(input_prompt) 91 | print(sys_response, file=log_fp) 92 | print(f'\n prompt_token,response_token: {prompt_token} {response_token}\n', file=log_fp) 93 | 94 | if len(world_dict) > 0: 95 | world_dict = {} 96 | 97 | if len(kwargs) > 0: 98 | world_dict = {} 99 | for k, v in kwargs.items(): 100 | world_dict[k] = v 101 | # prompt response to world_dict 102 | world_dict['response'] = '\n' + sys_response.strip() + '\n' 103 | world_dict['input_prompt'] = input_prompt.strip() + '\n' 104 | 105 | world_dict['prompt_token'] = prompt_token 106 | world_dict['response_token'] = response_token 107 | 108 | 109 | total_prompt_tokens += prompt_token 110 | total_response_tokens += response_token 111 | 112 | world_dict['cur_total_prompt_tokens'] = total_prompt_tokens 113 | world_dict['cur_total_response_tokens'] = total_response_tokens 114 | 115 | # world_dict to json str 116 | world_json_str = json.dumps(world_dict, ensure_ascii=False) 117 | print(world_json_str, file=trace_json_fp) 118 | 119 | world_dict = {} 120 | world_json_str = '' 121 | 122 | print(f'\n total_prompt_tokens,total_response_tokens: {total_prompt_tokens} {total_response_tokens}\n', file=log_fp) 123 | print(f'\n total_prompt_tokens,total_response_tokens: {total_prompt_tokens} {total_response_tokens}\n') 124 | return sys_response 125 | except Exception as ex: 126 | print(ex) 127 | print(f'Request {MODEL_NAME} failed. try {i} times. Sleep 10 secs.') 128 | time.sleep(10) 129 | 130 | raise ValueError('safe_call_llm error!') 131 | 132 | 133 | if __name__ == "__main__": 134 | res = safe_call_llm('你能做好Text-to-SQL的任务嘛?') 135 | print(res) 136 | -------------------------------------------------------------------------------- /main_scripts/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | import random 5 | import json 6 | import time 7 | import sqlite3 8 | from main_scripts.const import subq_pattern 9 | from typing import Dict, List 10 | 11 | 12 | def is_valid_date(date_str): 13 | if (not isinstance(date_str, str)): 14 | return False 15 | date_str = date_str.split()[0] 16 | if len(date_str) != 10: 17 | return False 18 | pattern = r'^\d{4}-\d{2}-\d{2}$' 19 | if re.match(pattern, date_str): 20 | year, month, day = map(int, date_str.split('-')) 21 | if year < 1 or month < 1 or month > 12 or day < 1 or day > 31: 22 | return False 23 | else: 24 | return True 25 | else: 26 | return False 27 | 28 | 29 | def is_valid_date_column(col_value_lst): 30 | for col_value in col_value_lst: 31 | if not is_valid_date(col_value): 32 | return False 33 | return True 34 | 35 | 36 | def rename_file(file_path, new_name): 37 | """ 38 | 给定原文件路径和新文件名,重命名文件 39 | 40 | @param file_path: 原文件路径, 如: /home/user/test.txt 41 | @param new_name: 新文件名, 如: backup 42 | @return: 新文件路径 43 | """ 44 | # 获取文件的目录和后缀名 45 | dir_name = os.path.dirname(file_path) 46 | file_name, file_ext = os.path.splitext(os.path.basename(file_path)) 47 | 48 | # 获取当前时间戳 49 | timestamp = str(int(time.time())) 50 | 51 | # 构建新的文件名 52 | new_file_name = new_name + '_' + timestamp + file_ext 53 | 54 | # 构建新的文件路径 55 | new_file_path = os.path.join(dir_name, new_file_name) 56 | 57 | # 重命名文件 58 | os.rename(file_path, new_file_path) 59 | 60 | return new_file_path 61 | 62 | 63 | def is_email(string): 64 | pattern = r'^[\w\.-]+@[\w\.-]+\.\w+$' 65 | match = re.match(pattern, string) 66 | if match: 67 | return True 68 | else: 69 | return False 70 | 71 | 72 | 73 | def extract_world_info(message_dict: dict): 74 | info_dict = {} 75 | info_dict['idx'] = message_dict['idx'] 76 | info_dict['db_id'] = message_dict['db_id'] 77 | info_dict['query'] = message_dict['query'] 78 | info_dict['evidence'] = message_dict.get('evidence', '') 79 | info_dict['difficulty'] = message_dict.get('difficulty', '') 80 | info_dict['ground_truth'] = message_dict.get('ground_truth', '') 81 | info_dict['send_to'] = message_dict.get('send_to', '') 82 | return info_dict 83 | 84 | 85 | def replace_multiple_spaces(text): 86 | # 定义正则表达式,匹配多个空字符 87 | pattern = r'\s+' 88 | # 将多个空字符替换成一个空格 89 | new_text = re.sub(pattern, ' ', text) 90 | return new_text 91 | 92 | 93 | # SQL parsing 94 | def extract_table_names(sql_query): 95 | # 使用正则表达式提取FROM子句中的表名 96 | # 使用正则表达式提取FROM子句中的表名 97 | # 假设表名位于FROM关键字后面,且没有特殊字符或空格 98 | sql_query = sql_query.replace('`', '') 99 | table_names = re.findall(r'FROM\s+([\w]+)', sql_query, re.IGNORECASE) + \ 100 | re.findall(r'JOIN\s+([\w]+)', sql_query, re.IGNORECASE) 101 | return set(table_names) 102 | 103 | 104 | def get_used_tables(sql, db_path) -> dict: # table_name -> chosen columns & discarded columns 105 | table_names = extract_table_names(sql) 106 | sch = {} 107 | conn = sqlite3.connect(db_path) 108 | conn.text_factory = lambda b: b.decode(errors="ignore") 109 | cursor = conn.cursor() 110 | for table_name in table_names: 111 | cursor.execute(f"PRAGMA table_info(`{table_name}`)") 112 | columns = cursor.fetchall() 113 | column_names = [cinfo[1] for cinfo in columns] 114 | sch[table_name] = { 115 | "chosen columns": column_names, 116 | "discarded columns": [] 117 | } 118 | return sch 119 | 120 | 121 | def get_all_tables(db_path) -> dict: 122 | conn = sqlite3.connect(db_path) 123 | conn.text_factory = lambda b: b.decode(errors="ignore") 124 | cursor = conn.cursor() 125 | cursor.execute("SELECT name FROM sqlite_master WHERE type=\'table\'") 126 | tables = cursor.fetchall() 127 | table_names = [a[0] for a in tables if a[0] != 'sqlite_sequence'] 128 | sch = {} 129 | for table_name in table_names: 130 | cursor.execute(f"PRAGMA table_info(`{table_name}`)") 131 | columns = cursor.fetchall() 132 | column_names = [cinfo[1] for cinfo in columns] 133 | sch[table_name] = { 134 | "chosen columns": column_names, 135 | "discarded columns": [] 136 | } 137 | return sch 138 | 139 | 140 | gold_schema = [] 141 | 142 | 143 | def get_gold_columns(idx, db_path) -> dict: 144 | global gold_schema 145 | if gold_schema == []: 146 | input_file = "data/bird/dev_gold_schema.json" 147 | with open(input_file, encoding='utf8') as f: 148 | gold_schema = json.load(f) 149 | table2cols = gold_schema[idx]["columns_map"] 150 | 151 | sch = {} 152 | conn = sqlite3.connect(db_path) 153 | conn.text_factory = lambda b: b.decode(errors="ignore") 154 | cursor = conn.cursor() 155 | cursor.execute("SELECT name FROM sqlite_master WHERE type=\'table\'") 156 | tables = cursor.fetchall() 157 | table_names = [a[0] for a in tables if a[0] != 'sqlite_sequence'] 158 | for table_name in table_names: 159 | cursor.execute(f"PRAGMA table_info(`{table_name}`)") 160 | columns = cursor.fetchall() 161 | all_columns = [cinfo[1] for cinfo in columns] 162 | gold_columns = table2cols.get(table_name, []) 163 | gold_columns = [str(item).replace('`', '') for item in gold_columns] 164 | unused_columns = list(set(all_columns).difference(set(gold_columns))) 165 | random.shuffle(unused_columns) 166 | sch[table_name] = { 167 | "chosen columns": gold_columns + unused_columns[:3], # used golden cols + unused random 3 cols 168 | "discarded columns": [] 169 | } 170 | return sch 171 | 172 | 173 | # GPT result parsing 174 | 175 | 176 | # def parse_json(res: str) -> dict: 177 | # lines = res.split('\n') 178 | # start_idx, end_idx = -1, -1 179 | # for idx in range(0, len(lines)): 180 | # if '```json' in lines[idx]: 181 | # start_idx = idx 182 | # break 183 | # if start_idx == -1: return {} 184 | # for idx in range(start_idx + 1, len(lines)): 185 | # if '```' in lines[idx]: 186 | # end_idx = idx 187 | # break 188 | # if end_idx == -1: return {} 189 | # jstr = " ".join(lines[start_idx + 1: end_idx]) 190 | # return json.loads(jstr) 191 | 192 | 193 | # parse json output 194 | def parse_json(res: str) -> dict: 195 | # lines = res.split('\n') 196 | # start_idx, end_idx = -1, -1 197 | # for idx in range(0, len(lines)): 198 | # if '```json' in lines[idx]: 199 | # start_idx = idx 200 | # break 201 | # if start_idx == -1: return {} 202 | # for idx in range(start_idx + 1, len(lines)): 203 | # if '```' in lines[idx]: 204 | # end_idx = idx 205 | # break 206 | # if end_idx == -1: return {} 207 | # jstr = " ".join(lines[start_idx + 1: end_idx]) 208 | # return json.loads(jstr) 209 | # todo: for debug 210 | return {} 211 | 212 | 213 | # check if valid format 214 | def check_selector_response(json_data: Dict) -> bool: 215 | FLAGS = ['keep_all', 'drop_all'] 216 | for k, v in json_data.items(): 217 | if isinstance(v, str): 218 | if v not in FLAGS: 219 | print(f"error: invalid table flag: {v}\n") 220 | print(f"json_data: {json_data}\n\n") 221 | return False 222 | elif isinstance(v, list): 223 | pass 224 | else: 225 | print(f"error: invalid flag type: {v}\n") 226 | print(f"json_data: {json_data}\n\n") 227 | return False 228 | return True 229 | 230 | def check_letter(s): 231 | for char in s: 232 | if char.isalpha(): 233 | return True 234 | return False 235 | 236 | def contain_value(s): 237 | pattern = r"\s+'(.*?)'" 238 | r = re.compile(pattern) 239 | if r.search(s) != None: 240 | return re.findall(pattern, s) 241 | else: 242 | return None 243 | 244 | 245 | def get_files(root, suffix): 246 | """ 247 | 获取指定目录下的所有指定后缀的文件 248 | :param root: 指定目录 str 类型 如:'.' 249 | :param suffix: 指定后缀 str 类型 如:'.txt' 250 | :return: 文件列表 251 | """ 252 | import os 253 | import glob 254 | if not os.path.exists(root): 255 | raise FileNotFoundError(f'path {root} not found.') 256 | res = glob.glob(f'{root}/**/*{suffix}', recursive=True) 257 | res = [os.path.abspath(p) for p in res] 258 | return res 259 | 260 | 261 | # read txt file to string list and strip empty lines 262 | def read_txt_file(path): 263 | with open(path, 'r', encoding='utf-8') as f: 264 | print(f"load txt file from {path}") 265 | return [line.strip() for line in f if line.strip()!= ''] 266 | 267 | def load_json_file(path): 268 | with open(path, 'r', encoding='utf-8') as f: 269 | print(f"load json file from {path}") 270 | return json.load(f) 271 | 272 | 273 | def load_jsonl_file(path): 274 | with open(path, 'r', encoding='utf-8') as f: 275 | data = [] 276 | for line in f: 277 | js_str = line.strip() 278 | if js_str == '': 279 | continue 280 | js = json.loads(js_str) 281 | data.append(js) 282 | print(f"load jsonl file from {path}") 283 | return data 284 | 285 | 286 | def save_file(path, string_lst): 287 | """ 288 | 保存文件 289 | :param path: 文件路径 str 类型 290 | :param string_lst: 字符串列表, 带有换行符 291 | """ 292 | with open(path, 'w', encoding='utf-8') as f: 293 | f.writelines(string_lst) 294 | print(f"save file to {path}") 295 | 296 | 297 | def save_json_file(path, data): 298 | with open(path, 'w', encoding='utf-8') as f: 299 | json.dump(data, f, indent=2, ensure_ascii=False) 300 | print(f"save json file to {path}") 301 | 302 | 303 | def save_jsonl_file(path, data): 304 | with open(path, 'w', encoding='utf-8') as f: 305 | for js in data: 306 | f.write(json.dumps(js, ensure_ascii=False) + '\n') 307 | print(f"save jsonl file to {path}") 308 | 309 | 310 | def parse_json(text: str) -> dict: 311 | # 查找字符串中的 JSON 块 312 | if "json" in text: 313 | start = text.find("```json") 314 | end = text.find("```", start + 7) 315 | # 如果找到了 JSON 块 316 | if start != -1 and end != -1: 317 | json_string = text[start + 7: end] 318 | print(json_string) 319 | try: 320 | # 解析 JSON 字符串 321 | json_data = json.loads(json_string) 322 | #valid = check_selector_response(json_data) 323 | return json_data 324 | except: 325 | print(f"error: parse json error!\n") 326 | print(f"json_string: {json_string}\n\n") 327 | pass 328 | elif "```" in text: 329 | start = text.find("```") 330 | end = text.find("```", start + 3) 331 | if start != -1 and end != -1: 332 | json_string = text[start + 3: end] 333 | 334 | try: 335 | # 解析 JSON 字符串 336 | json_data = json.loads(json_string) 337 | return json_data 338 | except: 339 | print(f"error: parse json error!\n") 340 | print(f"json_string: {json_string}\n\n") 341 | pass 342 | else: 343 | start = text.find("{") 344 | end = text.find("}", start + 1) 345 | if start != -1: 346 | json_string = text[start: end + 1] 347 | try: 348 | # 解析 JSON 字符串 349 | json_data = json.loads(json_string) 350 | return json_data 351 | except: 352 | print(f"error: parse json error!\n") 353 | print(f"json_string: {json_string}\n\n") 354 | pass 355 | """ 356 | # 如果找到了 JSON 块 357 | if start != -1 and end != -1: 358 | json_string = text[start + 7: end] 359 | 360 | try: 361 | # 解析 JSON 字符串 362 | json_data = json.loads(json_string) 363 | valid = check_selector_response(json_data) 364 | if valid: 365 | return json_data 366 | else: 367 | return {} 368 | except: 369 | print(f"error: parse json error!\n") 370 | print(f"json_string: {json_string}\n\n") 371 | pass 372 | """ 373 | return {} 374 | 375 | def get_chosen_schema(raw_linked_schema:dict) -> dict: 376 | if raw_linked_schema == {}: 377 | return {} 378 | try: 379 | linked_schema = dict() 380 | for key,value in raw_linked_schema.items(): 381 | for column in value: 382 | table = column.split('.')[0] 383 | col = column.split('.')[1] 384 | if linked_schema.get(table) == None: 385 | linked_schema[table] = {col} 386 | else: 387 | linked_schema[table].add(col) 388 | for key in linked_schema.keys(): 389 | linked_schema[key] = list(linked_schema[key]) 390 | return linked_schema 391 | except: 392 | return {} 393 | 394 | def parse_sql(res: str) -> str: 395 | """Only need SQL(startswith `SELECT`) of LLM result""" 396 | if 'SELECT' not in res and 'select' not in res: 397 | res = 'SELECT ' + res 398 | # match = re.search(parse_pattern, res, re.IGNORECASE | re.DOTALL) 399 | # if match: 400 | # sql = match.group().strip() 401 | # sql = sql.replace('```', '') # TODO 402 | # sql = sql.replace('\n', ' ') # TODO 403 | # return True, sql 404 | # else: 405 | # return False, "" 406 | res = res.replace('\n', ' ') 407 | return res.strip() 408 | 409 | 410 | def parse_single_sql(res: str) -> str: # if do not need decompose, just one code block is OK! 411 | """Return SQL in markdown block""" 412 | lines = res.split('\n') 413 | iter, start_idx, end_idx = -1, -1, -1 414 | for idx in range(iter + 1, len(lines)): 415 | if '```' in lines[idx]: 416 | start_idx = idx 417 | break 418 | if start_idx == -1: return "" 419 | for idx in range(start_idx + 1, len(lines)): 420 | if '```' in lines[idx]: 421 | end_idx = idx 422 | break 423 | if end_idx == -1: return "" 424 | 425 | return " ".join(lines[start_idx + 1: end_idx]) 426 | 427 | 428 | def parse_qa_pairs(res: str, end_pos=2333) -> list: 429 | lines = res.split('\n') 430 | qa_pairs = [] 431 | # end_pos = -1 432 | # for idx, line in enumerate(lines): 433 | # if 'final SQL' in line or 'final sql' in line: 434 | # end_pos = idx 435 | # if end_pos == -1: return [] 436 | end_pos = len(lines) if (end_pos == 2333) else end_pos 437 | for idx in range(0, end_pos): 438 | if re.findall(subq_pattern, lines[idx], re.IGNORECASE) != []: 439 | query = lines[idx] 440 | start_idx = -1 441 | for idx2 in range(idx + 1, end_pos): 442 | if '```' in lines[idx2]: 443 | start_idx = idx2 444 | break 445 | if start_idx == -1: return [] 446 | for idx3 in range(start_idx + 1, end_pos): 447 | if '```' in lines[idx3]: 448 | end_idx = idx3 449 | break 450 | if end_idx == -1: return [] 451 | answer = " ".join(lines[start_idx + 1: end_idx]) 452 | qa_pairs.append((str(query), str(answer))) 453 | idx = end_idx 454 | return qa_pairs 455 | 456 | def extract_sql(a: str) -> str: 457 | k = a.split("```") 458 | check_format = False 459 | sql_list = [] 460 | for statement in k: 461 | if 'sql' in statement: 462 | sql_list.append(statement) 463 | check_format = True 464 | if not check_format: 465 | for statement in k: 466 | if 'SELECT' in statement: 467 | sql_list.append(statement) 468 | 469 | if len(sql_list) > 0: 470 | sql = sql_list[-1] 471 | token_list = sql.split('\n') 472 | for i in range(len(token_list)): 473 | if "--" in token_list[i]: 474 | token_list[i] = token_list[i].split("--")[0] 475 | token_list[i] = token_list[i].strip() 476 | final_sql = ' '.join(token_list).replace('sql ','') 477 | else: 478 | final_sql = None 479 | return final_sql 480 | 481 | def extract_subquery(reply:str) -> list: 482 | try: 483 | answer = reply.split('\n') 484 | subqueries = [] 485 | for i in range(len(answer)): 486 | if "##" in answer[i]: 487 | sublist = answer[i].split(":") 488 | subqueries.append(sublist[1].strip()) 489 | except: 490 | subqueries = [] 491 | return subqueries 492 | 493 | def divide_targets_conditions(reply: str) -> list: 494 | decomposition = "Targets:" + reply.split("Targets:")[1] 495 | temp = decomposition.split("Conditions:") 496 | targets = temp[0].strip() 497 | conditions = "Conditions:" + temp[1].strip() 498 | return [targets,conditions] 499 | 500 | def extract_new_CoT(reply: str) -> str: 501 | #extract the new chain of thoughts 502 | if "**" in reply: 503 | return reply.split("**")[2] 504 | else: 505 | return reply 506 | 507 | def parse_subq(res: str) -> list: 508 | """Only sub questions after decomposition""" 509 | res = '-- ' + res 510 | sub_qustions = [] 511 | sub_qustions += res.split('-- ') 512 | sub_qustions = [q.strip() for q in sub_qustions if len(q) > 1] 513 | return sub_qustions 514 | 515 | 516 | def add_prefix(sql): 517 | if not sql.startswith('SELECT') and not sql.startswith('select'): 518 | sql = 'SELECT' + sql 519 | return sql 520 | 521 | def detect_special_char(name): 522 | for special_char in ['(', '-', ')', ' ', '/']: 523 | if special_char in name: 524 | return True 525 | 526 | return False 527 | 528 | def add_quotation_mark(s): 529 | return "`" + s + "`" 530 | 531 | def get_matched_content_sequence(matched_contents): 532 | content_sequence = "" 533 | if matched_contents == None: 534 | content_sequence = "NULL" 535 | elif len(matched_contents) != 0: 536 | for tc_name, contents in matched_contents.items(): 537 | table_name = tc_name.split(".")[0] 538 | column_name = tc_name.split(".")[1] 539 | #if detect_special_char(table_name): 540 | #table_name = add_quotation_mark(table_name) 541 | #if detect_special_char(column_name): 542 | column_name = add_quotation_mark(column_name) 543 | content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n" 544 | else: 545 | content_sequence = "NULL" 546 | 547 | return content_sequence.strip() 548 | 549 | # Spider data preprocess 550 | 551 | 552 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 553 | JOIN_KEYWORDS = ('join', 'on', 'as') 554 | 555 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 556 | UNIT_OPS = ('none', '-', '+', "*", '/') 557 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 558 | TABLE_TYPE = { 559 | 'sql': "sql", 560 | 'table_unit': "table_unit", 561 | } 562 | 563 | COND_OPS = ('and', 'or') 564 | SQL_OPS = ('intersect', 'union', 'except') 565 | ORDER_OPS = ('desc', 'asc') 566 | 567 | 568 | HARDNESS = { 569 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), 570 | "component2": ('except', 'union', 'intersect') 571 | } 572 | 573 | 574 | def get_nestedSQL(sql): 575 | nested = [] 576 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: 577 | if type(cond_unit[3]) is dict: 578 | nested.append(cond_unit[3]) 579 | if type(cond_unit[4]) is dict: 580 | nested.append(cond_unit[4]) 581 | if sql['intersect'] is not None: 582 | nested.append(sql['intersect']) 583 | if sql['except'] is not None: 584 | nested.append(sql['except']) 585 | if sql['union'] is not None: 586 | nested.append(sql['union']) 587 | return nested 588 | 589 | 590 | def has_agg(unit): 591 | return unit[0] != AGG_OPS.index('none') 592 | 593 | 594 | def count_agg(units): 595 | return len([unit for unit in units if has_agg(unit)]) 596 | 597 | 598 | def count_component1(sql): 599 | count = 0 600 | if len(sql['where']) > 0: 601 | count += 1 602 | if len(sql['groupBy']) > 0: 603 | count += 1 604 | if len(sql['orderBy']) > 0: 605 | count += 1 606 | if sql['limit'] is not None: 607 | count += 1 608 | if len(sql['from']['table_units']) > 0: # JOIN 609 | count += len(sql['from']['table_units']) - 1 610 | 611 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 612 | count += len([token for token in ao if token == 'or']) 613 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 614 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) 615 | 616 | return count 617 | 618 | 619 | def count_component2(sql): 620 | nested = get_nestedSQL(sql) 621 | return len(nested) 622 | 623 | 624 | def count_others(sql): 625 | count = 0 626 | # number of aggregation 627 | agg_count = count_agg(sql['select'][1]) 628 | agg_count += count_agg(sql['where'][::2]) 629 | agg_count += count_agg(sql['groupBy']) 630 | if len(sql['orderBy']) > 0: 631 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + 632 | [unit[2] for unit in sql['orderBy'][1] if unit[2]]) 633 | agg_count += count_agg(sql['having']) 634 | if agg_count > 1: 635 | count += 1 636 | 637 | # number of select columns 638 | if len(sql['select'][1]) > 1: 639 | count += 1 640 | 641 | # number of where conditions 642 | if len(sql['where']) > 1: 643 | count += 1 644 | 645 | # number of group by clauses 646 | if len(sql['groupBy']) > 1: 647 | count += 1 648 | 649 | return count 650 | 651 | 652 | def eval_hardness(sql): 653 | count_comp1_ = count_component1(sql) 654 | count_comp2_ = count_component2(sql) 655 | count_others_ = count_others(sql) 656 | 657 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: 658 | return "easy" 659 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ 660 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): 661 | return "medium" 662 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ 663 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ 664 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): 665 | return "hard" 666 | else: 667 | return "extra" 668 | -------------------------------------------------------------------------------- /output/gpt4_mag.json: -------------------------------------------------------------------------------- 1 | {"idx": 0, "db_id": "california_schools", "query": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?", "evidence": "Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`", "extracted_schema": {"frpm": ["County Name", "Enrollment (K-12)", "Free Meal Count (K-12)"], "satscores": ["enroll12", "cname"], "schools": ["CDSCode", "County", "City"]}, "ground_truth": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1", "difficulty": "simple", "matched_list": ["frpm.`County Name` = 'Alameda'", "schools.`MailCity` = 'Alameda'", "satscores.`cname` = 'Alameda'", "schools.`County` = 'Alameda'", "schools.`GSoffered` = 'K-12'", "schools.`GSserved` = 'K-12'", "schools.`City` = 'Alameda'"], "complete_desc_str": "# Table: frpm\n[\n (CDSCode , CDSCode.),\n (Academic Year , Academic Year. Value examples: ['2014-2015'].),\n (County Code , County Code. Value examples: ['19', '37', '30', '36', '33', '43'].),\n (District Code , District Code.),\n (School Code , School Code. Value examples: ['0000000', '9010745', '6121081', '6121073', '6121016', '6120968'].),\n (County Name , County Name. Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Riverside', 'Santa Clara'].),\n (District Name , District Name.),\n (School Name , School Name.),\n (District Type , District Type. Value examples: ['Unified School District', 'Elementary School District', 'High School District', 'County Office of Education (COE)', 'State Board of Education', 'Statewide Benefit Charter'].),\n (School Type , School Type. Value examples: [None, 'Elementary Schools (Public)', 'High Schools (Public)', 'Intermediate/Middle Schools (Public)', 'Continuation High Schools', 'Alternative Schools of Choice', 'K-12 Schools (Public)'].),\n (Educational Option Type , Educational Option Type. Value examples: [None, 'Traditional', 'Continuation School', 'Alternative School of Choice', 'Community Day School', 'Special Education School', 'County Community School'].),\n (NSLP Provision Status , NSLP Provision Status. Value examples: [None, 'Provision 2', 'Breakfast Provision 2', 'CEP', 'Multiple Provision Types', 'Provision 1', 'Lunch Provision 2'].),\n (Charter School (Y/N) , Charter School (Y/N). Value examples: [None, 0, 1].),\n (Charter School Number , Charter School Number. Value examples: [None, '00D2', '0756', '00D4', '00D1', '1700', '1696'].),\n (Charter Funding Type , Charter Funding Type. Value examples: [None, 'Directly funded', 'Locally funded', 'Not in CS funding model'].),\n (IRC , IRC. Value examples: [None, 0, 1].),\n (Low Grade , Low Grade. Value examples: ['K', '9', '6', '7', 'P', '5'].),\n (High Grade , High Grade. Value examples: ['12', '5', '8', '6', '3', '4'].),\n (Enrollment (K-12) , Enrollment (K-12).),\n (Free Meal Count (K-12) , Free Meal Count (K-12).),\n (Percent (%) Eligible Free (K-12) , Percent (%) Eligible Free (K-12).),\n (FRPM Count (K-12) , Free or Reduced Price Meal Count (K-12).),\n (Percent (%) Eligible FRPM (K-12) , Percent (%) Eligible FRPM (K-12).),\n (Enrollment (Ages 5-17) , Enrollment (Ages 5-17).),\n (Free Meal Count (Ages 5-17) , Free Meal Count (Ages 5-17).),\n (Percent (%) Eligible Free (Ages 5-17) , Percent (%) Eligible Free (Ages 5-17).),\n (FRPM Count (Ages 5-17) , FRPM Count (Ages 5-17).),\n (Percent (%) Eligible FRPM (Ages 5-17) , Percent (%) Eligible FRPM (Ages 5-17).),\n (2013-14 CALPADS Fall 1 Certification Status , 2013-14 CALPADS Fall 1 Certification Status. Value examples: [1].)\n]\n# Table: satscores\n[\n (cds , California Department Schools.),\n (rtype , rtype. Value examples: ['S', 'D'].),\n (sname , school name. Value examples: [None, 'Middle College High', 'John F. Kennedy High', 'Independence High', 'Foothill High', 'Washington High', 'Redwood High'].),\n (dname , district segment. Value examples: ['Los Angeles Unified', 'San Diego Unified', 'Oakland Unified', 'San Francisco Unified', 'Kern High', 'East Side Union High'].),\n (cname , county name. Value examples: ['Los Angeles', 'San Diego', 'San Bernardino', 'Riverside', 'Orange', 'Alameda'].),\n (enroll12 , enrollment (1st-12nd grade).),\n (NumTstTakr , Number of Test Takers in this school.),\n (AvgScrRead , average scores in Reading.),\n (AvgScrMath , average scores in Math.),\n (AvgScrWrite , average scores in writing.),\n (NumGE1500 , Number of Test Takers Whose Total SAT Scores Are Greater or Equal to 1500.)\n]\n# Table: schools\n[\n (CDSCode , CDSCode.),\n (NCESDist , This field represents the 7-digit National Center for Educational Statistics (NCES) school district identification number. The first 2 digits identify the state and the last 5 digits identify the school district. Combined, they make a unique 7-digit ID for each school district.. Value examples: [None, '0622710', '0634320', '0628050', '0634410', '0614550', '0633840'].),\n (NCESSchool , This field represents the 5-digit NCES school identification number. The NCESSchool combined with the NCESDist form a unique 12-digit ID for each school.. Value examples: [None, '12271', '13785', '13747', '12909', '12311', '12003'].),\n (StatusType , This field identifies the status of the district.. Value examples: ['Active', 'Closed', 'Merged', 'Pending'].),\n (County , County name. Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Santa Clara', 'Alameda'].),\n (District , District.),\n (School , School.),\n (Street , Street.),\n (StreetAbr , The abbreviated street address of the school, district, or administrative authority’s physical location..),\n (City , City. Value examples: [None, 'Los Angeles', 'San Diego', 'San Jose', 'Sacramento', 'Fresno', 'Oakland'].),\n (Zip , Zip. Value examples: [None, '91723', '95959', '94558', '92345', '90278', '90266'].),\n (State , State. Value examples: [None, 'CA'].),\n (MailStreet , MailStreet.),\n (MailStrAbr , mailing street address.),\n (MailCity , mailing city. Value examples: [None, 'Los Angeles', 'San Diego', 'San Jose', 'Sacramento', 'Oakland', 'Fresno'].),\n (MailZip , mailing zip. Value examples: [None, '94606', '91710-4130', '92626-4300', '90266', '95336-0032', '92084'].),\n (MailState , mailing state. Value examples: [None, 'CA'].),\n (Phone , Phone. Value examples: [None, '(805) 581-0202', '(619) 243-5014', '(562) 698-8121', '(559) 683-4667', '(650) 329-2800', '(530) 625-5600'].),\n (Ext , The phone number extension of the school, district, or administrative authority.. Value examples: [None, '200', '201', '111', '352', '221', '203'].),\n (Website , The website address of the school, district, or administrative authority..),\n (OpenDate , The date the school opened.. Value examples: [None, '1980-07-01'].),\n (ClosedDate , The date the school closed.. Value examples: [None, '1989-06-30'].),\n (Charter , This field identifies a charter school.. Value examples: [None, 0, 1].),\n (CharterNum , The charter school number,. Value examples: [None, '00D2', '0756', '00D4', '00D1', '1826', '1710'].),\n (FundingType , Indicates the charter school funding type. Value examples: [None, 'Directly funded', 'Locally funded', 'Not in CS funding model'].),\n (DOC , District Ownership Code. Value examples: ['54', '52', '00', '56', '98', '02'].),\n (DOCType , The District Ownership Code Type is the text description of the DOC category.. Value examples: ['Unified School District', 'Elementary School District', 'County Office of Education (COE)', 'High School District', 'Regional Occupation Center/Program (ROC/P)', 'State Board of Education'].),\n (SOC , The School Ownership Code is a numeric code used to identify the type of school.. Value examples: [None, '60', '66', '62', '68', '09', '69'].),\n (SOCType , The School Ownership Code Type is the text description of the type of school.. Value examples: [None, 'Elementary Schools (Public)', 'High Schools (Public)', 'Intermediate/Middle Schools (Public)', 'Continuation High Schools', 'Special Education Schools (Public)', 'District Community Day Schools'].),\n (EdOpsCode , The Education Option Code is a short text description of the type of education offered.. Value examples: [None, 'TRAD', 'CON', 'COMMDAY', 'ALTSOC', 'SPEC', 'COMM'].),\n (EdOpsName , Educational Option Name. Value examples: [None, 'Traditional', 'Continuation School', 'Community Day School', 'Alternative School of Choice', 'Special Education School', 'County Community School'].),\n (EILCode , The Educational Instruction Level Code is a short text description of the institution's type relative to the grade range served.. Value examples: [None, 'ELEM', 'HS', 'UG', 'INTMIDJR', 'ELEMHIGH', 'A'].),\n (EILName , The Educational Instruction Level Name is the long text description of the institution’s type relative to the grade range served.. Value examples: [None, 'Elementary', 'High School', 'Ungraded', 'Intermediate/Middle/Junior High', 'Elementary-High Combination', 'Adult'].),\n (GSoffered , The grade span offered is the lowest grade and the highest grade offered or supported by the school, district, or administrative authority. This field might differ from the grade span served as reported in the most recent certified California Longitudinal Pupil Achievement (CALPADS) Fall 1 data collection.. Value examples: [None, 'K-5', '9-12', 'K-6', 'K-8', '6-8', 'K-12'].),\n (GSserved , It is the lowest grade and the highest grade of student enrollment as reported in the most recent certified CALPADS Fall 1 data collection. Only K–12 enrollment is reported through CALPADS. This field may differ from the grade span offered.. Value examples: [None, 'K-5', 'K-6', '9-12', 'K-8', '6-8', '7-8'].),\n (Virtual , This field identifies the type of virtual instruction offered by the school. Virtual instruction is instruction in which students and teachers are separated by time and/or location, and interaction occurs via computers and/or telecommunications technologies.. Value examples: [None, 'N', 'P', 'F'].),\n (Magnet , This field identifies whether a school is a magnet school and/or provides a magnet program.. Value examples: [None, 0, 1].),\n (Latitude , The angular distance (expressed in degrees) between the location of the school, district, or administrative authority and the equator measured north to south..),\n (Longitude , The angular distance (expressed in degrees) between the location of the school, district, or administrative authority and the prime meridian (Greenwich, England) measured from west to east..),\n (AdmFName1 , administrator's first name 1. Value examples: [None, 'David', 'Michael', 'John', 'Jennifer', 'Lisa', 'Robert'].),\n (AdmLName1 , administrator's last name 1.),\n (AdmEmail1 , administrator's email address 1.),\n (AdmFName2 , administrator's first name 2. Value examples: [None, 'Steve', 'Delphine', 'Tim', 'Michael', 'Lisa', 'Jeff'].),\n (AdmLName2 , administrator's last name 2. Value examples: [None, 'Sherman', 'Jones', 'Wilson', 'Petersen', 'Landry', 'Hambright'].),\n (AdmEmail2 , administrator's email address 2.),\n (AdmFName3 , administrator's first name 3. Value examples: [None, 'Mark', 'Eric', 'Vickie', 'Tracie', 'Shawn', 'Sarah'].),\n (AdmLName3 , administrator's last name 3. Value examples: [None, 'Zerpoli', 'Yount', 'Woodruff', 'Wood', 'Wolk', 'Spotts'].),\n (AdmEmail3 , administrator's email address 3.),\n (LastUpdate , Last Update. Value examples: ['2015-06-18'].)\n]", "summary_str": "# frpm: Data on schools' Free and Reduced Price Meal (FRPM) eligibility and enrollment across various educational options and provisions.\n# satscores: SAT performance metrics for schools including average scores and number of high-scoring test takers.\n# schools: Detailed information about schools, including their addresses, contact details, and administrative data.", "chosen_db_schem_dict": {"frpm": ["CDSCode", "Academic Year", "County Code", "District Code", "School Code", "County Name", "District Name", "School Name", "District Type", "School Type", "Educational Option Type", "NSLP Provision Status", "Charter School (Y/N)", "Charter School Number", "Charter Funding Type", "IRC", "Low Grade", "High Grade", "Enrollment (K-12)", "Free Meal Count (K-12)", "Percent (%) Eligible Free (K-12)", "FRPM Count (K-12)", "Percent (%) Eligible FRPM (K-12)", "Enrollment (Ages 5-17)", "Free Meal Count (Ages 5-17)", "Percent (%) Eligible Free (Ages 5-17)", "FRPM Count (Ages 5-17)", "Percent (%) Eligible FRPM (Ages 5-17)", "2013-14 CALPADS Fall 1 Certification Status"], "satscores": ["cds", "rtype", "sname", "dname", "cname", "enroll12", "NumTstTakr", "AvgScrRead", "AvgScrMath", "AvgScrWrite", "NumGE1500"], "schools": ["CDSCode", "NCESDist", "NCESSchool", "StatusType", "County", "District", "School", "Street", "StreetAbr", "City", "Zip", "State", "MailStreet", "MailStrAbr", "MailCity", "MailZip", "MailState", "Phone", "Ext", "Website", "OpenDate", "ClosedDate", "Charter", "CharterNum", "FundingType", "DOC", "DOCType", "SOC", "SOCType", "EdOpsCode", "EdOpsName", "EILCode", "EILName", "GSoffered", "GSserved", "Virtual", "Magnet", "Latitude", "Longitude", "AdmFName1", "AdmLName1", "AdmEmail1", "AdmFName2", "AdmLName2", "AdmEmail2", "AdmFName3", "AdmLName3", "AdmEmail3", "LastUpdate"]}, "desc_str": "# frpm: [CDSCode (TEXT), , Academic Year (TEXT), , County Code (TEXT), , District Code (INTEGER), , School Code (TEXT), , County Name (TEXT), , District Name (TEXT), , School Name (TEXT), , District Type (TEXT), , School Type (TEXT), , Educational Option Type (TEXT), , NSLP Provision Status (TEXT), , Charter School (Y/N) (INTEGER), , Charter School Number (TEXT), , Charter Funding Type (TEXT), , IRC (INTEGER), , Low Grade (TEXT), , High Grade (TEXT), , Enrollment (K-12) (REAL), , Free Meal Count (K-12) (REAL), , Percent (%) Eligible Free (K-12) (REAL), , FRPM Count (K-12) (REAL), , Percent (%) Eligible FRPM (K-12) (REAL), , Enrollment (Ages 5-17) (REAL), , Free Meal Count (Ages 5-17) (REAL), , Percent (%) Eligible Free (Ages 5-17) (REAL), , FRPM Count (Ages 5-17) (REAL), , Percent (%) Eligible FRPM (Ages 5-17) (REAL), , 2013-14 CALPADS Fall 1 Certification Status (INTEGER), ]\n# satscores: [cds (TEXT), , rtype (TEXT), , sname (TEXT), , dname (TEXT), , cname (TEXT), , enroll12 (INTEGER), , NumTstTakr (INTEGER), , AvgScrRead (INTEGER), , AvgScrMath (INTEGER), , AvgScrWrite (INTEGER), , NumGE1500 (INTEGER), ]\n# schools: [CDSCode (TEXT), , NCESDist (TEXT), , NCESSchool (TEXT), , StatusType (TEXT), , County (TEXT), , District (TEXT), , School (TEXT), , Street (TEXT), , StreetAbr (TEXT), , City (TEXT), , Zip (TEXT), , State (TEXT), , MailStreet (TEXT), , MailStrAbr (TEXT), , MailCity (TEXT), , MailZip (TEXT), , MailState (TEXT), , Phone (TEXT), , Ext (TEXT), , Website (TEXT), , OpenDate (DATE), , ClosedDate (DATE), , Charter (INTEGER), , CharterNum (TEXT), , FundingType (TEXT), , DOC (TEXT), , DOCType (TEXT), , SOC (TEXT), , SOCType (TEXT), , EdOpsCode (TEXT), , EdOpsName (TEXT), , EILCode (TEXT), , EILName (TEXT), , GSoffered (TEXT), , GSserved (TEXT), , Virtual (TEXT), , Magnet (INTEGER), , Latitude (REAL), , Longitude (REAL), , AdmFName1 (TEXT), , AdmLName1 (TEXT), , AdmEmail1 (TEXT), , AdmFName2 (TEXT), , AdmLName2 (TEXT), , AdmEmail2 (TEXT), , AdmFName3 (TEXT), , AdmLName3 (TEXT), , AdmEmail3 (TEXT), , LastUpdate (DATE), ]", "fk_str": "frpm.`CDSCode` = schools.`CDSCode`\nsatscores.`cds` = schools.`CDSCode`", "pk_str": "frpm.`CDSCode` | satscores.`cds` | schools.`CDSCode`", "pruned": true, "match_content_str": "NULL", "columns_details_str": "frpm.`County Name`: The column 'County Name' in Table has column descriptions of \"County Name\". Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Riverside', 'Santa Clara'].\nfrpm.`Enrollment (K-12)`: The column 'Enrollment (K-12)' in Table has column descriptions of \"Enrollment (K-12)\". \nfrpm.`Free Meal Count (K-12)`: The column 'Free Meal Count (K-12)' in Table has column descriptions of \"Free Meal Count (K-12)\". \nsatscores.`cname`: The column 'cname' in Table has column descriptions of \"county name\". Value examples: ['Los Angeles', 'San Diego', 'San Bernardino', 'Riverside', 'Orange', 'Alameda'].\nsatscores.`enroll12`: The column 'enroll12' in Table has column descriptions of \"enrollment (1st-12nd grade)\". \nschools.`CDSCode`: The column 'CDSCode' in Table has column descriptions of \"CDSCode\". \nschools.`County`: The column 'County' in Table has column descriptions of \"County name\". Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Santa Clara', 'Alameda'].\nschools.`City`: The column 'City' in Table has column descriptions of \"City\". Value examples: [None, 'Los Angeles', 'San Diego', 'San Jose', 'Sacramento', 'Fresno', 'Oakland'].\n", "subquery_list": ["What is the highest eligible free rate for K-12 students in the schools in Alameda County?"], "initial_state": false, "old_chain_of_thoughts": "To answer the question \"What is the highest eligible free rate for K-12 students in the schools in Alameda County?\", we need to follow these steps:\n\n1. **Identify the relevant columns**: \n - The eligible free rate for K-12 is calculated using the formula: `Free Meal Count (K-12) / Enrollment (K-12)`.\n - We need to select the maximum of this calculated value.\n - We also need to filter the results to only include schools in Alameda County.\n\n2. **Identify the tables involved**: \n - We will use the `frpm` table to filter by county and to access the `Free Meal Count (K-12)` and `Enrollment (K-12)` columns.\n - We will also need to join with the `schools` table to ensure we are only considering schools in Alameda County.\n\n3. **Establish the connection between tables**: \n - The `frpm` table is connected to the `schools` table through the `CDSCode` column.\n\n4. **Construct the SQL query**:\n - We will calculate the eligible free rate and then find the maximum value for schools in Alameda County.\n\nHere is the SQL query that accomplishes this:\n\n```sql\nSELECT MAX(FRM.`Free Meal Count (K-12)` / FRM.`Enrollment (K-12)`) AS highest_eligible_free_rate\nFROM frpm AS FRM\nJOIN schools AS SCH ON FRM.`CDSCode` = SCH.`CDSCode`\nWHERE FRM.`County Name` = 'Alameda'\n```\n\n### Explanation of the SQL:\n- We are selecting the maximum eligible free rate calculated from the `Free Meal Count (K-12)` divided by `Enrollment (K-12)`.\n- We join the `frpm` table with the `schools` table on the `CDSCode` to ensure we are looking at the correct schools.\n- We filter the results to only include those from Alameda County.\n\nThis SQL query will return the highest eligible free rate for K-12 students in schools located in Alameda County.", "final_sql": "SELECT MAX(FRM.`Free Meal Count (K-12)` / FRM.`Enrollment (K-12)`) AS highest_eligible_free_rate FROM frpm AS FRM JOIN schools AS SCH ON FRM.`CDSCode` = SCH.`CDSCode` WHERE FRM.`County Name` = 'Alameda' ", "fixed": false, "try_times": 1, "pred": "SELECT MAX(FRM.`Free Meal Count (K-12)` / FRM.`Enrollment (K-12)`) AS highest_eligible_free_rate FROM frpm AS FRM JOIN schools AS SCH ON FRM.`CDSCode` = SCH.`CDSCode` WHERE FRM.`County Name` = 'Alameda' "} 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | func_timeout==4.3.5 2 | nltk==3.8.1 3 | numpy==1.26.4 4 | openai==1.6.0 5 | pandas==2.2.2 6 | protobuf==5.26.1 7 | sql_metadata==2.11.0 8 | sqlparse==0.5.0 9 | tiktoken==0.5.2 10 | tqdm==4.66.1 11 | transformers==4.40.1 12 | rapidfuzz 13 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from main_scripts.utils import * 3 | from main_scripts.chat_manager import ChatManager 4 | from main_scripts.utils import get_gold_columns 5 | from main_scripts.const import SYSTEM_NAME 6 | import re 7 | from tqdm import tqdm 8 | import time 9 | import argparse 10 | import sys 11 | import os 12 | import json 13 | import traceback 14 | from nltk.tokenize import word_tokenize 15 | from nltk import ngrams 16 | 17 | 18 | 19 | def init_spider_message(idx: int, item: dict) -> dict: 20 | """ 21 | Construct message for text-to-SQL task 22 | :param idx: start from 0 23 | :param item: one sample of dataset 24 | :return: initial message object of group chat 25 | """ 26 | db_id, query, evidence, gt = item['db_id'], item['question'], str(""), item['query'] 27 | difficulty = eval_hardness(item['sql']) 28 | user_message = { 29 | "idx": idx, 30 | "db_id": db_id, 31 | "query": query, 32 | "evidence": evidence, 33 | "extracted_schema": {}, 34 | "ground_truth": gt, 35 | "difficulty": difficulty, 36 | "send_to": SYSTEM_NAME 37 | } 38 | return user_message 39 | 40 | 41 | def init_bird_message(idx: int, item: dict, db_path: str=None, use_gold_schema: bool = False) -> dict: 42 | """ 43 | Construct message for text-to-SQL task 44 | :param idx: start from 0 45 | :param item: one sample of dataset 46 | :return: initial message object of group chat 47 | """ 48 | db_id, query, evidence, gt, difficulty = item['db_id'], \ 49 | item['question'], \ 50 | item.get('evidence','NULL'), \ 51 | item.get('SQL', ''), \ 52 | item.get('difficulty', 'simple'), \ 53 | 54 | 55 | gold_schema_path = './data/bird/dev_gold_schema.json' 56 | gold_schema = {} 57 | key = f"{db_id.strip()}\t{query.strip()}" 58 | if use_gold_schema: 59 | if os.path.exists(gold_schema_path): 60 | all_gold_schema_dict = load_json_file(gold_schema_path) 61 | if key in all_gold_schema_dict: 62 | gold_schema = all_gold_schema_dict[key] 63 | else: 64 | raise ValueError(f"Can't find gold schema for {key}") 65 | 66 | user_message = { 67 | "idx": idx, 68 | "db_id": db_id, 69 | "query": query, 70 | "evidence": evidence, 71 | "extracted_schema": gold_schema if gold_schema else {}, 72 | "ground_truth": gt, 73 | "difficulty": difficulty, 74 | "send_to": SYSTEM_NAME 75 | } 76 | return user_message 77 | 78 | def contain_date(s): 79 | r = re.compile('.*\'\d{4}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])\'.*') 80 | if r.match(s) != None: 81 | return r.match(s).group() 82 | else: 83 | return None 84 | 85 | def run_batch(dataset_name, input_file, output_file, db_path, tables_json_path, start_pos=0, log_file=None, dataset_mode='dev', use_gold_schema=False, without_selector=False): 86 | chat_manager = ChatManager(data_path=db_path, 87 | tables_json_path=tables_json_path, 88 | log_path=log_file, 89 | dataset_name=dataset_name, 90 | dataset_path=input_file, 91 | model_name='gpt-4', 92 | lazy=True, 93 | without_selector=without_selector) 94 | # load dataset 95 | batch = load_json_file(input_file) 96 | # resume from last checkpoint 97 | finished_ids = set() 98 | """ 99 | file = open(output_file, 'r', encoding='utf-8') 100 | output_data_lst = [] 101 | for line in file.readlines(): 102 | dic = json.loads(line) 103 | output_data_lst.append(dic) 104 | for o in output_data_lst: 105 | finished_ids.add(o['idx']) 106 | """ 107 | if os.path.exists(output_file): 108 | output_data_lst = load_jsonl_file(output_file) 109 | for o in output_data_lst: 110 | finished_ids.add(o['idx']) 111 | 112 | unfinished_ids = [n for n in range(len(batch)) if n not in finished_ids and n >= start_pos] 113 | print(f"unsolved idxs are : {unfinished_ids}") 114 | 115 | # add question_id if needed 116 | for k, item in enumerate(batch): 117 | if 'question_id' not in item: 118 | item['question_id'] = k 119 | 120 | # skip some json data 121 | excluded_db_ids = [] 122 | if dataset_mode == 'train': 123 | exclude_txt = './data/bird_train/excluded_db_ids.txt' 124 | excluded_db_ids = read_txt_file(exclude_txt) 125 | new_batch = [] 126 | exclude_db_json_cnt = 0 # for exclude some dbs in bird train set 127 | for k, item in enumerate(batch): 128 | q_id = item['question_id'] 129 | if q_id not in unfinished_ids: 130 | continue 131 | if dataset_mode == 'train': 132 | # skip excluded db_id 133 | if item['db_id'] in excluded_db_ids: 134 | exclude_db_json_cnt += 1 135 | continue 136 | new_batch.append(item) 137 | 138 | if exclude_db_json_cnt: 139 | print(f"excluded {exclude_db_json_cnt} excluded db json data") 140 | time.sleep(2) 141 | batch = new_batch 142 | 143 | 144 | # generate SQL one by one, and save result one by one 145 | with open(output_file, 'a+', encoding='utf-8') as fp: 146 | total_num = len(batch) 147 | for cur_idx, item in tqdm(enumerate(batch)): 148 | idx = item['question_id'] 149 | db_id = item['db_id'] 150 | print(f"\n\nprocessing: {cur_idx}/{total_num}\n\n", flush=True) 151 | if idx not in unfinished_ids: continue 152 | if dataset_name == "spider": 153 | user_message = init_spider_message(idx, item) # imitate user send a question to system 154 | elif dataset_name == "bird": 155 | user_message = init_bird_message(idx, item, db_path=db_path, use_gold_schema=use_gold_schema) # imitate user send a question to system 156 | try: 157 | chat_manager.start(user_message) 158 | try: 159 | """ 160 | del user_message['desc_str'] 161 | del user_message['fk_str'] 162 | """ 163 | del user_message['send_to'] 164 | del user_message['db_content_dict'] 165 | except: 166 | pass 167 | print(json.dumps(user_message, ensure_ascii=False), file=fp, flush=True) 168 | except Exception as e: 169 | # for debug 170 | traceback.print_exc() 171 | print(f"Exception: {e}, sleep 20 seconds.", flush=True) 172 | time.sleep(20) 173 | # raise Exception(str(e)) 174 | print(f"\n\ndeal {cur_idx+1}/{total_num} done!\n\n") 175 | print(f"Result dump into {output_file}", file=sys.stdout, flush=True) 176 | 177 | # export evaluation results 178 | out_dir = os.path.dirname(output_file) 179 | 180 | # transfer SQL result to supportable BIRD format 181 | if dataset_name == "bird": 182 | 183 | evaluation_file_path = f"{out_dir}/predict_{dataset_mode}_new.json" 184 | with open(evaluation_file_path, 'w', encoding='utf8') as fout: 185 | output_json_list = load_jsonl_file(output_file) 186 | output_json_list = sorted(output_json_list, key=lambda i: i['idx']) 187 | eval_tuple_lst = [] 188 | for o in output_json_list: 189 | if o.get('pred') == None: 190 | pred_sql = o['sub_sql'].strip() 191 | else: 192 | pred_sql = o['pred'].strip() 193 | if ' || \' \' || ' in pred_sql: 194 | pred_sql = pred_sql.replace(' || \' \' || ',', ') 195 | pred_sql = replace_multiple_spaces(pred_sql) 196 | sql_and_db_str = pred_sql + '\t' + '----- bird -----' + '\t' + o['db_id'] 197 | obj = [o['query'], sql_and_db_str] 198 | eval_tuple_lst.append(obj) 199 | json.dump(eval_tuple_lst, fp=fout, ensure_ascii=False, indent=2) 200 | print(f"BIRD format file dump into {evaluation_file_path}") 201 | """ 202 | #evaluation_file_path = f"{out_dir}/mac_pred_codellama7b_epoch5.sql" 203 | evaluation_file_path = f"{out_dir}/newprompt_mac.sql" 204 | with open(evaluation_file_path, "w") as f: 205 | output_json_list = load_jsonl_file(output_file) 206 | output_json_list = sorted(output_json_list, key=lambda i: i['idx']) 207 | for output in output_json_list: 208 | pred_sql = output['pred'].strip() 209 | pred_sql = replace_multiple_spaces(pred_sql) 210 | f.write(pred_sql.replace("\n", " ") + "\n") 211 | """ 212 | 213 | elif dataset_name == "spider": 214 | evaluation_file_path = f"{out_dir}/pred_{dataset_mode}.sql" 215 | spider_sql_lst = [] 216 | output_json_lst = load_jsonl_file(output_file) 217 | for output_json in output_json_lst: 218 | pred_sql = output_json['pred'] 219 | pred_sql = replace_multiple_spaces(pred_sql) 220 | spider_sql_lst.append(pred_sql.strip() + '\n') 221 | save_file(evaluation_file_path, spider_sql_lst) 222 | print(f"Spider format file dump into {evaluation_file_path}") 223 | else: 224 | raise NotImplementedError 225 | 226 | 227 | def check_all_paths(args): 228 | if not os.path.exists(args.input_file): 229 | raise FileNotFoundError(f"Input file {args.input_file} not found") 230 | if not os.path.exists(args.db_path): 231 | raise FileNotFoundError(f"Database path {args.db_path} not found") 232 | if not os.path.exists(args.tables_json_path): 233 | raise FileNotFoundError(f"Tables json path {args.tables_json_path} not found") 234 | 235 | def obtain_n_grams(sequence, max_n): 236 | ''' 237 | returns all grams of sequence less than or equal to `max_n` 238 | ''' 239 | tokens = word_tokenize(sequence) 240 | all_grams = [] 241 | for n in range(1, max_n + 1): 242 | all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)]) 243 | 244 | return all_grams 245 | 246 | 247 | 248 | if __name__ == "__main__": 249 | parser = argparse.ArgumentParser() 250 | parser.add_argument('--dataset_name', type=str, default='spider', choices=['spider', 'bird'], help='dataset name') 251 | parser.add_argument('--dataset_mode', type=str, default='dev', choices=['train', 'dev', 'test'], help='dataset mode') 252 | parser.add_argument('--input_file', type=str, required=True, help='path to dataset input') 253 | parser.add_argument('--db_path', type=str, required=True, help='path to databases in dataset') 254 | parser.add_argument('--tables_json_path', type=str, default=None, help='path to tables.json') 255 | parser.add_argument('--output_file', type=str, required=True, help='path to predicted output') 256 | parser.add_argument('--log_file', type=str, default='', help='path to log file if needed') 257 | parser.add_argument('--start_pos', type=int, default=0, help='start position of a batch') 258 | parser.add_argument('--use_gold_schema', action='store_true', default=False) 259 | parser.add_argument('--without_selector', action='store_true', default=False) 260 | args = parser.parse_args() 261 | 262 | # print each key argument and corresponding value 263 | for key, value in vars(args).items(): 264 | print(f"{key}: {value}") 265 | 266 | check_all_paths(args) 267 | 268 | # print the args in clean json format 269 | args_json_str = json.dumps(vars(args), indent=2, ensure_ascii=False) 270 | 271 | print(f"args:\n{args_json_str}") 272 | time.sleep(3) 273 | 274 | 275 | run_batch( 276 | dataset_name=args.dataset_name, 277 | dataset_mode=args.dataset_mode, 278 | input_file=args.input_file, 279 | output_file=args.output_file, 280 | db_path=args.db_path, 281 | tables_json_path= args.tables_json_path, 282 | log_file=args.log_file, 283 | start_pos=args.start_pos, 284 | use_gold_schema=args.use_gold_schema, 285 | without_selector=args.without_selector, 286 | ) 287 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python ./run.py --dataset_name bird \ 4 | --dataset_mode dev \ 5 | --input_file ./data/bird/dev/dev.json \ 6 | --db_path ./data/bird/dev/dev_databases/ \ 7 | --tables_json_path ./data/bird/dev/dev_tables.json \ 8 | --output_file ./output/gpt4_mag.json \ 9 | --log_file ./output/log/log/txt \ 10 | --start_pos 0 11 | 12 | echo "Generate SQL on bird dev data done!" 13 | 14 | 15 | --------------------------------------------------------------------------------