├── Dockerfile ├── LICENSE ├── README.md ├── conf └── wikisql.conf ├── evaluator.py ├── featurizer.py ├── main.py ├── modeling ├── base_model.py ├── model_factory.py └── torch_model.py ├── requirements.txt ├── utils.py ├── wikisql_evaluate.py ├── wikisql_gendata.py ├── wikisql_lib ├── __init__.py ├── common.py ├── dbengine.py ├── query.py └── table.py └── wikisql_prediction.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime 2 | COPY . /app 3 | WORKDIR /app 4 | RUN pip install -r requirements.txt 5 | 6 | RUN mkdir data 7 | RUN apt-get update && apt-get install -y wget && apt-get install -y git 8 | RUN git clone https://github.com/salesforce/WikiSQL && tar xvjf WikiSQL/data.tar.bz2 -C WikiSQL 9 | RUN python wikisql_gendata.py 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hybrid Ranking Network for Text-to-SQL 2 | Code for our paper [Hybrid Ranking Network for Text-to-SQL](https://arxiv.org/abs/2008.04759) 3 | 4 | ## Environment Setup 5 | 6 | * `Python 3.8` 7 | * `Pytorch 1.7.1` or higher 8 | * `pip install -r requirements.txt` 9 | 10 | We can also run experiments with docker image: 11 | `docker build -t hydranet -f Dockerfile .` 12 | 13 | The built image above contains processed data and is ready for training and evaluation. 14 | 15 | ## Data Preprocessing 16 | 1. Create data folder and output folder first: `mkdir data && mkdir output` 17 | 2. Clone WikiSQL repo: 18 | `git clone https://github.com/salesforce/WikiSQL && tar xvjf WikiSQL/data.tar.bz2 -C WikiSQL` 19 | 3. Preprocess data: 20 | `python wikisql_gendata.py` 21 | 22 | ## Training 23 | 1. Run `python main.py train --conf conf/wikisql.conf --gpu 0,1,2,3 --note "some note"`. 24 | 2. Model will be saved to `output` folder, named by training start datetime. 25 | 26 | ## Evaluation 27 | 1. Modify model, input and output settings in `wikisql_prediction.py` and run it. 28 | 2. Run WikiSQL evaluation script to get official numbers: `cd WikiSQL && python evaluate.py data/test.jsonl data/test.db ../output/test_out.jsonl` 29 | 30 | Note: the WikiSQL evaluation script will encounter error when running in Windows system. Hence we included the fixed version for Windows User (run in root folder): `python wikisql_evaluate.py WikiSQL/data/test.jsonl WikiSQL/data/test.db output/test_out.jsonl` 31 | 32 | 33 | ## Trained Model 34 | Trained model that can reproduce reported number on WikiSQL leaderboard is attached in the releases (see under "Releases" in the right column). Model prediction outputs are also attached. -------------------------------------------------------------------------------- /conf/wikisql.conf: -------------------------------------------------------------------------------- 1 | model_type pytorch 2 | 3 | #DEBUG 1 4 | SAVE 1 5 | train_data_path data/wikitrain.jsonl 6 | dev_data_path data/wikidev.jsonl 7 | test_data_path data/wikitest.jsonl 8 | 9 | base_class roberta 10 | base_name large 11 | max_total_length 96 12 | where_column_num 4 13 | op_num 4 14 | agg_num 6 15 | 16 | drop_rate 0.2 17 | learning_rate 3e-5 18 | decay 0.01 19 | epochs 5 20 | batch_size 256 21 | num_warmup_steps 400 -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import utils 4 | from modeling.base_model import BaseModel 5 | from modeling.model_factory import create_model 6 | from featurizer import InputFeature, HydraFeaturizer, SQLDataset 7 | 8 | class HydraEvaluator(): 9 | def __init__(self, output_path, config, hydra_featurizer: HydraFeaturizer, model:BaseModel, note=""): 10 | self.config = config 11 | self.model = model 12 | self.eval_history_file = os.path.join(output_path, "eval.log") 13 | self.bad_case_dir = os.path.join(output_path, "bad_cases") 14 | if "DEBUG" not in config: 15 | os.mkdir(self.bad_case_dir) 16 | with open(self.eval_history_file, "w", encoding="utf8") as f: 17 | f.write(note.rstrip() + "\n") 18 | 19 | self.eval_data = {} 20 | for eval_path in config["dev_data_path"].split("|") + config["test_data_path"].split("|"): 21 | eval_data = SQLDataset(eval_path, config, hydra_featurizer, True) 22 | self.eval_data[os.path.basename(eval_path)] = eval_data 23 | 24 | print("Eval Data file {0} loaded, sample num = {1}".format(eval_path, len(eval_data))) 25 | 26 | def _eval_imp(self, eval_data: SQLDataset, get_sq=True): 27 | items = ["overall", "agg", "sel", "wn", "wc", "op", "val"] 28 | acc = {k:0.0 for k in items} 29 | sq = [] 30 | cnt = 0 31 | model_outputs = self.model.dataset_inference(eval_data) 32 | for input_feature, model_output in zip(eval_data.input_features, model_outputs): 33 | cur_acc = {k:1 for k in acc if k != "overall"} 34 | 35 | select_label = np.argmax(input_feature.select) 36 | agg_label = input_feature.agg[select_label] 37 | wn_label = input_feature.where_num[0] 38 | wc_label = [i for i, w in enumerate(input_feature.where) if w == 1] 39 | 40 | agg, select, where, conditions = self.model.parse_output(input_feature, model_output, wc_label) 41 | if agg != agg_label: 42 | cur_acc["agg"] = 0 43 | if select != select_label: 44 | cur_acc["sel"] = 0 45 | if len(where) != wn_label: 46 | cur_acc["wn"] = 0 47 | if set(where) != set(wc_label): 48 | cur_acc["wc"] = 0 49 | 50 | for w in wc_label: 51 | _, op, vs, ve = conditions[w] 52 | if op != input_feature.op[w]: 53 | cur_acc["op"] = 0 54 | 55 | if vs != input_feature.value_start[w] or ve != input_feature.value_end[w]: 56 | cur_acc["val"] = 0 57 | 58 | for k in cur_acc: 59 | acc[k] += cur_acc[k] 60 | 61 | all_correct = 0 if 0 in cur_acc.values() else 1 62 | acc["overall"] += all_correct 63 | 64 | if ("DEBUG" in self.config or get_sq) and not all_correct: 65 | try: 66 | true_sq = input_feature.output_SQ() 67 | pred_sq = input_feature.output_SQ(agg=agg, sel=select, conditions=[conditions[w] for w in where]) 68 | task_cor_text = "".join([str(cur_acc[k]) for k in items if k in cur_acc]) 69 | sq.append([str(cnt), input_feature.question, "|".join([task_cor_text, pred_sq, true_sq])]) 70 | except: 71 | pass 72 | cnt += 1 73 | 74 | result_str = [] 75 | for item in items: 76 | result_str.append(item + ":{0:.1f}".format(acc[item] * 100.0 / cnt)) 77 | 78 | result_str = ", ".join(result_str) 79 | 80 | return result_str, sq 81 | 82 | def eval(self, epochs): 83 | print(self.bad_case_dir) 84 | for eval_file in self.eval_data: 85 | result_str, sq = self._eval_imp(self.eval_data[eval_file]) 86 | print(eval_file + ": " + result_str) 87 | 88 | if "DEBUG" in self.config: 89 | for text in sq: 90 | print(text[0] + ":" + text[1] + "\t" + text[2]) 91 | else: 92 | with open(self.eval_history_file, "a+", encoding="utf8") as f: 93 | f.write("[{0}, epoch {1}] ".format(eval_file, epochs) + result_str + "\n") 94 | 95 | bad_case_file = os.path.join(self.bad_case_dir, 96 | "{0}_epoch_{1}.log".format(eval_file, epochs)) 97 | with open(bad_case_file, "w", encoding="utf8") as f: 98 | for text in sq: 99 | f.write(text[0] + ":" + text[1] + "\t" + text[2] + "\n") 100 | 101 | if __name__ == "__main__": 102 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 103 | config = utils.read_conf(os.path.join("conf", "wikisql.conf")) 104 | config["DEBUG"] = 1 105 | config["num_train_steps"] = 1000 106 | config["num_warmup_steps"] = 100 107 | 108 | featurizer = HydraFeaturizer(config) 109 | model = create_model(config, is_train=True, num_gpu=1) 110 | evaluator = HydraEvaluator("output", config, featurizer, model, "debug evaluator") 111 | evaluator.eval(0) -------------------------------------------------------------------------------- /featurizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import utils 5 | import torch.utils.data as torch_data 6 | from wikisql_gendata import SQLExample 7 | from collections import defaultdict 8 | from typing import List 9 | 10 | stats = defaultdict(int) 11 | 12 | class InputFeature(object): 13 | def __init__(self, 14 | question, 15 | table_id, 16 | tokens, 17 | word_to_char_start, 18 | word_to_subword, 19 | subword_to_word, 20 | input_ids, 21 | input_mask, 22 | segment_ids): 23 | self.question = question 24 | self.table_id = table_id 25 | self.tokens = tokens 26 | self.word_to_char_start = word_to_char_start 27 | self.word_to_subword = word_to_subword 28 | self.subword_to_word = subword_to_word 29 | self.input_ids = input_ids 30 | self.input_mask = input_mask 31 | self.segment_ids = segment_ids 32 | 33 | self.columns = None 34 | self.agg = None 35 | self.select = None 36 | self.where_num = None 37 | self.where = None 38 | self.op = None 39 | self.value_start = None 40 | self.value_end = None 41 | 42 | def output_SQ(self, agg = None, sel = None, conditions = None, return_str=True): 43 | agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 44 | cond_ops = ['=', '>', '<', 'OP'] 45 | 46 | if agg is None and sel is None and conditions is None: 47 | sel = np.argmax(self.select) 48 | agg = self.agg[sel] 49 | conditions = [] 50 | for i in range(len(self.where)): 51 | if self.where[i] == 0: 52 | continue 53 | conditions.append((i, self.op[i], self.value_start[i], self.value_end[i])) 54 | 55 | agg_text = agg_ops[agg] 56 | select_text = self.columns[sel] 57 | cond_texts = [] 58 | for wc, op, vs, ve in conditions: 59 | column_text = self.columns[wc] 60 | op_text = cond_ops[op] 61 | word_start, word_end = self.subword_to_word[wc][vs], self.subword_to_word[wc][ve] 62 | char_start = self.word_to_char_start[word_start] 63 | char_end = len(self.question) if word_end + 1 >= len(self.word_to_char_start) else self.word_to_char_start[word_end + 1] 64 | value_span_text = self.question[char_start:char_end] 65 | cond_texts.append(column_text + op_text + value_span_text.rstrip()) 66 | 67 | if return_str: 68 | sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts) 69 | else: 70 | sq = (agg_text, select_text, set(cond_texts)) 71 | 72 | return sq 73 | 74 | class HydraFeaturizer(object): 75 | def __init__(self, config): 76 | self.config = config 77 | self.tokenizer = utils.create_tokenizer(config) 78 | self.colType2token = { 79 | "string": "[unused1]", 80 | "real": "[unused2]"} 81 | 82 | def get_input_feature(self, example: SQLExample, config): 83 | max_total_length = int(config["max_total_length"]) 84 | 85 | input_feature = InputFeature( 86 | example.question, 87 | example.table_id, 88 | [], 89 | example.word_to_char_start, 90 | [], 91 | [], 92 | [], 93 | [], 94 | [] 95 | ) 96 | 97 | for column, col_type, _ in example.column_meta: 98 | # get query tokens 99 | tokens = [] 100 | word_to_subword = [] 101 | subword_to_word = [] 102 | for i, query_token in enumerate(example.tokens): 103 | if self.config["base_class"] == "roberta": 104 | sub_tokens = self.tokenizer.tokenize(query_token, add_prefix_space=True) 105 | else: 106 | sub_tokens = self.tokenizer.tokenize(query_token) 107 | cur_pos = len(tokens) 108 | if len(sub_tokens) > 0: 109 | word_to_subword += [(cur_pos, cur_pos + len(sub_tokens))] 110 | tokens.extend(sub_tokens) 111 | subword_to_word.extend([i] * len(sub_tokens)) 112 | 113 | if self.config["base_class"] == "roberta": 114 | tokenize_result = self.tokenizer.encode_plus( 115 | col_type + " " + column, 116 | tokens, 117 | padding="max_length", 118 | max_length=max_total_length, 119 | truncation=True, 120 | add_prefix_space=True 121 | ) 122 | else: 123 | tokenize_result = self.tokenizer.encode_plus( 124 | col_type + " " + column, 125 | tokens, 126 | padding="max_length", 127 | max_length=max_total_length, 128 | truncation_strategy="longest_first", 129 | truncation=True, 130 | ) 131 | 132 | input_ids = tokenize_result["input_ids"] 133 | input_mask = tokenize_result["attention_mask"] 134 | 135 | tokens = self.tokenizer.convert_ids_to_tokens(input_ids) 136 | column_token_length = 0 137 | if self.config["base_class"] == "roberta": 138 | for i, token_id in enumerate(input_ids): 139 | if token_id == self.tokenizer.sep_token_id: 140 | column_token_length = i + 2 141 | break 142 | segment_ids = [0] * max_total_length 143 | for i in range(column_token_length, max_total_length): 144 | if input_mask[i] == 0: 145 | break 146 | segment_ids[i] = 1 147 | else: 148 | for i, token_id in enumerate(input_ids): 149 | if token_id == self.tokenizer.sep_token_id: 150 | column_token_length = i + 1 151 | break 152 | segment_ids = tokenize_result["token_type_ids"] 153 | 154 | subword_to_word = [0] * column_token_length + subword_to_word 155 | word_to_subword = [(pos[0]+column_token_length, pos[1]+column_token_length) for pos in word_to_subword] 156 | 157 | assert len(input_ids) == max_total_length 158 | assert len(input_mask) == max_total_length 159 | assert len(segment_ids) == max_total_length 160 | 161 | input_feature.tokens.append(tokens) 162 | input_feature.word_to_subword.append(word_to_subword) 163 | input_feature.subword_to_word.append(subword_to_word) 164 | input_feature.input_ids.append(input_ids) 165 | input_feature.input_mask.append(input_mask) 166 | input_feature.segment_ids.append(segment_ids) 167 | 168 | return input_feature 169 | 170 | def fill_label_feature(self, example: SQLExample, input_feature: InputFeature, config): 171 | max_total_length = int(config["max_total_length"]) 172 | 173 | columns = [c[0] for c in example.column_meta] 174 | col_num = len(columns) 175 | input_feature.columns = columns 176 | 177 | input_feature.agg = [0] * col_num 178 | input_feature.agg[example.select] = example.agg 179 | input_feature.where_num = [len(example.conditions)] * col_num 180 | 181 | input_feature.select = [0] * len(columns) 182 | input_feature.select[example.select] = 1 183 | 184 | input_feature.where = [0] * len(columns) 185 | input_feature.op = [0] * len(columns) 186 | input_feature.value_start = [0] * len(columns) 187 | input_feature.value_end = [0] * len(columns) 188 | 189 | for colidx, op, _ in example.conditions: 190 | input_feature.where[colidx] = 1 191 | input_feature.op[colidx] = op 192 | for colidx, column_meta in enumerate(example.column_meta): 193 | if column_meta[-1] == None: 194 | continue 195 | se = example.value_start_end[column_meta[-1]] 196 | try: 197 | s = input_feature.word_to_subword[colidx][se[0]][0] 198 | input_feature.value_start[colidx] = s 199 | e = input_feature.word_to_subword[colidx][se[1]-1][1]-1 200 | input_feature.value_end[colidx] = e 201 | assert s < max_total_length and input_feature.input_mask[colidx][s] == 1 202 | assert e < max_total_length and input_feature.input_mask[colidx][e] == 1 203 | 204 | except: 205 | print("value span is out of range") 206 | return False 207 | 208 | # feature_sq = input_feature.output_SQ(return_str=False) 209 | # example_sq = example.output_SQ(return_str=False) 210 | # if feature_sq != example_sq: 211 | # print(example.qid, feature_sq, example_sq) 212 | return True 213 | 214 | def load_data(self, data_paths, config, include_label=False): 215 | model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]} 216 | if include_label: 217 | for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]: 218 | model_inputs[k] = [] 219 | 220 | pos = [] 221 | input_features = [] 222 | for data_path in data_paths.split("|"): 223 | cnt = 0 224 | for line in open(data_path, encoding="utf8"): 225 | example = SQLExample.load_from_json(line) 226 | if not example.valid and include_label == True: 227 | continue 228 | 229 | input_feature = self.get_input_feature(example, config) 230 | if include_label: 231 | success = self.fill_label_feature(example, input_feature, config) 232 | if not success: 233 | continue 234 | 235 | # sq = input_feature.output_SQ() 236 | input_features.append(input_feature) 237 | 238 | cur_start = len(model_inputs["input_ids"]) 239 | cur_sample_num = len(input_feature.input_ids) 240 | pos.append((cur_start, cur_start + cur_sample_num)) 241 | 242 | model_inputs["input_ids"].extend(input_feature.input_ids) 243 | model_inputs["input_mask"].extend(input_feature.input_mask) 244 | model_inputs["segment_ids"].extend(input_feature.segment_ids) 245 | if include_label: 246 | model_inputs["agg"].extend(input_feature.agg) 247 | model_inputs["select"].extend(input_feature.select) 248 | model_inputs["where_num"].extend(input_feature.where_num) 249 | model_inputs["where"].extend(input_feature.where) 250 | model_inputs["op"].extend(input_feature.op) 251 | model_inputs["value_start"].extend(input_feature.value_start) 252 | model_inputs["value_end"].extend(input_feature.value_end) 253 | 254 | cnt += 1 255 | if cnt % 5000 == 0: 256 | print(cnt) 257 | 258 | if "DEBUG" in config and cnt > 100: 259 | break 260 | 261 | for k in model_inputs: 262 | model_inputs[k] = np.array(model_inputs[k], dtype=np.int64) 263 | 264 | return input_features, model_inputs, pos 265 | 266 | class SQLDataset(torch_data.Dataset): 267 | def __init__(self, data_paths, config, featurizer, include_label=False): 268 | self.config = config 269 | self.featurizer = featurizer 270 | self.input_features, self.model_inputs, self.pos = self.featurizer.load_data(data_paths, config, include_label) 271 | 272 | print("{0} loaded. Data shapes:".format(data_paths)) 273 | for k, v in self.model_inputs.items(): 274 | print(k, v.shape) 275 | 276 | def __len__(self): 277 | return self.model_inputs["input_ids"].shape[0] 278 | 279 | def __getitem__(self, idx): 280 | return {k: v[idx] for k, v in self.model_inputs.items()} 281 | 282 | 283 | if __name__ == "__main__": 284 | vocab = "vocab/baseTrue.txt" 285 | config = {} 286 | for line in open("conf/wikisql.conf", encoding="utf8"): 287 | if line.strip() == "" or line[0] == "#": 288 | continue 289 | fields = line.strip().split("\t") 290 | config[fields[0]] = fields[1] 291 | # config["DEBUG"] = 1 292 | 293 | featurizer = HydraFeaturizer(config) 294 | train_data = SQLDataset(config["train_data_path"], config, featurizer, True) 295 | train_data_loader = torch_data.DataLoader(train_data, batch_size=128, shuffle=True, pin_memory=True) 296 | for batch_id, batch in enumerate(train_data_loader): 297 | print(batch_id, {k: v.shape for k, v in batch.items()}) 298 | 299 | for k, v in stats.items(): 300 | print(k, v) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import shutil 5 | import datetime 6 | import utils 7 | from modeling.model_factory import create_model 8 | from featurizer import HydraFeaturizer, SQLDataset 9 | from evaluator import HydraEvaluator 10 | import torch.utils.data as torch_data 11 | 12 | parser = argparse.ArgumentParser(description='HydraNet training script') 13 | parser.add_argument("job", type=str, choices=["train"], 14 | help="job can be train") 15 | parser.add_argument("--conf", help="conf file path") 16 | parser.add_argument("--output_path", type=str, default="output", help="folder path for all outputs") 17 | parser.add_argument("--model_path", help="trained model folder path (used in eval, predict and export mode)") 18 | parser.add_argument("--epoch", help="epochs to restore (used in eval, predict and export mode)") 19 | parser.add_argument("--gpu", type=str, default=None, help="gpu id") 20 | parser.add_argument("--note", type=str) 21 | 22 | args = parser.parse_args() 23 | 24 | if args.job == "train": 25 | if args.gpu is not None: 26 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 27 | # os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1' 28 | conf_path = os.path.abspath(args.conf) 29 | config = utils.read_conf(conf_path) 30 | 31 | note = args.note if args.note else "" 32 | 33 | script_path = os.path.dirname(os.path.abspath(sys.argv[0])) 34 | output_path = args.output_path 35 | model_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 36 | model_path = os.path.join(output_path, model_name) 37 | 38 | if "DEBUG" not in config: 39 | if not os.path.exists(output_path): 40 | os.mkdir(output_path) 41 | if not os.path.exists(model_path): 42 | os.mkdir(model_path) 43 | 44 | shutil.copyfile(conf_path, os.path.join(model_path, "model.conf")) 45 | for pyfile in ["featurizer.py"]: 46 | shutil.copyfile(pyfile, os.path.join(model_path, pyfile)) 47 | if config["model_type"] == "pytorch": 48 | shutil.copyfile("modeling/torch_model.py", os.path.join(model_path, "torch_model.py")) 49 | elif config["model_type"] == "tf": 50 | shutil.copyfile("modeling/tf_model.py", os.path.join(model_path, "tf_model.py")) 51 | else: 52 | raise Exception("model_type is not supported") 53 | 54 | featurizer = HydraFeaturizer(config) 55 | train_data = SQLDataset(config["train_data_path"], config, featurizer, True) 56 | train_data_loader = torch_data.DataLoader(train_data, batch_size=int(config["batch_size"]), shuffle=True, pin_memory=True) 57 | 58 | num_samples = len(train_data) 59 | config["num_train_steps"] = int(num_samples * int(config["epochs"]) / int(config["batch_size"])) 60 | step_per_epoch = num_samples / int(config["batch_size"]) 61 | print("total_steps: {0}, warm_up_steps: {1}".format(config["num_train_steps"], config["num_warmup_steps"])) 62 | 63 | model = create_model(config, is_train=True) 64 | evaluator = HydraEvaluator(model_path, config, featurizer, model, note) 65 | print("start training") 66 | loss_avg, step, epoch = 0.0, 0, 0 67 | while True: 68 | for batch_id, batch in enumerate(train_data_loader): 69 | # print(batch_id) 70 | cur_loss = model.train_on_batch(batch) 71 | loss_avg = (loss_avg * step + cur_loss) / (step + 1) 72 | step += 1 73 | if batch_id % 100 == 0: 74 | currentDT = datetime.datetime.now() 75 | print("[{3}] epoch {0}, batch {1}, batch_loss={2:.4f}".format(epoch, batch_id, cur_loss, 76 | currentDT.strftime("%m-%d %H:%M:%S"))) 77 | if args.note: 78 | print(args.note) 79 | model.save(model_path, epoch) 80 | evaluator.eval(epoch) 81 | epoch += 1 82 | if epoch >= int(config["epochs"]): 83 | break 84 | 85 | else: 86 | raise Exception("Job type {0} is not supported for now".format(args.job)) -------------------------------------------------------------------------------- /modeling/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from typing import List 4 | from featurizer import SQLDataset 5 | 6 | class BaseModel(object): 7 | """Define common interfaces for HydraNet models""" 8 | def train_on_batch(self, batch): 9 | raise NotImplementedError() 10 | 11 | def save(self, model_path, epoch): 12 | raise NotImplementedError() 13 | 14 | def load(self, model_path, epoch): 15 | raise NotImplementedError() 16 | 17 | def model_inference(self, model_inputs): 18 | """model prediction on processed features""" 19 | raise NotImplementedError() 20 | 21 | def dataset_inference(self, dataset: SQLDataset): 22 | print("model prediction start") 23 | start_time = time.time() 24 | model_outputs = self.model_inference(dataset.model_inputs) 25 | 26 | final_outputs = [] 27 | for pos in dataset.pos: 28 | final_output = {} 29 | for k in model_outputs: 30 | final_output[k] = model_outputs[k][pos[0]:pos[1], :] 31 | final_outputs.append(final_output) 32 | print("model prediction end, time elapse: {0}".format(time.time() - start_time)) 33 | assert len(dataset.input_features) == len(final_outputs) 34 | 35 | return final_outputs 36 | 37 | def predict_SQL(self, dataset: SQLDataset, model_outputs=None): 38 | if model_outputs is None: 39 | model_outputs = self.dataset_inference(dataset) 40 | sqls = [] 41 | for input_feature, model_output in zip(dataset.input_features, model_outputs): 42 | agg, select, where, conditions = self.parse_output(input_feature, model_output, []) 43 | 44 | conditions_with_value_texts = [] 45 | for wc in where: 46 | _, op, vs, ve = conditions[wc] 47 | word_start, word_end = input_feature.subword_to_word[wc][vs], input_feature.subword_to_word[wc][ve] 48 | char_start = input_feature.word_to_char_start[word_start] 49 | char_end = len(input_feature.question) 50 | if word_end + 1 < len(input_feature.word_to_char_start): 51 | char_end = input_feature.word_to_char_start[word_end + 1] 52 | value_span_text = input_feature.question[char_start:char_end].rstrip() 53 | conditions_with_value_texts.append((wc, op, value_span_text)) 54 | 55 | sqls.append((agg, select, conditions_with_value_texts)) 56 | 57 | return sqls 58 | 59 | def predict_SQL_with_EG(self, engine, dataset: SQLDataset, beam_size=5, model_outputs=None): 60 | if model_outputs is None: 61 | model_outputs = self.dataset_inference(dataset) 62 | sqls = [] 63 | for input_feature, model_output in zip(dataset.input_features, model_outputs): 64 | agg, select, where_num, conditions = self.beam_parse_output(input_feature, model_output, beam_size) 65 | query = {"agg": agg, "sel": select, "conds": []} 66 | wcs = set() 67 | conditions_with_value_texts = [] 68 | for condition in conditions: 69 | if len(wcs) >= where_num: 70 | break 71 | _, wc, op, vs, ve = condition 72 | if wc in wcs: 73 | continue 74 | 75 | word_start, word_end = input_feature.subword_to_word[wc][vs], input_feature.subword_to_word[wc][ve] 76 | char_start = input_feature.word_to_char_start[word_start] 77 | char_end = len(input_feature.question) 78 | if word_end + 1 < len(input_feature.word_to_char_start): 79 | char_end = input_feature.word_to_char_start[word_end + 1] 80 | value_span_text = input_feature.question[char_start:char_end].rstrip() 81 | 82 | query["conds"] = [[int(wc), int(op), value_span_text]] 83 | result, sql = engine.execute_dict_query(input_feature.table_id, query) 84 | if not result or 'ERROR: ' in result: 85 | continue 86 | 87 | conditions_with_value_texts.append((wc, op, value_span_text)) 88 | wcs.add(wc) 89 | 90 | sqls.append((agg, select, conditions_with_value_texts)) 91 | 92 | return sqls 93 | 94 | def _get_where_num(self, output): 95 | # wn = np.argmax(output["where_num"], -1) 96 | # max_num = 0 97 | # max_cnt = np.sum(wn == 0) 98 | # for num in range(1, 5): 99 | # cur_cnt = np.sum(wn==num) 100 | # if cur_cnt > max_cnt: 101 | # max_cnt = cur_cnt 102 | # max_num = num 103 | # def sigmoid(x): 104 | # return 1/(1 + np.exp(-x)) 105 | relevant_prob = 1 - np.exp(output["column_func"][:, 2]) 106 | where_num_scores = np.average(output["where_num"], axis=0, weights=relevant_prob) 107 | where_num = int(np.argmax(where_num_scores)) 108 | 109 | return where_num 110 | 111 | def parse_output(self, input_feature, model_output, where_label = []): 112 | def get_span(i): 113 | offset = 0 114 | segment_ids = np.array(input_feature.segment_ids[i]) 115 | for j in range(len(segment_ids)): 116 | if segment_ids[j] == 1: 117 | offset = j 118 | break 119 | 120 | value_start, value_end = model_output["value_start"][i, segment_ids == 1], model_output["value_end"][i, segment_ids == 1] 121 | l = len(value_start) 122 | sum_mat = value_start.reshape((l, 1)) + value_end.reshape((1, l)) 123 | span = (0, 0) 124 | for cur_span, _ in sorted(np.ndenumerate(sum_mat), key=lambda x:x[1], reverse=True): 125 | if cur_span[1] < cur_span[0] or cur_span[0] == l - 1 or cur_span[1] == l - 1: 126 | continue 127 | span = cur_span 128 | break 129 | 130 | return (span[0]+offset, span[1]+offset) 131 | 132 | select_id_prob = sorted(enumerate(model_output["column_func"][:, 0]), key=lambda x:x[1], reverse=True) 133 | select = select_id_prob[0][0] 134 | agg = np.argmax(model_output["agg"][select, :]) 135 | 136 | where_id_prob = sorted(enumerate(model_output["column_func"][:, 1]), key=lambda x:x[1], reverse=True) 137 | where_num = self._get_where_num(model_output) 138 | where = [i for i, _ in where_id_prob[:where_num]] 139 | conditions = {} 140 | for idx in set(where + where_label): 141 | span = get_span(idx) 142 | op = np.argmax(model_output["op"][idx, :]) 143 | conditions[idx] = (idx, op, span[0], span[1]) 144 | 145 | return agg, select, where, conditions 146 | 147 | def beam_parse_output(self, input_feature, model_output, beam_size=5): 148 | def get_span(i): 149 | offset = 0 150 | segment_ids = np.array(input_feature.segment_ids[i]) 151 | for j in range(len(segment_ids)): 152 | if segment_ids[j] == 1: 153 | offset = j 154 | break 155 | 156 | value_start, value_end = model_output["value_start"][i, segment_ids == 1], model_output["value_end"][i, segment_ids == 1] 157 | l = len(value_start) 158 | sum_mat = value_start.reshape((l, 1)) + value_end.reshape((1, l)) 159 | spans = [] 160 | for cur_span, sum_logp in sorted(np.ndenumerate(sum_mat), key=lambda x:x[1], reverse=True): 161 | if cur_span[1] < cur_span[0] or cur_span[0] == l - 1 or cur_span[1] == l - 1: 162 | continue 163 | spans.append((cur_span[0]+offset, cur_span[1]+offset, sum_logp)) 164 | if len(spans) >= beam_size: 165 | break 166 | 167 | return spans 168 | 169 | select_id_prob = sorted(enumerate(model_output["column_func"][:, 0]), key=lambda x:x[1], reverse=True) 170 | select = select_id_prob[0][0] 171 | agg = np.argmax(model_output["agg"][select, :]) 172 | 173 | where_id_prob = sorted(enumerate(model_output["column_func"][:, 1]), key=lambda x:x[1], reverse=True) 174 | where_num = self._get_where_num(model_output) 175 | conditions = [] 176 | for idx, wlogp in where_id_prob[:beam_size]: 177 | op = np.argmax(model_output["op"][idx, :]) 178 | for span in get_span(idx): 179 | conditions.append((wlogp+span[2], idx, op, span[0], span[1])) 180 | conditions.sort(key=lambda x:x[0], reverse=True) 181 | return agg, select, where_num, conditions -------------------------------------------------------------------------------- /modeling/model_factory.py: -------------------------------------------------------------------------------- 1 | from modeling.base_model import BaseModel 2 | from modeling.torch_model import HydraTorch 3 | 4 | def create_model(config, is_train = False) -> BaseModel: 5 | if config["model_type"] == "pytorch": 6 | return HydraTorch(config) 7 | # elif config["model_type"] == "tf": 8 | # return HydraTensorFlow(config, is_train, num_gpu) 9 | else: 10 | raise NotImplementedError("model type {0} is not supported".format(config["model_type"])) -------------------------------------------------------------------------------- /modeling/torch_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import transformers 5 | import utils 6 | from modeling.base_model import BaseModel 7 | from torch import nn 8 | 9 | class HydraTorch(BaseModel): 10 | def __init__(self, config): 11 | self.config = config 12 | self.model = HydraNet(config) 13 | if torch.cuda.device_count() > 1: 14 | self.model = nn.DataParallel(self.model) 15 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | self.model.to(self.device) 17 | 18 | self.optimizer, self.scheduler = None, None 19 | 20 | def train_on_batch(self, batch): 21 | if self.optimizer is None: 22 | no_decay = ["bias", "LayerNorm.weight"] 23 | optimizer_grouped_parameters = [ 24 | { 25 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 26 | "weight_decay": float(self.config["decay"]), 27 | }, 28 | {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 29 | "weight_decay": 0.0}, 30 | ] 31 | self.optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=float(self.config["learning_rate"])) 32 | self.scheduler = transformers.get_cosine_schedule_with_warmup( 33 | self.optimizer, 34 | num_warmup_steps=int(self.config["num_warmup_steps"]), 35 | num_training_steps=int(self.config["num_train_steps"])) 36 | self.optimizer.zero_grad() 37 | 38 | self.model.train() 39 | for k, v in batch.items(): 40 | batch[k] = v.to(self.device) 41 | batch_loss = torch.mean(self.model(**batch)["loss"]) 42 | batch_loss.backward() 43 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 44 | self.optimizer.step() 45 | self.scheduler.step() 46 | self.optimizer.zero_grad() 47 | 48 | return batch_loss.cpu().detach().numpy() 49 | 50 | def model_inference(self, model_inputs): 51 | self.model.eval() 52 | model_outputs = {} 53 | batch_size = 512 54 | for start_idx in range(0, model_inputs["input_ids"].shape[0], batch_size): 55 | input_tensor = {k: torch.from_numpy(model_inputs[k][start_idx:start_idx+batch_size]).to(self.device) for k in ["input_ids", "input_mask", "segment_ids"]} 56 | with torch.no_grad(): 57 | model_output = self.model(**input_tensor) 58 | for k, out_tensor in model_output.items(): 59 | if out_tensor is None: 60 | continue 61 | if k not in model_outputs: 62 | model_outputs[k] = [] 63 | model_outputs[k].append(out_tensor.cpu().detach().numpy()) 64 | 65 | for k in model_outputs: 66 | model_outputs[k] = np.concatenate(model_outputs[k], 0) 67 | 68 | return model_outputs 69 | 70 | def save(self, model_path, epoch): 71 | if "SAVE" in self.config and "DEBUG" not in self.config: 72 | save_path = os.path.join(model_path, "model_{0}.pt".format(epoch)) 73 | if torch.cuda.device_count() > 1: 74 | torch.save(self.model.module.state_dict(), save_path) 75 | else: 76 | torch.save(self.model.state_dict(), save_path) 77 | print("Model saved in path: %s" % save_path) 78 | 79 | def load(self, model_path, epoch): 80 | pt_path = os.path.join(model_path, "model_{0}.pt".format(epoch)) 81 | loaded_dict = torch.load(pt_path, map_location=torch.device(self.device)) 82 | if torch.cuda.device_count() > 1: 83 | self.model.module.load_state_dict(loaded_dict) 84 | else: 85 | self.model.load_state_dict(loaded_dict) 86 | print("PyTorch model loaded from {0}".format(pt_path)) 87 | 88 | class HydraNet(nn.Module): 89 | def __init__(self, config): 90 | super(HydraNet, self).__init__() 91 | self.config = config 92 | self.base_model = utils.create_base_model(config) 93 | 94 | # #=====Hack for RoBERTa model==== 95 | # self.base_model.config.type_vocab_size = 2 96 | # single_emb = self.base_model.embeddings.token_type_embeddings 97 | # self.base_model.embeddings.token_type_embeddings = torch.nn.Embedding(2, single_emb.embedding_dim) 98 | # self.base_model.embeddings.token_type_embeddings.weight = torch.nn.Parameter(single_emb.weight.repeat([2, 1]), requires_grad=True) 99 | # #==================================== 100 | 101 | drop_rate = float(config["drop_rate"]) if "drop_rate" in config else 0.0 102 | self.dropout = nn.Dropout(drop_rate) 103 | 104 | bert_hid_size = self.base_model.config.hidden_size 105 | self.column_func = nn.Linear(bert_hid_size, 3) 106 | self.agg = nn.Linear(bert_hid_size, int(config["agg_num"])) 107 | self.op = nn.Linear(bert_hid_size, int(config["op_num"])) 108 | self.where_num = nn.Linear(bert_hid_size, int(config["where_column_num"]) + 1) 109 | self.start_end = nn.Linear(bert_hid_size, 2) 110 | 111 | def forward(self, input_ids, input_mask, segment_ids, agg=None, select=None, where=None, where_num=None, op=None, value_start=None, value_end=None): 112 | # print("[inner] input_ids size:", input_ids.size()) 113 | if self.config["base_class"] == "roberta": 114 | bert_output, pooled_output = self.base_model( 115 | input_ids=input_ids, 116 | attention_mask=input_mask, 117 | token_type_ids=None, 118 | return_dict=False) 119 | else: 120 | bert_output, pooled_output = self.base_model( 121 | input_ids=input_ids, 122 | attention_mask=input_mask, 123 | token_type_ids=segment_ids, 124 | return_dict=False) 125 | 126 | bert_output = self.dropout(bert_output) 127 | pooled_output = self.dropout(pooled_output) 128 | 129 | column_func_logit = self.column_func(pooled_output) 130 | agg_logit = self.agg(pooled_output) 131 | op_logit = self.op(pooled_output) 132 | where_num_logit = self.where_num(pooled_output) 133 | start_end_logit = self.start_end(bert_output) 134 | value_span_mask = input_mask.to(dtype=bert_output.dtype) 135 | # value_span_mask[:, 0] = 1 136 | start_logit = start_end_logit[:, :, 0] * value_span_mask - 1000000.0 * (1 - value_span_mask) 137 | end_logit = start_end_logit[:, :, 1] * value_span_mask - 1000000.0 * (1 - value_span_mask) 138 | 139 | loss = None 140 | if select is not None: 141 | bceloss = nn.BCEWithLogitsLoss(reduction="none") 142 | cross_entropy = nn.CrossEntropyLoss(reduction="none") 143 | 144 | loss = cross_entropy(agg_logit, agg) * select.float() 145 | loss += bceloss(column_func_logit[:, 0], select.float()) 146 | loss += bceloss(column_func_logit[:, 1], where.float()) 147 | loss += bceloss(column_func_logit[:, 2], (1-select.float()) * (1-where.float())) 148 | loss += cross_entropy(where_num_logit, where_num) 149 | loss += cross_entropy(op_logit, op) * where.float() 150 | loss += cross_entropy(start_logit, value_start) 151 | loss += cross_entropy(end_logit, value_end) 152 | 153 | 154 | # return loss, column_func_logit, agg_logit, op_logit, where_num_logit, start_logit, end_logit 155 | log_sigmoid = nn.LogSigmoid() 156 | 157 | return {"column_func": log_sigmoid(column_func_logit), 158 | "agg": agg_logit.log_softmax(1), 159 | "op": op_logit.log_softmax(1), 160 | "where_num": where_num_logit.log_softmax(1), 161 | "value_start": start_logit.log_softmax(1), 162 | "value_end": end_logit.log_softmax(1), 163 | "loss": loss} 164 | 165 | if __name__ == "__main__": 166 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 167 | config = {} 168 | config["num_train_steps"] = 1000 169 | config["num_warmup_steps"] = 100 170 | for line in open("../conf/wikisql.conf", encoding="utf8"): 171 | if line.strip() == "" or line[0] == "#": 172 | continue 173 | fields = line.strip().split("\t") 174 | config[fields[0]] = fields[1] 175 | 176 | model = HydraTorch(config) 177 | tokenizer = utils.create_tokenizer(config) 178 | inputs = tokenizer.encode_plus("Here is some text to encode", text_pair="hello world!", add_special_tokens=True) 179 | batch_size = 16 180 | inputs = { 181 | "input_ids": torch.tensor([inputs["input_ids"]]*batch_size), 182 | "input_mask": torch.tensor([inputs["attention_mask"]]*batch_size), 183 | "segment_ids": torch.tensor([inputs["token_type_ids"]]*batch_size) 184 | } 185 | inputs["agg"] = torch.tensor([0] * batch_size) 186 | inputs["select"] = torch.tensor([0] * batch_size) 187 | inputs["where_num"] = torch.tensor([0] * batch_size) 188 | inputs["where"] = torch.tensor([0] * batch_size) 189 | inputs["op"] = torch.tensor([0] * batch_size) 190 | inputs["value_start"] = torch.tensor([0] * batch_size) 191 | inputs["value_end"] = torch.tensor([0] * batch_size) 192 | 193 | print("===========train=============") 194 | batch_loss = model.train_on_batch(inputs) 195 | print(batch_loss) 196 | batch_loss = model.train_on_batch(inputs) 197 | print(batch_loss) 198 | 199 | print("===========infer=============") 200 | model_output = model.model_inference(inputs) 201 | for k in model_output: 202 | print(k, model_output[k].shape) 203 | print("done") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.30.0 2 | sqlalchemy==1.3.23 3 | tqdm 4 | records 5 | babel 6 | tabulate -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import transformers 4 | 5 | pretrained_weights = { 6 | ("bert", "base"): "bert-base-uncased", 7 | ("bert", "large"): "bert-large-uncased-whole-word-masking", 8 | ("roberta", "base"): "roberta-base", 9 | ("roberta", "large"): "roberta-large", 10 | ("albert", "xlarge"): "albert-xlarge-v2" 11 | } 12 | 13 | 14 | def read_jsonl(jsonl): 15 | for line in open(jsonl, encoding="utf8"): 16 | sample = json.loads(line.rstrip()) 17 | yield sample 18 | 19 | def read_conf(conf_path): 20 | config = {} 21 | for line in open(conf_path, encoding="utf8"): 22 | if line.strip() == "" or line[0] == "#": 23 | continue 24 | fields = line.strip().split("\t") 25 | config[fields[0]] = fields[1] 26 | config["train_data_path"] = os.path.abspath(config["train_data_path"]) 27 | config["dev_data_path"] = os.path.abspath(config["dev_data_path"]) 28 | 29 | return config 30 | 31 | def create_base_model(config): 32 | weights_name = pretrained_weights[(config["base_class"], config["base_name"])] 33 | if config["base_class"] == "bert": 34 | return transformers.BertModel.from_pretrained(weights_name) 35 | elif config["base_class"] == "roberta": 36 | return transformers.RobertaModel.from_pretrained(weights_name) 37 | elif config["base_class"] == "albert": 38 | return transformers.AlbertModel.from_pretrained(weights_name) 39 | else: 40 | raise Exception("base_class {0} not supported".format(config["base_class"])) 41 | 42 | def create_tokenizer(config): 43 | weights_name = pretrained_weights[(config["base_class"], config["base_name"])] 44 | if config["base_class"] == "bert": 45 | return transformers.BertTokenizer.from_pretrained(weights_name) 46 | elif config["base_class"] == "roberta": 47 | return transformers.RobertaTokenizer.from_pretrained(weights_name) 48 | elif config["base_class"] == "albert": 49 | return transformers.AlbertTokenizer.from_pretrained(weights_name) 50 | else: 51 | raise Exception("base_class {0} not supported".format(config["base_class"])) 52 | 53 | if __name__ == "__main__": 54 | qtokens = ['Tell', 'me', 'what', 'the', 'notes', 'are', 'for', 'South', 'Australia'] 55 | column = "string School/Club Team" 56 | 57 | tokenizer = create_tokenizer({"base_class": "roberta", "base_name": "large"}) 58 | 59 | qsubtokens = [] 60 | for t in qtokens: 61 | qsubtokens += tokenizer.tokenize(t, add_prefix_space=True) 62 | print(qsubtokens) 63 | result = tokenizer.encode_plus(column, qsubtokens, add_prefix_space=True) 64 | for k in result: 65 | print(k, result[k]) 66 | print(tokenizer.convert_ids_to_tokens(result["input_ids"])) 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /wikisql_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from wikisql_lib.dbengine import DBEngine 6 | from wikisql_lib.query import Query 7 | from wikisql_lib.common import count_lines 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = ArgumentParser() 12 | parser.add_argument('source_file', help='source file for the prediction') 13 | parser.add_argument('db_file', help='source database for the prediction') 14 | parser.add_argument('pred_file', help='predictions by the model') 15 | parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') 16 | args = parser.parse_args() 17 | 18 | engine = DBEngine(args.db_file) 19 | exact_match = [] 20 | with open(args.source_file) as fs, open(args.pred_file) as fp: 21 | grades = [] 22 | for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): 23 | eg = json.loads(ls) 24 | ep = json.loads(lp) 25 | qg = Query.from_dict(eg['sql'], ordered=args.ordered) 26 | gold = engine.execute_query(eg['table_id'], qg, lower=True) 27 | pred = ep.get('error', None) 28 | qp = None 29 | if not ep.get('error', None): 30 | try: 31 | qp = Query.from_dict(ep['query'], ordered=args.ordered) 32 | pred = engine.execute_query(eg['table_id'], qp, lower=True) 33 | except Exception as e: 34 | pred = repr(e) 35 | correct = pred == gold 36 | match = qp == qg 37 | grades.append(correct) 38 | exact_match.append(match) 39 | print(json.dumps({ 40 | 'ex_accuracy': sum(grades) / len(grades), 41 | 'lf_accuracy': sum(exact_match) / len(exact_match), 42 | }, indent=2)) 43 | -------------------------------------------------------------------------------- /wikisql_gendata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import string 4 | import unicodedata 5 | import utils 6 | 7 | def is_whitespace(c): 8 | if c == " " or c == "\t" or c == "\n" or c == "\r": 9 | return True 10 | cat = unicodedata.category(c) 11 | if cat == "Zs": 12 | return True 13 | return False 14 | 15 | def is_punctuation(c): 16 | """Checks whether `chars` is a punctuation character.""" 17 | cp = ord(c) 18 | # We treat all non-letter/number ASCII as punctuation. 19 | # Characters such as "^", "$", and "`" are not in the Unicode 20 | # Punctuation class but we treat them as punctuation anyways, for 21 | # consistency. 22 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 23 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 24 | return True 25 | cat = unicodedata.category(c) 26 | if cat.startswith("P") or cat.startswith("S"): 27 | return True 28 | return False 29 | 30 | def basic_tokenize(doc): 31 | doc_tokens = [] 32 | char_to_word = [] 33 | word_to_char_start = [] 34 | prev_is_whitespace = True 35 | prev_is_punc = False 36 | prev_is_num = False 37 | for pos, c in enumerate(doc): 38 | if is_whitespace(c): 39 | prev_is_whitespace = True 40 | prev_is_punc = False 41 | else: 42 | if prev_is_whitespace or is_punctuation(c) or prev_is_punc or (prev_is_num and not str(c).isnumeric()): 43 | doc_tokens.append(c) 44 | word_to_char_start.append(pos) 45 | else: 46 | doc_tokens[-1] += c 47 | prev_is_whitespace = False 48 | prev_is_punc = is_punctuation(c) 49 | prev_is_num = str(c).isnumeric() 50 | char_to_word.append(len(doc_tokens) - 1) 51 | 52 | return doc_tokens, char_to_word, word_to_char_start 53 | 54 | class SQLExample(object): 55 | def __init__(self, 56 | qid, 57 | question, 58 | table_id, 59 | column_meta, 60 | agg=None, 61 | select=None, 62 | conditions=None, 63 | tokens=None, 64 | char_to_word=None, 65 | word_to_char_start=None, 66 | value_start_end=None, 67 | valid=True): 68 | self.qid = qid 69 | self.question = question 70 | self.table_id = table_id 71 | self.column_meta = column_meta 72 | self.agg = agg 73 | self.select = select 74 | self.conditions = conditions 75 | self.valid = valid 76 | if tokens is None: 77 | self.tokens, self.char_to_word, self.word_to_char_start = basic_tokenize(question) 78 | self.value_start_end = {} 79 | if conditions is not None and len(conditions) > 0: 80 | cur_start = None 81 | for cond in conditions: 82 | value = cond[-1] 83 | value_tokens, _, _ = basic_tokenize(value) 84 | val_len = len(value_tokens) 85 | for i in range(len(self.tokens)): 86 | if " ".join(self.tokens[i:i+val_len]).lower() != " ".join(value_tokens).lower(): 87 | continue 88 | s = self.word_to_char_start[i] 89 | e = len(question) if i + val_len >= len(self.word_to_char_start) else self.word_to_char_start[i + val_len] 90 | recovered_answer_text = question[s:e].strip() 91 | if value.lower() == recovered_answer_text.lower(): 92 | cur_start = i 93 | break 94 | 95 | if cur_start is None: 96 | self.valid = False 97 | print([value, value_tokens, question, self.tokens]) 98 | # for c in question: 99 | # print((c, ord(c), unicodedata.category(c))) 100 | # raise Exception() 101 | else: 102 | self.value_start_end[value] = (cur_start, cur_start + val_len) 103 | else: 104 | self.tokens, self.char_to_word, self.word_to_char_start, self.value_start_end = tokens, char_to_word, word_to_char_start, value_start_end 105 | 106 | @staticmethod 107 | def load_from_json(s): 108 | d = json.loads(s) 109 | keys = ["qid", "question", "table_id", "column_meta", "agg", "select", "conditions", "tokens", "char_to_word", "word_to_char_start", "value_start_end", "valid"] 110 | 111 | return SQLExample(*[d[k] for k in keys]) 112 | 113 | def dump_to_json(self): 114 | d = {} 115 | d["qid"] = self.qid 116 | d["question"] = self.question 117 | d["table_id"] = self.table_id 118 | d["column_meta"] = self.column_meta 119 | d["agg"] = self.agg 120 | d["select"] = self.select 121 | d["conditions"] = self.conditions 122 | d["tokens"] = self.tokens 123 | d["char_to_word"] = self.char_to_word 124 | d["word_to_char_start"] = self.word_to_char_start 125 | d["value_start_end"] = self.value_start_end 126 | d["valid"] = self.valid 127 | 128 | return json.dumps(d) 129 | 130 | def output_SQ(self, return_str=True): 131 | agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 132 | cond_ops = ['=', '>', '<', 'OP'] 133 | 134 | agg_text = agg_ops[self.agg] 135 | select_text = self.column_meta[self.select][0] 136 | cond_texts = [] 137 | for wc, op, value_text in self.conditions: 138 | column_text = self.column_meta[wc][0] 139 | op_text = cond_ops[op] 140 | cond_texts.append(column_text + op_text + value_text) 141 | 142 | if return_str: 143 | sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts) 144 | else: 145 | sq = (agg_text, select_text, set(cond_texts)) 146 | return sq 147 | 148 | def get_schema(tables): 149 | schema, headers, colTypes, naturalMap = {}, {}, {}, {} 150 | for table in tables: 151 | values = [set() for _ in range(len(table["header"]))] 152 | for row in table["rows"]: 153 | for i, value in enumerate(row): 154 | values[i].add(str(value).lower()) 155 | columns = {column: values[i] for i, column in enumerate(table["header"])} 156 | 157 | trans = {"text": "string", "real": "real"} 158 | colTypes[table["id"]] = {col:trans[ty] for ty, col in zip(table["types"], table["header"])} 159 | schema[table["id"]] = columns 160 | naturalMap[table["id"]] = {col: col for col in columns} 161 | headers[table["id"]] = table["header"] 162 | 163 | return schema, headers, colTypes, naturalMap 164 | 165 | 166 | 167 | if __name__ == "__main__": 168 | data_path = os.path.join("WikiSQL", "data") 169 | for phase in ["train", "dev", "test"]: 170 | src_file = os.path.join(data_path, phase + ".jsonl") 171 | schema_file = os.path.join(data_path, phase + ".tables.jsonl") 172 | output_file = os.path.join("data", "wiki" + phase + ".jsonl") 173 | schema, headers, colTypes, naturalMap = get_schema(utils.read_jsonl(schema_file)) 174 | 175 | cnt = 0 176 | print("processing {0}...".format(src_file)) 177 | with open(output_file, "w", encoding="utf8") as f: 178 | for raw_sample in utils.read_jsonl(src_file): 179 | table_id = raw_sample["table_id"] 180 | sql = raw_sample["sql"] 181 | 182 | cur_schema = schema[table_id] 183 | header = headers[table_id] 184 | cond_col_values = {header[cond[0]]: str(cond[2]) for cond in sql["conds"]} 185 | column_meta = [] 186 | for col in header: 187 | if col in cond_col_values: 188 | column_meta.append((col, colTypes[table_id][col], cond_col_values[col])) 189 | else: 190 | detected_val = None 191 | # for cond_col_val in cond_col_values.values(): 192 | # if cond_col_val.lower() in cur_schema[col]: 193 | # detected_val = cond_col_val 194 | # break 195 | column_meta.append((col, colTypes[table_id][col], detected_val)) 196 | 197 | example = SQLExample( 198 | cnt, 199 | raw_sample["question"], 200 | table_id, 201 | column_meta, 202 | sql["agg"], 203 | int(sql["sel"]), 204 | [(int(cond[0]), cond[1], str(cond[2])) for cond in sql["conds"]]) 205 | 206 | f.write(example.dump_to_json() + "\n") 207 | cnt += 1 208 | 209 | # if cnt % 1000 == 0 and cnt > 0: 210 | # print(cnt) -------------------------------------------------------------------------------- /wikisql_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyuqin/HydraNet-WikiSQL/1a6896207894cb4ac1c4fda1b014c2394bc86296/wikisql_lib/__init__.py -------------------------------------------------------------------------------- /wikisql_lib/common.py: -------------------------------------------------------------------------------- 1 | def count_lines(fname): 2 | with open(fname) as f: 3 | return sum(1 for line in f) 4 | 5 | 6 | def detokenize(tokens): 7 | ret = '' 8 | for g, a in zip(tokens['gloss'], tokens['after']): 9 | ret += g + a 10 | return ret.strip() 11 | -------------------------------------------------------------------------------- /wikisql_lib/dbengine.py: -------------------------------------------------------------------------------- 1 | import records 2 | import re 3 | from babel.numbers import parse_decimal, NumberFormatError 4 | from wikisql_lib.query import Query 5 | 6 | schema_re = re.compile(r'\((.+)\)') 7 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') 8 | 9 | 10 | class DBEngine: 11 | 12 | def __init__(self, fdb): 13 | self.db = records.Database('sqlite:///{}'.format(fdb)) 14 | self.conn = self.db.get_connection() 15 | 16 | def execute_dict_query(self, table_id, query): 17 | try: 18 | query = Query.from_dict(query) 19 | result = self.execute_query(table_id, query, lower=True) 20 | except Exception as e: 21 | result = 'ERROR: ' + repr(e) 22 | return (result, query) 23 | 24 | def execute_query(self, table_id, query, *args, **kwargs): 25 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 26 | 27 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 28 | if not table_id.startswith('table'): 29 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 30 | table_info = self.conn.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql 31 | schema_str = schema_re.findall(table_info)[0] 32 | schema = {} 33 | for tup in schema_str.split(', '): 34 | c, t = tup.split() 35 | schema[c] = t 36 | select = 'col{}'.format(select_index) 37 | agg = Query.agg_ops[aggregation_index] 38 | if agg: 39 | select = '{}({})'.format(agg, select) 40 | where_clause = [] 41 | where_map = {} 42 | for col_index, op, val in conditions: 43 | if lower and isinstance(val, str): 44 | val = val.lower() 45 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 46 | try: 47 | # print(val, parse_decimal(val, locale="en_US")) 48 | val = float(parse_decimal(val, locale="en_US")) 49 | except NumberFormatError as e: 50 | val = float(num_re.findall(val)[0]) 51 | # except: 52 | # print([val]) 53 | where_clause.append('col{} {} :col{}'.format(col_index, Query.cond_ops[op], col_index)) 54 | where_map['col{}'.format(col_index)] = val 55 | where_str = '' 56 | if where_clause: 57 | where_str = 'WHERE ' + ' AND '.join(where_clause) 58 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 59 | out = self.conn.query(query, **where_map) 60 | return [o.result for o in out] 61 | 62 | if __name__ == "__main__": 63 | engine = DBEngine("../data/wikisql/dev.db") 64 | query = {"agg": 0, "sel": 3, "conds": [[5, 0, "butler cc (ks)"]]} 65 | print(engine.execute_dict_query("1-10015132-11", query)) 66 | 67 | -------------------------------------------------------------------------------- /wikisql_lib/query.py: -------------------------------------------------------------------------------- 1 | from wikisql_lib.common import detokenize 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | import re 5 | 6 | 7 | re_whitespace = re.compile(r'\s+', flags=re.UNICODE) 8 | 9 | 10 | class Query: 11 | 12 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 13 | cond_ops = ['=', '>', '<', 'OP'] 14 | syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'] 15 | 16 | def __init__(self, sel_index, agg_index, conditions=tuple(), ordered=False): 17 | self.sel_index = sel_index 18 | self.agg_index = agg_index 19 | self.conditions = list(conditions) 20 | self.ordered = ordered 21 | 22 | def __eq__(self, other): 23 | if isinstance(other, self.__class__): 24 | indices = self.sel_index == other.sel_index and self.agg_index == other.agg_index 25 | if other.ordered: 26 | conds = [(col, op, str(cond).lower()) for col, op, cond in self.conditions] == [(col, op, str(cond).lower()) for col, op, cond in other.conditions] 27 | else: 28 | conds = set([(col, op, str(cond).lower()) for col, op, cond in self.conditions]) == set([(col, op, str(cond).lower()) for col, op, cond in other.conditions]) 29 | 30 | return indices and conds 31 | return NotImplemented 32 | 33 | def __ne__(self, other): 34 | if isinstance(other, self.__class__): 35 | return not self.__eq__(other) 36 | return NotImplemented 37 | 38 | def __hash__(self): 39 | return hash(tuple(sorted(self.__dict__.items()))) 40 | 41 | def __repr__(self): 42 | rep = 'SELECT {agg} {sel} FROM table'.format( 43 | agg=self.agg_ops[self.agg_index], 44 | sel='col{}'.format(self.sel_index), 45 | ) 46 | if self.conditions: 47 | rep += ' WHERE ' + ' AND '.join(['{} {} {}'.format('col{}'.format(i), self.cond_ops[o], v) for i, o, v in self.conditions]) 48 | return rep 49 | 50 | def to_dict(self): 51 | return {'sel': self.sel_index, 'agg': self.agg_index, 'conds': self.conditions} 52 | 53 | def lower(self): 54 | conds = [] 55 | for col, op, cond in self.conditions: 56 | conds.append([col, op, cond.lower()]) 57 | return self.__class__(self.sel_index, self.agg_index, conds) 58 | 59 | @classmethod 60 | def from_dict(cls, d, ordered=False): 61 | return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds'], ordered=ordered) 62 | 63 | @classmethod 64 | def from_tokenized_dict(cls, d): 65 | conds = [] 66 | for col, op, val in d['conds']: 67 | conds.append([col, op, detokenize(val)]) 68 | return cls(d['sel'], d['agg'], conds) 69 | 70 | @classmethod 71 | def from_generated_dict(cls, d): 72 | conds = [] 73 | for col, op, val in d['conds']: 74 | end = len(val['words']) 75 | conds.append([col, op, detokenize(val)]) 76 | return cls(d['sel'], d['agg'], conds) 77 | 78 | @classmethod 79 | def from_sequence(cls, sequence, table, lowercase=True): 80 | sequence = deepcopy(sequence) 81 | if 'symend' in sequence['words']: 82 | end = sequence['words'].index('symend') 83 | for k, v in sequence.items(): 84 | sequence[k] = v[:end] 85 | terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] 86 | headers = [detokenize(h) for h in table['header']] 87 | 88 | # lowercase everything and truncate sequence 89 | if lowercase: 90 | headers = [h.lower() for h in headers] 91 | for i, t in enumerate(terms): 92 | for k, v in t.items(): 93 | t[k] = v.lower() 94 | headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] 95 | 96 | # get select 97 | if 'symselect' != terms.pop(0)['word']: 98 | raise Exception('Missing symselect operator') 99 | 100 | # get aggregation 101 | if 'symagg' != terms.pop(0)['word']: 102 | raise Exception('Missing symagg operator') 103 | agg_op = terms.pop(0)['word'] 104 | 105 | if agg_op == 'symcol': 106 | agg_op = '' 107 | else: 108 | if 'symcol' != terms.pop(0)['word']: 109 | raise Exception('Missing aggregation column') 110 | try: 111 | agg_op = cls.agg_ops.index(agg_op.upper()) 112 | except Exception as e: 113 | raise Exception('Invalid agg op {}'.format(agg_op)) 114 | 115 | def find_column(name): 116 | return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) 117 | 118 | def flatten(tokens): 119 | ret = {'words': [], 'after': [], 'gloss': []} 120 | for t in tokens: 121 | ret['words'].append(t['word']) 122 | ret['after'].append(t['after']) 123 | ret['gloss'].append(t['gloss']) 124 | return ret 125 | where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere'] 126 | where_index = where_index[0] if where_index else len(terms) 127 | flat = flatten(terms[:where_index]) 128 | try: 129 | agg_col = find_column(detokenize(flat)) 130 | except Exception as e: 131 | raise Exception('Cannot find aggregation column {}'.format(flat['words'])) 132 | where_terms = terms[where_index+1:] 133 | 134 | # get conditions 135 | conditions = [] 136 | while where_terms: 137 | t = where_terms.pop(0) 138 | flat = flatten(where_terms) 139 | if t['word'] != 'symcol': 140 | raise Exception('Missing conditional column {}'.format(flat['words'])) 141 | try: 142 | op_index = flat['words'].index('symop') 143 | col_tokens = flatten(where_terms[:op_index]) 144 | except Exception as e: 145 | raise Exception('Missing conditional operator {}'.format(flat['words'])) 146 | cond_op = where_terms[op_index+1]['word'] 147 | try: 148 | cond_op = cls.cond_ops.index(cond_op.upper()) 149 | except Exception as e: 150 | raise Exception('Invalid cond op {}'.format(cond_op)) 151 | try: 152 | cond_col = find_column(detokenize(col_tokens)) 153 | except Exception as e: 154 | raise Exception('Cannot find conditional column {}'.format(col_tokens['words'])) 155 | try: 156 | val_index = flat['words'].index('symcond') 157 | except Exception as e: 158 | raise Exception('Cannot find conditional value {}'.format(flat['words'])) 159 | 160 | where_terms = where_terms[val_index+1:] 161 | flat = flatten(where_terms) 162 | val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms) 163 | cond_val = detokenize(flatten(where_terms[:val_end_index])) 164 | conditions.append([cond_col, cond_op, cond_val]) 165 | where_terms = where_terms[val_end_index+1:] 166 | q = cls(agg_col, agg_op, conditions) 167 | return q 168 | 169 | @classmethod 170 | def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True): 171 | sequence = deepcopy(sequence) 172 | if 'symend' in sequence['words']: 173 | end = sequence['words'].index('symend') 174 | for k, v in sequence.items(): 175 | sequence[k] = v[:end] 176 | terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] 177 | headers = [detokenize(h) for h in table['header']] 178 | 179 | # lowercase everything and truncate sequence 180 | if lowercase: 181 | headers = [h.lower() for h in headers] 182 | for i, t in enumerate(terms): 183 | for k, v in t.items(): 184 | t[k] = v.lower() 185 | headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] 186 | 187 | def find_column(name): 188 | return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) 189 | 190 | def flatten(tokens): 191 | ret = {'words': [], 'after': [], 'gloss': []} 192 | for t in tokens: 193 | ret['words'].append(t['word']) 194 | ret['after'].append(t['after']) 195 | ret['gloss'].append(t['gloss']) 196 | return ret 197 | where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere'] 198 | where_index = where_index[0] if where_index else len(terms) 199 | where_terms = terms[where_index+1:] 200 | 201 | # get conditions 202 | conditions = [] 203 | while where_terms: 204 | t = where_terms.pop(0) 205 | flat = flatten(where_terms) 206 | if t['word'] != 'symcol': 207 | raise Exception('Missing conditional column {}'.format(flat['words'])) 208 | try: 209 | op_index = flat['words'].index('symop') 210 | col_tokens = flatten(where_terms[:op_index]) 211 | except Exception as e: 212 | raise Exception('Missing conditional operator {}'.format(flat['words'])) 213 | cond_op = where_terms[op_index+1]['word'] 214 | try: 215 | cond_op = cls.cond_ops.index(cond_op.upper()) 216 | except Exception as e: 217 | raise Exception('Invalid cond op {}'.format(cond_op)) 218 | try: 219 | cond_col = find_column(detokenize(col_tokens)) 220 | except Exception as e: 221 | raise Exception('Cannot find conditional column {}'.format(col_tokens['words'])) 222 | try: 223 | val_index = flat['words'].index('symcond') 224 | except Exception as e: 225 | raise Exception('Cannot find conditional value {}'.format(flat['words'])) 226 | 227 | where_terms = where_terms[val_index+1:] 228 | flat = flatten(where_terms) 229 | val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms) 230 | cond_val = detokenize(flatten(where_terms[:val_end_index])) 231 | conditions.append([cond_col, cond_op, cond_val]) 232 | where_terms = where_terms[val_end_index+1:] 233 | q = cls(agg_col, agg_op, conditions) 234 | return q 235 | -------------------------------------------------------------------------------- /wikisql_lib/table.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tabulate import tabulate 3 | from wikisql_lib.query import Query 4 | import random 5 | 6 | 7 | class Table: 8 | 9 | schema_re = re.compile('\((.+)\)') 10 | 11 | def __init__(self, table_id, header, types, rows, caption=None): 12 | self.table_id = table_id 13 | self.header = header 14 | self.types = types 15 | self.rows = rows 16 | self.caption = caption 17 | 18 | def __repr__(self): 19 | return 'Table: {id}\nCaption: {caption}\n{tabulate}'.format( 20 | id=self.table_id, 21 | caption=self.caption, 22 | tabulate=tabulate(self.rows, headers=self.header) 23 | ) 24 | 25 | @classmethod 26 | def get_schema(cls, db, table_id): 27 | table_infos = db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=cls.get_id(table_id)).all() 28 | if table_infos: 29 | return table_infos[0] 30 | else: 31 | return None 32 | 33 | @classmethod 34 | def get_id(cls, table_id): 35 | return 'table_{}'.format(table_id.replace('-', '_')) 36 | 37 | @classmethod 38 | def from_db(cls, db, table_id): 39 | table_info = cls.get_schema(db, table_id) 40 | if table_info: 41 | schema_str = cls.schema_re.findall(table_info)[0] = [0].sql 42 | header, types = [], [] 43 | for tup in schema_str.split(', '): 44 | c, t = tup.split() 45 | header.append(c) 46 | types.append(t) 47 | rows = [[getattr(r, h) for h in header] for r in db.query('SELECT * from {}'.format(cls.get_id(table_id)))] 48 | return cls(table_id, header, types, rows) 49 | else: 50 | return None 51 | 52 | @property 53 | def name(self): 54 | return self.get_id(self.table_id) 55 | 56 | def create_table(self, db, replace_existing=False, lower=True): 57 | exists = self.get_schema(db, self.table_id) 58 | if exists: 59 | if replace_existing: 60 | db.query('DROP TABLE {}'.format(self.name)) 61 | else: 62 | return 63 | type_str = ', '.join(['col{} {}'.format(i, t) for i, t in enumerate(self.types)]) 64 | db.query('CREATE TABLE {name} ({types})'.format(name=self.name, types=type_str)) 65 | for row in self.rows: 66 | value_str = ', '.join([':val{}'.format(j) for j, c in enumerate(row)]) 67 | value_dict = {'val{}'.format(j): c for j, c in enumerate(row)} 68 | if lower: 69 | value_dict = {k: v.lower() if isinstance(v, str) else v for k, v in value_dict.items()} 70 | db.query('INSERT INTO {name} VALUES ({values})'.format(name=self.name, values=value_str), **value_dict) 71 | 72 | def execute_query(self, db, query, lower=True): 73 | sel_str = 'col{}'.format(query.sel_index) if query.sel_index >= 0 else '*' 74 | agg_str = sel_str 75 | agg_op = Query.agg_ops[query.agg_index] 76 | if agg_op: 77 | agg_str = '{}({})'.format(agg_op, sel_str) 78 | where_str = ' AND '.join(['col{} {} :col{}'.format(i, Query.cond_ops[o], i) for i, o, v in query.conditions]) 79 | where_map = {'col{}'.format(i): v for i, o, v in query.conditions} 80 | if lower: 81 | where_map = {k: v.lower() if isinstance(v, str) else v for k, v in where_map.items()} 82 | if where_map: 83 | where_str = 'WHERE ' + where_str 84 | 85 | if query.sel_index >= 0: 86 | query_str = 'SELECT {agg_str} AS result FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 87 | return [r.result for r in db.query(query_str, **where_map)] 88 | else: 89 | query_str = 'SELECT {agg_str} FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 90 | return [[getattr(r, 'col{}'.format(i)) for i in range(len(self.header))] for r in db.query(query_str, **where_map)] 91 | 92 | def query_str(self, query): 93 | agg_str = self.header[query.sel_index] 94 | agg_op = Query.agg_ops[query.agg_index] 95 | if agg_op: 96 | agg_str = '{}({})'.format(agg_op, agg_str) 97 | where_str = ' AND '.join(['{} {} {}'.format(self.header[i], Query.cond_ops[o], v) for i, o, v in query.conditions]) 98 | return 'SELECT {} FROM {} WHERE {}'.format(agg_str, self.name, where_str) 99 | 100 | def generate_query(self, db, max_cond=4): 101 | max_cond = min(len(self.header), max_cond) 102 | # sample a select column 103 | sel_index = random.choice(list(range(len(self.header)))) 104 | # sample where conditions 105 | query = Query(-1, Query.agg_ops.index('')) 106 | results = self.execute_query(db, query) 107 | condition_options = list(range(len(self.header))) 108 | condition_options.remove(sel_index) 109 | for i in range(max_cond): 110 | if not results: 111 | break 112 | cond_index = random.choice(condition_options) 113 | if self.types[cond_index] == 'text': 114 | cond_op = Query.cond_ops.index('=') 115 | else: 116 | cond_op = random.choice(list(range(len(Query.cond_ops)))) 117 | cond_val = random.choice([r[cond_index] for r in results]) 118 | query.conditions.append((cond_index, cond_op, cond_val)) 119 | new_results = self.execute_query(db, query) 120 | if [r[sel_index] for r in new_results] != [r[sel_index] for r in results]: 121 | condition_options.remove(cond_index) 122 | results = new_results 123 | else: 124 | query.conditions.pop() 125 | # sample an aggregation operation 126 | if self.types[sel_index] == 'text': 127 | query.agg_index = Query.agg_ops.index('') 128 | else: 129 | query.agg_index = random.choice(list(range(len(Query.agg_ops)))) 130 | query.sel_index = sel_index 131 | results = self.execute_query(db, query) 132 | return query, results 133 | 134 | def generate_queries(self, db, n=1, max_tries=5, lower=True): 135 | qs = [] 136 | for i in range(n): 137 | n_tries = 0 138 | r = None 139 | while r is None and n_tries < max_tries: 140 | q, r = self.generate_query(db, max_cond=4) 141 | n_tries += 1 142 | if r: 143 | qs.append((q, r)) 144 | return qs 145 | -------------------------------------------------------------------------------- /wikisql_prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import utils 5 | from modeling.model_factory import create_model 6 | from featurizer import HydraFeaturizer, SQLDataset 7 | from wikisql_lib.dbengine import DBEngine 8 | 9 | def print_metric(label_file, pred_file): 10 | sp = [(json.loads(ls)["sql"], json.loads(lp)["query"]) for ls, lp in zip(open(label_file), open(pred_file))] 11 | 12 | sel_acc = sum(p["sel"] == s["sel"] for s, p in sp) / len(sp) 13 | agg_acc = sum(p["agg"] == s["agg"] for s, p in sp) / len(sp) 14 | wcn_acc = sum(len(p["conds"]) == len(s["conds"]) for s, p in sp) / len(sp) 15 | 16 | def wcc_match(a, b): 17 | a = sorted(a, key=lambda k: k[0]) 18 | b = sorted(b, key=lambda k: k[0]) 19 | return [c[0] for c in a] == [c[0] for c in b] 20 | 21 | def wco_match(a, b): 22 | a = sorted(a, key=lambda k: k[0]) 23 | b = sorted(b, key=lambda k: k[0]) 24 | return [c[1] for c in a] == [c[1] for c in b] 25 | 26 | def wcv_match(a, b): 27 | a = sorted(a, key=lambda k: k[0]) 28 | b = sorted(b, key=lambda k: k[0]) 29 | return [str(c[2]).lower() for c in a] == [str(c[2]).lower() for c in b] 30 | 31 | wcc_acc = sum(wcc_match(p["conds"], s["conds"]) for s, p in sp) / len(sp) 32 | wco_acc = sum(wco_match(p["conds"], s["conds"]) for s, p in sp) / len(sp) 33 | wcv_acc = sum(wcv_match(p["conds"], s["conds"]) for s, p in sp) / len(sp) 34 | 35 | print('sel_acc: {}\nagg_acc: {}\nwcn_acc: {}\nwcc_acc: {}\nwco_acc: {}\nwcv_acc: {}\n' \ 36 | .format(sel_acc, agg_acc, wcn_acc, wcc_acc, wco_acc, wcv_acc)) 37 | 38 | 39 | if __name__ == "__main__": 40 | # in_file = "data/wikidev.jsonl" 41 | # out_file = "output/dev_out.jsonl" 42 | # label_file = "WikiSQL/data/dev.jsonl" 43 | # db_file = "WikiSQL/data/dev.db" 44 | # model_out_file = "output/dev_model_out.pkl" 45 | 46 | in_file = "data/wikitest.jsonl" 47 | out_file = "output/test_out.jsonl" 48 | label_file = "WikiSQL/data/test.jsonl" 49 | db_file = "WikiSQL/data/test.db" 50 | model_out_file = "output/test_model_out.pkl" 51 | 52 | # All Best 53 | model_path = "output/20200207_105347" 54 | epoch = 4 55 | 56 | engine = DBEngine(db_file) 57 | config = utils.read_conf(os.path.join(model_path, "model.conf")) 58 | # config["DEBUG"] = 1 59 | featurizer = HydraFeaturizer(config) 60 | pred_data = SQLDataset(in_file, config, featurizer, False) 61 | print("num of samples: {0}".format(len(pred_data.input_features))) 62 | 63 | model = create_model(config, is_train=False) 64 | model.load(model_path, epoch) 65 | 66 | if "DEBUG" in config: 67 | model_out_file = model_out_file + ".partial" 68 | 69 | if os.path.exists(model_out_file): 70 | model_outputs = pickle.load(open(model_out_file, "rb")) 71 | else: 72 | model_outputs = model.dataset_inference(pred_data) 73 | pickle.dump(model_outputs, open(model_out_file, "wb")) 74 | 75 | print("===HydraNet===") 76 | pred_sqls = model.predict_SQL(pred_data, model_outputs=model_outputs) 77 | with open(out_file, "w") as g: 78 | for pred_sql in pred_sqls: 79 | # print(pred_sql) 80 | result = {"query": {}} 81 | result["query"]["agg"] = int(pred_sql[0]) 82 | result["query"]["sel"] = int(pred_sql[1]) 83 | result["query"]["conds"] = [(int(cond[0]), int(cond[1]), str(cond[2])) for cond in pred_sql[2]] 84 | g.write(json.dumps(result) + "\n") 85 | print_metric(label_file, out_file) 86 | 87 | print("===HydraNet+EG===") 88 | pred_sqls = model.predict_SQL_with_EG(engine, pred_data, model_outputs=model_outputs) 89 | with open(out_file + ".eg", "w") as g: 90 | for pred_sql in pred_sqls: 91 | # print(pred_sql) 92 | result = {"query": {}} 93 | result["query"]["agg"] = int(pred_sql[0]) 94 | result["query"]["sel"] = int(pred_sql[1]) 95 | result["query"]["conds"] = [(int(cond[0]), int(cond[1]), str(cond[2])) for cond in pred_sql[2]] 96 | g.write(json.dumps(result) + "\n") 97 | print_metric(label_file, out_file + ".eg") 98 | --------------------------------------------------------------------------------