├── models ├── __init__.py ├── net_utils.py ├── andor_predictor.py ├── having_predictor.py ├── root_teminal_predictor.py ├── desasc_limit_predictor.py ├── multisql_predictor.py ├── agg_predictor.py ├── keyword_predictor.py ├── op_predictor.py └── col_predictor.py ├── requirements.txt ├── merge_jsons.py ├── train_all.sh ├── test_gen.sh ├── test.py ├── README.md ├── train.py ├── word_embedding.py ├── get_data_wikisql.py ├── utils.py ├── generate_wikisql_augment.py ├── process_sql.py ├── preprocess_train_dev_data.py └── supermodel.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Babel==2.5.0 2 | nltk==3.2.1 3 | numpy==1.14.0 4 | pandas==0.18.1 5 | records==0.5.0 6 | torch==0.2.0.post2 7 | et-xmlfile==1.0.1 8 | jdcal==1.3 9 | odfpy==1.3.5 10 | olefile==0.44 11 | openpyxl==2.4.9 12 | pkg-resources==0.0.0 13 | pytz==2017.3 14 | records==0.5.2 15 | SQLAlchemy==1.1.14 16 | tablib==0.12.1 17 | unicodecsv==0.14.1 18 | xlrd==1.1.0 19 | xlwt==1.3.0 20 | scikit-learn==0.18.1 21 | -------------------------------------------------------------------------------- /merge_jsons.py: -------------------------------------------------------------------------------- 1 | import re 2 | import io 3 | import json 4 | import numpy as np 5 | 6 | 7 | 8 | def merge_files(f1, f2, output_file): 9 | with open(f1) as inf1: 10 | data_1 = json.load(inf1) 11 | with open(f2) as inf2: 12 | data_2 = json.load(inf2) 13 | 14 | data = data_1 + data_2 15 | 16 | with open(output_file, 'wt') as out: 17 | json.dump(data, out, sort_keys=True, indent=4, separators=(',', ': ')) 18 | 19 | 20 | 21 | merge_files("/data/projects/nl2sql/datasets/data_add_wikisql/wikisql_tables.json", "/data/projects/nl2sql/datasets/data/tables.json", "/data/projects/nl2sql/datasets/data_add_wikisql/all_tables.json") 22 | -------------------------------------------------------------------------------- /train_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ## full + aug 4 | # hs=full 5 | # tbl=std 6 | # d_type="_augment" 7 | 8 | # ## - aug 9 | hs=full 10 | tbl=std 11 | d_type="" 12 | 13 | ## - aug - table 14 | # hs=full 15 | # tbl=no 16 | # d_type="" 17 | 18 | # ## - aug - table -history 19 | # hs=no 20 | # tbl=no 21 | # d_type="" 22 | 23 | 24 | # toy="--toy" 25 | toy="" 26 | # epoch=1 # 600 for spider, 200 for +aug 27 | 28 | DATE=`date '+%Y-%m-%d-%H:%M:%S'` 29 | 30 | data_root=generated_datasets/generated_data${d_type} 31 | save_dir="${data_root}/saved_models_hs=${hs}_tbl=${tbl}_${DATE}" 32 | log_dir=${save_dir}/train_log 33 | mkdir -p ${save_dir} 34 | mkdir -p ${log_dir} 35 | 36 | 37 | export CUDA_VISIBLE_DEVICES=2 38 | module=col 39 | epoch=600 40 | python train.py \ 41 | --data_root ${data_root} \ 42 | --save_dir ${save_dir} \ 43 | --history_type ${hs} \ 44 | --table_type ${tbl} \ 45 | --train_component ${module} \ 46 | --epoch ${epoch} \ 47 | ${toy} \ 48 | > "${log_dir}/train_${d_type}_hs=${hs}_tbl=${tbl}_${module}_${DATE}.txt" \ 49 | 2>&1 & 50 | 51 | export CUDA_VISIBLE_DEVICES=3 52 | epoch=300 53 | for module in multi_sql keyword op agg root_tem des_asc having andor 54 | do 55 | python train.py \ 56 | --data_root ${data_root} \ 57 | --save_dir ${save_dir} \ 58 | --history_type ${hs} \ 59 | --table_type ${tbl} \ 60 | --train_component ${module} \ 61 | --epoch ${epoch} \ 62 | ${toy} \ 63 | > "${log_dir}/train_${d_type}_hs=${hs}_tbl=${tbl}_${module}_${DATE}.txt" \ 64 | 2>&1 & 65 | done 66 | -------------------------------------------------------------------------------- /test_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=2,3 4 | TEST_DATA=data/dev.json 5 | 6 | # # full + aug 7 | SAVE_PATH=generated_datasets/generated_data_augment/saved_models_hs=full_tbl=std 8 | python test.py \ 9 | --test_data_path ${TEST_DATA} \ 10 | --models ${SAVE_PATH} \ 11 | --output_path ${SAVE_PATH}/dev_result.txt \ 12 | --history_type full \ 13 | --table_type std \ 14 | > ${SAVE_PATH}/dev_result.out.txt 2>&1 & 15 | 16 | 17 | # - aug 18 | SAVE_PATH=generated_datasets/generated_data/saved_models_hs=full_tbl=std 19 | python test.py \ 20 | --test_data_path ${TEST_DATA} \ 21 | --models ${SAVE_PATH} \ 22 | --output_path ${SAVE_PATH}/dev_result.txt \ 23 | --history_type full \ 24 | --table_type std \ 25 | > ${SAVE_PATH}/dev_result.out.txt 2>&1 & 26 | 27 | 28 | # - aug - table 29 | SAVE_PATH=generated_datasets/generated_data/saved_models_hs=full_tbl=no 30 | python test.py \ 31 | --test_data_path ${TEST_DATA} \ 32 | --models ${SAVE_PATH} \ 33 | --output_path ${SAVE_PATH}/dev_result.txt \ 34 | --history_type full \ 35 | --table_type no \ 36 | > ${SAVE_PATH}/dev_result.out.txt 2>&1 & 37 | 38 | 39 | # - aug - table - history 40 | SAVE_PATH=generated_datasets/generated_data/saved_models_hs=no_tbl=no 41 | python test.py \ 42 | --test_data_path ${TEST_DATA} \ 43 | --models ${SAVE_PATH} \ 44 | --output_path ${SAVE_PATH}/dev_result.txt \ 45 | --history_type no \ 46 | --table_type no \ 47 | > ${SAVE_PATH}/dev_result.out.txt 2>&1 & 48 | -------------------------------------------------------------------------------- /models/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd import Variable 5 | 6 | def run_lstm(lstm, inp, inp_len, hidden=None): 7 | # Run the LSTM using packed sequence. 8 | # This requires to first sort the input according to its length. 9 | sort_perm = np.array(sorted(range(len(inp_len)), 10 | key=lambda k:inp_len[k], reverse=True)) 11 | sort_inp_len = inp_len[sort_perm] 12 | sort_perm_inv = np.argsort(sort_perm) 13 | if inp.is_cuda: 14 | sort_perm = torch.LongTensor(sort_perm).cuda() 15 | sort_perm_inv = torch.LongTensor(sort_perm_inv).cuda() 16 | 17 | lstm_inp = nn.utils.rnn.pack_padded_sequence(inp[sort_perm], 18 | sort_inp_len, batch_first=True) 19 | if hidden is None: 20 | lstm_hidden = None 21 | else: 22 | lstm_hidden = (hidden[0][:, sort_perm], hidden[1][:, sort_perm]) 23 | 24 | sort_ret_s, sort_ret_h = lstm(lstm_inp, lstm_hidden) 25 | ret_s = nn.utils.rnn.pad_packed_sequence( 26 | sort_ret_s, batch_first=True)[0][sort_perm_inv] 27 | ret_h = (sort_ret_h[0][:, sort_perm_inv], sort_ret_h[1][:, sort_perm_inv]) 28 | return ret_s, ret_h 29 | 30 | 31 | def col_name_encode(name_inp_var, name_len, col_len, enc_lstm): 32 | #Encode the columns. 33 | #The embedding of a column name is the last state of its LSTM output. 34 | name_hidden, _ = run_lstm(enc_lstm, name_inp_var, name_len) 35 | name_out = name_hidden[tuple(range(len(name_len))), name_len-1] 36 | ret = torch.FloatTensor( 37 | len(col_len), max(col_len), name_out.size()[1]).zero_() 38 | if name_out.is_cuda: 39 | ret = ret.cuda() 40 | 41 | st = 0 42 | for idx, cur_len in enumerate(col_len): 43 | ret[idx, :cur_len] = name_out.data[st:st+cur_len] 44 | st += cur_len 45 | ret_var = Variable(ret) 46 | 47 | return ret_var, col_len 48 | 49 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import datetime 4 | import argparse 5 | import numpy as np 6 | from utils import * 7 | from supermodel import SuperModel 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--train_emb', action='store_true', 12 | help='Train word embedding.') 13 | parser.add_argument('--toy', action='store_true', 14 | help='If set, use small data; used for fast debugging.') 15 | parser.add_argument('--models', type=str, help='path to saved model') 16 | parser.add_argument('--test_data_path',type=str) 17 | parser.add_argument('--output_path', type=str) 18 | parser.add_argument('--history_type', type=str, default='full', choices=['full','part','no'], help='full, part, or no history') 19 | parser.add_argument('--table_type', type=str, default='std', choices=['std','hier','no'], help='standard, hierarchical, or no table info') 20 | args = parser.parse_args() 21 | use_hs = True 22 | if args.history_type == "no": 23 | args.history_type = "full" 24 | use_hs = False 25 | 26 | N_word=300 27 | B_word=42 28 | N_h = 300 29 | N_depth=2 30 | # if args.part: 31 | # part = True 32 | # else: 33 | # part = False 34 | if args.toy: 35 | USE_SMALL=True 36 | GPU=True 37 | BATCH_SIZE=2 #20 38 | else: 39 | USE_SMALL=False 40 | GPU=True 41 | BATCH_SIZE=2 #64 42 | # TRAIN_ENTRY=(False, True, False) # (AGG, SEL, COND) 43 | # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY 44 | learning_rate = 1e-4 45 | 46 | #TODO 47 | data = json.load(open(args.test_data_path)) 48 | # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history) 49 | 50 | word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ 51 | load_used=args.train_emb, use_small=USE_SMALL) 52 | # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history) 53 | #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt") 54 | 55 | model = SuperModel(word_emb, N_word=N_word, gpu=GPU, trainable_emb = args.train_emb, table_type=args.table_type, use_hs=use_hs) 56 | 57 | # agg_m, sel_m, cond_m = best_model_name(args) 58 | # torch.save(model.state_dict(), "saved_models/{}_models.dump".format(args.train_component)) 59 | 60 | print "Loading from modules..." 61 | model.multi_sql.load_state_dict(torch.load("{}/multi_sql_models.dump".format(args.models))) 62 | model.key_word.load_state_dict(torch.load("{}/keyword_models.dump".format(args.models))) 63 | model.col.load_state_dict(torch.load("{}/col_models.dump".format(args.models))) 64 | model.op.load_state_dict(torch.load("{}/op_models.dump".format(args.models))) 65 | model.agg.load_state_dict(torch.load("{}/agg_models.dump".format(args.models))) 66 | model.root_teminal.load_state_dict(torch.load("{}/root_tem_models.dump".format(args.models))) 67 | model.des_asc.load_state_dict(torch.load("{}/des_asc_models.dump".format(args.models))) 68 | model.having.load_state_dict(torch.load("{}/having_models.dump".format(args.models))) 69 | 70 | test_acc(model, BATCH_SIZE, data, args.output_path) 71 | #test_exec_acc() 72 | -------------------------------------------------------------------------------- /models/andor_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class AndOrPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(AndOrPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.q_att = nn.Linear(N_h, N_h) 26 | self.hs_att = nn.Linear(N_h, N_h) 27 | self.ao_out_q = nn.Linear(N_h, N_h) 28 | self.ao_out_hs = nn.Linear(N_h, N_h) 29 | self.ao_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for and/or 30 | 31 | self.softmax = nn.Softmax() #dim=1 32 | self.CE = nn.CrossEntropyLoss() 33 | self.log_softmax = nn.LogSoftmax() 34 | self.mlsml = nn.MultiLabelSoftMarginLoss() 35 | self.bce_logit = nn.BCEWithLogitsLoss() 36 | self.sigm = nn.Sigmoid() 37 | if gpu: 38 | self.cuda() 39 | 40 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len): 41 | max_q_len = max(q_len) 42 | max_hs_len = max(hs_len) 43 | B = len(q_len) 44 | 45 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 46 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 47 | 48 | att_np_q = np.ones((B, max_q_len)) 49 | att_val_q = torch.from_numpy(att_np_q).float() 50 | att_val_q = Variable(att_val_q.cuda()) 51 | for idx, num in enumerate(q_len): 52 | if num < max_q_len: 53 | att_val_q[idx, num:] = -100 54 | att_prob_q = self.softmax(att_val_q) 55 | q_weighted = (q_enc * att_prob_q.unsqueeze(2)).sum(1) 56 | 57 | # Same as the above, compute SQL history embedding weighted by column attentions 58 | att_np_h = np.ones((B, max_hs_len)) 59 | att_val_h = torch.from_numpy(att_np_h).float() 60 | att_val_h = Variable(att_val_h.cuda()) 61 | for idx, num in enumerate(hs_len): 62 | if num < max_hs_len: 63 | att_val_h[idx, num:] = -100 64 | att_prob_h = self.softmax(att_val_h) 65 | hs_weighted = (hs_enc * att_prob_h.unsqueeze(2)).sum(1) 66 | # ao_score: (B, 2) 67 | ao_score = self.ao_out(self.ao_out_q(q_weighted) + int(self.use_hs)* self.ao_out_hs(hs_weighted)) 68 | 69 | return ao_score 70 | 71 | 72 | def loss(self, score, truth): 73 | loss = 0 74 | data = torch.from_numpy(np.array(truth)) 75 | truth_var = Variable(data.cuda()) 76 | loss = self.CE(score, truth_var) 77 | 78 | return loss 79 | 80 | 81 | def check_acc(self, score, truth): 82 | err = 0 83 | B = len(score) 84 | pred = [] 85 | for b in range(B): 86 | pred.append(np.argmax(score[b].data.cpu().numpy())) 87 | for b, (p, t) in enumerate(zip(pred, truth)): 88 | if p != t: 89 | err += 1 90 | 91 | return err 92 | -------------------------------------------------------------------------------- /models/having_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class HavingPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(HavingPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 26 | num_layers=N_depth, batch_first=True, 27 | dropout=0.3, bidirectional=True) 28 | 29 | self.q_att = nn.Linear(N_h, N_h) 30 | self.hs_att = nn.Linear(N_h, N_h) 31 | self.hv_out_q = nn.Linear(N_h, N_h) 32 | self.hv_out_hs = nn.Linear(N_h, N_h) 33 | self.hv_out_c = nn.Linear(N_h, N_h) 34 | self.hv_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for having/none 35 | 36 | self.softmax = nn.Softmax() #dim=1 37 | self.CE = nn.CrossEntropyLoss() 38 | self.log_softmax = nn.LogSoftmax() 39 | self.mlsml = nn.MultiLabelSoftMarginLoss() 40 | self.bce_logit = nn.BCEWithLogitsLoss() 41 | self.sigm = nn.Sigmoid() 42 | if gpu: 43 | self.cuda() 44 | 45 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): 46 | max_q_len = max(q_len) 47 | max_hs_len = max(hs_len) 48 | max_col_len = max(col_len) 49 | B = len(q_len) 50 | 51 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 52 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 53 | col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) 54 | 55 | # get target/predicted column's embedding 56 | # col_emb: (B, hid_dim) 57 | col_emb = [] 58 | for b in range(B): 59 | col_emb.append(col_enc[b, gt_col[b]]) 60 | col_emb = torch.stack(col_emb) 61 | att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) 62 | for idx, num in enumerate(q_len): 63 | if num < max_q_len: 64 | att_val_qc[idx, num:] = -100 65 | att_prob_qc = self.softmax(att_val_qc) 66 | q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) 67 | 68 | # Same as the above, compute SQL history embedding weighted by column attentions 69 | att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) 70 | for idx, num in enumerate(hs_len): 71 | if num < max_hs_len: 72 | att_val_hc[idx, num:] = -100 73 | att_prob_hc = self.softmax(att_val_hc) 74 | hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) 75 | # hv_score: (B, 2) 76 | hv_score = self.hv_out(self.hv_out_q(q_weighted) + int(self.use_hs)* self.hv_out_hs(hs_weighted) + self.hv_out_c(col_emb)) 77 | 78 | return hv_score 79 | 80 | 81 | def loss(self, score, truth): 82 | loss = 0 83 | data = torch.from_numpy(np.array(truth)) 84 | truth_var = Variable(data.cuda()) 85 | loss = self.CE(score, truth_var) 86 | 87 | return loss 88 | 89 | 90 | def check_acc(self, score, truth): 91 | err = 0 92 | B = len(score) 93 | pred = [] 94 | for b in range(B): 95 | pred.append(np.argmax(score[b].data.cpu().numpy())) 96 | for b, (p, t) in enumerate(zip(pred, truth)): 97 | if p != t: 98 | err += 1 99 | 100 | return err 101 | -------------------------------------------------------------------------------- /models/root_teminal_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class RootTeminalPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(RootTeminalPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 26 | num_layers=N_depth, batch_first=True, 27 | dropout=0.3, bidirectional=True) 28 | 29 | self.q_att = nn.Linear(N_h, N_h) 30 | self.hs_att = nn.Linear(N_h, N_h) 31 | self.rt_out_q = nn.Linear(N_h, N_h) 32 | self.rt_out_hs = nn.Linear(N_h, N_h) 33 | self.rt_out_c = nn.Linear(N_h, N_h) 34 | self.rt_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for 2 operators 35 | 36 | self.softmax = nn.Softmax() #dim=1 37 | self.CE = nn.CrossEntropyLoss() 38 | self.log_softmax = nn.LogSoftmax() 39 | self.mlsml = nn.MultiLabelSoftMarginLoss() 40 | self.bce_logit = nn.BCEWithLogitsLoss() 41 | self.sigm = nn.Sigmoid() 42 | if gpu: 43 | self.cuda() 44 | 45 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): 46 | max_q_len = max(q_len) 47 | max_hs_len = max(hs_len) 48 | max_col_len = max(col_len) 49 | B = len(q_len) 50 | 51 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 52 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 53 | col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) 54 | 55 | # get target/predicted column's embedding 56 | # col_emb: (B, hid_dim) 57 | col_emb = [] 58 | for b in range(B): 59 | col_emb.append(col_enc[b, gt_col[b]]) 60 | col_emb = torch.stack(col_emb) 61 | att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) 62 | for idx, num in enumerate(q_len): 63 | if num < max_q_len: 64 | att_val_qc[idx, num:] = -100 65 | att_prob_qc = self.softmax(att_val_qc) 66 | q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) 67 | 68 | # Same as the above, compute SQL history embedding weighted by column attentions 69 | att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) 70 | for idx, num in enumerate(hs_len): 71 | if num < max_hs_len: 72 | att_val_hc[idx, num:] = -100 73 | att_prob_hc = self.softmax(att_val_hc) 74 | hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) 75 | # rt_score: (B, 2) 76 | rt_score = self.rt_out(self.rt_out_q(q_weighted) + int(self.use_hs)* self.rt_out_hs(hs_weighted) + self.rt_out_c(col_emb)) 77 | 78 | return rt_score 79 | 80 | 81 | def loss(self, score, truth): 82 | loss = 0 83 | data = torch.from_numpy(np.array(truth)) 84 | truth_var = Variable(data.cuda()) 85 | loss = self.CE(score, truth_var) 86 | 87 | return loss 88 | 89 | 90 | def check_acc(self, score, truth): 91 | err = 0 92 | B = len(score) 93 | pred = [] 94 | for b in range(B): 95 | pred.append(np.argmax(score[b].data.cpu().numpy())) 96 | for b, (p, t) in enumerate(zip(pred, truth)): 97 | if p != t: 98 | err += 1 99 | 100 | return err 101 | -------------------------------------------------------------------------------- /models/desasc_limit_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class DesAscLimitPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(DesAscLimitPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 26 | num_layers=N_depth, batch_first=True, 27 | dropout=0.3, bidirectional=True) 28 | 29 | 30 | self.q_att = nn.Linear(N_h, N_h) 31 | self.hs_att = nn.Linear(N_h, N_h) 32 | self.dat_out_q = nn.Linear(N_h, N_h) 33 | self.dat_out_hs = nn.Linear(N_h, N_h) 34 | self.dat_out_c = nn.Linear(N_h, N_h) 35 | self.dat_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 4)) #for 4 desc/asc limit/none combinations 36 | 37 | self.softmax = nn.Softmax() #dim=1 38 | self.CE = nn.CrossEntropyLoss() 39 | self.log_softmax = nn.LogSoftmax() 40 | self.mlsml = nn.MultiLabelSoftMarginLoss() 41 | self.bce_logit = nn.BCEWithLogitsLoss() 42 | self.sigm = nn.Sigmoid() 43 | if gpu: 44 | self.cuda() 45 | 46 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): 47 | max_q_len = max(q_len) 48 | max_hs_len = max(hs_len) 49 | max_col_len = max(col_len) 50 | B = len(q_len) 51 | 52 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 53 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 54 | col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) 55 | 56 | # get target/predicted column's embedding 57 | # col_emb: (B, hid_dim) 58 | col_emb = [] 59 | for b in range(B): 60 | col_emb.append(col_enc[b, gt_col[b]]) 61 | col_emb = torch.stack(col_emb) # [B, dim] 62 | # self.q_att(q_enc).transpose(1, 2): [B, dim, max_q_len] 63 | att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) 64 | for idx, num in enumerate(q_len): 65 | if num < max_q_len: 66 | att_val_qc[idx, num:] = -100 67 | att_prob_qc = self.softmax(att_val_qc) 68 | q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) 69 | 70 | # Same as the above, compute SQL history embedding weighted by column attentions 71 | att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) 72 | for idx, num in enumerate(hs_len): 73 | if num < max_hs_len: 74 | att_val_hc[idx, num:] = -100 75 | att_prob_hc = self.softmax(att_val_hc) 76 | hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) 77 | # dat_score: (B, 4) 78 | dat_score = self.dat_out(self.dat_out_q(q_weighted) + int(self.use_hs)* self.dat_out_hs(hs_weighted) + self.dat_out_c(col_emb)) 79 | 80 | return dat_score 81 | 82 | 83 | def loss(self, score, truth): 84 | loss = 0 85 | data = torch.from_numpy(np.array(truth)) 86 | truth_var = Variable(data.cuda()) 87 | loss = self.CE(score, truth_var) 88 | 89 | return loss 90 | 91 | 92 | def check_acc(self, score, truth): 93 | err = 0 94 | B = len(score) 95 | pred = [] 96 | for b in range(B): 97 | pred.append(np.argmax(score[b].data.cpu().numpy())) 98 | for b, (p, t) in enumerate(zip(pred, truth)): 99 | if p != t: 100 | err += 1 101 | 102 | return err 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## SyntaxSQLNet: Syntax Tree Networks for Complex and Cross-Domain Text-to-SQL Task 2 | 3 | Source code of our EMNLP 2018 paper: [SyntaxSQLNet: Syntax Tree Networks for Complex and Cross-DomainText-to-SQL Task 4 | ](https://arxiv.org/abs/1810.05237). 5 | 6 | :+1: `03/20/2022`: **We open-sourced a simple but SOTA model (just T5) for 20 tasks including text-to-SQL! Please check out our code in the [UnifiedSKG repo](https://github.com/hkunlp/unifiedskg)!!** 7 | 8 | ### Citation 9 | 10 | ``` 11 | @InProceedings{Yu&al.18.emnlp.syntax, 12 | author = {Tao Yu and Michihiro Yasunaga and Kai Yang and Rui Zhang and Dongxu Wang and Zifan Li and Dragomir Radev}, 13 | title = {SyntaxSQLNet: Syntax Tree Networks for Complex and Cross-Domain Text-to-SQL Task}, 14 | year = {2018}, 15 | booktitle = {Proceedings of EMNLP}, 16 | publisher = {Association for Computational Linguistics}, 17 | } 18 | ``` 19 | 20 | #### Environment Setup 21 | 22 | 1. The code uses Python 2.7 and [Pytorch 0.2.0](https://pytorch.org/previous-versions/) GPU. 23 | 2. Install Python dependency: `pip install -r requirements.txt` 24 | 25 | #### Download Data, Embeddings, Scripts, and Pretrained Models 26 | 1. Download the dataset from [the Spider task website](https://yale-lily.github.io/spider) to be updated, and put `tables.json`, `train.json`, and `dev.json` under `data/` directory. 27 | 2. Download the pretrained [Glove](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip), and put it as `glove/glove.%dB.%dd.txt` 28 | 3. Download `evaluation.py` and `process_sql.py` from [the Spider github page](https://github.com/taoyds/spider) 29 | 4. Download preprocessed train/dev datasets and pretrained models from [here](https://drive.google.com/file/d/1FHEcceYuf__PLhtD5QzJvexM7SNGnoBu/view?usp=sharing). It contains: 30 | -`generated_datasets/` 31 | - ``generated_data`` for original Spider training datasets, pretrained models can be found at `generated_data/saved_models` 32 | - ``generated_data_augment`` for original Spider + augmented training datasets, pretrained models can be found at `generated_data_augment/saved_models` 33 | 34 | #### Generating Train/dev Data for Modules 35 | You could find preprocessed train/dev data in ``generated_datasets/``. 36 | 37 | To generate them by yourself, update dirs under `TODO` in `preprocess_train_dev_data.py`, and run the following command to generate training files for each module: 38 | ``` 39 | python preprocess_train_dev_data.py train|dev 40 | ``` 41 | 42 | #### Folder/File Description 43 | - ``data/`` contains raw train/dev/test data and table file 44 | - ``generated_datasets/`` described as above 45 | - ``models/`` contains the code for each module. 46 | - ``evaluation.py`` is for evaluation. It uses ``process_sql.py``. 47 | - ``train.py`` is the main file for training. Use ``train_all.sh`` to train all the modules (see below). 48 | - ``test.py`` is the main file for testing. It uses ``supermodel.sh`` to call the trained modules and generate SQL queries. In practice, and use ``test_gen.sh`` to generate SQL queries. 49 | - `generate_wikisql_augment.py` for cross-domain data augmentation 50 | 51 | 52 | #### Training 53 | Run ``train_all.sh`` to train all the modules. 54 | It looks like: 55 | ``` 56 | python train.py \ 57 | --data_root path/to/generated_data \ 58 | --save_dir path/to/save/trained/module \ 59 | --history_type full|no \ 60 | --table_type std|no \ 61 | --train_component \ 62 | --epoch 63 | ``` 64 | 65 | #### Testing 66 | Run ``test_gen.sh`` to generate SQL queries. 67 | ``test_gen.sh`` looks like: 68 | ``` 69 | SAVE_PATH=generated_datasets/generated_data/saved_models_hs=full_tbl=std 70 | python test.py \ 71 | --test_data_path path/to/raw/test/data \ 72 | --models path/to/trained/module \ 73 | --output_path path/to/print/generated/SQL \ 74 | --history_type full|no \ 75 | --table_type std|no \ 76 | ``` 77 | 78 | #### Evaluation 79 | Follow the general evaluation process in [the Spider github page](https://github.com/taoyds/spider). 80 | 81 | #### Cross-Domain Data Augmentation 82 | You could find preprocessed augmented data at `generated_datasets/generated_data_augment`. 83 | 84 | If you would like to run data augmentation by yourself, first download `wikisql_tables.json` and `train_patterns.json` from [here](https://drive.google.com/file/d/13I_EqnAR4v2aE-CWhJ0XQ8c-UlGS9oic/view?usp=sharing), and then run ```python generate_wikisql_augment.py``` to generate more training data. Second, run `get_data_wikisql.py` to generate WikiSQL augment json file. Finally, use `merge_jsons.py` to generate the final spider + wikisql + wikisql augment dataset. 85 | 86 | #### Acknowledgement 87 | 88 | The implementation is based on [SQLNet](https://github.com/xiaojunxu/SQLNet). Please cite it too if you use this code. 89 | -------------------------------------------------------------------------------- /models/multisql_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class MultiSqlPredictor(nn.Module): 11 | '''Predict if the next token is (multi SQL key words): 12 | NONE, EXCEPT, INTERSECT, or UNION.''' 13 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 14 | super(MultiSqlPredictor, self).__init__() 15 | self.N_h = N_h 16 | self.gpu = gpu 17 | self.use_hs = use_hs 18 | 19 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 20 | num_layers=N_depth, batch_first=True, 21 | dropout=0.3, bidirectional=True) 22 | 23 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 24 | num_layers=N_depth, batch_first=True, 25 | dropout=0.3, bidirectional=True) 26 | 27 | self.mkw_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 28 | num_layers=N_depth, batch_first=True, 29 | dropout=0.3, bidirectional=True) 30 | 31 | self.q_att = nn.Linear(N_h, N_h) 32 | self.hs_att = nn.Linear(N_h, N_h) 33 | self.multi_out_q = nn.Linear(N_h, N_h) 34 | self.multi_out_hs = nn.Linear(N_h, N_h) 35 | self.multi_out_c = nn.Linear(N_h, N_h) 36 | self.multi_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) 37 | 38 | self.softmax = nn.Softmax() #dim=1 39 | self.CE = nn.CrossEntropyLoss() 40 | self.log_softmax = nn.LogSoftmax() 41 | self.mlsml = nn.MultiLabelSoftMarginLoss() 42 | self.bce_logit = nn.BCEWithLogitsLoss() 43 | self.sigm = nn.Sigmoid() 44 | 45 | if gpu: 46 | self.cuda() 47 | 48 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len): 49 | # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size())) 50 | max_q_len = max(q_len) 51 | max_hs_len = max(hs_len) 52 | B = len(q_len) 53 | 54 | # q_enc: (B, max_q_len, hid_dim) 55 | # hs_enc: (B, max_hs_len, hid_dim) 56 | # mkw: (B, 4, hid_dim) 57 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 58 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 59 | mkw_enc, _ = run_lstm(self.mkw_lstm, mkw_emb_var, mkw_len) 60 | 61 | # Compute attention values between multi SQL key words and question tokens. 62 | # qmkw_att(q_enc).transpose(1, 2): (B, hid_dim, max_q_len) 63 | # att_val_qmkw: (B, 4, max_q_len) 64 | # print("mkw_enc {} q_enc {}".format(mkw_enc.size(), self.q_att(q_enc).transpose(1, 2).size())) 65 | att_val_qmkw = torch.bmm(mkw_enc, self.q_att(q_enc).transpose(1, 2)) 66 | # assign appended positions values -100 67 | for idx, num in enumerate(q_len): 68 | if num < max_q_len: 69 | att_val_qmkw[idx, :, num:] = -100 70 | # att_prob_qmkw: (B, 4, max_q_len) 71 | att_prob_qmkw = self.softmax(att_val_qmkw.view((-1, max_q_len))).view(B, -1, max_q_len) 72 | # q_enc.unsqueeze(1): (B, 1, max_q_len, hid_dim) 73 | # att_prob_qmkw.unsqueeze(3): (B, 4, max_q_len, 1) 74 | # q_weighted: (B, 4, hid_dim) 75 | q_weighted = (q_enc.unsqueeze(1) * att_prob_qmkw.unsqueeze(3)).sum(2) 76 | 77 | # Same as the above, compute SQL history embedding weighted by key words attentions 78 | att_val_hsmkw = torch.bmm(mkw_enc, self.hs_att(hs_enc).transpose(1, 2)) 79 | for idx, num in enumerate(hs_len): 80 | if num < max_hs_len: 81 | att_val_hsmkw[idx, :, num:] = -100 82 | att_prob_hsmkw = self.softmax(att_val_hsmkw.view((-1, max_hs_len))).view(B, -1, max_hs_len) 83 | hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hsmkw.unsqueeze(3)).sum(2) 84 | 85 | # Compute prediction scores 86 | # self.multi_out.squeeze(): (B, 4, 1) -> (B, 4) 87 | mulit_score = self.multi_out(self.multi_out_q(q_weighted) + int(self.use_hs)* self.multi_out_hs(hs_weighted) + self.multi_out_c(mkw_enc)).view(B,-1) 88 | 89 | return mulit_score 90 | 91 | 92 | def loss(self, score, truth): 93 | data = torch.from_numpy(np.array(truth)) 94 | truth_var = Variable(data.cuda()) 95 | loss = self.CE(score, truth_var) 96 | 97 | return loss 98 | 99 | 100 | def check_acc(self, score, truth): 101 | err = 0 102 | B = len(score) 103 | pred = [] 104 | for b in range(B): 105 | pred.append(np.argmax(score[b].data.cpu().numpy())) 106 | for b, (p, t) in enumerate(zip(pred, truth)): 107 | if p != t: 108 | err += 1 109 | 110 | return err 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import datetime 4 | import argparse 5 | import numpy as np 6 | from utils import * 7 | from word_embedding import WordEmbedding 8 | from models.agg_predictor import AggPredictor 9 | from models.col_predictor import ColPredictor 10 | from models.desasc_limit_predictor import DesAscLimitPredictor 11 | from models.having_predictor import HavingPredictor 12 | from models.keyword_predictor import KeyWordPredictor 13 | from models.multisql_predictor import MultiSqlPredictor 14 | from models.op_predictor import OpPredictor 15 | from models.root_teminal_predictor import RootTeminalPredictor 16 | from models.andor_predictor import AndOrPredictor 17 | 18 | TRAIN_COMPONENTS = ('multi_sql','keyword','col','op','agg','root_tem','des_asc','having','andor') 19 | SQL_TOK = ['', '', 'WHERE', 'AND', 'EQL', 'GT', 'LT', ''] 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--toy', action='store_true', 23 | help='If set, use small data; used for fast debugging.') 24 | parser.add_argument('--save_dir', type=str, default='', 25 | help='set model save directory.') 26 | parser.add_argument('--data_root', type=str, default='', 27 | help='root path for generated_data') 28 | parser.add_argument('--train_emb', action='store_true', 29 | help='Train word embedding.') 30 | parser.add_argument('--train_component',type=str,default='', 31 | help='set train components,available:[multi_sql,keyword,col,op,agg,root_tem,des_asc,having,andor]') 32 | parser.add_argument('--epoch',type=int,default=500, 33 | help='number of epoch for training') 34 | parser.add_argument('--history_type', type=str, default='full', choices=['full','part','no'], help='full, part, or no history') 35 | parser.add_argument('--table_type', type=str, default='std', choices=['std','no'], help='standard, hierarchical, or no table info') 36 | args = parser.parse_args() 37 | use_hs = True 38 | if args.history_type == "no": 39 | args.history_type = "full" 40 | use_hs = False 41 | 42 | 43 | N_word=300 44 | B_word=42 45 | N_h = 300 46 | N_depth=2 47 | if args.toy: 48 | USE_SMALL=True 49 | GPU=True 50 | BATCH_SIZE=20 51 | else: 52 | USE_SMALL=False 53 | GPU=True 54 | BATCH_SIZE=64 55 | # TRAIN_ENTRY=(False, True, False) # (AGG, SEL, COND) 56 | # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY 57 | learning_rate = 1e-4 58 | if args.train_component not in TRAIN_COMPONENTS: 59 | print("Invalid train component") 60 | exit(1) 61 | train_data = load_train_dev_dataset(args.train_component, "train", args.history_type, args.data_root) 62 | dev_data = load_train_dev_dataset(args.train_component, "dev", args.history_type, args.data_root) 63 | # sql_data, table_data, val_sql_data, val_table_data, \ 64 | # test_sql_data, test_table_data, \ 65 | # TRAIN_DB, DEV_DB, TEST_DB = load_dataset(args.dataset, use_small=USE_SMALL) 66 | 67 | word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ 68 | load_used=args.train_emb, use_small=USE_SMALL) 69 | print("finished load word embedding") 70 | #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt") 71 | model = None 72 | if args.train_component == "multi_sql": 73 | model = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 74 | elif args.train_component == "keyword": 75 | model = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 76 | elif args.train_component == "col": 77 | model = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 78 | elif args.train_component == "op": 79 | model = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 80 | elif args.train_component == "agg": 81 | model = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 82 | elif args.train_component == "root_tem": 83 | model = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 84 | elif args.train_component == "des_asc": 85 | model = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 86 | elif args.train_component == "having": 87 | model = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) 88 | elif args.train_component == "andor": 89 | model = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) 90 | # model = SQLNet(word_emb, N_word=N_word, gpu=GPU, trainable_emb=args.train_emb) 91 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 0) 92 | print("finished build model") 93 | 94 | print_flag = False 95 | embed_layer = WordEmbedding(word_emb, N_word, gpu=GPU, 96 | SQL_TOK=SQL_TOK, trainable=args.train_emb) 97 | print("start training") 98 | best_acc = 0.0 99 | for i in range(args.epoch): 100 | print('Epoch %d @ %s'%(i+1, datetime.datetime.now())) 101 | print(' Loss = %s'%epoch_train( 102 | model, optimizer, BATCH_SIZE,args.train_component,embed_layer,train_data,table_type=args.table_type)) 103 | acc = epoch_acc(model, BATCH_SIZE, args.train_component,embed_layer,dev_data,table_type=args.table_type) 104 | if acc > best_acc: 105 | best_acc = acc 106 | print("Save model...") 107 | torch.save(model.state_dict(), args.save_dir+"/{}_models.dump".format(args.train_component)) 108 | -------------------------------------------------------------------------------- /models/agg_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class AggPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(AggPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 26 | num_layers=N_depth, batch_first=True, 27 | dropout=0.3, bidirectional=True) 28 | 29 | self.q_num_att = nn.Linear(N_h, N_h) 30 | self.hs_num_att = nn.Linear(N_h, N_h) 31 | self.agg_num_out_q = nn.Linear(N_h, N_h) 32 | self.agg_num_out_hs = nn.Linear(N_h, N_h) 33 | self.agg_num_out_c = nn.Linear(N_h, N_h) 34 | self.agg_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 4)) #for 0-3 agg num 35 | 36 | self.q_att = nn.Linear(N_h, N_h) 37 | self.hs_att = nn.Linear(N_h, N_h) 38 | self.agg_out_q = nn.Linear(N_h, N_h) 39 | self.agg_out_hs = nn.Linear(N_h, N_h) 40 | self.agg_out_c = nn.Linear(N_h, N_h) 41 | self.agg_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 5)) #for 1-5 aggregators 42 | 43 | self.softmax = nn.Softmax() #dim=1 44 | self.CE = nn.CrossEntropyLoss() 45 | self.log_softmax = nn.LogSoftmax() 46 | self.mlsml = nn.MultiLabelSoftMarginLoss() 47 | self.bce_logit = nn.BCEWithLogitsLoss() 48 | self.sigm = nn.Sigmoid() 49 | if gpu: 50 | self.cuda() 51 | 52 | 53 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): 54 | max_q_len = max(q_len) 55 | max_hs_len = max(hs_len) 56 | max_col_len = max(col_len) 57 | B = len(q_len) 58 | 59 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 60 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 61 | col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) 62 | 63 | col_emb = [] 64 | for b in range(B): 65 | col_emb.append(col_enc[b, gt_col[b]]) 66 | col_emb = torch.stack(col_emb) 67 | 68 | # Predict agg number 69 | att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1) 70 | for idx, num in enumerate(q_len): 71 | if num < max_q_len: 72 | att_val_qc_num[idx, num:] = -100 73 | att_prob_qc_num = self.softmax(att_val_qc_num) 74 | q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) 75 | 76 | # Same as the above, compute SQL history embedding weighted by column attentions 77 | att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1) 78 | for idx, num in enumerate(hs_len): 79 | if num < max_hs_len: 80 | att_val_hc_num[idx, num:] = -100 81 | att_prob_hc_num = self.softmax(att_val_hc_num) 82 | hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) 83 | # agg_num_score: (B, 4) 84 | agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) + int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) + self.agg_num_out_c(col_emb)) 85 | 86 | # Predict aggregators 87 | att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) 88 | for idx, num in enumerate(q_len): 89 | if num < max_q_len: 90 | att_val_qc[idx, num:] = -100 91 | att_prob_qc = self.softmax(att_val_qc) 92 | q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) 93 | 94 | # Same as the above, compute SQL history embedding weighted by column attentions 95 | att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) 96 | for idx, num in enumerate(hs_len): 97 | if num < max_hs_len: 98 | att_val_hc[idx, num:] = -100 99 | att_prob_hc = self.softmax(att_val_hc) 100 | hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) 101 | # agg_score: (B, 5) 102 | agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) + self.agg_out_c(col_emb)) 103 | 104 | score = (agg_num_score, agg_score) 105 | 106 | return score 107 | 108 | 109 | def loss(self, score, truth): 110 | loss = 0 111 | B = len(truth) 112 | agg_num_score, agg_score = score 113 | #loss for the column number 114 | truth_num = [len(t) for t in truth] # double check truth format and for test cases 115 | data = torch.from_numpy(np.array(truth_num)) 116 | truth_num_var = Variable(data.cuda()) 117 | loss += self.CE(agg_num_score, truth_num_var) 118 | #loss for the key words 119 | T = len(agg_score[0]) 120 | truth_prob = np.zeros((B, T), dtype=np.float32) 121 | for b in range(B): 122 | truth_prob[b][truth[b]] = 1 123 | data = torch.from_numpy(truth_prob) 124 | truth_var = Variable(data.cuda()) 125 | #loss += self.mlsml(agg_score, truth_var) 126 | #loss += self.bce_logit(agg_score, truth_var) # double check no sigmoid 127 | pred_prob = self.sigm(agg_score) 128 | bce_loss = -torch.mean( 3*(truth_var * \ 129 | torch.log(pred_prob+1e-10)) + \ 130 | (1-truth_var) * torch.log(1-pred_prob+1e-10) ) 131 | loss += bce_loss 132 | 133 | return loss 134 | 135 | 136 | def check_acc(self, score, truth): 137 | num_err, err, tot_err = 0, 0, 0 138 | B = len(truth) 139 | pred = [] 140 | agg_num_score, agg_score = [x.data.cpu().numpy() for x in score] 141 | for b in range(B): 142 | cur_pred = {} 143 | agg_num = np.argmax(agg_num_score[b]) #double check 144 | cur_pred['agg_num'] = agg_num 145 | cur_pred['agg'] = np.argsort(-agg_score[b])[:agg_num] 146 | pred.append(cur_pred) 147 | 148 | for b, (p, t) in enumerate(zip(pred, truth)): 149 | agg_num, agg = p['agg_num'], p['agg'] 150 | flag = True 151 | if agg_num != len(t): # double check truth format and for test cases 152 | num_err += 1 153 | flag = False 154 | if flag and set(agg) != set(t): 155 | err += 1 156 | flag = False 157 | if not flag: 158 | tot_err += 1 159 | 160 | return np.array((num_err, err, tot_err)) 161 | -------------------------------------------------------------------------------- /models/keyword_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class KeyWordPredictor(nn.Module): 11 | '''Predict if the next token is (SQL key words): 12 | WHERE, GROUP BY, ORDER BY. excluding SELECT (it is a must)''' 13 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 14 | super(KeyWordPredictor, self).__init__() 15 | self.N_h = N_h 16 | self.gpu = gpu 17 | self.use_hs = use_hs 18 | 19 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 20 | num_layers=N_depth, batch_first=True, 21 | dropout=0.3, bidirectional=True) 22 | 23 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 24 | num_layers=N_depth, batch_first=True, 25 | dropout=0.3, bidirectional=True) 26 | 27 | self.kw_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 28 | num_layers=N_depth, batch_first=True, 29 | dropout=0.3, bidirectional=True) 30 | 31 | self.q_num_att = nn.Linear(N_h, N_h) 32 | self.hs_num_att = nn.Linear(N_h, N_h) 33 | self.kw_num_out_q = nn.Linear(N_h, N_h) 34 | self.kw_num_out_hs = nn.Linear(N_h, N_h) 35 | self.kw_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 4)) # num of key words: 0-3 36 | 37 | self.q_att = nn.Linear(N_h, N_h) 38 | self.hs_att = nn.Linear(N_h, N_h) 39 | self.kw_out_q = nn.Linear(N_h, N_h) 40 | self.kw_out_hs = nn.Linear(N_h, N_h) 41 | self.kw_out_kw = nn.Linear(N_h, N_h) 42 | self.kw_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) 43 | 44 | self.softmax = nn.Softmax() #dim=1 45 | self.CE = nn.CrossEntropyLoss() 46 | self.log_softmax = nn.LogSoftmax() 47 | self.mlsml = nn.MultiLabelSoftMarginLoss() 48 | self.bce_logit = nn.BCEWithLogitsLoss() 49 | self.sigm = nn.Sigmoid() 50 | if gpu: 51 | self.cuda() 52 | 53 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len): 54 | max_q_len = max(q_len) 55 | max_hs_len = max(hs_len) 56 | B = len(q_len) 57 | 58 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 59 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 60 | kw_enc, _ = run_lstm(self.kw_lstm, kw_emb_var, kw_len) 61 | 62 | # Predict key words number: 0-3 63 | att_val_qkw_num = torch.bmm(kw_enc, self.q_num_att(q_enc).transpose(1, 2)) 64 | for idx, num in enumerate(q_len): 65 | if num < max_q_len: 66 | att_val_qkw_num[idx, :, num:] = -100 67 | att_prob_qkw_num = self.softmax(att_val_qkw_num.view((-1, max_q_len))).view(B, -1, max_q_len) 68 | # q_weighted: (B, hid_dim) 69 | q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qkw_num.unsqueeze(3)).sum(2).sum(1) 70 | 71 | # Same as the above, compute SQL history embedding weighted by key words attentions 72 | att_val_hskw_num = torch.bmm(kw_enc, self.hs_num_att(hs_enc).transpose(1, 2)) 73 | for idx, num in enumerate(hs_len): 74 | if num < max_hs_len: 75 | att_val_hskw_num[idx, :, num:] = -100 76 | att_prob_hskw_num = self.softmax(att_val_hskw_num.view((-1, max_hs_len))).view(B, -1, max_hs_len) 77 | hs_weighted_num = (hs_enc.unsqueeze(1) * att_prob_hskw_num.unsqueeze(3)).sum(2).sum(1) 78 | # Compute prediction scores 79 | # self.kw_num_out: (B, 4) 80 | kw_num_score = self.kw_num_out(self.kw_num_out_q(q_weighted_num) + int(self.use_hs)* self.kw_num_out_hs(hs_weighted_num)) 81 | 82 | # Predict key words: WHERE, GROUP BY, ORDER BY. 83 | att_val_qkw = torch.bmm(kw_enc, self.q_att(q_enc).transpose(1, 2)) 84 | for idx, num in enumerate(q_len): 85 | if num < max_q_len: 86 | att_val_qkw[idx, :, num:] = -100 87 | att_prob_qkw = self.softmax(att_val_qkw.view((-1, max_q_len))).view(B, -1, max_q_len) 88 | # q_weighted: (B, 3, hid_dim) 89 | q_weighted = (q_enc.unsqueeze(1) * att_prob_qkw.unsqueeze(3)).sum(2) 90 | 91 | # Same as the above, compute SQL history embedding weighted by key words attentions 92 | att_val_hskw = torch.bmm(kw_enc, self.hs_att(hs_enc).transpose(1, 2)) 93 | for idx, num in enumerate(hs_len): 94 | if num < max_hs_len: 95 | att_val_hskw[idx, :, num:] = -100 96 | att_prob_hskw = self.softmax(att_val_hskw.view((-1, max_hs_len))).view(B, -1, max_hs_len) 97 | hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hskw.unsqueeze(3)).sum(2) 98 | # Compute prediction scores 99 | # self.kw_out.squeeze(): (B, 3) 100 | kw_score = self.kw_out(self.kw_out_q(q_weighted) + int(self.use_hs)* self.kw_out_hs(hs_weighted) + self.kw_out_kw(kw_enc)).view(B,-1) 101 | 102 | score = (kw_num_score, kw_score) 103 | 104 | return score 105 | 106 | def loss(self, score, truth): 107 | loss = 0 108 | B = len(truth) 109 | kw_num_score, kw_score = score 110 | #loss for the key word number 111 | truth_num = [len(t) for t in truth] # double check to exclude select 112 | data = torch.from_numpy(np.array(truth_num)) 113 | truth_num_var = Variable(data.cuda()) 114 | loss += self.CE(kw_num_score, truth_num_var) 115 | #loss for the key words 116 | T = len(kw_score[0]) 117 | truth_prob = np.zeros((B, T), dtype=np.float32) 118 | for b in range(B): 119 | truth_prob[b][truth[b]] = 1 120 | data = torch.from_numpy(truth_prob) 121 | truth_var = Variable(data.cuda()) 122 | #loss += self.mlsml(kw_score, truth_var) 123 | #loss += self.bce_logit(kw_score, truth_var) # double check no sigmoid for kw 124 | pred_prob = self.sigm(kw_score) 125 | bce_loss = -torch.mean( 3*(truth_var * \ 126 | torch.log(pred_prob+1e-10)) + \ 127 | (1-truth_var) * torch.log(1-pred_prob+1e-10) ) 128 | loss += bce_loss 129 | 130 | return loss 131 | 132 | 133 | def check_acc(self, score, truth): 134 | num_err, err, tot_err = 0, 0, 0 135 | B = len(truth) 136 | pred = [] 137 | kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] 138 | for b in range(B): 139 | cur_pred = {} 140 | kw_num = np.argmax(kw_num_score[b]) 141 | cur_pred['kw_num'] = kw_num 142 | cur_pred['kw'] = np.argsort(-kw_score[b])[:kw_num] 143 | pred.append(cur_pred) 144 | 145 | for b, (p, t) in enumerate(zip(pred, truth)): 146 | kw_num, kw = p['kw_num'], p['kw'] 147 | flag = True 148 | if kw_num != len(t): # double check to excluding select 149 | num_err += 1 150 | flag = False 151 | if flag and set(kw) != set(t): 152 | err += 1 153 | flag = False 154 | if not flag: 155 | tot_err += 1 156 | 157 | return np.array((num_err, err, tot_err)) 158 | -------------------------------------------------------------------------------- /models/op_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class OpPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(OpPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 26 | num_layers=N_depth, batch_first=True, 27 | dropout=0.3, bidirectional=True) 28 | 29 | self.q_num_att = nn.Linear(N_h, N_h) 30 | self.hs_num_att = nn.Linear(N_h, N_h) 31 | self.op_num_out_q = nn.Linear(N_h, N_h) 32 | self.op_num_out_hs = nn.Linear(N_h, N_h) 33 | self.op_num_out_c = nn.Linear(N_h, N_h) 34 | self.op_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for 1-2 op num, could be changed 35 | 36 | self.q_att = nn.Linear(N_h, N_h) 37 | self.hs_att = nn.Linear(N_h, N_h) 38 | self.op_out_q = nn.Linear(N_h, N_h) 39 | self.op_out_hs = nn.Linear(N_h, N_h) 40 | self.op_out_c = nn.Linear(N_h, N_h) 41 | self.op_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 11)) #for 11 operators 42 | 43 | self.softmax = nn.Softmax() #dim=1 44 | self.CE = nn.CrossEntropyLoss() 45 | self.log_softmax = nn.LogSoftmax() 46 | self.mlsml = nn.MultiLabelSoftMarginLoss() 47 | self.bce_logit = nn.BCEWithLogitsLoss() 48 | self.sigm = nn.Sigmoid() 49 | if gpu: 50 | self.cuda() 51 | 52 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): 53 | max_q_len = max(q_len) 54 | max_hs_len = max(hs_len) 55 | max_col_len = max(col_len) 56 | B = len(q_len) 57 | 58 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 59 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 60 | col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) 61 | 62 | # get target/predicted column's embedding 63 | # col_emb: (B, hid_dim) 64 | col_emb = [] 65 | for b in range(B): 66 | col_emb.append(col_enc[b, gt_col[b]]) 67 | col_emb = torch.stack(col_emb) 68 | 69 | # Predict op number 70 | att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B,-1) 71 | for idx, num in enumerate(q_len): 72 | if num < max_q_len: 73 | att_val_qc_num[idx, num:] = -100 74 | att_prob_qc_num = self.softmax(att_val_qc_num) 75 | q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) 76 | 77 | # Same as the above, compute SQL history embedding weighted by column attentions 78 | att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B,-1) 79 | for idx, num in enumerate(hs_len): 80 | if num < max_hs_len: 81 | att_val_hc_num[idx, num:] = -100 82 | att_prob_hc_num = self.softmax(att_val_hc_num) 83 | hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) 84 | # op_num_score: (B, 2) 85 | op_num_score = self.op_num_out(self.op_num_out_q(q_weighted_num) + int(self.use_hs)* self.op_num_out_hs(hs_weighted_num) + self.op_num_out_c(col_emb)) 86 | 87 | # Compute attention values between selected column and question tokens. 88 | # q_enc.transpose(1, 2): (B, hid_dim, max_q_len) 89 | # col_emb.unsqueeze(1): (B, 1, hid_dim) 90 | # att_val_qc: (B, max_q_len) 91 | # print("col_emb {} q_enc {}".format(col_emb.unsqueeze(1).size(),self.q_att(q_enc).transpose(1, 2).size())) 92 | att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) 93 | # assign appended positions values -100 94 | for idx, num in enumerate(q_len): 95 | if num < max_q_len: 96 | att_val_qc[idx, num:] = -100 97 | # att_prob_qc: (B, max_q_len) 98 | att_prob_qc = self.softmax(att_val_qc) 99 | # q_enc: (B, max_q_len, hid_dim) 100 | # att_prob_qc.unsqueeze(2): (B, max_q_len, 1) 101 | # q_weighted: (B, hid_dim) 102 | q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) 103 | 104 | # Same as the above, compute SQL history embedding weighted by column attentions 105 | att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) 106 | for idx, num in enumerate(hs_len): 107 | if num < max_hs_len: 108 | att_val_hc[idx, num:] = -100 109 | att_prob_hc = self.softmax(att_val_hc) 110 | hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) 111 | 112 | # Compute prediction scores 113 | # op_score: (B, 10) 114 | op_score = self.op_out(self.op_out_q(q_weighted) + int(self.use_hs)* self.op_out_hs(hs_weighted) + self.op_out_c(col_emb)) 115 | 116 | score = (op_num_score, op_score) 117 | 118 | return score 119 | 120 | 121 | def loss(self, score, truth): 122 | loss = 0 123 | B = len(truth) 124 | op_num_score, op_score = score 125 | truth = [t if len(t) <= 2 else t[:2] for t in truth] 126 | # loss for the op number 127 | truth_num = [len(t)-1 for t in truth] #num_score 0 maps to 1 in truth 128 | data = torch.from_numpy(np.array(truth_num)) 129 | truth_num_var = Variable(data.cuda()) 130 | loss += self.CE(op_num_score, truth_num_var) 131 | # loss for op 132 | T = len(op_score[0]) 133 | truth_prob = np.zeros((B, T), dtype=np.float32) 134 | for b in range(B): 135 | truth_prob[b][truth[b]] = 1 136 | data = torch.from_numpy(np.array(truth_prob)) 137 | truth_var = Variable(data.cuda()) 138 | #loss += self.mlsml(op_score, truth_var) 139 | #loss += self.bce_logit(op_score, truth_var) 140 | pred_prob = self.sigm(op_score) 141 | bce_loss = -torch.mean( 3*(truth_var * \ 142 | torch.log(pred_prob+1e-10)) + \ 143 | (1-truth_var) * torch.log(1-pred_prob+1e-10) ) 144 | loss += bce_loss 145 | 146 | return loss 147 | 148 | 149 | def check_acc(self, score, truth): 150 | num_err, err, tot_err = 0, 0, 0 151 | B = len(truth) 152 | pred = [] 153 | op_num_score, op_score = [x.data.cpu().numpy() for x in score] 154 | for b in range(B): 155 | cur_pred = {} 156 | op_num = np.argmax(op_num_score[b]) + 1 #num_score 0 maps to 1 in truth, must have at least one op 157 | cur_pred['op_num'] = op_num 158 | cur_pred['op'] = np.argsort(-op_score[b])[:op_num] 159 | pred.append(cur_pred) 160 | 161 | for b, (p, t) in enumerate(zip(pred, truth)): 162 | op_num, op = p['op_num'], p['op'] 163 | flag = True 164 | if op_num != len(t): 165 | num_err += 1 166 | flag = False 167 | if flag and set(op) != set(t): 168 | err += 1 169 | flag = False 170 | if not flag: 171 | tot_err += 1 172 | 173 | return np.array((num_err, err, tot_err)) 174 | -------------------------------------------------------------------------------- /word_embedding.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | AGG_OPS = ('none', 'maximum', 'minimum', 'count', 'sum', 'average') 9 | class WordEmbedding(nn.Module): 10 | def __init__(self, word_emb, N_word, gpu, SQL_TOK, 11 | trainable=False): 12 | super(WordEmbedding, self).__init__() 13 | self.trainable = trainable 14 | self.N_word = N_word 15 | self.gpu = gpu 16 | self.SQL_TOK = SQL_TOK 17 | 18 | if trainable: 19 | print "Using trainable embedding" 20 | self.w2i, word_emb_val = word_emb 21 | # tranable when using pretrained model, init embedding weights using prev embedding 22 | self.embedding = nn.Embedding(len(self.w2i), N_word) 23 | self.embedding.weight = nn.Parameter(torch.from_numpy(word_emb_val.astype(np.float32))) 24 | else: 25 | # else use word2vec or glove 26 | self.word_emb = word_emb 27 | print "Using fixed embedding" 28 | 29 | def gen_x_q_batch(self, q): 30 | B = len(q) 31 | val_embs = [] 32 | val_len = np.zeros(B, dtype=np.int64) 33 | for i, one_q in enumerate(q): 34 | q_val = [] 35 | for ws in one_q: 36 | q_val.append(self.word_emb.get(ws, np.zeros(self.N_word, dtype=np.float32))) 37 | 38 | val_embs.append([np.zeros(self.N_word, dtype=np.float32)] + q_val + [np.zeros(self.N_word, dtype=np.float32)]) # and 39 | val_len[i] = 1 + len(q_val) + 1 40 | max_len = max(val_len) 41 | 42 | val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32) 43 | for i in range(B): 44 | for t in range(len(val_embs[i])): 45 | val_emb_array[i, t, :] = val_embs[i][t] 46 | val_inp = torch.from_numpy(val_emb_array) 47 | if self.gpu: 48 | val_inp = val_inp.cuda() 49 | val_inp_var = Variable(val_inp) 50 | 51 | return val_inp_var, val_len 52 | 53 | def gen_x_history_batch(self, history): 54 | B = len(history) 55 | val_embs = [] 56 | val_len = np.zeros(B, dtype=np.int64) 57 | for i, one_history in enumerate(history): 58 | history_val = [] 59 | for item in one_history: 60 | #col 61 | if isinstance(item, list) or isinstance(item, tuple): 62 | emb_list = [] 63 | ws = item[0].split() + item[1].split() 64 | ws_len = len(ws) 65 | for w in ws: 66 | emb_list.append(self.word_emb.get(w, np.zeros(self.N_word, dtype=np.float32))) 67 | if ws_len == 0: 68 | raise Exception("word list should not be empty!") 69 | elif ws_len == 1: 70 | history_val.append(emb_list[0]) 71 | else: 72 | history_val.append(sum(emb_list) / float(ws_len)) 73 | #ROOT 74 | elif isinstance(item,basestring): 75 | if item == "ROOT": 76 | item = "root" 77 | elif item == "asc": 78 | item = "ascending" 79 | elif item == "desc": 80 | item == "descending" 81 | if item in ( 82 | "none", "select", "from", "where", "having", "limit", "intersect", "except", "union", 'not', 83 | 'between', '=', '>', '<', 'in', 'like', 'is', 'exists', 'root', 'ascending', 'descending'): 84 | history_val.append(self.word_emb.get(item, np.zeros(self.N_word, dtype=np.float32))) 85 | elif item == "orderBy": 86 | history_val.append((self.word_emb.get("order", np.zeros(self.N_word, dtype=np.float32)) + 87 | self.word_emb.get("by", np.zeros(self.N_word, dtype=np.float32))) / 2) 88 | elif item == "groupBy": 89 | history_val.append((self.word_emb.get("group", np.zeros(self.N_word, dtype=np.float32)) + 90 | self.word_emb.get("by", np.zeros(self.N_word, dtype=np.float32))) / 2) 91 | elif item in ('>=', '<=', '!='): 92 | history_val.append((self.word_emb.get(item[0], np.zeros(self.N_word, dtype=np.float32)) + 93 | self.word_emb.get(item[1], np.zeros(self.N_word, dtype=np.float32))) / 2) 94 | elif isinstance(item,int): 95 | history_val.append(self.word_emb.get(AGG_OPS[item], np.zeros(self.N_word, dtype=np.float32))) 96 | else: 97 | print("Warning: unsupported data type in history! {}".format(item)) 98 | 99 | val_embs.append(history_val) 100 | val_len[i] = len(history_val) 101 | max_len = max(val_len) 102 | 103 | val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32) 104 | for i in range(B): 105 | for t in range(len(val_embs[i])): 106 | val_emb_array[i, t, :] = val_embs[i][t] 107 | val_inp = torch.from_numpy(val_emb_array) 108 | if self.gpu: 109 | val_inp = val_inp.cuda() 110 | val_inp_var = Variable(val_inp) 111 | 112 | return val_inp_var, val_len 113 | 114 | 115 | def gen_word_list_embedding(self,words,B): 116 | val_emb_array = np.zeros((B,len(words), self.N_word), dtype=np.float32) 117 | for i,word in enumerate(words): 118 | if len(word.split()) == 1: 119 | emb = self.word_emb.get(word, np.zeros(self.N_word, dtype=np.float32)) 120 | else: 121 | word = word.split() 122 | emb = (self.word_emb.get(word[0], np.zeros(self.N_word, dtype=np.float32)) 123 | +self.word_emb.get(word[1], np.zeros(self.N_word, dtype=np.float32)) )/2 124 | for b in range(B): 125 | val_emb_array[b,i,:] = emb 126 | val_inp = torch.from_numpy(val_emb_array) 127 | if self.gpu: 128 | val_inp = val_inp.cuda() 129 | val_inp_var = Variable(val_inp) 130 | return val_inp_var 131 | 132 | 133 | def gen_col_batch(self, cols): 134 | ret = [] 135 | col_len = np.zeros(len(cols), dtype=np.int64) 136 | 137 | names = [] 138 | for b, one_cols in enumerate(cols): 139 | names = names + one_cols 140 | col_len[b] = len(one_cols) 141 | #TODO: what is the diff bw name_len and col_len? 142 | name_inp_var, name_len = self.str_list_to_batch(names) 143 | return name_inp_var, name_len, col_len 144 | 145 | def str_list_to_batch(self, str_list): 146 | """get a list var of wemb of words in each column name in current bactch""" 147 | B = len(str_list) 148 | 149 | val_embs = [] 150 | val_len = np.zeros(B, dtype=np.int64) 151 | for i, one_str in enumerate(str_list): 152 | if self.trainable: 153 | val = [self.w2i.get(x, 0) for x in one_str] 154 | else: 155 | val = [self.word_emb.get(x, np.zeros( 156 | self.N_word, dtype=np.float32)) for x in one_str] 157 | val_embs.append(val) 158 | val_len[i] = len(val) 159 | max_len = max(val_len) 160 | 161 | if self.trainable: 162 | val_tok_array = np.zeros((B, max_len), dtype=np.int64) 163 | for i in range(B): 164 | for t in range(len(val_embs[i])): 165 | val_tok_array[i,t] = val_embs[i][t] 166 | val_tok = torch.from_numpy(val_tok_array) 167 | if self.gpu: 168 | val_tok = val_tok.cuda() 169 | val_tok_var = Variable(val_tok) 170 | val_inp_var = self.embedding(val_tok_var) 171 | else: 172 | val_emb_array = np.zeros( 173 | (B, max_len, self.N_word), dtype=np.float32) 174 | for i in range(B): 175 | for t in range(len(val_embs[i])): 176 | val_emb_array[i,t,:] = val_embs[i][t] 177 | val_inp = torch.from_numpy(val_emb_array) 178 | if self.gpu: 179 | val_inp = val_inp.cuda() 180 | val_inp_var = Variable(val_inp) 181 | 182 | return val_inp_var, val_len 183 | -------------------------------------------------------------------------------- /models/col_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | class ColPredictor(nn.Module): 11 | def __init__(self, N_word, N_h, N_depth, gpu, use_hs): 12 | super(ColPredictor, self).__init__() 13 | self.N_h = N_h 14 | self.gpu = gpu 15 | self.use_hs = use_hs 16 | 17 | self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 18 | num_layers=N_depth, batch_first=True, 19 | dropout=0.3, bidirectional=True) 20 | 21 | self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 22 | num_layers=N_depth, batch_first=True, 23 | dropout=0.3, bidirectional=True) 24 | 25 | self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 26 | num_layers=N_depth, batch_first=True, 27 | dropout=0.3, bidirectional=True) 28 | 29 | self.q_num_att = nn.Linear(N_h, N_h) 30 | self.hs_num_att = nn.Linear(N_h, N_h) 31 | self.col_num_out_q = nn.Linear(N_h, N_h) 32 | self.col_num_out_hs = nn.Linear(N_h, N_h) 33 | self.col_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 6)) # num of cols: 1-3 34 | 35 | self.q_att = nn.Linear(N_h, N_h) 36 | self.hs_att = nn.Linear(N_h, N_h) 37 | self.col_out_q = nn.Linear(N_h, N_h) 38 | self.col_out_c = nn.Linear(N_h, N_h) 39 | self.col_out_hs = nn.Linear(N_h, N_h) 40 | self.col_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) 41 | 42 | self.softmax = nn.Softmax() #dim=1 43 | self.CE = nn.CrossEntropyLoss() 44 | self.log_softmax = nn.LogSoftmax() 45 | self.mlsml = nn.MultiLabelSoftMarginLoss() 46 | self.bce_logit = nn.BCEWithLogitsLoss() 47 | self.sigm = nn.Sigmoid() 48 | if gpu: 49 | self.cuda() 50 | 51 | def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len): 52 | 53 | max_q_len = max(q_len) 54 | max_hs_len = max(hs_len) 55 | max_col_len = max(col_len) 56 | B = len(q_len) 57 | 58 | q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) 59 | hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) 60 | col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) 61 | 62 | # Predict column number: 1-3 63 | # att_val_qc_num: (B, max_col_len, max_q_len) 64 | att_val_qc_num = torch.bmm(col_enc, self.q_num_att(q_enc).transpose(1, 2)) 65 | for idx, num in enumerate(col_len): 66 | if num < max_col_len: 67 | att_val_qc_num[idx, num:, :] = -100 68 | for idx, num in enumerate(q_len): 69 | if num < max_q_len: 70 | att_val_qc_num[idx, :, num:] = -100 71 | att_prob_qc_num = self.softmax(att_val_qc_num.view((-1, max_q_len))).view(B, -1, max_q_len) 72 | # q_weighted_num: (B, hid_dim) 73 | q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qc_num.unsqueeze(3)).sum(2).sum(1) 74 | 75 | # Same as the above, compute SQL history embedding weighted by column attentions 76 | # att_val_hc_num: (B, max_col_len, max_hs_len) 77 | att_val_hc_num = torch.bmm(col_enc, self.hs_num_att(hs_enc).transpose(1, 2)) 78 | for idx, num in enumerate(hs_len): 79 | if num < max_hs_len: 80 | att_val_hc_num[idx, :, num:] = -100 81 | for idx, num in enumerate(col_len): 82 | if num < max_col_len: 83 | att_val_hc_num[idx, num:, :] = -100 84 | att_prob_hc_num = self.softmax(att_val_hc_num.view((-1, max_hs_len))).view(B, -1, max_hs_len) 85 | hs_weighted_num = (hs_enc.unsqueeze(1) * att_prob_hc_num.unsqueeze(3)).sum(2).sum(1) 86 | # self.col_num_out: (B, 3) 87 | col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num) + int(self.use_hs)* self.col_num_out_hs(hs_weighted_num)) 88 | 89 | # Predict columns. 90 | att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) 91 | for idx, num in enumerate(q_len): 92 | if num < max_q_len: 93 | att_val_qc[idx, :, num:] = -100 94 | att_prob_qc = self.softmax(att_val_qc.view((-1, max_q_len))).view(B, -1, max_q_len) 95 | # q_weighted: (B, max_col_len, hid_dim) 96 | q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) 97 | 98 | # Same as the above, compute SQL history embedding weighted by column attentions 99 | att_val_hc = torch.bmm(col_enc, self.hs_att(hs_enc).transpose(1, 2)) 100 | for idx, num in enumerate(hs_len): 101 | if num < max_hs_len: 102 | att_val_hc[idx, :, num:] = -100 103 | att_prob_hc = self.softmax(att_val_hc.view((-1, max_hs_len))).view(B, -1, max_hs_len) 104 | hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hc.unsqueeze(3)).sum(2) 105 | # Compute prediction scores 106 | # self.col_out.squeeze(): (B, max_col_len) 107 | col_score = self.col_out(self.col_out_q(q_weighted) + int(self.use_hs)* self.col_out_hs(hs_weighted) + self.col_out_c(col_enc)).view(B,-1) 108 | 109 | for idx, num in enumerate(col_len): 110 | if num < max_col_len: 111 | col_score[idx, num:] = -100 112 | 113 | score = (col_num_score, col_score) 114 | 115 | return score 116 | 117 | def loss(self, score, truth): 118 | #here suppose truth looks like [[[1, 4], 3], [], ...] 119 | loss = 0 120 | B = len(truth) 121 | col_num_score, col_score = score 122 | #loss for the column number 123 | truth_num = [len(t) - 1 for t in truth] # double check truth format and for test cases 124 | data = torch.from_numpy(np.array(truth_num)) 125 | truth_num_var = Variable(data.cuda()) 126 | loss += self.CE(col_num_score, truth_num_var) 127 | #loss for the key words 128 | T = len(col_score[0]) 129 | # print("T {}".format(T)) 130 | truth_prob = np.zeros((B, T), dtype=np.float32) 131 | for b in range(B): 132 | gold_l = [] 133 | for t in truth[b]: 134 | if isinstance(t, list): 135 | gold_l.extend(t) 136 | else: 137 | gold_l.append(t) 138 | truth_prob[b][gold_l] = 1 139 | data = torch.from_numpy(truth_prob) 140 | # print("data {}".format(data)) 141 | # print("data {}".format(data.cuda())) 142 | truth_var = Variable(data.cuda()) 143 | #loss += self.mlsml(col_score, truth_var) 144 | #loss += self.bce_logit(col_score, truth_var) # double check no sigmoid 145 | pred_prob = self.sigm(col_score) 146 | bce_loss = -torch.mean( 3*(truth_var * \ 147 | torch.log(pred_prob+1e-10)) + \ 148 | (1-truth_var) * torch.log(1-pred_prob+1e-10) ) 149 | loss += bce_loss 150 | 151 | return loss 152 | 153 | 154 | def check_acc(self, score, truth): 155 | num_err, err, tot_err = 0, 0, 0 156 | B = len(truth) 157 | pred = [] 158 | col_num_score, col_score = [x.data.cpu().numpy() for x in score] 159 | for b in range(B): 160 | cur_pred = {} 161 | col_num = np.argmax(col_num_score[b]) + 1 #double check 162 | cur_pred['col_num'] = col_num 163 | cur_pred['col'] = np.argsort(-col_score[b])[:col_num] 164 | pred.append(cur_pred) 165 | 166 | for b, (p, t) in enumerate(zip(pred, truth)): 167 | col_num, col = p['col_num'], p['col'] 168 | flag = True 169 | if col_num != len(t): # double check truth format and for test cases 170 | num_err += 1 171 | flag = False 172 | #to eval col predicts, if the gold sql has JOIN and foreign key col, then both fks are acceptable 173 | fk_list = [] 174 | regular = [] 175 | for l in t: 176 | if isinstance(l, list): 177 | fk_list.append(l) 178 | else: 179 | regular.append(l) 180 | 181 | if flag: #double check 182 | for c in col: 183 | for fk in fk_list: 184 | if c in fk: 185 | fk_list.remove(fk) 186 | for r in regular: 187 | if c == r: 188 | regular.remove(r) 189 | 190 | if len(fk_list) != 0 or len(regular) != 0: 191 | err += 1 192 | flag = False 193 | 194 | if not flag: 195 | tot_err += 1 196 | 197 | return np.array((num_err, err, tot_err)) 198 | -------------------------------------------------------------------------------- /get_data_wikisql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | import re 4 | import sys 5 | import json 6 | import sqlite3 7 | import sqlparse 8 | from os import listdir, makedirs 9 | from collections import OrderedDict 10 | from nltk import word_tokenize, tokenize 11 | from os.path import isfile, isdir, join, split, exists, splitext 12 | 13 | from process_sql import get_sql 14 | 15 | VALUE_NUM_SYMBOL = 'VALUE' 16 | 17 | class Schema: 18 | """ 19 | Simple schema which maps table&column to a unique identifier 20 | """ 21 | def __init__(self, schema, table): 22 | self._schema = schema 23 | self._table = table 24 | self._idMap = self._map(self._schema, self._table) 25 | 26 | @property 27 | def schema(self): 28 | return self._schema 29 | 30 | @property 31 | def idMap(self): 32 | return self._idMap 33 | 34 | def _map(self, schema, table): 35 | column_names_original = table['column_names_original'] 36 | table_names_original = table['table_names_original'] 37 | #print 'column_names_original: ', column_names_original 38 | #print 'table_names_original: ', table_names_original 39 | for i, (tab_id, col) in enumerate(column_names_original): 40 | if tab_id == -1: 41 | idMap = {'*': i} 42 | else: 43 | key = table_names_original[tab_id].lower() 44 | val = col.lower() 45 | idMap[key + "." + val] = i 46 | 47 | for i, tab in enumerate(table_names_original): 48 | key = tab.lower() 49 | idMap[key] = i 50 | 51 | return idMap 52 | 53 | 54 | def strip_query(query): 55 | ''' 56 | return keywords of sql query 57 | ''' 58 | query_keywords = [] 59 | query = query.strip().replace(";","").replace("\t","") 60 | query = query.replace("(", " ( ").replace(")", " ) ") 61 | query = query.replace(">=", " >= ").replace("=", " = ").replace("<=", " <= ").replace("!=", " != ") 62 | 63 | 64 | # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE} 65 | str_1 = re.findall("\"[^\"]*\"", query) 66 | str_2 = re.findall("\'[^\']*\'", query) 67 | values = str_1 + str_2 68 | for val in values: 69 | query = query.replace(val.strip(), VALUE_NUM_SYMBOL) 70 | 71 | query_tokenized = query.split() 72 | float_nums = re.findall("[-+]?\d*\.\d+", query) 73 | query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized] 74 | query = " ".join(query_tokenized) 75 | int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)] 76 | 77 | 78 | query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized] 79 | # print int_nums, query, query_tokenized 80 | 81 | for tok in query_tokenized: 82 | if "." in tok: 83 | table = re.findall("[Tt]\d+\.", tok) 84 | if len(table)>0: 85 | to = tok.replace(".", " . ").split() 86 | to = [t.lower() for t in to if len(t)>0] 87 | query_keywords.extend(to) 88 | else: 89 | query_keywords.append(tok.lower()) 90 | 91 | elif len(tok) > 0: 92 | query_keywords.append(tok.lower()) 93 | 94 | return query_keywords 95 | 96 | 97 | def get_schemas_from_json(fpath): 98 | with open(fpath) as f: 99 | data = json.load(f) 100 | db_names = [db['db_id'] for db in data] 101 | 102 | tables = {} 103 | schemas = {} 104 | for db in data: 105 | db_id = db['db_id'] 106 | schema = {} #{'table': [col.lower, ..., ]} * -> __all__ 107 | column_names_original = db['column_names_original'] 108 | table_names_original = db['table_names_original'] 109 | tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original} 110 | for i, tabn in enumerate(table_names_original): 111 | table = str(tabn.encode("utf8").lower()) 112 | cols = [str(col.encode("utf8").lower()) for td, col in column_names_original if td == i] 113 | schema[table] = cols 114 | schemas[db_id] = schema 115 | 116 | return schemas, db_names, tables 117 | 118 | 119 | def parse_file_and_sql(filepath, schema, db_id): 120 | f = open(filepath,"r") 121 | ret = [] 122 | lines = list(f.readlines()) 123 | f.close() 124 | i = 0 125 | questions = [] 126 | has_prefix = False 127 | while i < len(lines): 128 | line = lines[i].lstrip().rstrip() 129 | line = line.replace("\r","") 130 | line = line.replace("\n","") 131 | if len(line) == 0: 132 | i += 1 133 | continue 134 | if ord('0') <= ord(line[0]) <= ord('9'): 135 | #remove question number 136 | if len(questions) != 0: 137 | print '\n-----------------------------wrong indexing!-----------------------------------\n' 138 | print 'questions: ', questions 139 | sys.exit() 140 | index = line.find(".") 141 | if index != -1: 142 | line = line[index+1:] 143 | if line != '' and len(line) != 0: 144 | questions.append(line.lstrip().rstrip()) 145 | i += 1 146 | continue 147 | if line.startswith("P:"): 148 | index = line.find("P:") 149 | line = line[index+2:] 150 | if line != '' and len(line) != 0: 151 | questions.append(line.lstrip().rstrip()) 152 | has_prefix = True 153 | if (line.startswith("select") or line.startswith("SELECT") or line.startswith("Select") or \ 154 | line.startswith("with") or line.startswith("With") or line.startswith("WITH")) and has_prefix: 155 | sql = [line] 156 | i += 1 157 | while i < len(lines): 158 | line = lines[i] 159 | line = lines[i].lstrip().rstrip() 160 | line = line.replace("\r","") 161 | line = line.replace("\n","") 162 | if len(line) == 0 or len(line.strip()) == 0 or ord('0') <= ord(line[0]) <= ord('9') or \ 163 | not (line[0].isalpha() or line[0] in ['(',')','=','<','>', '+', '-','!','\'','\"','%']): 164 | break 165 | sql.append(line) 166 | i += 1 167 | sql = " ".join(sql) 168 | sql = sqlparse.format(sql, reindent=False, keyword_case='upper') 169 | sql = re.sub(r"(<=|>=|=|<|>|,)",r" \1 ",sql) 170 | # sql = sql.replace("\"","'") 171 | sql = re.sub(r"(T\d+\.)\s",r"\1",sql) 172 | #if len(questions) != 2: 173 | # print '\n-----------------------------wrong indexing!-----------------------------------\n' 174 | # print 'questions: ', questions 175 | # sys.exit() 176 | for ix, q in enumerate(questions): 177 | try: 178 | q = q.encode("utf8") 179 | sql = sql.encode("utf8") 180 | q_toks = word_tokenize(q) 181 | query_toks = word_tokenize(sql) 182 | query_toks_no_value = strip_query(sql) 183 | sql_label = None 184 | 185 | sql_label = get_sql(schema, sql) 186 | #print("query: {}".format(sql)) 187 | #print("\ndb_id: {}".format(db_id)) 188 | #print("query: {}".format(sql)) 189 | ret.append({'question': q, 190 | 'question_toks': q_toks, 191 | 'query': sql, 192 | 'query_toks': query_toks, 193 | 'query_toks_no_value': query_toks_no_value, 194 | 'sql': sql_label, 195 | 'db_id': db_id}) 196 | except Exception as e: 197 | #print("query: {}".format(sql)) 198 | #print(e) 199 | pass 200 | questions = [] 201 | has_prefix = False 202 | continue 203 | 204 | i += 1 205 | 206 | return ret 207 | 208 | 209 | if __name__ == '__main__': 210 | if len(sys.argv) < 3: 211 | print "Usage: python get_data.py [dir containing reviewed files] [processed table json file] [output file name e.g. output.json]" 212 | sys.exit() 213 | input_dir = sys.argv[1] 214 | table_file = sys.argv[2] 215 | output_file = sys.argv[3] 216 | 217 | schemas, db_names, tables = get_schemas_from_json(table_file) 218 | db_files = [f for f in listdir(input_dir) if f.endswith('.txt')] 219 | fn_map = {} 220 | for f in db_files: 221 | flag = True 222 | for db in db_names: 223 | if db.lower() in f.lower(): 224 | flag = False 225 | fn_map[f] = db 226 | continue 227 | if flag == True: 228 | print "db not found: ", f 229 | if len(db_files) != len(fn_map.keys()): 230 | tab_db_files = [f.lower() for f in fn_map.keys()] 231 | print 'Warning: misspelled files: ', [f for f in db_files if f.lower() not in tab_db_files] 232 | sys.exit() 233 | 234 | data = [] 235 | for f, db_id in fn_map.items(): 236 | raw_file = join(input_dir, f) 237 | #print 'reading labeled file for db: ', db_id 238 | schema = schemas[db_id] 239 | table = tables[db_id] 240 | schema = Schema(schema, table) 241 | data_one = parse_file_and_sql(raw_file, schema, db_id) 242 | data.extend(data_one) 243 | with open(output_file, 'wt') as out: 244 | json.dump(data, out, sort_keys=True, indent=4, separators=(',', ': ')) 245 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import io 3 | import json 4 | import numpy as np 5 | import os 6 | import signal 7 | from preprocess_train_dev_data import get_table_dict 8 | 9 | 10 | def load_train_dev_dataset(component,train_dev,history, root): 11 | return json.load(open("{}/{}_{}_{}_dataset.json".format(root, history,train_dev,component))) 12 | 13 | 14 | def to_batch_seq(data, idxes, st, ed): 15 | q_seq = [] 16 | history = [] 17 | label = [] 18 | for i in range(st, ed): 19 | q_seq.append(data[idxes[i]]['question_tokens']) 20 | history.append(data[idxes[i]]["history"]) 21 | label.append(data[idxes[i]]["label"]) 22 | return q_seq,history,label 23 | 24 | # CHANGED 25 | def to_batch_tables(data, idxes, st,ed, table_type): 26 | # col_lens = [] 27 | col_seq = [] 28 | for i in range(st, ed): 29 | ts = data[idxes[i]]["ts"] 30 | tname_toks = [x.split(" ") for x in ts[0]] 31 | col_type = ts[2] 32 | cols = [x.split(" ") for xid, x in ts[1]] 33 | tab_seq = [xid for xid, x in ts[1]] 34 | cols_add = [] 35 | for tid, col, ct in zip(tab_seq, cols, col_type): 36 | col_one = [ct] 37 | if tid == -1: 38 | tabn = ["all"] 39 | else: 40 | if table_type=="no": tabn = [] 41 | else: tabn = tname_toks[tid] 42 | for t in tabn: 43 | if t not in col: 44 | col_one.append(t) 45 | col_one.extend(col) 46 | cols_add.append(col_one) 47 | col_seq.append(cols_add) 48 | 49 | return col_seq 50 | 51 | ## used for training in train.py 52 | def epoch_train(model, optimizer, batch_size, component,embed_layer,data, table_type): 53 | model.train() 54 | perm=np.random.permutation(len(data)) 55 | cum_loss = 0.0 56 | st = 0 57 | 58 | while st < len(data): 59 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 60 | q_seq, history,label = to_batch_seq(data, perm, st, ed) 61 | q_emb_var, q_len = embed_layer.gen_x_q_batch(q_seq) 62 | hs_emb_var, hs_len = embed_layer.gen_x_history_batch(history) 63 | score = 0.0 64 | loss = 0.0 65 | if component == "multi_sql": 66 | mkw_emb_var = embed_layer.gen_word_list_embedding(["none","except","intersect","union"],(ed-st)) 67 | mkw_len = np.full(q_len.shape, 4,dtype=np.int64) 68 | # print("mkw_emb:{}".format(mkw_emb_var.size())) 69 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var=mkw_emb_var, mkw_len=mkw_len) 70 | elif component == "keyword": 71 | #where group by order by 72 | # [[0,1,2]] 73 | kw_emb_var = embed_layer.gen_word_list_embedding(["where", "group by", "order by"],(ed-st)) 74 | mkw_len = np.full(q_len.shape, 3, dtype=np.int64) 75 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var=kw_emb_var, kw_len=mkw_len) 76 | elif component == "col": 77 | #col word embedding 78 | # [[0,1,3]] 79 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 80 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 81 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len) 82 | 83 | elif component == "op": 84 | #B*index 85 | gt_col = np.zeros(q_len.shape,dtype=np.int64) 86 | index = 0 87 | for i in range(st,ed): 88 | # print(i) 89 | gt_col[index] = data[perm[i]]["gt_col"] 90 | index += 1 91 | 92 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 93 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 94 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 95 | 96 | elif component == "agg": 97 | # [[0,1,3]] 98 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 99 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 100 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 101 | # print(ed) 102 | index = 0 103 | for i in range(st, ed): 104 | # print(i) 105 | gt_col[index] = data[perm[i]]["gt_col"] 106 | index += 1 107 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 108 | 109 | elif component == "root_tem": 110 | #B*0/1 111 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 112 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 113 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 114 | # print(ed) 115 | index = 0 116 | for i in range(st, ed): 117 | # print(data[perm[i]]["history"]) 118 | gt_col[index] = data[perm[i]]["gt_col"] 119 | index += 1 120 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 121 | 122 | elif component == "des_asc": 123 | # B*0/1 124 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 125 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 126 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 127 | # print(ed) 128 | index = 0 129 | for i in range(st, ed): 130 | # print(i) 131 | gt_col[index] = data[perm[i]]["gt_col"] 132 | index += 1 133 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 134 | 135 | elif component == 'having': 136 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 137 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 138 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 139 | # print(ed) 140 | index = 0 141 | for i in range(st, ed): 142 | # print(i) 143 | gt_col[index] = data[perm[i]]["gt_col"] 144 | index += 1 145 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 146 | 147 | elif component == "andor": 148 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len) 149 | # score = model.forward(q_seq, col_seq, col_num, pred_entry, 150 | # gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq) 151 | # print("label {}".format(label)) 152 | loss = model.loss(score, label) 153 | # print("loss {}".format(loss.data.cpu().numpy())) 154 | cum_loss += loss.data.cpu().numpy()[0]*(ed - st) 155 | optimizer.zero_grad() 156 | loss.backward() 157 | optimizer.step() 158 | 159 | st = ed 160 | 161 | return cum_loss / len(data) 162 | 163 | ## used for development evaluation in train.py 164 | def epoch_acc(model, batch_size, component, embed_layer,data, table_type, error_print=False, train_flag = False): 165 | model.eval() 166 | perm = list(range(len(data))) 167 | st = 0 168 | total_number_error = 0.0 169 | total_p_error = 0.0 170 | total_error = 0.0 171 | print("dev data size {}".format(len(data))) 172 | while st < len(data): 173 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 174 | 175 | q_seq, history, label = to_batch_seq(data, perm, st, ed) 176 | q_emb_var, q_len = embed_layer.gen_x_q_batch(q_seq) 177 | hs_emb_var, hs_len = embed_layer.gen_x_history_batch(history) 178 | score = 0.0 179 | 180 | if component == "multi_sql": 181 | #none, except, intersect,union 182 | #truth B*index(0,1,2,3) 183 | # print("hs_len:{}".format(hs_len)) 184 | # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size())) 185 | mkw_emb_var = embed_layer.gen_word_list_embedding(["none","except","intersect","union"],(ed-st)) 186 | mkw_len = np.full(q_len.shape, 4,dtype=np.int64) 187 | # print("mkw_emb:{}".format(mkw_emb_var.size())) 188 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var=mkw_emb_var, mkw_len=mkw_len) 189 | elif component == "keyword": 190 | #where group by order by 191 | # [[0,1,2]] 192 | kw_emb_var = embed_layer.gen_word_list_embedding(["where", "group by", "order by"],(ed-st)) 193 | mkw_len = np.full(q_len.shape, 3, dtype=np.int64) 194 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var=kw_emb_var, kw_len=mkw_len) 195 | elif component == "col": 196 | #col word embedding 197 | # [[0,1,3]] 198 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 199 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 200 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len) 201 | elif component == "op": 202 | #B*index 203 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 204 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 205 | gt_col = np.zeros(q_len.shape,dtype=np.int64) 206 | # print(ed) 207 | index = 0 208 | for i in range(st,ed): 209 | # print(i) 210 | gt_col[index] = data[perm[i]]["gt_col"] 211 | index += 1 212 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 213 | 214 | elif component == "agg": 215 | # [[0,1,3]] 216 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 217 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 218 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 219 | # print(ed) 220 | index = 0 221 | for i in range(st, ed): 222 | # print(i) 223 | gt_col[index] = data[perm[i]]["gt_col"] 224 | index += 1 225 | 226 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 227 | 228 | elif component == "root_tem": 229 | #B*0/1 230 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 231 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 232 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 233 | # print(ed) 234 | index = 0 235 | for i in range(st, ed): 236 | # print(data[perm[i]]["history"]) 237 | gt_col[index] = data[perm[i]]["gt_col"] 238 | index += 1 239 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 240 | 241 | elif component == "des_asc": 242 | # B*0/1 243 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 244 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 245 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 246 | # print(ed) 247 | index = 0 248 | for i in range(st, ed): 249 | # print(i) 250 | gt_col[index] = data[perm[i]]["gt_col"] 251 | index += 1 252 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 253 | 254 | elif component == 'having': 255 | col_seq = to_batch_tables(data, perm, st, ed, table_type) 256 | col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch(col_seq) 257 | gt_col = np.zeros(q_len.shape, dtype=np.int64) 258 | # print(ed) 259 | index = 0 260 | for i in range(st, ed): 261 | # print(i) 262 | gt_col[index] = data[perm[i]]["gt_col"] 263 | index += 1 264 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col=gt_col) 265 | 266 | elif component == "andor": 267 | score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len) 268 | # print("label {}".format(label)) 269 | if component in ("agg","col","keyword","op"): 270 | num_err, p_err, err = model.check_acc(score, label) 271 | total_number_error += num_err 272 | total_p_error += p_err 273 | total_error += err 274 | else: 275 | err = model.check_acc(score, label) 276 | total_error += err 277 | st = ed 278 | 279 | if component in ("agg","col","keyword","op"): 280 | print("Dev {} acc number predict acc:{} partial acc: {} total acc: {}".format(component,1 - total_number_error*1.0/len(data),1 - total_p_error*1.0/len(data), 1 - total_error*1.0/len(data))) 281 | return 1 - total_error*1.0/len(data) 282 | else: 283 | print("Dev {} acc total acc: {}".format(component,1 - total_error*1.0/len(data))) 284 | return 1 - total_error*1.0/len(data) 285 | 286 | 287 | def timeout_handler(num, stack): 288 | print("Received SIGALRM") 289 | raise Exception("Timeout") 290 | 291 | ## used in test.py 292 | def test_acc(model, batch_size, data,output_path): 293 | table_dict = get_table_dict("./data/tables.json") 294 | f = open(output_path,"w") 295 | for item in data[:]: 296 | db_id = item["db_id"] 297 | if db_id not in table_dict: print "Error %s not in table_dict" % db_id 298 | # signal.signal(signal.SIGALRM, timeout_handler) 299 | # signal.alarm(2) # set timer to prevent infinite recursion in SQL generation 300 | sql = model.forward([item["question_toks"]]*batch_size,[],table_dict[db_id]) 301 | if sql is not None: 302 | print(sql) 303 | sql = model.gen_sql(sql,table_dict[db_id]) 304 | else: 305 | sql = "select a from b" 306 | print(sql) 307 | print("") 308 | f.write("{}\n".format(sql)) 309 | f.close() 310 | 311 | 312 | def load_word_emb(file_name, load_used=False, use_small=False): 313 | if not load_used: 314 | print ('Loading word embedding from %s'%file_name) 315 | ret = {} 316 | with open(file_name) as inf: 317 | for idx, line in enumerate(inf): 318 | if (use_small and idx >= 5000): 319 | break 320 | info = line.strip().split(' ') 321 | if info[0].lower() not in ret: 322 | ret[info[0]] = np.array(map(lambda x:float(x), info[1:])) 323 | return ret 324 | else: 325 | print ('Load used word embedding') 326 | with open('../alt/glove/word2idx.json') as inf: 327 | w2i = json.load(inf) 328 | with open('../alt/glove/usedwordemb.npy') as inf: 329 | word_emb_val = np.load(inf) 330 | return w2i, word_emb_val 331 | -------------------------------------------------------------------------------- /generate_wikisql_augment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | import traceback 5 | import os 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 10 | cond_ops = ['=', '>', '<', 'OP'] 11 | 12 | random.seed(0) 13 | 14 | 15 | class Column: 16 | ATTRIBUTE_TXT = "TXT" 17 | ATTRIBUTE_NUM = "NUM" 18 | ATTRIBUTE_GROUP_BY_ABLE = "GROUPBY" 19 | 20 | def __init__(self, name, natural_name, table=None, attributes=None): 21 | self.name = name 22 | self.natural_name = natural_name 23 | self.table = table 24 | if attributes is not None: 25 | self.attributes = attributes 26 | 27 | def __str__(self): 28 | return self.name + "||" + self.natural_name + "||" + str(self.attributes) 29 | 30 | found_path_error = 0 31 | 32 | class Table(object): 33 | def __init__(self, name, natural_name): 34 | self.name = name 35 | self.natural_name = natural_name 36 | self.foreign_keys = [] 37 | 38 | def add_foreign_key_to(self, my_col, their_col, that_table): 39 | self.foreign_keys.append((my_col, their_col, that_table)) 40 | 41 | def get_foreign_keys(self): 42 | return self.foreign_keys 43 | 44 | def __str__(self): 45 | return self.name + "||" + self.natural_name 46 | 47 | def __repr__(self): 48 | return self.name + "||" + self.natural_name 49 | 50 | def __hash__(self): 51 | val = 0 52 | for c in self.name: 53 | val = val * 10 + ord(c) 54 | return val 55 | 56 | def __eq__(self, rhs): 57 | return self.name == rhs.name 58 | 59 | def __ne__(self, rhs): 60 | return not self.name == rhs.name 61 | # return self.name + "||" + self.natural_name 62 | 63 | # as the column "*" in the data format is marked to be not belonging to any table 64 | # so here's a dummy table for that :( 65 | class DummyTable(Table): 66 | def add_foreign_key_to(self, my_col, their_col, that_table): 67 | pass 68 | 69 | def get_foreign_keys(self): 70 | return [] 71 | 72 | 73 | # models a requirement for a column, will be "attached" to a real column that satisfies the given attribute criteria 74 | class ColumnPlaceholder: 75 | # e.g. {COLUMN,2,TXT} 76 | def __init__(self, id_in_pattern, attributes): 77 | self.id_in_pattern = id_in_pattern 78 | self.attributes = attributes 79 | self.column = None 80 | 81 | def attach_to_column(self, column): 82 | self.column = column 83 | 84 | 85 | # modelling a SQL pattern along with a bunch of question patterns 86 | class Pattern: 87 | def __init__(self, schema, json_data): 88 | self.schema = schema 89 | self.raw_sql = json_data['SQL Pattern'] 90 | self.raw_questions = json_data['Question Patterns'] 91 | reference_id_to_original_id = json_data['Column Identity'] 92 | self.column_identity = {} 93 | 94 | for reference, original in reference_id_to_original_id.items(): 95 | rid = int(reference) 96 | oid = int(original) 97 | 98 | self.column_identity[rid] = oid 99 | 100 | raw_column_attributes = json_data['Column Attributes'] 101 | sorted_column_attributes = sorted( 102 | [(int(column_id), attributes) for column_id, attributes in raw_column_attributes.items()]) 103 | 104 | self.column_id_to_column_placeholders = {} 105 | self.column_placeholders = [] 106 | 107 | for column_id, attributes in sorted_column_attributes: 108 | # see if this references another column 109 | original_column_id = self.column_identity.get(column_id, None) 110 | if original_column_id is not None: 111 | self.column_id_to_column_placeholders[column_id] = self.column_id_to_column_placeholders[ 112 | original_column_id] 113 | continue 114 | 115 | # if this does not reference an existing column 116 | column_placeholder = ColumnPlaceholder(column_id, attributes) 117 | self.column_placeholders.append(column_placeholder) 118 | self.column_id_to_column_placeholders[column_id] = column_placeholder 119 | 120 | # given this pattern and a schema, see what new SQL-question pairs can we generate 121 | def populate(self): 122 | if self.raw_sql == "SELECT * {FROM, 0}": 123 | table_name = random.choice(self.schema.orginal_table) 124 | sql = "SELECT * FROM {}".format(table_name) 125 | return sql,[ 126 | "list all information about {} .".format(table_name), 127 | "Show everything on {}".format(table_name), 128 | "Return all columns in {} .".format(table_name) 129 | ] 130 | # find a column for each placeholder 131 | for column_placeholder in self.column_placeholders: 132 | all_permissible_columns = self.schema.get_columns_with_attributes(column_placeholder.attributes) 133 | if len(all_permissible_columns) == 0: 134 | raise Exception("No possible column found for column {} with required attributes: {}".format( 135 | column_placeholder.id_in_pattern, 136 | column_placeholder.attributes 137 | )) 138 | chosen_column = random.choice(all_permissible_columns) 139 | column_placeholder.attach_to_column(chosen_column) 140 | 141 | column_id_to_tn = {} 142 | 143 | ## generate processed SQL 144 | # start with the original (and replace stuff) 145 | generated_sql = self.raw_sql[:] 146 | 147 | # first identify the FROM replacement tokens 148 | replacements = [] 149 | for match in re.finditer("{FROM,[,0-9]+}", self.raw_sql): 150 | raw_from_token = match.group() 151 | split = raw_from_token[1:-1].split(',')[1:] # strip the brackets, then the "FROM" 152 | id_of_columns_involved = [int(x) for x in split] 153 | # print(id_of_columns_involved) 154 | # print(self.column_id_to_column_placeholders) 155 | # print(self.raw_sql) 156 | placeholders_of_columns_involved = [self.column_id_to_column_placeholders[x] for x in id_of_columns_involved] 157 | columns_used_for_this_from_clause = [x.column for x in placeholders_of_columns_involved] 158 | try: 159 | from_clause, table_to_tn = self.schema.generate_from_clause(columns_used_for_this_from_clause) 160 | except: 161 | # traceback.print_exc() 162 | # print("error generated join") 163 | # continue 164 | return "",[] 165 | # replace this {FROM..} with the generated FROM clause 166 | replacements.append((raw_from_token, from_clause)) 167 | 168 | # add the table_to_tn to our column_id to tn dict 169 | for column_id in id_of_columns_involved: 170 | column = self.column_id_to_column_placeholders[column_id].column 171 | try: 172 | tn = table_to_tn[column.table] 173 | except: 174 | global found_path_error 175 | found_path_error += 1 176 | # print("find path error {}".format(found_path_error)) 177 | # print "\n-----------------------" 178 | # print column 179 | # print column.table 180 | # print table_to_tn 181 | return "",[] 182 | # print column_id 183 | column_id_to_tn[column_id] = tn 184 | # print("column_identity:{}".format(self.column_identity)) 185 | # print("sql template:{}".format(generated_sql)) 186 | # print("column_id_to_tn {}".format(column_id_to_tn)) 187 | 188 | for original, new in replacements: 189 | generated_sql = re.sub(original, new, generated_sql) 190 | 191 | # then replace the column tokens 192 | replacements = [] 193 | val = None 194 | table_name = None 195 | # if self.raw_sql == "SELECT * {FROM, 0}": 196 | # print generated_sql 197 | for match in re.finditer("{[A-Z]+,[,0-9]+}", generated_sql): 198 | raw_column_token = match.group() 199 | type, column_id = raw_column_token[1:-1].split(',') 200 | column_id = int(column_id) 201 | 202 | if type == "COLUMN": 203 | # find out tn 204 | if column_id not in column_id_to_tn: 205 | column_id = self.column_identity[column_id] 206 | tn = column_id_to_tn[column_id] 207 | # find out column name 208 | column_name = self.column_id_to_column_placeholders[column_id].column.name 209 | result = "t{}.{}".format(tn, column_name) 210 | elif type == "VALUE": 211 | if column_id == 1: 212 | result = str(random.randint(1,101)) 213 | val = result 214 | elif type == "COLUMN_NAME": 215 | natural_name = self.column_id_to_column_placeholders[column_id].column.natural_name 216 | result = natural_name 217 | elif type == "TABLE_NAME": 218 | try: 219 | natural_name = self.column_id_to_column_placeholders[column_id].column.table.natural_name 220 | result = natural_name 221 | except: 222 | result = random.choice(self.schema.orginal_table) 223 | table_name = result 224 | else: 225 | raise Exception("Unknown type {} in type field".format(type)) 226 | 227 | replacements.append((raw_column_token, result)) 228 | 229 | for original, new in replacements: 230 | # print(original,new,generated_sql) 231 | generated_sql = re.sub(original, new, generated_sql) 232 | 233 | # up to this point, SQL processing is complete 234 | ## start processing questions 235 | generated_questions = [] 236 | for question_pattern in self.raw_questions: 237 | generated_question = question_pattern[:] 238 | replacements = [] 239 | for match in re.finditer("{[_A-Z]+,[0-9]+}", generated_question): 240 | raw_column_token = match.group() 241 | type, column_id = raw_column_token[1:-1].split(',') 242 | column_id = int(column_id) 243 | 244 | if type == "COLUMN": 245 | # find out tn 246 | tn = column_id_to_tn[column_id] 247 | # find out column name 248 | column_name = self.column_id_to_column_placeholders[column_id].column.name 249 | result = "t{}.{}".format(tn, column_name) 250 | elif type == "VALUE": 251 | result = val 252 | elif type == "COLUMN_NAME": 253 | natural_name = self.column_id_to_column_placeholders[column_id].column.natural_name 254 | result = natural_name 255 | elif type == "TABLE_NAME": 256 | try: 257 | natural_name = self.column_id_to_column_placeholders[column_id].column.table.natural_name 258 | result = natural_name 259 | except: 260 | if table_name: 261 | result = table_name 262 | else: 263 | result = random.choice(self.schema.orginal_table) 264 | else: 265 | raise Exception("Unknown type {} in type field".format(type)) 266 | 267 | replacements.append((raw_column_token, result)) 268 | 269 | for original, new in replacements: 270 | generated_question = re.sub(original, new, generated_question) 271 | 272 | generated_questions.append(generated_question) 273 | 274 | return generated_sql, generated_questions 275 | 276 | 277 | class Schema: 278 | def __init__(self, json_data): 279 | tables = [] 280 | table_index_to_table_object = {} 281 | table_name_to_table_object = {} 282 | next_table_index = 0 283 | self.orginal_table = json_data['table_names_original'] 284 | # dummy_table = DummyTable("dummy", "dummy") 285 | # table_index_to_table_object[-1] = dummy_table 286 | # tables.append(dummy_table) 287 | 288 | for table_name, table_name_natural in zip(json_data['table_names_original'], json_data['table_names']): 289 | table = Table(table_name, table_name_natural) 290 | tables.append(table) 291 | table_index_to_table_object[next_table_index] = table 292 | table_name_to_table_object[table_name] = table 293 | next_table_index += 1 294 | columns = [] 295 | column_and_table_name_to_column_object = {} # use table name as well to avoid collision 296 | for (table_index, column_name), column_type, column_names_natural in zip(json_data['column_names_original'], 297 | json_data['column_types'], 298 | json_data['column_names']): 299 | if table_index == -1: 300 | continue 301 | its_table = table_index_to_table_object[table_index] 302 | if column_type == "text": 303 | attributes = [Column.ATTRIBUTE_TXT] 304 | elif column_type == "number": 305 | attributes = [Column.ATTRIBUTE_NUM] 306 | else: 307 | attributes = [] 308 | column = Column(column_name, column_names_natural[1], table=its_table, attributes=attributes) 309 | column_and_table_name_to_column_object[(column_name, its_table.name)] = column 310 | columns.append(column) 311 | # print table_name_to_table_object 312 | for (from_table_name, from_column_name), (to_table_name, to_column_name) in json_data['foreign_keys']: 313 | from_table = table_name_to_table_object[from_table_name] 314 | from_column = column_and_table_name_to_column_object[(from_column_name, from_table_name)] 315 | to_table = table_name_to_table_object[to_table_name] 316 | to_column = column_and_table_name_to_column_object[(to_column_name, to_table_name)] 317 | 318 | from_table.add_foreign_key_to(from_column, to_column, to_table) 319 | to_table.add_foreign_key_to(to_column, from_column, from_table) 320 | 321 | self.all_columns = columns 322 | self.all_tables = tables 323 | 324 | # e.g. get all the numerical columns that can be group-by'ed over 325 | def get_columns_with_attributes(self, column_attributes=[]): 326 | results = [] 327 | for column in self.all_columns: 328 | # if the column has all the desired attributes 329 | if all([attribute in column.attributes for attribute in column_attributes]): 330 | results.append(column) 331 | 332 | return results 333 | 334 | class Join: 335 | def __init__(self, schema, starting_table): 336 | self.schema = schema 337 | self.starting_table = starting_table 338 | self.table_to_tn = {starting_table: 1} 339 | self.joins = [] 340 | 341 | def find_a_way_to_join(self, table): 342 | # if this table is already in our join 343 | if table in self.table_to_tn: 344 | return 345 | 346 | # BFS 347 | frontier = [] 348 | visited_tables = set() 349 | found_path = None 350 | for table in self.table_to_tn.keys(): 351 | visited_tables.add(table) 352 | for from_column, to_column, to_table in table.get_foreign_keys(): 353 | frontier.append((table, from_column, to_column, to_table, [])) 354 | while len(frontier) > 0: 355 | from_table, from_column, to_column, to_table, path = frontier.pop(0) 356 | # check if this foreign keys connects to the destination 357 | path.append((from_table, from_column, to_column, to_table)) 358 | if to_table == table: 359 | found_path = path 360 | break 361 | else: 362 | for next_from_column, next_to_column, next_to_table in to_table.get_foreign_keys(): 363 | frontier.append((to_table, next_from_column, next_to_column, next_to_table, path)) 364 | 365 | if found_path is None: 366 | # if a path is not found 367 | raise Exception( 368 | "A path could not be found from the current join {} to table {}".format(self.table_to_tn.keys(), 369 | table)) 370 | 371 | for from_table, from_column, to_column, to_table in found_path: 372 | # allocate a number like "t3" for the next table if necessary 373 | if to_table not in self.table_to_tn: 374 | self.table_to_tn[to_table] = len(self.table_to_tn) + 1 375 | self.joins.append((from_table, from_column, to_column, to_table)) 376 | 377 | def generate_from_clause(self): 378 | # if no join was needed (only one table) 379 | if len(self.joins) == 0: 380 | return "from {} as t1".format(self.starting_table.name) 381 | 382 | from_clause = "from {} as t{} ".format(self.joins[0][0].name, self.table_to_tn[self.joins[0][0]]) 383 | for from_table, from_column, to_column, to_table in self.joins[1:]: 384 | from_clause += ("join {} as t{}\non t{}.{} = t{}.{}".format( 385 | to_table.name, 386 | self.table_to_tn[to_table], 387 | self.table_to_tn[from_table], 388 | from_column.name, 389 | self.table_to_tn[to_table], 390 | to_column.name 391 | )) 392 | 393 | return from_clause 394 | 395 | # e.g. I used, doc_name and user_id, how should I write a from clause? 396 | # not only returning the from clause constructed, but also the mapping from doc_id to t1.doc_id 397 | def generate_from_clause(self, columns): 398 | join = self.Join(self, columns[0].table) 399 | for next_column in columns[1:]: 400 | join.find_a_way_to_join(next_column.table) 401 | 402 | return join.generate_from_clause(), join.table_to_tn 403 | 404 | 405 | def load_database_schema(path): 406 | data = json.load(open(path, "r")) 407 | schema = Schema(random.choice(data)) 408 | 409 | return schema 410 | 411 | 412 | def load_patterns(path, schema): 413 | data = json.load(open(path, "r")) 414 | patterns = [] 415 | for pattern_data in data: 416 | patterns.append(Pattern(schema, pattern_data)) 417 | 418 | return patterns 419 | 420 | def generate_every_db(db): 421 | db_name = db["db_id"] 422 | col_types = db["column_types"] 423 | if "number" in col_types: 424 | try: 425 | schema = Schema(db) 426 | except: 427 | traceback.print_exc() 428 | print("skip db {}".format(db_name)) 429 | return 430 | f = open("data_augment/{}.txt".format(db_name),"w") 431 | 432 | 433 | idx = 0 434 | patterns = load_patterns("data_augment/train_patterns.json", schema) 435 | 436 | while idx < 10: 437 | pattern = random.choice(patterns) 438 | try: 439 | sql, questions = pattern.populate() 440 | #for q in questions: 441 | if len(questions) != 0: 442 | f.write("{}. {}\n".format(1,random.choice(questions).encode("utf8"))) 443 | f.write("P:\n\n") 444 | f.write("{}\n\n".format(sql.encode("utf8"))) 445 | idx += 1 446 | except: 447 | pass 448 | f.close() 449 | 450 | # for pattern in patterns: 451 | # try: 452 | # sql, questions = pattern.populate() 453 | # except: 454 | # continue 455 | # # for q in questions: 456 | # if len(questions) == 0: 457 | # continue 458 | # f.write("{}. {}\n".format(idx,random.choice(questions))) 459 | # f.write("P:\n\n") 460 | # f.write("{}\n\n".format(sql)) 461 | # idx += 1 462 | 463 | if __name__ == "__main__": 464 | dbs = json.load(open("data_augment/wikisql_tables.json")) 465 | count = 0 466 | for db in dbs[:]: 467 | if count % 1000 == 0: 468 | print("processed {} files...".format(float(count)/len(dbs))) 469 | generate_every_db(db) 470 | count += 1 471 | -------------------------------------------------------------------------------- /process_sql.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Assumptions: 3 | # 1. sql is correct 4 | # 2. only table name has alias 5 | # 3. only one intersect/union/except 6 | # 7 | # val: number(float)/string(str)/sql(dict) 8 | # col_unit: (agg_id, col_id, isDistinct(bool)) 9 | # val_unit: (unit_op, col_unit1, col_unit2) 10 | # table_unit: (table_type, col_unit/sql) 11 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 13 | # sql { 14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 16 | # 'where': condition 17 | # 'groupBy': [col_unit1, col_unit2, ...] 18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 19 | # 'having': condition 20 | # 'limit': None/limit value 21 | # 'intersect': None/sql 22 | # 'except': None/sql 23 | # 'union': None/sql 24 | # } 25 | ################################ 26 | 27 | import json 28 | import sqlite3 29 | from nltk import word_tokenize 30 | 31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 32 | JOIN_KEYWORDS = ('join', 'on', 'as') 33 | 34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 35 | UNIT_OPS = ('none', '-', '+', "*", '/') 36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 37 | TABLE_TYPE = { 38 | 'sql': "sql", 39 | 'table_unit': "table_unit", 40 | } 41 | 42 | COND_OPS = ('and', 'or') 43 | SQL_OPS = ('intersect', 'union', 'except') 44 | ORDER_OPS = ('desc', 'asc') 45 | 46 | 47 | 48 | class Schema: 49 | """ 50 | Simple schema which maps table&column to a unique identifier 51 | """ 52 | def __init__(self, schema): 53 | self._schema = schema 54 | self._idMap = self._map(self._schema) 55 | 56 | @property 57 | def schema(self): 58 | return self._schema 59 | 60 | @property 61 | def idMap(self): 62 | return self._idMap 63 | 64 | def _map(self, schema): 65 | idMap = {'*': "__all__"} 66 | id = 1 67 | for key, vals in schema.iteritems(): 68 | for val in vals: 69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 70 | id += 1 71 | 72 | for key in schema: 73 | idMap[key.lower()] = "__" + key.lower() + "__" 74 | id += 1 75 | 76 | return idMap 77 | 78 | 79 | def get_schema(db): 80 | """ 81 | Get database's schema, which is a dict with table name as key 82 | and list of column names as value 83 | :param db: database path 84 | :return: schema dict 85 | """ 86 | 87 | schema = {} 88 | conn = sqlite3.connect(db) 89 | cursor = conn.cursor() 90 | 91 | # fetch table names 92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 93 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 94 | 95 | # fetch table info 96 | for table in tables: 97 | cursor.execute("PRAGMA table_info({})".format(table)) 98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 99 | 100 | return schema 101 | 102 | 103 | def get_schema_from_json(fpath): 104 | with open(fpath) as f: 105 | data = json.load(f) 106 | 107 | schema = {} 108 | for entry in data: 109 | table = str(entry['table'].lower()) 110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 111 | schema[table] = cols 112 | 113 | return schema 114 | 115 | 116 | def tokenize(string): 117 | string = str(string) 118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 121 | 122 | # keep string value as token 123 | vals = {} 124 | for i in range(len(quote_idxs)-1, -1, -2): 125 | qidx1 = quote_idxs[i-1] 126 | qidx2 = quote_idxs[i] 127 | val = string[qidx1: qidx2+1] 128 | key = "__val_{}_{}__".format(qidx1, qidx2) 129 | string = string[:qidx1] + key + string[qidx2+1:] 130 | vals[key] = val 131 | 132 | toks = [word.lower() for word in word_tokenize(string)] 133 | # replace with string value token 134 | for i in range(len(toks)): 135 | if toks[i] in vals: 136 | toks[i] = vals[toks[i]] 137 | 138 | # find if there exists !=, >=, <= 139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 140 | eq_idxs.reverse() 141 | prefix = ('!', '>', '<') 142 | for eq_idx in eq_idxs: 143 | pre_tok = toks[eq_idx-1] 144 | if pre_tok in prefix: 145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 146 | 147 | return toks 148 | 149 | 150 | def scan_alias(toks): 151 | """Scan the index of 'as' and build the map for all alias""" 152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 153 | alias = {} 154 | for idx in as_idxs: 155 | alias[toks[idx+1]] = toks[idx-1] 156 | return alias 157 | 158 | 159 | def get_tables_with_alias(schema, toks): 160 | tables = scan_alias(toks) 161 | for key in schema: 162 | assert key not in tables, "Alias {} has the same name in table".format(key) 163 | tables[key] = key 164 | return tables 165 | 166 | 167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 168 | """ 169 | :returns next idx, column id 170 | """ 171 | tok = toks[start_idx] 172 | if tok == "*": 173 | return start_idx + 1, schema.idMap[tok] 174 | 175 | if '.' in tok: # if token is a composite 176 | alias, col = tok.split('.') 177 | key = tables_with_alias[alias] + "." + col 178 | return start_idx+1, schema.idMap[key] 179 | 180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 181 | 182 | for alias in default_tables: 183 | table = tables_with_alias[alias] 184 | if tok in schema.schema[table]: 185 | key = table + "." + tok 186 | return start_idx+1, schema.idMap[key] 187 | 188 | assert False, "Error col: {}".format(tok) 189 | 190 | 191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 192 | """ 193 | :returns next idx, (agg_op id, col_id) 194 | """ 195 | idx = start_idx 196 | len_ = len(toks) 197 | isBlock = False 198 | isDistinct = False 199 | if toks[idx] == '(': 200 | isBlock = True 201 | idx += 1 202 | 203 | if toks[idx] in AGG_OPS: 204 | agg_id = AGG_OPS.index(toks[idx]) 205 | idx += 1 206 | assert idx < len_ and toks[idx] == '(' 207 | idx += 1 208 | if toks[idx] == "distinct": 209 | idx += 1 210 | isDistinct = True 211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 212 | assert idx < len_ and toks[idx] == ')' 213 | idx += 1 214 | return idx, (agg_id, col_id, isDistinct) 215 | 216 | if toks[idx] == "distinct": 217 | idx += 1 218 | isDistinct = True 219 | agg_id = AGG_OPS.index("none") 220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 221 | 222 | if isBlock: 223 | assert toks[idx] == ')' 224 | idx += 1 # skip ')' 225 | 226 | return idx, (agg_id, col_id, isDistinct) 227 | 228 | 229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 230 | idx = start_idx 231 | len_ = len(toks) 232 | isBlock = False 233 | if toks[idx] == '(': 234 | isBlock = True 235 | idx += 1 236 | 237 | col_unit1 = None 238 | col_unit2 = None 239 | unit_op = UNIT_OPS.index('none') 240 | 241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 242 | if idx < len_ and toks[idx] in UNIT_OPS: 243 | unit_op = UNIT_OPS.index(toks[idx]) 244 | idx += 1 245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 246 | 247 | if isBlock: 248 | assert toks[idx] == ')' 249 | idx += 1 # skip ')' 250 | 251 | return idx, (unit_op, col_unit1, col_unit2) 252 | 253 | 254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 255 | """ 256 | :returns next idx, table id, table name 257 | """ 258 | idx = start_idx 259 | len_ = len(toks) 260 | key = tables_with_alias[toks[idx]] 261 | 262 | if idx + 1 < len_ and toks[idx+1] == "as": 263 | idx += 3 264 | else: 265 | idx += 1 266 | 267 | return idx, schema.idMap[key], key 268 | 269 | 270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 271 | idx = start_idx 272 | len_ = len(toks) 273 | 274 | isBlock = False 275 | if toks[idx] == '(': 276 | isBlock = True 277 | idx += 1 278 | 279 | if toks[idx] == 'select': 280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 281 | elif "\"" in toks[idx]: # token is a string value 282 | val = toks[idx] 283 | idx += 1 284 | else: 285 | try: 286 | val = float(toks[idx]) 287 | idx += 1 288 | except: 289 | end_idx = idx 290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS: 292 | end_idx += 1 293 | 294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 295 | idx = end_idx 296 | 297 | if isBlock: 298 | assert toks[idx] == ')' 299 | idx += 1 300 | 301 | return idx, val 302 | # def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 303 | # idx = start_idx 304 | # len_ = len(toks) 305 | # 306 | # isBlock = False 307 | # if toks[idx] == '(': 308 | # isBlock = True 309 | # idx += 1 310 | # 311 | # if toks[idx] == 'select': 312 | # idx, val = parse_sql(toks, idx, tables_with_alias, schema) 313 | # elif "\"" in toks[idx]: # token is a string value 314 | # val = toks[idx] 315 | # idx += 1 316 | # else: 317 | # end_idx = idx 318 | # while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 319 | # and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS: 320 | # end_idx += 1 321 | # 322 | # tok = "".join(toks[idx: end_idx]) 323 | # val = tok 324 | # 325 | # try: 326 | # idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 327 | # except: 328 | # # print "Value is not a column" 329 | # try: 330 | # val = float(val) 331 | # except: 332 | # pass 333 | # # print "Value is not a number" 334 | # idx = end_idx 335 | # 336 | # if isBlock: 337 | # assert toks[idx] == ')' 338 | # idx += 1 339 | # 340 | # return idx, val 341 | 342 | 343 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 344 | idx = start_idx 345 | len_ = len(toks) 346 | conds = [] 347 | 348 | while idx < len_: 349 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 350 | not_op = False 351 | if toks[idx] == 'not': 352 | not_op = True 353 | idx += 1 354 | 355 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 356 | op_id = WHERE_OPS.index(toks[idx]) 357 | idx += 1 358 | val1 = val2 = None 359 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 360 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 361 | assert toks[idx] == 'and' 362 | idx += 1 363 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 364 | else: # normal case: single value 365 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 366 | val2 = None 367 | 368 | conds.append((not_op, op_id, val_unit, val1, val2)) 369 | 370 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 371 | break 372 | 373 | if idx < len_ and toks[idx] in COND_OPS: 374 | conds.append(toks[idx]) 375 | idx += 1 # skip and/or 376 | 377 | return idx, conds 378 | 379 | 380 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 381 | idx = start_idx 382 | len_ = len(toks) 383 | 384 | assert toks[idx] == 'select', "'select' not found" 385 | idx += 1 386 | isDistinct = False 387 | if idx < len_ and toks[idx] == 'distinct': 388 | idx += 1 389 | isDistinct = True 390 | val_units = [] 391 | 392 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 393 | agg_id = AGG_OPS.index("none") 394 | if toks[idx] in AGG_OPS: 395 | agg_id = AGG_OPS.index(toks[idx]) 396 | idx += 1 397 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 398 | val_units.append((agg_id, val_unit)) 399 | if idx < len_ and toks[idx] == ',': 400 | idx += 1 # skip ',' 401 | 402 | return idx, (isDistinct, val_units) 403 | 404 | 405 | def parse_from(toks, start_idx, tables_with_alias, schema): 406 | """ 407 | Assume in the from clause, all table units are combined with join 408 | """ 409 | assert 'from' in toks[start_idx:], "'from' not found" 410 | 411 | len_ = len(toks) 412 | idx = toks.index('from', start_idx) + 1 413 | default_tables = [] 414 | table_units = [] 415 | conds = [] 416 | 417 | while idx < len_: 418 | isBlock = False 419 | if toks[idx] == '(': 420 | isBlock = True 421 | idx += 1 422 | 423 | if toks[idx] == 'select': 424 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 425 | table_units.append((TABLE_TYPE['sql'], sql)) 426 | else: 427 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 428 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 429 | default_tables.append(table_name) 430 | if idx < len_ and toks[idx] == 'join': 431 | idx += 1 # skip join 432 | if idx < len_ and toks[idx] == "on": 433 | idx += 1 # skip on 434 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 435 | conds.extend(this_conds) 436 | 437 | if isBlock: 438 | assert toks[idx] == ')' 439 | idx += 1 440 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 441 | break 442 | 443 | return idx, table_units, conds, default_tables 444 | 445 | 446 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 447 | idx = start_idx 448 | len_ = len(toks) 449 | 450 | if idx >= len_ or toks[idx] != 'where': 451 | return idx, [] 452 | 453 | idx += 1 454 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 455 | return idx, conds 456 | 457 | 458 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 459 | idx = start_idx 460 | len_ = len(toks) 461 | col_units = [] 462 | 463 | if idx >= len_ or toks[idx] != 'group': 464 | return idx, col_units 465 | 466 | idx += 1 467 | assert toks[idx] == 'by' 468 | idx += 1 469 | 470 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 471 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 472 | col_units.append(col_unit) 473 | if idx < len_ and toks[idx] == ',': 474 | idx += 1 # skip ',' 475 | else: 476 | break 477 | 478 | return idx, col_units 479 | 480 | 481 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 482 | idx = start_idx 483 | len_ = len(toks) 484 | val_units = [] 485 | order_type = 'asc' # default type is 'asc' 486 | 487 | if idx >= len_ or toks[idx] != 'order': 488 | return idx, val_units 489 | 490 | idx += 1 491 | assert toks[idx] == 'by' 492 | idx += 1 493 | 494 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 495 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 496 | val_units.append(val_unit) 497 | if idx < len_ and toks[idx] in ORDER_OPS: 498 | order_type = toks[idx] 499 | idx += 1 500 | if idx < len_ and toks[idx] == ',': 501 | idx += 1 # skip ',' 502 | else: 503 | break 504 | 505 | return idx, (order_type, val_units) 506 | 507 | 508 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 509 | idx = start_idx 510 | len_ = len(toks) 511 | 512 | if idx >= len_ or toks[idx] != 'having': 513 | return idx, [] 514 | 515 | idx += 1 516 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 517 | return idx, conds 518 | 519 | 520 | def parse_limit(toks, start_idx): 521 | idx = start_idx 522 | len_ = len(toks) 523 | 524 | if idx < len_ and toks[idx] == 'limit': 525 | idx += 2 526 | return idx, int(toks[idx-1]) 527 | 528 | return idx, None 529 | 530 | 531 | def parse_sql(toks, start_idx, tables_with_alias, schema): 532 | isBlock = False # indicate whether this is a block of sql/sub-sql 533 | len_ = len(toks) 534 | idx = start_idx 535 | 536 | sql = {} 537 | if toks[idx] == '(': 538 | isBlock = True 539 | idx += 1 540 | 541 | # parse from clause in order to get default tables 542 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 543 | sql['from'] = {'table_units': table_units, 'conds': conds} 544 | # select clause 545 | _, select_col_units = parse_select(toks, start_idx, tables_with_alias, schema, default_tables) 546 | idx = from_end_idx 547 | sql['select'] = select_col_units 548 | # where clause 549 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 550 | sql['where'] = where_conds 551 | # group by clause 552 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 553 | sql['groupBy'] = group_col_units 554 | # order by clause 555 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 556 | sql['orderBy'] = order_col_units 557 | # having clause 558 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 559 | sql['having'] = having_conds 560 | # limit clause 561 | idx, limit_val = parse_limit(toks, idx) 562 | sql['limit'] = limit_val 563 | 564 | if isBlock: 565 | assert toks[idx] == ')' 566 | idx += 1 # skip ')' 567 | 568 | # intersect/union/except clause 569 | for op in SQL_OPS: # initialize IUE 570 | sql[op] = None 571 | if idx < len_ and toks[idx] in SQL_OPS: 572 | sql_op = toks[idx] 573 | idx += 1 574 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 575 | sql[sql_op] = IUE_sql 576 | return idx, sql 577 | 578 | 579 | def load_data(fpath): 580 | with open(fpath) as f: 581 | data = json.load(f) 582 | return data 583 | 584 | 585 | def get_sql(schema, query): 586 | toks = tokenize(query) 587 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 588 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 589 | 590 | return sql 591 | 592 | if __name__ == '__main__': 593 | # print get_schema('art_1.sqlite') 594 | # fpath = '/Users/zilinzhang/Workspace/Github/nl2sql/Data/Initial/table/art_1_table.json' 595 | # print schema 596 | 597 | # schema = Schema(get_schema('art_1.sqlite')) 598 | # print schema.schema 599 | schema = {"paragraphs": ["paragraph_text","paragraph_id", "document_id"], "documents": ["document_id", "document_name"]} 600 | schema = Schema(schema) 601 | # print schema.idMap 602 | data = ["test1"] 603 | # data = load_data("/Users/zilinzhang/Workspace/Github/nl2sql/Data/Processed/train/art_1_processed.json") 604 | for ix, entry in enumerate(data): 605 | # query = entry["query"] 606 | # query = "SELECT template_id FROM Templates WHERE template_type_code = \"PP\" OR template_type_code = \"PPT\"" 607 | # query = "SELECT count(*) FROM Paragraphs AS T1 JOIN Documents AS T2 ON T1.document_ID = T2.document_ID WHERE T2.document_name = 'Summer Show'" 608 | query = "SELECT T1.paragraph_id , T1.paragraph_text FROM Paragraphs AS T1 JOIN Documents AS T2 ON T1.document_id = T2.document_id WHERE T2.Document_Name = 'Welcome to NY'" 609 | toks = tokenize(query) 610 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 611 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 612 | print sql 613 | break 614 | -------------------------------------------------------------------------------- /preprocess_train_dev_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python3 preprocess_train_dev_data.py train|dev (full|part) 3 | ''' 4 | 5 | import json 6 | import sys 7 | from collections import defaultdict 8 | 9 | ###TODO: change dirs 10 | train_data_path = "./data/train.json" 11 | table_data_path = "./data/tables.json" 12 | if train_dev == "dev": 13 | train_data_path = "./data/dev.json" 14 | 15 | train_dev = "train" 16 | if len(sys.argv) > 1: 17 | train_dev = sys.argv[1] 18 | train_data = json.load(open(train_data_path)) 19 | history_option = "full" 20 | if len(sys.argv) > 2: 21 | history_option = sys.argv[2] 22 | 23 | OLD_WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 24 | NEW_WHERE_OPS = ('=','>','<','>=','<=','!=','like','not in','in','between','is') 25 | NEW_WHERE_DICT = { 26 | '=': 0, 27 | '>': 1, 28 | '<': 2, 29 | '>=': 3, 30 | '<=': 4, 31 | '!=': 5, 32 | 'like': 6, 33 | 'not in': 7, 34 | 'in': 8, 35 | 'between': 9, 36 | 'is':10 37 | } 38 | # SQL_OPS = ('none','intersect', 'union', 'except') 39 | SQL_OPS = { 40 | 'none': 0, 41 | 'intersect': 1, 42 | 'union': 2, 43 | 'except': 3 44 | } 45 | KW_DICT = { 46 | 'where': 0, 47 | 'groupBy': 1, 48 | 'orderBy': 2 49 | } 50 | ORDER_OPS = { 51 | 'desc': 0, 52 | 'asc': 1} 53 | AGG_OPS = ('none','max', 'min', 'count', 'sum', 'avg') 54 | 55 | COND_OPS = { 56 | 'and':0, 57 | 'or':1 58 | } 59 | 60 | def convert_to_op_index(is_not,op): 61 | op = OLD_WHERE_OPS[op] 62 | if is_not and op == "in": 63 | return 7 64 | try: 65 | return NEW_WHERE_DICT[op] 66 | except: 67 | print("Unsupport op: {}".format(op)) 68 | return -1 69 | 70 | def index_to_column_name(index, table): 71 | column_name = table["column_names"][index][1] 72 | table_index = table["column_names"][index][0] 73 | table_name = table["table_names"][table_index] 74 | return table_name, column_name, index 75 | 76 | 77 | def get_label_cols(with_join,fk_dict,labels): 78 | # list(set([l[1][i][0][2] for i in range(min(len(l[1]), 3))])) 79 | cols = set() 80 | ret = [] 81 | for i in range(len(labels)): 82 | cols.add(labels[i][0][2]) 83 | if len(cols) > 3: 84 | break 85 | for col in cols: 86 | # ret.append([col]) 87 | if with_join and len(fk_dict[col]) > 0: 88 | ret.append([col]+fk_dict[col]) 89 | else: 90 | ret.append(col) 91 | return ret 92 | 93 | class MultiSqlPredictor: 94 | def __init__(self, question, sql, history): 95 | self.sql = sql 96 | self.question = question 97 | self.history = history 98 | self.keywords = ('intersect', 'except', 'union') 99 | 100 | def generate_output(self): 101 | for key in self.sql: 102 | if key in self.keywords and self.sql[key]: 103 | return self.history + ['root'], key, self.sql[key] 104 | return self.history + ['root'], 'none', self.sql 105 | 106 | 107 | class KeyWordPredictor: 108 | def __init__(self, question, sql, history): 109 | self.sql = sql 110 | self.question = question 111 | self.history = history 112 | self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'limit', 'having') 113 | 114 | def generate_output(self): 115 | sql_keywords = [] 116 | for key in self.sql: 117 | if key in self.keywords and self.sql[key]: 118 | sql_keywords.append(key) 119 | return self.history, [len(sql_keywords), sql_keywords], self.sql 120 | 121 | 122 | class ColPredictor: 123 | def __init__(self, question, sql, table, history,kw=None): 124 | self.sql = sql 125 | self.question = question 126 | self.history = history 127 | self.table = table 128 | self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having') 129 | self.kw = kw 130 | 131 | def generate_output(self): 132 | ret = [] 133 | candidate_keys = self.sql.keys() 134 | if self.kw: 135 | candidate_keys = [self.kw] 136 | for key in candidate_keys: 137 | if key in self.keywords and self.sql[key]: 138 | cols = [] 139 | sqls = [] 140 | if key == 'groupBy': 141 | sql_cols = self.sql[key] 142 | for col in sql_cols: 143 | cols.append((index_to_column_name(col[1], self.table), col[2])) 144 | sqls.append(col) 145 | elif key == 'orderBy': 146 | sql_cols = self.sql[key][1] 147 | for col in sql_cols: 148 | cols.append((index_to_column_name(col[1][1], self.table), col[1][2])) 149 | sqls.append(col) 150 | elif key == 'select': 151 | sql_cols = self.sql[key][1] 152 | for col in sql_cols: 153 | cols.append((index_to_column_name(col[1][1][1], self.table), col[1][1][2])) 154 | sqls.append(col) 155 | elif key == 'where' or key == 'having': 156 | sql_cols = self.sql[key] 157 | for col in sql_cols: 158 | if not isinstance(col, list): 159 | continue 160 | try: 161 | cols.append((index_to_column_name(col[2][1][1], self.table), col[2][1][2])) 162 | except: 163 | print("Key:{} Col:{} Question:{}".format(key, col, self.question)) 164 | sqls.append(col) 165 | ret.append(( 166 | self.history + [key], (len(cols), cols), sqls 167 | )) 168 | return ret 169 | # ret.append(history+[key],) 170 | 171 | 172 | class OpPredictor: 173 | def __init__(self, question, sql, history): 174 | self.sql = sql 175 | self.question = question 176 | self.history = history 177 | # self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having') 178 | 179 | def generate_output(self): 180 | return self.history, convert_to_op_index(self.sql[0],self.sql[1]), (self.sql[3], self.sql[4]) 181 | 182 | 183 | class AggPredictor: 184 | def __init__(self, question, sql, history,kw=None): 185 | self.sql = sql 186 | self.question = question 187 | self.history = history 188 | self.kw = kw 189 | def generate_output(self): 190 | label = -1 191 | if self.kw: 192 | key = self.kw 193 | else: 194 | key = self.history[-2] 195 | if key == 'select': 196 | label = self.sql[0] 197 | elif key == 'orderBy': 198 | label = self.sql[1][0] 199 | elif key == 'having': 200 | label = self.sql[2][1][0] 201 | return self.history, label 202 | 203 | 204 | # class RootTemPredictor: 205 | # def __init__(self, question, sql): 206 | # self.sql = sql 207 | # self.question = question 208 | # self.keywords = ('intersect', 'except', 'union') 209 | # 210 | # def generate_output(self): 211 | # for key in self.sql: 212 | # if key in self.keywords: 213 | # return ['ROOT'], key, self.sql[key] 214 | # return ['ROOT'], 'none', self.sql 215 | 216 | 217 | class DesAscPredictor: 218 | def __init__(self, question, sql, table, history): 219 | self.sql = sql 220 | self.question = question 221 | self.history = history 222 | self.table = table 223 | 224 | def generate_output(self): 225 | for key in self.sql: 226 | if key == "orderBy" and self.sql[key]: 227 | # self.history.append(key) 228 | try: 229 | col = self.sql[key][1][0][1][1] 230 | except: 231 | print("question:{} sql:{}".format(self.question, self.sql)) 232 | # self.history.append(index_to_column_name(col, self.table)) 233 | # self.history.append(self.sql[key][1][0][1][0]) 234 | if self.sql[key][0] == "asc" and self.sql["limit"]: 235 | label = 0 236 | elif self.sql[key][0] == "asc" and not self.sql["limit"]: 237 | label = 1 238 | elif self.sql[key][0] == "desc" and self.sql["limit"]: 239 | label = 2 240 | else: 241 | label = 3 242 | return self.history+[index_to_column_name(col, self.table),self.sql[key][1][0][1][0]], label 243 | 244 | 245 | class AndOrPredictor: 246 | def __init__(self, question, sql, table, history): 247 | self.sql = sql 248 | self.question = question 249 | self.history = history 250 | self.table = table 251 | 252 | def generate_output(self): 253 | if 'where' in self.sql and self.sql['where'] and len(self.sql['where']) > 1: 254 | return self.history,COND_OPS[self.sql['where'][1]] 255 | return self.history,-1 256 | 257 | 258 | def parser_item_with_long_history(question_tokens, sql, table, history, dataset): 259 | table_schema = [ 260 | table["table_names"], 261 | table["column_names"], 262 | table["column_types"] 263 | ] 264 | stack = [("root",sql)] 265 | with_join = False 266 | fk_dict = defaultdict(list) 267 | for fk in table["foreign_keys"]: 268 | fk_dict[fk[0]].append(fk[1]) 269 | fk_dict[fk[1]].append(fk[0]) 270 | while len(stack) > 0: 271 | node = stack.pop() 272 | if node[0] == "root": 273 | history, label, sql = MultiSqlPredictor(question_tokens, node[1], history).generate_output() 274 | dataset['multi_sql_dataset'].append({ 275 | "question_tokens": question_tokens, 276 | "ts": table_schema, 277 | "history": history[:], 278 | "label": SQL_OPS[label] 279 | }) 280 | history.append(label) 281 | if label == "none": 282 | stack.append((label,sql)) 283 | else: 284 | node[1][label] = None 285 | stack.append((label, node[1],sql)) 286 | # if label != "none": 287 | # stack.append(("none",node[1])) 288 | elif node[0] in ('intersect', 'except', 'union'): 289 | stack.append(("root",node[1])) 290 | stack.append(("root",node[2])) 291 | elif node[0] == "none": 292 | with_join = len(node[1]["from"]["table_units"]) > 1 293 | history, label, sql = KeyWordPredictor(question_tokens, node[1], history).generate_output() 294 | label_idxs = [] 295 | for item in label[1]: 296 | if item in KW_DICT: 297 | label_idxs.append(KW_DICT[item]) 298 | label_idxs.sort() 299 | dataset['keyword_dataset'].append({ 300 | "question_tokens": question_tokens, 301 | "ts": table_schema, 302 | "history": history[:], 303 | "label": label_idxs 304 | }) 305 | if "having" in label[1]: 306 | stack.append(("having",node[1])) 307 | if "orderBy" in label[1]: 308 | stack.append(("orderBy",node[1])) 309 | if "groupBy" in label[1]: 310 | if "having" in label[1]: 311 | dataset['having_dataset'].append({ 312 | "question_tokens": question_tokens, 313 | "ts": table_schema, 314 | "history": history[:], 315 | "gt_col":node[1]["groupBy"][0][1], 316 | "label": 1 317 | }) 318 | else: 319 | dataset['having_dataset'].append({ 320 | "question_tokens": question_tokens, 321 | "ts": table_schema, 322 | "history": history[:], 323 | "gt_col":node[1]["groupBy"][0][1], 324 | "label": 0 325 | }) 326 | stack.append(("groupBy",node[1])) 327 | if "where" in label[1]: 328 | stack.append(("where",node[1])) 329 | if "select" in label[1]: 330 | stack.append(("select",node[1])) 331 | elif node[0] in ("select","having","orderBy"): 332 | # if node[0] != "orderBy": 333 | history.append(node[0]) 334 | if node[0] == "orderBy": 335 | orderby_ret = DesAscPredictor(question_tokens, node[1], table, history).generate_output() 336 | if orderby_ret: 337 | dataset['des_asc_dataset'].append({ 338 | "question_tokens": question_tokens, 339 | "ts": table_schema, 340 | "history": orderby_ret[0], 341 | "gt_col":node[1]["orderBy"][1][0][1][1], 342 | "label": orderby_ret[1] 343 | }) 344 | # history.append(orderby_ret[1]) 345 | col_ret = ColPredictor(question_tokens, node[1], table, history,node[0]).generate_output() 346 | agg_col_dict = dict() 347 | op_col_dict = dict() 348 | for h, l, s in col_ret: 349 | if l[0] == 0: 350 | print("Warning: predicted 0 columns!") 351 | continue 352 | dataset['col_dataset'].append({ 353 | "question_tokens": question_tokens, 354 | "ts": table_schema, 355 | "history": history[:], 356 | "label":get_label_cols(with_join,fk_dict,l[1]) 357 | }) 358 | for col, sql_item in zip(l[1], s): 359 | key = "{}{}{}".format(col[0][0],col[0][1],col[0][2]) 360 | if key not in agg_col_dict: 361 | agg_col_dict[key] = [(sql_item,col[0])] 362 | else: 363 | agg_col_dict[key].append((sql_item,col[0])) 364 | if key not in op_col_dict: 365 | op_col_dict[key] = [(sql_item,col[0])] 366 | else: 367 | op_col_dict[key].append((sql_item,col[0])) 368 | for key in agg_col_dict: 369 | stack.append(("col",node[0],agg_col_dict[key],op_col_dict[key])) 370 | elif node[0] == "col": 371 | history.append(node[2][0][1]) 372 | if node[1] == "where": 373 | stack.append(("op",node[2],"where")) 374 | else: 375 | labels = [] 376 | for sql_item,col in node[2]: 377 | _, label = AggPredictor(question_tokens, sql_item, history,node[1]).generate_output() 378 | if label-1 >= 0: 379 | labels.append(label-1) 380 | 381 | # print(node[2][0][1][2]) 382 | dataset['agg_dataset'].append({ 383 | "question_tokens": question_tokens, 384 | "ts": table_schema, 385 | "history": history[:], 386 | "gt_col":node[2][0][1][2], 387 | "label": labels[:min(len(labels),3)] 388 | }) 389 | if node[1] == "having": 390 | stack.append(("op", node[2], "having")) 391 | # if len(labels) == 0: 392 | # history.append("none") 393 | # else: 394 | if len(labels) > 0: 395 | history.append(AGG_OPS[labels[0]+1]) 396 | elif node[0] == "op": 397 | # history.append(node[1][0][1]) 398 | labels = [] 399 | # if len(labels) > 2: 400 | # print(question_tokens) 401 | dataset['op_dataset'].append({ 402 | "question_tokens": question_tokens, 403 | "ts": table_schema, 404 | "history": history[:], 405 | "gt_col": node[1][0][1][2], 406 | "label": labels 407 | }) 408 | 409 | for sql_item,col in node[1]: 410 | _, label, s = OpPredictor(question_tokens, sql_item, history).generate_output() 411 | if label != -1: 412 | labels.append(label) 413 | history.append(NEW_WHERE_OPS[label]) 414 | if isinstance(s[0], dict): 415 | stack.append(("root",s[0])) 416 | # history.append("root") 417 | dataset['root_tem_dataset'].append({ 418 | "question_tokens": question_tokens, 419 | "ts": table_schema, 420 | "history": history[:], 421 | "gt_col": node[1][0][1][2], 422 | "label": 0 423 | }) 424 | else: 425 | dataset['root_tem_dataset'].append({ 426 | "question_tokens": question_tokens, 427 | "ts": table_schema, 428 | "history": history[:], 429 | "gt_col": node[1][0][1][2], 430 | "label": 1 431 | }) 432 | # history.append("terminal") 433 | if len(labels) > 2: 434 | print(question_tokens) 435 | dataset['op_dataset'][-1]["label"] = labels 436 | elif node[0] == "where": 437 | history.append(node[0]) 438 | hist, label = AndOrPredictor(question_tokens, node[1], table, history).generate_output() 439 | if label != -1: 440 | dataset['andor_dataset'].append({ 441 | "question_tokens": question_tokens, 442 | "ts": table_schema, 443 | "history": history[:], 444 | "label":label 445 | }) 446 | col_ret = ColPredictor(question_tokens, node[1], table, history, "where").generate_output() 447 | op_col_dict = dict() 448 | for h, l, s in col_ret: 449 | if l[0] == 0: 450 | print("Warning: predicted 0 columns!") 451 | continue 452 | dataset['col_dataset'].append({ 453 | "question_tokens": question_tokens, 454 | "ts": table_schema, 455 | "history": history[:], 456 | "label": get_label_cols(with_join,fk_dict,l[1]) 457 | }) 458 | for col, sql_item in zip(l[1], s): 459 | key = "{}{}{}".format(col[0][0], col[0][1], col[0][2]) 460 | if key not in op_col_dict: 461 | op_col_dict[key] = [(sql_item, col[0])] 462 | else: 463 | op_col_dict[key].append((sql_item, col[0])) 464 | for key in op_col_dict: 465 | stack.append(("col", "where", op_col_dict[key])) 466 | elif node[0] == "groupBy": 467 | history.append(node[0]) 468 | col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output() 469 | agg_col_dict = dict() 470 | for h, l, s in col_ret: 471 | if l[0] == 0: 472 | print("Warning: predicted 0 columns!") 473 | continue 474 | dataset['col_dataset'].append({ 475 | "question_tokens": question_tokens, 476 | "ts": table_schema, 477 | "history": history[:], 478 | "label": get_label_cols(with_join,fk_dict,l[1]) 479 | }) 480 | for col, sql_item in zip(l[1], s): 481 | key = "{}{}{}".format(col[0][0], col[0][1], col[0][2]) 482 | if key not in agg_col_dict: 483 | agg_col_dict[key] = [(sql_item, col[0])] 484 | else: 485 | agg_col_dict[key].append((sql_item, col[0])) 486 | for key in agg_col_dict: 487 | stack.append(("col", node[0], agg_col_dict[key])) 488 | 489 | 490 | 491 | def parser_item(question_tokens, sql, table, history, dataset): 492 | # try: 493 | # question_tokens = item['question_toks'] 494 | # except: 495 | # print(item) 496 | # sql = item['sql'] 497 | table_schema = [ 498 | table["table_names"], 499 | table["column_names"], 500 | table["column_types"] 501 | ] 502 | history, label, sql = MultiSqlPredictor(question_tokens, sql, history).generate_output() 503 | dataset['multi_sql_dataset'].append({ 504 | "question_tokens": question_tokens, 505 | "ts": table_schema, 506 | "history": history[:], 507 | "label": SQL_OPS[label] 508 | }) 509 | history.append(label) 510 | history, label, sql = KeyWordPredictor(question_tokens, sql, history).generate_output() 511 | label_idxs = [] 512 | for item in label[1]: 513 | if item in KW_DICT: 514 | label_idxs.append(KW_DICT[item]) 515 | label_idxs.sort() 516 | dataset['keyword_dataset'].append({ 517 | "question_tokens": question_tokens, 518 | "ts": table_schema, 519 | "history": history[:], 520 | "label": label_idxs 521 | }) 522 | hist,label = AndOrPredictor(question_tokens,sql,table,history).generate_output() 523 | if label != -1: 524 | dataset['andor_dataset'].append({ 525 | "question_tokens": question_tokens, 526 | "ts": table_schema, 527 | "history": hist[:]+["where"], 528 | "label": label 529 | }) 530 | orderby_ret = DesAscPredictor(question_tokens, sql, table, history).generate_output() 531 | if orderby_ret: 532 | dataset['des_asc_dataset'].append({ 533 | "question_tokens": question_tokens, 534 | "ts": table_schema, 535 | "history": orderby_ret[0][:], 536 | "label": orderby_ret[1] 537 | }) 538 | col_ret = ColPredictor(question_tokens, sql, table, history).generate_output() 539 | agg_candidates = [] 540 | op_candidates = [] 541 | for h, l, s in col_ret: 542 | if l[0] == 0: 543 | print("Warning: predicted 0 columns!") 544 | continue 545 | dataset['col_dataset'].append({ 546 | "question_tokens": question_tokens, 547 | "ts": table_schema, 548 | "history": h[:], 549 | "label": list(set([l[1][i][0][2] for i in range(min(len(l[1]),3))])) 550 | }) 551 | for col, sql_item in zip(l[1], s): 552 | if h[-1] in ('where', 'having'): 553 | op_candidates.append((h + [col[0]], sql_item)) 554 | if h[-1] in ('select', 'orderBy', 'having'): 555 | agg_candidates.append((h + [col[0]], sql_item)) 556 | if h[-1] == "groupBy": 557 | label = 0 558 | if sql["having"]: 559 | label = 1 560 | dataset['having_dataset'].append({ 561 | "question_tokens": question_tokens, 562 | "ts": table_schema, 563 | "history": h[:] + [col[0]], 564 | "label": label 565 | }) 566 | 567 | op_col_dict = dict() 568 | for h, sql_item in op_candidates: 569 | _, label, s = OpPredictor(question_tokens, sql_item, h).generate_output() 570 | if label == -1: 571 | continue 572 | key = "{}{}".format(h[-2], h[-1][2]) 573 | label = NEW_WHERE_OPS[label] 574 | if key in op_col_dict: 575 | op_col_dict[key][1].append(label) 576 | else: 577 | op_col_dict[key] = [h[:], [label]] 578 | # dataset['op_dataset'].append({ 579 | # "question_tokens": question_tokens, 580 | # "ts": table_schema, 581 | # "history": h[:], 582 | # "label": label 583 | # }) 584 | if isinstance(s[0], dict): 585 | dataset['root_tem_dataset'].append({ 586 | "question_tokens": question_tokens, 587 | "ts": table_schema, 588 | "history": h[:] + [label], 589 | "label": 0 590 | }) 591 | parser_item(question_tokens, s[0], table, h[:] + [label], dataset) 592 | else: 593 | dataset['root_tem_dataset'].append({ 594 | "question_tokens": question_tokens, 595 | "ts": table_schema, 596 | "history": h[:] + [label], 597 | "label": 1 598 | }) 599 | for key in op_col_dict: 600 | # if len(op_col_dict[key][1]) > 1: 601 | # print("same col has mult op ") 602 | dataset['op_dataset'].append({ 603 | "question_tokens": question_tokens, 604 | "ts": table_schema, 605 | "history": op_col_dict[key][0], 606 | "label": op_col_dict[key][1] 607 | }) 608 | agg_col_dict = dict() 609 | for h, sql_item in agg_candidates: 610 | _, label = AggPredictor(question_tokens, sql_item, h).generate_output() 611 | if label != 5: 612 | key = "{}{}".format(h[-2], h[-1][2]) 613 | if key in agg_col_dict: 614 | agg_col_dict[key][1].append(label) 615 | else: 616 | agg_col_dict[key] = [h[:], [label]] 617 | for key in agg_col_dict: 618 | # if 5 in agg_col_dict[key][1]: 619 | # print("none in agg label!!!") 620 | dataset['agg_dataset'].append({ 621 | "question_tokens": question_tokens, 622 | "ts": table_schema, 623 | "history": agg_col_dict[key][0], 624 | "label": agg_col_dict[key][1] 625 | }) 626 | 627 | 628 | def get_table_dict(table_data_path): 629 | data = json.load(open(table_data_path)) 630 | table = dict() 631 | for item in data: 632 | table[item["db_id"]] = item 633 | return table 634 | 635 | 636 | def parse_data(data): 637 | dataset = { 638 | "multi_sql_dataset": [], 639 | "keyword_dataset": [], 640 | "col_dataset": [], 641 | "op_dataset": [], 642 | "agg_dataset": [], 643 | "root_tem_dataset": [], 644 | "des_asc_dataset": [], 645 | "having_dataset": [], 646 | "andor_dataset":[] 647 | } 648 | table_dict = get_table_dict(table_data_path) 649 | for item in data: 650 | if history_option == "full": 651 | # parser_item(item["question_toks"], item["sql"], table_dict[item["db_id"]], [], dataset) 652 | parser_item_with_long_history(item["question_toks"], item["sql"], table_dict[item["db_id"]], [], dataset) 653 | else: 654 | parser_item(item["question_toks"], item["sql"], table_dict[item["db_id"]], [], dataset) 655 | print("finished preprocess") 656 | for key in dataset: 657 | print("dataset:{} size:{}".format(key, len(dataset[key]))) 658 | json.dump(dataset[key], open("./generated_data/{}_{}_{}.json".format(history_option,train_dev, key), "w"), indent=2) 659 | 660 | 661 | if __name__ == '__main__': 662 | parse_data(train_data) 663 | -------------------------------------------------------------------------------- /supermodel.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import datetime 4 | import time 5 | import argparse 6 | import numpy as np 7 | import torch.nn as nn 8 | import traceback 9 | from collections import defaultdict 10 | 11 | from utils import * 12 | from word_embedding import WordEmbedding 13 | from models.agg_predictor import AggPredictor 14 | from models.col_predictor import ColPredictor 15 | from models.desasc_limit_predictor import DesAscLimitPredictor 16 | from models.having_predictor import HavingPredictor 17 | from models.keyword_predictor import KeyWordPredictor 18 | from models.multisql_predictor import MultiSqlPredictor 19 | from models.root_teminal_predictor import RootTeminalPredictor 20 | from models.andor_predictor import AndOrPredictor 21 | from models.op_predictor import OpPredictor 22 | from preprocess_train_dev_data import index_to_column_name 23 | 24 | 25 | SQL_OPS = ('none','intersect', 'union', 'except') 26 | KW_OPS = ('where','groupBy','orderBy') 27 | AGG_OPS = ('max', 'min', 'count', 'sum', 'avg') 28 | ROOT_TERM_OPS = ("root","terminal") 29 | COND_OPS = ("and","or") 30 | DEC_ASC_OPS = (("asc",True),("asc",False),("desc",True),("desc",False)) 31 | NEW_WHERE_OPS = ('=','>','<','>=','<=','!=','like','not in','in','between') 32 | KW_WITH_COL = ("select","where","groupBy","orderBy","having") 33 | class Stack: 34 | def __init__(self): 35 | self.items = [] 36 | 37 | def isEmpty(self): 38 | return self.items == [] 39 | 40 | def push(self, item): 41 | self.items.append(item) 42 | 43 | def pop(self): 44 | return self.items.pop() 45 | 46 | def peek(self): 47 | return self.items[len(self.items)-1] 48 | 49 | def size(self): 50 | return len(self.items) 51 | 52 | def insert(self,i,x): 53 | return self.items.insert(i,x) 54 | 55 | 56 | def to_batch_tables(tables, B, table_type): 57 | # col_lens = [] 58 | col_seq = [] 59 | ts = [tables["table_names"],tables["column_names"],tables["column_types"]] 60 | tname_toks = [x.split(" ") for x in ts[0]] 61 | col_type = ts[2] 62 | cols = [x.split(" ") for xid, x in ts[1]] 63 | tab_seq = [xid for xid, x in ts[1]] 64 | cols_add = [] 65 | for tid, col, ct in zip(tab_seq, cols, col_type): 66 | col_one = [ct] 67 | if tid == -1: 68 | tabn = ["all"] 69 | else: 70 | if table_type=="no": tabn = [] 71 | else: tabn = tname_toks[tid] 72 | for t in tabn: 73 | if t not in col: 74 | col_one.append(t) 75 | col_one.extend(col) 76 | cols_add.append(col_one) 77 | 78 | col_seq = [cols_add] * B 79 | 80 | return col_seq 81 | 82 | class SuperModel(nn.Module): 83 | def __init__(self, word_emb, N_word, N_h=300, N_depth=2, gpu=True, trainable_emb=False, table_type="std", use_hs=True): 84 | super(SuperModel, self).__init__() 85 | self.gpu = gpu 86 | self.N_h = N_h 87 | self.N_depth = N_depth 88 | self.trainable_emb = trainable_emb 89 | self.table_type = table_type 90 | self.use_hs = use_hs 91 | self.SQL_TOK = ['', '', 'WHERE', 'AND', 'EQL', 'GT', 'LT', ''] 92 | 93 | # word embedding layer 94 | self.embed_layer = WordEmbedding(word_emb, N_word, gpu, 95 | self.SQL_TOK, trainable=trainable_emb) 96 | 97 | # initial all modules 98 | self.multi_sql = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 99 | self.multi_sql.eval() 100 | 101 | self.key_word = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 102 | self.key_word.eval() 103 | 104 | self.col = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 105 | self.col.eval() 106 | 107 | self.op = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 108 | self.op.eval() 109 | 110 | self.agg = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 111 | self.agg.eval() 112 | 113 | self.root_teminal = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 114 | self.root_teminal.eval() 115 | 116 | self.des_asc = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 117 | self.des_asc.eval() 118 | 119 | self.having = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) 120 | self.having.eval() 121 | 122 | self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) 123 | self.andor.eval() 124 | 125 | self.softmax = nn.Softmax() #dim=1 126 | self.CE = nn.CrossEntropyLoss() 127 | self.log_softmax = nn.LogSoftmax() 128 | self.mlsml = nn.MultiLabelSoftMarginLoss() 129 | self.bce_logit = nn.BCEWithLogitsLoss() 130 | self.sigm = nn.Sigmoid() 131 | if gpu: 132 | self.cuda() 133 | self.path_not_found = 0 134 | 135 | def forward(self,q_seq,history,tables): 136 | # if self.part: 137 | # return self.part_forward(q_seq,history,tables) 138 | # else: 139 | return self.full_forward(q_seq, history, tables) 140 | 141 | def full_forward(self, q_seq, history, tables): 142 | B = len(q_seq) 143 | # print("q_seq:{}".format(q_seq)) 144 | # print("Batch size:{}".format(B)) 145 | q_emb_var, q_len = self.embed_layer.gen_x_q_batch(q_seq) 146 | col_seq = to_batch_tables(tables, B, self.table_type) 147 | col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch(col_seq) 148 | 149 | mkw_emb_var = self.embed_layer.gen_word_list_embedding(["none","except","intersect","union"],(B)) 150 | mkw_len = np.full(q_len.shape, 4,dtype=np.int64) 151 | kw_emb_var = self.embed_layer.gen_word_list_embedding(["where", "group by", "order by"], (B)) 152 | kw_len = np.full(q_len.shape, 3, dtype=np.int64) 153 | 154 | stack = Stack() 155 | stack.push(("root",None)) 156 | history = [["root"]]*B 157 | andor_cond = "" 158 | has_limit = False 159 | # sql = {} 160 | current_sql = {} 161 | sql_stack = [] 162 | idx_stack = [] 163 | kw_stack = [] 164 | kw = "" 165 | nested_label = "" 166 | has_having = False 167 | 168 | timeout = time.time() + 2 # set timer to prevent infinite recursion in SQL generation 169 | failed = False 170 | while not stack.isEmpty(): 171 | if time.time() > timeout: failed=True; break 172 | vet = stack.pop() 173 | # print(vet) 174 | hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) 175 | if len(idx_stack) > 0 and stack.size() < idx_stack[-1]: 176 | # print("pop!!!!!!!!!!!!!!!!!!!!!!") 177 | idx_stack.pop() 178 | current_sql = sql_stack.pop() 179 | kw = kw_stack.pop() 180 | # current_sql = current_sql["sql"] 181 | # history.append(vet) 182 | # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size())) 183 | if isinstance(vet,tuple) and vet[0] == "root": 184 | if history[0][-1] != "root": 185 | history[0].append("root") 186 | hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) 187 | if vet[1] != "original": 188 | idx_stack.append(stack.size()) 189 | sql_stack.append(current_sql) 190 | kw_stack.append(kw) 191 | else: 192 | idx_stack.append(stack.size()) 193 | sql_stack.append(sql_stack[-1]) 194 | kw_stack.append(kw) 195 | if "sql" in current_sql: 196 | current_sql["nested_sql"] = {} 197 | current_sql["nested_label"] = nested_label 198 | current_sql = current_sql["nested_sql"] 199 | elif isinstance(vet[1],dict): 200 | vet[1]["sql"] = {} 201 | current_sql = vet[1]["sql"] 202 | elif vet[1] != "original": 203 | current_sql["sql"] = {} 204 | current_sql = current_sql["sql"] 205 | # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size())) 206 | if vet[1] == "nested" or vet[1] == "original": 207 | stack.push("none") 208 | history[0].append("none") 209 | else: 210 | score = self.multi_sql.forward(q_emb_var,q_len,hs_emb_var,hs_len,mkw_emb_var,mkw_len) 211 | label = np.argmax(score[0].data.cpu().numpy()) 212 | label = SQL_OPS[label] 213 | history[0].append(label) 214 | stack.push(label) 215 | if label != "none": 216 | nested_label = label 217 | 218 | elif vet in ('intersect', 'except', 'union'): 219 | stack.push(("root","nested")) 220 | stack.push(("root","original")) 221 | # history[0].append("root") 222 | elif vet == "none": 223 | score = self.key_word.forward(q_emb_var,q_len,hs_emb_var,hs_len,kw_emb_var,kw_len) 224 | kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] 225 | # print("kw_num_score:{}".format(kw_num_score)) 226 | # print("kw_score:{}".format(kw_score)) 227 | num_kw = np.argmax(kw_num_score[0]) 228 | kw_score = list(np.argsort(-kw_score[0])[:num_kw]) 229 | kw_score.sort(reverse=True) 230 | # print("num_kw:{}".format(num_kw)) 231 | for kw in kw_score: 232 | stack.push(KW_OPS[kw]) 233 | stack.push("select") 234 | elif vet in ("select","orderBy","where","groupBy","having"): 235 | kw = vet 236 | current_sql[kw] = [] 237 | history[0].append(vet) 238 | stack.push(("col",vet)) 239 | # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len) 240 | # label = score[0].data.cpu().numpy() 241 | # andor_cond = COND_OPS[label] 242 | # history.append("") 243 | # elif vet == "groupBy": 244 | # score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,) 245 | elif isinstance(vet,tuple) and vet[0] == "col": 246 | # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size())) 247 | score = self.col.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len) 248 | col_num_score, col_score = [x.data.cpu().numpy() for x in score] 249 | col_num = np.argmax(col_num_score[0]) + 1 # double check 250 | cols = np.argsort(-col_score[0])[:col_num] 251 | # print(col_num) 252 | # print("col_num_score:{}".format(col_num_score)) 253 | # print("col_score:{}".format(col_score)) 254 | for col in cols: 255 | if vet[1] == "where": 256 | stack.push(("op","where",col)) 257 | elif vet[1] != "groupBy": 258 | stack.push(("agg",vet[1],col)) 259 | elif vet[1] == "groupBy": 260 | history[0].append(index_to_column_name(col, tables)) 261 | current_sql[kw].append(index_to_column_name(col, tables)) 262 | #predict and or or when there is multi col in where condition 263 | if col_num > 1 and vet[1] == "where": 264 | score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len) 265 | label = np.argmax(score[0].data.cpu().numpy()) 266 | andor_cond = COND_OPS[label] 267 | current_sql[kw].append(andor_cond) 268 | if vet[1] == "groupBy" and col_num > 0: 269 | score = self.having.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, cols[0],dtype=np.int64)) 270 | label = np.argmax(score[0].data.cpu().numpy()) 271 | if label == 1: 272 | has_having = (label == 1) 273 | # stack.insert(-col_num,"having") 274 | stack.push("having") 275 | # history.append(index_to_column_name(cols[-1], tables[0])) 276 | elif isinstance(vet,tuple) and vet[0] == "agg": 277 | history[0].append(index_to_column_name(vet[2], tables)) 278 | if vet[1] not in ("having","orderBy"): #DEBUG-ed 20180817 279 | try: 280 | current_sql[kw].append(index_to_column_name(vet[2], tables)) 281 | except Exception as e: 282 | # print(e) 283 | traceback.print_exc() 284 | print("history:{},current_sql:{} stack:{}".format(history[0], current_sql,stack.items)) 285 | print("idx_stack:{}".format(idx_stack)) 286 | print("sql_stack:{}".format(sql_stack)) 287 | exit(1) 288 | hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) 289 | 290 | score = self.agg.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2],dtype=np.int64)) 291 | agg_num_score, agg_score = [x.data.cpu().numpy() for x in score] 292 | agg_num = np.argmax(agg_num_score[0]) # double check 293 | agg_idxs = np.argsort(-agg_score[0])[:agg_num] 294 | # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs])) 295 | if len(agg_idxs) > 0: 296 | history[0].append(AGG_OPS[agg_idxs[0]]) 297 | if vet[1] not in ("having", "orderBy"): 298 | current_sql[kw].append(AGG_OPS[agg_idxs[0]]) 299 | elif vet[1] == "orderBy": 300 | stack.push(("des_asc", vet[2], AGG_OPS[agg_idxs[0]])) #DEBUG-ed 20180817 301 | else: 302 | stack.push(("op","having",vet[2],AGG_OPS[agg_idxs[0]])) 303 | for agg in agg_idxs[1:]: 304 | history[0].append(index_to_column_name(vet[2], tables)) 305 | history[0].append(AGG_OPS[agg]) 306 | if vet[1] not in ("having", "orderBy"): 307 | current_sql[kw].append(index_to_column_name(vet[2], tables)) 308 | current_sql[kw].append(AGG_OPS[agg]) 309 | elif vet[1] == "orderBy": 310 | stack.push(("des_asc", vet[2], AGG_OPS[agg])) 311 | else: 312 | stack.push(("op", "having", vet[2], agg_idxs)) 313 | if len(agg_idxs) == 0: 314 | if vet[1] not in ("having", "orderBy"): 315 | current_sql[kw].append("none_agg") 316 | elif vet[1] == "orderBy": 317 | stack.push(("des_asc", vet[2], "none_agg")) 318 | else: 319 | stack.push(("op", "having", vet[2], "none_agg")) 320 | # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs]) 321 | # if vet[1] == "having": 322 | # stack.push(("op","having",vet[2],agg_idxs)) 323 | # if vet[1] == "orderBy": 324 | # stack.push(("des_asc",vet[2],agg_idxs)) 325 | # if vet[1] == "groupBy" and has_having: 326 | # stack.push("having") 327 | elif isinstance(vet,tuple) and vet[0] == "op": 328 | if vet[1] == "where": 329 | # current_sql[kw].append(index_to_column_name(vet[2], tables)) 330 | history[0].append(index_to_column_name(vet[2], tables)) 331 | hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) 332 | 333 | score = self.op.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2],dtype=np.int64)) 334 | 335 | op_num_score, op_score = [x.data.cpu().numpy() for x in score] 336 | op_num = np.argmax(op_num_score[0]) + 1 # num_score 0 maps to 1 in truth, must have at least one op 337 | ops = np.argsort(-op_score[0])[:op_num] 338 | # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops]) 339 | if op_num > 0: 340 | history[0].append(NEW_WHERE_OPS[ops[0]]) 341 | if vet[1] == "having": 342 | stack.push(("root_teminal", vet[2],vet[3],ops[0])) 343 | else: 344 | stack.push(("root_teminal", vet[2],ops[0])) 345 | # current_sql[kw].append(NEW_WHERE_OPS[ops[0]]) 346 | for op in ops[1:]: 347 | history[0].append(index_to_column_name(vet[2], tables)) 348 | history[0].append(NEW_WHERE_OPS[op]) 349 | # current_sql[kw].append(index_to_column_name(vet[2], tables)) 350 | # current_sql[kw].append(NEW_WHERE_OPS[op]) 351 | if vet[1] == "having": 352 | stack.push(("root_teminal", vet[2],vet[3],op)) 353 | else: 354 | stack.push(("root_teminal", vet[2],op)) 355 | # stack.push(("root_teminal",vet[2])) 356 | elif isinstance(vet,tuple) and vet[0] == "root_teminal": 357 | score = self.root_teminal.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1],dtype=np.int64)) 358 | 359 | label = np.argmax(score[0].data.cpu().numpy()) 360 | label = ROOT_TERM_OPS[label] 361 | if len(vet) == 4: 362 | current_sql[kw].append(index_to_column_name(vet[1], tables)) 363 | current_sql[kw].append(vet[2]) 364 | current_sql[kw].append(NEW_WHERE_OPS[vet[3]]) 365 | else: 366 | # print("kw:{}".format(kw)) 367 | try: 368 | current_sql[kw].append(index_to_column_name(vet[1], tables)) 369 | except Exception as e: 370 | # print(e) 371 | traceback.print_exc() 372 | print("history:{},current_sql:{} stack:{}".format(history[0], current_sql, stack.items)) 373 | print("idx_stack:{}".format(idx_stack)) 374 | print("sql_stack:{}".format(sql_stack)) 375 | exit(1) 376 | current_sql[kw].append(NEW_WHERE_OPS[vet[2]]) 377 | if label == "root": 378 | history[0].append("root") 379 | current_sql[kw].append({}) 380 | # current_sql = current_sql[kw][-1] 381 | stack.push(("root",current_sql[kw][-1])) 382 | else: 383 | current_sql[kw].append("terminal") 384 | elif isinstance(vet,tuple) and vet[0] == "des_asc": 385 | current_sql[kw].append(index_to_column_name(vet[1], tables)) 386 | current_sql[kw].append(vet[2]) 387 | score = self.des_asc.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1],dtype=np.int64)) 388 | label = np.argmax(score[0].data.cpu().numpy()) 389 | dec_asc,has_limit = DEC_ASC_OPS[label] 390 | history[0].append(dec_asc) 391 | current_sql[kw].append(dec_asc) 392 | current_sql[kw].append(has_limit) 393 | # print("{}".format(current_sql)) 394 | 395 | if failed: return None 396 | print("history:{}".format(history[0])) 397 | if len(sql_stack) > 0: 398 | current_sql = sql_stack[0] 399 | # print("{}".format(current_sql)) 400 | return current_sql 401 | 402 | 403 | def gen_col(self,col,table,table_alias_dict): 404 | colname = table["column_names_original"][col[2]][1] 405 | table_idx = table["column_names_original"][col[2]][0] 406 | if table_idx not in table_alias_dict: 407 | return colname 408 | return "T{}.{}".format(table_alias_dict[table_idx],colname) 409 | 410 | def gen_group_by(self,sql,kw,table,table_alias_dict): 411 | ret = [] 412 | for i in range(0,len(sql)): 413 | # if len(sql[i+1]) == 0: 414 | # if sql[i+1] == "none_agg": 415 | ret.append(self.gen_col(sql[i],table,table_alias_dict)) 416 | # else: 417 | # ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict))) 418 | # for agg in sql[i+1]: 419 | # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) 420 | return "{} {}".format(kw,",".join(ret)) 421 | 422 | def gen_select(self,sql,kw,table,table_alias_dict): 423 | ret = [] 424 | for i in range(0,len(sql),2): 425 | # if len(sql[i+1]) == 0: 426 | if sql[i+1] == "none_agg" or not isinstance(sql[i+1],basestring): #DEBUG-ed 20180817 427 | ret.append(self.gen_col(sql[i],table,table_alias_dict)) 428 | else: 429 | ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict))) 430 | # for agg in sql[i+1]: 431 | # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) 432 | return "{} {}".format(kw,",".join(ret)) 433 | 434 | def gen_where(self,sql,table,table_alias_dict): 435 | if len(sql) == 0: 436 | return "" 437 | start_idx = 0 438 | andor = "and" 439 | if isinstance(sql[0],basestring): 440 | start_idx += 1 441 | andor = sql[0] 442 | ret = [] 443 | for i in range(start_idx,len(sql),3): 444 | col = self.gen_col(sql[i],table,table_alias_dict) 445 | op = sql[i+1] 446 | val = sql[i+2] 447 | where_item = "" 448 | if val == "terminal": 449 | where_item = "{} {} '{}'".format(col,op,val) 450 | else: 451 | val = self.gen_sql(val,table) 452 | where_item = "{} {} ({})".format(col,op,val) 453 | if op == "between": 454 | #TODO temprarily fixed 455 | where_item += " and 'terminal'" 456 | ret.append(where_item) 457 | return "where {}".format(" {} ".format(andor).join(ret)) 458 | 459 | def gen_orderby(self,sql,table,table_alias_dict): 460 | ret = [] 461 | limit = "" 462 | if sql[-1] == True: 463 | limit = "limit 1" 464 | for i in range(0,len(sql),4): 465 | if sql[i+1] == "none_agg" or not isinstance(sql[i+1],basestring): #DEBUG-ed 20180817 466 | ret.append("{} {}".format(self.gen_col(sql[i],table,table_alias_dict), sql[i+2])) 467 | else: 468 | ret.append("{}({}) {}".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict),sql[i+2])) 469 | return "order by {} {}".format(",".join(ret),limit) 470 | 471 | def gen_having(self,sql,table,table_alias_dict): 472 | ret = [] 473 | for i in range(0,len(sql),4): 474 | if sql[i+1] == "none_agg": 475 | col = self.gen_col(sql[i],table,table_alias_dict) 476 | else: 477 | col = "{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict)) 478 | op = sql[i+2] 479 | val = sql[i+3] 480 | if val == "terminal": 481 | ret.append("{} {} '{}'".format(col,op,val)) 482 | else: 483 | val = self.gen_sql(val, table) 484 | ret.append("{} {} ({})".format(col, op, val)) 485 | return "having {}".format(",".join(ret)) 486 | 487 | def find_shortest_path(self,start,end,graph): 488 | stack = [[start,[]]] 489 | visited = set() 490 | while len(stack) > 0: 491 | ele,history = stack.pop() 492 | if ele == end: 493 | return history 494 | for node in graph[ele]: 495 | if node[0] not in visited: 496 | stack.append((node[0],history+[(node[0],node[1])])) 497 | visited.add(node[0]) 498 | print("table {} table {}".format(start,end)) 499 | # print("could not find path!!!!!{}".format(self.path_not_found)) 500 | self.path_not_found += 1 501 | # return [] 502 | def gen_from(self,candidate_tables,table): 503 | def find(d,col): 504 | if d[col] == -1: 505 | return col 506 | return find(d,d[col]) 507 | def union(d,c1,c2): 508 | r1 = find(d,c1) 509 | r2 = find(d,c2) 510 | if r1 == r2: 511 | return 512 | d[r1] = r2 513 | 514 | ret = "" 515 | if len(candidate_tables) <= 1: 516 | if len(candidate_tables) == 1: 517 | ret = "from {}".format(table["table_names_original"][list(candidate_tables)[0]]) 518 | else: 519 | ret = "from {}".format(table["table_names_original"][0]) 520 | #TODO: temporarily settings 521 | return {},ret 522 | # print("candidate:{}".format(candidate_tables)) 523 | table_alias_dict = {} 524 | uf_dict = {} 525 | for t in candidate_tables: 526 | uf_dict[t] = -1 527 | idx = 1 528 | graph = defaultdict(list) 529 | for acol,bcol in table["foreign_keys"]: 530 | t1 = table["column_names"][acol][0] 531 | t2 = table["column_names"][bcol][0] 532 | graph[t1].append((t2,(acol,bcol))) 533 | graph[t2].append((t1,(bcol, acol))) 534 | # if t1 in candidate_tables and t2 in candidate_tables: 535 | # r1 = find(uf_dict,t1) 536 | # r2 = find(uf_dict,t2) 537 | # if r1 == r2: 538 | # continue 539 | # union(uf_dict,t1,t2) 540 | # if len(ret) == 0: 541 | # ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2], 542 | # idx+1,idx,table["column_names_original"][acol][1],idx+1, 543 | # table["column_names_original"][bcol][1]) 544 | # table_alias_dict[t1] = idx 545 | # table_alias_dict[t2] = idx+1 546 | # idx += 2 547 | # else: 548 | # if t1 in table_alias_dict: 549 | # old_t = t1 550 | # new_t = t2 551 | # acol,bcol = bcol,acol 552 | # elif t2 in table_alias_dict: 553 | # old_t = t2 554 | # new_t = t1 555 | # else: 556 | # ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx, 557 | # table["table_names"][t2], 558 | # idx + 1, idx, 559 | # table["column_names_original"][acol][1], 560 | # idx + 1, 561 | # table["column_names_original"][bcol][1]) 562 | # table_alias_dict[t1] = idx 563 | # table_alias_dict[t2] = idx + 1 564 | # idx += 2 565 | # continue 566 | # ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1], 567 | # table_alias_dict[old_t],table["column_names_original"][bcol][1]) 568 | # table_alias_dict[new_t] = idx 569 | # idx += 1 570 | # visited = set() 571 | candidate_tables = list(candidate_tables) 572 | start = candidate_tables[0] 573 | table_alias_dict[start] = idx 574 | idx += 1 575 | ret = "from {} as T1".format(table["table_names_original"][start]) 576 | try: 577 | for end in candidate_tables[1:]: 578 | if end in table_alias_dict: 579 | continue 580 | path = self.find_shortest_path(start, end, graph) 581 | prev_table = start 582 | if not path: 583 | table_alias_dict[end] = idx 584 | idx += 1 585 | ret = "{} join {} as T{}".format(ret, table["table_names_original"][end], 586 | table_alias_dict[end], 587 | ) 588 | continue 589 | for node, (acol, bcol) in path: 590 | if node in table_alias_dict: 591 | prev_table = node 592 | continue 593 | table_alias_dict[node] = idx 594 | idx += 1 595 | ret = "{} join {} as T{} on T{}.{} = T{}.{}".format(ret, table["table_names_original"][node], 596 | table_alias_dict[node], 597 | table_alias_dict[prev_table], 598 | table["column_names_original"][acol][1], 599 | table_alias_dict[node], 600 | table["column_names_original"][bcol][1]) 601 | prev_table = node 602 | except: 603 | traceback.print_exc() 604 | print("db:{}".format(table["db_id"])) 605 | # print(table["db_id"]) 606 | return table_alias_dict,ret 607 | # if len(candidate_tables) != len(table_alias_dict): 608 | # print("error in generate from clause!!!!!") 609 | return table_alias_dict,ret 610 | 611 | def gen_sql(self, sql,table): 612 | select_clause = "" 613 | from_clause = "" 614 | groupby_clause = "" 615 | orderby_clause = "" 616 | having_clause = "" 617 | where_clause = "" 618 | nested_clause = "" 619 | cols = {} 620 | candidate_tables = set() 621 | nested_sql = {} 622 | nested_label = "" 623 | parent_sql = sql 624 | # if "sql" in sql: 625 | # sql = sql["sql"] 626 | if "nested_label" in sql: 627 | nested_label = sql["nested_label"] 628 | nested_sql = sql["nested_sql"] 629 | sql = sql["sql"] 630 | elif "sql" in sql: 631 | sql = sql["sql"] 632 | for key in sql: 633 | if key not in KW_WITH_COL: 634 | continue 635 | for item in sql[key]: 636 | if isinstance(item,tuple) and len(item) == 3: 637 | if table["column_names"][item[2]][0] != -1: 638 | candidate_tables.add(table["column_names"][item[2]][0]) 639 | table_alias_dict,from_clause = self.gen_from(candidate_tables,table) 640 | ret = [] 641 | if "select" in sql: 642 | select_clause = self.gen_select(sql["select"],"select",table,table_alias_dict) 643 | if len(select_clause) > 0: 644 | ret.append(select_clause) 645 | else: 646 | print("select not found:{}".format(parent_sql)) 647 | else: 648 | print("select not found:{}".format(parent_sql)) 649 | if len(from_clause) > 0: 650 | ret.append(from_clause) 651 | if "where" in sql: 652 | where_clause = self.gen_where(sql["where"],table,table_alias_dict) 653 | if len(where_clause) > 0: 654 | ret.append(where_clause) 655 | if "groupBy" in sql: ## DEBUG-ed order 656 | groupby_clause = self.gen_group_by(sql["groupBy"],"group by",table,table_alias_dict) 657 | if len(groupby_clause) > 0: 658 | ret.append(groupby_clause) 659 | if "orderBy" in sql: 660 | orderby_clause = self.gen_orderby(sql["orderBy"],table,table_alias_dict) 661 | if len(orderby_clause) > 0: 662 | ret.append(orderby_clause) 663 | if "having" in sql: 664 | having_clause = self.gen_having(sql["having"],table,table_alias_dict) 665 | if len(having_clause) > 0: 666 | ret.append(having_clause) 667 | if len(nested_label) > 0: 668 | nested_clause = "{} {}".format(nested_label,self.gen_sql(nested_sql,table)) 669 | if len(nested_clause) > 0: 670 | ret.append(nested_clause) 671 | return " ".join(ret) 672 | 673 | def check_acc(self, pred_sql, gt_sql): 674 | pass 675 | --------------------------------------------------------------------------------