├── README.md
├── assets
└── Workflow.png
├── evaluation
├── evaluation_bird_ex.py
├── evaluation_bird_ves.py
├── evaluation_spider.py
├── exec_eval.py
├── parse.py
└── process_sql.py
├── evaluation_bird_ex_ves.sh
├── main_scripts
├── MAG.py
├── __init__.py
├── __pycache__
│ ├── MAG.cpython-39.pyc
│ ├── __init__.cpython-39.pyc
│ ├── bridge_content_encoder.cpython-39.pyc
│ ├── chat_manager.cpython-39.pyc
│ ├── const.cpython-39.pyc
│ ├── llm.cpython-39.pyc
│ └── utils.cpython-39.pyc
├── bridge_content_encoder.py
├── chat_manager.py
├── const.py
├── llm.py
└── utils.py
├── output
├── gpt4_mag.json
└── log
│ └── log
│ ├── api_trace.json
│ └── txt
├── requirements.txt
├── run.py
└── run.sh
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## 📖 Introduction
3 |
4 | Thank you for your interest in our work: MAG-SQL: Multi-Agent Generative Approach with Soft Schema Linking and Iterative Sub-SQL Refinement for Text-to-SQL
5 |
6 | You can find our paper [[here]](https://arxiv.org/abs/2408.07930). In this paper, we propose MAG-SQL, a multi-agent generative approach with soft schema linking and iterative Sub-SQL refinement. In our framework, an entity-based method with tables' summary is used to select the columns in database, and a novel targets-conditions decomposition method is introduced to decompose those complex questions. Additionally, we build a iterative generating module which includes a Sub-SQL Generator and Sub-SQL Refiner, introducing external oversight for each step of generation. The workflow is shown below:
7 |
8 |
9 |
10 | ## 🌍 Environment
11 |
12 | 1. Config your local environment.
13 |
14 | ```bash
15 | conda create -n magsql python=3.9 -y
16 | conda activate magsql
17 | cd MAG-SQL
18 | pip install -r requirements.txt
19 | python -c "import nltk; nltk.download('punkt')"
20 | ```
21 |
22 | Note: we use `openai==0.28.1`, which use `openai.ChatCompletions.create` to call api.
23 |
24 | 1. Edit openai config at **main_scripts/llm.py**, and set related environment variables of Azure OpenAI API.
25 |
26 |
27 | ## 📜 Data Preparation
28 | The main dataset we used is BIRD(https://bird-bench.github.io/).
29 | For evaluation, the data should be in **data/bird/dev**
30 |
31 |
32 | ## 🚀 Run
33 |
34 | You can enter the following commands at the command line to evaluate and change the corresponding file paths
35 | ```bash
36 | python run.py --dataset_name bird --dataset_mode dev --input_file ./data/bird/dev/dev.json --db_path ./data/bird/dev/dev_databases/ --tables_json_path ./data/bird/dev/dev_tables.json --output_file ./output/gpt4_mag.json --log_file ./output/log/log/txt --start_pos 0
37 | ```
38 |
39 | ## 📑 Evaluation
40 |
41 | We provide the script for Linux systems.
42 | **evaluation_bird_ex_ves.sh**
43 |
44 | ## 🎈 Useful Resources
45 |
46 | Our project is built based on MAC-SQL: [[Repo]](https://github.com/wbbeyourself/MAC-SQL)
47 |
48 | ## 📰 Citation
49 | If you find our work is helpful, please cite as:
50 | ```text
51 | @article{xie2024mag,
52 | title={MAG-SQL: Multi-Agent Generative Approach with Soft Schema Linking and Iterative Sub-SQL Refinement for Text-to-SQL},
53 | author={Xie, Wenxuan and Wu, Gaochen and Zhou, Bowen},
54 | journal={arXiv preprint arXiv:2408.07930},
55 | year={2024}
56 | }
57 | ```
--------------------------------------------------------------------------------
/assets/Workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/assets/Workflow.png
--------------------------------------------------------------------------------
/evaluation/evaluation_bird_ex.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | import json
5 | import argparse
6 | import sqlite3
7 | import multiprocessing as mp
8 | from func_timeout import func_timeout, FunctionTimedOut
9 |
10 | def replace_multiple_spaces(text):
11 | # 定义正则表达式,匹配多个空字符
12 | pattern = r'\s+'
13 | # 将多个空字符替换成一个空格
14 | new_text = re.sub(pattern, ' ', text)
15 | return new_text
16 |
17 | def load_json(dir):
18 | with open(dir, 'r', encoding='utf8') as j:
19 | contents = json.loads(j.read())
20 | return contents
21 |
22 | def save_json_file(path, data):
23 | with open(path, 'w', encoding='utf-8') as f:
24 | json.dump(data, f, indent=2, ensure_ascii=False)
25 | print(f"save json file to {path}")
26 |
27 | def result_callback(result):
28 | exec_result.append(result)
29 |
30 |
31 | def execute_sql(predicted_sql,ground_truth, db_path):
32 | conn = sqlite3.connect(db_path)
33 | # Connect to the database
34 | cursor = conn.cursor()
35 | cursor.execute(predicted_sql)
36 | predicted_res = cursor.fetchall()
37 | cursor.execute(ground_truth)
38 | ground_truth_res = cursor.fetchall()
39 | res = 0
40 | # todo: this should permute column order!
41 | if set(predicted_res) == set(ground_truth_res):
42 | res = 1
43 | return res
44 |
45 |
46 |
47 | def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out):
48 | try:
49 | res = func_timeout(meta_time_out, execute_sql,
50 | args=(predicted_sql, ground_truth, db_place))
51 | except KeyboardInterrupt:
52 | sys.exit(0)
53 | except FunctionTimedOut:
54 | result = [(f'timeout',)]
55 | res = 0
56 | except Exception as e:
57 | result = [(f'error',)] # possibly len(query) > 512 or not executable
58 | res = 0
59 | # print(result)
60 | # result = str(set([ret[0] for ret in result]))
61 | result = {'sql_idx': idx, 'res': res}
62 | # print(result)
63 | return result
64 |
65 |
66 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
67 | clean_sqls = []
68 | db_path_list = []
69 | if mode == 'gpt':
70 | sql_data = json.load(open(sql_path, 'r', encoding='utf8'))
71 | for idx, sql_str in sql_data: # .items()
72 | if type(sql_str) == str:
73 | sql, db_name = sql_str.split('\t----- bird -----\t')
74 | else:
75 | sql, db_name = " ", "financial"
76 | clean_sqls.append(sql)
77 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
78 |
79 | elif mode == 'gt':
80 | sqls = open(sql_path, encoding='utf8')
81 | sql_txt = sqls.readlines()
82 | for idx, sql_str in enumerate(sql_txt):
83 | sql, db_name = sql_str.strip().split('\t')
84 | clean_sqls.append(sql)
85 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
86 |
87 | return clean_sqls, db_path_list
88 |
89 | def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
90 | pool = mp.Pool(processes=num_cpus)
91 | for i, sql_pair in enumerate(sqls):
92 |
93 | predicted_sql, ground_truth = sql_pair
94 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback)
95 | pool.close()
96 | pool.join()
97 |
98 | def sort_results(list_of_dicts):
99 | return sorted(list_of_dicts, key=lambda x: x['sql_idx'])
100 |
101 | def compute_acc_by_diff(exec_results, diff_json_path):
102 | num_queries = len(exec_results)
103 | results = [res['res'] for res in exec_results]
104 | contents = load_json(diff_json_path)
105 | simple_results, moderate_results, challenging_results = [], [], []
106 |
107 | for i, content in enumerate(contents):
108 | difficulty = content.get('difficulty', 'simple')
109 | if difficulty == 'simple':
110 | try:
111 | simple_results.append(exec_results[i])
112 | except Exception as e:
113 | print(e)
114 | import pdb
115 | pdb.set_trace()
116 |
117 | if difficulty == 'moderate':
118 | moderate_results.append(exec_results[i])
119 |
120 | if difficulty == 'challenging':
121 | challenging_results.append(exec_results[i])
122 |
123 | simple_acc = sum([res['res'] for res in simple_results])/len(simple_results)
124 |
125 | if len(moderate_results) == 0:
126 | moderate_acc = 0
127 | else:
128 | moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results)
129 |
130 | if len(challenging_results) == 0:
131 | challenging_acc = 0
132 | else:
133 | challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results)
134 |
135 | all_acc = sum(results)/num_queries
136 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
137 | return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists
138 |
139 |
140 |
141 | def print_data(score_lists,count_lists):
142 | levels = ['simple', 'moderate', 'challenging', 'total']
143 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
144 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))
145 |
146 | print('====================================== ACCURACY =====================================')
147 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists))
148 |
149 |
150 | if __name__ == '__main__':
151 | args_parser = argparse.ArgumentParser()
152 | args_parser.add_argument('--predicted_sql_json_path', type=str, required=True)
153 | args_parser.add_argument('--ground_truth_sql_path', type=str, required=True)
154 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev', choices=['train', 'dev', 'test'])
155 | args_parser.add_argument('--db_root_path', type=str, required=True)
156 | args_parser.add_argument('--num_cpus', type=int, default=1)
157 | args_parser.add_argument('--meta_time_out', type=float, default=30.0)
158 | args_parser.add_argument('--mode_predict', type=str, default='gpt')
159 | args_parser.add_argument('--difficulty',type=str, default='simple')
160 | args_parser.add_argument('--diff_json_path',type=str,default='./data/bird/dev/dev.json')
161 | args = args_parser.parse_args()
162 | exec_result = []
163 |
164 | pred_queries, db_paths = package_sqls(args.predicted_sql_json_path, args.db_root_path,
165 | mode=args.mode_predict, data_mode=args.data_mode)
166 | if len(pred_queries) == 0:
167 | raise ValueError(f'Empty data in {args.predicted_sql_json_path}')
168 | # generate gt sqls:
169 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_sql_path, args.db_root_path, mode='gt',
170 | data_mode=args.data_mode)
171 | print(len(pred_queries),len(gt_queries))
172 | assert len(pred_queries) == len(gt_queries), "len(pred_queries) != len(gt_queries)"
173 | query_pairs = list(zip(pred_queries, gt_queries))
174 | run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
175 | exec_result = sort_results(exec_result)
176 |
177 | # save ex results
178 | out_dir = os.path.dirname(args.predicted_sql_json_path)
179 | if not os.path.exists(out_dir):
180 | os.makedirs(out_dir, exist_ok=True)
181 | result_json_path = os.path.join(out_dir, f'eval_result_{args.data_mode}.json')
182 |
183 | # relocate idx of exec_result
184 | raw_json_data = load_json(args.diff_json_path)
185 | pred_sqls = [replace_multiple_spaces(s) for s in pred_queries]
186 | result_json_lst = []
187 | for i, item in enumerate(raw_json_data):
188 | item['pred'] = pred_sqls[i]
189 | item['gold'] = replace_multiple_spaces(item.get('SQL', ''))
190 | if 'SQL' in item:
191 | del item['SQL']
192 | item['res'] = exec_result[i]['res']
193 | result_json_lst.append(item)
194 | save_json_file(result_json_path, result_json_lst)
195 |
196 | print('start calculate')
197 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = \
198 | compute_acc_by_diff(exec_result, args.diff_json_path)
199 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
200 | print_data(score_lists,count_lists)
201 | print('===========================================================================================')
202 | print("Finished evaluation")
203 |
--------------------------------------------------------------------------------
/evaluation/evaluation_bird_ves.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import sys
4 | import json
5 | import numpy as np
6 | import argparse
7 | import sqlite3
8 | import multiprocessing as mp
9 | from func_timeout import func_timeout, FunctionTimedOut
10 | import time
11 | import math
12 |
13 |
14 | def result_callback(result):
15 | exec_result.append(result)
16 |
17 |
18 | def clean_abnormal(input):
19 | input = np.asarray(input)
20 | processed_list = []
21 | mean = np.mean(input, axis=0)
22 | std = np.std(input, axis=0)
23 | for x in input:
24 | if x < mean + 3 * std and x > mean - 3 * std:
25 | processed_list.append(x)
26 | return processed_list
27 |
28 |
29 | def execute_sql(sql, db_path):
30 | # Connect to the database
31 | conn = sqlite3.connect(db_path)
32 | # Create a cursor object
33 | cursor = conn.cursor()
34 | start_time = time.time()
35 | cursor.execute(sql)
36 | exec_time = time.time() - start_time
37 | return exec_time
38 |
39 |
40 | def iterated_execute_sql(predicted_sql, ground_truth, db_path, iterate_num):
41 | conn = sqlite3.connect(db_path)
42 | diff_list = []
43 | cursor = conn.cursor()
44 | cursor.execute(predicted_sql)
45 | predicted_res = cursor.fetchall()
46 | cursor.execute(ground_truth)
47 | ground_truth_res = cursor.fetchall()
48 | time_ratio = 0
49 | if set(predicted_res) == set(ground_truth_res):
50 | for i in range(iterate_num):
51 | predicted_time = execute_sql(predicted_sql, db_path)
52 | ground_truth_time = execute_sql(ground_truth, db_path)
53 | diff_list.append(ground_truth_time / predicted_time)
54 | processed_diff_list = clean_abnormal(diff_list)
55 | time_ratio = sum(processed_diff_list) / len(processed_diff_list)
56 | return time_ratio
57 |
58 |
59 | def execute_model(predicted_sql, ground_truth, db_place, idx, iterate_num, meta_time_out):
60 | try:
61 | # you can personalize the total timeout number
62 | # larger timeout leads to more stable ves
63 | # while it needs more your patience....
64 | print(idx, file=sys.stdout, flush=True)
65 | time_ratio = func_timeout(meta_time_out * iterate_num, iterated_execute_sql,
66 | args=(predicted_sql, ground_truth, db_place, iterate_num))
67 | # print([idx, math.sqrt(time_ratio)])
68 | except KeyboardInterrupt:
69 | sys.exit(0)
70 | except FunctionTimedOut:
71 | result = [(f'timeout',)]
72 | time_ratio = 0
73 | except Exception as e:
74 | result = [(f'error',)] # possibly len(query) > 512 or not executable
75 | time_ratio = 0
76 | result = {'sql_idx': idx, 'time_ratio': time_ratio}
77 | return result
78 |
79 |
80 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
81 | clean_sqls = []
82 | db_path_list = []
83 | if mode == 'gpt':
84 | sql_data = json.load(open(sql_path, 'r', encoding='utf8'))
85 | for idx, sql_str in sql_data:
86 | if type(sql_str) == str:
87 | sql, db_name = sql_str.split('\t----- bird -----\t')
88 | else:
89 | sql, db_name = " ", "financial"
90 | clean_sqls.append(sql)
91 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
92 |
93 | elif mode == 'gt':
94 | sqls = open(sql_path, encoding='utf8')
95 | sql_txt = sqls.readlines()
96 | for idx, sql_str in enumerate(sql_txt):
97 | sql, db_name = sql_str.strip().split('\t')
98 | clean_sqls.append(sql)
99 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
100 |
101 | return clean_sqls, db_path_list
102 |
103 |
104 | def run_sqls_parallel(sqls, db_places, num_cpus=1, iterate_num=100, meta_time_out=30.0):
105 | pool = mp.Pool(processes=num_cpus)
106 | for i, sql_pair in enumerate(sqls):
107 | predicted_sql, ground_truth = sql_pair
108 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, iterate_num, meta_time_out),
109 | callback=result_callback)
110 | pool.close()
111 | pool.join()
112 |
113 |
114 | def sort_results(list_of_dicts):
115 | return sorted(list_of_dicts, key=lambda x: x['sql_idx'])
116 |
117 |
118 | def compute_ves(exec_results):
119 | num_queries = len(exec_results)
120 | if num_queries == 0:
121 | return 0
122 | total_ratio = 0
123 | count = 0
124 |
125 | for i, result in enumerate(exec_results):
126 | if result['time_ratio'] != 0:
127 | count += 1
128 | total_ratio += math.sqrt(result['time_ratio']) * 100
129 | ves = (total_ratio / num_queries)
130 | return ves
131 |
132 |
133 | def load_json(dir):
134 | with open(dir, 'r', encoding='utf8') as j:
135 | contents = json.loads(j.read())
136 | return contents
137 |
138 |
139 | def compute_ves_by_diff(exec_results, diff_json_path):
140 | num_queries = len(exec_results)
141 | contents = load_json(diff_json_path)
142 | simple_results, moderate_results, challenging_results = [], [], []
143 | for i, content in enumerate(contents):
144 | if content['difficulty'] == 'simple':
145 | simple_results.append(exec_results[i])
146 | if content['difficulty'] == 'moderate':
147 | moderate_results.append(exec_results[i])
148 | if content['difficulty'] == 'challenging':
149 | challenging_results.append(exec_results[i])
150 | simple_ves = compute_ves(simple_results)
151 | moderate_ves = compute_ves(moderate_results)
152 | challenging_ves = compute_ves(challenging_results)
153 | all_ves = compute_ves(exec_results)
154 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
155 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists
156 |
157 |
158 | def print_data(score_lists, count_lists):
159 | levels = ['simple', 'moderate', 'challenging', 'total']
160 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
161 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))
162 |
163 | print('========================================= VES ========================================')
164 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('ves', *score_lists))
165 |
166 |
167 | if __name__ == '__main__':
168 | args_parser = argparse.ArgumentParser()
169 | args_parser.add_argument('--predicted_sql_json_path', type=str, required=True, default='')
170 | args_parser.add_argument('--ground_truth_sql_path', type=str, required=True, default='')
171 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev')
172 | args_parser.add_argument('--db_root_path', type=str, required=True, default='')
173 | args_parser.add_argument('--num_cpus', type=int, default=1)
174 | args_parser.add_argument('--meta_time_out', type=float, default=30.0)
175 | args_parser.add_argument('--mode_gt', type=str, default='gt')
176 | args_parser.add_argument('--mode_predict', type=str, default='gpt')
177 | args_parser.add_argument('--diff_json_path', type=str, required=True, default='')
178 | args = args_parser.parse_args()
179 | exec_result = []
180 |
181 | pred_queries, db_paths = package_sqls(args.predicted_sql_json_path, args.db_root_path,
182 | mode=args.mode_predict, data_mode=args.data_mode)
183 | if len(pred_queries) == 0:
184 | raise ValueError(f'Empty data in {args.predicted_sql_json_path}')
185 | # generate gt sqls:
186 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_sql_path, args.db_root_path, mode='gt',
187 | data_mode=args.data_mode)
188 |
189 | assert len(pred_queries) == len(gt_queries), "len(pred_queries) != len(gt_queries)"
190 | query_pairs = list(zip(pred_queries, gt_queries))
191 | run_sqls_parallel(query_pairs, iterate_num=100, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
192 | exec_result = sort_results(exec_result)
193 | print('start calculate')
194 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = \
195 | compute_ves_by_diff(exec_result, args.diff_json_path)
196 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves]
197 | print_data(score_lists, count_lists)
198 | print('===========================================================================================')
199 | print("Finished evaluation")
200 |
201 |
202 |
--------------------------------------------------------------------------------
/evaluation/evaluation_spider.py:
--------------------------------------------------------------------------------
1 | ################################
2 | # val: number(float)/string(str)/sql(dict)
3 | # col_unit: (agg_id, col_id, isDistinct(bool))
4 | # val_unit: (unit_op, col_unit1, col_unit2)
5 | # table_unit: (table_type, col_unit/sql)
6 | # cond_unit: (not_op, op_id, val_unit, val1, val2)
7 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
8 | # sql {
9 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
10 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
11 | # 'where': condition
12 | # 'groupBy': [col_unit1, col_unit2, ...]
13 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
14 | # 'having': condition
15 | # 'limit': None/limit value
16 | # 'intersect': None/sql
17 | # 'except': None/sql
18 | # 'union': None/sql
19 | # }
20 | ################################
21 |
22 | import os
23 | import json
24 | import sqlite3
25 | import argparse
26 |
27 | from process_sql import get_schema, Schema, get_sql
28 | from exec_eval import eval_exec_match
29 |
30 | # Flag to disable value evaluation
31 | DISABLE_VALUE = True
32 | # Flag to disable distinct in select evaluation
33 | DISABLE_DISTINCT = True
34 |
35 |
36 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
37 | JOIN_KEYWORDS = ('join', 'on', 'as')
38 |
39 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
40 | UNIT_OPS = ('none', '-', '+', "*", '/')
41 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
42 | TABLE_TYPE = {
43 | 'sql': "sql",
44 | 'table_unit': "table_unit",
45 | }
46 |
47 | COND_OPS = ('and', 'or')
48 | SQL_OPS = ('intersect', 'union', 'except')
49 | ORDER_OPS = ('desc', 'asc')
50 |
51 |
52 | HARDNESS = {
53 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
54 | "component2": ('except', 'union', 'intersect')
55 | }
56 |
57 |
58 | def condition_has_or(conds):
59 | return 'or' in conds[1::2]
60 |
61 |
62 | def condition_has_like(conds):
63 | return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
64 |
65 |
66 | def condition_has_sql(conds):
67 | for cond_unit in conds[::2]:
68 | val1, val2 = cond_unit[3], cond_unit[4]
69 | if val1 is not None and type(val1) is dict:
70 | return True
71 | if val2 is not None and type(val2) is dict:
72 | return True
73 | return False
74 |
75 |
76 | def val_has_op(val_unit):
77 | return val_unit[0] != UNIT_OPS.index('none')
78 |
79 |
80 | def has_agg(unit):
81 | return unit[0] != AGG_OPS.index('none')
82 |
83 |
84 | def accuracy(count, total):
85 | if count == total:
86 | return 1
87 | return 0
88 |
89 |
90 | def recall(count, total):
91 | if count == total:
92 | return 1
93 | return 0
94 |
95 |
96 | def F1(acc, rec):
97 | if (acc + rec) == 0:
98 | return 0
99 | return (2. * acc * rec) / (acc + rec)
100 |
101 |
102 | def get_scores(count, pred_total, label_total):
103 | if pred_total != label_total:
104 | return 0,0,0
105 | elif count == pred_total:
106 | return 1,1,1
107 | return 0,0,0
108 |
109 |
110 | def eval_sel(pred, label):
111 | pred_sel = pred['select'][1]
112 | label_sel = label['select'][1]
113 | label_wo_agg = [unit[1] for unit in label_sel]
114 | pred_total = len(pred_sel)
115 | label_total = len(label_sel)
116 | cnt = 0
117 | cnt_wo_agg = 0
118 |
119 | for unit in pred_sel:
120 | if unit in label_sel:
121 | cnt += 1
122 | label_sel.remove(unit)
123 | if unit[1] in label_wo_agg:
124 | cnt_wo_agg += 1
125 | label_wo_agg.remove(unit[1])
126 |
127 | return label_total, pred_total, cnt, cnt_wo_agg
128 |
129 |
130 | def eval_where(pred, label):
131 | pred_conds = [unit for unit in pred['where'][::2]]
132 | label_conds = [unit for unit in label['where'][::2]]
133 | label_wo_agg = [unit[2] for unit in label_conds]
134 | pred_total = len(pred_conds)
135 | label_total = len(label_conds)
136 | cnt = 0
137 | cnt_wo_agg = 0
138 |
139 | for unit in pred_conds:
140 | if unit in label_conds:
141 | cnt += 1
142 | label_conds.remove(unit)
143 | if unit[2] in label_wo_agg:
144 | cnt_wo_agg += 1
145 | label_wo_agg.remove(unit[2])
146 |
147 | return label_total, pred_total, cnt, cnt_wo_agg
148 |
149 |
150 | def eval_group(pred, label):
151 | pred_cols = [unit[1] for unit in pred['groupBy']]
152 | label_cols = [unit[1] for unit in label['groupBy']]
153 | pred_total = len(pred_cols)
154 | label_total = len(label_cols)
155 | cnt = 0
156 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
157 | label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
158 | for col in pred_cols:
159 | if col in label_cols:
160 | cnt += 1
161 | label_cols.remove(col)
162 | return label_total, pred_total, cnt
163 |
164 |
165 | def eval_having(pred, label):
166 | pred_total = label_total = cnt = 0
167 | if len(pred['groupBy']) > 0:
168 | pred_total = 1
169 | if len(label['groupBy']) > 0:
170 | label_total = 1
171 |
172 | pred_cols = [unit[1] for unit in pred['groupBy']]
173 | label_cols = [unit[1] for unit in label['groupBy']]
174 | if pred_total == label_total == 1 \
175 | and pred_cols == label_cols \
176 | and pred['having'] == label['having']:
177 | cnt = 1
178 |
179 | return label_total, pred_total, cnt
180 |
181 |
182 | def eval_order(pred, label):
183 | pred_total = label_total = cnt = 0
184 | if len(pred['orderBy']) > 0:
185 | pred_total = 1
186 | if len(label['orderBy']) > 0:
187 | label_total = 1
188 | if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
189 | ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
190 | cnt = 1
191 | return label_total, pred_total, cnt
192 |
193 |
194 | def eval_and_or(pred, label):
195 | pred_ao = pred['where'][1::2]
196 | label_ao = label['where'][1::2]
197 | pred_ao = set(pred_ao)
198 | label_ao = set(label_ao)
199 |
200 | if pred_ao == label_ao:
201 | return 1,1,1
202 | return len(pred_ao),len(label_ao),0
203 |
204 |
205 | def get_nestedSQL(sql):
206 | nested = []
207 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
208 | if type(cond_unit[3]) is dict:
209 | nested.append(cond_unit[3])
210 | if type(cond_unit[4]) is dict:
211 | nested.append(cond_unit[4])
212 | if sql['intersect'] is not None:
213 | nested.append(sql['intersect'])
214 | if sql['except'] is not None:
215 | nested.append(sql['except'])
216 | if sql['union'] is not None:
217 | nested.append(sql['union'])
218 | return nested
219 |
220 |
221 | def eval_nested(pred, label):
222 | label_total = 0
223 | pred_total = 0
224 | cnt = 0
225 | if pred is not None:
226 | pred_total += 1
227 | if label is not None:
228 | label_total += 1
229 | if pred is not None and label is not None:
230 | cnt += Evaluator().eval_exact_match(pred, label)
231 | return label_total, pred_total, cnt
232 |
233 |
234 | def eval_IUEN(pred, label):
235 | lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
236 | lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
237 | lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
238 | label_total = lt1 + lt2 + lt3
239 | pred_total = pt1 + pt2 + pt3
240 | cnt = cnt1 + cnt2 + cnt3
241 | return label_total, pred_total, cnt
242 |
243 |
244 | def get_keywords(sql):
245 | res = set()
246 | if len(sql['where']) > 0:
247 | res.add('where')
248 | if len(sql['groupBy']) > 0:
249 | res.add('group')
250 | if len(sql['having']) > 0:
251 | res.add('having')
252 | if len(sql['orderBy']) > 0:
253 | res.add(sql['orderBy'][0])
254 | res.add('order')
255 | if sql['limit'] is not None:
256 | res.add('limit')
257 | if sql['except'] is not None:
258 | res.add('except')
259 | if sql['union'] is not None:
260 | res.add('union')
261 | if sql['intersect'] is not None:
262 | res.add('intersect')
263 |
264 | # or keyword
265 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
266 | if len([token for token in ao if token == 'or']) > 0:
267 | res.add('or')
268 |
269 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
270 | # not keyword
271 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
272 | res.add('not')
273 |
274 | # in keyword
275 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
276 | res.add('in')
277 |
278 | # like keyword
279 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
280 | res.add('like')
281 |
282 | return res
283 |
284 |
285 | def eval_keywords(pred, label):
286 | pred_keywords = get_keywords(pred)
287 | label_keywords = get_keywords(label)
288 | pred_total = len(pred_keywords)
289 | label_total = len(label_keywords)
290 | cnt = 0
291 |
292 | for k in pred_keywords:
293 | if k in label_keywords:
294 | cnt += 1
295 | return label_total, pred_total, cnt
296 |
297 |
298 | def count_agg(units):
299 | return len([unit for unit in units if has_agg(unit)])
300 |
301 |
302 | def count_component1(sql):
303 | count = 0
304 | if len(sql['where']) > 0:
305 | count += 1
306 | if len(sql['groupBy']) > 0:
307 | count += 1
308 | if len(sql['orderBy']) > 0:
309 | count += 1
310 | if sql['limit'] is not None:
311 | count += 1
312 | if len(sql['from']['table_units']) > 0: # JOIN
313 | count += len(sql['from']['table_units']) - 1
314 |
315 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
316 | count += len([token for token in ao if token == 'or'])
317 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
318 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
319 |
320 | return count
321 |
322 |
323 | def count_component2(sql):
324 | nested = get_nestedSQL(sql)
325 | return len(nested)
326 |
327 |
328 | def count_others(sql):
329 | count = 0
330 | # number of aggregation
331 | agg_count = count_agg(sql['select'][1])
332 | agg_count += count_agg(sql['where'][::2])
333 | agg_count += count_agg(sql['groupBy'])
334 | if len(sql['orderBy']) > 0:
335 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
336 | [unit[2] for unit in sql['orderBy'][1] if unit[2]])
337 | agg_count += count_agg(sql['having'])
338 | if agg_count > 1:
339 | count += 1
340 |
341 | # number of select columns
342 | if len(sql['select'][1]) > 1:
343 | count += 1
344 |
345 | # number of where conditions
346 | if len(sql['where']) > 1:
347 | count += 1
348 |
349 | # number of group by clauses
350 | if len(sql['groupBy']) > 1:
351 | count += 1
352 |
353 | return count
354 |
355 |
356 | class Evaluator:
357 | """A simple evaluator"""
358 | def __init__(self):
359 | self.partial_scores = None
360 |
361 | def eval_hardness(self, sql):
362 | count_comp1_ = count_component1(sql)
363 | count_comp2_ = count_component2(sql)
364 | count_others_ = count_others(sql)
365 |
366 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
367 | return "easy"
368 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
369 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
370 | return "medium"
371 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
372 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
373 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
374 | return "hard"
375 | else:
376 | return "extra"
377 |
378 | def eval_exact_match(self, pred, label):
379 | partial_scores = self.eval_partial_match(pred, label)
380 | self.partial_scores = partial_scores
381 |
382 | for key, score in partial_scores.items():
383 | if score['f1'] != 1:
384 | return 0
385 |
386 | if len(label['from']['table_units']) > 0:
387 | label_tables = sorted(label['from']['table_units'])
388 | pred_tables = sorted(pred['from']['table_units'])
389 | return label_tables == pred_tables
390 | return 1
391 |
392 | def eval_partial_match(self, pred, label):
393 | res = {}
394 |
395 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
396 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
397 | res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
398 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
399 | res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
400 |
401 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
402 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
403 | res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
404 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
405 | res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
406 |
407 | label_total, pred_total, cnt = eval_group(pred, label)
408 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
409 | res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
410 |
411 | label_total, pred_total, cnt = eval_having(pred, label)
412 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
413 | res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
414 |
415 | label_total, pred_total, cnt = eval_order(pred, label)
416 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
417 | res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
418 |
419 | label_total, pred_total, cnt = eval_and_or(pred, label)
420 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
421 | res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
422 |
423 | label_total, pred_total, cnt = eval_IUEN(pred, label)
424 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
425 | res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
426 |
427 | label_total, pred_total, cnt = eval_keywords(pred, label)
428 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
429 | res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
430 |
431 | return res
432 |
433 |
434 | def isValidSQL(sql, db):
435 | conn = sqlite3.connect(db)
436 | cursor = conn.cursor()
437 | try:
438 | cursor.execute(sql)
439 | except:
440 | return False
441 | return True
442 |
443 |
444 |
445 | def print_formated_s(row_name, l, element_format):
446 | template = "{:20} " + ' '.join([element_format] * len(l))
447 | print(template.format(row_name, *l))
448 |
449 |
450 | def print_scores(scores, etype, include_turn_acc=True):
451 | turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4']
452 | levels = ['easy', 'medium', 'hard', 'extra', 'all']
453 | if include_turn_acc:
454 | levels.append('joint_all')
455 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
456 | 'group', 'order', 'and/or', 'IUEN', 'keywords']
457 |
458 | print_formated_s("", levels, '{:20}')
459 | counts = [scores[level]['count'] for level in levels]
460 | print_formated_s("count", counts, '{:<20d}')
461 |
462 | if etype in ["all", "exec"]:
463 | print ('===================== EXECUTION ACCURACY =====================')
464 | exec_scores = [scores[level]['exec'] for level in levels]
465 | print_formated_s("execution", exec_scores, '{:<20.3f}')
466 |
467 | if etype in ["all", "match"]:
468 | print ('\n====================== EXACT MATCHING ACCURACY =====================')
469 | exact_scores = [scores[level]['exact'] for level in levels]
470 | print_formated_s("exact match", exact_scores, '{:<20.3f}')
471 | print ('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
472 | for type_ in partial_types:
473 | this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
474 | print_formated_s(type_, this_scores, '{:<20.3f}')
475 |
476 | print ('---------------------- PARTIAL MATCHING RECALL ----------------------')
477 | for type_ in partial_types:
478 | this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
479 | print_formated_s(type_, this_scores, '{:<20.3f}')
480 |
481 | print ('---------------------- PARTIAL MATCHING F1 --------------------------')
482 | for type_ in partial_types:
483 | this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
484 | print_formated_s(type_, this_scores, '{:<20.3f}')
485 |
486 | if include_turn_acc:
487 | print()
488 | print()
489 | print_formated_s("", turns, '{:20}')
490 | counts = [scores[turn]['count'] for turn in turns]
491 | print_formated_s("count", counts, "{:<20d}")
492 |
493 | if etype in ["all", "exec"]:
494 | print ('===================== TURN EXECUTION ACCURACY =====================')
495 | exec_scores = [scores[turn]['exec'] for turn in turns]
496 | print_formated_s("execution", exec_scores, '{:<20.3f}')
497 |
498 | if etype in ["all", "match"]:
499 | print ('\n====================== TURN EXACT MATCHING ACCURACY =====================')
500 | exact_scores = [scores[turn]['exact'] for turn in turns]
501 | print_formated_s("exact match", exact_scores, '{:<20.3f}')
502 |
503 |
504 | def evaluate(gold, predict, db_dir, etype, kmaps, plug_value, keep_distinct, progress_bar_for_each_datapoint):
505 |
506 | with open(gold) as f:
507 | glist = []
508 | gseq_one = []
509 | for l in f.readlines():
510 | if len(l.strip()) == 0:
511 | glist.append(gseq_one)
512 | gseq_one = []
513 | else:
514 | lstrip = l.strip().split('\t')
515 | gseq_one.append(lstrip)
516 |
517 | # include the last session
518 | # this was previously ignored in the SParC evaluation script
519 | # which might lead to slight differences in scores
520 | if len(gseq_one) != 0:
521 | glist.append(gseq_one)
522 |
523 | # spider formatting indicates that there is only one "single turn"
524 | # do not report "turn accuracy" for SPIDER
525 | include_turn_acc = len(glist) > 1
526 |
527 | with open(predict) as f:
528 | plist = []
529 | pseq_one = []
530 | for l in f.readlines():
531 | if len(l.strip()) == 0:
532 | plist.append(pseq_one)
533 | pseq_one = []
534 | else:
535 | pseq_one.append(l.strip().split('\t'))
536 |
537 | if len(pseq_one) != 0:
538 | plist.append(pseq_one)
539 |
540 | assert len(plist) == len(glist), "number of sessions must equal"
541 |
542 | evaluator = Evaluator()
543 | turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4']
544 | levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
545 |
546 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
547 | 'group', 'order', 'and/or', 'IUEN', 'keywords']
548 | entries = []
549 | scores = {}
550 |
551 | for turn in turns:
552 | scores[turn] = {'count': 0, 'exact': 0.}
553 | scores[turn]['exec'] = 0
554 |
555 | for level in levels:
556 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
557 | scores[level]['exec'] = 0
558 | for type_ in partial_types:
559 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
560 |
561 | gold_pred_map_lst = []
562 |
563 | for i, (p, g) in enumerate(zip(plist, glist)):
564 | if (i + 1) % 10 == 0:
565 | print('Evaluating %dth prediction' % (i + 1))
566 | scores['joint_all']['count'] += 1
567 | turn_scores = {"exec": [], "exact": []}
568 |
569 | print(f"len(p): {len(p)}; len(g): {len(g)}")
570 | for idx, pg in enumerate(zip(p, g)):
571 | gold_pred_map = {
572 | 'idx': idx,
573 | 'db_id': '',
574 | 'question': '',
575 | 'gold': '',
576 | 'pred': '',
577 | 'exec_result': 0
578 | }
579 | p, g = pg
580 | p_str = p[0]
581 | p_str = p_str.replace("value", "1")
582 | g_str, db = g
583 |
584 | gold_pred_map['pred'] = p_str
585 | gold_pred_map['gold'] = g_str
586 | gold_pred_map['db_id'] = db
587 |
588 | db_name = db
589 | db = os.path.join(db_dir, db, db + ".sqlite")
590 | schema = Schema(get_schema(db))
591 | g_sql = get_sql(schema, g_str)
592 | hardness = evaluator.eval_hardness(g_sql)
593 | if idx > 3:
594 | idx = "> 4"
595 | else:
596 | idx += 1
597 | turn_id = "turn " + str(idx)
598 | scores[turn_id]['count'] += 1
599 | scores[hardness]['count'] += 1
600 | scores['all']['count'] += 1
601 |
602 | try:
603 | p_sql = get_sql(schema, p_str)
604 | except:
605 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
606 | p_sql = {
607 | "except": None,
608 | "from": {
609 | "conds": [],
610 | "table_units": []
611 | },
612 | "groupBy": [],
613 | "having": [],
614 | "intersect": None,
615 | "limit": None,
616 | "orderBy": [],
617 | "select": [
618 | False,
619 | []
620 | ],
621 | "union": None,
622 | "where": []
623 | }
624 |
625 | if etype in ["all", "exec"]:
626 | exec_score = eval_exec_match(db=db, p_str=p_str, g_str=g_str, plug_value=plug_value,
627 | keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint)
628 | if exec_score:
629 | scores[hardness]['exec'] += 1
630 | scores[turn_id]['exec'] += 1
631 | scores['all']['exec'] += 1
632 | turn_scores['exec'].append(1)
633 |
634 | gold_pred_map['exec_result'] = 1
635 | else:
636 | turn_scores['exec'].append(0)
637 | gold_pred_map_lst.append(gold_pred_map)
638 |
639 | if etype in ["all", "match"]:
640 | # rebuild sql for value evaluation
641 | kmap = kmaps[db_name]
642 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
643 | g_sql = rebuild_sql_val(g_sql)
644 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
645 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
646 | p_sql = rebuild_sql_val(p_sql)
647 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
648 | exact_score = evaluator.eval_exact_match(p_sql, g_sql)
649 | partial_scores = evaluator.partial_scores
650 | if exact_score == 0:
651 | turn_scores['exact'].append(0)
652 | print("{} pred: {}".format(hardness, p_str))
653 | print("{} gold: {}".format(hardness, g_str))
654 | print("")
655 | else:
656 | turn_scores['exact'].append(1)
657 | scores[turn_id]['exact'] += exact_score
658 | scores[hardness]['exact'] += exact_score
659 | scores['all']['exact'] += exact_score
660 | for type_ in partial_types:
661 | if partial_scores[type_]['pred_total'] > 0:
662 | scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
663 | scores[hardness]['partial'][type_]['acc_count'] += 1
664 | if partial_scores[type_]['label_total'] > 0:
665 | scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
666 | scores[hardness]['partial'][type_]['rec_count'] += 1
667 | scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
668 | if partial_scores[type_]['pred_total'] > 0:
669 | scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
670 | scores['all']['partial'][type_]['acc_count'] += 1
671 | if partial_scores[type_]['label_total'] > 0:
672 | scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
673 | scores['all']['partial'][type_]['rec_count'] += 1
674 | scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
675 |
676 | entries.append({
677 | 'predictSQL': p_str,
678 | 'goldSQL': g_str,
679 | 'hardness': hardness,
680 | 'exact': exact_score,
681 | 'partial': partial_scores
682 | })
683 |
684 | if all(v == 1 for v in turn_scores["exec"]):
685 | scores['joint_all']['exec'] += 1
686 |
687 | if all(v == 1 for v in turn_scores["exact"]):
688 | scores['joint_all']['exact'] += 1
689 |
690 | # export evaluation result
691 | out_dir = os.path.dirname(predict)
692 | out_evaluation_json_path = os.path.join(out_dir, "evaluation.json")
693 | if not os.path.exists(out_dir):
694 | os.makedirs(out_dir, exist_ok=True)
695 |
696 | total_cnt = len(gold_pred_map_lst)
697 | err_cnt = len([m for m in gold_pred_map_lst if m['exec_result']== 0])
698 | print(f"total_cnt: {total_cnt}, err_cnt: {err_cnt}")
699 |
700 | # save json file
701 | with open(out_evaluation_json_path, 'w') as f:
702 | json.dump(gold_pred_map_lst, f, indent=2)
703 | print("save evaluation result to {}".format(out_evaluation_json_path))
704 |
705 | for turn in turns:
706 | if scores[turn]['count'] == 0:
707 | continue
708 | if etype in ["all", "exec"]:
709 | scores[turn]['exec'] /= scores[turn]['count']
710 |
711 | if etype in ["all", "match"]:
712 | scores[turn]['exact'] /= scores[turn]['count']
713 |
714 | for level in levels:
715 | if scores[level]['count'] == 0:
716 | continue
717 | if etype in ["all", "exec"]:
718 | scores[level]['exec'] /= scores[level]['count']
719 |
720 | if etype in ["all", "match"]:
721 | scores[level]['exact'] /= scores[level]['count']
722 | for type_ in partial_types:
723 | if scores[level]['partial'][type_]['acc_count'] == 0:
724 | scores[level]['partial'][type_]['acc'] = 0
725 | else:
726 | scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
727 | scores[level]['partial'][type_]['acc_count'] * 1.0
728 | if scores[level]['partial'][type_]['rec_count'] == 0:
729 | scores[level]['partial'][type_]['rec'] = 0
730 | else:
731 | scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
732 | scores[level]['partial'][type_]['rec_count'] * 1.0
733 | if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
734 | scores[level]['partial'][type_]['f1'] = 1
735 | else:
736 | scores[level]['partial'][type_]['f1'] = \
737 | 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
738 | scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
739 |
740 | print_scores(scores, etype, include_turn_acc=include_turn_acc)
741 |
742 |
743 | # Rebuild SQL functions for value evaluation
744 | def rebuild_cond_unit_val(cond_unit):
745 | if cond_unit is None or not DISABLE_VALUE:
746 | return cond_unit
747 |
748 | not_op, op_id, val_unit, val1, val2 = cond_unit
749 | if type(val1) is not dict:
750 | val1 = None
751 | else:
752 | val1 = rebuild_sql_val(val1)
753 | if type(val2) is not dict:
754 | val2 = None
755 | else:
756 | val2 = rebuild_sql_val(val2)
757 | return not_op, op_id, val_unit, val1, val2
758 |
759 |
760 | def rebuild_condition_val(condition):
761 | if condition is None or not DISABLE_VALUE:
762 | return condition
763 |
764 | res = []
765 | for idx, it in enumerate(condition):
766 | if idx % 2 == 0:
767 | res.append(rebuild_cond_unit_val(it))
768 | else:
769 | res.append(it)
770 | return res
771 |
772 |
773 | def rebuild_sql_val(sql):
774 | if sql is None or not DISABLE_VALUE:
775 | return sql
776 |
777 | sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
778 | sql['having'] = rebuild_condition_val(sql['having'])
779 | sql['where'] = rebuild_condition_val(sql['where'])
780 | sql['intersect'] = rebuild_sql_val(sql['intersect'])
781 | sql['except'] = rebuild_sql_val(sql['except'])
782 | sql['union'] = rebuild_sql_val(sql['union'])
783 |
784 | return sql
785 |
786 |
787 | # Rebuild SQL functions for foreign key evaluation
788 | def build_valid_col_units(table_units, schema):
789 | col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
790 | prefixs = [col_id[:-2] for col_id in col_ids]
791 | valid_col_units= []
792 | for value in schema.idMap.values():
793 | if '.' in value and value[:value.index('.')] in prefixs:
794 | valid_col_units.append(value)
795 | return valid_col_units
796 |
797 |
798 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
799 | if col_unit is None:
800 | return col_unit
801 |
802 | agg_id, col_id, distinct = col_unit
803 | if col_id in kmap and col_id in valid_col_units:
804 | col_id = kmap[col_id]
805 | if DISABLE_DISTINCT:
806 | distinct = None
807 | return agg_id, col_id, distinct
808 |
809 |
810 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
811 | if val_unit is None:
812 | return val_unit
813 |
814 | unit_op, col_unit1, col_unit2 = val_unit
815 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
816 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
817 | return unit_op, col_unit1, col_unit2
818 |
819 |
820 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
821 | if table_unit is None:
822 | return table_unit
823 |
824 | table_type, col_unit_or_sql = table_unit
825 | if isinstance(col_unit_or_sql, tuple):
826 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
827 | return table_type, col_unit_or_sql
828 |
829 |
830 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
831 | if cond_unit is None:
832 | return cond_unit
833 |
834 | not_op, op_id, val_unit, val1, val2 = cond_unit
835 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
836 | return not_op, op_id, val_unit, val1, val2
837 |
838 |
839 | def rebuild_condition_col(valid_col_units, condition, kmap):
840 | for idx in range(len(condition)):
841 | if idx % 2 == 0:
842 | condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
843 | return condition
844 |
845 |
846 | def rebuild_select_col(valid_col_units, sel, kmap):
847 | if sel is None:
848 | return sel
849 | distinct, _list = sel
850 | new_list = []
851 | for it in _list:
852 | agg_id, val_unit = it
853 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
854 | if DISABLE_DISTINCT:
855 | distinct = None
856 | return distinct, new_list
857 |
858 |
859 | def rebuild_from_col(valid_col_units, from_, kmap):
860 | if from_ is None:
861 | return from_
862 |
863 | from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
864 | from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
865 | return from_
866 |
867 |
868 | def rebuild_group_by_col(valid_col_units, group_by, kmap):
869 | if group_by is None:
870 | return group_by
871 |
872 | return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
873 |
874 |
875 | def rebuild_order_by_col(valid_col_units, order_by, kmap):
876 | if order_by is None or len(order_by) == 0:
877 | return order_by
878 |
879 | direction, val_units = order_by
880 | new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
881 | return direction, new_val_units
882 |
883 |
884 | def rebuild_sql_col(valid_col_units, sql, kmap):
885 | if sql is None:
886 | return sql
887 |
888 | sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
889 | sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
890 | sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
891 | sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
892 | sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
893 | sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
894 | sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
895 | sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
896 | sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
897 |
898 | return sql
899 |
900 |
901 | def build_foreign_key_map(entry):
902 | cols_orig = entry["column_names_original"]
903 | tables_orig = entry["table_names_original"]
904 |
905 | # rebuild cols corresponding to idmap in Schema
906 | cols = []
907 | for col_orig in cols_orig:
908 | if col_orig[0] >= 0:
909 | t = tables_orig[col_orig[0]]
910 | c = col_orig[1]
911 | cols.append("__" + t.lower() + "." + c.lower() + "__")
912 | else:
913 | cols.append("__all__")
914 |
915 | def keyset_in_list(k1, k2, k_list):
916 | for k_set in k_list:
917 | if k1 in k_set or k2 in k_set:
918 | return k_set
919 | new_k_set = set()
920 | k_list.append(new_k_set)
921 | return new_k_set
922 |
923 | foreign_key_list = []
924 | foreign_keys = entry["foreign_keys"]
925 | for fkey in foreign_keys:
926 | key1, key2 = fkey
927 | key_set = keyset_in_list(key1, key2, foreign_key_list)
928 | key_set.add(key1)
929 | key_set.add(key2)
930 |
931 | foreign_key_map = {}
932 | for key_set in foreign_key_list:
933 | sorted_list = sorted(list(key_set))
934 | midx = sorted_list[0]
935 | for idx in sorted_list:
936 | foreign_key_map[cols[idx]] = cols[midx]
937 |
938 | return foreign_key_map
939 |
940 |
941 | def build_foreign_key_map_from_json(table):
942 | with open(table) as f:
943 | data = json.load(f)
944 | tables = {}
945 | for entry in data:
946 | tables[entry['db_id']] = build_foreign_key_map(entry)
947 | return tables
948 |
949 |
950 | if __name__ == "__main__":
951 | parser = argparse.ArgumentParser()
952 | parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold queries")
953 | parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted queries")
954 | parser.add_argument('--db', dest='db', type=str, help="the directory that contains all the databases and test suites")
955 | parser.add_argument('--table', dest='table', type=str, help="the tables.json schema file")
956 | parser.add_argument('--etype', dest='etype', type=str, default='exec',
957 | help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy",
958 | choices=('all', 'exec', 'match'))
959 | parser.add_argument('--plug_value', default=False, action='store_true',
960 | help='whether to plug in the gold value into the predicted query; suitable if your model does not predict values.')
961 | parser.add_argument('--keep_distinct', default=False, action='store_true',
962 | help='whether to keep distinct keyword during evaluation. default is false.')
963 | parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true',
964 | help='whether to print progress bar of running test inputs for each datapoint')
965 | args = parser.parse_args()
966 |
967 | # only evaluting exact match needs this argument
968 | kmaps = None
969 | if args.etype in ['all', 'match']:
970 | assert args.table is not None, 'table argument must be non-None if exact set match is evaluated'
971 | kmaps = build_foreign_key_map_from_json(args.table)
972 |
973 | evaluate(args.gold, args.pred, args.db, args.etype, kmaps, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint)
974 |
--------------------------------------------------------------------------------
/evaluation/exec_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import asyncio
4 | import sqlite3
5 | import threading
6 | from typing import Tuple, Any, List, Set
7 | from itertools import product
8 | from collections import defaultdict
9 | import tqdm
10 | import random
11 | from parse import get_all_preds_for_execution, remove_distinct
12 | import time
13 | import pickle as pkl
14 | import subprocess
15 | from itertools import chain
16 |
17 |
18 |
19 | threadLock = threading.Lock()
20 | TIMEOUT = 60
21 | EXEC_TMP_DIR = 'tmp/'
22 |
23 | def permute_tuple(element: Tuple, perm: Tuple) -> Tuple:
24 | assert len(element) == len(perm)
25 | return tuple([element[i] for i in perm])
26 |
27 |
28 | def unorder_row(row: Tuple) -> Tuple:
29 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
30 |
31 |
32 | # unorder each row in the table
33 | # [result_1 and result_2 has the same bag of unordered row]
34 | # is a necessary condition of
35 | # [result_1 and result_2 are equivalent in denotation]
36 | def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
37 | s1 = [unorder_row(row) for row in result1]
38 | s2 = [unorder_row(row) for row in result2]
39 | if order_matters:
40 | return s1 == s2
41 | else:
42 | return set(s1) == set(s2)
43 |
44 |
45 | # return whether two bag of relations are equivalent
46 | def multiset_eq(l1: List, l2: List) -> bool:
47 | if len(l1) != len(l2):
48 | return False
49 | d = defaultdict(int)
50 | for e in l1:
51 | d[e] = d[e] + 1
52 | for e in l2:
53 | d[e] = d[e] - 1
54 | if d[e] < 0:
55 | return False
56 | return True
57 |
58 |
59 | def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]):
60 | num_cols = len(result2[0])
61 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
62 | if num_cols <= 3:
63 | return product(*perm_constraints)
64 |
65 | # we sample 20 rows and constrain the space of permutations
66 | for _ in range(20):
67 | random_tab2_row = random.choice(result2)
68 |
69 | for tab1_col in range(num_cols):
70 | for tab2_col in set(perm_constraints[tab1_col]):
71 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
72 | perm_constraints[tab1_col].remove(tab2_col)
73 | return product(*perm_constraints)
74 |
75 |
76 | # check whether two denotations are correct
77 | def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
78 | if len(result1) == 0 and len(result2) == 0:
79 | return True
80 |
81 | # if length is not the same, then they are definitely different bag of rows
82 | if len(result1) != len(result2):
83 | return False
84 |
85 | num_cols = len(result1[0])
86 |
87 | # if the results do not have the same number of columns, they are different
88 | if len(result2[0]) != num_cols:
89 | return False
90 |
91 | # unorder each row and compare whether the denotation is the same
92 | # this can already find most pair of denotations that are different
93 | if not quick_rej(result1, result2, order_matters):
94 | return False
95 |
96 | # the rest of the problem is in fact more complicated than one might think
97 | # we want to find a permutation of column order and a permutation of row order,
98 | # s.t. result_1 is the same as result_2
99 | # we return true if we can find such column & row permutations
100 | # and false if we cannot
101 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
102 |
103 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2
104 | # we decrease the size of the column permutation space by the function get_constraint_permutation
105 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent
106 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
107 | if len(perm) != len(set(perm)):
108 | continue
109 | if num_cols == 1:
110 | result2_perm = result2
111 | else:
112 | result2_perm = [permute_tuple(element, perm) for element in result2]
113 | if order_matters:
114 | if result1 == result2_perm:
115 | return True
116 | else:
117 | # in fact the first condition must hold if the second condition holds
118 | # but the first is way more efficient implementation-wise
119 | # and we use it to quickly reject impossible candidates
120 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
121 | return True
122 | return False
123 |
124 |
125 | def replace_cur_year(query: str) -> str:
126 | return re.sub(
127 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE
128 | )
129 |
130 |
131 | # get the database cursor for a sqlite database path
132 | def get_cursor_from_path(sqlite_path: str):
133 | try:
134 | if not os.path.exists(sqlite_path):
135 | print("Openning a new connection %s" % sqlite_path)
136 | connection = sqlite3.connect(sqlite_path)
137 | except Exception as e:
138 | print(sqlite_path)
139 | raise e
140 | connection.text_factory = lambda b: b.decode(errors="ignore")
141 | cursor = connection.cursor()
142 | return cursor
143 |
144 |
145 | async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]:
146 | query = replace_cur_year(query)
147 | cursor = get_cursor_from_path(sqlite_path)
148 | try:
149 | cursor.execute(query)
150 | result = cursor.fetchall()
151 | cursor.close()
152 | cursor.connection.close()
153 | return "result", result
154 | except Exception as e:
155 | cursor.close()
156 | cursor.connection.close()
157 | return "exception", e
158 |
159 | async def exec_on_db(
160 | sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT
161 | ) -> Tuple[str, Any]:
162 | try:
163 | return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout)
164 | except asyncio.TimeoutError:
165 | return ('exception', TimeoutError)
166 | except Exception as e:
167 | return ("exception", e)
168 |
169 |
170 | # postprocess the model predictions to avoid execution errors
171 | # e.g. removing spaces between ">" and "="
172 | def postprocess(query: str) -> str:
173 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=')
174 | return query
175 |
176 |
177 | # approximate whether p_str and g_str are semantically equivalent
178 | # db is the database path
179 | # we are going to evaluate whether they are equivalent in all the databases
180 | # that are in the same directory as db
181 | # 0 if denotationally equivalent
182 | # 1 otherwise
183 | # the meaning of each auxillary argument can be seen in the parser definition in evaluation.py
184 | def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int:
185 | # post-process the prediction.
186 | # e.g. removing spaces between ">" and "="
187 | p_str, g_str = postprocess(p_str), postprocess(g_str)
188 | if not keep_distinct:
189 | p_str = remove_distinct(p_str)
190 | g_str = remove_distinct(g_str)
191 |
192 | # we decide whether two denotations are equivalent based on "bag semantics"
193 | # https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf
194 | # if there is order by in query, then we assume order of the rows matter
195 | # order by might also be used to find the max/min instead of sorting,
196 | # but in that case the result mostly only contains one row and hence order_matters does not make a difference
197 | order_matters = 'order by' in g_str.lower()
198 |
199 | # find all databases in the same directory
200 | db_dir = os.path.dirname(db)
201 | db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename]
202 |
203 | preds = [p_str]
204 | # if plug in value (i.e. we do not consider value prediction correctness)
205 | # enumerate all ways to plug in values in the gold query to the model predictions
206 | # otherwise, we only evaluate the predicted query with its own value prediction
207 | if plug_value:
208 | _, preds = get_all_preds_for_execution(g_str, p_str)
209 | # we did not add this line in our EMNLP work
210 | # this reduces "false negatives" when value is substituted
211 | preds = chain([p_str], preds)
212 |
213 | for pred in preds:
214 |
215 | pred_passes = 1
216 | # compare the gold and predicted denotations on each database in the directory
217 | # wrap with progress bar if required
218 | if progress_bar_for_each_datapoint:
219 | ranger = tqdm.tqdm(db_paths)
220 | else:
221 | ranger = db_paths
222 |
223 | for db_path in ranger:
224 | g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str))
225 | p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred))
226 |
227 | # we should expect the gold to be succesfully executed on the database
228 | assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path)
229 |
230 | # wrong if execution fails
231 | if p_flag == 'exception':
232 | pred_passes = 0
233 |
234 | # if denotations are not equivalent, the prediction must be wrong
235 | elif not result_eq(g_denotation, p_denotation, order_matters=order_matters):
236 | pred_passes = 0
237 | if pred_passes == 0:
238 | break
239 |
240 | # the model prediction has the same denotation as the gold for all databases
241 | if pred_passes == 1:
242 | return 1
243 |
244 | # none of the predictions passed
245 | return 0
246 |
--------------------------------------------------------------------------------
/evaluation/parse.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sqlparse
3 | from typing import List, Tuple, Set, Iterator, Dict, Any, Union
4 | from sqlparse.sql import Comparison, Identifier
5 | from sqlparse.tokens import Whitespace
6 | import itertools
7 | from collections import namedtuple
8 |
9 | Token = namedtuple('Token', ['ttype', 'value'])
10 | VALUE_NUM_SYMBOL = 'VALUERARE'
11 | QUOTE_CHARS = {'`', '\'', '"'}
12 |
13 |
14 | def tokenize(query: str) -> List[Token]:
15 | tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()])
16 | return tokens
17 |
18 |
19 | def join_tokens(tokens: List[Token]) -> str:
20 | return ''.join([x.value for x in tokens]).strip().replace(' ', ' ')
21 |
22 |
23 | def round_trip_test(query: str) -> None:
24 | tokens = tokenize(query)
25 | reconstructed = ''.join([token.value for token in tokens])
26 | assert query == reconstructed, "Round trip test fails for string %s" % query
27 |
28 |
29 | def postprocess(query: str) -> str:
30 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=')
31 | return query
32 |
33 |
34 | # strip_query, reformat_query and replace values
35 | # were implemented by Yu Tao for processing CoSQL
36 | def strip_query(query: str) -> Tuple[List[str], List[str]]:
37 | query_keywords, all_values = [], []
38 |
39 | # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE}
40 |
41 | # Tao's implementation is commented out here.
42 | """
43 | str_1 = re.findall("\"[^\"]*\"", query)
44 | str_2 = re.findall("\'[^\']*\'", query)
45 | values = str_1 + str_2
46 | """
47 |
48 | toks = sqlparse.parse(query)[0].flatten()
49 | values = [t.value for t in toks if t.ttype == sqlparse.tokens.Literal.String.Single or t.ttype == sqlparse.tokens.Literal.String.Symbol]
50 |
51 |
52 | for val in values:
53 | all_values.append(val)
54 | query = query.replace(val.strip(), VALUE_NUM_SYMBOL)
55 |
56 | query_tokenized = query.split()
57 | float_nums = re.findall("[-+]?\d*\.\d+", query)
58 | all_values += [qt for qt in query_tokenized if qt in float_nums]
59 | query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized]
60 |
61 | query = " ".join(query_tokenized)
62 | int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)]
63 |
64 | all_values += [qt for qt in query_tokenized if qt in int_nums]
65 | query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized]
66 | # print int_nums, query, query_tokenized
67 |
68 | for tok in query_tokenized:
69 | if "." in tok:
70 | table = re.findall("[Tt]\d+\.", tok)
71 | if len(table) > 0:
72 | to = tok.replace(".", " . ").split()
73 | to = [t.lower() for t in to if len(t) > 0]
74 | query_keywords.extend(to)
75 | else:
76 | query_keywords.append(tok.lower())
77 |
78 | elif len(tok) > 0:
79 | query_keywords.append(tok.lower())
80 | return query_keywords, all_values
81 |
82 |
83 | def reformat_query(query: str) -> str:
84 | query = query.strip().replace(";", "").replace("\t", "")
85 | query = ' '.join([t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace])
86 | t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"]
87 | for ts in t_stars:
88 | query = query.replace(ts, "*")
89 | return query
90 |
91 |
92 | def replace_values(sql: str) -> Tuple[List[str], Set[str]]:
93 | sql = sqlparse.format(sql, reindent=False, keyword_case='upper')
94 | # sql = re.sub(r"(<=|>=|!=|=|<|>|,)", r" \1 ", sql)
95 | sql = re.sub(r"(T\d+\.)\s", r"\1", sql)
96 | query_toks_no_value, values = strip_query(sql)
97 | return query_toks_no_value, set(values)
98 |
99 |
100 | # extract the non-value tokens and the set of values
101 | # from a sql query
102 | def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]:
103 | reformated = reformat_query(query=sql)
104 | query_value_replaced, values = replace_values(reformated)
105 | return query_value_replaced, values
106 |
107 |
108 | # plug in the values into query with value slots
109 | def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str:
110 | q_length = len(query_value_replaced)
111 | query_w_values = query_value_replaced[:]
112 | value_idx = [idx for idx in range(q_length) if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower()]
113 | assert len(value_idx) == len(values_in_order)
114 |
115 | for idx, value in zip(value_idx, values_in_order):
116 | query_w_values[idx] = value
117 | return ' '.join(query_w_values)
118 |
119 |
120 | # a generator generating all possible ways of
121 | # filling values into predicted query
122 | def plugin_all_permutations(query_value_replaced: List[str], values: Set[str]) -> Iterator[str]:
123 | num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()])
124 | for values in itertools.product(*[list(values) for _ in range(num_slots)]):
125 | yield plugin(query_value_replaced, list(values))
126 |
127 |
128 | # given the gold query and the model prediction
129 | # extract values from the gold, extract predicted sql with value slots
130 | # return 1) number of possible ways to plug in gold values and 2) an iterator of predictions with value plugged in
131 | def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]:
132 | _, gold_values = extract_query_values(gold)
133 | pred_query_value_replaced, _ = extract_query_values(pred)
134 | num_slots = len([v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()])
135 | num_alternatives = len(gold_values) ** num_slots
136 | return num_alternatives, plugin_all_permutations(pred_query_value_replaced, gold_values)
137 |
138 |
139 | def remove_distinct(s):
140 | toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())]
141 | return ''.join([t for t in toks if t.lower() != 'distinct'])
142 |
143 |
144 | def extract_all_comparison_from_node(node: Token) -> List[Comparison]:
145 | comparison_list = []
146 | if hasattr(node, 'tokens'):
147 | for t in node.tokens:
148 | comparison_list.extend(extract_all_comparison_from_node(t))
149 | if type(node) == Comparison:
150 | comparison_list.append(node)
151 | return comparison_list
152 |
153 |
154 | def extract_all_comparison(query: str) -> List[Comparison]:
155 | tree = sqlparse.parse(query)[0]
156 | comparison_list = extract_all_comparison_from_node(tree)
157 | return comparison_list
158 |
159 |
160 | def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]:
161 | tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace]
162 | return tokens
163 |
164 |
165 | def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]:
166 | tokens = extract_toks_from_comparison(comparison_node)
167 | left, op, right = tokens
168 |
169 | returned_dict = {
170 | 'left': left,
171 | 'op': op.value,
172 | 'right': right
173 | }
174 |
175 | if type(left) != Identifier:
176 | return returned_dict
177 |
178 | table = None
179 | if len(left.tokens) == 3 and re.match('^[tT][0-9]$', left.tokens[0].value) is None:
180 | table = left.tokens[0].value.lower()
181 | col = left.tokens[-1].value
182 |
183 | if type(right) == Identifier:
184 | if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token:
185 | right_val = right.tokens[0].value
186 | else:
187 | return returned_dict
188 | elif type(right) == sqlparse.sql.Token:
189 | right_val = right.value
190 | else:
191 | return returned_dict
192 |
193 | returned_dict['table_col'], returned_dict['val'] = (table, col.upper()), process_str_value(right_val)
194 |
195 | return returned_dict
196 |
197 |
198 | def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]:
199 | comparison_list = extract_all_comparison(query)
200 | return [extract_info_from_comparison(c) for c in comparison_list]
201 |
202 |
203 | def extract_typed_value_in_comparison_from_query(query: str) -> List[Tuple[Tuple[Union[str, None], str], str]]:
204 | cmps = extract_all_comparison_from_query(query)
205 | typed_values = [(cmp['table_col'], cmp['val']) for cmp in cmps if 'table_col' in cmp]
206 | for table, col, val1, val2 in re.findall('(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)', query, re.IGNORECASE):
207 | if table == '':
208 | table = None
209 | else:
210 | table = table.lower()
211 | col = col.upper()
212 | for v in [val1, val2]:
213 | typed_values.append(((table, col), v))
214 | return typed_values
215 |
216 |
217 | def process_str_value(v: str) -> str:
218 | if len(v) > 0 and v[0] in QUOTE_CHARS:
219 | v = v[1:]
220 | if len(v) > 0 and v[-1] in QUOTE_CHARS:
221 | v = v[:-1]
222 | for c in QUOTE_CHARS:
223 | v = v.replace(c + c, c)
224 | return v
225 |
--------------------------------------------------------------------------------
/evaluation/process_sql.py:
--------------------------------------------------------------------------------
1 | ################################
2 | # Assumptions:
3 | # 1. sql is correct
4 | # 2. only table name has alias
5 | # 3. only one intersect/union/except
6 | #
7 | # val: number(float)/string(str)/sql(dict)
8 | # col_unit: (agg_id, col_id, isDistinct(bool))
9 | # val_unit: (unit_op, col_unit1, col_unit2)
10 | # table_unit: (table_type, col_unit/sql)
11 | # cond_unit: (not_op, op_id, val_unit, val1, val2)
12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
13 | # sql {
14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
16 | # 'where': condition
17 | # 'groupBy': [col_unit1, col_unit2, ...]
18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
19 | # 'having': condition
20 | # 'limit': None/limit value
21 | # 'intersect': None/sql
22 | # 'except': None/sql
23 | # 'union': None/sql
24 | # }
25 | ################################
26 |
27 | import json
28 | import sqlite3
29 | from nltk import word_tokenize
30 |
31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
32 | JOIN_KEYWORDS = ('join', 'on', 'as')
33 |
34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
35 | UNIT_OPS = ('none', '-', '+', "*", '/')
36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
37 | TABLE_TYPE = {
38 | 'sql': "sql",
39 | 'table_unit': "table_unit",
40 | }
41 |
42 | COND_OPS = ('and', 'or')
43 | SQL_OPS = ('intersect', 'union', 'except')
44 | ORDER_OPS = ('desc', 'asc')
45 |
46 |
47 |
48 | class Schema:
49 | """
50 | Simple schema which maps table&column to a unique identifier
51 | """
52 | def __init__(self, schema):
53 | self._schema = schema
54 | self._idMap = self._map(self._schema)
55 |
56 | @property
57 | def schema(self):
58 | return self._schema
59 |
60 | @property
61 | def idMap(self):
62 | return self._idMap
63 |
64 | def _map(self, schema):
65 | idMap = {'*': "__all__"}
66 | id = 1
67 | for key, vals in schema.items():
68 | for val in vals:
69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
70 | id += 1
71 |
72 | for key in schema:
73 | idMap[key.lower()] = "__" + key.lower() + "__"
74 | id += 1
75 |
76 | return idMap
77 |
78 |
79 | def get_schema(db):
80 | """
81 | Get database's schema, which is a dict with table name as key
82 | and list of column names as value
83 | :param db: database path
84 | :return: schema dict
85 | """
86 |
87 | schema = {}
88 | conn = sqlite3.connect(db)
89 | cursor = conn.cursor()
90 |
91 | # fetch table names
92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
93 | tables = [str(table[0].lower()) for table in cursor.fetchall()]
94 |
95 | # fetch table info
96 | for table in tables:
97 | cursor.execute("PRAGMA table_info({})".format(table))
98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
99 |
100 | return schema
101 |
102 |
103 | def get_schema_from_json(fpath):
104 | with open(fpath) as f:
105 | data = json.load(f)
106 |
107 | schema = {}
108 | for entry in data:
109 | table = str(entry['table'].lower())
110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']]
111 | schema[table] = cols
112 |
113 | return schema
114 |
115 |
116 | def tokenize(string):
117 | string = str(string)
118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote"
121 |
122 | # keep string value as token
123 | vals = {}
124 | for i in range(len(quote_idxs)-1, -1, -2):
125 | qidx1 = quote_idxs[i-1]
126 | qidx2 = quote_idxs[i]
127 | val = string[qidx1: qidx2+1]
128 | key = "__val_{}_{}__".format(qidx1, qidx2)
129 | string = string[:qidx1] + key + string[qidx2+1:]
130 | vals[key] = val
131 |
132 | toks = [word.lower() for word in word_tokenize(string)]
133 | # replace with string value token
134 | for i in range(len(toks)):
135 | if toks[i] in vals:
136 | toks[i] = vals[toks[i]]
137 |
138 | # find if there exists !=, >=, <=
139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
140 | eq_idxs.reverse()
141 | prefix = ('!', '>', '<')
142 | for eq_idx in eq_idxs:
143 | pre_tok = toks[eq_idx-1]
144 | if pre_tok in prefix:
145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
146 |
147 | return toks
148 |
149 |
150 | def scan_alias(toks):
151 | """Scan the index of 'as' and build the map for all alias"""
152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
153 | alias = {}
154 | for idx in as_idxs:
155 | alias[toks[idx+1]] = toks[idx-1]
156 | return alias
157 |
158 |
159 | def get_tables_with_alias(schema, toks):
160 | tables = scan_alias(toks)
161 | for key in schema:
162 | assert key not in tables, "Alias {} has the same name in table".format(key)
163 | tables[key] = key
164 | return tables
165 |
166 |
167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
168 | """
169 | :returns next idx, column id
170 | """
171 | tok = toks[start_idx]
172 | if tok == "*":
173 | return start_idx + 1, schema.idMap[tok]
174 |
175 | if '.' in tok: # if token is a composite
176 | alias, col = tok.split('.')
177 | key = tables_with_alias[alias] + "." + col
178 | return start_idx+1, schema.idMap[key]
179 |
180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
181 |
182 | for alias in default_tables:
183 | table = tables_with_alias[alias]
184 | if tok in schema.schema[table]:
185 | key = table + "." + tok
186 | return start_idx+1, schema.idMap[key]
187 |
188 | assert False, "Error col: {}".format(tok)
189 |
190 |
191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
192 | """
193 | :returns next idx, (agg_op id, col_id)
194 | """
195 | idx = start_idx
196 | len_ = len(toks)
197 | isBlock = False
198 | isDistinct = False
199 | if toks[idx] == '(':
200 | isBlock = True
201 | idx += 1
202 |
203 | if toks[idx] in AGG_OPS:
204 | agg_id = AGG_OPS.index(toks[idx])
205 | idx += 1
206 | assert idx < len_ and toks[idx] == '('
207 | idx += 1
208 | if toks[idx] == "distinct":
209 | idx += 1
210 | isDistinct = True
211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
212 | assert idx < len_ and toks[idx] == ')'
213 | idx += 1
214 | return idx, (agg_id, col_id, isDistinct)
215 |
216 | if toks[idx] == "distinct":
217 | idx += 1
218 | isDistinct = True
219 | agg_id = AGG_OPS.index("none")
220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
221 |
222 | if isBlock:
223 | assert toks[idx] == ')'
224 | idx += 1 # skip ')'
225 |
226 | return idx, (agg_id, col_id, isDistinct)
227 |
228 |
229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
230 | idx = start_idx
231 | len_ = len(toks)
232 | isBlock = False
233 | if toks[idx] == '(':
234 | isBlock = True
235 | idx += 1
236 |
237 | col_unit1 = None
238 | col_unit2 = None
239 | unit_op = UNIT_OPS.index('none')
240 |
241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
242 | if idx < len_ and toks[idx] in UNIT_OPS:
243 | unit_op = UNIT_OPS.index(toks[idx])
244 | idx += 1
245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
246 |
247 | if isBlock:
248 | assert toks[idx] == ')'
249 | idx += 1 # skip ')'
250 |
251 | return idx, (unit_op, col_unit1, col_unit2)
252 |
253 |
254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema):
255 | """
256 | :returns next idx, table id, table name
257 | """
258 | idx = start_idx
259 | len_ = len(toks)
260 | key = tables_with_alias[toks[idx]]
261 |
262 | if idx + 1 < len_ and toks[idx+1] == "as":
263 | idx += 3
264 | else:
265 | idx += 1
266 |
267 | return idx, schema.idMap[key], key
268 |
269 |
270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
271 | idx = start_idx
272 | len_ = len(toks)
273 |
274 | isBlock = False
275 | if toks[idx] == '(':
276 | isBlock = True
277 | idx += 1
278 |
279 | if toks[idx] == 'select':
280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema)
281 | elif "\"" in toks[idx]: # token is a string value
282 | val = toks[idx]
283 | idx += 1
284 | else:
285 | try:
286 | val = float(toks[idx])
287 | idx += 1
288 | except:
289 | end_idx = idx
290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
292 | end_idx += 1
293 |
294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
295 | idx = end_idx
296 |
297 | if isBlock:
298 | assert toks[idx] == ')'
299 | idx += 1
300 |
301 | return idx, val
302 |
303 |
304 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
305 | idx = start_idx
306 | len_ = len(toks)
307 | conds = []
308 |
309 | while idx < len_:
310 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
311 | not_op = False
312 | if toks[idx] == 'not':
313 | not_op = True
314 | idx += 1
315 |
316 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
317 | op_id = WHERE_OPS.index(toks[idx])
318 | idx += 1
319 | val1 = val2 = None
320 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
321 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
322 | assert toks[idx] == 'and'
323 | idx += 1
324 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
325 | else: # normal case: single value
326 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
327 | val2 = None
328 |
329 | conds.append((not_op, op_id, val_unit, val1, val2))
330 |
331 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
332 | break
333 |
334 | if idx < len_ and toks[idx] in COND_OPS:
335 | conds.append(toks[idx])
336 | idx += 1 # skip and/or
337 |
338 | return idx, conds
339 |
340 |
341 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
342 | idx = start_idx
343 | len_ = len(toks)
344 |
345 | assert toks[idx] == 'select', "'select' not found"
346 | idx += 1
347 | isDistinct = False
348 | if idx < len_ and toks[idx] == 'distinct':
349 | idx += 1
350 | isDistinct = True
351 | val_units = []
352 |
353 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
354 | agg_id = AGG_OPS.index("none")
355 | if toks[idx] in AGG_OPS:
356 | agg_id = AGG_OPS.index(toks[idx])
357 | idx += 1
358 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
359 | val_units.append((agg_id, val_unit))
360 | if idx < len_ and toks[idx] == ',':
361 | idx += 1 # skip ','
362 |
363 | return idx, (isDistinct, val_units)
364 |
365 |
366 | def parse_from(toks, start_idx, tables_with_alias, schema):
367 | """
368 | Assume in the from clause, all table units are combined with join
369 | """
370 | assert 'from' in toks[start_idx:], "'from' not found"
371 |
372 | len_ = len(toks)
373 | idx = toks.index('from', start_idx) + 1
374 | default_tables = []
375 | table_units = []
376 | conds = []
377 |
378 | while idx < len_:
379 | isBlock = False
380 | if toks[idx] == '(':
381 | isBlock = True
382 | idx += 1
383 |
384 | if toks[idx] == 'select':
385 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
386 | table_units.append((TABLE_TYPE['sql'], sql))
387 | else:
388 | if idx < len_ and toks[idx] == 'join':
389 | idx += 1 # skip join
390 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
391 | table_units.append((TABLE_TYPE['table_unit'],table_unit))
392 | default_tables.append(table_name)
393 | if idx < len_ and toks[idx] == "on":
394 | idx += 1 # skip on
395 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
396 | if len(conds) > 0:
397 | conds.append('and')
398 | conds.extend(this_conds)
399 |
400 | if isBlock:
401 | assert toks[idx] == ')'
402 | idx += 1
403 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
404 | break
405 |
406 | return idx, table_units, conds, default_tables
407 |
408 |
409 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
410 | idx = start_idx
411 | len_ = len(toks)
412 |
413 | if idx >= len_ or toks[idx] != 'where':
414 | return idx, []
415 |
416 | idx += 1
417 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
418 | return idx, conds
419 |
420 |
421 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
422 | idx = start_idx
423 | len_ = len(toks)
424 | col_units = []
425 |
426 | if idx >= len_ or toks[idx] != 'group':
427 | return idx, col_units
428 |
429 | idx += 1
430 | assert toks[idx] == 'by'
431 | idx += 1
432 |
433 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
434 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
435 | col_units.append(col_unit)
436 | if idx < len_ and toks[idx] == ',':
437 | idx += 1 # skip ','
438 | else:
439 | break
440 |
441 | return idx, col_units
442 |
443 |
444 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
445 | idx = start_idx
446 | len_ = len(toks)
447 | val_units = []
448 | order_type = 'asc' # default type is 'asc'
449 |
450 | if idx >= len_ or toks[idx] != 'order':
451 | return idx, val_units
452 |
453 | idx += 1
454 | assert toks[idx] == 'by'
455 | idx += 1
456 |
457 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
458 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
459 | val_units.append(val_unit)
460 | if idx < len_ and toks[idx] in ORDER_OPS:
461 | order_type = toks[idx]
462 | idx += 1
463 | if idx < len_ and toks[idx] == ',':
464 | idx += 1 # skip ','
465 | else:
466 | break
467 |
468 | return idx, (order_type, val_units)
469 |
470 |
471 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
472 | idx = start_idx
473 | len_ = len(toks)
474 |
475 | if idx >= len_ or toks[idx] != 'having':
476 | return idx, []
477 |
478 | idx += 1
479 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
480 | return idx, conds
481 |
482 |
483 | def parse_limit(toks, start_idx):
484 | idx = start_idx
485 | len_ = len(toks)
486 |
487 | if idx < len_ and toks[idx] == 'limit':
488 | idx += 2
489 | # make limit value can work, cannot assume put 1 as a fake limit number
490 | if type(toks[idx-1]) != int:
491 | return idx, 1
492 |
493 | return idx, int(toks[idx-1])
494 |
495 | return idx, None
496 |
497 |
498 | def parse_sql(toks, start_idx, tables_with_alias, schema):
499 | isBlock = False # indicate whether this is a block of sql/sub-sql
500 | len_ = len(toks)
501 | idx = start_idx
502 |
503 | sql = {}
504 | if toks[idx] == '(':
505 | isBlock = True
506 | idx += 1
507 |
508 | # parse from clause in order to get default tables
509 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
510 | sql['from'] = {'table_units': table_units, 'conds': conds}
511 | # select clause
512 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
513 | idx = from_end_idx
514 | sql['select'] = select_col_units
515 | # where clause
516 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
517 | sql['where'] = where_conds
518 | # group by clause
519 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
520 | sql['groupBy'] = group_col_units
521 | # having clause
522 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
523 | sql['having'] = having_conds
524 | # order by clause
525 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
526 | sql['orderBy'] = order_col_units
527 | # limit clause
528 | idx, limit_val = parse_limit(toks, idx)
529 | sql['limit'] = limit_val
530 |
531 | idx = skip_semicolon(toks, idx)
532 | if isBlock:
533 | assert toks[idx] == ')'
534 | idx += 1 # skip ')'
535 | idx = skip_semicolon(toks, idx)
536 |
537 | # intersect/union/except clause
538 | for op in SQL_OPS: # initialize IUE
539 | sql[op] = None
540 | if idx < len_ and toks[idx] in SQL_OPS:
541 | sql_op = toks[idx]
542 | idx += 1
543 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
544 | sql[sql_op] = IUE_sql
545 | return idx, sql
546 |
547 |
548 | def load_data(fpath):
549 | with open(fpath) as f:
550 | data = json.load(f)
551 | return data
552 |
553 |
554 | def get_sql(schema, query):
555 | toks = tokenize(query)
556 | tables_with_alias = get_tables_with_alias(schema.schema, toks)
557 | _, sql = parse_sql(toks, 0, tables_with_alias, schema)
558 |
559 | return sql
560 |
561 |
562 | def skip_semicolon(toks, start_idx):
563 | idx = start_idx
564 | while idx < len(toks) and toks[idx] == ";":
565 | idx += 1
566 | return idx
567 |
--------------------------------------------------------------------------------
/evaluation_bird_ex_ves.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | db_root_path="./data/bird/dev/dev_databases/"
4 | data_mode="dev"
5 | diff_json_path="./data/bird/dev/dev.json"
6 | predicted_sql_json_path="./outputs/bird/predict_dev.json"
7 | ground_truth_sql_path="./data/bird/dev/dev_gold.sql"
8 | num_cpus=12
9 | meta_time_out=30.0
10 | time_out=60
11 | mode_gt="gt"
12 | mode_predict="gpt"
13 |
14 | # evaluate EX
15 | echo "Evaluate BIRD EX begin!"
16 | python ./evaluation/evaluation_bird_ex.py --db_root_path $db_root_path \
17 | --predicted_sql_json_path $predicted_sql_json_path \
18 | --data_mode $data_mode \
19 | --ground_truth_sql_path $ground_truth_sql_path \
20 | --num_cpus $num_cpus \
21 | --mode_predict $mode_predict \
22 | --diff_json_path $diff_json_path \
23 | --meta_time_out $meta_time_out
24 | echo "Evaluate EX done!"
25 |
26 | # evaluate VES
27 | echo "Evaluate BIRD VES begin!"
28 | python ./evaluation/evaluation_bird_ves.py \
29 | --db_root_path $db_root_path \
30 | --predicted_sql_json_path $predicted_sql_json_path \
31 | --data_mode $data_mode \
32 | --ground_truth_sql_path $ground_truth_sql_path \
33 | --num_cpus $num_cpus --meta_time_out $time_out \
34 | --mode_gt $mode_gt --mode_predict $mode_predict \
35 | --diff_json_path $diff_json_path
36 | echo "Evaluate VES done!"
--------------------------------------------------------------------------------
/main_scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__init__.py
--------------------------------------------------------------------------------
/main_scripts/__pycache__/MAG.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/MAG.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/__pycache__/bridge_content_encoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/bridge_content_encoder.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/__pycache__/chat_manager.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/chat_manager.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/__pycache__/const.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/const.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/__pycache__/llm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/llm.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotXWX/MAG-SQL/38905841c4622eb7279744fc022080ec2331d3b4/main_scripts/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/main_scripts/bridge_content_encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2020, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 |
7 | Encode DB content.
8 | """
9 |
10 | import difflib
11 | from typing import List, Optional, Tuple
12 | from rapidfuzz import fuzz
13 | import sqlite3
14 | import functools
15 |
16 | # fmt: off
17 | _stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again',
18 | 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but',
19 | "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is',
20 | "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being',
21 | 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be',
22 | 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor',
23 | 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because',
24 | 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have',
25 | 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours',
26 | 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his',
27 | "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what',
28 | 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan',
29 | 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here',
30 | 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself',
31 | 'then', 'did', 'just', "aren't"}
32 | # fmt: on
33 |
34 | _commonwords = {"no", "yes", "many"}
35 |
36 |
37 | def is_number(s: str) -> bool:
38 | try:
39 | float(s.replace(",", ""))
40 | return True
41 | except:
42 | return False
43 |
44 |
45 | def is_stopword(s: str) -> bool:
46 | return s.strip() in _stopwords
47 |
48 |
49 | def is_commonword(s: str) -> bool:
50 | return s.strip() in _commonwords
51 |
52 |
53 | def is_common_db_term(s: str) -> bool:
54 | return s.strip() in ["id"]
55 |
56 |
57 | class Match(object):
58 | def __init__(self, start: int, size: int) -> None:
59 | self.start = start
60 | self.size = size
61 |
62 |
63 | def is_span_separator(c: str) -> bool:
64 | return c in "'\"()`,.?! "
65 |
66 |
67 | def split(s: str) -> List[str]:
68 | return [c.lower() for c in s.strip()]
69 |
70 |
71 | def prefix_match(s1: str, s2: str) -> bool:
72 | i, j = 0, 0
73 | for i in range(len(s1)):
74 | if not is_span_separator(s1[i]):
75 | break
76 | for j in range(len(s2)):
77 | if not is_span_separator(s2[j]):
78 | break
79 | if i < len(s1) and j < len(s2):
80 | return s1[i] == s2[j]
81 | elif i >= len(s1) and j >= len(s2):
82 | return True
83 | else:
84 | return False
85 |
86 |
87 | def get_effective_match_source(s: str, start: int, end: int) -> Match:
88 | _start = -1
89 |
90 | for i in range(start, start - 2, -1):
91 | if i < 0:
92 | _start = i + 1
93 | break
94 | if is_span_separator(s[i]):
95 | _start = i
96 | break
97 |
98 | if _start < 0:
99 | return None
100 |
101 | _end = -1
102 | for i in range(end - 1, end + 3):
103 | if i >= len(s):
104 | _end = i - 1
105 | break
106 | if is_span_separator(s[i]):
107 | _end = i
108 | break
109 |
110 | if _end < 0:
111 | return None
112 |
113 | while _start < len(s) and is_span_separator(s[_start]):
114 | _start += 1
115 | while _end >= 0 and is_span_separator(s[_end]):
116 | _end -= 1
117 |
118 | return Match(_start, _end - _start + 1)
119 |
120 |
121 | def get_matched_entries(
122 | s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85
123 | ) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]:
124 | if not field_values:
125 | return None
126 |
127 | if isinstance(s, str):
128 | n_grams = split(s)
129 | else:
130 | n_grams = s
131 |
132 | matched = dict()
133 | for field_value in field_values:
134 | if not isinstance(field_value, str):
135 | continue
136 | fv_tokens = split(field_value)
137 | sm = difflib.SequenceMatcher(None, n_grams, fv_tokens)
138 | match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens))
139 | if match.size > 0:
140 | source_match = get_effective_match_source(
141 | n_grams, match.a, match.a + match.size
142 | )
143 | if source_match: # and source_match.size > 1
144 | match_str = field_value[match.b : match.b + match.size]
145 | source_match_str = s[
146 | source_match.start : source_match.start + source_match.size
147 | ]
148 | c_match_str = match_str.lower().strip()
149 | c_source_match_str = source_match_str.lower().strip()
150 | c_field_value = field_value.lower().strip()
151 | if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str)
152 | if (
153 | is_stopword(c_match_str)
154 | or is_stopword(c_source_match_str)
155 | or is_stopword(c_field_value)
156 | ):
157 | continue
158 | if c_source_match_str.endswith(c_match_str + "'s"):
159 | match_score = 1.0
160 | else:
161 | if prefix_match(c_field_value, c_source_match_str):
162 | match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100
163 | else:
164 | match_score = 0
165 | if (
166 | is_commonword(c_match_str)
167 | or is_commonword(c_source_match_str)
168 | or is_commonword(c_field_value)
169 | ) and match_score < 1:
170 | continue
171 | s_match_score = match_score
172 | if match_score >= m_theta and s_match_score >= s_theta:
173 | if field_value.isupper() and match_score * s_match_score < 1:
174 | continue
175 | matched[match_str] = (
176 | field_value,
177 | source_match_str,
178 | match_score,
179 | s_match_score,
180 | match.size,
181 | )
182 |
183 | if not matched:
184 | return None
185 | else:
186 | return sorted(
187 | matched.items(),
188 | key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]),
189 | reverse=True,
190 | )
191 |
192 |
193 | @functools.lru_cache(maxsize=1000, typed=False)
194 | def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list:
195 | fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name)
196 | try:
197 | conn = sqlite3.connect(db_path)
198 | conn.text_factory = bytes
199 | c = conn.cursor()
200 | c.execute(fetch_sql)
201 | picklist = set()
202 | for x in c.fetchall():
203 | if isinstance(x[0], str):
204 | picklist.add(x[0].encode("utf-8"))
205 | elif isinstance(x[0], bytes):
206 | try:
207 | picklist.add(x[0].decode("utf-8"))
208 | except UnicodeDecodeError:
209 | picklist.add(x[0].decode("latin-1"))
210 | else:
211 | picklist.add(x[0])
212 | picklist = list(picklist)
213 | except Exception as e:
214 | picklist = []
215 | finally:
216 | conn.close()
217 | return picklist
218 |
219 |
220 | def get_database_matches(
221 | question: str,
222 | table_name: str,
223 | column_name: str,
224 | db_path: str,
225 | top_k_matches: int = 2,
226 | match_threshold: float = 0.85,
227 | ) -> List[str]:
228 | picklist = get_column_picklist(
229 | table_name=table_name, column_name=column_name, db_path=db_path
230 | )
231 | # only maintain data in ``str'' type
232 | picklist = [ele.strip() for ele in picklist if isinstance(ele, str)]
233 | # picklist is unordered, we sort it to ensure the reproduction stability
234 | picklist = sorted(picklist)
235 |
236 | matches = []
237 | if picklist and isinstance(picklist[0], str):
238 | matched_entries = get_matched_entries(
239 | s=question,
240 | field_values=picklist,
241 | m_theta=match_threshold,
242 | s_theta=match_threshold,
243 | )
244 |
245 | if matched_entries:
246 | num_values_inserted = 0
247 | for _match_str, (
248 | field_value,
249 | _s_match_str,
250 | match_score,
251 | s_match_score,
252 | _match_size,
253 | ) in matched_entries:
254 | if "name" in column_name and match_score * s_match_score < 1:
255 | continue
256 | if table_name != "sqlite_sequence": # Spider database artifact
257 | matches.append(field_value.strip())
258 | num_values_inserted += 1
259 | if num_values_inserted >= top_k_matches:
260 | break
261 | return matches
--------------------------------------------------------------------------------
/main_scripts/chat_manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from main_scripts.MAG import Soft_Schema_linker, Decomposer, Generator, Refiner
3 | from main_scripts.const import MAX_ROUND, SYSTEM_NAME, SCHEMALINKER_NAME, DECOMPOSER_NAME, REFINER_NAME
4 |
5 | INIT_LOG__PATH_FUNC = None
6 | LLM_API_FUC = None
7 | try:
8 | from main_scripts import llm
9 | LLM_API_FUC = llm.safe_call_llm
10 | INIT_LOG__PATH_FUNC = llm.init_log_path
11 | print(f"Use func from main_scripts.llm in chat_manager.py")
12 | except:
13 | print("Please prepare the code related to the large language model.")
14 | import time
15 | from pprint import pprint
16 |
17 |
18 | class ChatManager(object):
19 | def __init__(self, data_path: str, tables_json_path: str, log_path: str, model_name: str, dataset_name:str, dataset_path:str, lazy: bool=False, without_selector: bool=False):
20 | self.data_path = data_path # root path to database dir, including all databases
21 | self.tables_json_path = tables_json_path # path to table description json file
22 | self.log_path = log_path # path to record important printed content during running
23 | self.model_name = model_name # name of base LLM called by agent
24 | self.dataset_name = dataset_name
25 | self.ping_network()
26 |
27 | self.chat_group = [
28 | Soft_Schema_linker(data_path=self.data_path, tables_json_path=self.tables_json_path, model_name=self.model_name, dataset_name=dataset_name, dataset_path=dataset_path, lazy=lazy, without_selector=without_selector),
29 | Decomposer(),
30 | Generator(dataset_name=dataset_name),
31 | Refiner(data_path=self.data_path, dataset_name=dataset_name)
32 | ]
33 | INIT_LOG__PATH_FUNC(log_path)
34 |
35 | def ping_network(self):
36 | # check network status
37 | print("Checking network status...", flush=True)
38 | try:
39 | _ = LLM_API_FUC("Hello world!")
40 | print("Network is available", flush=True)
41 | except Exception as e:
42 | raise Exception(f"Network is not available: {e}")
43 |
44 | def _chat_single_round(self, message: dict):
45 | # we use `dict` type so value can be changed in the function
46 | for agent in self.chat_group: # check each agent in the group
47 | if message['send_to'] == agent.name:
48 | agent.talk(message)
49 |
50 | def start(self, user_message: dict):
51 | # we use `dict` type so value can be changed in the function
52 | start_time = time.time()
53 | if user_message['send_to'] == SYSTEM_NAME: # in the first round, pass message to prune
54 | #user_message['send_to'] = SELECTOR_NAME
55 | user_message['send_to'] = SCHEMALINKER_NAME
56 | for _ in range(MAX_ROUND): # start chat in group
57 | self._chat_single_round(user_message)
58 | if user_message['send_to'] == SYSTEM_NAME: # should terminate chat
59 | break
60 | end_time = time.time()
61 | exec_time = end_time - start_time
62 | print(f"\033[0;34mExecute {exec_time} seconds\033[0m", flush=True)
63 |
--------------------------------------------------------------------------------
/main_scripts/const.py:
--------------------------------------------------------------------------------
1 | MAX_ROUND = 8 # MAX_ROUND + 2 is the max try times of one task
2 |
3 |
4 | SCHEMALINKER_NAME = 'Soft_Schema_Linker'
5 | DECOMPOSER_NAME = 'Decomposer'
6 | GENERATOR_NAME = 'Generator'
7 | REFINER_NAME = 'Refiner'
8 | SYSTEM_NAME = 'System'
9 |
10 |
11 | schema_linker_template = """
12 | As an experienced and professional database administrator, your task is to analyze a user question and a database schema to provide relevant information. The database schema consists of table descriptions, each containing multiple column descriptions. Your goal is to extract the entities from question and identify the relevant tables and columns based on these entities and the evidence provided.
13 |
14 | [Instruction]:
15 | 1. Extract the mentioned entities from the user question. Make sure all of the entities are extracted.
16 | 2. For each entity, keep at least 3 related columns.
17 | 4. Your output should include entity extraction, analysis and related database schema.
18 | 5. The related database schema should be in JSON format.
19 | 6. Each column's information in provided 【Schema】 is in this format: (column, description. Value examples)
20 |
21 | [Requirements]:
22 | 1. Sort the related columns in each list corresponding to each entity in descending order of relevance.
23 | 2. The chosen columns should be in this format: .
24 | 3. Make sure each chosen list is not empty. The value [] will be punished.
25 | 4.【Matched values】 may contain redundant or useless information in addition to the correct matching values, so you need to select the useful information in conjunction with the specific column names and descriptions.
26 | 5. An entity may not have a corresponding evidence, which requires you to find the relevant columns yourself through your understanding of the database schema.
27 | 6. There are columns with similar names but different meanings, you can find the most accurate columns in the right table based on the 【Summary of each table】.
28 |
29 | Here is a typical example:
30 |
31 | ==========
32 | 【DB_ID】 banking_system
33 | 【Schema】
34 | # Table: account
35 | [
36 | (account_id , the id of the account. Value examples: [11382, 11362, 2, 1, 2367].),
37 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].),
38 | (frequency , frequency of the acount. Value examples: ['POPLATEK MESICNE', 'POPLATEK TYDNE', 'POPLATEK PO OBRATU'].),
39 | (date , the creation date of the account. Value examples: ['1997-12-29', '1997-12-28'].)
40 | ]
41 | # Table: client
42 | [
43 | (client_id , the unique number. Value examples: [13998, 13971, 2, 1, 2839].),
44 | (gender , gender. Value examples: ['M', 'F']. And F:female . M:male ),
45 | (birth_date , birth date. Value examples: ['1987-09-27', '1986-08-13'].),
46 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].),
47 | (first_name , first_name.),
48 | (last_name , last_name.)
49 | ]
50 | # Table: loan
51 | [
52 | (loan_id , the id number identifying the loan data. Value examples: [4959, 4960, 4961].),
53 | (account_id , the id number identifying the account. Value examples: [10, 80, 55, 43].),
54 | (date , the date when the loan is approved. Value examples: ['1998-07-12', '1998-04-19'].),
55 | (amount , the id number identifying the loan data. Value examples: [1567, 7877, 9988].),
56 | (duration , the id number identifying the loan data. Value examples: [60, 48, 24, 12, 36].),
57 | (payments , the id number identifying the loan data. Value examples: [3456, 8972, 9845].),
58 | (status , the id number identifying the loan data. Value examples: ['C', 'A', 'D', 'B'].)
59 | ]
60 | # Table: district
61 | [
62 | (district_id , location of branch. Value examples: [77, 76].),
63 | (A2 , area in square kilometers. Value examples: [50.5, 48.9].),
64 | (A4 , number of inhabitants. Value examples: [95907, 95616].),
65 | (A5 , number of households. Value examples: [35678, 34892].),
66 | (A6 , literacy rate. Value examples: [95.6, 92.3, 89.7].),
67 | (A7 , number of entrepreneurs. Value examples: [1234, 1456].),
68 | (A8 , number of cities. Value examples: [5, 4].),
69 | (A9 , number of schools. Value examples: [15, 12, 10].),
70 | (A10 , number of hospitals. Value examples: [8, 6, 4].),
71 | (A11 , average salary. Value examples: [12541.5, 11277].),
72 | (A12 , poverty rate. Value examples: [12.4, 9.8].),
73 | (A13 , unemployment rate. Value examples: [8.2, 7.9].),
74 | (A15 , number of crimes. Value examples: [256, 189].)
75 | ]
76 | 【Primary keys】
77 | account.`account_id` | client.`client_id` | loan.`loan_id` | district.`district_id`
78 | 【Foreign keys】
79 | client.`district_id` = district.`district_id`
80 | 【Question】
81 | What is the gender of the youngest client who opened account in the lowest average salary branch and when did this client open the account? Please list their full name.
82 | 【Evidence】
83 | Later birthdate refers to younger age; A11 refers to average salary; Full name refers to first_name, last_name
84 | 【Matched values】
85 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you in selecting columns (You need to ignore matches that are not relevant to the question):
86 | No matched values.
87 | 【Summary of each table】
88 | # account: Specific information for each account
89 | # client: Basic information about each client
90 | # loan: Detailed records of each loan
91 | # district: Various data recorded in each district
92 | 【Answer】
93 | The entities extracted from the 【Question】 are:
94 | 1. gender;
95 | 2. youngest client;
96 | 3. account;
97 | 4. lowest average salary branch;
98 | 5. when did this client open the account;
99 | 6. full name
100 |
101 | Extract the related evidence in 【Evidence】 for entities if 【Evidence】 is not None:
102 | 1. gender --no related evidence **the columns about gender are related**
103 | 2. youngest client --Later birthdate refers to younger age; **the columns about birth date of client and the id of client are related**
104 | 3. account --no related evidence **the columns about account or account ids are related**
105 | 4. lowest average salary branch --A11 refers to average salary **the columns about average salary and branch are related**
106 | 5. when did this client open the account --no related evidence **the columns about time or date of the account are related**
107 | 6. full name --Full name refers to first_name, last_name **the columns about first_name, last_name of the client are related**
108 |
109 | Therefore, we can select the related database schema based on these entities with 【Evidence】:
110 | ```json
111 | {{
112 | "gender": ["client.gender","client.client_id","loan.status"],
113 | "youngest client": ["client.birth_date","client.client_id","account.date","loan.date"],
114 | "account": ["account.account_id","loan.account_id","account.date"],
115 | "lowest average salary branch": ["district.A11","district.district_id","district.A13"],
116 | "when did this client open the account": ["account.date","loan.date","client.birth_date"],
117 | "full name": ["client.first_name","client.last_name","client.client_id"]
118 | }}
119 | ```
120 | Question Solved.
121 |
122 | ==========
123 |
124 | Here is a new example, please start answering:
125 |
126 | 【DB_ID】 {db_id}
127 | 【Schema】
128 | {desc_str}
129 | 【Primary keys】
130 | {pk_str}
131 | 【Foreign keys】
132 | {fk_str}
133 | 【Question】
134 | {query}
135 | 【Evidence】
136 | {evidence}
137 | 【Matched values】
138 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you in selecting columns (You need to ignore matches that are not relevant to the question):
139 | {matched_list}
140 | 【Summary of each table】
141 | {summary_str}
142 | 【Answer】
143 | """
144 |
145 | summarizer_template = """
146 | [instruction]
147 | Given the database schema, you need to summarise the data stored in each table in one sentence, based on the name of the table and the columns in the table.
148 |
149 | [Requirements]
150 | - Your output should be in json format
151 |
152 | Here is an example:
153 | ==========
154 |
155 | 【DB_ID】 banking_system
156 | 【Schema】
157 | # Table: account
158 | [
159 | (account_id , the id of the account. Value examples: [11382, 11362, 2, 1, 2367].),
160 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].),
161 | (frequency , frequency of the acount. Value examples: ['POPLATEK MESICNE', 'POPLATEK TYDNE', 'POPLATEK PO OBRATU'].),
162 | (date , the creation date of the account. Value examples: ['1997-12-29', '1997-12-28'].)
163 | ]
164 | # Table: client
165 | [
166 | (client_id , the unique number. Value examples: [13998, 13971, 2, 1, 2839].),
167 | (gender , gender. Value examples: ['M', 'F']. And F:female . M:male ),
168 | (birth_date , birth date. Value examples: ['1987-09-27', '1986-08-13'].),
169 | (district_id , location of branch. Value examples: [77, 76, 2, 1, 39].),
170 | (first_name , first_name.),
171 | (last_name , last_name.)
172 | ]
173 | # Table: loan
174 | [
175 | (loan_id , the id number identifying the loan data. Value examples: [4959, 4960, 4961].),
176 | (account_id , the id number identifying the account. Value examples: [10, 80, 55, 43].),
177 | (date , the date when the loan is approved. Value examples: ['1998-07-12', '1998-04-19'].),
178 | (amount , the id number identifying the loan data. Value examples: [1567, 7877, 9988].),
179 | (duration , the id number identifying the loan data. Value examples: [60, 48, 24, 12, 36].),
180 | (payments , the id number identifying the loan data. Value examples: [3456, 8972, 9845].),
181 | (status , the id number identifying the loan data. Value examples: ['C', 'A', 'D', 'B'].)
182 | ]
183 | # Table: district
184 | [
185 | (district_id , location of branch. Value examples: [77, 76].),
186 | (A2 , area in square kilometers. Value examples: [50.5, 48.9].),
187 | (A4 , number of inhabitants. Value examples: [95907, 95616].),
188 | (A5 , number of households. Value examples: [35678, 34892].),
189 | (A6 , literacy rate. Value examples: [95.6, 92.3, 89.7].),
190 | (A7 , number of entrepreneurs. Value examples: [1234, 1456].),
191 | (A8 , number of cities. Value examples: [5, 4].),
192 | (A9 , number of schools. Value examples: [15, 12, 10].),
193 | (A10 , number of hospitals. Value examples: [8, 6, 4].),
194 | (A11 , average salary. Value examples: [12541.5, 11277].),
195 | (A12 , poverty rate. Value examples: [12.4, 9.8].),
196 | (A13 , unemployment rate. Value examples: [8.2, 7.9].),
197 | (A15 , number of crimes. Value examples: [256, 189].)
198 | ]
199 |
200 | 【Summary】
201 | ```json
202 | {{
203 | "account":"Specific information for each account",
204 | "client":"Basic information about each client",
205 | "loan":"Detailed records of each loan",
206 | "district":"Various data recorded in each district",
207 |
208 | }}
209 | ```
210 | ==========
211 |
212 | Here is a new case:
213 | 【DB_ID】 {db_id}
214 | 【Schema】
215 | {desc_str}
216 |
217 | 【Summary】
218 | """
219 |
220 |
221 | subq_pattern = r"Sub question\s*\d+\s*:"
222 |
223 |
224 | pure_decomposer_template = """
225 | [Instruction]
226 | Given a 【query】, you need to understanding the intent of Query, and then deceompose it into Targets and Conditions. Then you need to combine Targets and Conditions into Subquerys step by step.
227 | For the case where Conditions is NULL, consider Targets as the final Subquery directly.
228 | For the case where Conditions are not NULL, combine Targets and the first Condition to get the first Subquery, then combine this Subquery and the next Condition into a new Subquery until all Conditions are used (which means the content of the last Subquery and the original Query is the same).
229 |
230 | [Requirements]
231 | -Try not to overlap Targets and Conditions.
232 | -Make sure the decomposed Target and Condition can cover all of the information in Query.
233 | -Don't change any information (specific value) in Query!
234 | -Mark each Subquery with ## in front of it.
235 |
236 | Here are some examples:
237 | ==========
238 |
239 | 【Query】
240 | Show the stadium name and the number of concerts in each stadium. Please also list the year the stadium was built.
241 | 【Evidence】
242 | NULL
243 |
244 | 【Decomposition】
245 | Targets: List the stadium name, the year the stadium built and the number of concerts in each stadium
246 | Conditions: NULL
247 |
248 | Subqueries:
249 | 1. Since Conditions is NULL, the final Subquery is the Targets.
250 | ##Subquery: List the stadium name, the year the stadium built and the number of concerts in each stadium
251 | ==========
252 |
253 | 【Query】
254 | What is the qualification rate for the H-11 products produced in 2023/11/2?
255 | 【Evidence】
256 | qualification rate = `Numqualified(H-11)` / `production(H-11)`
257 |
258 | 【Decomposition】
259 | Targets: List the qualification rate for the H-11 Products
260 | Conditions:
261 | 1. produced in 2023/11/2 --Condition_1
262 |
263 | Subqueries:
264 | 1. Combine Targets and Conditon_1 to get the first Subquery.
265 | ##Subquery: List the qualification rate for the H-11 Products produced in 2023/11/2
266 |
267 | ==========
268 |
269 | 【Query】
270 | List the race of institutions in Alabama with number of students greater than the 90% of average number of students of all institutions?
271 | 【Evidence】
272 | Alabama refers to state = 'Alabama'; number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count
273 |
274 | 【Decomposition】
275 | Targets: List the race of institutions
276 | Conditions:
277 | 1. in Alabama --Condition_1
278 | 2. number of students greater than the 90% of average number of students of all institutions --Condition_2
279 |
280 | Subqueries:
281 | 1. Combine Targets and Condition_1 to get the first Subquery.
282 | ##Subquery: List the race of institution in Alabama
283 | 2. Conbine the first Subquery and Conditon_2 to get the seconed Subquery.
284 | ##Subquery: List the race of institutions in Alabama with number of number of students greater than the 90% of average number of students of all institutions
285 |
286 | ==========
287 |
288 | 【Query】
289 | Which president got a higher approval rating, Joseph Biden or Donald Trump?
290 | 【Evidence】
291 | NULL
292 |
293 | 【Decomposition】
294 | Targets: List the name of the president who got a higher approval rating between Joseph Biden or Donald Trump
295 | Conditions:
296 | NULL
297 |
298 | Subqueries:
299 | 1. Since Conditions is NULL, the final Subquery is the Targets.
300 | ##Subquery: List the name of the president who got a higher approval rating between Joseph Biden or Donald Trump
301 |
302 | ==========
303 |
304 | 【Query】
305 | For movie id 1269, how many users, who was a paying subscriber and was eligible for trial when he rated the movie, gave the movie a rating score of less than or equal to 2?
306 | 【Evidence】
307 | NULL
308 |
309 | 【Decomposition】
310 | Targets: List the number of users
311 | Conditions:
312 | 1. was a paying subscriber --Condition_1
313 | 2. was eligible for trial --Condition_2
314 | 3. for movie id 1269, gave the movie a rating score of less than or equal to 2 --Condition3
315 |
316 | Subquerys:
317 | 1. Combine Targets and Conditon_1 to get the first Subquery.
318 | ##Subquery: List the number of users who was a paying subscriber
319 | 2. Combine the first Subquery and Condition_2 to get the second Subquery.
320 | ##Subquery: List the number of users who was a paying subscriber and was eligible for trial
321 | 3. Combine the second Subquery and Condition_3 to get the third Subquery.
322 | ##Subquery: List the number of users who was a paying subscriber and was eligible for trial and gave the movie whose id is 1269 a rating score of less than or equal to 2
323 |
324 | ==========
325 |
326 | Here is a new query need to be decomposed:
327 |
328 | 【Query】
329 | {query}
330 | 【Evidence】
331 | {evidence}
332 |
333 | 【Decomposition】
334 | """
335 |
336 | soft_schema_initial_generator_template = """
337 | Given a 【Database schema】 description, a knowledge 【Evidence】and a 【Question】, you need to use valid SQLite and understand the database and knowledge so that you can generate a good SQL for the 【Question】.
338 | When generating SQL, we should always consider constraints:
339 | 【Constraints】
340 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value
341 | - In `FROM ` or `JOIN `, do not include unnecessary table
342 | - If use max or min func, `JOIN ` FIRST, THEN use `SELECT MAX()` or `SELECT MIN()`
343 | - If [Value examples] of has 'None' or None, use `JOIN ` or `WHERE is NOT NULL` is better
344 | - If use `ORDER BY ASC|DESC`, add `GROUP BY ` before to select distinct values
345 | - If include more than one table, use `JOIN `
346 | - If use `JOIN `, the connected columns should be in the 【Foreign keys】
347 | - If evidence gives a formula for calculating a value, try to use that formula
348 | - If use `ORDER BY ASC LIMIT `, please use `ORDER BY ASC NULLS LAST LIMIT ` to make sure the null values will not be selected
349 |
350 | ==========
351 |
352 | 【Database schema】
353 | # stadium: [Stadium_ID (INTEGER), Location (TEXT), Name (TEXT)]
354 | # concert: [concert_ID (INTEGER), concert_Name (TEXT), Stadium_ID (INTEGER)]
355 | 【Primary keys】
356 | stadium.`Stadium_ID` | concert.`concert_ID`
357 | 【Foreign keys】
358 | concert.`Stadium_ID` = stadium.`Stadium_ID`
359 | 【Detailed descriptions of tables and columns】
360 | stadium.`Stadium_ID`: The column 'Stadium_ID' in Table has column descriptions of "stadium id". Value examples: [1, 2, 3, 4, 5, 6].
361 | stadium.`Name`: The column 'Name' in Table has column descriptions of "name of stadium". Value examples: ["Stark's Park", 'Somerset Park', 'Recreation Park', 'Hampden Park', 'Glebe Park', 'Gayfield Park'].
362 | stadium.`Location`: The column 'Location' in Table has column descriptions of "location of stadium". Value examples: ['Stirling Albion', 'Raith Rovers', "Queen's Park", 'Peterhead', 'East Fife', 'Brechin City'].
363 | concert.`concert_ID`: The column 'concert_ID' in Table has column descriptions of "concert id". Value examples: [1, 2, 3, 4, 5, 6].
364 | concert.`concert_Name`: The column 'concert_Name' in Table has column descriptions of "concert name". Value examples: ['Week 1', 'Week 2', 'Super bootcamp', 'Home Visits', 'Auditions'].
365 | concert.`Stadium_ID`: The column 'Stadium_ID' in Table has column descriptions of "stadium id". Value examples: [2, 9, 7, 10, 1].
366 | 【Evidence】
367 | NULL
368 | 【Question】
369 | Show the stadium name and the number of concerts in each stadium.
370 | 【Matched values】
371 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
372 | No matched values.
373 |
374 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step:
375 | 【Question】: Show the stadium name and the number of concerts in each stadium.
376 | Targets for `SELECT`: the stadium name and the number of concerts
377 | hints from 【Evidence】:NULL
378 | For the entities in the 【Question】, get corresponding columns from 【Database schema】 with hints: the stadium name refers to stadium.`Name`, the number of concerts refers to COUNT(concert.`concert_ID`), each stadium refers to stadium.`Stadium_ID`
379 | Connection of tables: include tables and , get connected keys from 【Foreign keys】: concert.`Stadium_ID` = stadium.`Stadium_ID`
380 | Final SQL:
381 | ```sql
382 | SELECT T1.`Name`, COUNT(T2.concert_ID) AS num_concerts
383 | FROM stadium AS T1
384 | JOIN concert AS T2
385 | ON T1.`Stadium_ID` = T2.`Stadium_ID`
386 | GROUP BY T1.`Stadium_ID`
387 | ```
388 |
389 | Question Solved.
390 |
391 | ==========
392 |
393 | 【Database schema】
394 | # country: [origin (INTEGER), country(TEXT)]
395 | # price: [ID (INTEGER), price (REAL)]
396 | # data: [ID (INTEGER), mpg (REAL), cylinders (INTEGER), displacement (TEXT), horsepower (REAL), weight (REAL), acceleration (REAL), model (TEXT), car_name (TEXT)]
397 | # production: [ID (INTEGER), model_year (INTEGER), country (INTEGER)]
398 | 【Primary keys】
399 | country.`origin` | price.`ID` | data.`ID` | production.`ID`
400 | 【Foreign keys】
401 | data.`ID` = price.`ID`
402 | production.`ID` = price.`ID`
403 | production.`ID` = data.`ID`
404 | production.`country` = country.`origin`
405 | 【Detailed descriptions of tables and columns】
406 | country.`origin`: The column 'origin' in Table has column descriptions of "the unique identifier for the origin country". Value examples: [1, 2, 3].
407 | country.`country`: The column 'country' in Table has column descriptions of "the origin country of the car". Value examples: ['USA', 'Japan', 'Europe'].
408 | data.`horsepower`: The column 'horsepower' in Table has column descriptions of "horse power associated with the car".
409 | data.`acceleration`: The column 'acceleration' in Table has column descriptions of "acceleration of the car in miles per squared hour".
410 | production.`country`: The column 'country' in Table has column descriptions of "country id to which the car belongs". Value examples: [1, 2, 3].
411 | 【Evidence】
412 | the fastest refers to max(horsepower); name of the car refers to car_name
413 | 【Question】
414 | What is the fastest car made by Japan?
415 | 【Matched values】
416 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
417 | country.`country` = 'Japan'
418 |
419 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step:
420 | 【Question】: What is the fastest car made by Japan?
421 | Targets for `SELECT`: the name of the fastest car
422 | hints from 【Evidence】: the fastest refers to max(horsepower); name of the car refers to car_name
423 | For the entities in the 【Question】, get corresponding columns from 【Database schema】 with hints: car refers to data.`car_name`, the fastest refers to MAX(data.`horsepower`), Japan refers to country.`country`
424 | Connection of tables: includes tables and , since and are not connected directly, use as a bridge, get connected keys from 【Foreign keys】: production.`ID` = data.`ID`, production.`country` = country.`origin`
425 | Final SQL:
426 | ```sql
427 | SELECT T1.`car_name`
428 | FROM data AS T1
429 | INNER JOIN production AS T2
430 | ON T1.`ID` = T2.`ID`
431 | INNER JOIN country AS T3
432 | ON T3.`origin` = T2.`country`
433 | WHERE T3.`country` = 'Japan'
434 | ORDER BY T1.`horsepower` DESC LIMIT 1
435 | ```
436 |
437 | Question Solved.
438 |
439 | ==========
440 |
441 | 【Database schema】
442 | # institution_details: [unitid (INTEGER), chronname (TEXT), city (TEXT), state (TEXT), site (TEXT), student_count(TEXT)]
443 | # institution_grads: [unitid (INTEGER), gender (TEXT), race (TEXT), cohhort (TEXT)]
444 | 【Primary keys】
445 | institution_details.`unitid` | institution_grads.`unitid`
446 | 【Foreign keys】
447 | institution_grads.`unitid` = institution_details.`unitid`
448 | 【Detailed descriptions of tables and columns】
449 | institution_details.`unitid`: The column 'unitid' in Table has column descriptions of "Education Unit ID number".
450 | institution_details.`chronname`: The column 'chronname' in Table has column descriptions of "Institution name".
451 | institution_details.`student_count`: The column 'student_count' in Table has column descriptions of "Total number of undergraduates in 2010".
452 | 【Evidence】
453 | number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count
454 | 【Question】
455 | List the chronname of institutions with number of students greater than the 90% of average number of students of all institutions?
456 | 【Matched values】
457 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
458 | No matched values.
459 |
460 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step:
461 | 【Question】: List the chroname of institutions with number of students greater than the 90% of average number of students of all institutions?
462 | Targets for `SELECT`: the chroname of institutions
463 | hints from 【Evidence】: number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count
464 | There is a formula in 【Evidence】, so we can turn it into SQL format: institution_details.`student_count` > ( SELECT AVG(`student_count`) * 0.9 FROM institution_details )
465 | For the entities in the 【Question】, get corresponding columns from 【Database schema】 with hints: the chronname of institutions refers to institution_details.`chronname`, number of students refers to institution_details.`student_count`
466 | Connection of tables: include only one table , no connection
467 | Final SQL:
468 | ```sql
469 | SELECT DISTINCT `chroname`
470 | FROM institution_details
471 | WHERE `student_count` > (
472 | SELECT AVG(`student_count`) * 0.9 FROM institution_details
473 | )
474 | ```
475 |
476 | Question Solved.
477 |
478 | ==========
479 |
480 | 【Database schema】
481 | {desc_str}
482 | 【Primary keys】
483 | {pk_str}
484 | 【Foreign keys】
485 | {fk_str}
486 | 【Detailed descriptions of tables and columns】
487 | {detailed_str}
488 | 【Evidence】
489 | {evidence}
490 | 【Question】
491 | {query}
492 | 【Matched values】
493 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
494 | {matched_list}
495 |
496 | Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, select columns from 【Database schema】 and then generate SQL for 【Question】, you need to think step by step:
497 | """
498 |
499 | soft_schema_continuous_generator_template = """
500 | Given a 【Database schema】 description, a knowledge 【Evidence】, a 【Question】, a 【Subquesion】 extracted from 【Question】 and a 【Sub-SQL】 for the 【Subquestion】, you need to use valid SQLite and understand the database and knowledge, and then generate a complete SQL for the 【Question】 based on 【Sub-SQL】.
501 | When generating SQL, we should always consider constraints:
502 | 【Constraints】
503 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value
504 | - In `FROM ` or `JOIN `, do not include unnecessary table
505 | - If use max or min func, `JOIN ` FIRST, THEN use `SELECT MAX()` or `SELECT MIN()`
506 | - If [Value examples] of has 'None' or None, use `JOIN ` or `WHERE is NOT NULL` is better
507 | - If use `ORDER BY ASC|DESC`, add `GROUP BY ` before to select distinct values
508 | - If include more than one table, use `JOIN `
509 | - If use `JOIN `, the connected columns should be in the 【Foreign keys】
510 | - If evidence gives a formula for calculating a value, try to use that formula
511 | - If use `ORDER BY ASC LIMIT `, please use `ORDER BY ASC NULLS LAST LIMIT ` to make sure the null values will not be selected
512 |
513 | ==========
514 |
515 | 【Database schema】
516 | # country: [origin (INTEGER), country(TEXT)]
517 | # price: [ID (INTEGER), price (REAL)]
518 | # data: [ID (INTEGER), mpg (REAL), cylinders (INTEGER), displacement (TEXT), horsepower (REAL), weight (REAL), acceleration (REAL), model (TEXT), car_name (TEXT)]
519 | # production: [ID (INTEGER), model_year (INTEGER), country (INTEGER)]
520 | 【Primary keys】
521 | country.`origin` | price.`ID` | data.`ID` | production.`ID`
522 | 【Foreign keys】
523 | data.`ID` = price.`ID`
524 | production.`ID` = price.`ID`
525 | production.`ID` = data.`ID`
526 | production.`country` = country.`origin`
527 | 【Detailed descriptions of tables and columns】
528 | country.`origin`: The column 'origin' in Table has column descriptions of "the unique identifier for the origin country". Value examples: [1, 2, 3].
529 | country.`country`: The column 'country' in Table has column descriptions of "the origin country of the car". Value examples: ['USA', 'Japan', 'Europe'].
530 | data.`horsepower`: The column 'horsepower' in Table has column descriptions of "horse power associated with the car".
531 | data.`acceleration`: The column 'acceleration' in Table has column descriptions of "acceleration of the car in miles per squared hour".
532 | production.`country`: The column 'country' in Table has column descriptions of "country id to which the car belongs". Value examples: [1, 2, 3].
533 | 【Evidence】
534 | the fastest refers to max(horsepower); made by Japan refers to country = 'Japan'
535 | 【Question】
536 | What is the price of the fastest car made by Japan?
537 | 【Subquestion】
538 | What is the price of the fastest car?
539 | 【Sub-SQL】
540 | ```
541 | SELECT T1.`price` FROM price AS T1 INNER JOIN data AS T2 ON T2.`ID` = T1.`ID` ORDER BY T2.`horsepower` DESC LIMIT 1
542 | ```
543 | 【Matched values】
544 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
545 | No matched values.
546 |
547 | 【Subquestion】 is decomposed from 【Question】, now we provide Sub-SQL corresponding to Subquestion, you just need to complete the remaining conditions based on Sub-SQL to generate SQL for 【Question】. Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, and then generate SQL after thinking step by step:
548 | hints from 【Evidence】: the fastest refers to max(horsepower); made by Japan refers to country = 'Japan';
549 | Find necessary columns from 【Database schema】 in addition to Sub-SQL: Japan refers to country.`country`
550 | Connection of tables: the connection of and is needed, since and are not connected directly, use as a bridge, get connected keys from 【Foreign keys】: production.`ID` = data.`ID`, production.`country` = country.`origin`
551 | Final SQL:
552 | ```sql
553 | SELECT T1.`price`
554 | FROM price AS T1
555 | INNER JOIN data AS T2
556 | on T2.`ID` = T1.`ID`
557 | INNER JOIN production AS T3
558 | ON T3.`ID` = T2.`ID`
559 | INNER JOIN country AS T4
560 | ON T4.`origin` = T3.`country`
561 | WHERE T4.`country` = 'Japan'
562 | ORDER BY T2.`horsepower` DESC LIMIT 1
563 | ```
564 |
565 | Question solved.
566 |
567 | ==========
568 |
569 | 【Database schema】
570 | # institution_details: [unitid (INTEGER), chronname (TEXT), city (TEXT), state (TEXT), site (TEXT), student_count(TEXT)]
571 | # institution_grads: [unitid (INTEGER), gender (TEXT), race (TEXT), cohhort (TEXT)]
572 | 【Primary keys】
573 | institution_details.`unitid` | institution_grads.`unitid`
574 | 【Foreign keys】
575 | institution_grads.`unitid` = institution_details.`unitid`
576 | 【Detailed descriptions of tables and columns】
577 | institution_details.`unitid`: The column 'unitid' in Table has column descriptions of "Education Unit ID number".
578 | institution_details.`state`: The column 'state' in Table has column descriptions of "Institution state".
579 | institution_details.`student_count`: The column 'student_count' in Table has column descriptions of "Total number of undergraduates in 2010".
580 | institution_grads.`race`: The column 'race' in Table has column descriptions of "race/ethnicity of students". Value examples: ['X', 'Ai', 'A', 'B', 'H'].
581 | 【Evidence】
582 | Alabama refers to state = 'Alabama'; number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count
583 | 【Question】
584 | List the race of institutions in Alabama with number of students greater than the 90% of average number of students of all institutions?
585 | 【Subquestion】
586 | List the race of institutions in 'Alabama'.
587 | 【Sub-SQL】
588 | ```
589 | SELECT DISTINCT T1.`race` FROM institution_grads AS T1 INNER JOIN institution_details AS T2 ON T1.`unitid` = T2.`unitid` WHERE T2.`state` = 'Alabama'
590 | ```
591 | 【Matched values】
592 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
593 | institution_details.`city` = 'Alabama'; institution_details.`state` = 'Alabama'
594 |
595 | 【Subquestion】 is decomposed from 【Question】, now we provide Sub-SQL corresponding to Subquestion, you just need to complete the remaining conditions based on Sub-SQL to generate SQL for 【Question】. Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, and then generate SQL after thinking step by step:
596 | hints from 【Evidence】: Alabama refers to state = 'Alabama'; number of students greater than the 90% of average = MULTIPLY(AVG(student_count), 90%) < student_count
597 | There is a formula in 【Evidence】, so we can turn it into SQL format: institution_details.`student_count` > ( SELECT AVG(`student_count`) * 0.9 FROM institution_details )
598 | Find necessary columns from 【Database schema】 in addition to Sub-SQL: number of students refers to institution_details.`student_count`
599 | Connection of tables: the table is already in Sub-SQL, no more connection is needed
600 | Final SQL:
601 | ```sql
602 | SELECT DISTINCT T1.`race`
603 | FROM institution_grads AS T1
604 | INNER JOIN institution_details AS T2
605 | ON T1.`unitid` = T2.`unitid`
606 | WHERE T2.`student_count` > (
607 | SELECT AVG(`student_count`) * 0.9 FROM institution_details
608 | )
609 | AND T2.`state` = 'Alabama'
610 | ```
611 |
612 | Question solved.
613 |
614 | ==========
615 |
616 | 【Database schema】
617 | {desc_str}
618 | 【Primary keys】
619 | {pk_str}
620 | 【Foreign keys】
621 | {fk_str}
622 | 【Detailed descriptions of tables and columns】
623 | {detailed_str}
624 | 【Evidence】
625 | {evidence}
626 | 【Question】
627 | {query}
628 | 【Subquestion】
629 | {subquery}
630 | 【Sub-SQL】
631 | {subsql}
632 | 【Matched values】
633 | Since some of the specific values in Question and evidence match the data in the database, here are some matches retrieved from the database that may help you to generate SQL (Matched values may contain useless information and you should ignore matches that are not relevant to the question):
634 | {matched_list}
635 |
636 | 【Subquestion】 is decomposed from 【Question】, now we provide Sub-SQL corresponding to Subquestion, you just need to complete the remaining conditions based on Sub-SQL to generate SQL for 【Question】. Consider 【Constraints】, extract hints from 【Evidence】 if 【Evidence】 is related to the 【Question】, and then generate SQL after thinking step by step:
637 | """
638 |
639 |
640 |
641 | baseline_template = """
642 | Given a 【Database schema】 description, a knowledge 【Evidence】 and the 【Question】, you need to use valid SQLite and understand the database and knowledge, and then generate SQL.
643 | You can write answer in script blocks, and indicate script type in it, like this:
644 | ```sql
645 | SELECT column_a
646 | FROM table_b
647 | ```
648 |
649 | 【Database schema】
650 | {desc_str}
651 | 【Question】
652 | {query}
653 | 【Evidence】
654 | {evidence}
655 | 【Answer】
656 | """
657 |
658 |
659 | refiner_template = """
660 | 【Instruction】
661 | When executing SQL below, some errors occurred, please fix up SQL based on query and database info.
662 | Solve the task step by step if you need to. Using SQL format in the code block, and indicate script type in the code block.
663 | When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
664 | 【Constraints】
665 | - The SQL should start with 'SELECT'
666 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value
667 | - In `FROM ` or `JOIN `, do not include unnecessary table
668 | - If use `JOIN `, the connected columns should be in the Foreign keys of 【Database schema】
669 | 【Response format】
670 | Your response should be in this format:
671 | Analysis:
672 | **[Your analysis]**
673 | Correct SQL:
674 | ```sql
675 | [the fixed SQL]
676 | ```
677 | 【Attention】
678 | Only SQL statements are allowed in [the fixed SQL], do not add any comments.
679 |
680 | 【Evidence】
681 | {evidence}
682 | 【Query】
683 | -- {query}
684 | 【Database info】
685 | {desc_str}
686 | 【Primary keys】
687 | {pk_str}
688 | 【Foreign keys】
689 | {fk_str}
690 | 【Detailed descriptions of tables and columns]】
691 | {detailed_str}
692 | 【old SQL】
693 | ```sql
694 | {sql}
695 | ```
696 | 【SQLite error】
697 | {sqlite_error}
698 | 【Exception class】
699 | {exception_class}
700 |
701 | Now please fixup old SQL and generate new SQL again.
702 | 【correct SQL】
703 | """
704 |
705 | nested_refiner_template = """
706 | 【Instruction】
707 | When executing SQL below, some errors occurred, please fix up SQL based on query and database info.
708 | Solve the task step by step if you need to. Using SQL format in the code block, and indicate script type in the code block.
709 | When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
710 | 【Possible error】
711 | When using nested SQL, using the MIN or MAX function in the sub-SQL may result in null data in the end. Because when multiple tables are joined, they are connected by foreign keys and contain only some data in common, the maximum or minimum value of a column in a table may not be in the joined table.
712 | 【Constraints】
713 | - The SQL should start with 'SELECT'
714 | - In `SELECT `, just select needed columns in the 【Question】 without any unnecessary column or value
715 | - In `FROM ` or `JOIN `, do not include unnecessary table
716 | - If use `JOIN `, the connected columns should be in the Foreign keys of 【Database schema】
717 | - If use `ORDER BY ASC LIMIT `, please use `ORDER BY ASC NULLS LAST LIMIT ` to make sure the null values will not be selected
718 | - There are columns with similar names but different meanings, you can find the most accurate columns in the right table based on the 【Summary of each table】
719 | 【Response format】
720 | Your response should be in this format:
721 | Analysis:
722 | **[Your analysis]**
723 | Correct SQL:
724 | ```sql
725 | [the fixed SQL]
726 | ```
727 | 【Attention】
728 | Only SQL statements are allowed in [the fixed SQL], do not add any comments.
729 | 【Typical examples】
730 | For the error called "no data selected", here are some typical example.
731 | 1. Wrong SQL:
732 | ```
733 | SELECT T1.`date` FROM recordtime AS T1 INNER JOIN information AS T2 ON T1.`id` = T2.`id` WHERE T2.`viewers` = ( SELECT MAX(`viewers`) FROM information)
734 | ```
735 | Analysis: **Using the MAX function in a nested SQL statement will result in a mismatch because the maximum or minimum value of the column may not be in the joined table, so use `ORDER BY`!**
736 | Correct SQL:
737 | ```sql
738 | SELECT T1.`date`
739 | FROM recordtime AS T1
740 | INNER JOIN information AS T2
741 | ON T1.`id` = T2.`id`
742 | ORDER BY T2.`viewers` DESC
743 | LIMIT 1
744 | ```
745 | 2. Wrong SQL:
746 | ```
747 | SELECT T1.`idea` FROM work AS T1 INNER JOIN student AS T2 ON T1.`ID` = T2.`ID` WHERE T2.`final_score` = ( SELECT MIN(`final_score`) FROM student)
748 | ```
749 | Analysis: **Using the MIN function in a nested SQL statement will result in a mismatch because the maximum or minimum value of the column may not be in the joined table, so use `ORDER BY`!**
750 | Correct SQL:
751 | ```sql
752 | SELECT T1.`idea`
753 | FROM work AS T1
754 | INNER JOIN student AS T2
755 | ON T1.`ID` = T2.`ID`
756 | ORDER BY T2.`final_score` ASC
757 | LIMIT 1
758 | ```
759 |
760 | Here is a new case for you.
761 | 【Evidence】
762 | {evidence}
763 | 【Query】
764 | -- {query}
765 | 【Database info】
766 | {desc_str}
767 | 【Primary keys】
768 | {pk_str}
769 | 【Foreign keys】
770 | {fk_str}
771 | 【Detailed descriptions of tables and columns]】
772 | {detailed_str}
773 | 【Matched values】
774 | {matched_content}
775 | 【old SQL】
776 | ```sql
777 | {sql}
778 | ```
779 | 【SQLite error】
780 | {sqlite_error}
781 | 【Exception class】
782 | {exception_class}
783 |
784 | Now please fixup old SQL and generate new SQL again.
785 | 【correct SQL】
786 | """
787 |
--------------------------------------------------------------------------------
/main_scripts/llm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import time
4 |
5 | import openai
6 | import os
7 |
8 |
9 | # Your api settings
10 | API_Key = "API Key"
11 | Base_url = "The base url of your API"
12 |
13 | #MODEL_NAME = 'gpt-4'
14 | MODEL_NAME = 'gpt-3.5-turbo-1106'
15 | #MODEL_NAME = 'gpt-4-0125-preview'
16 | MAX_TRY = 5
17 |
18 | # 用来传递外面的字典进来
19 | world_dict = {}
20 |
21 | log_path = None
22 | api_trace_json_path = None
23 | total_prompt_tokens = 0
24 | total_response_tokens = 0
25 |
26 |
27 | def init_log_path(my_log_path):
28 | global total_prompt_tokens
29 | global total_response_tokens
30 | global log_path
31 | global api_trace_json_path
32 | log_path = my_log_path
33 | total_prompt_tokens = 0
34 | total_response_tokens = 0
35 | dir_name = os.path.dirname(log_path)
36 | os.makedirs(dir_name, exist_ok=True)
37 |
38 | # 另外一个记录api调用的文件
39 | api_trace_json_path = os.path.join(dir_name, 'api_trace.json')
40 |
41 |
42 | def api_func(prompt:str):
43 | global MODEL_NAME
44 | print(f"\nUse OpenAI model: {MODEL_NAME}\n")
45 |
46 | client = openai.OpenAI(api_key=API_Key,base_url=Base_url)
47 |
48 | response = client.chat.completions.create(
49 | model=MODEL_NAME,
50 | messages=[{"role": "user", "content": prompt}],
51 | n = 1,
52 | temperature=0.1,
53 | max_tokens=1000
54 | )
55 |
56 | text = response.choices[0].message.content.strip()
57 | prompt_token = response.usage.prompt_tokens
58 | response_token = response.usage.completion_tokens
59 |
60 |
61 | return text, prompt_token, response_token
62 |
63 |
64 | def safe_call_llm(input_prompt, **kwargs) -> str:
65 | """
66 | Function description: Given the prompt, return the content generated by LLM (internal automatic error retry 5 times, 5 errors throw an exception)
67 | """
68 | global MODEL_NAME
69 | global log_path
70 | global api_trace_json_path
71 | global total_prompt_tokens
72 | global total_response_tokens
73 | global world_dict
74 |
75 | for i in range(5):
76 | try:
77 | if log_path is None:
78 | # print(input_prompt)
79 | sys_response, prompt_token, response_token = api_func(input_prompt)
80 | print(f"\nsys_response: \n{sys_response}")
81 | print(f'\n prompt_token,response_token: {prompt_token} {response_token}\n')
82 | else:
83 | # check log_path and api_trace_json_path is not None
84 | if (log_path is None) or (api_trace_json_path is None):
85 | raise FileExistsError('log_path or api_trace_json_path is None, init_log_path first!')
86 | with open(log_path, 'a+', encoding='utf8') as log_fp, open(api_trace_json_path, 'a+', encoding='utf8') as trace_json_fp:
87 | print('\n' + f'*'*20 +'\n', file=log_fp)
88 | print(input_prompt, file=log_fp)
89 | print('\n' + f'='*20 +'\n', file=log_fp)
90 | sys_response, prompt_token, response_token = api_func(input_prompt)
91 | print(sys_response, file=log_fp)
92 | print(f'\n prompt_token,response_token: {prompt_token} {response_token}\n', file=log_fp)
93 |
94 | if len(world_dict) > 0:
95 | world_dict = {}
96 |
97 | if len(kwargs) > 0:
98 | world_dict = {}
99 | for k, v in kwargs.items():
100 | world_dict[k] = v
101 | # prompt response to world_dict
102 | world_dict['response'] = '\n' + sys_response.strip() + '\n'
103 | world_dict['input_prompt'] = input_prompt.strip() + '\n'
104 |
105 | world_dict['prompt_token'] = prompt_token
106 | world_dict['response_token'] = response_token
107 |
108 |
109 | total_prompt_tokens += prompt_token
110 | total_response_tokens += response_token
111 |
112 | world_dict['cur_total_prompt_tokens'] = total_prompt_tokens
113 | world_dict['cur_total_response_tokens'] = total_response_tokens
114 |
115 | # world_dict to json str
116 | world_json_str = json.dumps(world_dict, ensure_ascii=False)
117 | print(world_json_str, file=trace_json_fp)
118 |
119 | world_dict = {}
120 | world_json_str = ''
121 |
122 | print(f'\n total_prompt_tokens,total_response_tokens: {total_prompt_tokens} {total_response_tokens}\n', file=log_fp)
123 | print(f'\n total_prompt_tokens,total_response_tokens: {total_prompt_tokens} {total_response_tokens}\n')
124 | return sys_response
125 | except Exception as ex:
126 | print(ex)
127 | print(f'Request {MODEL_NAME} failed. try {i} times. Sleep 10 secs.')
128 | time.sleep(10)
129 |
130 | raise ValueError('safe_call_llm error!')
131 |
132 |
133 | if __name__ == "__main__":
134 | res = safe_call_llm('你能做好Text-to-SQL的任务嘛?')
135 | print(res)
136 |
--------------------------------------------------------------------------------
/main_scripts/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import re
4 | import random
5 | import json
6 | import time
7 | import sqlite3
8 | from main_scripts.const import subq_pattern
9 | from typing import Dict, List
10 |
11 |
12 | def is_valid_date(date_str):
13 | if (not isinstance(date_str, str)):
14 | return False
15 | date_str = date_str.split()[0]
16 | if len(date_str) != 10:
17 | return False
18 | pattern = r'^\d{4}-\d{2}-\d{2}$'
19 | if re.match(pattern, date_str):
20 | year, month, day = map(int, date_str.split('-'))
21 | if year < 1 or month < 1 or month > 12 or day < 1 or day > 31:
22 | return False
23 | else:
24 | return True
25 | else:
26 | return False
27 |
28 |
29 | def is_valid_date_column(col_value_lst):
30 | for col_value in col_value_lst:
31 | if not is_valid_date(col_value):
32 | return False
33 | return True
34 |
35 |
36 | def rename_file(file_path, new_name):
37 | """
38 | 给定原文件路径和新文件名,重命名文件
39 |
40 | @param file_path: 原文件路径, 如: /home/user/test.txt
41 | @param new_name: 新文件名, 如: backup
42 | @return: 新文件路径
43 | """
44 | # 获取文件的目录和后缀名
45 | dir_name = os.path.dirname(file_path)
46 | file_name, file_ext = os.path.splitext(os.path.basename(file_path))
47 |
48 | # 获取当前时间戳
49 | timestamp = str(int(time.time()))
50 |
51 | # 构建新的文件名
52 | new_file_name = new_name + '_' + timestamp + file_ext
53 |
54 | # 构建新的文件路径
55 | new_file_path = os.path.join(dir_name, new_file_name)
56 |
57 | # 重命名文件
58 | os.rename(file_path, new_file_path)
59 |
60 | return new_file_path
61 |
62 |
63 | def is_email(string):
64 | pattern = r'^[\w\.-]+@[\w\.-]+\.\w+$'
65 | match = re.match(pattern, string)
66 | if match:
67 | return True
68 | else:
69 | return False
70 |
71 |
72 |
73 | def extract_world_info(message_dict: dict):
74 | info_dict = {}
75 | info_dict['idx'] = message_dict['idx']
76 | info_dict['db_id'] = message_dict['db_id']
77 | info_dict['query'] = message_dict['query']
78 | info_dict['evidence'] = message_dict.get('evidence', '')
79 | info_dict['difficulty'] = message_dict.get('difficulty', '')
80 | info_dict['ground_truth'] = message_dict.get('ground_truth', '')
81 | info_dict['send_to'] = message_dict.get('send_to', '')
82 | return info_dict
83 |
84 |
85 | def replace_multiple_spaces(text):
86 | # 定义正则表达式,匹配多个空字符
87 | pattern = r'\s+'
88 | # 将多个空字符替换成一个空格
89 | new_text = re.sub(pattern, ' ', text)
90 | return new_text
91 |
92 |
93 | # SQL parsing
94 | def extract_table_names(sql_query):
95 | # 使用正则表达式提取FROM子句中的表名
96 | # 使用正则表达式提取FROM子句中的表名
97 | # 假设表名位于FROM关键字后面,且没有特殊字符或空格
98 | sql_query = sql_query.replace('`', '')
99 | table_names = re.findall(r'FROM\s+([\w]+)', sql_query, re.IGNORECASE) + \
100 | re.findall(r'JOIN\s+([\w]+)', sql_query, re.IGNORECASE)
101 | return set(table_names)
102 |
103 |
104 | def get_used_tables(sql, db_path) -> dict: # table_name -> chosen columns & discarded columns
105 | table_names = extract_table_names(sql)
106 | sch = {}
107 | conn = sqlite3.connect(db_path)
108 | conn.text_factory = lambda b: b.decode(errors="ignore")
109 | cursor = conn.cursor()
110 | for table_name in table_names:
111 | cursor.execute(f"PRAGMA table_info(`{table_name}`)")
112 | columns = cursor.fetchall()
113 | column_names = [cinfo[1] for cinfo in columns]
114 | sch[table_name] = {
115 | "chosen columns": column_names,
116 | "discarded columns": []
117 | }
118 | return sch
119 |
120 |
121 | def get_all_tables(db_path) -> dict:
122 | conn = sqlite3.connect(db_path)
123 | conn.text_factory = lambda b: b.decode(errors="ignore")
124 | cursor = conn.cursor()
125 | cursor.execute("SELECT name FROM sqlite_master WHERE type=\'table\'")
126 | tables = cursor.fetchall()
127 | table_names = [a[0] for a in tables if a[0] != 'sqlite_sequence']
128 | sch = {}
129 | for table_name in table_names:
130 | cursor.execute(f"PRAGMA table_info(`{table_name}`)")
131 | columns = cursor.fetchall()
132 | column_names = [cinfo[1] for cinfo in columns]
133 | sch[table_name] = {
134 | "chosen columns": column_names,
135 | "discarded columns": []
136 | }
137 | return sch
138 |
139 |
140 | gold_schema = []
141 |
142 |
143 | def get_gold_columns(idx, db_path) -> dict:
144 | global gold_schema
145 | if gold_schema == []:
146 | input_file = "data/bird/dev_gold_schema.json"
147 | with open(input_file, encoding='utf8') as f:
148 | gold_schema = json.load(f)
149 | table2cols = gold_schema[idx]["columns_map"]
150 |
151 | sch = {}
152 | conn = sqlite3.connect(db_path)
153 | conn.text_factory = lambda b: b.decode(errors="ignore")
154 | cursor = conn.cursor()
155 | cursor.execute("SELECT name FROM sqlite_master WHERE type=\'table\'")
156 | tables = cursor.fetchall()
157 | table_names = [a[0] for a in tables if a[0] != 'sqlite_sequence']
158 | for table_name in table_names:
159 | cursor.execute(f"PRAGMA table_info(`{table_name}`)")
160 | columns = cursor.fetchall()
161 | all_columns = [cinfo[1] for cinfo in columns]
162 | gold_columns = table2cols.get(table_name, [])
163 | gold_columns = [str(item).replace('`', '') for item in gold_columns]
164 | unused_columns = list(set(all_columns).difference(set(gold_columns)))
165 | random.shuffle(unused_columns)
166 | sch[table_name] = {
167 | "chosen columns": gold_columns + unused_columns[:3], # used golden cols + unused random 3 cols
168 | "discarded columns": []
169 | }
170 | return sch
171 |
172 |
173 | # GPT result parsing
174 |
175 |
176 | # def parse_json(res: str) -> dict:
177 | # lines = res.split('\n')
178 | # start_idx, end_idx = -1, -1
179 | # for idx in range(0, len(lines)):
180 | # if '```json' in lines[idx]:
181 | # start_idx = idx
182 | # break
183 | # if start_idx == -1: return {}
184 | # for idx in range(start_idx + 1, len(lines)):
185 | # if '```' in lines[idx]:
186 | # end_idx = idx
187 | # break
188 | # if end_idx == -1: return {}
189 | # jstr = " ".join(lines[start_idx + 1: end_idx])
190 | # return json.loads(jstr)
191 |
192 |
193 | # parse json output
194 | def parse_json(res: str) -> dict:
195 | # lines = res.split('\n')
196 | # start_idx, end_idx = -1, -1
197 | # for idx in range(0, len(lines)):
198 | # if '```json' in lines[idx]:
199 | # start_idx = idx
200 | # break
201 | # if start_idx == -1: return {}
202 | # for idx in range(start_idx + 1, len(lines)):
203 | # if '```' in lines[idx]:
204 | # end_idx = idx
205 | # break
206 | # if end_idx == -1: return {}
207 | # jstr = " ".join(lines[start_idx + 1: end_idx])
208 | # return json.loads(jstr)
209 | # todo: for debug
210 | return {}
211 |
212 |
213 | # check if valid format
214 | def check_selector_response(json_data: Dict) -> bool:
215 | FLAGS = ['keep_all', 'drop_all']
216 | for k, v in json_data.items():
217 | if isinstance(v, str):
218 | if v not in FLAGS:
219 | print(f"error: invalid table flag: {v}\n")
220 | print(f"json_data: {json_data}\n\n")
221 | return False
222 | elif isinstance(v, list):
223 | pass
224 | else:
225 | print(f"error: invalid flag type: {v}\n")
226 | print(f"json_data: {json_data}\n\n")
227 | return False
228 | return True
229 |
230 | def check_letter(s):
231 | for char in s:
232 | if char.isalpha():
233 | return True
234 | return False
235 |
236 | def contain_value(s):
237 | pattern = r"\s+'(.*?)'"
238 | r = re.compile(pattern)
239 | if r.search(s) != None:
240 | return re.findall(pattern, s)
241 | else:
242 | return None
243 |
244 |
245 | def get_files(root, suffix):
246 | """
247 | 获取指定目录下的所有指定后缀的文件
248 | :param root: 指定目录 str 类型 如:'.'
249 | :param suffix: 指定后缀 str 类型 如:'.txt'
250 | :return: 文件列表
251 | """
252 | import os
253 | import glob
254 | if not os.path.exists(root):
255 | raise FileNotFoundError(f'path {root} not found.')
256 | res = glob.glob(f'{root}/**/*{suffix}', recursive=True)
257 | res = [os.path.abspath(p) for p in res]
258 | return res
259 |
260 |
261 | # read txt file to string list and strip empty lines
262 | def read_txt_file(path):
263 | with open(path, 'r', encoding='utf-8') as f:
264 | print(f"load txt file from {path}")
265 | return [line.strip() for line in f if line.strip()!= '']
266 |
267 | def load_json_file(path):
268 | with open(path, 'r', encoding='utf-8') as f:
269 | print(f"load json file from {path}")
270 | return json.load(f)
271 |
272 |
273 | def load_jsonl_file(path):
274 | with open(path, 'r', encoding='utf-8') as f:
275 | data = []
276 | for line in f:
277 | js_str = line.strip()
278 | if js_str == '':
279 | continue
280 | js = json.loads(js_str)
281 | data.append(js)
282 | print(f"load jsonl file from {path}")
283 | return data
284 |
285 |
286 | def save_file(path, string_lst):
287 | """
288 | 保存文件
289 | :param path: 文件路径 str 类型
290 | :param string_lst: 字符串列表, 带有换行符
291 | """
292 | with open(path, 'w', encoding='utf-8') as f:
293 | f.writelines(string_lst)
294 | print(f"save file to {path}")
295 |
296 |
297 | def save_json_file(path, data):
298 | with open(path, 'w', encoding='utf-8') as f:
299 | json.dump(data, f, indent=2, ensure_ascii=False)
300 | print(f"save json file to {path}")
301 |
302 |
303 | def save_jsonl_file(path, data):
304 | with open(path, 'w', encoding='utf-8') as f:
305 | for js in data:
306 | f.write(json.dumps(js, ensure_ascii=False) + '\n')
307 | print(f"save jsonl file to {path}")
308 |
309 |
310 | def parse_json(text: str) -> dict:
311 | # 查找字符串中的 JSON 块
312 | if "json" in text:
313 | start = text.find("```json")
314 | end = text.find("```", start + 7)
315 | # 如果找到了 JSON 块
316 | if start != -1 and end != -1:
317 | json_string = text[start + 7: end]
318 | print(json_string)
319 | try:
320 | # 解析 JSON 字符串
321 | json_data = json.loads(json_string)
322 | #valid = check_selector_response(json_data)
323 | return json_data
324 | except:
325 | print(f"error: parse json error!\n")
326 | print(f"json_string: {json_string}\n\n")
327 | pass
328 | elif "```" in text:
329 | start = text.find("```")
330 | end = text.find("```", start + 3)
331 | if start != -1 and end != -1:
332 | json_string = text[start + 3: end]
333 |
334 | try:
335 | # 解析 JSON 字符串
336 | json_data = json.loads(json_string)
337 | return json_data
338 | except:
339 | print(f"error: parse json error!\n")
340 | print(f"json_string: {json_string}\n\n")
341 | pass
342 | else:
343 | start = text.find("{")
344 | end = text.find("}", start + 1)
345 | if start != -1:
346 | json_string = text[start: end + 1]
347 | try:
348 | # 解析 JSON 字符串
349 | json_data = json.loads(json_string)
350 | return json_data
351 | except:
352 | print(f"error: parse json error!\n")
353 | print(f"json_string: {json_string}\n\n")
354 | pass
355 | """
356 | # 如果找到了 JSON 块
357 | if start != -1 and end != -1:
358 | json_string = text[start + 7: end]
359 |
360 | try:
361 | # 解析 JSON 字符串
362 | json_data = json.loads(json_string)
363 | valid = check_selector_response(json_data)
364 | if valid:
365 | return json_data
366 | else:
367 | return {}
368 | except:
369 | print(f"error: parse json error!\n")
370 | print(f"json_string: {json_string}\n\n")
371 | pass
372 | """
373 | return {}
374 |
375 | def get_chosen_schema(raw_linked_schema:dict) -> dict:
376 | if raw_linked_schema == {}:
377 | return {}
378 | try:
379 | linked_schema = dict()
380 | for key,value in raw_linked_schema.items():
381 | for column in value:
382 | table = column.split('.')[0]
383 | col = column.split('.')[1]
384 | if linked_schema.get(table) == None:
385 | linked_schema[table] = {col}
386 | else:
387 | linked_schema[table].add(col)
388 | for key in linked_schema.keys():
389 | linked_schema[key] = list(linked_schema[key])
390 | return linked_schema
391 | except:
392 | return {}
393 |
394 | def parse_sql(res: str) -> str:
395 | """Only need SQL(startswith `SELECT`) of LLM result"""
396 | if 'SELECT' not in res and 'select' not in res:
397 | res = 'SELECT ' + res
398 | # match = re.search(parse_pattern, res, re.IGNORECASE | re.DOTALL)
399 | # if match:
400 | # sql = match.group().strip()
401 | # sql = sql.replace('```', '') # TODO
402 | # sql = sql.replace('\n', ' ') # TODO
403 | # return True, sql
404 | # else:
405 | # return False, ""
406 | res = res.replace('\n', ' ')
407 | return res.strip()
408 |
409 |
410 | def parse_single_sql(res: str) -> str: # if do not need decompose, just one code block is OK!
411 | """Return SQL in markdown block"""
412 | lines = res.split('\n')
413 | iter, start_idx, end_idx = -1, -1, -1
414 | for idx in range(iter + 1, len(lines)):
415 | if '```' in lines[idx]:
416 | start_idx = idx
417 | break
418 | if start_idx == -1: return ""
419 | for idx in range(start_idx + 1, len(lines)):
420 | if '```' in lines[idx]:
421 | end_idx = idx
422 | break
423 | if end_idx == -1: return ""
424 |
425 | return " ".join(lines[start_idx + 1: end_idx])
426 |
427 |
428 | def parse_qa_pairs(res: str, end_pos=2333) -> list:
429 | lines = res.split('\n')
430 | qa_pairs = []
431 | # end_pos = -1
432 | # for idx, line in enumerate(lines):
433 | # if 'final SQL' in line or 'final sql' in line:
434 | # end_pos = idx
435 | # if end_pos == -1: return []
436 | end_pos = len(lines) if (end_pos == 2333) else end_pos
437 | for idx in range(0, end_pos):
438 | if re.findall(subq_pattern, lines[idx], re.IGNORECASE) != []:
439 | query = lines[idx]
440 | start_idx = -1
441 | for idx2 in range(idx + 1, end_pos):
442 | if '```' in lines[idx2]:
443 | start_idx = idx2
444 | break
445 | if start_idx == -1: return []
446 | for idx3 in range(start_idx + 1, end_pos):
447 | if '```' in lines[idx3]:
448 | end_idx = idx3
449 | break
450 | if end_idx == -1: return []
451 | answer = " ".join(lines[start_idx + 1: end_idx])
452 | qa_pairs.append((str(query), str(answer)))
453 | idx = end_idx
454 | return qa_pairs
455 |
456 | def extract_sql(a: str) -> str:
457 | k = a.split("```")
458 | check_format = False
459 | sql_list = []
460 | for statement in k:
461 | if 'sql' in statement:
462 | sql_list.append(statement)
463 | check_format = True
464 | if not check_format:
465 | for statement in k:
466 | if 'SELECT' in statement:
467 | sql_list.append(statement)
468 |
469 | if len(sql_list) > 0:
470 | sql = sql_list[-1]
471 | token_list = sql.split('\n')
472 | for i in range(len(token_list)):
473 | if "--" in token_list[i]:
474 | token_list[i] = token_list[i].split("--")[0]
475 | token_list[i] = token_list[i].strip()
476 | final_sql = ' '.join(token_list).replace('sql ','')
477 | else:
478 | final_sql = None
479 | return final_sql
480 |
481 | def extract_subquery(reply:str) -> list:
482 | try:
483 | answer = reply.split('\n')
484 | subqueries = []
485 | for i in range(len(answer)):
486 | if "##" in answer[i]:
487 | sublist = answer[i].split(":")
488 | subqueries.append(sublist[1].strip())
489 | except:
490 | subqueries = []
491 | return subqueries
492 |
493 | def divide_targets_conditions(reply: str) -> list:
494 | decomposition = "Targets:" + reply.split("Targets:")[1]
495 | temp = decomposition.split("Conditions:")
496 | targets = temp[0].strip()
497 | conditions = "Conditions:" + temp[1].strip()
498 | return [targets,conditions]
499 |
500 | def extract_new_CoT(reply: str) -> str:
501 | #extract the new chain of thoughts
502 | if "**" in reply:
503 | return reply.split("**")[2]
504 | else:
505 | return reply
506 |
507 | def parse_subq(res: str) -> list:
508 | """Only sub questions after decomposition"""
509 | res = '-- ' + res
510 | sub_qustions = []
511 | sub_qustions += res.split('-- ')
512 | sub_qustions = [q.strip() for q in sub_qustions if len(q) > 1]
513 | return sub_qustions
514 |
515 |
516 | def add_prefix(sql):
517 | if not sql.startswith('SELECT') and not sql.startswith('select'):
518 | sql = 'SELECT' + sql
519 | return sql
520 |
521 | def detect_special_char(name):
522 | for special_char in ['(', '-', ')', ' ', '/']:
523 | if special_char in name:
524 | return True
525 |
526 | return False
527 |
528 | def add_quotation_mark(s):
529 | return "`" + s + "`"
530 |
531 | def get_matched_content_sequence(matched_contents):
532 | content_sequence = ""
533 | if matched_contents == None:
534 | content_sequence = "NULL"
535 | elif len(matched_contents) != 0:
536 | for tc_name, contents in matched_contents.items():
537 | table_name = tc_name.split(".")[0]
538 | column_name = tc_name.split(".")[1]
539 | #if detect_special_char(table_name):
540 | #table_name = add_quotation_mark(table_name)
541 | #if detect_special_char(column_name):
542 | column_name = add_quotation_mark(column_name)
543 | content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n"
544 | else:
545 | content_sequence = "NULL"
546 |
547 | return content_sequence.strip()
548 |
549 | # Spider data preprocess
550 |
551 |
552 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
553 | JOIN_KEYWORDS = ('join', 'on', 'as')
554 |
555 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
556 | UNIT_OPS = ('none', '-', '+', "*", '/')
557 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
558 | TABLE_TYPE = {
559 | 'sql': "sql",
560 | 'table_unit': "table_unit",
561 | }
562 |
563 | COND_OPS = ('and', 'or')
564 | SQL_OPS = ('intersect', 'union', 'except')
565 | ORDER_OPS = ('desc', 'asc')
566 |
567 |
568 | HARDNESS = {
569 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
570 | "component2": ('except', 'union', 'intersect')
571 | }
572 |
573 |
574 | def get_nestedSQL(sql):
575 | nested = []
576 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
577 | if type(cond_unit[3]) is dict:
578 | nested.append(cond_unit[3])
579 | if type(cond_unit[4]) is dict:
580 | nested.append(cond_unit[4])
581 | if sql['intersect'] is not None:
582 | nested.append(sql['intersect'])
583 | if sql['except'] is not None:
584 | nested.append(sql['except'])
585 | if sql['union'] is not None:
586 | nested.append(sql['union'])
587 | return nested
588 |
589 |
590 | def has_agg(unit):
591 | return unit[0] != AGG_OPS.index('none')
592 |
593 |
594 | def count_agg(units):
595 | return len([unit for unit in units if has_agg(unit)])
596 |
597 |
598 | def count_component1(sql):
599 | count = 0
600 | if len(sql['where']) > 0:
601 | count += 1
602 | if len(sql['groupBy']) > 0:
603 | count += 1
604 | if len(sql['orderBy']) > 0:
605 | count += 1
606 | if sql['limit'] is not None:
607 | count += 1
608 | if len(sql['from']['table_units']) > 0: # JOIN
609 | count += len(sql['from']['table_units']) - 1
610 |
611 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
612 | count += len([token for token in ao if token == 'or'])
613 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
614 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
615 |
616 | return count
617 |
618 |
619 | def count_component2(sql):
620 | nested = get_nestedSQL(sql)
621 | return len(nested)
622 |
623 |
624 | def count_others(sql):
625 | count = 0
626 | # number of aggregation
627 | agg_count = count_agg(sql['select'][1])
628 | agg_count += count_agg(sql['where'][::2])
629 | agg_count += count_agg(sql['groupBy'])
630 | if len(sql['orderBy']) > 0:
631 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
632 | [unit[2] for unit in sql['orderBy'][1] if unit[2]])
633 | agg_count += count_agg(sql['having'])
634 | if agg_count > 1:
635 | count += 1
636 |
637 | # number of select columns
638 | if len(sql['select'][1]) > 1:
639 | count += 1
640 |
641 | # number of where conditions
642 | if len(sql['where']) > 1:
643 | count += 1
644 |
645 | # number of group by clauses
646 | if len(sql['groupBy']) > 1:
647 | count += 1
648 |
649 | return count
650 |
651 |
652 | def eval_hardness(sql):
653 | count_comp1_ = count_component1(sql)
654 | count_comp2_ = count_component2(sql)
655 | count_others_ = count_others(sql)
656 |
657 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
658 | return "easy"
659 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
660 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
661 | return "medium"
662 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
663 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
664 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
665 | return "hard"
666 | else:
667 | return "extra"
668 |
--------------------------------------------------------------------------------
/output/gpt4_mag.json:
--------------------------------------------------------------------------------
1 | {"idx": 0, "db_id": "california_schools", "query": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?", "evidence": "Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`", "extracted_schema": {"frpm": ["County Name", "Enrollment (K-12)", "Free Meal Count (K-12)"], "satscores": ["enroll12", "cname"], "schools": ["CDSCode", "County", "City"]}, "ground_truth": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1", "difficulty": "simple", "matched_list": ["frpm.`County Name` = 'Alameda'", "schools.`MailCity` = 'Alameda'", "satscores.`cname` = 'Alameda'", "schools.`County` = 'Alameda'", "schools.`GSoffered` = 'K-12'", "schools.`GSserved` = 'K-12'", "schools.`City` = 'Alameda'"], "complete_desc_str": "# Table: frpm\n[\n (CDSCode , CDSCode.),\n (Academic Year , Academic Year. Value examples: ['2014-2015'].),\n (County Code , County Code. Value examples: ['19', '37', '30', '36', '33', '43'].),\n (District Code , District Code.),\n (School Code , School Code. Value examples: ['0000000', '9010745', '6121081', '6121073', '6121016', '6120968'].),\n (County Name , County Name. Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Riverside', 'Santa Clara'].),\n (District Name , District Name.),\n (School Name , School Name.),\n (District Type , District Type. Value examples: ['Unified School District', 'Elementary School District', 'High School District', 'County Office of Education (COE)', 'State Board of Education', 'Statewide Benefit Charter'].),\n (School Type , School Type. Value examples: [None, 'Elementary Schools (Public)', 'High Schools (Public)', 'Intermediate/Middle Schools (Public)', 'Continuation High Schools', 'Alternative Schools of Choice', 'K-12 Schools (Public)'].),\n (Educational Option Type , Educational Option Type. Value examples: [None, 'Traditional', 'Continuation School', 'Alternative School of Choice', 'Community Day School', 'Special Education School', 'County Community School'].),\n (NSLP Provision Status , NSLP Provision Status. Value examples: [None, 'Provision 2', 'Breakfast Provision 2', 'CEP', 'Multiple Provision Types', 'Provision 1', 'Lunch Provision 2'].),\n (Charter School (Y/N) , Charter School (Y/N). Value examples: [None, 0, 1].),\n (Charter School Number , Charter School Number. Value examples: [None, '00D2', '0756', '00D4', '00D1', '1700', '1696'].),\n (Charter Funding Type , Charter Funding Type. Value examples: [None, 'Directly funded', 'Locally funded', 'Not in CS funding model'].),\n (IRC , IRC. Value examples: [None, 0, 1].),\n (Low Grade , Low Grade. Value examples: ['K', '9', '6', '7', 'P', '5'].),\n (High Grade , High Grade. Value examples: ['12', '5', '8', '6', '3', '4'].),\n (Enrollment (K-12) , Enrollment (K-12).),\n (Free Meal Count (K-12) , Free Meal Count (K-12).),\n (Percent (%) Eligible Free (K-12) , Percent (%) Eligible Free (K-12).),\n (FRPM Count (K-12) , Free or Reduced Price Meal Count (K-12).),\n (Percent (%) Eligible FRPM (K-12) , Percent (%) Eligible FRPM (K-12).),\n (Enrollment (Ages 5-17) , Enrollment (Ages 5-17).),\n (Free Meal Count (Ages 5-17) , Free Meal Count (Ages 5-17).),\n (Percent (%) Eligible Free (Ages 5-17) , Percent (%) Eligible Free (Ages 5-17).),\n (FRPM Count (Ages 5-17) , FRPM Count (Ages 5-17).),\n (Percent (%) Eligible FRPM (Ages 5-17) , Percent (%) Eligible FRPM (Ages 5-17).),\n (2013-14 CALPADS Fall 1 Certification Status , 2013-14 CALPADS Fall 1 Certification Status. Value examples: [1].)\n]\n# Table: satscores\n[\n (cds , California Department Schools.),\n (rtype , rtype. Value examples: ['S', 'D'].),\n (sname , school name. Value examples: [None, 'Middle College High', 'John F. Kennedy High', 'Independence High', 'Foothill High', 'Washington High', 'Redwood High'].),\n (dname , district segment. Value examples: ['Los Angeles Unified', 'San Diego Unified', 'Oakland Unified', 'San Francisco Unified', 'Kern High', 'East Side Union High'].),\n (cname , county name. Value examples: ['Los Angeles', 'San Diego', 'San Bernardino', 'Riverside', 'Orange', 'Alameda'].),\n (enroll12 , enrollment (1st-12nd grade).),\n (NumTstTakr , Number of Test Takers in this school.),\n (AvgScrRead , average scores in Reading.),\n (AvgScrMath , average scores in Math.),\n (AvgScrWrite , average scores in writing.),\n (NumGE1500 , Number of Test Takers Whose Total SAT Scores Are Greater or Equal to 1500.)\n]\n# Table: schools\n[\n (CDSCode , CDSCode.),\n (NCESDist , This field represents the 7-digit National Center for Educational Statistics (NCES) school district identification number. The first 2 digits identify the state and the last 5 digits identify the school district. Combined, they make a unique 7-digit ID for each school district.. Value examples: [None, '0622710', '0634320', '0628050', '0634410', '0614550', '0633840'].),\n (NCESSchool , This field represents the 5-digit NCES school identification number. The NCESSchool combined with the NCESDist form a unique 12-digit ID for each school.. Value examples: [None, '12271', '13785', '13747', '12909', '12311', '12003'].),\n (StatusType , This field identifies the status of the district.. Value examples: ['Active', 'Closed', 'Merged', 'Pending'].),\n (County , County name. Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Santa Clara', 'Alameda'].),\n (District , District.),\n (School , School.),\n (Street , Street.),\n (StreetAbr , The abbreviated street address of the school, district, or administrative authority’s physical location..),\n (City , City. Value examples: [None, 'Los Angeles', 'San Diego', 'San Jose', 'Sacramento', 'Fresno', 'Oakland'].),\n (Zip , Zip. Value examples: [None, '91723', '95959', '94558', '92345', '90278', '90266'].),\n (State , State. Value examples: [None, 'CA'].),\n (MailStreet , MailStreet.),\n (MailStrAbr , mailing street address.),\n (MailCity , mailing city. Value examples: [None, 'Los Angeles', 'San Diego', 'San Jose', 'Sacramento', 'Oakland', 'Fresno'].),\n (MailZip , mailing zip. Value examples: [None, '94606', '91710-4130', '92626-4300', '90266', '95336-0032', '92084'].),\n (MailState , mailing state. Value examples: [None, 'CA'].),\n (Phone , Phone. Value examples: [None, '(805) 581-0202', '(619) 243-5014', '(562) 698-8121', '(559) 683-4667', '(650) 329-2800', '(530) 625-5600'].),\n (Ext , The phone number extension of the school, district, or administrative authority.. Value examples: [None, '200', '201', '111', '352', '221', '203'].),\n (Website , The website address of the school, district, or administrative authority..),\n (OpenDate , The date the school opened.. Value examples: [None, '1980-07-01'].),\n (ClosedDate , The date the school closed.. Value examples: [None, '1989-06-30'].),\n (Charter , This field identifies a charter school.. Value examples: [None, 0, 1].),\n (CharterNum , The charter school number,. Value examples: [None, '00D2', '0756', '00D4', '00D1', '1826', '1710'].),\n (FundingType , Indicates the charter school funding type. Value examples: [None, 'Directly funded', 'Locally funded', 'Not in CS funding model'].),\n (DOC , District Ownership Code. Value examples: ['54', '52', '00', '56', '98', '02'].),\n (DOCType , The District Ownership Code Type is the text description of the DOC category.. Value examples: ['Unified School District', 'Elementary School District', 'County Office of Education (COE)', 'High School District', 'Regional Occupation Center/Program (ROC/P)', 'State Board of Education'].),\n (SOC , The School Ownership Code is a numeric code used to identify the type of school.. Value examples: [None, '60', '66', '62', '68', '09', '69'].),\n (SOCType , The School Ownership Code Type is the text description of the type of school.. Value examples: [None, 'Elementary Schools (Public)', 'High Schools (Public)', 'Intermediate/Middle Schools (Public)', 'Continuation High Schools', 'Special Education Schools (Public)', 'District Community Day Schools'].),\n (EdOpsCode , The Education Option Code is a short text description of the type of education offered.. Value examples: [None, 'TRAD', 'CON', 'COMMDAY', 'ALTSOC', 'SPEC', 'COMM'].),\n (EdOpsName , Educational Option Name. Value examples: [None, 'Traditional', 'Continuation School', 'Community Day School', 'Alternative School of Choice', 'Special Education School', 'County Community School'].),\n (EILCode , The Educational Instruction Level Code is a short text description of the institution's type relative to the grade range served.. Value examples: [None, 'ELEM', 'HS', 'UG', 'INTMIDJR', 'ELEMHIGH', 'A'].),\n (EILName , The Educational Instruction Level Name is the long text description of the institution’s type relative to the grade range served.. Value examples: [None, 'Elementary', 'High School', 'Ungraded', 'Intermediate/Middle/Junior High', 'Elementary-High Combination', 'Adult'].),\n (GSoffered , The grade span offered is the lowest grade and the highest grade offered or supported by the school, district, or administrative authority. This field might differ from the grade span served as reported in the most recent certified California Longitudinal Pupil Achievement (CALPADS) Fall 1 data collection.. Value examples: [None, 'K-5', '9-12', 'K-6', 'K-8', '6-8', 'K-12'].),\n (GSserved , It is the lowest grade and the highest grade of student enrollment as reported in the most recent certified CALPADS Fall 1 data collection. Only K–12 enrollment is reported through CALPADS. This field may differ from the grade span offered.. Value examples: [None, 'K-5', 'K-6', '9-12', 'K-8', '6-8', '7-8'].),\n (Virtual , This field identifies the type of virtual instruction offered by the school. Virtual instruction is instruction in which students and teachers are separated by time and/or location, and interaction occurs via computers and/or telecommunications technologies.. Value examples: [None, 'N', 'P', 'F'].),\n (Magnet , This field identifies whether a school is a magnet school and/or provides a magnet program.. Value examples: [None, 0, 1].),\n (Latitude , The angular distance (expressed in degrees) between the location of the school, district, or administrative authority and the equator measured north to south..),\n (Longitude , The angular distance (expressed in degrees) between the location of the school, district, or administrative authority and the prime meridian (Greenwich, England) measured from west to east..),\n (AdmFName1 , administrator's first name 1. Value examples: [None, 'David', 'Michael', 'John', 'Jennifer', 'Lisa', 'Robert'].),\n (AdmLName1 , administrator's last name 1.),\n (AdmEmail1 , administrator's email address 1.),\n (AdmFName2 , administrator's first name 2. Value examples: [None, 'Steve', 'Delphine', 'Tim', 'Michael', 'Lisa', 'Jeff'].),\n (AdmLName2 , administrator's last name 2. Value examples: [None, 'Sherman', 'Jones', 'Wilson', 'Petersen', 'Landry', 'Hambright'].),\n (AdmEmail2 , administrator's email address 2.),\n (AdmFName3 , administrator's first name 3. Value examples: [None, 'Mark', 'Eric', 'Vickie', 'Tracie', 'Shawn', 'Sarah'].),\n (AdmLName3 , administrator's last name 3. Value examples: [None, 'Zerpoli', 'Yount', 'Woodruff', 'Wood', 'Wolk', 'Spotts'].),\n (AdmEmail3 , administrator's email address 3.),\n (LastUpdate , Last Update. Value examples: ['2015-06-18'].)\n]", "summary_str": "# frpm: Data on schools' Free and Reduced Price Meal (FRPM) eligibility and enrollment across various educational options and provisions.\n# satscores: SAT performance metrics for schools including average scores and number of high-scoring test takers.\n# schools: Detailed information about schools, including their addresses, contact details, and administrative data.", "chosen_db_schem_dict": {"frpm": ["CDSCode", "Academic Year", "County Code", "District Code", "School Code", "County Name", "District Name", "School Name", "District Type", "School Type", "Educational Option Type", "NSLP Provision Status", "Charter School (Y/N)", "Charter School Number", "Charter Funding Type", "IRC", "Low Grade", "High Grade", "Enrollment (K-12)", "Free Meal Count (K-12)", "Percent (%) Eligible Free (K-12)", "FRPM Count (K-12)", "Percent (%) Eligible FRPM (K-12)", "Enrollment (Ages 5-17)", "Free Meal Count (Ages 5-17)", "Percent (%) Eligible Free (Ages 5-17)", "FRPM Count (Ages 5-17)", "Percent (%) Eligible FRPM (Ages 5-17)", "2013-14 CALPADS Fall 1 Certification Status"], "satscores": ["cds", "rtype", "sname", "dname", "cname", "enroll12", "NumTstTakr", "AvgScrRead", "AvgScrMath", "AvgScrWrite", "NumGE1500"], "schools": ["CDSCode", "NCESDist", "NCESSchool", "StatusType", "County", "District", "School", "Street", "StreetAbr", "City", "Zip", "State", "MailStreet", "MailStrAbr", "MailCity", "MailZip", "MailState", "Phone", "Ext", "Website", "OpenDate", "ClosedDate", "Charter", "CharterNum", "FundingType", "DOC", "DOCType", "SOC", "SOCType", "EdOpsCode", "EdOpsName", "EILCode", "EILName", "GSoffered", "GSserved", "Virtual", "Magnet", "Latitude", "Longitude", "AdmFName1", "AdmLName1", "AdmEmail1", "AdmFName2", "AdmLName2", "AdmEmail2", "AdmFName3", "AdmLName3", "AdmEmail3", "LastUpdate"]}, "desc_str": "# frpm: [CDSCode (TEXT), , Academic Year (TEXT), , County Code (TEXT), , District Code (INTEGER), , School Code (TEXT), , County Name (TEXT), , District Name (TEXT), , School Name (TEXT), , District Type (TEXT), , School Type (TEXT), , Educational Option Type (TEXT), , NSLP Provision Status (TEXT), , Charter School (Y/N) (INTEGER), , Charter School Number (TEXT), , Charter Funding Type (TEXT), , IRC (INTEGER), , Low Grade (TEXT), , High Grade (TEXT), , Enrollment (K-12) (REAL), , Free Meal Count (K-12) (REAL), , Percent (%) Eligible Free (K-12) (REAL), , FRPM Count (K-12) (REAL), , Percent (%) Eligible FRPM (K-12) (REAL), , Enrollment (Ages 5-17) (REAL), , Free Meal Count (Ages 5-17) (REAL), , Percent (%) Eligible Free (Ages 5-17) (REAL), , FRPM Count (Ages 5-17) (REAL), , Percent (%) Eligible FRPM (Ages 5-17) (REAL), , 2013-14 CALPADS Fall 1 Certification Status (INTEGER), ]\n# satscores: [cds (TEXT), , rtype (TEXT), , sname (TEXT), , dname (TEXT), , cname (TEXT), , enroll12 (INTEGER), , NumTstTakr (INTEGER), , AvgScrRead (INTEGER), , AvgScrMath (INTEGER), , AvgScrWrite (INTEGER), , NumGE1500 (INTEGER), ]\n# schools: [CDSCode (TEXT), , NCESDist (TEXT), , NCESSchool (TEXT), , StatusType (TEXT), , County (TEXT), , District (TEXT), , School (TEXT), , Street (TEXT), , StreetAbr (TEXT), , City (TEXT), , Zip (TEXT), , State (TEXT), , MailStreet (TEXT), , MailStrAbr (TEXT), , MailCity (TEXT), , MailZip (TEXT), , MailState (TEXT), , Phone (TEXT), , Ext (TEXT), , Website (TEXT), , OpenDate (DATE), , ClosedDate (DATE), , Charter (INTEGER), , CharterNum (TEXT), , FundingType (TEXT), , DOC (TEXT), , DOCType (TEXT), , SOC (TEXT), , SOCType (TEXT), , EdOpsCode (TEXT), , EdOpsName (TEXT), , EILCode (TEXT), , EILName (TEXT), , GSoffered (TEXT), , GSserved (TEXT), , Virtual (TEXT), , Magnet (INTEGER), , Latitude (REAL), , Longitude (REAL), , AdmFName1 (TEXT), , AdmLName1 (TEXT), , AdmEmail1 (TEXT), , AdmFName2 (TEXT), , AdmLName2 (TEXT), , AdmEmail2 (TEXT), , AdmFName3 (TEXT), , AdmLName3 (TEXT), , AdmEmail3 (TEXT), , LastUpdate (DATE), ]", "fk_str": "frpm.`CDSCode` = schools.`CDSCode`\nsatscores.`cds` = schools.`CDSCode`", "pk_str": "frpm.`CDSCode` | satscores.`cds` | schools.`CDSCode`", "pruned": true, "match_content_str": "NULL", "columns_details_str": "frpm.`County Name`: The column 'County Name' in Table has column descriptions of \"County Name\". Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Riverside', 'Santa Clara'].\nfrpm.`Enrollment (K-12)`: The column 'Enrollment (K-12)' in Table has column descriptions of \"Enrollment (K-12)\". \nfrpm.`Free Meal Count (K-12)`: The column 'Free Meal Count (K-12)' in Table has column descriptions of \"Free Meal Count (K-12)\". \nsatscores.`cname`: The column 'cname' in Table has column descriptions of \"county name\". Value examples: ['Los Angeles', 'San Diego', 'San Bernardino', 'Riverside', 'Orange', 'Alameda'].\nsatscores.`enroll12`: The column 'enroll12' in Table has column descriptions of \"enrollment (1st-12nd grade)\". \nschools.`CDSCode`: The column 'CDSCode' in Table has column descriptions of \"CDSCode\". \nschools.`County`: The column 'County' in Table has column descriptions of \"County name\". Value examples: ['Los Angeles', 'San Diego', 'Orange', 'San Bernardino', 'Santa Clara', 'Alameda'].\nschools.`City`: The column 'City' in Table has column descriptions of \"City\". Value examples: [None, 'Los Angeles', 'San Diego', 'San Jose', 'Sacramento', 'Fresno', 'Oakland'].\n", "subquery_list": ["What is the highest eligible free rate for K-12 students in the schools in Alameda County?"], "initial_state": false, "old_chain_of_thoughts": "To answer the question \"What is the highest eligible free rate for K-12 students in the schools in Alameda County?\", we need to follow these steps:\n\n1. **Identify the relevant columns**: \n - The eligible free rate for K-12 is calculated using the formula: `Free Meal Count (K-12) / Enrollment (K-12)`.\n - We need to select the maximum of this calculated value.\n - We also need to filter the results to only include schools in Alameda County.\n\n2. **Identify the tables involved**: \n - We will use the `frpm` table to filter by county and to access the `Free Meal Count (K-12)` and `Enrollment (K-12)` columns.\n - We will also need to join with the `schools` table to ensure we are only considering schools in Alameda County.\n\n3. **Establish the connection between tables**: \n - The `frpm` table is connected to the `schools` table through the `CDSCode` column.\n\n4. **Construct the SQL query**:\n - We will calculate the eligible free rate and then find the maximum value for schools in Alameda County.\n\nHere is the SQL query that accomplishes this:\n\n```sql\nSELECT MAX(FRM.`Free Meal Count (K-12)` / FRM.`Enrollment (K-12)`) AS highest_eligible_free_rate\nFROM frpm AS FRM\nJOIN schools AS SCH ON FRM.`CDSCode` = SCH.`CDSCode`\nWHERE FRM.`County Name` = 'Alameda'\n```\n\n### Explanation of the SQL:\n- We are selecting the maximum eligible free rate calculated from the `Free Meal Count (K-12)` divided by `Enrollment (K-12)`.\n- We join the `frpm` table with the `schools` table on the `CDSCode` to ensure we are looking at the correct schools.\n- We filter the results to only include those from Alameda County.\n\nThis SQL query will return the highest eligible free rate for K-12 students in schools located in Alameda County.", "final_sql": "SELECT MAX(FRM.`Free Meal Count (K-12)` / FRM.`Enrollment (K-12)`) AS highest_eligible_free_rate FROM frpm AS FRM JOIN schools AS SCH ON FRM.`CDSCode` = SCH.`CDSCode` WHERE FRM.`County Name` = 'Alameda' ", "fixed": false, "try_times": 1, "pred": "SELECT MAX(FRM.`Free Meal Count (K-12)` / FRM.`Enrollment (K-12)`) AS highest_eligible_free_rate FROM frpm AS FRM JOIN schools AS SCH ON FRM.`CDSCode` = SCH.`CDSCode` WHERE FRM.`County Name` = 'Alameda' "}
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | func_timeout==4.3.5
2 | nltk==3.8.1
3 | numpy==1.26.4
4 | openai==1.6.0
5 | pandas==2.2.2
6 | protobuf==5.26.1
7 | sql_metadata==2.11.0
8 | sqlparse==0.5.0
9 | tiktoken==0.5.2
10 | tqdm==4.66.1
11 | transformers==4.40.1
12 | rapidfuzz
13 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from main_scripts.utils import *
3 | from main_scripts.chat_manager import ChatManager
4 | from main_scripts.utils import get_gold_columns
5 | from main_scripts.const import SYSTEM_NAME
6 | import re
7 | from tqdm import tqdm
8 | import time
9 | import argparse
10 | import sys
11 | import os
12 | import json
13 | import traceback
14 | from nltk.tokenize import word_tokenize
15 | from nltk import ngrams
16 |
17 |
18 |
19 | def init_spider_message(idx: int, item: dict) -> dict:
20 | """
21 | Construct message for text-to-SQL task
22 | :param idx: start from 0
23 | :param item: one sample of dataset
24 | :return: initial message object of group chat
25 | """
26 | db_id, query, evidence, gt = item['db_id'], item['question'], str(""), item['query']
27 | difficulty = eval_hardness(item['sql'])
28 | user_message = {
29 | "idx": idx,
30 | "db_id": db_id,
31 | "query": query,
32 | "evidence": evidence,
33 | "extracted_schema": {},
34 | "ground_truth": gt,
35 | "difficulty": difficulty,
36 | "send_to": SYSTEM_NAME
37 | }
38 | return user_message
39 |
40 |
41 | def init_bird_message(idx: int, item: dict, db_path: str=None, use_gold_schema: bool = False) -> dict:
42 | """
43 | Construct message for text-to-SQL task
44 | :param idx: start from 0
45 | :param item: one sample of dataset
46 | :return: initial message object of group chat
47 | """
48 | db_id, query, evidence, gt, difficulty = item['db_id'], \
49 | item['question'], \
50 | item.get('evidence','NULL'), \
51 | item.get('SQL', ''), \
52 | item.get('difficulty', 'simple'), \
53 |
54 |
55 | gold_schema_path = './data/bird/dev_gold_schema.json'
56 | gold_schema = {}
57 | key = f"{db_id.strip()}\t{query.strip()}"
58 | if use_gold_schema:
59 | if os.path.exists(gold_schema_path):
60 | all_gold_schema_dict = load_json_file(gold_schema_path)
61 | if key in all_gold_schema_dict:
62 | gold_schema = all_gold_schema_dict[key]
63 | else:
64 | raise ValueError(f"Can't find gold schema for {key}")
65 |
66 | user_message = {
67 | "idx": idx,
68 | "db_id": db_id,
69 | "query": query,
70 | "evidence": evidence,
71 | "extracted_schema": gold_schema if gold_schema else {},
72 | "ground_truth": gt,
73 | "difficulty": difficulty,
74 | "send_to": SYSTEM_NAME
75 | }
76 | return user_message
77 |
78 | def contain_date(s):
79 | r = re.compile('.*\'\d{4}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])\'.*')
80 | if r.match(s) != None:
81 | return r.match(s).group()
82 | else:
83 | return None
84 |
85 | def run_batch(dataset_name, input_file, output_file, db_path, tables_json_path, start_pos=0, log_file=None, dataset_mode='dev', use_gold_schema=False, without_selector=False):
86 | chat_manager = ChatManager(data_path=db_path,
87 | tables_json_path=tables_json_path,
88 | log_path=log_file,
89 | dataset_name=dataset_name,
90 | dataset_path=input_file,
91 | model_name='gpt-4',
92 | lazy=True,
93 | without_selector=without_selector)
94 | # load dataset
95 | batch = load_json_file(input_file)
96 | # resume from last checkpoint
97 | finished_ids = set()
98 | """
99 | file = open(output_file, 'r', encoding='utf-8')
100 | output_data_lst = []
101 | for line in file.readlines():
102 | dic = json.loads(line)
103 | output_data_lst.append(dic)
104 | for o in output_data_lst:
105 | finished_ids.add(o['idx'])
106 | """
107 | if os.path.exists(output_file):
108 | output_data_lst = load_jsonl_file(output_file)
109 | for o in output_data_lst:
110 | finished_ids.add(o['idx'])
111 |
112 | unfinished_ids = [n for n in range(len(batch)) if n not in finished_ids and n >= start_pos]
113 | print(f"unsolved idxs are : {unfinished_ids}")
114 |
115 | # add question_id if needed
116 | for k, item in enumerate(batch):
117 | if 'question_id' not in item:
118 | item['question_id'] = k
119 |
120 | # skip some json data
121 | excluded_db_ids = []
122 | if dataset_mode == 'train':
123 | exclude_txt = './data/bird_train/excluded_db_ids.txt'
124 | excluded_db_ids = read_txt_file(exclude_txt)
125 | new_batch = []
126 | exclude_db_json_cnt = 0 # for exclude some dbs in bird train set
127 | for k, item in enumerate(batch):
128 | q_id = item['question_id']
129 | if q_id not in unfinished_ids:
130 | continue
131 | if dataset_mode == 'train':
132 | # skip excluded db_id
133 | if item['db_id'] in excluded_db_ids:
134 | exclude_db_json_cnt += 1
135 | continue
136 | new_batch.append(item)
137 |
138 | if exclude_db_json_cnt:
139 | print(f"excluded {exclude_db_json_cnt} excluded db json data")
140 | time.sleep(2)
141 | batch = new_batch
142 |
143 |
144 | # generate SQL one by one, and save result one by one
145 | with open(output_file, 'a+', encoding='utf-8') as fp:
146 | total_num = len(batch)
147 | for cur_idx, item in tqdm(enumerate(batch)):
148 | idx = item['question_id']
149 | db_id = item['db_id']
150 | print(f"\n\nprocessing: {cur_idx}/{total_num}\n\n", flush=True)
151 | if idx not in unfinished_ids: continue
152 | if dataset_name == "spider":
153 | user_message = init_spider_message(idx, item) # imitate user send a question to system
154 | elif dataset_name == "bird":
155 | user_message = init_bird_message(idx, item, db_path=db_path, use_gold_schema=use_gold_schema) # imitate user send a question to system
156 | try:
157 | chat_manager.start(user_message)
158 | try:
159 | """
160 | del user_message['desc_str']
161 | del user_message['fk_str']
162 | """
163 | del user_message['send_to']
164 | del user_message['db_content_dict']
165 | except:
166 | pass
167 | print(json.dumps(user_message, ensure_ascii=False), file=fp, flush=True)
168 | except Exception as e:
169 | # for debug
170 | traceback.print_exc()
171 | print(f"Exception: {e}, sleep 20 seconds.", flush=True)
172 | time.sleep(20)
173 | # raise Exception(str(e))
174 | print(f"\n\ndeal {cur_idx+1}/{total_num} done!\n\n")
175 | print(f"Result dump into {output_file}", file=sys.stdout, flush=True)
176 |
177 | # export evaluation results
178 | out_dir = os.path.dirname(output_file)
179 |
180 | # transfer SQL result to supportable BIRD format
181 | if dataset_name == "bird":
182 |
183 | evaluation_file_path = f"{out_dir}/predict_{dataset_mode}_new.json"
184 | with open(evaluation_file_path, 'w', encoding='utf8') as fout:
185 | output_json_list = load_jsonl_file(output_file)
186 | output_json_list = sorted(output_json_list, key=lambda i: i['idx'])
187 | eval_tuple_lst = []
188 | for o in output_json_list:
189 | if o.get('pred') == None:
190 | pred_sql = o['sub_sql'].strip()
191 | else:
192 | pred_sql = o['pred'].strip()
193 | if ' || \' \' || ' in pred_sql:
194 | pred_sql = pred_sql.replace(' || \' \' || ',', ')
195 | pred_sql = replace_multiple_spaces(pred_sql)
196 | sql_and_db_str = pred_sql + '\t' + '----- bird -----' + '\t' + o['db_id']
197 | obj = [o['query'], sql_and_db_str]
198 | eval_tuple_lst.append(obj)
199 | json.dump(eval_tuple_lst, fp=fout, ensure_ascii=False, indent=2)
200 | print(f"BIRD format file dump into {evaluation_file_path}")
201 | """
202 | #evaluation_file_path = f"{out_dir}/mac_pred_codellama7b_epoch5.sql"
203 | evaluation_file_path = f"{out_dir}/newprompt_mac.sql"
204 | with open(evaluation_file_path, "w") as f:
205 | output_json_list = load_jsonl_file(output_file)
206 | output_json_list = sorted(output_json_list, key=lambda i: i['idx'])
207 | for output in output_json_list:
208 | pred_sql = output['pred'].strip()
209 | pred_sql = replace_multiple_spaces(pred_sql)
210 | f.write(pred_sql.replace("\n", " ") + "\n")
211 | """
212 |
213 | elif dataset_name == "spider":
214 | evaluation_file_path = f"{out_dir}/pred_{dataset_mode}.sql"
215 | spider_sql_lst = []
216 | output_json_lst = load_jsonl_file(output_file)
217 | for output_json in output_json_lst:
218 | pred_sql = output_json['pred']
219 | pred_sql = replace_multiple_spaces(pred_sql)
220 | spider_sql_lst.append(pred_sql.strip() + '\n')
221 | save_file(evaluation_file_path, spider_sql_lst)
222 | print(f"Spider format file dump into {evaluation_file_path}")
223 | else:
224 | raise NotImplementedError
225 |
226 |
227 | def check_all_paths(args):
228 | if not os.path.exists(args.input_file):
229 | raise FileNotFoundError(f"Input file {args.input_file} not found")
230 | if not os.path.exists(args.db_path):
231 | raise FileNotFoundError(f"Database path {args.db_path} not found")
232 | if not os.path.exists(args.tables_json_path):
233 | raise FileNotFoundError(f"Tables json path {args.tables_json_path} not found")
234 |
235 | def obtain_n_grams(sequence, max_n):
236 | '''
237 | returns all grams of sequence less than or equal to `max_n`
238 | '''
239 | tokens = word_tokenize(sequence)
240 | all_grams = []
241 | for n in range(1, max_n + 1):
242 | all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)])
243 |
244 | return all_grams
245 |
246 |
247 |
248 | if __name__ == "__main__":
249 | parser = argparse.ArgumentParser()
250 | parser.add_argument('--dataset_name', type=str, default='spider', choices=['spider', 'bird'], help='dataset name')
251 | parser.add_argument('--dataset_mode', type=str, default='dev', choices=['train', 'dev', 'test'], help='dataset mode')
252 | parser.add_argument('--input_file', type=str, required=True, help='path to dataset input')
253 | parser.add_argument('--db_path', type=str, required=True, help='path to databases in dataset')
254 | parser.add_argument('--tables_json_path', type=str, default=None, help='path to tables.json')
255 | parser.add_argument('--output_file', type=str, required=True, help='path to predicted output')
256 | parser.add_argument('--log_file', type=str, default='', help='path to log file if needed')
257 | parser.add_argument('--start_pos', type=int, default=0, help='start position of a batch')
258 | parser.add_argument('--use_gold_schema', action='store_true', default=False)
259 | parser.add_argument('--without_selector', action='store_true', default=False)
260 | args = parser.parse_args()
261 |
262 | # print each key argument and corresponding value
263 | for key, value in vars(args).items():
264 | print(f"{key}: {value}")
265 |
266 | check_all_paths(args)
267 |
268 | # print the args in clean json format
269 | args_json_str = json.dumps(vars(args), indent=2, ensure_ascii=False)
270 |
271 | print(f"args:\n{args_json_str}")
272 | time.sleep(3)
273 |
274 |
275 | run_batch(
276 | dataset_name=args.dataset_name,
277 | dataset_mode=args.dataset_mode,
278 | input_file=args.input_file,
279 | output_file=args.output_file,
280 | db_path=args.db_path,
281 | tables_json_path= args.tables_json_path,
282 | log_file=args.log_file,
283 | start_pos=args.start_pos,
284 | use_gold_schema=args.use_gold_schema,
285 | without_selector=args.without_selector,
286 | )
287 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python ./run.py --dataset_name bird \
4 | --dataset_mode dev \
5 | --input_file ./data/bird/dev/dev.json \
6 | --db_path ./data/bird/dev/dev_databases/ \
7 | --tables_json_path ./data/bird/dev/dev_tables.json \
8 | --output_file ./output/gpt4_mag.json \
9 | --log_file ./output/log/log/txt \
10 | --start_pos 0
11 |
12 | echo "Generate SQL on bird dev data done!"
13 |
14 |
15 |
--------------------------------------------------------------------------------