├── LICENSE ├── README.md ├── download_glove.sh ├── extract_vocab.py ├── requirements.txt ├── sqlnet ├── lib │ └── dbengine.py ├── model │ ├── modules │ │ ├── aggregator_predict.py │ │ ├── net_utils.py │ │ ├── selection_predict.py │ │ ├── seq2sql_condition_predict.py │ │ ├── sqlnet_condition_predict.py │ │ └── word_embedding.py │ ├── seq2sql.py │ └── sqlnet.py └── utils.py ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Xiaojun Xu, Chang Liu and Dawn Song 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SQLNet 2 | 3 | This is the code for [this](https://youtu.be/Rw3ewEXOKC8) video on Youtube by Siraj Raval. This repo provides an implementation of SQLNet and Seq2SQL neural networks for predicting SQL queries on [WikiSQL dataset](https://github.com/salesforce/WikiSQL). The paper is available at [here](https://arxiv.org/abs/1711.04436). 4 | 5 | ## Citation 6 | 7 | > Xiaojun Xu, Chang Liu, Dawn Song. 2017. SQLNet: Generating Structured Queries from Natural Language Without Reinforcement Learning. 8 | 9 | ## Bibtex 10 | 11 | ``` 12 | @article{xu2017sqlnet, 13 | title={SQLNet: Generating Structured Queries From Natural Language Without Reinforcement Learning}, 14 | author={Xu, Xiaojun and Liu, Chang and Song, Dawn}, 15 | journal={arXiv preprint arXiv:1711.04436}, 16 | year={2017} 17 | } 18 | ``` 19 | 20 | ## Installation 21 | The data is in `data.tar.bz2`. Unzip the code by running 22 | ```bash 23 | tar -xjvf data.tar.bz2 24 | ``` 25 | 26 | The code is written using PyTorch in Python 2.7. Check [here](http://pytorch.org/) to install PyTorch. You can install other dependency by running 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Downloading the glove embedding. 32 | Download the pretrained glove embedding from [here](https://github.com/stanfordnlp/GloVe) using 33 | ```bash 34 | bash download_glove.sh 35 | ``` 36 | 37 | ## Extract the glove embedding for training. 38 | Run the following command to process the pretrained glove embedding for training the word embedding: 39 | ```bash 40 | python extract_vocab.py 41 | ``` 42 | 43 | ## Train 44 | The training script is `train.py`. To see the detailed parameters for running: 45 | ```bash 46 | python train.py -h 47 | ``` 48 | 49 | Some typical usage are listed as below: 50 | 51 | Train a SQLNet model with column attention: 52 | ```bash 53 | python train.py --ca 54 | ``` 55 | 56 | Train a SQLNet model with column attention and trainable embedding (requires pretraining without training embedding, i.e., executing the command above): 57 | ```bash 58 | python train.py --ca --train_emb 59 | ``` 60 | 61 | Pretrain a [Seq2SQL model](https://arxiv.org/abs/1709.00103) on the re-splitted dataset 62 | ```bash 63 | python train.py --baseline --dataset 1 64 | ``` 65 | 66 | Train a Seq2SQL model with Reinforcement Learning after pretraining 67 | ```bash 68 | python train.py --baseline --dataset 1 --rl 69 | ``` 70 | 71 | ## Test 72 | The script for evaluation on the dev split and test split. The parameters for evaluation is roughly the same as the one used for training. For example, the commands for evaluating the models from above commands are: 73 | 74 | Test a trained SQLNet model with column attention 75 | ```bash 76 | python test.py --ca 77 | ``` 78 | 79 | Test a trained SQLNet model with column attention and trainable embedding: 80 | ```bash 81 | python test.py --ca --train_emb 82 | ``` 83 | 84 | Test a trained [Seq2SQL model](https://arxiv.org/abs/1709.00103) withour RL on the re-splitted dataset 85 | ```bash 86 | python test.py --baseline --dataset 1 87 | ``` 88 | 89 | Test a trained Seq2SQL model with Reinforcement learning 90 | ```bash 91 | python test.py --baseline --dataset 1 --rl 92 | ``` 93 | 94 | ## Credits 95 | 96 | Credits for this code go to [xiaojunxu](https://github.com/xiaojunxu/SQLNet). I've merely created a wrapper to get people started. 97 | -------------------------------------------------------------------------------- /download_glove.sh: -------------------------------------------------------------------------------- 1 | if [[ ! -d glove ]]; then 2 | mkdir glove 3 | fi 4 | 5 | cd glove 6 | wget http://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip 7 | unzip glove.42B.300d.zip 8 | -------------------------------------------------------------------------------- /extract_vocab.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from sqlnet.utils import * 4 | import numpy as np 5 | import datetime 6 | 7 | LOCAL_TEST=False 8 | 9 | 10 | if LOCAL_TEST: 11 | N_word=100 12 | B_word=6 13 | USE_SMALL=True 14 | else: 15 | N_word=300 16 | B_word=42 17 | USE_SMALL=False 18 | 19 | sql_data, table_data, val_sql_data, val_table_data,\ 20 | test_sql_data, test_table_data, TRAIN_DB, DEV_DB, TEST_DB = \ 21 | load_dataset(0, use_small=USE_SMALL) 22 | word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), 23 | use_small=USE_SMALL) 24 | print "Length of word vocabulary: %d"%len(word_emb) 25 | 26 | word_to_idx = {'':0, '':1, '':2} 27 | word_num = 3 28 | embs = [np.zeros(N_word,dtype=np.float32) for _ in range(word_num)] 29 | 30 | def check_and_add(tok): 31 | #Check if the tok is in the vocab. If not, add it. 32 | global word_num 33 | if tok not in word_to_idx and tok in word_emb: 34 | word_to_idx[tok] = word_num 35 | word_num += 1 36 | embs.append(word_emb[tok]) 37 | 38 | for sql in sql_data: 39 | for tok in sql['question_tok']: 40 | check_and_add(tok) 41 | for tab in table_data.values(): 42 | for col in tab['header_tok']: 43 | for tok in col: 44 | check_and_add(tok) 45 | for sql in val_sql_data: 46 | for tok in sql['question_tok']: 47 | check_and_add(tok) 48 | for tab in val_table_data.values(): 49 | for col in tab['header_tok']: 50 | for tok in col: 51 | check_and_add(tok) 52 | for sql in test_sql_data: 53 | for tok in sql['question_tok']: 54 | check_and_add(tok) 55 | for tab in test_table_data.values(): 56 | for col in tab['header_tok']: 57 | for tok in col: 58 | check_and_add(tok) 59 | 60 | print "Length of used word vocab: %s"%len(word_to_idx) 61 | 62 | emb_array = np.stack(embs, axis=0) 63 | with open('glove/word2idx.json', 'w') as outf: 64 | json.dump(word_to_idx, outf) 65 | np.save(open('glove/usedwordemb.npy', 'w'), emb_array) 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Babel==2.5.1 2 | docopt==0.6.2 3 | et-xmlfile==1.0.1 4 | jdcal==1.3 5 | odfpy==1.3.5 6 | olefile==0.44 7 | openpyxl==2.4.9 8 | pkg-resources==0.0.0 9 | pytz==2017.3 10 | records==0.5.2 11 | SQLAlchemy==1.1.14 12 | tablib==0.12.1 13 | unicodecsv==0.14.1 14 | xlrd==1.1.0 15 | xlwt==1.3.0 16 | -------------------------------------------------------------------------------- /sqlnet/lib/dbengine.py: -------------------------------------------------------------------------------- 1 | import records 2 | import re 3 | from babel.numbers import parse_decimal, NumberFormatError 4 | 5 | 6 | schema_re = re.compile(r'\((.+)\)') 7 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') 8 | 9 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 10 | cond_ops = ['=', '>', '<', 'OP'] 11 | 12 | class DBEngine: 13 | 14 | def __init__(self, fdb): 15 | #fdb = 'data/test.db' 16 | self.db = records.Database('sqlite:///{}'.format(fdb)) 17 | 18 | def execute_query(self, table_id, query, *args, **kwargs): 19 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 20 | 21 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 22 | if not table_id.startswith('table'): 23 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 24 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 25 | schema_str = schema_re.findall(table_info)[0] 26 | schema = {} 27 | for tup in schema_str.split(', '): 28 | c, t = tup.split() 29 | schema[c] = t 30 | select = 'col{}'.format(select_index) 31 | agg = agg_ops[aggregation_index] 32 | if agg: 33 | select = '{}({})'.format(agg, select) 34 | where_clause = [] 35 | where_map = {} 36 | for col_index, op, val in conditions: 37 | if lower and (isinstance(val, str) or isinstance(val, unicode)): 38 | val = val.lower() 39 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 40 | try: 41 | val = float(parse_decimal(val)) 42 | except NumberFormatError as e: 43 | val = float(num_re.findall(val)[0]) 44 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 45 | where_map['col{}'.format(col_index)] = val 46 | where_str = '' 47 | if where_clause: 48 | where_str = 'WHERE ' + ' AND '.join(where_clause) 49 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 50 | #print query 51 | out = self.db.query(query, **where_map) 52 | return [o.result for o in out] 53 | -------------------------------------------------------------------------------- /sqlnet/model/modules/aggregator_predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | 10 | 11 | class AggPredictor(nn.Module): 12 | def __init__(self, N_word, N_h, N_depth, use_ca): 13 | super(AggPredictor, self).__init__() 14 | self.use_ca = use_ca 15 | 16 | self.agg_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 17 | num_layers=N_depth, batch_first=True, 18 | dropout=0.3, bidirectional=True) 19 | if use_ca: 20 | print "Using column attention on aggregator predicting" 21 | self.agg_col_name_enc = nn.LSTM(input_size=N_word, 22 | hidden_size=N_h/2, num_layers=N_depth, 23 | batch_first=True, dropout=0.3, bidirectional=True) 24 | self.agg_att = nn.Linear(N_h, N_h) 25 | else: 26 | print "Not using column attention on aggregator predicting" 27 | self.agg_att = nn.Linear(N_h, 1) 28 | self.agg_out = nn.Sequential(nn.Linear(N_h, N_h), 29 | nn.Tanh(), nn.Linear(N_h, 6)) 30 | self.softmax = nn.Softmax() 31 | 32 | def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, 33 | col_len=None, col_num=None, gt_sel=None): 34 | B = len(x_emb_var) 35 | max_x_len = max(x_len) 36 | 37 | h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) 38 | if self.use_ca: 39 | e_col, _ = col_name_encode(col_inp_var, col_name_len, 40 | col_len, self.agg_col_name_enc) 41 | chosen_sel_idx = torch.LongTensor(gt_sel) 42 | aux_range = torch.LongTensor(range(len(gt_sel))) 43 | if x_emb_var.is_cuda: 44 | chosen_sel_idx = chosen_sel_idx.cuda() 45 | aux_range = aux_range.cuda() 46 | chosen_e_col = e_col[aux_range, chosen_sel_idx] 47 | att_val = torch.bmm(self.agg_att(h_enc), 48 | chosen_e_col.unsqueeze(2)).squeeze() 49 | else: 50 | att_val = self.agg_att(h_enc).squeeze() 51 | 52 | for idx, num in enumerate(x_len): 53 | if num < max_x_len: 54 | att_val[idx, num:] = -100 55 | att = self.softmax(att_val) 56 | 57 | K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) 58 | agg_score = self.agg_out(K_agg) 59 | return agg_score 60 | -------------------------------------------------------------------------------- /sqlnet/model/modules/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd import Variable 5 | 6 | def run_lstm(lstm, inp, inp_len, hidden=None): 7 | # Run the LSTM using packed sequence. 8 | # This requires to first sort the input according to its length. 9 | sort_perm = np.array(sorted(range(len(inp_len)), 10 | key=lambda k:inp_len[k], reverse=True)) 11 | sort_inp_len = inp_len[sort_perm] 12 | sort_perm_inv = np.argsort(sort_perm) 13 | if inp.is_cuda: 14 | sort_perm = torch.LongTensor(sort_perm).cuda() 15 | sort_perm_inv = torch.LongTensor(sort_perm_inv).cuda() 16 | 17 | lstm_inp = nn.utils.rnn.pack_padded_sequence(inp[sort_perm], 18 | sort_inp_len, batch_first=True) 19 | if hidden is None: 20 | lstm_hidden = None 21 | else: 22 | lstm_hidden = (hidden[0][:, sort_perm], hidden[1][:, sort_perm]) 23 | 24 | sort_ret_s, sort_ret_h = lstm(lstm_inp, lstm_hidden) 25 | ret_s = nn.utils.rnn.pad_packed_sequence( 26 | sort_ret_s, batch_first=True)[0][sort_perm_inv] 27 | ret_h = (sort_ret_h[0][:, sort_perm_inv], sort_ret_h[1][:, sort_perm_inv]) 28 | return ret_s, ret_h 29 | 30 | 31 | def col_name_encode(name_inp_var, name_len, col_len, enc_lstm): 32 | #Encode the columns. 33 | #The embedding of a column name is the last state of its LSTM output. 34 | name_hidden, _ = run_lstm(enc_lstm, name_inp_var, name_len) 35 | name_out = name_hidden[tuple(range(len(name_len))), name_len-1] 36 | ret = torch.FloatTensor( 37 | len(col_len), max(col_len), name_out.size()[1]).zero_() 38 | if name_out.is_cuda: 39 | ret = ret.cuda() 40 | 41 | st = 0 42 | for idx, cur_len in enumerate(col_len): 43 | ret[idx, :cur_len] = name_out.data[st:st+cur_len] 44 | st += cur_len 45 | ret_var = Variable(ret) 46 | 47 | return ret_var, col_len 48 | -------------------------------------------------------------------------------- /sqlnet/model/modules/selection_predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | class SelPredictor(nn.Module): 10 | def __init__(self, N_word, N_h, N_depth, max_tok_num, use_ca): 11 | super(SelPredictor, self).__init__() 12 | self.use_ca = use_ca 13 | self.max_tok_num = max_tok_num 14 | self.sel_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 15 | num_layers=N_depth, batch_first=True, 16 | dropout=0.3, bidirectional=True) 17 | if use_ca: 18 | print "Using column attention on selection predicting" 19 | self.sel_att = nn.Linear(N_h, N_h) 20 | else: 21 | print "Not using column attention on selection predicting" 22 | self.sel_att = nn.Linear(N_h, 1) 23 | self.sel_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 24 | num_layers=N_depth, batch_first=True, 25 | dropout=0.3, bidirectional=True) 26 | self.sel_out_K = nn.Linear(N_h, N_h) 27 | self.sel_out_col = nn.Linear(N_h, N_h) 28 | self.sel_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) 29 | self.softmax = nn.Softmax() 30 | 31 | 32 | def forward(self, x_emb_var, x_len, col_inp_var, 33 | col_name_len, col_len, col_num): 34 | B = len(x_emb_var) 35 | max_x_len = max(x_len) 36 | 37 | e_col, _ = col_name_encode(col_inp_var, col_name_len, 38 | col_len, self.sel_col_name_enc) 39 | 40 | if self.use_ca: 41 | h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) 42 | att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) 43 | for idx, num in enumerate(x_len): 44 | if num < max_x_len: 45 | att_val[idx, :, num:] = -100 46 | att = self.softmax(att_val.view((-1, max_x_len))).view( 47 | B, -1, max_x_len) 48 | K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) 49 | else: 50 | h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) 51 | att_val = self.sel_att(h_enc).squeeze() 52 | for idx, num in enumerate(x_len): 53 | if num < max_x_len: 54 | att_val[idx, num:] = -100 55 | att = self.softmax(att_val) 56 | K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) 57 | K_sel_expand=K_sel.unsqueeze(1) 58 | 59 | sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ 60 | self.sel_out_col(e_col) ).squeeze() 61 | max_col_num = max(col_num) 62 | for idx, num in enumerate(col_num): 63 | if num < max_col_num: 64 | sel_score[idx, num:] = -100 65 | 66 | return sel_score 67 | -------------------------------------------------------------------------------- /sqlnet/model/modules/seq2sql_condition_predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from net_utils import run_lstm 8 | 9 | class Seq2SQLCondPredictor(nn.Module): 10 | def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, gpu): 11 | super(Seq2SQLCondPredictor, self).__init__() 12 | print "Seq2SQL where prediction" 13 | self.N_h = N_h 14 | self.max_tok_num = max_tok_num 15 | self.max_col_num = max_col_num 16 | self.gpu = gpu 17 | 18 | self.cond_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 19 | num_layers=N_depth, batch_first=True, 20 | dropout=0.3, bidirectional=True) 21 | self.cond_decoder = nn.LSTM(input_size=self.max_tok_num, 22 | hidden_size=N_h, num_layers=N_depth, 23 | batch_first=True, dropout=0.3) 24 | 25 | self.cond_out_g = nn.Linear(N_h, N_h) 26 | self.cond_out_h = nn.Linear(N_h, N_h) 27 | self.cond_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) 28 | 29 | self.softmax = nn.Softmax() 30 | 31 | 32 | def gen_gt_batch(self, tok_seq, gen_inp=True): 33 | # If gen_inp: generate the input token sequence (removing ) 34 | # Otherwise: generate the output token sequence (removing ) 35 | B = len(tok_seq) 36 | ret_len = np.array([len(one_tok_seq)-1 for one_tok_seq in tok_seq]) 37 | max_len = max(ret_len) 38 | ret_array = np.zeros((B, max_len, self.max_tok_num), dtype=np.float32) 39 | for b, one_tok_seq in enumerate(tok_seq): 40 | out_one_tok_seq = one_tok_seq[:-1] if gen_inp else one_tok_seq[1:] 41 | for t, tok_id in enumerate(out_one_tok_seq): 42 | ret_array[b, t, tok_id] = 1 43 | 44 | ret_inp = torch.from_numpy(ret_array) 45 | if self.gpu: 46 | ret_inp = ret_inp.cuda() 47 | ret_inp_var = Variable(ret_inp) #[B, max_len, max_tok_num] 48 | 49 | return ret_inp_var, ret_len 50 | 51 | 52 | def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, 53 | col_num, gt_where, gt_cond, reinforce): 54 | max_x_len = max(x_len) 55 | B = len(x_len) 56 | 57 | h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len) 58 | decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2) 59 | for hid in hidden) 60 | if gt_where is not None: 61 | gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True) 62 | g_s, _ = run_lstm(self.cond_decoder, 63 | gt_tok_seq, gt_tok_len, decoder_hidden) 64 | 65 | h_enc_expand = h_enc.unsqueeze(1) 66 | g_s_expand = g_s.unsqueeze(2) 67 | cond_score = self.cond_out( self.cond_out_h(h_enc_expand) + 68 | self.cond_out_g(g_s_expand) ).squeeze() 69 | for idx, num in enumerate(x_len): 70 | if num < max_x_len: 71 | cond_score[idx, :, num:] = -100 72 | else: 73 | h_enc_expand = h_enc.unsqueeze(1) 74 | scores = [] 75 | choices = [] 76 | done_set = set() 77 | 78 | t = 0 79 | init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32) 80 | init_inp[:,0,7] = 1 #Set the token 81 | if self.gpu: 82 | cur_inp = Variable(torch.from_numpy(init_inp).cuda()) 83 | else: 84 | cur_inp = Variable(torch.from_numpy(init_inp)) 85 | cur_h = decoder_hidden 86 | while len(done_set) < B and t < 100: 87 | g_s, cur_h = self.cond_decoder(cur_inp, cur_h) 88 | g_s_expand = g_s.unsqueeze(2) 89 | 90 | cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) + 91 | self.cond_out_g(g_s_expand)).squeeze() 92 | for b, num in enumerate(x_len): 93 | if num < max_x_len: 94 | cur_cond_score[b, num:] = -100 95 | scores.append(cur_cond_score) 96 | 97 | if not reinforce: 98 | _, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1) 99 | ans_tok_var = ans_tok_var.unsqueeze(1) 100 | else: 101 | ans_tok_var = self.softmax(cur_cond_score).multinomial() 102 | choices.append(ans_tok_var) 103 | ans_tok = ans_tok_var.data.cpu() 104 | if self.gpu: #To one-hot 105 | cur_inp = Variable(torch.zeros( 106 | B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda()) 107 | else: 108 | cur_inp = Variable(torch.zeros( 109 | B, self.max_tok_num).scatter_(1, ans_tok, 1)) 110 | cur_inp = cur_inp.unsqueeze(1) 111 | 112 | for idx, tok in enumerate(ans_tok.squeeze()): 113 | if tok == 1: #Find the token 114 | done_set.add(idx) 115 | t += 1 116 | 117 | cond_score = torch.stack(scores, 1) 118 | 119 | if reinforce: 120 | return cond_score, choices 121 | else: 122 | return cond_score 123 | -------------------------------------------------------------------------------- /sqlnet/model/modules/sqlnet_condition_predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from net_utils import run_lstm, col_name_encode 8 | 9 | class SQLNetCondPredictor(nn.Module): 10 | def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu): 11 | super(SQLNetCondPredictor, self).__init__() 12 | self.N_h = N_h 13 | self.max_tok_num = max_tok_num 14 | self.max_col_num = max_col_num 15 | self.gpu = gpu 16 | self.use_ca = use_ca 17 | 18 | self.cond_num_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 19 | num_layers=N_depth, batch_first=True, 20 | dropout=0.3, bidirectional=True) 21 | self.cond_num_att = nn.Linear(N_h, 1) 22 | self.cond_num_out = nn.Sequential(nn.Linear(N_h, N_h), 23 | nn.Tanh(), nn.Linear(N_h, 5)) 24 | self.cond_num_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 25 | num_layers=N_depth, batch_first=True, 26 | dropout=0.3, bidirectional=True) 27 | self.cond_num_col_att = nn.Linear(N_h, 1) 28 | self.cond_num_col2hid1 = nn.Linear(N_h, 2*N_h) 29 | self.cond_num_col2hid2 = nn.Linear(N_h, 2*N_h) 30 | 31 | self.cond_col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 32 | num_layers=N_depth, batch_first=True, 33 | dropout=0.3, bidirectional=True) 34 | if use_ca: 35 | print "Using column attention on where predicting" 36 | self.cond_col_att = nn.Linear(N_h, N_h) 37 | else: 38 | print "Not using column attention on where predicting" 39 | self.cond_col_att = nn.Linear(N_h, 1) 40 | self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 41 | num_layers=N_depth, batch_first=True, 42 | dropout=0.3, bidirectional=True) 43 | self.cond_col_out_K = nn.Linear(N_h, N_h) 44 | self.cond_col_out_col = nn.Linear(N_h, N_h) 45 | self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1)) 46 | 47 | self.cond_op_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 48 | num_layers=N_depth, batch_first=True, 49 | dropout=0.3, bidirectional=True) 50 | if use_ca: 51 | self.cond_op_att = nn.Linear(N_h, N_h) 52 | else: 53 | self.cond_op_att = nn.Linear(N_h, 1) 54 | self.cond_op_out_K = nn.Linear(N_h, N_h) 55 | self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 56 | num_layers=N_depth, batch_first=True, 57 | dropout=0.3, bidirectional=True) 58 | self.cond_op_out_col = nn.Linear(N_h, N_h) 59 | self.cond_op_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(), 60 | nn.Linear(N_h, 3)) 61 | 62 | self.cond_str_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 63 | num_layers=N_depth, batch_first=True, 64 | dropout=0.3, bidirectional=True) 65 | self.cond_str_decoder = nn.LSTM(input_size=self.max_tok_num, 66 | hidden_size=N_h, num_layers=N_depth, 67 | batch_first=True, dropout=0.3) 68 | self.cond_str_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, 69 | num_layers=N_depth, batch_first=True, 70 | dropout=0.3, bidirectional=True) 71 | self.cond_str_out_g = nn.Linear(N_h, N_h) 72 | self.cond_str_out_h = nn.Linear(N_h, N_h) 73 | self.cond_str_out_col = nn.Linear(N_h, N_h) 74 | self.cond_str_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1)) 75 | 76 | self.softmax = nn.Softmax() 77 | 78 | 79 | def gen_gt_batch(self, split_tok_seq): 80 | B = len(split_tok_seq) 81 | max_len = max([max([len(tok) for tok in tok_seq]+[0]) for 82 | tok_seq in split_tok_seq]) - 1 # The max seq len in the batch. 83 | if max_len < 1: 84 | max_len = 1 85 | ret_array = np.zeros(( 86 | B, 4, max_len, self.max_tok_num), dtype=np.float32) 87 | ret_len = np.zeros((B, 4)) 88 | for b, tok_seq in enumerate(split_tok_seq): 89 | idx = 0 90 | for idx, one_tok_seq in enumerate(tok_seq): 91 | out_one_tok_seq = one_tok_seq[:-1] 92 | ret_len[b, idx] = len(out_one_tok_seq) 93 | for t, tok_id in enumerate(out_one_tok_seq): 94 | ret_array[b, idx, t, tok_id] = 1 95 | if idx < 3: 96 | ret_array[b, idx+1:, 0, 1] = 1 97 | ret_len[b, idx+1:] = 1 98 | 99 | ret_inp = torch.from_numpy(ret_array) 100 | if self.gpu: 101 | ret_inp = ret_inp.cuda() 102 | ret_inp_var = Variable(ret_inp) 103 | 104 | return ret_inp_var, ret_len #[B, IDX, max_len, max_tok_num] 105 | 106 | 107 | def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, 108 | col_len, col_num, gt_where, gt_cond, reinforce): 109 | max_x_len = max(x_len) 110 | B = len(x_len) 111 | if reinforce: 112 | raise NotImplementedError('Our model doesn\'t have RL') 113 | 114 | # Predict the number of conditions 115 | # First use column embeddings to calculate the initial hidden unit 116 | # Then run the LSTM and predict condition number. 117 | e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, 118 | col_len, self.cond_num_name_enc) 119 | num_col_att_val = self.cond_num_col_att(e_num_col).squeeze() 120 | for idx, num in enumerate(col_num): 121 | if num < max(col_num): 122 | num_col_att_val[idx, num:] = -100 123 | num_col_att = self.softmax(num_col_att_val) 124 | K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) 125 | cond_num_h1 = self.cond_num_col2hid1(K_num_col).view( 126 | B, 4, self.N_h/2).transpose(0, 1).contiguous() 127 | cond_num_h2 = self.cond_num_col2hid2(K_num_col).view( 128 | B, 4, self.N_h/2).transpose(0, 1).contiguous() 129 | 130 | h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len, 131 | hidden=(cond_num_h1, cond_num_h2)) 132 | 133 | num_att_val = self.cond_num_att(h_num_enc).squeeze() 134 | 135 | for idx, num in enumerate(x_len): 136 | if num < max_x_len: 137 | num_att_val[idx, num:] = -100 138 | num_att = self.softmax(num_att_val) 139 | 140 | K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as( 141 | h_num_enc)).sum(1) 142 | cond_num_score = self.cond_num_out(K_cond_num) 143 | 144 | #Predict the columns of conditions 145 | e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, 146 | self.cond_col_name_enc) 147 | 148 | h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len) 149 | if self.use_ca: 150 | col_att_val = torch.bmm(e_cond_col, 151 | self.cond_col_att(h_col_enc).transpose(1, 2)) 152 | for idx, num in enumerate(x_len): 153 | if num < max_x_len: 154 | col_att_val[idx, :, num:] = -100 155 | col_att = self.softmax(col_att_val.view( 156 | (-1, max_x_len))).view(B, -1, max_x_len) 157 | K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) 158 | else: 159 | col_att_val = self.cond_col_att(h_col_enc).squeeze() 160 | for idx, num in enumerate(x_len): 161 | if num < max_x_len: 162 | col_att_val[idx, num:] = -100 163 | col_att = self.softmax(col_att_val) 164 | K_cond_col = (h_col_enc * 165 | col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) 166 | 167 | cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) + 168 | self.cond_col_out_col(e_cond_col)).squeeze() 169 | max_col_num = max(col_num) 170 | for b, num in enumerate(col_num): 171 | if num < max_col_num: 172 | cond_col_score[b, num:] = -100 173 | 174 | #Predict the operator of conditions 175 | chosen_col_gt = [] 176 | if gt_cond is None: 177 | cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) 178 | col_scores = cond_col_score.data.cpu().numpy() 179 | chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]]) 180 | for b in range(len(cond_nums))] 181 | else: 182 | chosen_col_gt = [ [x[0] for x in one_gt_cond] for 183 | one_gt_cond in gt_cond] 184 | 185 | e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, 186 | col_len, self.cond_op_name_enc) 187 | col_emb = [] 188 | for b in range(B): 189 | cur_col_emb = torch.stack([e_cond_col[b, x] 190 | for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * 191 | (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) 192 | col_emb.append(cur_col_emb) 193 | col_emb = torch.stack(col_emb) 194 | 195 | h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) 196 | if self.use_ca: 197 | op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1), 198 | col_emb.unsqueeze(3)).squeeze() 199 | for idx, num in enumerate(x_len): 200 | if num < max_x_len: 201 | op_att_val[idx, :, num:] = -100 202 | op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1) 203 | K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) 204 | else: 205 | op_att_val = self.cond_op_att(h_op_enc).squeeze() 206 | for idx, num in enumerate(x_len): 207 | if num < max_x_len: 208 | op_att_val[idx, num:] = -100 209 | op_att = self.softmax(op_att_val) 210 | K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) 211 | 212 | cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) + 213 | self.cond_op_out_col(col_emb)).squeeze() 214 | 215 | #Predict the string of conditions 216 | h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) 217 | e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, 218 | col_len, self.cond_str_name_enc) 219 | col_emb = [] 220 | for b in range(B): 221 | cur_col_emb = torch.stack([e_cond_col[b, x] 222 | for x in chosen_col_gt[b]] + 223 | [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) 224 | col_emb.append(cur_col_emb) 225 | col_emb = torch.stack(col_emb) 226 | 227 | if gt_where is not None: 228 | gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) 229 | g_str_s_flat, _ = self.cond_str_decoder( 230 | gt_tok_seq.view(B*4, -1, self.max_tok_num)) 231 | g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) 232 | 233 | h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) 234 | g_ext = g_str_s.unsqueeze(3) 235 | col_ext = col_emb.unsqueeze(2).unsqueeze(2) 236 | 237 | cond_str_score = self.cond_str_out( 238 | self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + 239 | self.cond_str_out_col(col_ext)).squeeze() 240 | for b, num in enumerate(x_len): 241 | if num < max_x_len: 242 | cond_str_score[b, :, :, num:] = -100 243 | else: 244 | h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) 245 | col_ext = col_emb.unsqueeze(2).unsqueeze(2) 246 | scores = [] 247 | 248 | t = 0 249 | init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32) 250 | init_inp[:,0,0] = 1 #Set the token 251 | if self.gpu: 252 | cur_inp = Variable(torch.from_numpy(init_inp).cuda()) 253 | else: 254 | cur_inp = Variable(torch.from_numpy(init_inp)) 255 | cur_h = None 256 | while t < 50: 257 | if cur_h: 258 | g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h) 259 | else: 260 | g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) 261 | g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) 262 | g_ext = g_str_s.unsqueeze(3) 263 | 264 | cur_cond_str_score = self.cond_str_out( 265 | self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) 266 | + self.cond_str_out_col(col_ext)).squeeze() 267 | for b, num in enumerate(x_len): 268 | if num < max_x_len: 269 | cur_cond_str_score[b, :, num:] = -100 270 | scores.append(cur_cond_str_score) 271 | 272 | _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1) 273 | ans_tok = ans_tok_var.data.cpu() 274 | data = torch.zeros(B*4, self.max_tok_num).scatter_( 275 | 1, ans_tok.unsqueeze(1), 1) 276 | if self.gpu: #To one-hot 277 | cur_inp = Variable(data.cuda()) 278 | else: 279 | cur_inp = Variable(data) 280 | cur_inp = cur_inp.unsqueeze(1) 281 | 282 | t += 1 283 | 284 | cond_str_score = torch.stack(scores, 2) 285 | for b, num in enumerate(x_len): 286 | if num < max_x_len: 287 | cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM] 288 | 289 | cond_score = (cond_num_score, 290 | cond_col_score, cond_op_score, cond_str_score) 291 | 292 | return cond_score 293 | -------------------------------------------------------------------------------- /sqlnet/model/modules/word_embedding.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | class WordEmbedding(nn.Module): 9 | def __init__(self, word_emb, N_word, gpu, SQL_TOK, 10 | our_model, trainable=False): 11 | super(WordEmbedding, self).__init__() 12 | self.trainable = trainable 13 | self.N_word = N_word 14 | self.our_model = our_model 15 | self.gpu = gpu 16 | self.SQL_TOK = SQL_TOK 17 | 18 | if trainable: 19 | print "Using trainable embedding" 20 | self.w2i, word_emb_val = word_emb 21 | self.embedding = nn.Embedding(len(self.w2i), N_word) 22 | self.embedding.weight = nn.Parameter( 23 | torch.from_numpy(word_emb_val.astype(np.float32))) 24 | else: 25 | self.word_emb = word_emb 26 | print "Using fixed embedding" 27 | 28 | 29 | def gen_x_batch(self, q, col): 30 | B = len(q) 31 | val_embs = [] 32 | val_len = np.zeros(B, dtype=np.int64) 33 | for i, (one_q, one_col) in enumerate(zip(q, col)): 34 | if self.trainable: 35 | q_val = map(lambda x:self.w2i.get(x, 0), one_q) 36 | else: 37 | q_val = map(lambda x:self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)), one_q) 38 | if self.our_model: 39 | if self.trainable: 40 | val_embs.append([1] + q_val + [2]) # and 41 | else: 42 | val_embs.append([np.zeros(self.N_word, dtype=np.float32)] + q_val + [np.zeros(self.N_word, dtype=np.float32)]) # and 43 | val_len[i] = 1 + len(q_val) + 1 44 | else: 45 | one_col_all = [x for toks in one_col for x in toks+[',']] 46 | if self.trainable: 47 | col_val = map(lambda x:self.w2i.get(x, 0), one_col_all) 48 | val_embs.append( [0 for _ in self.SQL_TOK] + col_val + [0] + q_val+ [0]) 49 | else: 50 | col_val = map(lambda x:self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)), one_col_all) 51 | val_embs.append( [np.zeros(self.N_word, dtype=np.float32) for _ in self.SQL_TOK] + col_val + [np.zeros(self.N_word, dtype=np.float32)] + q_val+ [np.zeros(self.N_word, dtype=np.float32)]) 52 | val_len[i] = len(self.SQL_TOK) + len(col_val) + 1 + len(q_val) + 1 53 | max_len = max(val_len) 54 | 55 | if self.trainable: 56 | val_tok_array = np.zeros((B, max_len), dtype=np.int64) 57 | for i in range(B): 58 | for t in range(len(val_embs[i])): 59 | val_tok_array[i,t] = val_embs[i][t] 60 | val_tok = torch.from_numpy(val_tok_array) 61 | if self.gpu: 62 | val_tok = val_tok.cuda() 63 | val_tok_var = Variable(val_tok) 64 | val_inp_var = self.embedding(val_tok_var) 65 | else: 66 | val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32) 67 | for i in range(B): 68 | for t in range(len(val_embs[i])): 69 | val_emb_array[i,t,:] = val_embs[i][t] 70 | val_inp = torch.from_numpy(val_emb_array) 71 | if self.gpu: 72 | val_inp = val_inp.cuda() 73 | val_inp_var = Variable(val_inp) 74 | return val_inp_var, val_len 75 | 76 | def gen_col_batch(self, cols): 77 | ret = [] 78 | col_len = np.zeros(len(cols), dtype=np.int64) 79 | 80 | names = [] 81 | for b, one_cols in enumerate(cols): 82 | names = names + one_cols 83 | col_len[b] = len(one_cols) 84 | 85 | name_inp_var, name_len = self.str_list_to_batch(names) 86 | return name_inp_var, name_len, col_len 87 | 88 | def str_list_to_batch(self, str_list): 89 | B = len(str_list) 90 | 91 | val_embs = [] 92 | val_len = np.zeros(B, dtype=np.int64) 93 | for i, one_str in enumerate(str_list): 94 | if self.trainable: 95 | val = [self.w2i.get(x, 0) for x in one_str] 96 | else: 97 | val = [self.word_emb.get(x, np.zeros( 98 | self.N_word, dtype=np.float32)) for x in one_str] 99 | val_embs.append(val) 100 | val_len[i] = len(val) 101 | max_len = max(val_len) 102 | 103 | if self.trainable: 104 | val_tok_array = np.zeros((B, max_len), dtype=np.int64) 105 | for i in range(B): 106 | for t in range(len(val_embs[i])): 107 | val_tok_array[i,t] = val_embs[i][t] 108 | val_tok = torch.from_numpy(val_tok_array) 109 | if self.gpu: 110 | val_tok = val_tok.cuda() 111 | val_tok_var = Variable(val_tok) 112 | val_inp_var = self.embedding(val_tok_var) 113 | else: 114 | val_emb_array = np.zeros( 115 | (B, max_len, self.N_word), dtype=np.float32) 116 | for i in range(B): 117 | for t in range(len(val_embs[i])): 118 | val_emb_array[i,t,:] = val_embs[i][t] 119 | val_inp = torch.from_numpy(val_emb_array) 120 | if self.gpu: 121 | val_inp = val_inp.cuda() 122 | val_inp_var = Variable(val_inp) 123 | 124 | return val_inp_var, val_len 125 | -------------------------------------------------------------------------------- /sqlnet/model/seq2sql.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from modules.word_embedding import WordEmbedding 8 | from modules.aggregator_predict import AggPredictor 9 | from modules.selection_predict import SelPredictor 10 | from modules.seq2sql_condition_predict import Seq2SQLCondPredictor 11 | 12 | # This is a re-implementation based on the following paper: 13 | 14 | # Victor Zhong, Caiming Xiong, and Richard Socher. 2017. 15 | # Seq2SQL: Generating Structured Queries from Natural Language using 16 | # Reinforcement Learning. arXiv:1709.00103 17 | 18 | class Seq2SQL(nn.Module): 19 | def __init__(self, word_emb, N_word, N_h=100, N_depth=2, 20 | gpu=False, trainable_emb=False): 21 | super(Seq2SQL, self).__init__() 22 | self.trainable_emb = trainable_emb 23 | 24 | self.gpu = gpu 25 | self.N_h = N_h 26 | self.N_depth = N_depth 27 | 28 | self.max_col_num = 45 29 | self.max_tok_num = 200 30 | self.SQL_TOK = ['', '', 'WHERE', 'AND', 31 | 'EQL', 'GT', 'LT', ''] 32 | self.COND_OPS = ['EQL', 'GT', 'LT'] 33 | 34 | #Word embedding 35 | if trainable_emb: 36 | self.agg_embed_layer = WordEmbedding(word_emb, N_word, gpu, 37 | self.SQL_TOK, our_model=False, 38 | trainable=trainable_emb) 39 | self.sel_embed_layer = WordEmbedding(word_emb, N_word, gpu, 40 | self.SQL_TOK, our_model=False, 41 | trainable=trainable_emb) 42 | self.cond_embed_layer = WordEmbedding(word_emb, N_word, gpu, 43 | self.SQL_TOK, our_model=False, 44 | trainable=trainable_emb) 45 | else: 46 | self.embed_layer = WordEmbedding(word_emb, N_word, gpu, 47 | self.SQL_TOK, our_model=False, 48 | trainable=trainable_emb) 49 | 50 | #Predict aggregator 51 | self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=False) 52 | 53 | #Predict selected column 54 | self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, 55 | use_ca=False) 56 | 57 | #Predict number of cond 58 | self.cond_pred = Seq2SQLCondPredictor( 59 | N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, gpu) 60 | 61 | 62 | self.CE = nn.CrossEntropyLoss() 63 | self.softmax = nn.Softmax() 64 | self.log_softmax = nn.LogSoftmax() 65 | self.bce_logit = nn.BCEWithLogitsLoss() 66 | if gpu: 67 | self.cuda() 68 | 69 | 70 | def generate_gt_where_seq(self, q, col, query): 71 | # data format 72 | # WHERE cond1_col cond1_op cond1 73 | # AND cond2_col cond2_op cond2 74 | # AND ... 75 | 76 | ret_seq = [] 77 | for cur_q, cur_col, cur_query in zip(q, col, query): 78 | connect_col = [tok for col_tok in cur_col for tok in col_tok+[',']] 79 | all_toks = self.SQL_TOK + connect_col + [None] + cur_q + [None] 80 | cur_seq = [all_toks.index('')] 81 | if 'WHERE' in cur_query: 82 | cur_where_query = cur_query[cur_query.index('WHERE'):] 83 | cur_seq = cur_seq + map(lambda tok:all_toks.index(tok) 84 | if tok in all_toks else 0, cur_where_query) 85 | cur_seq.append(all_toks.index('')) 86 | ret_seq.append(cur_seq) 87 | return ret_seq 88 | 89 | 90 | def forward(self, q, col, col_num, pred_entry, 91 | gt_where = None, gt_cond=None, reinforce=False, gt_sel=None): 92 | B = len(q) 93 | pred_agg, pred_sel, pred_cond = pred_entry 94 | 95 | agg_score = None 96 | sel_score = None 97 | cond_score = None 98 | 99 | if self.trainable_emb: 100 | if pred_agg: 101 | x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col) 102 | batch = self.agg_embed_layer.gen_col_batch(col) 103 | col_inp_var, col_name_len, col_len = batch 104 | max_x_len = max(x_len) 105 | agg_score = self.agg_pred(x_emb_var, x_len) 106 | 107 | if pred_sel: 108 | x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col) 109 | batch = self.sel_embed_layer.gen_col_batch(col) 110 | col_inp_var, col_name_len, col_len = batch 111 | max_x_len = max(x_len) 112 | sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, 113 | col_name_len, col_len, col_num) 114 | 115 | if pred_cond: 116 | x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col) 117 | batch = self.cond_embed_layer.gen_col_batch(col) 118 | col_inp_var, col_name_len, col_len = batch 119 | max_x_len = max(x_len) 120 | cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, 121 | col_name_len, col_len, col_num, 122 | gt_where, gt_cond, 123 | reinforce=reinforce) 124 | else: 125 | x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col) 126 | batch = self.embed_layer.gen_col_batch(col) 127 | col_inp_var, col_name_len, col_len = batch 128 | max_x_len = max(x_len) 129 | if pred_agg: 130 | agg_score = self.agg_pred(x_emb_var, x_len) 131 | 132 | if pred_sel: 133 | sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, 134 | col_name_len, col_len, col_num) 135 | 136 | if pred_cond: 137 | cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, 138 | col_name_len, col_len, col_num, 139 | gt_where, gt_cond, 140 | reinforce=reinforce) 141 | 142 | return (agg_score, sel_score, cond_score) 143 | 144 | def loss(self, score, truth_num, pred_entry, gt_where): 145 | pred_agg, pred_sel, pred_cond = pred_entry 146 | agg_score, sel_score, cond_score = score 147 | loss = 0 148 | if pred_agg: 149 | agg_truth = map(lambda x:x[0], truth_num) 150 | data = torch.from_numpy(np.array(agg_truth)) 151 | if self.gpu: 152 | agg_truth_var = Variable(data.cuda()) 153 | else: 154 | agg_truth_var = Variable(data) 155 | 156 | loss += self.CE(agg_score, agg_truth_var) 157 | 158 | if pred_sel: 159 | sel_truth = map(lambda x:x[1], truth_num) 160 | data = torch.from_numpy(np.array(sel_truth)) 161 | if self.gpu: 162 | sel_truth_var = Variable(data).cuda() 163 | else: 164 | sel_truth_var = Variable(data) 165 | 166 | loss += self.CE(sel_score, sel_truth_var) 167 | 168 | if pred_cond: 169 | for b in range(len(gt_where)): 170 | if self.gpu: 171 | cond_truth_var = Variable( 172 | torch.from_numpy(np.array(gt_where[b][1:])).cuda()) 173 | else: 174 | cond_truth_var = Variable( 175 | torch.from_numpy(np.array(gt_where[b][1:]))) 176 | cond_pred_score = cond_score[b, :len(gt_where[b])-1] 177 | 178 | loss += ( self.CE( 179 | cond_pred_score, cond_truth_var) / len(gt_where) ) 180 | 181 | return loss 182 | 183 | def reinforce_backward(self, score, rewards): 184 | agg_score, sel_score, cond_score = score 185 | 186 | cur_reward = rewards[:] 187 | eof = self.SQL_TOK.index('') 188 | for t in range(len(cond_score[1])): 189 | reward_inp = torch.FloatTensor(cur_reward).unsqueeze(1) 190 | if self.gpu: 191 | reward_inp = reward_inp.cuda() 192 | cond_score[1][t].reinforce(reward_inp) 193 | 194 | for b in range(len(rewards)): 195 | if cond_score[1][t][b].data.cpu().numpy()[0] == eof: 196 | cur_reward[b] = 0 197 | torch.autograd.backward(cond_score[1], [None for _ in cond_score[1]]) 198 | return 199 | 200 | def check_acc(self, vis_info, pred_queries, gt_queries, pred_entry): 201 | def pretty_print(vis_data): 202 | print 'question:', vis_data[0] 203 | print 'headers: (%s)'%(' || '.join(vis_data[1])) 204 | print 'query:', vis_data[2] 205 | 206 | def gen_cond_str(conds, header): 207 | if len(conds) == 0: 208 | return 'None' 209 | cond_str = [] 210 | for cond in conds: 211 | cond_str.append( 212 | header[cond[0]] + ' ' + self.COND_OPS[cond[1]] + \ 213 | ' ' + unicode(cond[2]).lower()) 214 | return 'WHERE ' + ' AND '.join(cond_str) 215 | 216 | pred_agg, pred_sel, pred_cond = pred_entry 217 | 218 | B = len(gt_queries) 219 | 220 | tot_err = agg_err = sel_err = cond_err = cond_num_err = \ 221 | cond_col_err = cond_op_err = cond_val_err = 0.0 222 | agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 223 | for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)): 224 | good = True 225 | if pred_agg: 226 | agg_pred = pred_qry['agg'] 227 | agg_gt = gt_qry['agg'] 228 | if agg_pred != agg_gt: 229 | agg_err += 1 230 | good = False 231 | 232 | if pred_sel: 233 | sel_pred = pred_qry['sel'] 234 | sel_gt = gt_qry['sel'] 235 | if sel_pred != sel_gt: 236 | sel_err += 1 237 | good = False 238 | 239 | if pred_cond: 240 | cond_pred = pred_qry['conds'] 241 | cond_gt = gt_qry['conds'] 242 | flag = True 243 | if len(cond_pred) != len(cond_gt): 244 | flag = False 245 | cond_num_err += 1 246 | 247 | if flag and set( 248 | x[0] for x in cond_pred) != set(x[0] for x in cond_gt): 249 | flag = False 250 | cond_col_err += 1 251 | 252 | for idx in range(len(cond_pred)): 253 | if not flag: 254 | break 255 | gt_idx = tuple(x[0] for x in cond_gt).index(cond_pred[idx][0]) 256 | if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]: 257 | flag = False 258 | cond_op_err += 1 259 | 260 | for idx in range(len(cond_pred)): 261 | if not flag: 262 | break 263 | gt_idx = tuple(x[0] for x in cond_gt).index(cond_pred[idx][0]) 264 | if flag and unicode(cond_gt[gt_idx][2]).lower() != \ 265 | unicode(cond_pred[idx][2]).lower(): 266 | flag = False 267 | cond_val_err += 1 268 | 269 | if not flag: 270 | cond_err += 1 271 | good = False 272 | 273 | if not good: 274 | tot_err += 1 275 | 276 | return np.array((agg_err, sel_err, cond_err)), tot_err 277 | 278 | 279 | def gen_query(self, score, q, col, raw_q, raw_col, pred_entry, 280 | reinforce=False, verbose=False): 281 | def merge_tokens(tok_list, raw_tok_str): 282 | tok_str = raw_tok_str.lower() 283 | alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$(' 284 | special = {'-LRB-':'(', '-RRB-':')', '-LSB-':'[', '-RSB-':']', 285 | '``':'"', '\'\'':'"', '--':u'\u2013'} 286 | ret = '' 287 | double_quote_appear = 0 288 | for raw_tok in tok_list: 289 | if not raw_tok: 290 | continue 291 | tok = special.get(raw_tok, raw_tok) 292 | if tok == '"': 293 | double_quote_appear = 1 - double_quote_appear 294 | 295 | if len(ret) == 0: 296 | pass 297 | elif len(ret) > 0 and ret + ' ' + tok in tok_str: 298 | ret = ret + ' ' 299 | elif len(ret) > 0 and ret + tok in tok_str: 300 | pass 301 | elif tok == '"': 302 | if double_quote_appear: 303 | ret = ret + ' ' 304 | elif tok[0] not in alphabet: 305 | pass 306 | elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) and \ 307 | (ret[-1] != '"' or not double_quote_appear): 308 | ret = ret + ' ' 309 | ret = ret + tok 310 | return ret.strip() 311 | 312 | pred_agg, pred_sel, pred_cond = pred_entry 313 | agg_score, sel_score, cond_score = score 314 | 315 | ret_queries = [] 316 | if pred_agg: 317 | B = len(agg_score) 318 | elif pred_sel: 319 | B = len(sel_score) 320 | elif pred_cond: 321 | B = len(cond_score[0]) if reinforce else len(cond_score) 322 | for b in range(B): 323 | cur_query = {} 324 | if pred_agg: 325 | cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy()) 326 | if pred_sel: 327 | cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy()) 328 | if pred_cond: 329 | cur_query['conds'] = [] 330 | all_toks = self.SQL_TOK + \ 331 | [x for toks in col[b] for x in 332 | toks+[',']] + [''] + q[b] + [''] 333 | cond_toks = [] 334 | if reinforce: 335 | for choices in cond_score[1]: 336 | if choices[b].data.cpu().numpy()[0] < len(all_toks): 337 | cond_val = all_toks[choices[b].data.cpu().numpy()[0]] 338 | else: 339 | cond_val = '' 340 | if cond_val == '': 341 | break 342 | cond_toks.append(cond_val) 343 | else: 344 | for where_score in cond_score[b].data.cpu().numpy(): 345 | cond_tok = np.argmax(where_score) 346 | cond_val = all_toks[cond_tok] 347 | if cond_val == '': 348 | break 349 | cond_toks.append(cond_val) 350 | 351 | if verbose: 352 | print cond_toks 353 | if len(cond_toks) > 0: 354 | cond_toks = cond_toks[1:] 355 | st = 0 356 | while st < len(cond_toks): 357 | cur_cond = [None, None, None] 358 | ed = len(cond_toks) if 'AND' not in cond_toks[st:] \ 359 | else cond_toks[st:].index('AND') + st 360 | if 'EQL' in cond_toks[st:ed]: 361 | op = cond_toks[st:ed].index('EQL') + st 362 | cur_cond[1] = 0 363 | elif 'GT' in cond_toks[st:ed]: 364 | op = cond_toks[st:ed].index('GT') + st 365 | cur_cond[1] = 1 366 | elif 'LT' in cond_toks[st:ed]: 367 | op = cond_toks[st:ed].index('LT') + st 368 | cur_cond[1] = 2 369 | else: 370 | op = st 371 | cur_cond[1] = 0 372 | sel_col = cond_toks[st:op] 373 | to_idx = [x.lower() for x in raw_col[b]] 374 | pred_col = merge_tokens(sel_col, raw_q[b] + ' || ' + \ 375 | ' || '.join(raw_col[b])) 376 | if pred_col in to_idx: 377 | cur_cond[0] = to_idx.index(pred_col) 378 | else: 379 | cur_cond[0] = 0 380 | cur_cond[2] = merge_tokens(cond_toks[op+1:ed], raw_q[b]) 381 | cur_query['conds'].append(cur_cond) 382 | st = ed + 1 383 | ret_queries.append(cur_query) 384 | 385 | return ret_queries 386 | -------------------------------------------------------------------------------- /sqlnet/model/sqlnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from modules.word_embedding import WordEmbedding 8 | from modules.aggregator_predict import AggPredictor 9 | from modules.selection_predict import SelPredictor 10 | from modules.sqlnet_condition_predict import SQLNetCondPredictor 11 | 12 | 13 | class SQLNet(nn.Module): 14 | def __init__(self, word_emb, N_word, N_h=100, N_depth=2, 15 | gpu=False, use_ca=True, trainable_emb=False): 16 | super(SQLNet, self).__init__() 17 | self.use_ca = use_ca 18 | self.trainable_emb = trainable_emb 19 | 20 | self.gpu = gpu 21 | self.N_h = N_h 22 | self.N_depth = N_depth 23 | 24 | self.max_col_num = 45 25 | self.max_tok_num = 200 26 | self.SQL_TOK = ['', '', 'WHERE', 'AND', 27 | 'EQL', 'GT', 'LT', ''] 28 | self.COND_OPS = ['EQL', 'GT', 'LT'] 29 | 30 | #Word embedding 31 | if trainable_emb: 32 | self.agg_embed_layer = WordEmbedding(word_emb, N_word, gpu, 33 | self.SQL_TOK, our_model=True, trainable=trainable_emb) 34 | self.sel_embed_layer = WordEmbedding(word_emb, N_word, gpu, 35 | self.SQL_TOK, our_model=True, trainable=trainable_emb) 36 | self.cond_embed_layer = WordEmbedding(word_emb, N_word, gpu, 37 | self.SQL_TOK, our_model=True, trainable=trainable_emb) 38 | else: 39 | self.embed_layer = WordEmbedding(word_emb, N_word, gpu, 40 | self.SQL_TOK, our_model=True, trainable=trainable_emb) 41 | 42 | #Predict aggregator 43 | self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca) 44 | 45 | #Predict selected column 46 | self.sel_pred = SelPredictor(N_word, N_h, N_depth, 47 | self.max_tok_num, use_ca=use_ca) 48 | 49 | #Predict number of cond 50 | self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, 51 | self.max_col_num, self.max_tok_num, use_ca, gpu) 52 | 53 | 54 | self.CE = nn.CrossEntropyLoss() 55 | self.softmax = nn.Softmax() 56 | self.log_softmax = nn.LogSoftmax() 57 | self.bce_logit = nn.BCEWithLogitsLoss() 58 | if gpu: 59 | self.cuda() 60 | 61 | 62 | def generate_gt_where_seq(self, q, col, query): 63 | ret_seq = [] 64 | for cur_q, cur_col, cur_query in zip(q, col, query): 65 | cur_values = [] 66 | st = cur_query.index(u'WHERE')+1 if \ 67 | u'WHERE' in cur_query else len(cur_query) 68 | all_toks = [''] + cur_q + [''] 69 | while st < len(cur_query): 70 | ed = len(cur_query) if 'AND' not in cur_query[st:]\ 71 | else cur_query[st:].index('AND') + st 72 | if 'EQL' in cur_query[st:ed]: 73 | op = cur_query[st:ed].index('EQL') + st 74 | elif 'GT' in cur_query[st:ed]: 75 | op = cur_query[st:ed].index('GT') + st 76 | elif 'LT' in cur_query[st:ed]: 77 | op = cur_query[st:ed].index('LT') + st 78 | else: 79 | raise RuntimeError("No operator in it!") 80 | this_str = [''] + cur_query[op+1:ed] + [''] 81 | cur_seq = [all_toks.index(s) if s in all_toks \ 82 | else 0 for s in this_str] 83 | cur_values.append(cur_seq) 84 | st = ed+1 85 | ret_seq.append(cur_values) 86 | return ret_seq 87 | 88 | 89 | def forward(self, q, col, col_num, pred_entry, 90 | gt_where = None, gt_cond=None, reinforce=False, gt_sel=None): 91 | B = len(q) 92 | pred_agg, pred_sel, pred_cond = pred_entry 93 | 94 | agg_score = None 95 | sel_score = None 96 | cond_score = None 97 | 98 | #Predict aggregator 99 | if self.trainable_emb: 100 | if pred_agg: 101 | x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col) 102 | col_inp_var, col_name_len, col_len = \ 103 | self.agg_embed_layer.gen_col_batch(col) 104 | max_x_len = max(x_len) 105 | agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, 106 | col_name_len, col_len, col_num, gt_sel=gt_sel) 107 | 108 | if pred_sel: 109 | x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col) 110 | col_inp_var, col_name_len, col_len = \ 111 | self.sel_embed_layer.gen_col_batch(col) 112 | max_x_len = max(x_len) 113 | sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, 114 | col_name_len, col_len, col_num) 115 | 116 | if pred_cond: 117 | x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col) 118 | col_inp_var, col_name_len, col_len = \ 119 | self.cond_embed_layer.gen_col_batch(col) 120 | max_x_len = max(x_len) 121 | cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, 122 | col_name_len, col_len, col_num, 123 | gt_where, gt_cond, reinforce=reinforce) 124 | else: 125 | x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col) 126 | col_inp_var, col_name_len, col_len = \ 127 | self.embed_layer.gen_col_batch(col) 128 | max_x_len = max(x_len) 129 | if pred_agg: 130 | agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, 131 | col_name_len, col_len, col_num, gt_sel=gt_sel) 132 | 133 | if pred_sel: 134 | sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, 135 | col_name_len, col_len, col_num) 136 | 137 | if pred_cond: 138 | cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, 139 | col_name_len, col_len, col_num, 140 | gt_where, gt_cond, reinforce=reinforce) 141 | 142 | return (agg_score, sel_score, cond_score) 143 | 144 | def loss(self, score, truth_num, pred_entry, gt_where): 145 | pred_agg, pred_sel, pred_cond = pred_entry 146 | agg_score, sel_score, cond_score = score 147 | 148 | loss = 0 149 | if pred_agg: 150 | agg_truth = map(lambda x:x[0], truth_num) 151 | data = torch.from_numpy(np.array(agg_truth)) 152 | if self.gpu: 153 | agg_truth_var = Variable(data.cuda()) 154 | else: 155 | agg_truth_var = Variable(data) 156 | 157 | loss += self.CE(agg_score, agg_truth_var) 158 | 159 | if pred_sel: 160 | sel_truth = map(lambda x:x[1], truth_num) 161 | data = torch.from_numpy(np.array(sel_truth)) 162 | if self.gpu: 163 | sel_truth_var = Variable(data.cuda()) 164 | else: 165 | sel_truth_var = Variable(data) 166 | 167 | loss += self.CE(sel_score, sel_truth_var) 168 | 169 | if pred_cond: 170 | B = len(truth_num) 171 | cond_num_score, cond_col_score,\ 172 | cond_op_score, cond_str_score = cond_score 173 | #Evaluate the number of conditions 174 | cond_num_truth = map(lambda x:x[2], truth_num) 175 | data = torch.from_numpy(np.array(cond_num_truth)) 176 | if self.gpu: 177 | cond_num_truth_var = Variable(data.cuda()) 178 | else: 179 | cond_num_truth_var = Variable(data) 180 | loss += self.CE(cond_num_score, cond_num_truth_var) 181 | 182 | #Evaluate the columns of conditions 183 | T = len(cond_col_score[0]) 184 | truth_prob = np.zeros((B, T), dtype=np.float32) 185 | for b in range(B): 186 | if len(truth_num[b][3]) > 0: 187 | truth_prob[b][list(truth_num[b][3])] = 1 188 | data = torch.from_numpy(truth_prob) 189 | if self.gpu: 190 | cond_col_truth_var = Variable(data.cuda()) 191 | else: 192 | cond_col_truth_var = Variable(data) 193 | 194 | sigm = nn.Sigmoid() 195 | cond_col_prob = sigm(cond_col_score) 196 | bce_loss = -torch.mean( 3*(cond_col_truth_var * \ 197 | torch.log(cond_col_prob+1e-10)) + \ 198 | (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) ) 199 | loss += bce_loss 200 | 201 | #Evaluate the operator of conditions 202 | for b in range(len(truth_num)): 203 | if len(truth_num[b][4]) == 0: 204 | continue 205 | data = torch.from_numpy(np.array(truth_num[b][4])) 206 | if self.gpu: 207 | cond_op_truth_var = Variable(data.cuda()) 208 | else: 209 | cond_op_truth_var = Variable(data) 210 | cond_op_pred = cond_op_score[b, :len(truth_num[b][4])] 211 | loss += (self.CE(cond_op_pred, cond_op_truth_var) \ 212 | / len(truth_num)) 213 | 214 | #Evaluate the strings of conditions 215 | for b in range(len(gt_where)): 216 | for idx in range(len(gt_where[b])): 217 | cond_str_truth = gt_where[b][idx] 218 | if len(cond_str_truth) == 1: 219 | continue 220 | data = torch.from_numpy(np.array(cond_str_truth[1:])) 221 | if self.gpu: 222 | cond_str_truth_var = Variable(data.cuda()) 223 | else: 224 | cond_str_truth_var = Variable(data) 225 | str_end = len(cond_str_truth)-1 226 | cond_str_pred = cond_str_score[b, idx, :str_end] 227 | loss += (self.CE(cond_str_pred, cond_str_truth_var) \ 228 | / (len(gt_where) * len(gt_where[b]))) 229 | 230 | return loss 231 | 232 | def check_acc(self, vis_info, pred_queries, gt_queries, pred_entry): 233 | def pretty_print(vis_data): 234 | print 'question:', vis_data[0] 235 | print 'headers: (%s)'%(' || '.join(vis_data[1])) 236 | print 'query:', vis_data[2] 237 | 238 | def gen_cond_str(conds, header): 239 | if len(conds) == 0: 240 | return 'None' 241 | cond_str = [] 242 | for cond in conds: 243 | cond_str.append(header[cond[0]] + ' ' + 244 | self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower()) 245 | return 'WHERE ' + ' AND '.join(cond_str) 246 | 247 | pred_agg, pred_sel, pred_cond = pred_entry 248 | 249 | B = len(gt_queries) 250 | 251 | tot_err = agg_err = sel_err = cond_err = 0.0 252 | cond_num_err = cond_col_err = cond_op_err = cond_val_err = 0.0 253 | agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 254 | for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)): 255 | good = True 256 | if pred_agg: 257 | agg_pred = pred_qry['agg'] 258 | agg_gt = gt_qry['agg'] 259 | if agg_pred != agg_gt: 260 | agg_err += 1 261 | good = False 262 | 263 | if pred_sel: 264 | sel_pred = pred_qry['sel'] 265 | sel_gt = gt_qry['sel'] 266 | if sel_pred != sel_gt: 267 | sel_err += 1 268 | good = False 269 | 270 | if pred_cond: 271 | cond_pred = pred_qry['conds'] 272 | cond_gt = gt_qry['conds'] 273 | flag = True 274 | if len(cond_pred) != len(cond_gt): 275 | flag = False 276 | cond_num_err += 1 277 | 278 | if flag and set(x[0] for x in cond_pred) != \ 279 | set(x[0] for x in cond_gt): 280 | flag = False 281 | cond_col_err += 1 282 | 283 | for idx in range(len(cond_pred)): 284 | if not flag: 285 | break 286 | gt_idx = tuple( 287 | x[0] for x in cond_gt).index(cond_pred[idx][0]) 288 | if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]: 289 | flag = False 290 | cond_op_err += 1 291 | 292 | for idx in range(len(cond_pred)): 293 | if not flag: 294 | break 295 | gt_idx = tuple( 296 | x[0] for x in cond_gt).index(cond_pred[idx][0]) 297 | if flag and unicode(cond_gt[gt_idx][2]).lower() != \ 298 | unicode(cond_pred[idx][2]).lower(): 299 | flag = False 300 | cond_val_err += 1 301 | 302 | if not flag: 303 | cond_err += 1 304 | good = False 305 | 306 | if not good: 307 | tot_err += 1 308 | 309 | return np.array((agg_err, sel_err, cond_err)), tot_err 310 | 311 | 312 | def gen_query(self, score, q, col, raw_q, raw_col, 313 | pred_entry, reinforce=False, verbose=False): 314 | def merge_tokens(tok_list, raw_tok_str): 315 | tok_str = raw_tok_str.lower() 316 | alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$(' 317 | special = {'-LRB-':'(', 318 | '-RRB-':')', 319 | '-LSB-':'[', 320 | '-RSB-':']', 321 | '``':'"', 322 | '\'\'':'"', 323 | '--':u'\u2013'} 324 | ret = '' 325 | double_quote_appear = 0 326 | for raw_tok in tok_list: 327 | if not raw_tok: 328 | continue 329 | tok = special.get(raw_tok, raw_tok) 330 | if tok == '"': 331 | double_quote_appear = 1 - double_quote_appear 332 | 333 | if len(ret) == 0: 334 | pass 335 | elif len(ret) > 0 and ret + ' ' + tok in tok_str: 336 | ret = ret + ' ' 337 | elif len(ret) > 0 and ret + tok in tok_str: 338 | pass 339 | elif tok == '"': 340 | if double_quote_appear: 341 | ret = ret + ' ' 342 | elif tok[0] not in alphabet: 343 | pass 344 | elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \ 345 | and (ret[-1] != '"' or not double_quote_appear): 346 | ret = ret + ' ' 347 | ret = ret + tok 348 | return ret.strip() 349 | 350 | pred_agg, pred_sel, pred_cond = pred_entry 351 | agg_score, sel_score, cond_score = score 352 | 353 | ret_queries = [] 354 | if pred_agg: 355 | B = len(agg_score) 356 | elif pred_sel: 357 | B = len(sel_score) 358 | elif pred_cond: 359 | B = len(cond_score[0]) 360 | for b in range(B): 361 | cur_query = {} 362 | if pred_agg: 363 | cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy()) 364 | if pred_sel: 365 | cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy()) 366 | if pred_cond: 367 | cur_query['conds'] = [] 368 | cond_num_score,cond_col_score,cond_op_score,cond_str_score =\ 369 | [x.data.cpu().numpy() for x in cond_score] 370 | cond_num = np.argmax(cond_num_score[b]) 371 | all_toks = [''] + q[b] + [''] 372 | max_idxes = np.argsort(-cond_col_score[b])[:cond_num] 373 | for idx in range(cond_num): 374 | cur_cond = [] 375 | cur_cond.append(max_idxes[idx]) 376 | cur_cond.append(np.argmax(cond_op_score[b][idx])) 377 | cur_cond_str_toks = [] 378 | for str_score in cond_str_score[b][idx]: 379 | str_tok = np.argmax(str_score[:len(all_toks)]) 380 | str_val = all_toks[str_tok] 381 | if str_val == '': 382 | break 383 | cur_cond_str_toks.append(str_val) 384 | cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b])) 385 | cur_query['conds'].append(cur_cond) 386 | ret_queries.append(cur_query) 387 | 388 | return ret_queries 389 | -------------------------------------------------------------------------------- /sqlnet/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from lib.dbengine import DBEngine 3 | import re 4 | import numpy as np 5 | #from nltk.tokenize import StanfordTokenizer 6 | 7 | def load_data(sql_paths, table_paths, use_small=False): 8 | if not isinstance(sql_paths, list): 9 | sql_paths = (sql_paths, ) 10 | if not isinstance(table_paths, list): 11 | table_paths = (table_paths, ) 12 | sql_data = [] 13 | table_data = {} 14 | 15 | max_col_num = 0 16 | for SQL_PATH in sql_paths: 17 | print "Loading data from %s"%SQL_PATH 18 | with open(SQL_PATH) as inf: 19 | for idx, line in enumerate(inf): 20 | if use_small and idx >= 1000: 21 | break 22 | sql = json.loads(line.strip()) 23 | sql_data.append(sql) 24 | 25 | for TABLE_PATH in table_paths: 26 | print "Loading data from %s"%TABLE_PATH 27 | with open(TABLE_PATH) as inf: 28 | for line in inf: 29 | tab = json.loads(line.strip()) 30 | table_data[tab[u'id']] = tab 31 | 32 | for sql in sql_data: 33 | assert sql[u'table_id'] in table_data 34 | 35 | return sql_data, table_data 36 | 37 | def load_dataset(dataset_id, use_small=False): 38 | if dataset_id == 0: 39 | print "Loading from original dataset" 40 | sql_data, table_data = load_data('data/train_tok.jsonl', 41 | 'data/train_tok.tables.jsonl', use_small=use_small) 42 | val_sql_data, val_table_data = load_data('data/dev_tok.jsonl', 43 | 'data/dev_tok.tables.jsonl', use_small=use_small) 44 | test_sql_data, test_table_data = load_data('data/test_tok.jsonl', 45 | 'data/test_tok.tables.jsonl', use_small=use_small) 46 | TRAIN_DB = 'data/train.db' 47 | DEV_DB = 'data/dev.db' 48 | TEST_DB = 'data/test.db' 49 | else: 50 | print "Loading from re-split dataset" 51 | sql_data, table_data = load_data('data_resplit/train.jsonl', 52 | 'data_resplit/tables.jsonl', use_small=use_small) 53 | val_sql_data, val_table_data = load_data('data_resplit/dev.jsonl', 54 | 'data_resplit/tables.jsonl', use_small=use_small) 55 | test_sql_data, test_table_data = load_data('data_resplit/test.jsonl', 56 | 'data_resplit/tables.jsonl', use_small=use_small) 57 | TRAIN_DB = 'data_resplit/table.db' 58 | DEV_DB = 'data_resplit/table.db' 59 | TEST_DB = 'data_resplit/table.db' 60 | 61 | return sql_data, table_data, val_sql_data, val_table_data,\ 62 | test_sql_data, test_table_data, TRAIN_DB, DEV_DB, TEST_DB 63 | 64 | def best_model_name(args, for_load=False): 65 | new_data = 'new' if args.dataset > 0 else 'old' 66 | mode = 'seq2sql' if args.baseline else 'sqlnet' 67 | if for_load: 68 | use_emb = use_rl = '' 69 | else: 70 | use_emb = '_train_emb' if args.train_emb else '' 71 | use_rl = 'rl_' if args.rl else '' 72 | use_ca = '_ca' if args.ca else '' 73 | 74 | agg_model_name = 'saved_model/%s_%s%s%s.agg_model'%(new_data, 75 | mode, use_emb, use_ca) 76 | sel_model_name = 'saved_model/%s_%s%s%s.sel_model'%(new_data, 77 | mode, use_emb, use_ca) 78 | cond_model_name = 'saved_model/%s_%s%s%s.cond_%smodel'%(new_data, 79 | mode, use_emb, use_ca, use_rl) 80 | 81 | if not for_load and args.train_emb: 82 | agg_embed_name = 'saved_model/%s_%s%s%s.agg_embed'%(new_data, 83 | mode, use_emb, use_ca) 84 | sel_embed_name = 'saved_model/%s_%s%s%s.sel_embed'%(new_data, 85 | mode, use_emb, use_ca) 86 | cond_embed_name = 'saved_model/%s_%s%s%s.cond_embed'%(new_data, 87 | mode, use_emb, use_ca) 88 | 89 | return agg_model_name, sel_model_name, cond_model_name,\ 90 | agg_embed_name, sel_embed_name, cond_embed_name 91 | else: 92 | return agg_model_name, sel_model_name, cond_model_name 93 | 94 | 95 | def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False): 96 | q_seq = [] 97 | col_seq = [] 98 | col_num = [] 99 | ans_seq = [] 100 | query_seq = [] 101 | gt_cond_seq = [] 102 | vis_seq = [] 103 | for i in range(st, ed): 104 | sql = sql_data[idxes[i]] 105 | q_seq.append(sql['question_tok']) 106 | col_seq.append(table_data[sql['table_id']]['header_tok']) 107 | col_num.append(len(table_data[sql['table_id']]['header'])) 108 | ans_seq.append((sql['sql']['agg'], 109 | sql['sql']['sel'], 110 | len(sql['sql']['conds']), 111 | tuple(x[0] for x in sql['sql']['conds']), 112 | tuple(x[1] for x in sql['sql']['conds']))) 113 | query_seq.append(sql['query_tok']) 114 | gt_cond_seq.append(sql['sql']['conds']) 115 | vis_seq.append((sql['question'], 116 | table_data[sql['table_id']]['header'], sql['query'])) 117 | if ret_vis_data: 118 | return q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, vis_seq 119 | else: 120 | return q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq 121 | 122 | def to_batch_query(sql_data, idxes, st, ed): 123 | query_gt = [] 124 | table_ids = [] 125 | for i in range(st, ed): 126 | query_gt.append(sql_data[idxes[i]]['sql']) 127 | table_ids.append(sql_data[idxes[i]]['table_id']) 128 | return query_gt, table_ids 129 | 130 | def epoch_train(model, optimizer, batch_size, sql_data, table_data, pred_entry): 131 | model.train() 132 | perm=np.random.permutation(len(sql_data)) 133 | cum_loss = 0.0 134 | st = 0 135 | while st < len(sql_data): 136 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 137 | 138 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq = \ 139 | to_batch_seq(sql_data, table_data, perm, st, ed) 140 | gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) 141 | gt_sel_seq = [x[1] for x in ans_seq] 142 | score = model.forward(q_seq, col_seq, col_num, pred_entry, 143 | gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq) 144 | loss = model.loss(score, ans_seq, pred_entry, gt_where_seq) 145 | cum_loss += loss.data.cpu().numpy()[0]*(ed - st) 146 | optimizer.zero_grad() 147 | loss.backward() 148 | optimizer.step() 149 | 150 | st = ed 151 | 152 | return cum_loss / len(sql_data) 153 | 154 | def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path): 155 | engine = DBEngine(db_path) 156 | 157 | model.eval() 158 | perm = list(range(len(sql_data))) 159 | tot_acc_num = 0.0 160 | acc_of_log = 0.0 161 | st = 0 162 | while st < len(sql_data): 163 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 164 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \ 165 | to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) 166 | raw_q_seq = [x[0] for x in raw_data] 167 | raw_col_seq = [x[1] for x in raw_data] 168 | gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) 169 | query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) 170 | gt_sel_seq = [x[1] for x in ans_seq] 171 | score = model.forward(q_seq, col_seq, col_num, 172 | (True, True, True), gt_sel=gt_sel_seq) 173 | pred_queries = model.gen_query(score, q_seq, col_seq, 174 | raw_q_seq, raw_col_seq, (True, True, True)) 175 | 176 | for idx, (sql_gt, sql_pred, tid) in enumerate( 177 | zip(query_gt, pred_queries, table_ids)): 178 | ret_gt = engine.execute(tid, 179 | sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) 180 | try: 181 | ret_pred = engine.execute(tid, 182 | sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) 183 | except: 184 | ret_pred = None 185 | tot_acc_num += (ret_gt == ret_pred) 186 | 187 | st = ed 188 | 189 | return tot_acc_num / len(sql_data) 190 | 191 | def epoch_acc(model, batch_size, sql_data, table_data, pred_entry): 192 | model.eval() 193 | perm = list(range(len(sql_data))) 194 | st = 0 195 | one_acc_num = 0.0 196 | tot_acc_num = 0.0 197 | while st < len(sql_data): 198 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 199 | 200 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) 201 | raw_q_seq = [x[0] for x in raw_data] 202 | raw_col_seq = [x[1] for x in raw_data] 203 | query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) 204 | gt_sel_seq = [x[1] for x in ans_seq] 205 | score = model.forward(q_seq, col_seq, col_num, 206 | pred_entry, gt_sel = gt_sel_seq) 207 | pred_queries = model.gen_query(score, q_seq, col_seq, 208 | raw_q_seq, raw_col_seq, pred_entry) 209 | one_err, tot_err = model.check_acc(raw_data, 210 | pred_queries, query_gt, pred_entry) 211 | 212 | one_acc_num += (ed-st-one_err) 213 | tot_acc_num += (ed-st-tot_err) 214 | 215 | st = ed 216 | return tot_acc_num / len(sql_data), one_acc_num / len(sql_data) 217 | 218 | def epoch_reinforce_train(model, optimizer, batch_size, sql_data, table_data, db_path): 219 | engine = DBEngine(db_path) 220 | 221 | model.train() 222 | perm = np.random.permutation(len(sql_data)) 223 | cum_reward = 0.0 224 | st = 0 225 | while st < len(sql_data): 226 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 227 | 228 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data =\ 229 | to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) 230 | gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) 231 | raw_q_seq = [x[0] for x in raw_data] 232 | raw_col_seq = [x[1] for x in raw_data] 233 | query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) 234 | gt_sel_seq = [x[1] for x in ans_seq] 235 | score = model.forward(q_seq, col_seq, col_num, (True, True, True), 236 | reinforce=True, gt_sel=gt_sel_seq) 237 | pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, 238 | raw_col_seq, (True, True, True), reinforce=True) 239 | 240 | query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) 241 | rewards = [] 242 | for idx, (sql_gt, sql_pred, tid) in enumerate( 243 | zip(query_gt, pred_queries, table_ids)): 244 | ret_gt = engine.execute(tid, 245 | sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) 246 | try: 247 | ret_pred = engine.execute(tid, 248 | sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) 249 | except: 250 | ret_pred = None 251 | 252 | if ret_pred is None: 253 | rewards.append(-2) 254 | elif ret_pred != ret_gt: 255 | rewards.append(-1) 256 | else: 257 | rewards.append(1) 258 | 259 | cum_reward += (sum(rewards)) 260 | optimizer.zero_grad() 261 | model.reinforce_backward(score, rewards) 262 | optimizer.step() 263 | 264 | st = ed 265 | 266 | return cum_reward / len(sql_data) 267 | 268 | 269 | def load_word_emb(file_name, load_used=False, use_small=False): 270 | if not load_used: 271 | print ('Loading word embedding from %s'%file_name) 272 | ret = {} 273 | with open(file_name) as inf: 274 | for idx, line in enumerate(inf): 275 | if (use_small and idx >= 5000): 276 | break 277 | info = line.strip().split(' ') 278 | if info[0].lower() not in ret: 279 | ret[info[0]] = np.array(map(lambda x:float(x), info[1:])) 280 | return ret 281 | else: 282 | print ('Load used word embedding') 283 | with open('glove/word2idx.json') as inf: 284 | w2i = json.load(inf) 285 | with open('glove/usedwordemb.npy') as inf: 286 | word_emb_val = np.load(inf) 287 | return w2i, word_emb_val 288 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from sqlnet.utils import * 4 | from sqlnet.model.seq2sql import Seq2SQL 5 | from sqlnet.model.sqlnet import SQLNet 6 | import numpy as np 7 | import datetime 8 | 9 | import argparse 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--toy', action='store_true', 14 | help='If set, use small data; used for fast debugging.') 15 | parser.add_argument('--ca', action='store_true', 16 | help='Use conditional attention.') 17 | parser.add_argument('--dataset', type=int, default=0, 18 | help='0: original dataset, 1: re-split dataset') 19 | parser.add_argument('--rl', action='store_true', 20 | help='Use RL for Seq2SQL.') 21 | parser.add_argument('--baseline', action='store_true', 22 | help='If set, then test Seq2SQL model; default is SQLNet model.') 23 | parser.add_argument('--train_emb', action='store_true', 24 | help='Use trained word embedding for SQLNet.') 25 | args = parser.parse_args() 26 | 27 | N_word=300 28 | B_word=42 29 | if args.toy: 30 | USE_SMALL=True 31 | GPU=True 32 | BATCH_SIZE=15 33 | else: 34 | USE_SMALL=False 35 | GPU=True 36 | BATCH_SIZE=64 37 | TEST_ENTRY=(True, True, True) # (AGG, SEL, COND) 38 | 39 | sql_data, table_data, val_sql_data, val_table_data, \ 40 | test_sql_data, test_table_data, \ 41 | TRAIN_DB, DEV_DB, TEST_DB = load_dataset( 42 | args.dataset, use_small=USE_SMALL) 43 | 44 | word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ 45 | load_used=True, use_small=USE_SMALL) # load_used can speed up loading 46 | 47 | if args.baseline: 48 | model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU, trainable_emb = True) 49 | else: 50 | model = SQLNet(word_emb, N_word=N_word, use_ca=args.ca, gpu=GPU, 51 | trainable_emb = True) 52 | 53 | if args.train_emb: 54 | agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args) 55 | print "Loading from %s"%agg_m 56 | model.agg_pred.load_state_dict(torch.load(agg_m)) 57 | print "Loading from %s"%sel_m 58 | model.sel_pred.load_state_dict(torch.load(sel_m)) 59 | print "Loading from %s"%cond_m 60 | model.cond_pred.load_state_dict(torch.load(cond_m)) 61 | print "Loading from %s"%agg_e 62 | model.agg_embed_layer.load_state_dict(torch.load(agg_e)) 63 | print "Loading from %s"%sel_e 64 | model.sel_embed_layer.load_state_dict(torch.load(sel_e)) 65 | print "Loading from %s"%cond_e 66 | model.cond_embed_layer.load_state_dict(torch.load(cond_e)) 67 | else: 68 | agg_m, sel_m, cond_m = best_model_name(args) 69 | print "Loading from %s"%agg_m 70 | model.agg_pred.load_state_dict(torch.load(agg_m)) 71 | print "Loading from %s"%sel_m 72 | model.sel_pred.load_state_dict(torch.load(sel_m)) 73 | print "Loading from %s"%cond_m 74 | model.cond_pred.load_state_dict(torch.load(cond_m)) 75 | 76 | print "Dev acc_qm: %s;\n breakdown on (agg, sel, where): %s"%epoch_acc( 77 | model, BATCH_SIZE, val_sql_data, val_table_data, TEST_ENTRY) 78 | print "Dev execution acc: %s"%epoch_exec_acc( 79 | model, BATCH_SIZE, val_sql_data, val_table_data, DEV_DB) 80 | print "Test acc_qm: %s;\n breakdown on (agg, sel, where): %s"%epoch_acc( 81 | model, BATCH_SIZE, test_sql_data, test_table_data, TEST_ENTRY) 82 | print "Test execution acc: %s"%epoch_exec_acc( 83 | model, BATCH_SIZE, test_sql_data, test_table_data, TEST_DB) 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from sqlnet.utils import * 4 | from sqlnet.model.seq2sql import Seq2SQL 5 | from sqlnet.model.sqlnet import SQLNet 6 | import numpy as np 7 | import datetime 8 | 9 | import argparse 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--toy', action='store_true', 14 | help='If set, use small data; used for fast debugging.') 15 | parser.add_argument('--suffix', type=str, default='', 16 | help='The suffix at the end of saved model name.') 17 | parser.add_argument('--ca', action='store_true', 18 | help='Use conditional attention.') 19 | parser.add_argument('--dataset', type=int, default=0, 20 | help='0: original dataset, 1: re-split dataset') 21 | parser.add_argument('--rl', action='store_true', 22 | help='Use RL for Seq2SQL(requires pretrained model).') 23 | parser.add_argument('--baseline', action='store_true', 24 | help='If set, then train Seq2SQL model; default is SQLNet model.') 25 | parser.add_argument('--train_emb', action='store_true', 26 | help='Train word embedding for SQLNet(requires pretrained model).') 27 | args = parser.parse_args() 28 | 29 | N_word=300 30 | B_word=42 31 | if args.toy: 32 | USE_SMALL=True 33 | GPU=True 34 | BATCH_SIZE=15 35 | else: 36 | USE_SMALL=False 37 | GPU=True 38 | BATCH_SIZE=64 39 | TRAIN_ENTRY=(True, True, True) # (AGG, SEL, COND) 40 | TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY 41 | learning_rate = 1e-4 if args.rl else 1e-3 42 | 43 | sql_data, table_data, val_sql_data, val_table_data, \ 44 | test_sql_data, test_table_data, \ 45 | TRAIN_DB, DEV_DB, TEST_DB = load_dataset( 46 | args.dataset, use_small=USE_SMALL) 47 | 48 | word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ 49 | load_used=args.train_emb, use_small=USE_SMALL) 50 | 51 | if args.baseline: 52 | model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU, 53 | trainable_emb = args.train_emb) 54 | assert not args.train_emb, "Seq2SQL can\'t train embedding." 55 | else: 56 | model = SQLNet(word_emb, N_word=N_word, use_ca=args.ca, 57 | gpu=GPU, trainable_emb = args.train_emb) 58 | assert not args.rl, "SQLNet can\'t do reinforcement learning." 59 | optimizer = torch.optim.Adam(model.parameters(), 60 | lr=learning_rate, weight_decay = 0) 61 | 62 | if args.train_emb: 63 | agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args) 64 | else: 65 | agg_m, sel_m, cond_m = best_model_name(args) 66 | 67 | if args.rl or args.train_emb: # Load pretrained model. 68 | agg_lm, sel_lm, cond_lm = best_model_name(args, for_load=True) 69 | print "Loading from %s"%agg_lm 70 | model.agg_pred.load_state_dict(torch.load(agg_lm)) 71 | print "Loading from %s"%sel_lm 72 | model.sel_pred.load_state_dict(torch.load(sel_lm)) 73 | print "Loading from %s"%cond_lm 74 | model.cond_pred.load_state_dict(torch.load(cond_lm)) 75 | 76 | if args.rl: 77 | best_acc = 0.0 78 | best_idx = -1 79 | print "Init dev acc_qm: %s\n breakdown on (agg, sel, where): %s"% \ 80 | epoch_acc(model, BATCH_SIZE, val_sql_data,\ 81 | val_table_data, TRAIN_ENTRY) 82 | print "Init dev acc_ex: %s"%epoch_exec_acc( 83 | model, BATCH_SIZE, val_sql_data, val_table_data, DEV_DB) 84 | torch.save(model.cond_pred.state_dict(), cond_m) 85 | for i in range(100): 86 | print 'Epoch %d @ %s'%(i+1, datetime.datetime.now()) 87 | print ' Avg reward = %s'%epoch_reinforce_train( 88 | model, optimizer, BATCH_SIZE, sql_data, table_data, TRAIN_DB) 89 | print ' dev acc_qm: %s\n breakdown result: %s'% epoch_acc( 90 | model, BATCH_SIZE, val_sql_data, val_table_data, TRAIN_ENTRY) 91 | exec_acc = epoch_exec_acc( 92 | model, BATCH_SIZE, val_sql_data, val_table_data, DEV_DB) 93 | print ' dev acc_ex: %s', exec_acc 94 | if exec_acc[0] > best_acc: 95 | best_acc = exec_acc[0] 96 | best_idx = i+1 97 | torch.save(model.cond_pred.state_dict(), 98 | 'saved_model/epoch%d.cond_model%s'%(i+1, args.suffix)) 99 | torch.save(model.cond_pred.state_dict(), cond_m) 100 | print ' Best exec acc = %s, on epoch %s'%(best_acc, best_idx) 101 | else: 102 | init_acc = epoch_acc(model, BATCH_SIZE, 103 | val_sql_data, val_table_data, TRAIN_ENTRY) 104 | best_agg_acc = init_acc[1][0] 105 | best_agg_idx = 0 106 | best_sel_acc = init_acc[1][1] 107 | best_sel_idx = 0 108 | best_cond_acc = init_acc[1][2] 109 | best_cond_idx = 0 110 | print 'Init dev acc_qm: %s\n breakdown on (agg, sel, where): %s'%\ 111 | init_acc 112 | if TRAIN_AGG: 113 | torch.save(model.agg_pred.state_dict(), agg_m) 114 | if args.train_emb: 115 | torch.save(model.agg_embed_layer.state_dict(), agg_e) 116 | if TRAIN_SEL: 117 | torch.save(model.sel_pred.state_dict(), sel_m) 118 | if args.train_emb: 119 | torch.save(model.sel_embed_layer.state_dict(), sel_e) 120 | if TRAIN_COND: 121 | torch.save(model.cond_pred.state_dict(), cond_m) 122 | if args.train_emb: 123 | torch.save(model.cond_embed_layer.state_dict(), cond_e) 124 | for i in range(100): 125 | print 'Epoch %d @ %s'%(i+1, datetime.datetime.now()) 126 | print ' Loss = %s'%epoch_train( 127 | model, optimizer, BATCH_SIZE, 128 | sql_data, table_data, TRAIN_ENTRY) 129 | print ' Train acc_qm: %s\n breakdown result: %s'%epoch_acc( 130 | model, BATCH_SIZE, sql_data, table_data, TRAIN_ENTRY) 131 | #val_acc = epoch_token_acc(model, BATCH_SIZE, val_sql_data, val_table_data, TRAIN_ENTRY) 132 | val_acc = epoch_acc(model, 133 | BATCH_SIZE, val_sql_data, val_table_data, TRAIN_ENTRY) 134 | print ' Dev acc_qm: %s\n breakdown result: %s'%val_acc 135 | if TRAIN_AGG: 136 | if val_acc[1][0] > best_agg_acc: 137 | best_agg_acc = val_acc[1][0] 138 | best_agg_idx = i+1 139 | torch.save(model.agg_pred.state_dict(), 140 | 'saved_model/epoch%d.agg_model%s'%(i+1, args.suffix)) 141 | torch.save(model.agg_pred.state_dict(), agg_m) 142 | if args.train_emb: 143 | torch.save(model.agg_embed_layer.state_dict(), 144 | 'saved_model/epoch%d.agg_embed%s'%(i+1, args.suffix)) 145 | torch.save(model.agg_embed_layer.state_dict(), agg_e) 146 | if TRAIN_SEL: 147 | if val_acc[1][1] > best_sel_acc: 148 | best_sel_acc = val_acc[1][1] 149 | best_sel_idx = i+1 150 | torch.save(model.sel_pred.state_dict(), 151 | 'saved_model/epoch%d.sel_model%s'%(i+1, args.suffix)) 152 | torch.save(model.sel_pred.state_dict(), sel_m) 153 | if args.train_emb: 154 | torch.save(model.sel_embed_layer.state_dict(), 155 | 'saved_model/epoch%d.sel_embed%s'%(i+1, args.suffix)) 156 | torch.save(model.sel_embed_layer.state_dict(), sel_e) 157 | if TRAIN_COND: 158 | if val_acc[1][2] > best_cond_acc: 159 | best_cond_acc = val_acc[1][2] 160 | best_cond_idx = i+1 161 | torch.save(model.cond_pred.state_dict(), 162 | 'saved_model/epoch%d.cond_model%s'%(i+1, args.suffix)) 163 | torch.save(model.cond_pred.state_dict(), cond_m) 164 | if args.train_emb: 165 | torch.save(model.cond_embed_layer.state_dict(), 166 | 'saved_model/epoch%d.cond_embed%s'%(i+1, args.suffix)) 167 | torch.save(model.cond_embed_layer.state_dict(), cond_e) 168 | print ' Best val acc = %s, on epoch %s individually'%( 169 | (best_agg_acc, best_sel_acc, best_cond_acc), 170 | (best_agg_idx, best_sel_idx, best_cond_idx)) 171 | --------------------------------------------------------------------------------