├── corenlp_local.py ├── model_save_and_infer.py ├── LICENSE ├── README.md ├── load_model.py ├── dbengine_sqlnet.py ├── load_data.py ├── seq2sql_model_testing.py ├── infer_functions.py ├── roberta_training.py ├── seq2sql_model_internal_functions.py ├── seq2sql_model_training_functions.py ├── dev_function.py └── seq2sql_model_classes.py /corenlp_local.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_g_wvi_bert_from_g_wvi_corenlp(wh_to_wp_index, g_wvi_corenlp): 4 | """ 5 | Generate SQuAD style start and end index of wv in nlu. Index is for of after WordPiece tokenization. 6 | Assumption: where_str always presents in the nlu. 7 | """ 8 | g_wvi = [] 9 | for b, g_wvi_corenlp1 in enumerate(g_wvi_corenlp): 10 | wh_to_wp_index1 = wh_to_wp_index[b] 11 | g_wvi1 = [] 12 | for i_wn, g_wvi_corenlp11 in enumerate(g_wvi_corenlp1): 13 | 14 | st_idx, ed_idx = g_wvi_corenlp11 15 | 16 | st_wp_idx = wh_to_wp_index1[st_idx] 17 | ed_wp_idx = wh_to_wp_index1[ed_idx] 18 | 19 | g_wvi11 = [st_wp_idx, ed_wp_idx] 20 | g_wvi1.append(g_wvi11) 21 | 22 | g_wvi.append(g_wvi1) 23 | 24 | return g_wvi 25 | 26 | def get_g_wvi_corenlp(t): 27 | g_wvi_corenlp = [] 28 | for t1 in t: 29 | g_wvi_corenlp.append( t1['wvi_corenlp'] ) 30 | return g_wvi_corenlp 31 | -------------------------------------------------------------------------------- /model_save_and_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random as python_random 4 | 5 | def save_for_evaluation(path_save, results, dataset_name): 6 | path_save_file = os.path.join(path_save, f'results_{dataset_name}.jsonl') 7 | with open(path_save_file, 'w', encoding='utf-8') as file: 8 | for index, line in enumerate(results): 9 | json_string = json.dumps(line, ensure_ascii=False, default=json_default_type_checker) 10 | json_string += '\n' 11 | 12 | file.writelines(json_string) 13 | 14 | def json_default_type_checker(o): 15 | """ 16 | From https://stackoverflow.com/questions/11942364/typeerror-integer-is-not-json-serializable-when-serializing-json-in-python 17 | """ 18 | if isinstance(o, int): return int(o) 19 | raise TypeError 20 | 21 | def load_jsonl(path_file, seed=1): 22 | data = [] 23 | with open(path_file, "r", encoding="utf-8") as file: 24 | for idx, line in enumerate(file): 25 | curr_line = json.loads(line.strip()) 26 | data.append(curr_line) 27 | return data 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Debaditya Pal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Agnostic RoBERTa-based Natural Language to SQL Query Generation 2 | This repo contains the code for NL2SQL paper titled "Data Agnostic RoBERTa-based Natural Language to SQL Query Generation". The link to which is: https://arxiv.org/abs/2010.05243 3 | 4 | Our code was based on the previous works introduced in the paper "Content Enhanced BERT-based Text-to-SQL Generation", the code to which can be found [here](https://github.com/guotong1988/NL2SQL-RULE). Feel free to experiment with the models. 5 | 6 | ## Getting Started 7 | Our work was mainly done in a Google Colab workspace with a GPU runtime accelerator, in order to reproduce the results we would recommend running the code in a similar workspace. Thus we have created a notebook which can be found at: 8 | https://colab.research.google.com/drive/1qYJTbbEXYFVdY6xae9Zmt96hkeW8ZFrn 9 | Below are the step by step instructions to run the models: 10 | 11 | 1. Visit https://drive.google.com/drive/folders/13f2MrdpieC9QGXM_DJnj2f1Hs6ZBh2ZT?usp=sharing. This Drive contains all the datasets and pretrained model weights. **Add a shortcut to your drive.** This step is important. 12 | 2. Open the notebook, the link to which is provided above. 13 | 3. Make sure the runtime accelerator is set to GPU. 14 | 4. Mount your Google Drive and clone this repository in the runtime environment, the code for this step has already been provided in the notebook. 15 | 5. The Notebook is now ready for use. 16 | 17 | All the package requirements have been mentioned in the notebook. 18 | 19 | ## Results 20 | 21 | | Dev Logical Form Acc | Dev Execution Accuracy | Test Logical Form Accuracy | Test Execution Accuracy | 22 | |:-:|:-:|:-:|:-:| 23 | | 69.4% | 77.0% |68.9% | 76.7% | 24 | 25 | ## References 26 | 27 | - https://github.com/salesforce/WikiSQL 28 | - https://github.com/guotong1988/NL2SQL-RULE 29 | -------------------------------------------------------------------------------- /load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from seq2sql_model_classes import Seq2SQL_v1 4 | from transformers import RobertaConfig, RobertaModel, RobertaTokenizer 5 | 6 | device = torch.device("cuda") 7 | 8 | def get_roberta_model(): 9 | 10 | # Initializing a RoBERTa configuration 11 | configuration = RobertaConfig() 12 | 13 | # Initializing a model from the configuration 14 | Roberta_Model = RobertaModel(configuration).from_pretrained("roberta-base") 15 | Roberta_Model.to(device) 16 | 17 | # Accessing the model configuration 18 | configuration = Roberta_Model.config 19 | 20 | #get the Roberta Tokenizer 21 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 22 | 23 | return Roberta_Model, tokenizer, configuration 24 | 25 | 26 | def get_seq2sql_model(roberta_hidden_layer_size, number_of_layers = 2, 27 | hidden_vector_dimensions = 100, 28 | number_lstm_layers = 2, 29 | dropout_rate = 0.3, 30 | load_pretrained_model=False, model_path=None): 31 | 32 | ''' 33 | 34 | get_seq2sql_model 35 | Arguments: 36 | roberta_hidden_layer_size: sizes of hidden layers of Roberta model 37 | number_of_layers : total number of layers 38 | hidden_vector_dimensions : dimensions of hidden vectors 39 | number_lstm_layers : total number of lstm layers 40 | dropout_rate : value of dropout rate 41 | load_pretrained_model : want to load pretrained model(true or false) 42 | model_path : The path to the directory in which the model is contained 43 | 44 | Returns: 45 | model: returns the model 46 | 47 | ''' 48 | 49 | # number_of_layers = "The Number of final layers of RoBERTa to be used in downstream task." 50 | # hidden_vector_dimensions : "The dimension of hidden vector in the seq-to-SQL module." 51 | # number_lstm_layers : "The number of LSTM layers." in seqtosqlmodule 52 | 53 | sql_main_operators = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 54 | sql_conditional_operators = ['=', '>', '<', 'OP'] 55 | 56 | number_of_neurons = roberta_hidden_layer_size * number_of_layers # Seq-to-SQL input vector dimenstion 57 | 58 | model = Seq2SQL_v1(number_of_neurons, hidden_vector_dimensions, number_lstm_layers, dropout_rate, len(sql_conditional_operators), len(sql_main_operators)) 59 | model = model.to(device) 60 | 61 | if load_pretrained_model: 62 | assert model_path != None 63 | if torch.cuda.is_available(): 64 | res = torch.load(model_path) 65 | else: 66 | res = torch.load(model_path, map_location='cpu') 67 | model.load_state_dict(res['model']) 68 | 69 | return model 70 | 71 | def get_optimizers(model, model_roberta,learning_rate_model=1e-3,learning_rate_roberta=1e-5): 72 | ''' 73 | get_optimizers 74 | Arguments: 75 | model: returned model from get_seq2sql_model 76 | model_roberta : returned model from get_roberta_model 77 | fine_tune : want to fine tune(true or false) 78 | learning_rate_model : learning rate of model (from get_seq2sql_model) 79 | learning_rate_roberta : learning rate of roberta model (from get_roberta_model) 80 | 81 | Returns: 82 | opt: returns the optimised model (from get_seq2sql_model) 83 | opt_roberta : returns the optimised roberta model (from get_roberta_model) 84 | 85 | ''' 86 | 87 | opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 88 | lr=learning_rate_model, weight_decay=0) 89 | 90 | opt_roberta = torch.optim.Adam(filter(lambda p: p.requires_grad, model_roberta.parameters()), 91 | lr=learning_rate_roberta, weight_decay=0) 92 | 93 | return opt, opt_roberta 94 | 95 | -------------------------------------------------------------------------------- /dbengine_sqlnet.py: -------------------------------------------------------------------------------- 1 | 2 | import records 3 | import re 4 | from babel.numbers import parse_decimal, NumberFormatError 5 | # From original SQLNet code. 6 | # Wonseok modified. 20180607 7 | 8 | schema_re = re.compile(r'\((.+)\)') # group (.......) dfdf (.... )group 9 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') # ? zero or one time appear of preceding character, * zero or several time appear of preceding character. 10 | # Catch something like -34.34, .4543, 11 | # | is 'or' 12 | 13 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 14 | cond_ops = ['=', '>', '<', 'OP'] 15 | 16 | class DBEngine: 17 | 18 | def __init__(self, fdb): 19 | #fdb = 'data/test.db' 20 | self.db = records.Database('sqlite:///{}'.format(fdb)).get_connection() 21 | 22 | def execute_query(self, table_id, query, *args, **kwargs): 23 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 24 | 25 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 26 | if not table_id.startswith('table'): 27 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 28 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 29 | schema_str = schema_re.findall(table_info)[0] 30 | schema = {} 31 | for tup in schema_str.split(', '): 32 | c, t = tup.split() 33 | schema[c] = t 34 | select = 'col{}'.format(select_index) 35 | agg = agg_ops[aggregation_index] 36 | if agg: 37 | select = '{}({})'.format(agg, select) 38 | where_clause = [] 39 | where_map = {} 40 | for col_index, op, val in conditions: 41 | if lower and (isinstance(val, str) or isinstance(val, str)): 42 | val = val.lower() 43 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 44 | try: 45 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val)) 46 | # val = float(parse_decimal(val)) # somehow it generates error. 47 | val = float(parse_decimal(val, locale='en_US')) 48 | # print('!!!!!!After: val', val) 49 | 50 | except NumberFormatError as e: 51 | try: 52 | val = float(num_re.findall(val)[0]) # need to understand and debug this part. 53 | except: 54 | # Although column is of number, selected one is not number. Do nothing in this case. 55 | pass 56 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 57 | where_map['col{}'.format(col_index)] = val 58 | where_str = '' 59 | if where_clause: 60 | where_str = 'WHERE ' + ' AND '.join(where_clause) 61 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 62 | #print query 63 | out = self.db.query(query, **where_map) 64 | 65 | 66 | return [o.result for o in out] 67 | def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True): 68 | if not table_id.startswith('table'): 69 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 70 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 71 | schema_str = schema_re.findall(table_info)[0] 72 | schema = {} 73 | for tup in schema_str.split(', '): 74 | c, t = tup.split() 75 | schema[c] = t 76 | select = 'col{}'.format(select_index) 77 | agg = agg_ops[aggregation_index] 78 | if agg: 79 | select = '{}({})'.format(agg, select) 80 | where_clause = [] 81 | where_map = {} 82 | for col_index, op, val in conditions: 83 | if lower and (isinstance(val, str) or isinstance(val, str)): 84 | val = val.lower() 85 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 86 | try: 87 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val)) 88 | # val = float(parse_decimal(val)) # somehow it generates error. 89 | val = float(parse_decimal(val, locale='en_US')) 90 | # print('!!!!!!After: val', val) 91 | 92 | except NumberFormatError as e: 93 | val = float(num_re.findall(val)[0]) 94 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 95 | where_map['col{}'.format(col_index)] = val 96 | where_str = '' 97 | if where_clause: 98 | where_str = 'WHERE ' + ' AND '.join(where_clause) 99 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 100 | #print query 101 | out = self.db.query(query, **where_map) 102 | 103 | 104 | return [o.result for o in out], query 105 | def show_table(self, table_id): 106 | if not table_id.startswith('table'): 107 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 108 | rows = self.db.query('select * from ' +table_id) 109 | print(rows.dataset) -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import json 5 | from matplotlib.pylab import * 6 | 7 | def get_data(file_path: str,batch_size: int): 8 | ''' 9 | Gets data from the dataset and creates a data loader 10 | 11 | Arguments: 12 | file_path: The path to the directory in which the dataset is contained 13 | batch_size: Batch size to be used for the data loaders 14 | 15 | Returns: 16 | train_data: Training dataset (Natural Language utterances) 17 | train_table: Training tables (Table schema and table data) 18 | dev_data: Development dataset (Natural Language utterances) 19 | dev_table: Development tables (Table schema and table data) 20 | train_loader: Training dataset loader 21 | dev_loader: Development dataset loader 22 | ''' 23 | # Loading Dev Files(Development Dataset) 24 | dev_data = [] 25 | dev_table = {} 26 | 27 | with open(file_path + '/dev_knowledge.jsonl') as dev_data_file: 28 | for idx, line in enumerate(dev_data_file): 29 | current_line = json.loads(line.strip()) 30 | dev_data.append(current_line) 31 | 32 | with open(file_path + '/dev.tables.jsonl') as dev_table_file: 33 | for idx, line in enumerate(dev_table_file): 34 | current_line = json.loads(line.strip()) 35 | dev_table[current_line['id']] = current_line 36 | 37 | # Loading Train Files(Training Dataset) 38 | train_data = [] 39 | train_table = {} 40 | 41 | with open(file_path + '/train_knowledge.jsonl') as train_data_file: 42 | for idx, line in enumerate(train_data_file): 43 | current_line = json.loads(line.strip()) 44 | train_data.append(current_line) 45 | 46 | with open(file_path + '/train.tables.jsonl') as train_table_file: 47 | for idx, line in enumerate(train_table_file): 48 | current_line = json.loads(line.strip()) 49 | train_table[current_line['id']] = current_line 50 | 51 | train_loader = torch.utils.data.DataLoader( 52 | batch_size=batch_size, 53 | dataset=train_data, 54 | shuffle=True, 55 | num_workers=4, 56 | collate_fn=lambda x: x # now dictionary values are not merged! 57 | ) 58 | 59 | dev_loader = torch.utils.data.DataLoader( 60 | batch_size=batch_size, 61 | dataset=dev_data, 62 | shuffle=True, 63 | num_workers=4, 64 | collate_fn=lambda x: x # now dictionary values are not merged! 65 | ) 66 | 67 | return train_data, train_table, dev_data, dev_table, train_loader, dev_loader 68 | 69 | def get_test_data(file_path: str,batch_size: int): 70 | test_data=[] 71 | test_table = {} 72 | 73 | with open(file_path + '/test_knowledge.jsonl') as test_data_file: 74 | for idx, line in enumerate(test_data_file): 75 | current_line = json.loads(line.strip()) 76 | test_data.append(current_line) 77 | 78 | with open(file_path + '/test.tables.jsonl') as test_table_file: 79 | for idx, line in enumerate(test_table_file): 80 | current_line = json.loads(line.strip()) 81 | test_table[current_line['id']] = current_line 82 | 83 | test_loader = torch.utils.data.DataLoader( 84 | batch_size=batch_size, 85 | dataset=test_data, 86 | shuffle=True, 87 | num_workers=4, 88 | collate_fn=lambda x: x # now dictionary values are not merged! 89 | ) 90 | 91 | return test_data,test_table,test_loader 92 | 93 | def get_zero_data(file_path: str,batch_size: int): 94 | test_data=[] 95 | test_table = {} 96 | 97 | with open(file_path + '/zero.jsonl') as test_data_file: 98 | for idx, line in enumerate(test_data_file): 99 | current_line = json.loads(line.strip()) 100 | test_data.append(current_line) 101 | 102 | with open(file_path + '/test.tables.jsonl') as test_table_file: 103 | for idx, line in enumerate(test_table_file): 104 | current_line = json.loads(line.strip()) 105 | test_table[current_line['id']] = current_line 106 | 107 | test_loader = torch.utils.data.DataLoader( 108 | batch_size=batch_size, 109 | dataset=test_data, 110 | shuffle=True, 111 | num_workers=4, 112 | collate_fn=lambda x: x # now dictionary values are not merged! 113 | ) 114 | 115 | return test_data,test_table,test_loader 116 | 117 | 118 | def get_fields(data, header_tokenization=False, sql_tokenization=False): 119 | 120 | natural_language_utterance = [] 121 | tokenized_natural_language_utterance = [] 122 | sql_indexing = [] 123 | sql_query = [] 124 | tokenized_sql_query = [] 125 | table_indices = [] 126 | tokenized_headers = [] 127 | headers = [] 128 | 129 | for one_data in data: 130 | natural_language_utterance.append(one_data['question']) 131 | tokenized_natural_language_utterance.append(one_data['question_tok']) 132 | sql_indexing.append(one_data['sql']) 133 | sql_query.append(one_data['query']) 134 | headers.append(one_data['header']) 135 | table_indices.append({ 136 | "id" : one_data["table_id"], 137 | "header": one_data["header"], 138 | "types" : one_data["types"] 139 | }) 140 | 141 | if sql_tokenization: 142 | tokenized_sql_query.append(one_data['query_tok']) 143 | else: 144 | tokenized_sql_query.append(None) 145 | 146 | if header_tokenization: 147 | tokenized_headers.append(one_data['header_tok']) 148 | else: 149 | tokenized_headers.append(None) 150 | 151 | return natural_language_utterance,tokenized_natural_language_utterance,sql_indexing,sql_query,tokenized_sql_query,table_indices,tokenized_headers,headers 152 | 153 | -------------------------------------------------------------------------------- /seq2sql_model_testing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #import torch_xla 3 | #import torch_xla.core.xla_model as xm 4 | 5 | device = torch.device("cuda") 6 | 7 | def generate_sql_q(sql_i, tb): 8 | sql_q = [] 9 | for b, sql_i1 in enumerate(sql_i): 10 | tb1 = tb[b] 11 | sql_q1 = generate_sql_q1(sql_i1, tb1) 12 | sql_q.append(sql_q1) 13 | 14 | return sql_q 15 | 16 | def generate_sql_q1(sql_i1, tb1): 17 | """ 18 | sql = {'sel': 5, 'agg': 4, 'conds': [[3, 0, '59']]} 19 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg'] 20 | cond_ops = ['=', '>', '<', 'OP'] 21 | Temporal as it can show only one-time conditioned case. 22 | sql_query: real sql_query 23 | sql_plus_query: More redable sql_query 24 | "PLUS" indicates, it deals with the some of db specific facts like PCODE <-> NAME 25 | """ 26 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg'] 27 | cond_ops = ['=', '>', '<', 'OP'] 28 | 29 | headers = tb1["header"] 30 | types = tb1["types"] 31 | # select_header = headers[sql['sel']].lower() 32 | # try: 33 | # select_table = tb1["name"] 34 | # except: 35 | # print(f"No table name while headers are {headers}") 36 | select_table = tb1["id"] 37 | 38 | select_agg = agg_ops[sql_i1['agg']] 39 | select_header = headers[sql_i1['sel']] 40 | sql_query_part1 = f'SELECT {select_agg}({select_header}) ' 41 | 42 | 43 | where_num = len(sql_i1['conds']) 44 | if where_num == 0: 45 | sql_query_part2 = f'FROM {select_table}' 46 | # sql_plus_query_part2 = f'FROM {select_table}' 47 | 48 | else: 49 | sql_query_part2 = f'FROM {select_table} WHERE' 50 | # sql_plus_query_part2 = f'FROM {select_table_refined} WHERE' 51 | # ---------------------------------------------------------------------------------------------------------- 52 | for i in range(where_num): 53 | # check 'OR' 54 | # number_of_sub_conds = len(sql['conds'][i]) 55 | where_header_idx, where_op_idx, where_str = sql_i1['conds'][i] 56 | where_header = headers[where_header_idx] 57 | where_op = cond_ops[where_op_idx] 58 | if i > 0: 59 | sql_query_part2 += ' AND' 60 | # sql_plus_query_part2 += ' AND' 61 | if types[where_header_idx]=='text': 62 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'" 63 | else: 64 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'" 65 | sql_query = sql_query_part1 + sql_query_part2 66 | # sql_plus_query = sql_plus_query_part1 + sql_plus_query_part2 67 | 68 | return sql_query 69 | 70 | def report_detail(hds, nlu, 71 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, 72 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, 73 | cnt_list, current_cnt): 74 | cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x = current_cnt 75 | 76 | print(f'cnt = {cnt} / {cnt_tot} ===============================') 77 | 78 | print(f'headers: {hds}') 79 | print(f'nlu: {nlu}') 80 | 81 | # print(f's_sc: {s_sc[0]}') 82 | # print(f's_sa: {s_sa[0]}') 83 | # print(f's_wn: {s_wn[0]}') 84 | # print(f's_wc: {s_wc[0]}') 85 | # print(f's_wo: {s_wo[0]}') 86 | # print(f's_wv: {s_wv[0][0]}') 87 | print(f'===============================') 88 | print(f'g_sc : {g_sc}') 89 | print(f'pr_sc: {pr_sc}') 90 | print(f'g_sa : {g_sa}') 91 | print(f'pr_sa: {pr_sa}') 92 | print(f'g_wn : {g_wn}') 93 | print(f'pr_wn: {pr_wn}') 94 | print(f'g_wc : {g_wc}') 95 | print(f'pr_wc: {pr_wc}') 96 | print(f'g_wo : {g_wo}') 97 | print(f'pr_wo: {pr_wo}') 98 | print(f'g_wv : {g_wv}') 99 | # print(f'pr_wvi: {pr_wvi}') 100 | print('g_wv_str:', g_wv_str) 101 | print('p_wv_str:', pr_wv_str) 102 | print(f'g_sql_q: {g_sql_q}') 103 | print(f'pr_sql_q: {pr_sql_q}') 104 | print(f'g_ans: {g_ans}') 105 | print(f'pr_ans: {pr_ans}') 106 | print(f'--------------------------------') 107 | 108 | print(cnt_list) 109 | 110 | print(f'acc_lx = {cnt_lx / cnt:.3f}, acc_x = {cnt_x / cnt:.3f}\n', 111 | f'acc_sc = {cnt_sc / cnt:.3f}, acc_sa = {cnt_sa / cnt:.3f}, acc_wn = {cnt_wn / cnt:.3f}\n', 112 | f'acc_wc = {cnt_wc / cnt:.3f}, acc_wo = {cnt_wo / cnt:.3f}, acc_wv = {cnt_wv / cnt:.3f}') 113 | print(f'===============================') 114 | 115 | def generate_sql_q(sql_i, tb): 116 | sql_q = [] 117 | for b, sql_i1 in enumerate(sql_i): 118 | tb1 = tb[b] 119 | sql_q1 = generate_sql_q1(sql_i1, tb1) 120 | sql_q.append(sql_q1) 121 | 122 | return sql_q 123 | 124 | def generate_sql_q1(sql_i1, tb1): 125 | """ 126 | sql = {'sel': 5, 'agg': 4, 'conds': [[3, 0, '59']]} 127 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg'] 128 | cond_ops = ['=', '>', '<', 'OP'] 129 | Temporal as it can show only one-time conditioned case. 130 | sql_query: real sql_query 131 | sql_plus_query: More redable sql_query 132 | "PLUS" indicates, it deals with the some of db specific facts like PCODE <-> NAME 133 | """ 134 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg'] 135 | cond_ops = ['=', '>', '<', 'OP'] 136 | 137 | headers = tb1["header"] 138 | types = tb1["types"] 139 | # select_header = headers[sql['sel']].lower() 140 | # try: 141 | # select_table = tb1["name"] 142 | # except: 143 | # print(f"No table name while headers are {headers}") 144 | select_table = tb1["id"] 145 | 146 | select_agg = agg_ops[sql_i1['agg']] 147 | select_header = headers[sql_i1['sel']] 148 | sql_query_part1 = f'SELECT {select_agg}({select_header}) ' 149 | 150 | 151 | where_num = len(sql_i1['conds']) 152 | if where_num == 0: 153 | sql_query_part2 = f'FROM {select_table}' 154 | # sql_plus_query_part2 = f'FROM {select_table}' 155 | 156 | else: 157 | sql_query_part2 = f'FROM {select_table} WHERE' 158 | # sql_plus_query_part2 = f'FROM {select_table_refined} WHERE' 159 | # ---------------------------------------------------------------------------------------------------------- 160 | for i in range(where_num): 161 | # check 'OR' 162 | # number_of_sub_conds = len(sql['conds'][i]) 163 | where_header_idx, where_op_idx, where_str = sql_i1['conds'][i] 164 | where_header = headers[where_header_idx] 165 | where_op = cond_ops[where_op_idx] 166 | if i > 0: 167 | sql_query_part2 += ' AND' 168 | # sql_plus_query_part2 += ' AND' 169 | 170 | if types[where_header_idx]=='text': 171 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'" 172 | else: 173 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'" 174 | 175 | sql_query = sql_query_part1 + sql_query_part2 176 | # sql_plus_query = sql_plus_query_part1 + sql_plus_query_part2 177 | 178 | return sql_query 179 | 180 | -------------------------------------------------------------------------------- /infer_functions.py: -------------------------------------------------------------------------------- 1 | from matplotlib.pylab import * 2 | from roberta_training import * 3 | from seq2sql_model_testing import * 4 | import json 5 | import nltk 6 | from nltk.tokenize import word_tokenize, sent_tokenize 7 | import re 8 | import os 9 | re_ = re.compile(' ') 10 | 11 | def tokenize_corenlp_direct_version(client, nlu1): 12 | nlu1_tok = [] 13 | for sentence in client.annotate(nlu1).sentence: 14 | for tok in sentence.token: 15 | nlu1_tok.append(tok.originalText) 16 | return nlu1_tok 17 | 18 | def sent_split(documents): 19 | words = [] 20 | for sent in sent_tokenize(documents): 21 | for word in word_tokenize(sent): 22 | words.append(word) 23 | return words 24 | 25 | def load_jsonl(path_file, toy_data=False, toy_size=4, shuffle=False, seed=1): 26 | data = [] 27 | 28 | with open(path_file, "r", encoding="utf-8") as f: 29 | for idx, line in enumerate(f): 30 | if toy_data and idx >= toy_size and (not shuffle): 31 | break 32 | t1 = json.loads(line.strip()) 33 | data.append(t1) 34 | 35 | if shuffle and toy_data: 36 | # When shuffle required, get all the data, shuffle, and get the part of data. 37 | print( 38 | f"If the toy-data is used, the whole data loaded first and then shuffled before get the first {toy_size} data") 39 | 40 | python_random.Random(seed).shuffle(data) # fixed 41 | data = data[:toy_size] 42 | 43 | return data 44 | 45 | def sort_and_generate_pr_w(pr_sql_i): 46 | pr_wc = [] 47 | pr_wo = [] 48 | pr_wv = [] 49 | for b, pr_sql_i1 in enumerate(pr_sql_i): 50 | conds1 = pr_sql_i1["conds"] 51 | pr_wc1 = [] 52 | pr_wo1 = [] 53 | pr_wv1 = [] 54 | 55 | # Generate 56 | for i_wn, conds11 in enumerate(conds1): 57 | pr_wc1.append( conds11[0]) 58 | pr_wo1.append( conds11[1]) 59 | pr_wv1.append( conds11[2]) 60 | 61 | # sort based on pr_wc1 62 | idx = argsort(pr_wc1) 63 | pr_wc1 = array(pr_wc1)[idx].tolist() 64 | pr_wo1 = array(pr_wo1)[idx].tolist() 65 | pr_wv1 = array(pr_wv1)[idx].tolist() 66 | 67 | conds1_sorted = [] 68 | for i, idx1 in enumerate(idx): 69 | conds1_sorted.append( conds1[idx1] ) 70 | 71 | 72 | pr_wc.append(pr_wc1) 73 | pr_wo.append(pr_wo1) 74 | pr_wv.append(pr_wv1) 75 | 76 | pr_sql_i1['conds'] = conds1_sorted 77 | 78 | return pr_wc, pr_wo, pr_wv, pr_sql_i 79 | 80 | def process(data,tokenize): 81 | final_all = [] 82 | badcase = 0 83 | for i, one_data in enumerate(data): 84 | nlu_t1 = one_data["question_tok"] 85 | 86 | # 1. 2nd tokenization using RoBERTa Tokenizer 87 | charindex2wordindex = {} 88 | total = 0 89 | tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens (here, CoreNLP). 90 | t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens. 91 | nlu_tt1 = [] # all_doc_tokens[ orig_to_tok_idx[i] ] returns first sub-token segement of i-th-1st-level-token 92 | for (ii, token) in enumerate(nlu_t1): 93 | t_to_tt_idx1.append( 94 | len(nlu_tt1)) # all_doc_tokens[ indicate the start position of original 'white-space' tokens. 95 | sub_tokens = tokenize.tokenize(token, is_pretokenized=True) 96 | for sub_token in sub_tokens: 97 | tt_to_t_idx1.append(ii) 98 | nlu_tt1.append(sub_token) # all_doc_tokens are further tokenized using RoBERTa tokenizer 99 | 100 | token_ = re_.sub('',token) 101 | for iii in range(len(token_)): 102 | charindex2wordindex[total+iii]=ii 103 | total += len(token_) 104 | 105 | one_final = one_data 106 | # one_table = table[one_data["table_id"]] 107 | final_question = [0] * len(nlu_tt1) 108 | one_final["bertindex_knowledge"] = final_question 109 | final_header = [0] * len(one_data["header"]) 110 | one_final["header_knowledge"] = final_header 111 | for ii,h in enumerate(one_data["header"]): 112 | h = h.lower() 113 | hs = h.split("/") 114 | for h_ in hs: 115 | flag, start_, end_ = contains2(re_.sub('', h_), "".join(one_data["question_tok"]).lower()) 116 | if flag == True: 117 | try: 118 | start = t_to_tt_idx1[charindex2wordindex[start_]] 119 | end = t_to_tt_idx1[charindex2wordindex[end_]] 120 | for iii in range(start,end): 121 | final_question[iii] = 4 122 | final_question[start] = 4 123 | final_question[end] = 4 124 | one_final["bertindex_knowledge"] = final_question 125 | except: 126 | # print("!!!!!") 127 | continue 128 | 129 | for ii,h in enumerate(one_data["header"]): 130 | h = h.lower() 131 | hs = h.split("/") 132 | for h_ in hs: 133 | flag, start_, end_ = contains2(re_.sub('', h_), "".join(one_data["question_tok"]).lower()) 134 | if flag == True: 135 | try: 136 | final_header[ii] = 1 137 | break 138 | except: 139 | # print("!!!!") 140 | continue 141 | 142 | one_final["header_knowledge"] = final_header 143 | 144 | if "bertindex_knowledge" not in one_final and len(one_final["sql"]["conds"])>0: 145 | one_final["bertindex_knowledge"] = [0] * len(nlu_tt1) 146 | badcase+=1 147 | 148 | final_all.append([one_data["question_tok"],one_final["bertindex_knowledge"],one_final["header_knowledge"]]) 149 | return final_all 150 | 151 | 152 | def contains2(small_str,big_str): 153 | if small_str in big_str: 154 | start = big_str.index(small_str) 155 | return True,start,start+len(small_str)-1 156 | else: 157 | return False,-1,-1 158 | 159 | def infer(nlu1, 160 | table_id, headers, types, tokenizer, 161 | model, model_roberta, roberta_config, max_seq_length, num_target_layers, 162 | beam_size=4): 163 | 164 | model.eval() 165 | model_roberta.eval() 166 | 167 | # Get inputs 168 | nlu = [nlu1] 169 | 170 | nlu_t1 = sent_split(nlu1) 171 | nlu_t = [nlu_t1] 172 | hds = [headers] 173 | hs_t = [[]] 174 | 175 | data = {} 176 | data['question_tok'] = nlu_t[0] 177 | data['table_id'] = table_id 178 | data['header'] = headers 179 | data = [data] 180 | 181 | tb = {} 182 | tb['id'] = table_id 183 | tb['header'] = headers 184 | tb['types'] = types 185 | tb = [tb] 186 | 187 | tk = tokenizer 188 | 189 | check = process(data, tk) 190 | knowledge = [check[0][1]] 191 | header_knowledge = [check[0][2]] 192 | 193 | wemb_n, wemb_h, l_n, l_hpu, l_hs, \ 194 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 195 | = get_wemb_roberta(roberta_config, model_roberta, tokenizer, nlu_t, hds, max_seq_length, 196 | num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) 197 | 198 | prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, 199 | l_hs, tb, 200 | nlu_t, nlu_tt, 201 | tt_to_t_idx, nlu, 202 | beam_size=beam_size, 203 | knowledge=knowledge, 204 | knowledge_header=header_knowledge) 205 | pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) 206 | 207 | 208 | if len(pr_sql_i) != 1: 209 | raise EnvironmentError 210 | pr_sql_q1 = generate_sql_q(pr_sql_i, tb) 211 | pr_sql_q = [pr_sql_q1] 212 | ''' 213 | print(f'START ============================================================= ') 214 | print(f'{hds}') 215 | print(f'nlu: {nlu}') 216 | print(f'pr_sql_i : {pr_sql_i}') 217 | print(f'pr_sql_q : {pr_sql_q}') 218 | print(f'---------------------------------------------------------------------') 219 | ''' 220 | return pr_sql_q1 221 | -------------------------------------------------------------------------------- /roberta_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda") 4 | 5 | def get_where_column(conds): 6 | """ 7 | [ [where_column, where_operator, where_value], 8 | [where_column, where_operator, where_value], ... 9 | ] 10 | """ 11 | where_column = [] 12 | for cond in conds: 13 | where_column.append(cond[0]) 14 | return where_column 15 | 16 | 17 | def get_where_operator(conds): 18 | """ 19 | [ [where_column, where_operator, where_value], 20 | [where_column, where_operator, where_value], ... 21 | ] 22 | """ 23 | where_operator = [] 24 | for cond in conds: 25 | where_operator.append(cond[1]) 26 | return where_operator 27 | 28 | 29 | def get_where_value(conds): 30 | """ 31 | [ [where_column, where_operator, where_value], 32 | [where_column, where_operator, where_value], ... 33 | ] 34 | """ 35 | where_value = [] 36 | for cond in conds: 37 | where_value.append(cond[2]) 38 | return where_value 39 | 40 | 41 | def get_ground_truth_values(canonical_sql_queries): 42 | 43 | ground_select_column = [] 44 | ground_select_aggregate = [] 45 | ground_where_number = [] 46 | ground_where_column = [] 47 | ground_where_operator = [] 48 | ground_where_value = [] 49 | for _, canonical_sql_query in enumerate(canonical_sql_queries): 50 | ground_select_column.append( canonical_sql_query["sel"] ) 51 | ground_select_aggregate.append( canonical_sql_query["agg"]) 52 | 53 | conds = canonical_sql_query['conds'] 54 | if not canonical_sql_query["agg"] < 0: 55 | ground_where_number.append( len( conds ) ) 56 | ground_where_column.append( get_where_column(conds) ) 57 | ground_where_operator.append( get_where_operator(conds) ) 58 | ground_where_value.append( get_where_value(conds) ) 59 | else: 60 | raise EnvironmentError 61 | return ground_select_column, ground_select_aggregate, ground_where_number, ground_where_column,\ 62 | ground_where_operator, ground_where_value 63 | 64 | 65 | def get_wemb_roberta(roberta_config, model_roberta, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=1, num_out_layers_h=1): 66 | ''' 67 | wemb_n : word embedding of natural language question 68 | wemb_h : word embedding of header 69 | l_n : length of natural question 70 | l_hs : length of header 71 | nlu_tt : Natural language double tokenized 72 | t_to_tt_idx : map first level tokenization to second level tokenization 73 | tt_to_t_idx : map second level tokenization to first level tokenization 74 | ''' 75 | # get contextual output of all tokens from RoBERTa 76 | all_encoder_layer, i_nlu, i_headers,\ 77 | l_n, l_hpu, l_hs, \ 78 | nlu_tt, t_to_tt_idx, tt_to_t_idx = get_roberta_output(model_roberta, tokenizer, nlu_t, hds, max_seq_length) 79 | # all_encoder_layer: RoBERTa outputs from all layers. 80 | # i_nlu: start and end indices of question in tokens 81 | # i_headers: start and end indices of headers 82 | 83 | # get the wemb 84 | wemb_n = get_wemb_n(i_nlu, l_n, roberta_config.hidden_size, roberta_config.num_hidden_layers, all_encoder_layer, 85 | num_out_layers_n) 86 | 87 | wemb_h = get_wemb_h(i_headers, l_hpu, l_hs, roberta_config.hidden_size, roberta_config.num_hidden_layers, all_encoder_layer, 88 | num_out_layers_h) 89 | 90 | return wemb_n, wemb_h, l_n, l_hpu, l_hs, \ 91 | nlu_tt, t_to_tt_idx, tt_to_t_idx 92 | 93 | def get_roberta_output(model_roberta, tokenizer, nlu_t, headers, max_seq_length): 94 | """ 95 | Here, input is toknized further by RoBERTa Tokenizer and fed into RoBERTa 96 | INPUT 97 | :param model_roberta: 98 | :param tokenizer: RoBERTa toknizer 99 | :param nlu: Question 100 | :param nlu_t: tokenized natural_language_utterance. 101 | :param headers: Headers of the table 102 | :param max_seq_length: max input token length 103 | OUTPUT 104 | tokens: RoBERTa input tokens 105 | nlu_tt: RoBERTa-tokenized input natural language questions 106 | orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token 107 | tok_to_orig_index: inverse map. 108 | """ 109 | 110 | l_n = [] 111 | l_hs = [] # The length of columns for each batch 112 | 113 | input_ids = [] 114 | input_mask = [] 115 | 116 | i_nlu = [] # index to retreive the position of contextual vector later. 117 | i_headers = [] 118 | 119 | nlu_tt = [] 120 | 121 | t_to_tt_idx = [] 122 | tt_to_t_idx = [] 123 | for b, nlu_t1 in enumerate(nlu_t): 124 | 125 | batch_headers = headers[b] 126 | l_hs.append(len(batch_headers)) 127 | 128 | 129 | # 1. Tokenization using RoBERTa Tokenizer 130 | tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens 131 | t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens. 132 | nlu_tt1 = [] 133 | for (i, token) in enumerate(nlu_t1): 134 | t_to_tt_idx1.append( 135 | len(nlu_tt1)) 136 | sub_tokens = tokenizer.tokenize(token, is_pretokenized=True) 137 | for sub_token in sub_tokens: 138 | tt_to_t_idx1.append(i) 139 | nlu_tt1.append(sub_token) 140 | nlu_tt.append(nlu_tt1) 141 | tt_to_t_idx.append(tt_to_t_idx1) 142 | t_to_tt_idx.append(t_to_tt_idx1) 143 | 144 | l_n.append(len(nlu_tt1)) 145 | 146 | # nlu col1 col2 ...col-n 147 | # 2. Generate RoBERTa inputs & indices. 148 | tokens, i_nlu1, i_batch_headers = generate_inputs(tokenizer, nlu_tt1, batch_headers) 149 | input_ids1 = tokenizer.convert_tokens_to_ids(tokens) 150 | 151 | # Input masks 152 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 153 | # tokens are attended to. 154 | input_mask1 = [1] * len(input_ids1) 155 | 156 | # 3. Zero-pad up to the sequence length. 157 | while len(input_ids1) < max_seq_length: 158 | input_ids1.append(0) 159 | input_mask1.append(0) 160 | 161 | assert len(input_ids1) == max_seq_length 162 | assert len(input_mask1) == max_seq_length 163 | 164 | input_ids.append(input_ids1) 165 | input_mask.append(input_mask1) 166 | 167 | i_nlu.append(i_nlu1) 168 | i_headers.append(i_batch_headers) 169 | 170 | # Convert to tensor 171 | all_input_ids = torch.tensor(input_ids, dtype=torch.long).to(device) 172 | all_input_mask = torch.tensor(input_mask, dtype=torch.long).to(device) 173 | 174 | # 4. Generate RoBERTa output. 175 | check, _, all_encoder_layer = model_roberta(input_ids=all_input_ids, attention_mask=all_input_mask, output_hidden_states=True) 176 | all_encoder_layer = list(all_encoder_layer) 177 | 178 | assert all((check == all_encoder_layer[-1]).tolist()) 179 | 180 | # 5. generate l_hpu from i_headers 181 | l_hpu = gen_l_hpu(i_headers) 182 | 183 | return all_encoder_layer, i_nlu, i_headers, \ 184 | l_n, l_hpu, l_hs, \ 185 | nlu_tt, t_to_tt_idx, tt_to_t_idx 186 | 187 | 188 | def generate_inputs(tokenizer, nlu1_tok, hds1): 189 | tokens = [] 190 | 191 | tokens.append("") 192 | i_st_nlu = len(tokens) # to use it later 193 | 194 | for token in nlu1_tok: 195 | tokens.append(token) 196 | i_ed_nlu = len(tokens) 197 | tokens.append("") 198 | 199 | i_headers = [] 200 | 201 | for i, hds11 in enumerate(hds1): 202 | i_st_hd = len(tokens) 203 | sub_tok = tokenizer.tokenize(hds11) 204 | tokens += sub_tok 205 | i_ed_hd = len(tokens) 206 | i_headers.append((i_st_hd, i_ed_hd)) 207 | if i < len(hds1)-1: 208 | tokens.append("") 209 | elif i == len(hds1)-1: 210 | tokens.append("") 211 | else: 212 | raise EnvironmentError 213 | 214 | i_nlu = (i_st_nlu, i_ed_nlu) 215 | 216 | return tokens, i_nlu, i_headers 217 | 218 | def gen_l_hpu(i_headers): 219 | """ 220 | # Treat columns as if it is a batch of natural language utterance with batch-size = # of columns * # of batch_size 221 | i_headers = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)]) 222 | """ 223 | l_hpu = [] 224 | 225 | for i_header in i_headers: 226 | for index_pair in i_header: 227 | l_hpu.append(index_pair[1] - index_pair[0]) 228 | 229 | return l_hpu 230 | 231 | def get_wemb_n(i_nlu, l_n, hS, num_hidden_layers, all_encoder_layer, num_out_layers_n): 232 | """ 233 | Get the representation of each tokens. 234 | """ 235 | bS = len(l_n) 236 | l_n_max = max(l_n) 237 | wemb_n = torch.zeros([bS, l_n_max, hS * num_out_layers_n]).to(device) 238 | for b in range(bS): 239 | 240 | l_n1 = l_n[b] 241 | i_nlu1 = i_nlu[b] 242 | for i_noln in range(num_out_layers_n): 243 | i_layer = num_hidden_layers - 1 - i_noln 244 | st = i_noln * hS 245 | ed = (i_noln + 1) * hS 246 | wemb_n[b, 0:(i_nlu1[1] - i_nlu1[0]), st:ed] = all_encoder_layer[i_layer][b, i_nlu1[0]:i_nlu1[1], :] 247 | 248 | return wemb_n 249 | 250 | def get_wemb_h(i_headers, l_hpu, l_hs, hS, num_hidden_layers, all_encoder_layer, num_out_layers_h): 251 | """ 252 | As if 253 | [ [table-1-col-1-tok1, t1-c1-t2, ...], 254 | [t1-c2-t1, t1-c2-t2, ...]. 255 | ... 256 | [t2-c1-t1, ...,] 257 | ] 258 | """ 259 | bS = len(l_hs) 260 | l_hpu_max = max(l_hpu) 261 | num_of_all_hds = sum(l_hs) 262 | wemb_h = torch.zeros([num_of_all_hds, l_hpu_max, hS * num_out_layers_h]).to(device) 263 | b_pu = -1 264 | for b, i_header in enumerate(i_headers): 265 | for b1, index_pair in enumerate(i_header): 266 | b_pu += 1 267 | for i_nolh in range(num_out_layers_h): 268 | i_layer = num_hidden_layers - 1 - i_nolh 269 | st = i_nolh * hS 270 | ed = (i_nolh + 1) * hS 271 | wemb_h[b_pu, 0:(index_pair[1] - index_pair[0]), st:ed] \ 272 | = all_encoder_layer[i_layer][b, index_pair[0]:index_pair[1],:] 273 | 274 | return wemb_h 275 | -------------------------------------------------------------------------------- /seq2sql_model_internal_functions.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import random as rd 5 | from copy import deepcopy 6 | 7 | from matplotlib.pylab import * 8 | 9 | import math 10 | import torch 11 | import torchvision.datasets as dsets 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | # import torch_xla 15 | # import torch_xla.core.xla_model as xm 16 | 17 | device = torch.device("cuda") 18 | 19 | 20 | 21 | 22 | def encode(lstm, wemb_l, l, return_hidden=False, hc0=None, last_only=False): 23 | """ [batch_size, max token length, dim_emb] 24 | """ 25 | bS, mL, eS = wemb_l.shape 26 | 27 | # sort before packking 28 | l = array(l) 29 | perm_idx = argsort(-l) 30 | perm_idx_inv = generate_perm_inv(perm_idx) 31 | 32 | # pack sequence 33 | 34 | packed_wemb_l = nn.utils.rnn.pack_padded_sequence(wemb_l[perm_idx, :, :], 35 | l[perm_idx], 36 | batch_first=True) 37 | 38 | # Time to encode 39 | if hc0 is not None: 40 | hc0 = (hc0[0][:, perm_idx], hc0[1][:, perm_idx]) 41 | 42 | # ipdb.set_trace() 43 | packed_wemb_l = packed_wemb_l.float() # I don't know why.. 44 | packed_wenc, hc_out = lstm(packed_wemb_l, hc0) 45 | hout, cout = hc_out 46 | 47 | # unpack 48 | wenc, _l = nn.utils.rnn.pad_packed_sequence(packed_wenc, batch_first=True) 49 | 50 | if last_only: 51 | # Take only final outputs for each columns. 52 | wenc = wenc[tuple(range(bS)), l[perm_idx] - 1] # [batch_size, dim_emb] 53 | wenc.unsqueeze_(1) # [batch_size, 1, dim_emb] 54 | 55 | wenc = wenc[perm_idx_inv] 56 | 57 | if return_hidden: 58 | # hout.shape = [number_of_directoin * num_of_layer, seq_len(=batch size), dim * number_of_direction ] w/ batch_first.. w/o batch_first? I need to see. 59 | hout = hout[:, perm_idx_inv].to(device) 60 | cout = cout[:, perm_idx_inv].to(device) # Is this correct operation? 61 | 62 | return wenc, hout, cout 63 | else: 64 | return wenc 65 | 66 | 67 | def encode_hpu(lstm, wemb_hpu, l_hpu, l_hs): 68 | wenc_hpu, hout, cout = encode(lstm, 69 | wemb_hpu, 70 | l_hpu, 71 | return_hidden=True, 72 | hc0=None, 73 | last_only=True) 74 | 75 | wenc_hpu = wenc_hpu.squeeze(1) 76 | bS_hpu, mL_hpu, eS = wemb_hpu.shape 77 | hS = wenc_hpu.size(-1) 78 | 79 | wenc_hs = wenc_hpu.new_zeros(len(l_hs), max(l_hs), hS) 80 | wenc_hs = wenc_hs.to(device) 81 | 82 | # Re-pack according to batch. 83 | # ret = [B_NLq, max_len_headers_all, dim_lstm] 84 | st = 0 85 | for i, l_hs1 in enumerate(l_hs): 86 | wenc_hs[i, :l_hs1] = wenc_hpu[st:(st + l_hs1)] 87 | st += l_hs1 88 | 89 | return wenc_hs 90 | 91 | 92 | def generate_perm_inv(perm): 93 | # Definitly correct. 94 | perm_inv = zeros(len(perm), dtype=int) # Was an undefine int32 variable 95 | for i, p in enumerate(perm): 96 | perm_inv[int(p)] = i 97 | 98 | return perm_inv 99 | 100 | 101 | def pred_sc(s_sc): 102 | """ 103 | return: [ pr_wc1_i, pr_wc2_i, ...] 104 | """ 105 | # get g_num 106 | pr_sc = [] 107 | for s_sc1 in s_sc: 108 | pr_sc.append(s_sc1.argmax().item()) 109 | 110 | return pr_sc 111 | 112 | 113 | def pred_sc_beam(s_sc, beam_size): 114 | """ 115 | return: [ pr_wc1_i, pr_wc2_i, ...] 116 | """ 117 | # get g_num 118 | pr_sc_beam = [] 119 | 120 | for s_sc1 in s_sc: 121 | val, idxes = s_sc1.topk(k=beam_size) 122 | pr_sc_beam.append(idxes.tolist()) 123 | 124 | return pr_sc_beam 125 | 126 | 127 | def pred_sa(s_sa): 128 | """ 129 | return: [ pr_wc1_i, pr_wc2_i, ...] 130 | """ 131 | # get g_num 132 | pr_sa = [] 133 | for s_sa1 in s_sa: 134 | pr_sa.append(s_sa1.argmax().item()) 135 | 136 | return pr_sa 137 | 138 | 139 | def pred_wn(s_wn): 140 | """ 141 | return: [ pr_wc1_i, pr_wc2_i, ...] 142 | """ 143 | # get g_num 144 | pr_wn = [] 145 | for s_wn1 in s_wn: 146 | pr_wn.append(s_wn1.argmax().item()) 147 | # print(pr_wn, s_wn1) 148 | # if s_wn1.argmax().item() == 3: 149 | # input('') 150 | 151 | return pr_wn 152 | 153 | 154 | def pred_wc(wn, s_wc): 155 | """ 156 | return: [ pr_wc1_i, pr_wc2_i, ...] 157 | ! Returned index is sorted! 158 | """ 159 | # get g_num 160 | pr_wc = [] 161 | for b, wn1 in enumerate(wn): 162 | s_wc1 = s_wc[b] 163 | 164 | pr_wc1 = argsort(-s_wc1.data.cpu().numpy())[:wn1] 165 | pr_wc1.sort() 166 | 167 | pr_wc.append(list(pr_wc1)) 168 | return pr_wc 169 | 170 | 171 | def pred_wo(wn, s_wo): 172 | """ 173 | return: [ pr_wc1_i, pr_wc2_i, ...] 174 | """ 175 | # s_wo = [B, 4, n_op] 176 | pr_wo_a = s_wo.argmax(dim=2) # [B, 4] 177 | # get g_num 178 | pr_wo = [] 179 | for b, pr_wo_a1 in enumerate(pr_wo_a): 180 | wn1 = wn[b] 181 | pr_wo.append(list(pr_wo_a1.data.cpu().numpy()[:wn1])) 182 | 183 | return pr_wo 184 | 185 | 186 | def topk_multi_dim(tensor, n_topk=1, batch_exist=True): 187 | 188 | if batch_exist: 189 | idxs = [] 190 | for b, tensor1 in enumerate(tensor): 191 | idxs1 = [] 192 | tensor1_1d = tensor1.reshape(-1) 193 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk) 194 | idxs_list = unravel_index(idxs_1d.cpu().numpy(), tensor1.shape) 195 | # (dim0, dim1, dim2, ...) 196 | 197 | # reconstruct 198 | for i_beam in range(n_topk): 199 | idxs11 = [] 200 | for idxs_list1 in idxs_list: 201 | idxs11.append(idxs_list1[i_beam]) 202 | idxs1.append(idxs11) 203 | idxs.append(idxs1) 204 | 205 | else: 206 | tensor1 = tensor 207 | idxs1 = [] 208 | tensor1_1d = tensor1.reshape(-1) 209 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk) 210 | idxs_list = unravel_index(idxs_1d.numpy(), tensor1.shape) 211 | # (dim0, dim1, dim2, ...) 212 | 213 | # reconstruct 214 | for i_beam in range(n_topk): 215 | idxs11 = [] 216 | for idxs_list1 in idxs_list: 217 | idxs11.append(idxs_list1[i_beam]) 218 | idxs1.append(idxs11) 219 | idxs = idxs1 220 | return idxs 221 | 222 | 223 | def remap_sc_idx(idxs, pr_sc_beam): 224 | for b, idxs1 in enumerate(idxs): 225 | for i_beam, idxs11 in enumerate(idxs1): 226 | sc_beam_idx = idxs[b][i_beam][0] 227 | sc_idx = pr_sc_beam[b][sc_beam_idx] 228 | idxs[b][i_beam][0] = sc_idx 229 | 230 | return idxs 231 | 232 | 233 | def check_sc_sa_pairs(tb, pr_sc, pr_sa, ): 234 | """ 235 | Check whether pr_sc, pr_sa are allowed pairs or not. 236 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 237 | """ 238 | bS = len(pr_sc) 239 | check = [False] * bS 240 | for b, pr_sc1 in enumerate(pr_sc): 241 | pr_sa1 = pr_sa[b] 242 | hd_types1 = tb[b]['types'] 243 | hd_types11 = hd_types1[pr_sc1] 244 | if hd_types11 == 'text': 245 | if pr_sa1 == 0 or pr_sa1 == 3: # '' 246 | check[b] = True 247 | else: 248 | check[b] = False 249 | 250 | elif hd_types11 == 'real': 251 | check[b] = True 252 | else: 253 | raise Exception("New TYPE!!") 254 | 255 | return check 256 | 257 | 258 | def pred_wvi_se_beam(max_wn, s_wv, beam_size): 259 | """ 260 | s_wv: [B, 4, mL, 2] 261 | - predict best st-idx & ed-idx 262 | output: 263 | pr_wvi_beam = [B, max_wn, n_pairs, 2]. 2 means [st, ed]. 264 | prob_wvi_beam = [B, max_wn, n_pairs] 265 | """ 266 | bS = s_wv.shape[0] 267 | 268 | # [B, 4, mL, 2] -> [B, 4, mL, 1], [B, 4, mL, 1] 269 | s_wv_st, s_wv_ed = s_wv.split(1, dim=3) 270 | 271 | s_wv_st = s_wv_st.squeeze(3) # [B, 4, mL, 1] -> [B, 4, mL] 272 | s_wv_ed = s_wv_ed.squeeze(3) 273 | 274 | prob_wv_st = F.softmax(s_wv_st, dim=-1).detach().to('cpu').numpy() 275 | prob_wv_ed = F.softmax(s_wv_ed, dim=-1).detach().to('cpu').numpy() 276 | 277 | k_logit = int(ceil(sqrt(beam_size))) 278 | n_pairs = k_logit**2 279 | assert n_pairs >= beam_size 280 | values_st, idxs_st = s_wv_st.topk(k_logit) # [B, 4, mL] -> [B, 4, k_logit] 281 | values_ed, idxs_ed = s_wv_ed.topk(k_logit) # [B, 4, mL] -> [B, 4, k_logit] 282 | 283 | # idxs = [B, k_logit, 2] 284 | # Generate all possible combination of st, ed indices & prob 285 | pr_wvi_beam = [] # [B, max_wn, k_logit**2 [st, ed] paris] 286 | prob_wvi_beam = zeros([bS, max_wn, n_pairs]) 287 | for b in range(bS): 288 | pr_wvi_beam1 = [] 289 | 290 | idxs_st1 = idxs_st[b] 291 | idxs_ed1 = idxs_ed[b] 292 | for i_wn in range(max_wn): 293 | idxs_st11 = idxs_st1[i_wn] 294 | idxs_ed11 = idxs_ed1[i_wn] 295 | 296 | pr_wvi_beam11 = [] 297 | pair_idx = -1 298 | for i_k in range(k_logit): 299 | for j_k in range(k_logit): 300 | pair_idx += 1 301 | st = idxs_st11[i_k].item() 302 | ed = idxs_ed11[j_k].item() 303 | pr_wvi_beam11.append([st, ed]) 304 | 305 | p1 = prob_wv_st[b, i_wn, st] 306 | p2 = prob_wv_ed[b, i_wn, ed] 307 | prob_wvi_beam[b, i_wn, pair_idx] = p1*p2 308 | pr_wvi_beam1.append(pr_wvi_beam11) 309 | pr_wvi_beam.append(pr_wvi_beam1) 310 | 311 | # prob 312 | 313 | return pr_wvi_beam, prob_wvi_beam 314 | 315 | 316 | def convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_wp_t, wp_to_wh_index, nlu): 317 | """ 318 | - Convert to the string in whilte-space-separated tokens 319 | - Add-hoc addition. 320 | """ 321 | pr_wv_str_wp = [] # word-piece version 322 | pr_wv_str = [] 323 | for b, pr_wvi1 in enumerate(pr_wvi): 324 | pr_wv_str_wp1 = [] 325 | pr_wv_str1 = [] 326 | wp_to_wh_index1 = wp_to_wh_index[b] 327 | nlu_wp_t1 = nlu_wp_t[b] 328 | nlu_t1 = nlu_t[b] 329 | 330 | for i_wn, pr_wvi11 in enumerate(pr_wvi1): 331 | st_idx, ed_idx = pr_wvi11 332 | 333 | # Ad-hoc modification of ed_idx to deal with wp-tokenization effect. 334 | # e.g.) to convert "butler cc (" ->"butler cc (ks)" (dev set 1st question). 335 | pr_wv_str_wp11 = nlu_wp_t1[st_idx:ed_idx+1] 336 | pr_wv_str_wp1.append(pr_wv_str_wp11) 337 | 338 | st_wh_idx = wp_to_wh_index1[st_idx] 339 | ed_wh_idx = wp_to_wh_index1[ed_idx] 340 | pr_wv_str11 = nlu_t1[st_wh_idx:ed_wh_idx+1] 341 | 342 | pr_wv_str1.append(pr_wv_str11) 343 | 344 | pr_wv_str_wp.append(pr_wv_str_wp1) 345 | pr_wv_str.append(pr_wv_str1) 346 | 347 | return pr_wv_str, pr_wv_str_wp 348 | 349 | 350 | def merge_wv_t1_eng(where_str_tokens, NLq): 351 | """ 352 | Almost copied of SQLNet. 353 | The main purpose is pad blank line while combining tokens. 354 | """ 355 | nlq = NLq.lower() 356 | where_str_tokens = [tok.lower() for tok in where_str_tokens] 357 | alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$' 358 | special = {'-LRB-': '(', 359 | '-RRB-': ')', 360 | '-LSB-': '[', 361 | '-RSB-': ']', 362 | '``': '"', 363 | '\'\'': '"', 364 | } 365 | # '--': '\u2013'} # this generate error for test 5661 case. 366 | ret = '' 367 | double_quote_appear = 0 368 | for raw_w_token in where_str_tokens: 369 | # if '' (empty string) of None, continue 370 | if not raw_w_token: 371 | continue 372 | 373 | # Change the special characters 374 | # maybe necessary for some case? 375 | w_token = special.get(raw_w_token, raw_w_token) 376 | 377 | # check the double quote 378 | if w_token == '"': 379 | double_quote_appear = 1 - double_quote_appear 380 | 381 | # Check whether ret is empty. ret is selected where condition. 382 | if len(ret) == 0: 383 | pass 384 | # Check blank character. 385 | elif len(ret) > 0 and ret + ' ' + w_token in nlq: 386 | # Pad ' ' if ret + ' ' is part of nlq. 387 | ret = ret + ' ' 388 | 389 | elif len(ret) > 0 and ret + w_token in nlq: 390 | pass # already in good form. Later, ret + w_token will performed. 391 | 392 | # Below for unnatural question I guess. Is it likely to appear? 393 | elif w_token == '"': 394 | if double_quote_appear: 395 | ret = ret + ' ' # pad blank line between next token when " because in this case, it is of closing apperas 396 | # for the case of opening, no blank line. 397 | 398 | elif w_token[0] not in alphabet: 399 | pass # non alphabet one does not pad blank line. 400 | 401 | # when previous character is the special case. 402 | elif (ret[-1] not in ['(', '/', '\u2013', '#', '$', '&']) and (ret[-1] != '"' or not double_quote_appear): 403 | ret = ret + ' ' 404 | ret = ret + w_token 405 | 406 | return ret.strip() 407 | -------------------------------------------------------------------------------- /seq2sql_model_training_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from matplotlib.pylab import * 5 | from copy import deepcopy 6 | #import torch_xla 7 | #import torch_xla.core.xla_model as xm 8 | 9 | device = torch.device("cuda") 10 | 11 | 12 | def Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi): 13 | """ 14 | :param s_wv: score [ B, n_conds, T, score] 15 | :param g_wn: [ B ] 16 | :param g_wvi: [B, conds, pnt], e.g. [[[0, 6, 7, 8, 15], [0, 1, 2, 3, 4, 15]], [[0, 1, 2, 3, 16], [0, 7, 8, 9, 16]]] 17 | :return: 18 | """ 19 | loss = 0 20 | loss += F.cross_entropy(s_sc, torch.tensor(g_sc).to(device)) 21 | loss += F.cross_entropy(s_sa, torch.tensor(g_sa).to(device)) 22 | loss += F.cross_entropy(s_wn, torch.tensor(g_wn).to(device)) 23 | loss += Loss_wc(s_wc, g_wc) 24 | loss += Loss_wo(s_wo, g_wn, g_wo) 25 | loss += Loss_wv_se(s_wv, g_wn, g_wvi) 26 | 27 | return loss 28 | 29 | 30 | def Loss_wc(s_wc, g_wc): 31 | 32 | # Construct index matrix 33 | bS, max_h_len = s_wc.shape 34 | im = torch.zeros([bS, max_h_len]).to(device) 35 | for b, g_wc1 in enumerate(g_wc): 36 | for g_wc11 in g_wc1: 37 | im[b, g_wc11] = 1.0 38 | # Construct prob. 39 | p = F.sigmoid(s_wc) 40 | loss = F.binary_cross_entropy(p, im) 41 | 42 | return loss 43 | 44 | 45 | def Loss_wo(s_wo, g_wn, g_wo): 46 | 47 | # Construct index matrix 48 | loss = 0 49 | for b, g_wn1 in enumerate(g_wn): 50 | if g_wn1 == 0: 51 | continue 52 | g_wo1 = g_wo[b] 53 | s_wo1 = s_wo[b] 54 | loss += F.cross_entropy(s_wo1[:g_wn1], torch.tensor(g_wo1).to(device)) 55 | 56 | return loss 57 | 58 | def Loss_wv_se(s_wv, g_wn, g_wvi): 59 | """ 60 | s_wv: [bS, 4, mL, 2], 4 stands for maximum # of condition, 2 tands for start & end logits. 61 | g_wvi: [ [1, 3, 2], [4,3] ] (when B=2, wn(b=1) = 3, wn(b=2) = 2). 62 | """ 63 | loss = 0 64 | # g_wvi = torch.tensor(g_wvi).to(device) 65 | for b, g_wvi1 in enumerate(g_wvi): 66 | # for i_wn, g_wvi11 in enumerate(g_wvi1): 67 | 68 | g_wn1 = g_wn[b] 69 | if g_wn1 == 0: 70 | continue 71 | g_wvi1 = torch.tensor(g_wvi1).to(device) 72 | g_st1 = g_wvi1[:,0] 73 | g_ed1 = g_wvi1[:,1] 74 | # loss from the start position 75 | loss += F.cross_entropy(s_wv[b,:g_wn1,:,0], g_st1) 76 | 77 | # print("st_login: ", s_wv[b,:g_wn1,:,0], g_st1, loss) 78 | # loss from the end position 79 | loss += F.cross_entropy(s_wv[b,:g_wn1,:,1], g_ed1) 80 | # print("ed_login: ", s_wv[b,:g_wn1,:,1], g_ed1, loss) 81 | 82 | return loss 83 | 84 | def pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv): 85 | pr_sc = pred_sc(s_sc) 86 | pr_sa = pred_sa(s_sa) 87 | pr_wn = pred_wn(s_wn) 88 | pr_wc = pred_wc(pr_wn, s_wc) 89 | pr_wo = pred_wo(pr_wn, s_wo) 90 | pr_wvi = pred_wvi_se(pr_wn, s_wv) 91 | 92 | return pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi 93 | 94 | def pred_sc(s_sc): 95 | """ 96 | return: [ pr_wc1_i, pr_wc2_i, ...] 97 | """ 98 | # get g_num 99 | pr_sc = [] 100 | for s_sc1 in s_sc: 101 | pr_sc.append(s_sc1.argmax().item()) 102 | 103 | return pr_sc 104 | 105 | def pred_sc(s_sc): 106 | """ 107 | return: [ pr_wc1_i, pr_wc2_i, ...] 108 | """ 109 | # get g_num 110 | pr_sc = [] 111 | for s_sc1 in s_sc: 112 | pr_sc.append(s_sc1.argmax().item()) 113 | 114 | return pr_sc 115 | 116 | def pred_sa(s_sa): 117 | """ 118 | return: [ pr_wc1_i, pr_wc2_i, ...] 119 | """ 120 | # get g_num 121 | pr_sa = [] 122 | for s_sa1 in s_sa: 123 | pr_sa.append(s_sa1.argmax().item()) 124 | 125 | return pr_sa 126 | 127 | def pred_wn(s_wn): 128 | """ 129 | return: [ pr_wc1_i, pr_wc2_i, ...] 130 | """ 131 | # get g_num 132 | pr_wn = [] 133 | for s_wn1 in s_wn: 134 | pr_wn.append(s_wn1.argmax().item()) 135 | # print(pr_wn, s_wn1) 136 | # if s_wn1.argmax().item() == 3: 137 | # input('') 138 | 139 | return pr_wn 140 | 141 | def pred_wc(wn, s_wc): 142 | """ 143 | return: [ pr_wc1_i, pr_wc2_i, ...] 144 | ! Returned index is sorted! 145 | """ 146 | # get g_num 147 | pr_wc = [] 148 | for b, wn1 in enumerate(wn): 149 | s_wc1 = s_wc[b] 150 | 151 | pr_wc1 = argsort(-s_wc1.data.cpu().numpy())[:wn1] 152 | pr_wc1.sort() 153 | 154 | pr_wc.append(list(pr_wc1)) 155 | return pr_wc 156 | 157 | def pred_wo(wn, s_wo): 158 | """ 159 | return: [ pr_wc1_i, pr_wc2_i, ...] 160 | """ 161 | # s_wo = [B, 4, n_op] 162 | pr_wo_a = s_wo.argmax(dim=2) # [B, 4] 163 | # get g_num 164 | pr_wo = [] 165 | for b, pr_wo_a1 in enumerate(pr_wo_a): 166 | wn1 = wn[b] 167 | pr_wo.append(list(pr_wo_a1.data.cpu().numpy()[:wn1])) 168 | 169 | return pr_wo 170 | 171 | def pred_wvi_se(wn, s_wv): 172 | """ 173 | s_wv: [B, 4, mL, 2] 174 | - predict best st-idx & ed-idx 175 | """ 176 | 177 | s_wv_st, s_wv_ed = s_wv.split(1, dim=3) # [B, 4, mL, 2] -> [B, 4, mL, 1], [B, 4, mL, 1] 178 | 179 | s_wv_st = s_wv_st.squeeze(3) # [B, 4, mL, 1] -> [B, 4, mL] 180 | s_wv_ed = s_wv_ed.squeeze(3) 181 | 182 | pr_wvi_st_idx = s_wv_st.argmax(dim=2) # [B, 4, mL] -> [B, 4, 1] 183 | pr_wvi_ed_idx = s_wv_ed.argmax(dim=2) 184 | 185 | pr_wvi = [] 186 | for b, wn1 in enumerate(wn): 187 | pr_wvi1 = [] 188 | for i_wn in range(wn1): 189 | pr_wvi_st_idx11 = pr_wvi_st_idx[b][i_wn] 190 | pr_wvi_ed_idx11 = pr_wvi_ed_idx[b][i_wn] 191 | pr_wvi1.append([pr_wvi_st_idx11.item(), pr_wvi_ed_idx11.item()]) 192 | pr_wvi.append(pr_wvi1) 193 | 194 | return pr_wvi 195 | 196 | def convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_wp_t, wp_to_wh_index, nlu): 197 | """ 198 | - Convert to the string in whilte-space-separated tokens 199 | - Add-hoc addition. 200 | """ 201 | pr_wv_str_wp = [] # word-piece version 202 | pr_wv_str = [] 203 | for b, pr_wvi1 in enumerate(pr_wvi): 204 | pr_wv_str_wp1 = [] 205 | pr_wv_str1 = [] 206 | wp_to_wh_index1 = wp_to_wh_index[b] 207 | nlu_wp_t1 = nlu_wp_t[b] 208 | nlu_t1 = nlu_t[b] 209 | 210 | for i_wn, pr_wvi11 in enumerate(pr_wvi1): 211 | st_idx, ed_idx = pr_wvi11 212 | 213 | # Ad-hoc modification of ed_idx to deal with wp-tokenization effect. 214 | # e.g.) to convert "butler cc (" ->"butler cc (ks)" (dev set 1st question). 215 | pr_wv_str_wp11 = nlu_wp_t1[st_idx:ed_idx+1] 216 | pr_wv_str_wp1.append(pr_wv_str_wp11) 217 | 218 | st_wh_idx = wp_to_wh_index1[st_idx] 219 | ed_wh_idx = wp_to_wh_index1[ed_idx] 220 | pr_wv_str11 = nlu_t1[st_wh_idx:ed_wh_idx+1] 221 | 222 | pr_wv_str1.append(pr_wv_str11) 223 | 224 | pr_wv_str_wp.append(pr_wv_str_wp1) 225 | pr_wv_str.append(pr_wv_str1) 226 | 227 | return pr_wv_str, pr_wv_str_wp 228 | 229 | def sort_pr_wc(pr_wc, g_wc): 230 | """ 231 | Input: list 232 | pr_wc = [B, n_conds] 233 | g_wc = [B, n_conds] 234 | Return: list 235 | pr_wc_sorted = [B, n_conds] 236 | """ 237 | pr_wc_sorted = [] 238 | for b, pr_wc1 in enumerate(pr_wc): 239 | g_wc1 = g_wc[b] 240 | pr_wc1_sorted = [] 241 | 242 | if set(g_wc1) == set(pr_wc1): 243 | pr_wc1_sorted = deepcopy(g_wc1) 244 | else: 245 | # no sorting when g_wc1 and pr_wc1 are different. 246 | pr_wc1_sorted = deepcopy(pr_wc1) 247 | 248 | pr_wc_sorted.append(pr_wc1_sorted) 249 | return pr_wc_sorted 250 | 251 | def generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu): 252 | pr_sql_i = [] 253 | for b, nlu1 in enumerate(nlu): 254 | conds = [] 255 | for i_wn in range(pr_wn[b]): 256 | conds1 = [] 257 | conds1.append(pr_wc[b][i_wn]) 258 | conds1.append(pr_wo[b][i_wn]) 259 | merged_wv11 = merge_wv_t1_eng(pr_wv_str[b][i_wn], nlu[b]) 260 | conds1.append(merged_wv11) 261 | conds.append(conds1) 262 | 263 | pr_sql_i1 = {'agg': pr_sa[b], 'sel': pr_sc[b], 'conds': conds} 264 | pr_sql_i.append(pr_sql_i1) 265 | return pr_sql_i 266 | 267 | def merge_wv_t1_eng(where_str_tokens, NLq): 268 | """ 269 | Almost copied of SQLNet. 270 | The main purpose is pad blank line while combining tokens. 271 | """ 272 | nlq = NLq.lower() 273 | where_str_tokens = [tok.lower() for tok in where_str_tokens] 274 | alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$' 275 | special = {'-LRB-': '(', 276 | '-RRB-': ')', 277 | '-LSB-': '[', 278 | '-RSB-': ']', 279 | '``': '"', 280 | '\'\'': '"', 281 | } 282 | # '--': '\u2013'} # this generate error for test 5661 case. 283 | ret = '' 284 | double_quote_appear = 0 285 | for raw_w_token in where_str_tokens: 286 | # if '' (empty string) of None, continue 287 | if not raw_w_token: 288 | continue 289 | 290 | # Change the special characters 291 | w_token = special.get(raw_w_token, raw_w_token) # maybe necessary for some case? 292 | 293 | # check the double quote 294 | if w_token == '"': 295 | double_quote_appear = 1 - double_quote_appear 296 | 297 | # Check whether ret is empty. ret is selected where condition. 298 | if len(ret) == 0: 299 | pass 300 | # Check blank character. 301 | elif len(ret) > 0 and ret + ' ' + w_token in nlq: 302 | # Pad ' ' if ret + ' ' is part of nlq. 303 | ret = ret + ' ' 304 | 305 | elif len(ret) > 0 and ret + w_token in nlq: 306 | pass # already in good form. Later, ret + w_token will performed. 307 | 308 | # Below for unnatural question I guess. Is it likely to appear? 309 | elif w_token == '"': 310 | if double_quote_appear: 311 | ret = ret + ' ' # pad blank line between next token when " because in this case, it is of closing apperas 312 | # for the case of opening, no blank line. 313 | 314 | elif w_token[0] not in alphabet: 315 | pass # non alphabet one does not pad blank line. 316 | 317 | # when previous character is the special case. 318 | elif (ret[-1] not in ['(', '/', '\u2013', '#', '$', '&']) and (ret[-1] != '"' or not double_quote_appear): 319 | ret = ret + ' ' 320 | ret = ret + w_token 321 | 322 | return ret.strip() 323 | 324 | def get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, 325 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, 326 | g_sql_i, pr_sql_i, 327 | mode): 328 | """ usalbe only when g_wc was used to find pr_wv 329 | """ 330 | cnt_sc = get_cnt_sc_list(g_sc, pr_sc) 331 | cnt_sa = get_cnt_sc_list(g_sa, pr_sa) 332 | cnt_wn = get_cnt_sc_list(g_wn, pr_wn) 333 | cnt_wc = get_cnt_wc_list(g_wc, pr_wc) 334 | cnt_wo = get_cnt_wo_list(g_wn, g_wc, g_wo, pr_wc, pr_wo, mode) 335 | if pr_wvi: 336 | cnt_wvi = get_cnt_wvi_list(g_wn, g_wc, g_wvi, pr_wvi, mode) 337 | else: 338 | cnt_wvi = [0]*len(cnt_sc) 339 | cnt_wv = get_cnt_wv_list(g_wn, g_wc, g_sql_i, pr_sql_i, mode) # compare using wv-str which presented in original data. 340 | 341 | 342 | return cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wvi, cnt_wv 343 | 344 | def get_cnt_sc_list(g_sc, pr_sc): 345 | cnt_list = [] 346 | for b, g_sc1 in enumerate(g_sc): 347 | pr_sc1 = pr_sc[b] 348 | if pr_sc1 == g_sc1: 349 | cnt_list.append(1) 350 | else: 351 | cnt_list.append(0) 352 | 353 | return cnt_list 354 | 355 | def get_cnt_wc_list(g_wc, pr_wc): 356 | cnt_list= [] 357 | for b, g_wc1 in enumerate(g_wc): 358 | 359 | pr_wc1 = pr_wc[b] 360 | pr_wn1 = len(pr_wc1) 361 | g_wn1 = len(g_wc1) 362 | 363 | if pr_wn1 != g_wn1: 364 | cnt_list.append(0) 365 | continue 366 | else: 367 | wc1 = array(g_wc1) 368 | wc1.sort() 369 | 370 | if array_equal(pr_wc1, wc1): 371 | cnt_list.append(1) 372 | else: 373 | cnt_list.append(0) 374 | 375 | return cnt_list 376 | 377 | def get_cnt_wo_list(g_wn, g_wc, g_wo, pr_wc, pr_wo, mode): 378 | """ pr's are all sorted as pr_wc are sorted in increasing order (in column idx) 379 | However, g's are not sorted. 380 | Sort g's in increasing order (in column idx) 381 | """ 382 | cnt_list=[] 383 | for b, g_wo1 in enumerate(g_wo): 384 | g_wc1 = g_wc[b] 385 | pr_wc1 = pr_wc[b] 386 | pr_wo1 = pr_wo[b] 387 | pr_wn1 = len(pr_wo1) 388 | g_wn1 = g_wn[b] 389 | 390 | if g_wn1 != pr_wn1: 391 | cnt_list.append(0) 392 | continue 393 | else: 394 | # Sort based wc sequence. 395 | if mode == 'test': 396 | idx = argsort(array(g_wc1)) 397 | 398 | g_wo1_s = array(g_wo1)[idx] 399 | g_wo1_s = list(g_wo1_s) 400 | elif mode == 'train': 401 | # due to tearch forcing, no need to sort. 402 | g_wo1_s = g_wo1 403 | else: 404 | raise ValueError 405 | 406 | if type(pr_wo1) != list: 407 | raise TypeError 408 | if g_wo1_s == pr_wo1: 409 | cnt_list.append(1) 410 | else: 411 | cnt_list.append(0) 412 | return cnt_list 413 | 414 | def get_cnt_wvi_list(g_wn, g_wc, g_wvi, pr_wvi, mode): 415 | """ usalbe only when g_wc was used to find pr_wv 416 | """ 417 | cnt_list =[] 418 | for b, g_wvi1 in enumerate(g_wvi): 419 | g_wc1 = g_wc[b] 420 | pr_wvi1 = pr_wvi[b] 421 | pr_wn1 = len(pr_wvi1) 422 | g_wn1 = g_wn[b] 423 | 424 | # Now sorting. 425 | # Sort based wc sequence. 426 | if mode == 'test': 427 | idx1 = argsort(array(g_wc1)) 428 | elif mode == 'train': 429 | idx1 = list( range( g_wn1) ) 430 | else: 431 | raise ValueError 432 | 433 | if g_wn1 != pr_wn1: 434 | cnt_list.append(0) 435 | continue 436 | else: 437 | flag = True 438 | for i_wn, idx11 in enumerate(idx1): 439 | g_wvi11 = g_wvi1[idx11] 440 | pr_wvi11 = pr_wvi1[i_wn] 441 | if g_wvi11 != pr_wvi11: 442 | flag = False 443 | # print(g_wv1, g_wv11) 444 | # print(pr_wv1, pr_wv11) 445 | # input('') 446 | break 447 | if flag: 448 | cnt_list.append(1) 449 | else: 450 | cnt_list.append(0) 451 | 452 | return cnt_list 453 | 454 | def get_cnt_wv_list(g_wn, g_wc, g_sql_i, pr_sql_i, mode): 455 | """ usalbe only when g_wc was used to find pr_wv 456 | """ 457 | cnt_list =[] 458 | for b, g_wc1 in enumerate(g_wc): 459 | pr_wn1 = len(pr_sql_i[b]["conds"]) 460 | g_wn1 = g_wn[b] 461 | 462 | # Now sorting. 463 | # Sort based wc sequence. 464 | if mode == 'test': 465 | idx1 = argsort(array(g_wc1)) 466 | elif mode == 'train': 467 | idx1 = list( range( g_wn1) ) 468 | else: 469 | raise ValueError 470 | 471 | if g_wn1 != pr_wn1: 472 | cnt_list.append(0) 473 | continue 474 | else: 475 | flag = True 476 | for i_wn, idx11 in enumerate(idx1): 477 | g_wvi_str11 = str(g_sql_i[b]["conds"][idx11][2]).lower() 478 | pr_wvi_str11 = str(pr_sql_i[b]["conds"][i_wn][2]).lower() 479 | # print(g_wvi_str11) 480 | # print(pr_wvi_str11) 481 | # print(g_wvi_str11==pr_wvi_str11) 482 | if g_wvi_str11 != pr_wvi_str11: 483 | flag = False 484 | # print(g_wv1, g_wv11) 485 | # print(pr_wv1, pr_wv11) 486 | # input('') 487 | break 488 | if flag: 489 | cnt_list.append(1) 490 | else: 491 | cnt_list.append(0) 492 | 493 | return cnt_list 494 | 495 | def get_cnt_lx_list(cnt_sc1, cnt_sa1, cnt_wn1, cnt_wc1, cnt_wo1, cnt_wv1): 496 | # all cnt are list here. 497 | cnt_list = [] 498 | cnt_lx = 0 499 | for csc, csa, cwn, cwc, cwo, cwv in zip(cnt_sc1, cnt_sa1, cnt_wn1, cnt_wc1, cnt_wo1, cnt_wv1): 500 | if csc and csa and cwn and cwc and cwo and cwv: 501 | cnt_list.append(1) 502 | else: 503 | cnt_list.append(0) 504 | 505 | return cnt_list 506 | 507 | def get_cnt_x_list(engine, tb, g_sc, g_sa, g_sql_i, pr_sc, pr_sa, pr_sql_i): 508 | cnt_x1_list = [] 509 | g_ans = [] 510 | pr_ans = [] 511 | for b in range(len(g_sc)): 512 | g_ans1 = engine.execute(tb[b]['id'], g_sc[b], g_sa[b], g_sql_i[b]['conds']) 513 | # print(f'cnt: {cnt}') 514 | # print(f"pr_sql_i: {pr_sql_i[b]['conds']}") 515 | try: 516 | pr_ans1 = engine.execute(tb[b]['id'], pr_sc[b], pr_sa[b], pr_sql_i[b]['conds']) 517 | 518 | if bool(pr_ans1): # not empty due to lack of the data from incorretly generated sql 519 | if g_ans1 == pr_ans1: 520 | cnt_x1 = 1 521 | else: 522 | cnt_x1 = 0 523 | else: 524 | cnt_x1 = 0 525 | except: 526 | # type error etc... Execution-guided decoding may be used here. 527 | pr_ans1 = None 528 | cnt_x1 = 0 529 | cnt_x1_list.append(cnt_x1) 530 | g_ans.append(g_ans1) 531 | pr_ans.append(pr_ans1) 532 | 533 | return cnt_x1_list, g_ans, pr_ans 534 | -------------------------------------------------------------------------------- /dev_function.py: -------------------------------------------------------------------------------- 1 | from dbengine_sqlnet import DBEngine 2 | 3 | import os 4 | import seq2sql_model_training_functions 5 | import corenlp_local 6 | import load_data 7 | import roberta_training 8 | import infer_functions 9 | import torch 10 | from tqdm.notebook import tqdm 11 | import seq2sql_model_testing 12 | 13 | 14 | def train(seq2sql_model,roberta_model,model_optimizer,roberta_optimizer,roberta_tokenizer,roberta_config,path_wikisql,train_loader): 15 | 16 | roberta_model.train() 17 | seq2sql_model.train() 18 | 19 | results=[] 20 | average_loss = 0 21 | 22 | count_select_column = 0 # count the # of correct predictions of select column 23 | count_select_agg = 0 # of selectd aggregation 24 | count_where_number = 0 # of where number 25 | count_where_column = 0 # of where column 26 | count_where_operator = 0 # of where operator 27 | count_where_value = 0 # of where-value 28 | count_where_value_index = 0 # of where-value index (on question tokens) 29 | count_logical_form_acc = 0 # of logical form accuracy 30 | count_execution_acc = 0 # of execution accuracy 31 | 32 | 33 | # Engine for SQL querying. 34 | engine = DBEngine(os.path.join(path_wikisql, f"train.db")) 35 | count = 0 # count the # of examples 36 | for batch_index, batch in enumerate(tqdm(train_loader)): 37 | count += len(batch) 38 | 39 | # if batch_index > 2: 40 | # break 41 | # Get fields 42 | 43 | # nlu : natural language utterance 44 | # nlu_t: tokenized nlu 45 | # sql_i: canonical form of SQL query 46 | # sql_q: full SQL query text. Not used. 47 | # sql_t: tokenized SQL query 48 | # tb : table metadata. No row data needed 49 | # hs_t : tokenized headers. Not used. 50 | natural_lang_utterance, natural_lang_utterance_tokenized, sql_canonical, \ 51 | _, _, table_metadata, _, headers = load_data.get_fields(batch) 52 | 53 | 54 | select_column_ground, select_agg_ground, where_number_ground, \ 55 | where_column_ground, where_operator_ground, _ = roberta_training.get_ground_truth_values(sql_canonical) 56 | # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. 57 | 58 | 59 | natural_lang_embeddings, header_embeddings, question_token_length, header_token_length, header_count, \ 60 | natural_lang_double_tokenized, punkt_to_roberta_token_indices, roberta_to_punkt_token_indices \ 61 | = roberta_training.get_wemb_roberta(roberta_config, roberta_model, roberta_tokenizer, 62 | natural_lang_utterance_tokenized, headers,max_seq_length= 222, 63 | num_out_layers_n=2, num_out_layers_h=2) 64 | # natural_lang_embeddings: natural language embedding 65 | # header_embeddings: header embedding 66 | # question_token_length: token lengths of each question 67 | # header_token_length: header token lengths 68 | # header_count: the number of columns (headers) of the tables. 69 | 70 | where_value_index_ground_corenlp = corenlp_local.get_g_wvi_corenlp(batch) 71 | try: 72 | # 73 | where_value_index_ground = corenlp_local.get_g_wvi_bert_from_g_wvi_corenlp(punkt_to_roberta_token_indices, where_value_index_ground_corenlp) 74 | except: 75 | # Exception happens when where-condition is not found in natural_lang_double_tokenized. 76 | # In this case, that train example is not used. 77 | # During test, that example considered as wrongly answered. 78 | # e.g. train: 32. 79 | continue 80 | 81 | knowledge = [] 82 | for k in batch: 83 | if "bertindex_knowledge" in k: 84 | knowledge.append(k["bertindex_knowledge"]) 85 | else: 86 | knowledge.append(max(question_token_length)*[0]) 87 | 88 | knowledge_header = [] 89 | for k in batch: 90 | if "header_knowledge" in k: 91 | knowledge_header.append(k["header_knowledge"]) 92 | else: 93 | knowledge_header.append(max(header_count) * [0]) 94 | 95 | # score 96 | 97 | select_column_score, select_agg_score, where_number_score, where_column_score,\ 98 | where_operator_score, where_value_score = seq2sql_model(natural_lang_embeddings, question_token_length, header_embeddings, 99 | header_token_length, header_count, 100 | g_sc=select_column_ground, g_sa=select_agg_ground, 101 | g_wn=where_number_ground, g_wc=where_column_ground, 102 | g_wo=where_operator_ground, g_wvi=where_value_index_ground, 103 | knowledge = knowledge, 104 | knowledge_header = knowledge_header) 105 | 106 | # Calculate loss & step 107 | loss = seq2sql_model_training_functions.Loss_sw_se(select_column_score, select_agg_score, where_number_score, 108 | where_column_score, where_operator_score, where_value_score, 109 | select_column_ground, select_agg_ground, 110 | where_number_ground, where_column_ground, 111 | where_operator_ground, where_value_index_ground) 112 | 113 | 114 | model_optimizer.zero_grad() 115 | if roberta_optimizer: 116 | roberta_optimizer.zero_grad() 117 | loss.backward() 118 | model_optimizer.step() 119 | if roberta_optimizer: 120 | roberta_optimizer.step() 121 | 122 | 123 | # Prediction 124 | select_column_predict, select_agg_predict, where_number_predict, \ 125 | where_column_predict, where_operator_predict, where_val_index_predict = seq2sql_model_training_functions.pred_sw_se( 126 | select_column_score, select_agg_score, where_number_score, 127 | where_column_score, where_operator_score, where_value_score) 128 | where_value_string_predict, _ = seq2sql_model_training_functions.convert_pr_wvi_to_string( 129 | where_val_index_predict, 130 | natural_lang_utterance_tokenized, natural_lang_double_tokenized, 131 | roberta_to_punkt_token_indices, natural_lang_utterance) 132 | 133 | 134 | # Sort where_column_predict: 135 | # Sort where_column_predict when training the model as where_operator_predict and where_val_index_predict are predicted using ground-truth where-column (g_wc) 136 | # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. 137 | where_column_predict_sorted = seq2sql_model_training_functions.sort_pr_wc(where_column_predict, where_column_ground) 138 | 139 | sql_canonical_predict = seq2sql_model_training_functions.generate_sql_i( 140 | select_column_predict, select_agg_predict, where_number_predict, 141 | where_column_predict_sorted, where_operator_predict, 142 | where_value_string_predict, natural_lang_utterance) 143 | 144 | # Cacluate accuracy 145 | select_col_batchlist, select_agg_batchlist, where_number_batchlist, \ 146 | where_column_batchlist, where_operator_batchlist, where_value_index_batchlist, \ 147 | where_value_batchlist = seq2sql_model_training_functions.get_cnt_sw_list( 148 | select_column_ground, select_agg_ground, 149 | where_number_ground, where_column_ground, 150 | where_operator_ground, where_value_index_ground, 151 | select_column_predict, select_agg_predict, where_number_predict, 152 | where_column_predict, where_operator_predict, where_val_index_predict, 153 | sql_canonical, sql_canonical_predict, 154 | mode='train') 155 | 156 | logical_form_acc_batchlist = seq2sql_model_training_functions.get_cnt_lx_list( 157 | select_col_batchlist, select_agg_batchlist, where_number_batchlist, 158 | where_column_batchlist,where_operator_batchlist, where_value_batchlist) 159 | # lx stands for logical form accuracy 160 | # Execution accuracy test. 161 | execution_acc_batchlist, _, _ = seq2sql_model_training_functions.get_cnt_x_list( 162 | engine, table_metadata, select_column_ground, select_agg_ground, 163 | sql_canonical, select_column_predict, select_agg_predict, sql_canonical_predict) 164 | # statistics 165 | average_loss += loss.item() 166 | 167 | # count 168 | count_select_column += sum(select_col_batchlist) 169 | count_select_agg += sum(select_agg_batchlist) 170 | count_where_number += sum(where_number_batchlist) 171 | count_where_column += sum(where_column_batchlist) 172 | count_where_operator += sum(where_operator_batchlist) 173 | count_where_value_index += sum(where_value_index_batchlist) 174 | count_where_value += sum(where_value_batchlist) 175 | count_logical_form_acc += sum(logical_form_acc_batchlist) 176 | count_execution_acc += sum(execution_acc_batchlist) 177 | 178 | average_loss /= count 179 | select_column_acc = count_select_column / count 180 | select_agg_acc = count_select_agg / count 181 | where_number_acc = count_where_number / count 182 | where_column_acc = count_where_column / count 183 | where_operator_acc = count_where_operator / count 184 | where_value_index_acc = count_where_value_index / count 185 | where_value_acc = count_where_value / count 186 | logical_form_acc = count_logical_form_acc / count 187 | execution_acc = count_execution_acc / count 188 | accuracy = [average_loss, select_column_acc, select_agg_acc, where_number_acc, where_column_acc, 189 | where_operator_acc, where_value_index_acc, where_value_acc, logical_form_acc, execution_acc] 190 | 191 | return accuracy 192 | 193 | 194 | def test(seq2sql_model,roberta_model,model_optimizer,roberta_tokenizer,roberta_config,path_wikisql,test_loader,mode="dev"): 195 | 196 | roberta_model.eval() 197 | seq2sql_model.eval() 198 | 199 | count_batchlist=[] 200 | results=[] 201 | 202 | 203 | count_select_column = 0 # count the # of correct predictions of select column 204 | count_select_agg = 0 # of selectd aggregation 205 | count_where_number = 0 # of where number 206 | count_where_column = 0 # of where column 207 | count_where_operator = 0 # of where operator 208 | count_where_value = 0 # of where-value 209 | count_where_value_index = 0 # of where-value index (on question tokens) 210 | count_logical_form_acc = 0 # of logical form accuracy 211 | count_execution_acc = 0 # of execution accurac 212 | 213 | 214 | # Engine for SQL querying. 215 | engine = DBEngine(os.path.join(path_wikisql, mode+".db")) 216 | 217 | count = 0 218 | for batch_index, batch in enumerate(tqdm(test_loader)): 219 | count += len(batch) 220 | 221 | # if batch_index > 2: 222 | # break 223 | # Get fields 224 | natural_lang_utterance, natural_lang_utterance_tokenized, sql_canonical, \ 225 | _, _, table_metadata, _, headers = load_data.get_fields(batch) 226 | 227 | 228 | select_column_ground, select_agg_ground, where_number_ground, \ 229 | where_column_ground, where_operator_ground, _ = roberta_training.get_ground_truth_values(sql_canonical) 230 | # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. 231 | 232 | 233 | natural_lang_embeddings, header_embeddings, question_token_length, header_token_length, header_count, \ 234 | natural_lang_double_tokenized, punkt_to_roberta_token_indices, roberta_to_punkt_token_indices \ 235 | = roberta_training.get_wemb_roberta(roberta_config, roberta_model, roberta_tokenizer, 236 | natural_lang_utterance_tokenized, headers,max_seq_length= 222, 237 | num_out_layers_n=2, num_out_layers_h=2) 238 | # natural_lang_embeddings: natural language embedding 239 | # header_embeddings: header embedding 240 | # question_token_length: token lengths of each question 241 | # header_token_length: header token lengths 242 | # header_count: the number of columns (headers) of the tables. 243 | 244 | where_value_index_ground_corenlp = corenlp_local.get_g_wvi_corenlp(batch) 245 | try: 246 | # 247 | where_value_index_ground = corenlp_local.get_g_wvi_bert_from_g_wvi_corenlp(punkt_to_roberta_token_indices, where_value_index_ground_corenlp) 248 | except: 249 | # Exception happens when where-condition is not found in nlu_tt. 250 | # In this case, that train example is not used. 251 | # During test, that example considered as wrongly answered. 252 | # e.g. train: 32. 253 | for b in range(len(natural_lang_utterance)): 254 | curr_results = {} 255 | curr_results["error"] = "Skip happened" 256 | curr_results["nlu"] = natural_lang_utterance[b] 257 | curr_results["table_id"] = table_metadata[b]["id"] 258 | results.append(curr_results) 259 | continue 260 | 261 | 262 | knowledge = [] 263 | for k in batch: 264 | if "bertindex_knowledge" in k: 265 | knowledge.append(k["bertindex_knowledge"]) 266 | else: 267 | knowledge.append(max(question_token_length)*[0]) 268 | 269 | knowledge_header = [] 270 | for k in batch: 271 | if "header_knowledge" in k: 272 | knowledge_header.append(k["header_knowledge"]) 273 | else: 274 | knowledge_header.append(max(header_count) * [0]) 275 | 276 | 277 | 278 | # score 279 | _, _, _, select_column_predict, select_agg_predict, where_number_predict, sql_predict = seq2sql_model.beam_forward( 280 | natural_lang_embeddings, question_token_length, header_embeddings, 281 | header_token_length, header_count, table_metadata, 282 | natural_lang_utterance_tokenized, natural_lang_double_tokenized, 283 | roberta_to_punkt_token_indices, natural_lang_utterance, 284 | beam_size=4, knowledge=knowledge, knowledge_header=knowledge_header) 285 | 286 | # sort and generate 287 | where_column_predict, where_operator_predict, _, sql_predict = infer_functions.sort_and_generate_pr_w(sql_predict) 288 | 289 | # Follosing variables are just for the consistency with no-EG case. 290 | where_value_index_predict = None # not used 291 | 292 | for b, sql_predict_instance in enumerate(sql_predict): 293 | curr_results = {} 294 | curr_results["query"] = sql_predict_instance 295 | curr_results["table_id"] = table_metadata[b]["id"] 296 | curr_results["nlu"] = natural_lang_utterance[b] 297 | results.append(curr_results) 298 | 299 | # Cacluate accuracy 300 | select_column_batchlist, select_agg_batchlist, where_number_batchlist, \ 301 | where_column_batchlist, where_operator_batchlist, \ 302 | where_value_index_batchlist, where_value_batchlist = seq2sql_model_training_functions.get_cnt_sw_list( 303 | select_column_ground, select_agg_ground, where_number_ground, 304 | where_column_ground, where_operator_ground, where_value_index_ground, 305 | select_column_predict, select_agg_predict, where_number_predict, where_column_predict, 306 | where_operator_predict, where_value_index_predict, 307 | sql_canonical, sql_predict, 308 | mode='test') 309 | 310 | logical_form_acc_batchlist = seq2sql_model_training_functions.get_cnt_lx_list(select_column_batchlist, select_agg_batchlist, where_number_batchlist, where_column_batchlist, 311 | where_operator_batchlist, where_value_batchlist) 312 | # lx stands for logical form accuracy 313 | 314 | # Execution accuracy test. 315 | execution_acc_batchlist, _, _ = seq2sql_model_training_functions.get_cnt_x_list( 316 | engine, table_metadata, select_column_ground, select_agg_ground, sql_canonical, select_column_predict, select_agg_predict, sql_predict) 317 | 318 | # statistics 319 | # ave_loss += loss.item() 320 | 321 | # count 322 | count_select_column += sum(select_column_batchlist) 323 | count_select_agg += sum(select_agg_batchlist) 324 | count_where_number += sum(where_number_batchlist) 325 | count_where_column += sum(where_column_batchlist) 326 | count_where_operator += sum(where_operator_batchlist) 327 | count_where_value_index += sum(where_value_index_batchlist) 328 | count_where_value += sum(where_value_batchlist) 329 | count_logical_form_acc += sum(logical_form_acc_batchlist) 330 | count_execution_acc += sum(execution_acc_batchlist) 331 | 332 | count_curr_batchlist = [select_column_batchlist, select_agg_batchlist, where_number_batchlist, where_column_batchlist, where_operator_batchlist, where_value_batchlist, logical_form_acc_batchlist,execution_acc_batchlist] 333 | count_batchlist.append(count_curr_batchlist) 334 | 335 | # ave_loss /= cnt 336 | select_column_acc = count_select_column / count 337 | select_agg_acc = count_select_agg / count 338 | where_number_acc = count_where_number / count 339 | where_column_acc = count_where_column / count 340 | where_operator_acc = count_where_operator / count 341 | where_value_index_acc = count_where_value_index / count 342 | where_value_acc = count_where_value / count 343 | logical_form_acc = count_logical_form_acc / count 344 | execution_acc = count_execution_acc / count 345 | 346 | accuracy = [None, select_column_acc, select_agg_acc, where_number_acc, 347 | where_column_acc, where_operator_acc, where_value_index_acc, 348 | where_value_acc, logical_form_acc, execution_acc] 349 | 350 | return accuracy, results, count_batchlist 351 | -------------------------------------------------------------------------------- /seq2sql_model_classes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from copy import deepcopy 4 | from matplotlib.pylab import * 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from seq2sql_model_internal_functions import * 9 | #import torch_xla 10 | #import torch_xla.core.xla_model as xm 11 | 12 | device = torch.device("cuda") 13 | 14 | 15 | class Seq2SQL_v1(nn.Module): 16 | def __init__(self, iS, hS, lS, dr, n_cond_ops, n_agg_ops, old=False): 17 | super(Seq2SQL_v1, self).__init__() 18 | self.iS = iS 19 | self.hS = hS 20 | self.ls = lS 21 | self.dr = dr 22 | 23 | self.max_wn = 4 24 | self.n_cond_ops = n_cond_ops 25 | self.n_agg_ops = n_agg_ops 26 | 27 | self.scp = SCP(iS, hS, lS, dr) 28 | self.sap = SAP(iS, hS, lS, dr, n_agg_ops, old=old) 29 | self.wnp = WNP(iS, hS, lS, dr) 30 | self.wcp = WCP(iS, hS, lS, dr) 31 | self.wop = WOP(iS, hS, lS, dr, n_cond_ops) 32 | # start-end-search-discriminative model 33 | self.wvp = WVP_se(iS, hS, lS, dr, n_cond_ops, old=old) 34 | 35 | def forward(self, wemb_n, l_n, wemb_h, l_hpu, l_hs, 36 | g_sc=None, g_sa=None, g_wn=None, g_wc=None, g_wo=None, g_wvi=None, 37 | show_p_sc=False, show_p_sa=False, 38 | show_p_wn=False, show_p_wc=False, show_p_wo=False, show_p_wv=False, 39 | knowledge=None, 40 | knowledge_header=None): 41 | # sc 42 | s_sc = self.scp(wemb_n, l_n, wemb_h, l_hpu, l_hs, show_p_sc=show_p_sc, 43 | knowledge=knowledge, knowledge_header=knowledge_header) 44 | 45 | if g_sc: 46 | pr_sc = g_sc 47 | else: 48 | pr_sc = pred_sc(s_sc) 49 | # sa 50 | s_sa = self.sap(wemb_n, l_n, wemb_h, l_hpu, l_hs, pr_sc, show_p_sa=show_p_sa, 51 | knowledge=knowledge, knowledge_header=knowledge_header) 52 | if g_sa: 53 | # it's not necessary though. 54 | pr_sa = g_sa 55 | else: 56 | pr_sa = pred_sa(s_sa) 57 | # wn 58 | s_wn = self.wnp(wemb_n, l_n, wemb_h, l_hpu, l_hs, show_p_wn=show_p_wn, 59 | knowledge=knowledge, knowledge_header=knowledge_header) 60 | 61 | if g_wn: 62 | pr_wn = g_wn 63 | else: 64 | pr_wn = pred_wn(s_wn) 65 | # wc 66 | s_wc = self.wcp(wemb_n, l_n, wemb_h, l_hpu, l_hs, show_p_wc=show_p_wc, penalty=True, predict_select_column=pr_sc, 67 | knowledge=knowledge, knowledge_header=knowledge_header) 68 | 69 | if g_wc: 70 | pr_wc = g_wc 71 | else: 72 | pr_wc = pred_wc(pr_wn, s_wc) 73 | # for b, columns in enumerate(pr_wc): 74 | # for c in columns: 75 | # s_sc[b, c] = -1e+10 76 | 77 | # wo 78 | s_wo = self.wop(wemb_n, l_n, wemb_h, l_hpu, l_hs, wn=pr_wn, wc=pr_wc, show_p_wo=show_p_wo, 79 | knowledge=knowledge, knowledge_header=knowledge_header) 80 | 81 | if g_wo: 82 | pr_wo = g_wo 83 | else: 84 | pr_wo = pred_wo(pr_wn, s_wo) 85 | # wv 86 | s_wv = self.wvp(wemb_n, l_n, wemb_h, l_hpu, l_hs, wn=pr_wn, wc=pr_wc, wo=pr_wo, show_p_wv=show_p_wv, 87 | knowledge=knowledge, knowledge_header=knowledge_header) 88 | return s_sc, s_sa, s_wn, s_wc, s_wo, s_wv 89 | 90 | def beam_forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, tb, 91 | nlu_t, nlu_wp_t, wp_to_wh_index, nlu, 92 | beam_size=4, 93 | show_p_sc=False, show_p_sa=False, 94 | show_p_wn=False, show_p_wc=False, show_p_wo=False, show_p_wv=False, 95 | knowledge=None, 96 | knowledge_header=None): 97 | """ 98 | Execution-guided beam decoding. 99 | """ 100 | # s_sc = [batch_size, header_len] 101 | s_sc = self.scp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_sc=show_p_sc, 102 | knowledge=knowledge, knowledge_header=knowledge_header) 103 | prob_sc = F.softmax(s_sc, dim=-1) 104 | bS, mcL = s_sc.shape 105 | 106 | # minimum_hs_length = min(l_hs) 107 | # beam_size = minimum_hs_length if beam_size > minimum_hs_length else beam_size 108 | 109 | # sa 110 | # Construct all possible sc_sa_score 111 | prob_sc_sa = torch.zeros([bS, beam_size, self.n_agg_ops]).to(device) 112 | prob_sca = torch.zeros_like(prob_sc_sa).to(device) 113 | 114 | # get the top-k indices. pr_sc_beam = [B, beam_size] 115 | pr_sc_beam = pred_sc_beam(s_sc, beam_size) 116 | 117 | # calculate and predict s_sa. 118 | for i_beam in range(beam_size): 119 | pr_sc = list(array(pr_sc_beam)[:, i_beam]) # pr_sc = [batch_size] 120 | s_sa = self.sap(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, pr_sc, show_p_sa=show_p_sa, 121 | knowledge=knowledge, knowledge_header=knowledge_header) 122 | prob_sa = F.softmax(s_sa, dim=-1) 123 | prob_sc_sa[:, i_beam, :] = prob_sa 124 | 125 | prob_sc_selected = prob_sc[range(bS), pr_sc] # [B] 126 | prob_sca[:, i_beam, :] = (prob_sa.t() * prob_sc_selected).t() 127 | # [mcL, B] * [B] -> [mcL, B] (element-wise multiplication) 128 | # [mcL, B] -> [B, mcL] 129 | 130 | # Calculate the dimension of tensor 131 | # tot_dim = len(prob_sca.shape) 132 | 133 | # First flatten to 1-d 134 | idxs = topk_multi_dim(torch.tensor(prob_sca), 135 | n_topk=beam_size, batch_exist=True) 136 | # Now as sc_idx is already sorted, re-map them properly. 137 | 138 | # [sc_beam_idx, sa_idx] -> [sc_idx, sa_idx] 139 | idxs = remap_sc_idx(idxs, pr_sc_beam) 140 | idxs_arr = array(idxs) 141 | # [B, beam_size, remainig dim] 142 | # idxs[b][0] gives first probable [sc_idx, sa_idx] pairs. 143 | # idxs[b][1] gives of second. 144 | 145 | # Calculate prob_sca, a joint probability 146 | beam_idx_sca = [0] * bS 147 | beam_meet_the_final = [False] * bS 148 | while True: 149 | pr_sc = idxs_arr[range(bS), beam_idx_sca, 0] 150 | pr_sa = idxs_arr[range(bS), beam_idx_sca, 1] 151 | 152 | # map index properly 153 | 154 | check = check_sc_sa_pairs(tb, pr_sc, pr_sa) 155 | 156 | if sum(check) == bS: 157 | break 158 | else: 159 | for b, check1 in enumerate(check): 160 | if not check1: # wrong pair 161 | beam_idx_sca[b] += 1 162 | if beam_idx_sca[b] >= beam_size: 163 | beam_meet_the_final[b] = True 164 | beam_idx_sca[b] -= 1 165 | else: 166 | beam_meet_the_final[b] = True 167 | 168 | if sum(beam_meet_the_final) == bS: 169 | break 170 | 171 | # Now pr_sc, pr_sa are properly predicted. 172 | pr_sc_best = list(pr_sc) 173 | pr_sa_best = list(pr_sa) 174 | 175 | # Now, Where-clause beam search. 176 | s_wn = self.wnp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wn=show_p_wn, 177 | knowledge=knowledge, knowledge_header=knowledge_header) 178 | prob_wn = F.softmax(s_wn, dim=-1).detach().to('cpu').numpy() 179 | 180 | # Found "executable" most likely 4(=max_num_of_conditions) where-clauses. 181 | # wc 182 | s_wc = self.wcp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wc=show_p_wc, penalty=True, 183 | knowledge=knowledge, knowledge_header=knowledge_header) 184 | prob_wc = F.sigmoid(s_wc).detach().to('cpu').numpy() 185 | # pr_wc_sorted_by_prob = pred_wc_sorted_by_prob(s_wc) 186 | 187 | # get max_wn # of most probable columns & their prob. 188 | pr_wn_max = [self.max_wn]*bS 189 | # if some column do not have executable where-claouse, omit that column 190 | pr_wc_max = pred_wc(pr_wn_max, s_wc) 191 | prob_wc_max = zeros([bS, self.max_wn]) 192 | for b, pr_wc_max1 in enumerate(pr_wc_max): 193 | prob_wc_max[b, :] = prob_wc[b, pr_wc_max1] 194 | 195 | # get most probable max_wn where-clouses 196 | # wo 197 | s_wo_max = self.wop(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn=pr_wn_max, wc=pr_wc_max, show_p_wo=show_p_wo, 198 | knowledge=knowledge, knowledge_header=knowledge_header) 199 | prob_wo_max = F.softmax(s_wo_max, dim=-1).detach().to('cpu').numpy() 200 | # [B, max_wn, n_cond_op] 201 | 202 | pr_wvi_beam_op_list = [] 203 | prob_wvi_beam_op_list = [] 204 | for i_op in range(self.n_cond_ops-1): 205 | pr_wo_temp = [[i_op]*self.max_wn]*bS 206 | # wv 207 | s_wv = self.wvp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn=pr_wn_max, wc=pr_wc_max, wo=pr_wo_temp, show_p_wv=show_p_wv, 208 | knowledge=knowledge, knowledge_header=knowledge_header) 209 | prob_wv = F.softmax(s_wv, dim=-2).detach().to('cpu').numpy() 210 | 211 | # prob_wv 212 | pr_wvi_beam, prob_wvi_beam = pred_wvi_se_beam( 213 | self.max_wn, s_wv, beam_size) 214 | pr_wvi_beam_op_list.append(pr_wvi_beam) 215 | prob_wvi_beam_op_list.append(prob_wvi_beam) 216 | # pr_wvi_beam = [B, max_wn, k_logit**2 [st, ed] paris] 217 | 218 | # pred_wv_beam 219 | 220 | # Calculate joint probability of where-clause 221 | # prob_w = [batch, wc, wo, wv] = [B, max_wn, n_cond_op, n_pairs] 222 | n_wv_beam_pairs = prob_wvi_beam.shape[2] 223 | prob_w = zeros([bS, self.max_wn, self.n_cond_ops-1, n_wv_beam_pairs]) 224 | for b in range(bS): 225 | for i_wn in range(self.max_wn): 226 | for i_op in range(self.n_cond_ops-1): # do not use final one 227 | for i_wv_beam in range(n_wv_beam_pairs): 228 | # i_wc = pr_wc_max[b][i_wn] # already done 229 | p_wc = prob_wc_max[b, i_wn] 230 | p_wo = prob_wo_max[b, i_wn, i_op] 231 | p_wv = prob_wvi_beam_op_list[i_op][b, i_wn, i_wv_beam] 232 | 233 | prob_w[b, i_wn, i_op, i_wv_beam] = p_wc * p_wo * p_wv 234 | 235 | # Perform execution guided decoding 236 | conds_max = [] 237 | prob_conds_max = [] 238 | # while len(conds_max) < self.max_wn: 239 | idxs = topk_multi_dim(torch.tensor( 240 | prob_w), n_topk=beam_size, batch_exist=True) 241 | # idxs = [B, i_wc_beam, i_op, i_wv_pairs] 242 | 243 | # Construct conds1 244 | for b, idxs1 in enumerate(idxs): 245 | conds_max1 = [] 246 | prob_conds_max1 = [] 247 | for i_wn, idxs11 in enumerate(idxs1): 248 | i_wc = pr_wc_max[b][idxs11[0]] 249 | i_op = idxs11[1] 250 | wvi = pr_wvi_beam_op_list[i_op][b][idxs11[0]][idxs11[2]] 251 | 252 | # get wv_str 253 | temp_pr_wv_str, _ = convert_pr_wvi_to_string( 254 | [[wvi]], [nlu_t[b]], [nlu_wp_t[b]], [wp_to_wh_index[b]], [nlu[b]]) 255 | merged_wv11 = merge_wv_t1_eng(temp_pr_wv_str[0][0], nlu[b]) 256 | conds11 = [i_wc, i_op, merged_wv11] 257 | 258 | prob_conds11 = prob_w[b, idxs11[0], idxs11[1], idxs11[2]] 259 | 260 | # test execution 261 | # print(nlu[b]) 262 | # print(tb[b]['id'], tb[b]['types'], pr_sc[b], pr_sa[b], [conds11]) 263 | # pr_ans = engine.execute( 264 | # tb[b]['id'], pr_sc[b], pr_sa[b], [conds11]) 265 | # if bool(pr_ans): 266 | # # pr_ans is not empty! 267 | conds_max1.append(conds11) 268 | prob_conds_max1.append(prob_conds11) 269 | conds_max.append(conds_max1) 270 | prob_conds_max.append(prob_conds_max1) 271 | 272 | # May need to do more exhuastive search? 273 | # i.e. up to.. getting all executable cases. 274 | 275 | # Calculate total probability to decide the number of where-clauses 276 | pr_sql_i = [] 277 | prob_wn_w = [] 278 | pr_wn_based_on_prob = [] 279 | 280 | for b, prob_wn1 in enumerate(prob_wn): 281 | max_executable_wn1 = len(conds_max[b]) 282 | prob_wn_w1 = [] 283 | prob_wn_w1.append(prob_wn1[0]) # wn=0 case. 284 | for i_wn in range(max_executable_wn1): 285 | prob_wn_w11 = prob_wn1[i_wn+1] * prob_conds_max[b][i_wn] 286 | prob_wn_w1.append(prob_wn_w11) 287 | pr_wn_based_on_prob.append(argmax(prob_wn_w1)) 288 | prob_wn_w.append(prob_wn_w1) 289 | 290 | pr_sql_i1 = {'agg': pr_sa_best[b], 'sel': pr_sc_best[b], 291 | 'conds': conds_max[b][:pr_wn_based_on_prob[b]]} 292 | pr_sql_i.append(pr_sql_i1) 293 | # s_wv = [B, max_wn, max_nlu_tokens, 2] 294 | return prob_sca, prob_w, prob_wn_w, pr_sc_best, pr_sa_best, pr_wn_based_on_prob, pr_sql_i 295 | 296 | 297 | class SCP(nn.Module): 298 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3): 299 | super(SCP, self).__init__() 300 | self.iS = iS 301 | self.hS = hS 302 | self.lS = lS 303 | self.dr = dr 304 | 305 | self.question_knowledge_dim = 5 306 | self.header_knowledge_dim = 3 307 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 308 | num_layers=lS, batch_first=True, 309 | dropout=dr, bidirectional=True) 310 | 311 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 312 | num_layers=lS, batch_first=True, 313 | dropout=dr, bidirectional=True) 314 | 315 | self.W_att = nn.Linear( 316 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim) 317 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS) 318 | self.W_hs = nn.Linear(hS+self.header_knowledge_dim, hS) 319 | self.sc_out = nn.Sequential(nn.Tanh(), nn.Linear(2 * hS, 1)) 320 | 321 | self.softmax_dim1 = nn.Softmax(dim=1) 322 | self.softmax_dim2 = nn.Softmax(dim=2) 323 | 324 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_sc=False, 325 | knowledge=None, 326 | knowledge_header=None): 327 | # Encode 328 | mL_n = max(l_n) 329 | bS = len(l_hs) 330 | wenc_n = encode(self.enc_n, wemb_n, l_n, 331 | return_hidden=False, 332 | hc0=None, 333 | last_only=False) # [b, n, dim] 334 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge] 335 | knowledge = torch.tensor(knowledge).unsqueeze(-1) 336 | 337 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1, 338 | index=knowledge, 339 | value=1).to(device) 340 | wenc_n = torch.cat([wenc_n, feature], -1) 341 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim] 342 | knowledge_header = [k + (max(l_hs) - len(k)) * [0] 343 | for k in knowledge_header] 344 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1) 345 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1, 346 | index=knowledge_header, 347 | value=1).to(device) 348 | wenc_hs = torch.cat([wenc_hs, feature2], -1) 349 | bS = len(l_hs) 350 | mL_n = max(l_n) 351 | 352 | # [bS, mL_hs, 100] * [bS, 100, mL_n] -> [bS, mL_hs, mL_n] 353 | att_h = torch.bmm(wenc_hs, self.W_att(wenc_n).transpose(1, 2)) 354 | 355 | # Penalty on blank parts 356 | for b, l_n1 in enumerate(l_n): 357 | if l_n1 < mL_n: 358 | att_h[b, :, l_n1:] = -10000000000 359 | 360 | p_n = self.softmax_dim2(att_h) 361 | if show_p_sc: 362 | # p = [b, hs, n] 363 | if p_n.shape[0] != 1: 364 | raise Exception("Batch size should be 1.") 365 | fig = figure(2001, figsize=(12, 3.5)) 366 | # subplot(6,2,7) 367 | subplot2grid((7, 2), (3, 0), rowspan=2) 368 | cla() 369 | _color = 'rgbkcm' 370 | _symbol = '.......' 371 | for i_h in range(l_hs[0]): 372 | color_idx = i_h % len(_color) 373 | plot(p_n[0][i_h][:].data.numpy() - i_h, '--' + 374 | _symbol[color_idx]+_color[color_idx], ms=7) 375 | 376 | title('sc: p_n for each h') 377 | grid(True) 378 | fig.tight_layout() 379 | fig.canvas.draw() 380 | show() 381 | 382 | # p_n [ bS, mL_hs, mL_n] -> [ bS, mL_hs, mL_n, 1] 383 | # wenc_n [ bS, mL_n, 100] -> [ bS, 1, mL_n, 100] 384 | # -> [bS, mL_hs, mL_n, 100] -> [bS, mL_hs, 100] 385 | c_n = torch.mul(p_n.unsqueeze(3), wenc_n.unsqueeze(1)).sum(dim=2) 386 | 387 | vec = torch.cat([self.W_c(c_n), self.W_hs(wenc_hs)], dim=2) 388 | s_sc = self.sc_out(vec).squeeze(2) # [bS, mL_hs, 1] -> [bS, mL_hs] 389 | 390 | # Penalty 391 | mL_hs = max(l_hs) 392 | for b, l_hs1 in enumerate(l_hs): 393 | if l_hs1 < mL_hs: 394 | s_sc[b, l_hs1:] = -10000000000 395 | 396 | return s_sc 397 | 398 | 399 | class SAP(nn.Module): 400 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, n_agg_ops=-1, old=False): 401 | super(SAP, self).__init__() 402 | self.iS = iS 403 | self.hS = hS 404 | self.lS = lS 405 | self.dr = dr 406 | 407 | self.question_knowledge_dim = 5 408 | self.header_knowledge_dim = 3 409 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 410 | num_layers=lS, batch_first=True, 411 | dropout=dr, bidirectional=True) 412 | 413 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 414 | num_layers=lS, batch_first=True, 415 | dropout=dr, bidirectional=True) 416 | 417 | self.W_att = nn.Linear( 418 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim) 419 | self.sa_out = nn.Sequential(nn.Linear(hS + self.question_knowledge_dim, hS), 420 | nn.Tanh(), 421 | nn.Linear(hS, n_agg_ops)) # Fixed number of aggregation operator. 422 | 423 | self.softmax_dim1 = nn.Softmax(dim=1) 424 | self.softmax_dim2 = nn.Softmax(dim=2) 425 | 426 | if old: 427 | # for backwoard compatibility 428 | self.W_c = nn.Linear(hS, hS) 429 | self.W_hs = nn.Linear(hS, hS) 430 | 431 | # wemb_hpu [batch_size*header_num, max_header_len, hidden_dim] 432 | # l_hpu [batch_size*header_num] 433 | # l_hs [batch_size] 434 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, pr_sc, show_p_sa=False, 435 | knowledge=None, 436 | knowledge_header=None): 437 | # Encode 438 | mL_n = max(l_n) 439 | bS = len(l_hs) 440 | wenc_n = encode(self.enc_n, wemb_n, l_n, 441 | return_hidden=False, 442 | hc0=None, 443 | last_only=False) # [b, n, dim] 444 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge] 445 | knowledge = torch.tensor(knowledge).unsqueeze(-1) 446 | 447 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1, 448 | index=knowledge, 449 | value=1).to(device) 450 | wenc_n = torch.cat([wenc_n, feature], -1) 451 | 452 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim] 453 | knowledge_header = [k + (max(l_hs) - len(k)) * [0] 454 | for k in knowledge_header] 455 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1) 456 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1, 457 | index=knowledge_header, 458 | value=1).to(device) 459 | wenc_hs = torch.cat([wenc_hs, feature2], -1) 460 | bS = len(l_hs) 461 | mL_n = max(l_n) 462 | 463 | # list, so one sample for each batch. 464 | wenc_hs_ob = wenc_hs[list(range(bS)), pr_sc] 465 | 466 | # [bS, mL_n, 100] * [bS, 100, 1] -> [bS, mL_n] 467 | att = torch.bmm(self.W_att(wenc_n), wenc_hs_ob.unsqueeze(2)).squeeze(2) 468 | 469 | # Penalty on blank parts 470 | for b, l_n1 in enumerate(l_n): 471 | if l_n1 < mL_n: 472 | att[b, l_n1:] = -10000000000 473 | # [bS, mL_n] 474 | p = self.softmax_dim1(att) 475 | 476 | if show_p_sa: 477 | if p.shape[0] != 1: 478 | raise Exception("Batch size should be 1.") 479 | fig = figure(2001) 480 | subplot(7, 2, 3) 481 | cla() 482 | plot(p[0].data.numpy(), '--rs', ms=7) 483 | title('sa: nlu_weight') 484 | grid(True) 485 | fig.tight_layout() 486 | fig.canvas.draw() 487 | show() 488 | 489 | # [bS, mL_n, 100] * ( [bS, mL_n, 1] -> [bS, mL_n, 100]) 490 | # -> [bS, mL_n, 100] -> [bS, 100] 491 | c_n = torch.mul(wenc_n, p.unsqueeze(2).expand_as(wenc_n)).sum(dim=1) 492 | s_sa = self.sa_out(c_n) 493 | 494 | return s_sa 495 | 496 | 497 | class WNP(nn.Module): 498 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, ): 499 | super(WNP, self).__init__() 500 | self.iS = iS 501 | self.hS = hS 502 | self.lS = lS 503 | self.dr = dr 504 | 505 | self.mL_w = 4 # max where condition number 506 | self.question_knowledge_dim = 5 507 | self.header_knowledge_dim = 3 508 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 509 | num_layers=lS, batch_first=True, 510 | dropout=dr, bidirectional=True) 511 | 512 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 513 | num_layers=lS, batch_first=True, 514 | dropout=dr, bidirectional=True) 515 | 516 | self.W_att_h = nn.Linear(hS + self.header_knowledge_dim, 1) 517 | self.W_hidden = nn.Linear(hS + self.header_knowledge_dim, lS * hS) 518 | self.W_cell = nn.Linear(hS + self.header_knowledge_dim, lS * hS) 519 | 520 | self.W_att_n = nn.Linear(hS + self.question_knowledge_dim, 1) 521 | self.wn_out = nn.Sequential(nn.Linear(hS + self.question_knowledge_dim, hS), 522 | nn.Tanh(), 523 | nn.Linear(hS, self.mL_w + 1)) # max number (4 + 1) 524 | 525 | self.softmax_dim1 = nn.Softmax(dim=1) 526 | self.softmax_dim2 = nn.Softmax(dim=2) 527 | 528 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wn=False, 529 | knowledge=None, 530 | knowledge_header=None): 531 | # Encode 532 | mL_n = max(l_n) 533 | bS = len(l_hs) 534 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, 535 | l_hs) # [b, mL_hs, dim] 536 | knowledge_header = [k + (max(l_hs) - len(k)) * [0] 537 | for k in knowledge_header] 538 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1) 539 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1, 540 | index=knowledge_header, 541 | value=1).to(device) 542 | wenc_hs = torch.cat([wenc_hs, feature2], -1) 543 | 544 | bS = len(l_hs) 545 | mL_n = max(l_n) 546 | mL_hs = max(l_hs) 547 | # mL_h = max(l_hpu) 548 | 549 | # (self-attention?) column Embedding? 550 | # [B, mL_hs, 100] -> [B, mL_hs, 1] -> [B, mL_hs] 551 | att_h = self.W_att_h(wenc_hs).squeeze(2) 552 | 553 | # Penalty 554 | for b, l_hs1 in enumerate(l_hs): 555 | if l_hs1 < mL_hs: 556 | att_h[b, l_hs1:] = -10000000000 557 | p_h = self.softmax_dim1(att_h) 558 | 559 | if show_p_wn: 560 | if p_h.shape[0] != 1: 561 | raise Exception("Batch size should be 1.") 562 | fig = figure(2001) 563 | subplot(7, 2, 5) 564 | cla() 565 | plot(p_h[0].data.numpy(), '--rs', ms=7) 566 | title('wn: header_weight') 567 | grid(True) 568 | fig.canvas.draw() 569 | show() 570 | # input('Type Eenter to continue.') 571 | 572 | # [B, mL_hs, 100] * [ B, mL_hs, 1] -> [B, mL_hs, 100] -> [B, 100] 573 | c_hs = torch.mul(wenc_hs, p_h.unsqueeze(2)).sum(1) 574 | 575 | # [B, 100] --> [B, 2*100] Enlarge because there are two layers. 576 | hidden = self.W_hidden(c_hs) # [B, 4, 200/2] 577 | hidden = hidden.view(bS, self.lS * 2, int( 578 | self.hS / 2)) # [4, B, 100/2] # number_of_layer_layer * (bi-direction) # lstm input convention. 579 | hidden = hidden.transpose(0, 1).contiguous() 580 | 581 | cell = self.W_cell(c_hs) # [B, 4, 100/2] 582 | cell = cell.view(bS, self.lS * 2, int(self.hS / 2)) # [4, B, 100/2] 583 | cell = cell.transpose(0, 1).contiguous() 584 | 585 | wenc_n = encode(self.enc_n, wemb_n, l_n, 586 | return_hidden=False, 587 | hc0=(hidden, cell), 588 | last_only=False) # [b, n, dim] 589 | 590 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge] 591 | knowledge = torch.tensor(knowledge).unsqueeze(-1) 592 | 593 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1, 594 | index=knowledge, 595 | value=1).to(device) 596 | wenc_n = torch.cat([wenc_n, feature], -1) 597 | 598 | # [B, max_len, 100] -> [B, max_len, 1] -> [B, max_len] 599 | att_n = self.W_att_n(wenc_n).squeeze(2) 600 | 601 | # Penalty 602 | for b, l_n1 in enumerate(l_n): 603 | if l_n1 < mL_n: 604 | att_n[b, l_n1:] = -10000000000 605 | p_n = self.softmax_dim1(att_n) 606 | 607 | if show_p_wn: 608 | if p_n.shape[0] != 1: 609 | raise Exception("Batch size should be 1.") 610 | fig = figure(2001) 611 | subplot(7, 2, 6) 612 | cla() 613 | plot(p_n[0].data.numpy(), '--rs', ms=7) 614 | title('wn: nlu_weight') 615 | grid(True) 616 | fig.canvas.draw() 617 | 618 | show() 619 | # input('Type Enter to continue.') 620 | 621 | # [B, mL_n, 100] *([B, mL_n] -> [B, mL_n, 1] -> [B, mL_n, 100] ) -> [B, 100] 622 | c_n = torch.mul(wenc_n, p_n.unsqueeze(2).expand_as(wenc_n)).sum(dim=1) 623 | s_wn = self.wn_out(c_n) 624 | 625 | return s_wn 626 | 627 | 628 | class WCP(nn.Module): 629 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3): 630 | super(WCP, self).__init__() 631 | self.iS = iS 632 | self.hS = hS 633 | self.lS = lS 634 | self.dr = dr 635 | self.question_knowledge_dim = 5 636 | self.header_knowledge_dim = 3 637 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 638 | num_layers=lS, batch_first=True, 639 | dropout=dr, bidirectional=True) 640 | 641 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 642 | num_layers=lS, batch_first=True, 643 | dropout=dr, bidirectional=True) 644 | 645 | self.W_att = nn.Linear( 646 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim) 647 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS) 648 | self.W_hs = nn.Linear(hS + self.header_knowledge_dim, hS) 649 | self.W_out = nn.Sequential( 650 | nn.Tanh(), nn.Linear(2 * hS, 1) 651 | ) 652 | 653 | self.softmax_dim1 = nn.Softmax(dim=1) 654 | self.softmax_dim2 = nn.Softmax(dim=2) 655 | 656 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wc, penalty=True, predict_select_column=None, 657 | knowledge=None, 658 | knowledge_header=None): 659 | # Encode 660 | mL_n = max(l_n) 661 | bS = len(l_hs) 662 | wenc_n = encode(self.enc_n, wemb_n, l_n, 663 | return_hidden=False, 664 | hc0=None, 665 | last_only=False) # [b, n, dim] 666 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge] 667 | knowledge = torch.tensor(knowledge).unsqueeze(-1) 668 | 669 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1, 670 | index=knowledge, 671 | value=1).to(device) 672 | wenc_n = torch.cat([wenc_n, feature], -1) 673 | 674 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim] 675 | knowledge_header = [k + (max(l_hs) - len(k)) * [0] 676 | for k in knowledge_header] 677 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1) 678 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1, 679 | index=knowledge_header, 680 | value=1).to(device) 681 | wenc_hs = torch.cat([wenc_hs, feature2], -1) 682 | # attention 683 | # wenc = [bS, mL, hS] 684 | # att = [bS, mL_hs, mL_n] 685 | # att[b, i_h, j_n] = p(j_n| i_h) 686 | att = torch.bmm(wenc_hs, self.W_att(wenc_n).transpose(1, 2)) 687 | 688 | # penalty to blank part. 689 | mL_n = max(l_n) 690 | for b_n, l_n1 in enumerate(l_n): 691 | if l_n1 < mL_n: 692 | att[b_n, :, l_n1:] = -10000000000 693 | 694 | # for b, c in enumerate(predict_select_column): 695 | # att[b, c, :] = -10000000000 696 | 697 | # make p(j_n | i_h) 698 | p = self.softmax_dim2(att) 699 | 700 | if show_p_wc: 701 | # p = [b, hs, n] 702 | if p.shape[0] != 1: 703 | raise Exception("Batch size should be 1.") 704 | fig = figure(2001) 705 | # subplot(6,2,7) 706 | subplot2grid((7, 2), (3, 1), rowspan=2) 707 | cla() 708 | _color = 'rgbkcm' 709 | _symbol = '.......' 710 | for i_h in range(l_hs[0]): 711 | color_idx = i_h % len(_color) 712 | plot(p[0][i_h][:].data.numpy() - i_h, '--' + 713 | _symbol[color_idx]+_color[color_idx], ms=7) 714 | 715 | title('wc: p_n for each h') 716 | grid(True) 717 | fig.tight_layout() 718 | fig.canvas.draw() 719 | show() 720 | # max nlu context vectors 721 | # [bS, mL_hs, mL_n]*[bS, mL_hs, mL_n] 722 | wenc_n = wenc_n.unsqueeze(1) # [ b, n, dim] -> [b, 1, n, dim] 723 | p = p.unsqueeze(3) # [b, hs, n] -> [b, hs, n, 1] 724 | # -> [b, hs, dim], c_n for each header. 725 | c_n = torch.mul(wenc_n, p).sum(2) 726 | 727 | # bS = len(l_hs) 728 | # index = torch.tensor(predict_select_column).unsqueeze(-1) 729 | # feature = torch.zeros(bS, max(l_hs)).scatter_(dim=-1, 730 | # index=index, 731 | # value=1).to(device) 732 | # c_n = torch.cat([c_n, feature.unsqueeze(-1)],dim=-1) 733 | 734 | y = torch.cat([self.W_c(c_n), self.W_hs(wenc_hs)], 735 | dim=2) # [b, hs, 2*dim] 736 | score = self.W_out(y).squeeze(2) # [b, hs] 737 | 738 | if penalty: 739 | for b, l_hs1 in enumerate(l_hs): 740 | score[b, l_hs1:] = -1e+10 741 | 742 | # for b, c in enumerate(predict_select_column): 743 | # score[b, c] = -1e+10 744 | 745 | return score 746 | 747 | 748 | class WOP(nn.Module): 749 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, n_cond_ops=3): 750 | super(WOP, self).__init__() 751 | self.iS = iS 752 | self.hS = hS 753 | self.lS = lS 754 | self.dr = dr 755 | self.question_knowledge_dim = 0 756 | self.header_knowledge_dim = 0 757 | self.mL_w = 4 # max where condition number 758 | 759 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 760 | num_layers=lS, batch_first=True, 761 | dropout=dr, bidirectional=True) 762 | 763 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 764 | num_layers=lS, batch_first=True, 765 | dropout=dr, bidirectional=True) 766 | 767 | self.W_att = nn.Linear( 768 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim) 769 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS) 770 | self.W_hs = nn.Linear(hS + self.header_knowledge_dim, hS) 771 | self.wo_out = nn.Sequential( 772 | nn.Linear(2*hS, hS), 773 | nn.Tanh(), 774 | nn.Linear(hS, n_cond_ops) 775 | ) 776 | 777 | self.softmax_dim1 = nn.Softmax(dim=1) 778 | self.softmax_dim2 = nn.Softmax(dim=2) 779 | 780 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn, wc, wenc_n=None, show_p_wo=False, 781 | knowledge=None, 782 | knowledge_header=None): 783 | # Encode 784 | mL_n = max(l_n) 785 | bS = len(l_hs) 786 | if not wenc_n: 787 | wenc_n = encode(self.enc_n, wemb_n, l_n, 788 | return_hidden=False, 789 | hc0=None, 790 | last_only=False) # [b, n, dim] 791 | if self.question_knowledge_dim != 0: 792 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge] 793 | knowledge = torch.tensor(knowledge).unsqueeze(-1) 794 | 795 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1, 796 | index=knowledge, 797 | value=1).to(device) 798 | wenc_n = torch.cat([wenc_n, feature], -1) 799 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim] 800 | if self.header_knowledge_dim != 0: 801 | knowledge_header = [k + (max(l_hs) - len(k)) * [0] 802 | for k in knowledge_header] 803 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1) 804 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1, 805 | index=knowledge_header, 806 | value=1).to(device) 807 | wenc_hs = torch.cat([wenc_hs, feature2], -1) 808 | bS = len(l_hs) 809 | # wn 810 | 811 | wenc_hs_ob = [] # observed hs 812 | for b in range(bS): 813 | # [[...], [...]] 814 | # Pad list to maximum number of selections 815 | real = [wenc_hs[b, col] for col in wc[b]] 816 | # this padding could be wrong. Test with zero padding later. 817 | pad = (self.mL_w - wn[b]) * [wenc_hs[b, 0]] 818 | # It is not used in the loss function. 819 | wenc_hs_ob1 = torch.stack(real + pad) 820 | wenc_hs_ob.append(wenc_hs_ob1) 821 | 822 | # list to [B, 4, dim] tensor. 823 | wenc_hs_ob = torch.stack(wenc_hs_ob) # list to tensor. 824 | wenc_hs_ob = wenc_hs_ob.to(device) 825 | 826 | # [B, 1, mL_n, dim] * [B, 4, dim, 1] 827 | # -> [B, 4, mL_n, 1] -> [B, 4, mL_n] 828 | # multiplication bewteen NLq-tokens and selected column 829 | att = torch.matmul(self.W_att(wenc_n).unsqueeze(1), 830 | wenc_hs_ob.unsqueeze(3) 831 | ).squeeze(3) 832 | # Penalty for blank part. 833 | mL_n = max(l_n) 834 | for b, l_n1 in enumerate(l_n): 835 | if l_n1 < mL_n: 836 | att[b, :, l_n1:] = -10000000000 837 | 838 | p = self.softmax_dim2(att) # p( n| selected_col ) 839 | if show_p_wo: 840 | # p = [b, hs, n] 841 | if p.shape[0] != 1: 842 | raise Exception("Batch size should be 1.") 843 | fig = figure(2001) 844 | # subplot(6,2,7) 845 | subplot2grid((7, 2), (5, 0), rowspan=2) 846 | cla() 847 | _color = 'rgbkcm' 848 | _symbol = '.......' 849 | for i_wn in range(self.mL_w): 850 | color_idx = i_wn % len(_color) 851 | plot(p[0][i_wn][:].data.numpy() - i_wn, '--' + 852 | _symbol[color_idx]+_color[color_idx], ms=7) 853 | 854 | title('wo: p_n for selected h') 855 | grid(True) 856 | fig.tight_layout() 857 | fig.canvas.draw() 858 | show() 859 | 860 | # [B, 1, mL_n, dim] * [B, 4, mL_n, 1] 861 | # --> [B, 4, mL_n, dim] 862 | # --> [B, 4, dim] 863 | c_n = torch.mul(wenc_n.unsqueeze(1), p.unsqueeze(3)).sum(dim=2) 864 | 865 | # [bS, 5-1, dim] -> [bS, 5-1, 3] 866 | 867 | vec = torch.cat([self.W_c(c_n), self.W_hs(wenc_hs_ob)], dim=2) 868 | s_wo = self.wo_out(vec) 869 | return s_wo 870 | 871 | 872 | class WVP_se(nn.Module): 873 | """ 874 | Discriminative model 875 | Get start and end. 876 | Here, classifier for [ [투수], [팀1], [팀2], [연도], ...] 877 | Input: Encoded nlu & selected column. 878 | Algorithm: Encoded nlu & selected column. -> classifier -> mask scores -> ... 879 | """ 880 | 881 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, n_cond_ops=4, old=False): 882 | super(WVP_se, self).__init__() 883 | self.iS = iS 884 | self.hS = hS 885 | self.lS = lS 886 | self.dr = dr 887 | self.n_cond_ops = n_cond_ops 888 | self.question_knowledge_dim = 5 889 | self.header_knowledge_dim = 3 890 | self.mL_w = 4 # max where condition number 891 | 892 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 893 | num_layers=lS, batch_first=True, 894 | dropout=dr, bidirectional=True) 895 | 896 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2), 897 | num_layers=lS, batch_first=True, 898 | dropout=dr, bidirectional=True) 899 | 900 | self.W_att = nn.Linear( 901 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim) 902 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS) 903 | self.W_hs = nn.Linear(hS + self.header_knowledge_dim, hS) 904 | self.W_op = nn.Linear(n_cond_ops, hS) 905 | 906 | # self.W_n = nn.Linear(hS, hS) 907 | if old: 908 | self.wv_out = nn.Sequential( 909 | nn.Linear(4 * hS, 2) 910 | ) 911 | else: 912 | self.wv_out = nn.Sequential( 913 | nn.Linear(4 * hS + self.question_knowledge_dim, hS), 914 | nn.Tanh(), 915 | nn.Linear(hS, 2) 916 | ) 917 | # self.wv_out = nn.Sequential( 918 | # nn.Linear(3 * hS, hS), 919 | # nn.Tanh(), 920 | # nn.Linear(hS, self.gdkL) 921 | # ) 922 | 923 | self.softmax_dim1 = nn.Softmax(dim=1) 924 | self.softmax_dim2 = nn.Softmax(dim=2) 925 | 926 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn, wc, wo, wenc_n=None, show_p_wv=False, 927 | knowledge=None, 928 | knowledge_header=None): 929 | mL_n = max(l_n) 930 | bS = len(l_hs) 931 | # Encode 932 | if not wenc_n: 933 | wenc_n, hout, cout = encode(self.enc_n, wemb_n, l_n, 934 | return_hidden=True, 935 | hc0=None, 936 | last_only=False) # [b, n, dim] 937 | 938 | knowledge = [k+(mL_n-len(k))*[0] for k in knowledge] 939 | knowledge = torch.tensor(knowledge).unsqueeze(-1) 940 | 941 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1, 942 | index=knowledge, 943 | value=1).to(device) 944 | wenc_n = torch.cat([wenc_n, feature], -1) 945 | 946 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim] 947 | 948 | knowledge_header = [k + (max(l_hs) - len(k)) * [0] 949 | for k in knowledge_header] 950 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1) 951 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1, 952 | index=knowledge_header, 953 | value=1).to(device) 954 | wenc_hs = torch.cat([wenc_hs, feature2], -1) 955 | 956 | wenc_hs_ob = [] # observed hs 957 | 958 | for b in range(bS): 959 | # [[...], [...]] 960 | # Pad list to maximum number of selections 961 | real = [wenc_hs[b, col] for col in wc[b]] 962 | # this padding could be wrong. Test with zero padding later. 963 | pad = (self.mL_w - wn[b]) * [wenc_hs[b, 0]] 964 | # It is not used in the loss function. 965 | wenc_hs_ob1 = torch.stack(real + pad) 966 | wenc_hs_ob.append(wenc_hs_ob1) 967 | 968 | # list to [B, 4, dim] tensor. 969 | wenc_hs_ob = torch.stack(wenc_hs_ob) # list to tensor. 970 | wenc_hs_ob = wenc_hs_ob.to(device) 971 | 972 | # Column attention 973 | # [B, 1, mL_n, dim] * [B, 4, dim, 1] 974 | # -> [B, 4, mL_n, 1] -> [B, 4, mL_n] 975 | # multiplication bewteen NLq-tokens and 【selected】 column 976 | att = torch.matmul(self.W_att(wenc_n).unsqueeze(1), 977 | wenc_hs_ob.unsqueeze(3) 978 | ).squeeze(3) 979 | # Penalty for blank part. 980 | 981 | for b, l_n1 in enumerate(l_n): 982 | if l_n1 < mL_n: 983 | att[b, :, l_n1:] = -10000000000 984 | 985 | p = self.softmax_dim2(att) # p( n| selected_col ) 986 | 987 | if show_p_wv: 988 | # p = [b, hs, n] 989 | if p.shape[0] != 1: 990 | raise Exception("Batch size should be 1.") 991 | fig = figure(2001) 992 | # subplot(6,2,7) 993 | subplot2grid((7, 2), (5, 1), rowspan=2) 994 | cla() 995 | _color = 'rgbkcm' 996 | _symbol = '.......' 997 | for i_wn in range(self.mL_w): 998 | color_idx = i_wn % len(_color) 999 | plot(p[0][i_wn][:].data.numpy() - i_wn, '--' + 1000 | _symbol[color_idx]+_color[color_idx], ms=7) 1001 | 1002 | title('wv: p_n for selected h') 1003 | grid(True) 1004 | fig.tight_layout() 1005 | fig.canvas.draw() 1006 | show() 1007 | 1008 | # [B, 1, mL_n, dim] * [B, 4, mL_n, 1] 1009 | # --> [B, 4, mL_n, dim] 1010 | # --> [B, 4, dim] 1011 | c_n = torch.mul(wenc_n.unsqueeze(1), p.unsqueeze(3)).sum(dim=2) 1012 | 1013 | # Select observed headers only. 1014 | # Also generate one_hot vector encoding info of the operator 1015 | # [B, 4, dim] 1016 | wenc_op = [] 1017 | for b in range(bS): 1018 | # [[...], [...]] 1019 | # Pad list to maximum number of selections 1020 | wenc_op1 = torch.zeros(self.mL_w, self.n_cond_ops) 1021 | wo1 = wo[b] 1022 | idx_scatter = [] 1023 | l_wo1 = len(wo1) 1024 | for i_wo11 in range(self.mL_w): 1025 | if i_wo11 < l_wo1: 1026 | wo11 = wo1[i_wo11] 1027 | idx_scatter.append([int(wo11)]) 1028 | else: 1029 | idx_scatter.append([0]) # not used anyway 1030 | 1031 | wenc_op1 = wenc_op1.scatter(1, torch.tensor(idx_scatter), 1) 1032 | 1033 | wenc_op.append(wenc_op1) 1034 | 1035 | # list to [B, 4, dim] tensor. 1036 | wenc_op = torch.stack(wenc_op) # list to tensor. 1037 | wenc_op = wenc_op.to(device) 1038 | 1039 | # Now after concat, calculate logits for each token 1040 | # [bS, 5-1, 3*hS] = [bS, 4, 300] 1041 | vec = torch.cat([self.W_c(c_n), self.W_hs( 1042 | wenc_hs_ob), self.W_op(wenc_op)], dim=2) 1043 | 1044 | # Make extended vector based on encoded nl token containing column and operator information. 1045 | # wenc_n = [bS, mL, 100] 1046 | # vec2 = [bS, 4, mL, 400] 1047 | # [bS, 4, 1, 300] -> [bS, 4, mL, 300] 1048 | vec1e = vec.unsqueeze(2).expand(-1, -1, mL_n, -1) 1049 | # [bS, 1, mL, 100] -> [bS, 4, mL, 100] 1050 | wenc_ne = wenc_n.unsqueeze(1).expand(-1, 4, -1, -1) 1051 | vec2 = torch.cat([vec1e, wenc_ne], dim=3) 1052 | 1053 | # now make logits 1054 | s_wv = self.wv_out(vec2) # [bS, 4, mL, 400] -> [bS, 4, mL, 2] 1055 | 1056 | # penalty for spurious tokens 1057 | for b, l_n1 in enumerate(l_n): 1058 | if l_n1 < mL_n: 1059 | s_wv[b, :, l_n1:, :] = -10000000000 1060 | return s_wv 1061 | --------------------------------------------------------------------------------