├── corenlp_local.py
├── model_save_and_infer.py
├── LICENSE
├── README.md
├── load_model.py
├── dbengine_sqlnet.py
├── load_data.py
├── seq2sql_model_testing.py
├── infer_functions.py
├── roberta_training.py
├── seq2sql_model_internal_functions.py
├── seq2sql_model_training_functions.py
├── dev_function.py
└── seq2sql_model_classes.py
/corenlp_local.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def get_g_wvi_bert_from_g_wvi_corenlp(wh_to_wp_index, g_wvi_corenlp):
4 | """
5 | Generate SQuAD style start and end index of wv in nlu. Index is for of after WordPiece tokenization.
6 | Assumption: where_str always presents in the nlu.
7 | """
8 | g_wvi = []
9 | for b, g_wvi_corenlp1 in enumerate(g_wvi_corenlp):
10 | wh_to_wp_index1 = wh_to_wp_index[b]
11 | g_wvi1 = []
12 | for i_wn, g_wvi_corenlp11 in enumerate(g_wvi_corenlp1):
13 |
14 | st_idx, ed_idx = g_wvi_corenlp11
15 |
16 | st_wp_idx = wh_to_wp_index1[st_idx]
17 | ed_wp_idx = wh_to_wp_index1[ed_idx]
18 |
19 | g_wvi11 = [st_wp_idx, ed_wp_idx]
20 | g_wvi1.append(g_wvi11)
21 |
22 | g_wvi.append(g_wvi1)
23 |
24 | return g_wvi
25 |
26 | def get_g_wvi_corenlp(t):
27 | g_wvi_corenlp = []
28 | for t1 in t:
29 | g_wvi_corenlp.append( t1['wvi_corenlp'] )
30 | return g_wvi_corenlp
31 |
--------------------------------------------------------------------------------
/model_save_and_infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random as python_random
4 |
5 | def save_for_evaluation(path_save, results, dataset_name):
6 | path_save_file = os.path.join(path_save, f'results_{dataset_name}.jsonl')
7 | with open(path_save_file, 'w', encoding='utf-8') as file:
8 | for index, line in enumerate(results):
9 | json_string = json.dumps(line, ensure_ascii=False, default=json_default_type_checker)
10 | json_string += '\n'
11 |
12 | file.writelines(json_string)
13 |
14 | def json_default_type_checker(o):
15 | """
16 | From https://stackoverflow.com/questions/11942364/typeerror-integer-is-not-json-serializable-when-serializing-json-in-python
17 | """
18 | if isinstance(o, int): return int(o)
19 | raise TypeError
20 |
21 | def load_jsonl(path_file, seed=1):
22 | data = []
23 | with open(path_file, "r", encoding="utf-8") as file:
24 | for idx, line in enumerate(file):
25 | curr_line = json.loads(line.strip())
26 | data.append(curr_line)
27 | return data
28 |
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Debaditya Pal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Data Agnostic RoBERTa-based Natural Language to SQL Query Generation
2 | This repo contains the code for NL2SQL paper titled "Data Agnostic RoBERTa-based Natural Language to SQL Query Generation". The link to which is: https://arxiv.org/abs/2010.05243
3 |
4 | Our code was based on the previous works introduced in the paper "Content Enhanced BERT-based Text-to-SQL Generation", the code to which can be found [here](https://github.com/guotong1988/NL2SQL-RULE). Feel free to experiment with the models.
5 |
6 | ## Getting Started
7 | Our work was mainly done in a Google Colab workspace with a GPU runtime accelerator, in order to reproduce the results we would recommend running the code in a similar workspace. Thus we have created a notebook which can be found at:
8 | https://colab.research.google.com/drive/1qYJTbbEXYFVdY6xae9Zmt96hkeW8ZFrn
9 | Below are the step by step instructions to run the models:
10 |
11 | 1. Visit https://drive.google.com/drive/folders/13f2MrdpieC9QGXM_DJnj2f1Hs6ZBh2ZT?usp=sharing. This Drive contains all the datasets and pretrained model weights. **Add a shortcut to your drive.** This step is important.
12 | 2. Open the notebook, the link to which is provided above.
13 | 3. Make sure the runtime accelerator is set to GPU.
14 | 4. Mount your Google Drive and clone this repository in the runtime environment, the code for this step has already been provided in the notebook.
15 | 5. The Notebook is now ready for use.
16 |
17 | All the package requirements have been mentioned in the notebook.
18 |
19 | ## Results
20 |
21 | | Dev Logical Form Acc | Dev Execution Accuracy | Test Logical Form Accuracy | Test Execution Accuracy |
22 | |:-:|:-:|:-:|:-:|
23 | | 69.4% | 77.0% |68.9% | 76.7% |
24 |
25 | ## References
26 |
27 | - https://github.com/salesforce/WikiSQL
28 | - https://github.com/guotong1988/NL2SQL-RULE
29 |
--------------------------------------------------------------------------------
/load_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from seq2sql_model_classes import Seq2SQL_v1
4 | from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
5 |
6 | device = torch.device("cuda")
7 |
8 | def get_roberta_model():
9 |
10 | # Initializing a RoBERTa configuration
11 | configuration = RobertaConfig()
12 |
13 | # Initializing a model from the configuration
14 | Roberta_Model = RobertaModel(configuration).from_pretrained("roberta-base")
15 | Roberta_Model.to(device)
16 |
17 | # Accessing the model configuration
18 | configuration = Roberta_Model.config
19 |
20 | #get the Roberta Tokenizer
21 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
22 |
23 | return Roberta_Model, tokenizer, configuration
24 |
25 |
26 | def get_seq2sql_model(roberta_hidden_layer_size, number_of_layers = 2,
27 | hidden_vector_dimensions = 100,
28 | number_lstm_layers = 2,
29 | dropout_rate = 0.3,
30 | load_pretrained_model=False, model_path=None):
31 |
32 | '''
33 |
34 | get_seq2sql_model
35 | Arguments:
36 | roberta_hidden_layer_size: sizes of hidden layers of Roberta model
37 | number_of_layers : total number of layers
38 | hidden_vector_dimensions : dimensions of hidden vectors
39 | number_lstm_layers : total number of lstm layers
40 | dropout_rate : value of dropout rate
41 | load_pretrained_model : want to load pretrained model(true or false)
42 | model_path : The path to the directory in which the model is contained
43 |
44 | Returns:
45 | model: returns the model
46 |
47 | '''
48 |
49 | # number_of_layers = "The Number of final layers of RoBERTa to be used in downstream task."
50 | # hidden_vector_dimensions : "The dimension of hidden vector in the seq-to-SQL module."
51 | # number_lstm_layers : "The number of LSTM layers." in seqtosqlmodule
52 |
53 | sql_main_operators = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
54 | sql_conditional_operators = ['=', '>', '<', 'OP']
55 |
56 | number_of_neurons = roberta_hidden_layer_size * number_of_layers # Seq-to-SQL input vector dimenstion
57 |
58 | model = Seq2SQL_v1(number_of_neurons, hidden_vector_dimensions, number_lstm_layers, dropout_rate, len(sql_conditional_operators), len(sql_main_operators))
59 | model = model.to(device)
60 |
61 | if load_pretrained_model:
62 | assert model_path != None
63 | if torch.cuda.is_available():
64 | res = torch.load(model_path)
65 | else:
66 | res = torch.load(model_path, map_location='cpu')
67 | model.load_state_dict(res['model'])
68 |
69 | return model
70 |
71 | def get_optimizers(model, model_roberta,learning_rate_model=1e-3,learning_rate_roberta=1e-5):
72 | '''
73 | get_optimizers
74 | Arguments:
75 | model: returned model from get_seq2sql_model
76 | model_roberta : returned model from get_roberta_model
77 | fine_tune : want to fine tune(true or false)
78 | learning_rate_model : learning rate of model (from get_seq2sql_model)
79 | learning_rate_roberta : learning rate of roberta model (from get_roberta_model)
80 |
81 | Returns:
82 | opt: returns the optimised model (from get_seq2sql_model)
83 | opt_roberta : returns the optimised roberta model (from get_roberta_model)
84 |
85 | '''
86 |
87 | opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
88 | lr=learning_rate_model, weight_decay=0)
89 |
90 | opt_roberta = torch.optim.Adam(filter(lambda p: p.requires_grad, model_roberta.parameters()),
91 | lr=learning_rate_roberta, weight_decay=0)
92 |
93 | return opt, opt_roberta
94 |
95 |
--------------------------------------------------------------------------------
/dbengine_sqlnet.py:
--------------------------------------------------------------------------------
1 |
2 | import records
3 | import re
4 | from babel.numbers import parse_decimal, NumberFormatError
5 | # From original SQLNet code.
6 | # Wonseok modified. 20180607
7 |
8 | schema_re = re.compile(r'\((.+)\)') # group (.......) dfdf (.... )group
9 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') # ? zero or one time appear of preceding character, * zero or several time appear of preceding character.
10 | # Catch something like -34.34, .4543,
11 | # | is 'or'
12 |
13 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
14 | cond_ops = ['=', '>', '<', 'OP']
15 |
16 | class DBEngine:
17 |
18 | def __init__(self, fdb):
19 | #fdb = 'data/test.db'
20 | self.db = records.Database('sqlite:///{}'.format(fdb)).get_connection()
21 |
22 | def execute_query(self, table_id, query, *args, **kwargs):
23 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)
24 |
25 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
26 | if not table_id.startswith('table'):
27 | table_id = 'table_{}'.format(table_id.replace('-', '_'))
28 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
29 | schema_str = schema_re.findall(table_info)[0]
30 | schema = {}
31 | for tup in schema_str.split(', '):
32 | c, t = tup.split()
33 | schema[c] = t
34 | select = 'col{}'.format(select_index)
35 | agg = agg_ops[aggregation_index]
36 | if agg:
37 | select = '{}({})'.format(agg, select)
38 | where_clause = []
39 | where_map = {}
40 | for col_index, op, val in conditions:
41 | if lower and (isinstance(val, str) or isinstance(val, str)):
42 | val = val.lower()
43 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
44 | try:
45 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
46 | # val = float(parse_decimal(val)) # somehow it generates error.
47 | val = float(parse_decimal(val, locale='en_US'))
48 | # print('!!!!!!After: val', val)
49 |
50 | except NumberFormatError as e:
51 | try:
52 | val = float(num_re.findall(val)[0]) # need to understand and debug this part.
53 | except:
54 | # Although column is of number, selected one is not number. Do nothing in this case.
55 | pass
56 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
57 | where_map['col{}'.format(col_index)] = val
58 | where_str = ''
59 | if where_clause:
60 | where_str = 'WHERE ' + ' AND '.join(where_clause)
61 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
62 | #print query
63 | out = self.db.query(query, **where_map)
64 |
65 |
66 | return [o.result for o in out]
67 | def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True):
68 | if not table_id.startswith('table'):
69 | table_id = 'table_{}'.format(table_id.replace('-', '_'))
70 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
71 | schema_str = schema_re.findall(table_info)[0]
72 | schema = {}
73 | for tup in schema_str.split(', '):
74 | c, t = tup.split()
75 | schema[c] = t
76 | select = 'col{}'.format(select_index)
77 | agg = agg_ops[aggregation_index]
78 | if agg:
79 | select = '{}({})'.format(agg, select)
80 | where_clause = []
81 | where_map = {}
82 | for col_index, op, val in conditions:
83 | if lower and (isinstance(val, str) or isinstance(val, str)):
84 | val = val.lower()
85 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
86 | try:
87 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
88 | # val = float(parse_decimal(val)) # somehow it generates error.
89 | val = float(parse_decimal(val, locale='en_US'))
90 | # print('!!!!!!After: val', val)
91 |
92 | except NumberFormatError as e:
93 | val = float(num_re.findall(val)[0])
94 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
95 | where_map['col{}'.format(col_index)] = val
96 | where_str = ''
97 | if where_clause:
98 | where_str = 'WHERE ' + ' AND '.join(where_clause)
99 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
100 | #print query
101 | out = self.db.query(query, **where_map)
102 |
103 |
104 | return [o.result for o in out], query
105 | def show_table(self, table_id):
106 | if not table_id.startswith('table'):
107 | table_id = 'table_{}'.format(table_id.replace('-', '_'))
108 | rows = self.db.query('select * from ' +table_id)
109 | print(rows.dataset)
--------------------------------------------------------------------------------
/load_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import os
4 | import json
5 | from matplotlib.pylab import *
6 |
7 | def get_data(file_path: str,batch_size: int):
8 | '''
9 | Gets data from the dataset and creates a data loader
10 |
11 | Arguments:
12 | file_path: The path to the directory in which the dataset is contained
13 | batch_size: Batch size to be used for the data loaders
14 |
15 | Returns:
16 | train_data: Training dataset (Natural Language utterances)
17 | train_table: Training tables (Table schema and table data)
18 | dev_data: Development dataset (Natural Language utterances)
19 | dev_table: Development tables (Table schema and table data)
20 | train_loader: Training dataset loader
21 | dev_loader: Development dataset loader
22 | '''
23 | # Loading Dev Files(Development Dataset)
24 | dev_data = []
25 | dev_table = {}
26 |
27 | with open(file_path + '/dev_knowledge.jsonl') as dev_data_file:
28 | for idx, line in enumerate(dev_data_file):
29 | current_line = json.loads(line.strip())
30 | dev_data.append(current_line)
31 |
32 | with open(file_path + '/dev.tables.jsonl') as dev_table_file:
33 | for idx, line in enumerate(dev_table_file):
34 | current_line = json.loads(line.strip())
35 | dev_table[current_line['id']] = current_line
36 |
37 | # Loading Train Files(Training Dataset)
38 | train_data = []
39 | train_table = {}
40 |
41 | with open(file_path + '/train_knowledge.jsonl') as train_data_file:
42 | for idx, line in enumerate(train_data_file):
43 | current_line = json.loads(line.strip())
44 | train_data.append(current_line)
45 |
46 | with open(file_path + '/train.tables.jsonl') as train_table_file:
47 | for idx, line in enumerate(train_table_file):
48 | current_line = json.loads(line.strip())
49 | train_table[current_line['id']] = current_line
50 |
51 | train_loader = torch.utils.data.DataLoader(
52 | batch_size=batch_size,
53 | dataset=train_data,
54 | shuffle=True,
55 | num_workers=4,
56 | collate_fn=lambda x: x # now dictionary values are not merged!
57 | )
58 |
59 | dev_loader = torch.utils.data.DataLoader(
60 | batch_size=batch_size,
61 | dataset=dev_data,
62 | shuffle=True,
63 | num_workers=4,
64 | collate_fn=lambda x: x # now dictionary values are not merged!
65 | )
66 |
67 | return train_data, train_table, dev_data, dev_table, train_loader, dev_loader
68 |
69 | def get_test_data(file_path: str,batch_size: int):
70 | test_data=[]
71 | test_table = {}
72 |
73 | with open(file_path + '/test_knowledge.jsonl') as test_data_file:
74 | for idx, line in enumerate(test_data_file):
75 | current_line = json.loads(line.strip())
76 | test_data.append(current_line)
77 |
78 | with open(file_path + '/test.tables.jsonl') as test_table_file:
79 | for idx, line in enumerate(test_table_file):
80 | current_line = json.loads(line.strip())
81 | test_table[current_line['id']] = current_line
82 |
83 | test_loader = torch.utils.data.DataLoader(
84 | batch_size=batch_size,
85 | dataset=test_data,
86 | shuffle=True,
87 | num_workers=4,
88 | collate_fn=lambda x: x # now dictionary values are not merged!
89 | )
90 |
91 | return test_data,test_table,test_loader
92 |
93 | def get_zero_data(file_path: str,batch_size: int):
94 | test_data=[]
95 | test_table = {}
96 |
97 | with open(file_path + '/zero.jsonl') as test_data_file:
98 | for idx, line in enumerate(test_data_file):
99 | current_line = json.loads(line.strip())
100 | test_data.append(current_line)
101 |
102 | with open(file_path + '/test.tables.jsonl') as test_table_file:
103 | for idx, line in enumerate(test_table_file):
104 | current_line = json.loads(line.strip())
105 | test_table[current_line['id']] = current_line
106 |
107 | test_loader = torch.utils.data.DataLoader(
108 | batch_size=batch_size,
109 | dataset=test_data,
110 | shuffle=True,
111 | num_workers=4,
112 | collate_fn=lambda x: x # now dictionary values are not merged!
113 | )
114 |
115 | return test_data,test_table,test_loader
116 |
117 |
118 | def get_fields(data, header_tokenization=False, sql_tokenization=False):
119 |
120 | natural_language_utterance = []
121 | tokenized_natural_language_utterance = []
122 | sql_indexing = []
123 | sql_query = []
124 | tokenized_sql_query = []
125 | table_indices = []
126 | tokenized_headers = []
127 | headers = []
128 |
129 | for one_data in data:
130 | natural_language_utterance.append(one_data['question'])
131 | tokenized_natural_language_utterance.append(one_data['question_tok'])
132 | sql_indexing.append(one_data['sql'])
133 | sql_query.append(one_data['query'])
134 | headers.append(one_data['header'])
135 | table_indices.append({
136 | "id" : one_data["table_id"],
137 | "header": one_data["header"],
138 | "types" : one_data["types"]
139 | })
140 |
141 | if sql_tokenization:
142 | tokenized_sql_query.append(one_data['query_tok'])
143 | else:
144 | tokenized_sql_query.append(None)
145 |
146 | if header_tokenization:
147 | tokenized_headers.append(one_data['header_tok'])
148 | else:
149 | tokenized_headers.append(None)
150 |
151 | return natural_language_utterance,tokenized_natural_language_utterance,sql_indexing,sql_query,tokenized_sql_query,table_indices,tokenized_headers,headers
152 |
153 |
--------------------------------------------------------------------------------
/seq2sql_model_testing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | #import torch_xla
3 | #import torch_xla.core.xla_model as xm
4 |
5 | device = torch.device("cuda")
6 |
7 | def generate_sql_q(sql_i, tb):
8 | sql_q = []
9 | for b, sql_i1 in enumerate(sql_i):
10 | tb1 = tb[b]
11 | sql_q1 = generate_sql_q1(sql_i1, tb1)
12 | sql_q.append(sql_q1)
13 |
14 | return sql_q
15 |
16 | def generate_sql_q1(sql_i1, tb1):
17 | """
18 | sql = {'sel': 5, 'agg': 4, 'conds': [[3, 0, '59']]}
19 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg']
20 | cond_ops = ['=', '>', '<', 'OP']
21 | Temporal as it can show only one-time conditioned case.
22 | sql_query: real sql_query
23 | sql_plus_query: More redable sql_query
24 | "PLUS" indicates, it deals with the some of db specific facts like PCODE <-> NAME
25 | """
26 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg']
27 | cond_ops = ['=', '>', '<', 'OP']
28 |
29 | headers = tb1["header"]
30 | types = tb1["types"]
31 | # select_header = headers[sql['sel']].lower()
32 | # try:
33 | # select_table = tb1["name"]
34 | # except:
35 | # print(f"No table name while headers are {headers}")
36 | select_table = tb1["id"]
37 |
38 | select_agg = agg_ops[sql_i1['agg']]
39 | select_header = headers[sql_i1['sel']]
40 | sql_query_part1 = f'SELECT {select_agg}({select_header}) '
41 |
42 |
43 | where_num = len(sql_i1['conds'])
44 | if where_num == 0:
45 | sql_query_part2 = f'FROM {select_table}'
46 | # sql_plus_query_part2 = f'FROM {select_table}'
47 |
48 | else:
49 | sql_query_part2 = f'FROM {select_table} WHERE'
50 | # sql_plus_query_part2 = f'FROM {select_table_refined} WHERE'
51 | # ----------------------------------------------------------------------------------------------------------
52 | for i in range(where_num):
53 | # check 'OR'
54 | # number_of_sub_conds = len(sql['conds'][i])
55 | where_header_idx, where_op_idx, where_str = sql_i1['conds'][i]
56 | where_header = headers[where_header_idx]
57 | where_op = cond_ops[where_op_idx]
58 | if i > 0:
59 | sql_query_part2 += ' AND'
60 | # sql_plus_query_part2 += ' AND'
61 | if types[where_header_idx]=='text':
62 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'"
63 | else:
64 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'"
65 | sql_query = sql_query_part1 + sql_query_part2
66 | # sql_plus_query = sql_plus_query_part1 + sql_plus_query_part2
67 |
68 | return sql_query
69 |
70 | def report_detail(hds, nlu,
71 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans,
72 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans,
73 | cnt_list, current_cnt):
74 | cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x = current_cnt
75 |
76 | print(f'cnt = {cnt} / {cnt_tot} ===============================')
77 |
78 | print(f'headers: {hds}')
79 | print(f'nlu: {nlu}')
80 |
81 | # print(f's_sc: {s_sc[0]}')
82 | # print(f's_sa: {s_sa[0]}')
83 | # print(f's_wn: {s_wn[0]}')
84 | # print(f's_wc: {s_wc[0]}')
85 | # print(f's_wo: {s_wo[0]}')
86 | # print(f's_wv: {s_wv[0][0]}')
87 | print(f'===============================')
88 | print(f'g_sc : {g_sc}')
89 | print(f'pr_sc: {pr_sc}')
90 | print(f'g_sa : {g_sa}')
91 | print(f'pr_sa: {pr_sa}')
92 | print(f'g_wn : {g_wn}')
93 | print(f'pr_wn: {pr_wn}')
94 | print(f'g_wc : {g_wc}')
95 | print(f'pr_wc: {pr_wc}')
96 | print(f'g_wo : {g_wo}')
97 | print(f'pr_wo: {pr_wo}')
98 | print(f'g_wv : {g_wv}')
99 | # print(f'pr_wvi: {pr_wvi}')
100 | print('g_wv_str:', g_wv_str)
101 | print('p_wv_str:', pr_wv_str)
102 | print(f'g_sql_q: {g_sql_q}')
103 | print(f'pr_sql_q: {pr_sql_q}')
104 | print(f'g_ans: {g_ans}')
105 | print(f'pr_ans: {pr_ans}')
106 | print(f'--------------------------------')
107 |
108 | print(cnt_list)
109 |
110 | print(f'acc_lx = {cnt_lx / cnt:.3f}, acc_x = {cnt_x / cnt:.3f}\n',
111 | f'acc_sc = {cnt_sc / cnt:.3f}, acc_sa = {cnt_sa / cnt:.3f}, acc_wn = {cnt_wn / cnt:.3f}\n',
112 | f'acc_wc = {cnt_wc / cnt:.3f}, acc_wo = {cnt_wo / cnt:.3f}, acc_wv = {cnt_wv / cnt:.3f}')
113 | print(f'===============================')
114 |
115 | def generate_sql_q(sql_i, tb):
116 | sql_q = []
117 | for b, sql_i1 in enumerate(sql_i):
118 | tb1 = tb[b]
119 | sql_q1 = generate_sql_q1(sql_i1, tb1)
120 | sql_q.append(sql_q1)
121 |
122 | return sql_q
123 |
124 | def generate_sql_q1(sql_i1, tb1):
125 | """
126 | sql = {'sel': 5, 'agg': 4, 'conds': [[3, 0, '59']]}
127 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg']
128 | cond_ops = ['=', '>', '<', 'OP']
129 | Temporal as it can show only one-time conditioned case.
130 | sql_query: real sql_query
131 | sql_plus_query: More redable sql_query
132 | "PLUS" indicates, it deals with the some of db specific facts like PCODE <-> NAME
133 | """
134 | agg_ops = ['', 'max', 'min', 'count', 'sum', 'avg']
135 | cond_ops = ['=', '>', '<', 'OP']
136 |
137 | headers = tb1["header"]
138 | types = tb1["types"]
139 | # select_header = headers[sql['sel']].lower()
140 | # try:
141 | # select_table = tb1["name"]
142 | # except:
143 | # print(f"No table name while headers are {headers}")
144 | select_table = tb1["id"]
145 |
146 | select_agg = agg_ops[sql_i1['agg']]
147 | select_header = headers[sql_i1['sel']]
148 | sql_query_part1 = f'SELECT {select_agg}({select_header}) '
149 |
150 |
151 | where_num = len(sql_i1['conds'])
152 | if where_num == 0:
153 | sql_query_part2 = f'FROM {select_table}'
154 | # sql_plus_query_part2 = f'FROM {select_table}'
155 |
156 | else:
157 | sql_query_part2 = f'FROM {select_table} WHERE'
158 | # sql_plus_query_part2 = f'FROM {select_table_refined} WHERE'
159 | # ----------------------------------------------------------------------------------------------------------
160 | for i in range(where_num):
161 | # check 'OR'
162 | # number_of_sub_conds = len(sql['conds'][i])
163 | where_header_idx, where_op_idx, where_str = sql_i1['conds'][i]
164 | where_header = headers[where_header_idx]
165 | where_op = cond_ops[where_op_idx]
166 | if i > 0:
167 | sql_query_part2 += ' AND'
168 | # sql_plus_query_part2 += ' AND'
169 |
170 | if types[where_header_idx]=='text':
171 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'"
172 | else:
173 | sql_query_part2 += f" {where_header} {where_op} '{where_str}'"
174 |
175 | sql_query = sql_query_part1 + sql_query_part2
176 | # sql_plus_query = sql_plus_query_part1 + sql_plus_query_part2
177 |
178 | return sql_query
179 |
180 |
--------------------------------------------------------------------------------
/infer_functions.py:
--------------------------------------------------------------------------------
1 | from matplotlib.pylab import *
2 | from roberta_training import *
3 | from seq2sql_model_testing import *
4 | import json
5 | import nltk
6 | from nltk.tokenize import word_tokenize, sent_tokenize
7 | import re
8 | import os
9 | re_ = re.compile(' ')
10 |
11 | def tokenize_corenlp_direct_version(client, nlu1):
12 | nlu1_tok = []
13 | for sentence in client.annotate(nlu1).sentence:
14 | for tok in sentence.token:
15 | nlu1_tok.append(tok.originalText)
16 | return nlu1_tok
17 |
18 | def sent_split(documents):
19 | words = []
20 | for sent in sent_tokenize(documents):
21 | for word in word_tokenize(sent):
22 | words.append(word)
23 | return words
24 |
25 | def load_jsonl(path_file, toy_data=False, toy_size=4, shuffle=False, seed=1):
26 | data = []
27 |
28 | with open(path_file, "r", encoding="utf-8") as f:
29 | for idx, line in enumerate(f):
30 | if toy_data and idx >= toy_size and (not shuffle):
31 | break
32 | t1 = json.loads(line.strip())
33 | data.append(t1)
34 |
35 | if shuffle and toy_data:
36 | # When shuffle required, get all the data, shuffle, and get the part of data.
37 | print(
38 | f"If the toy-data is used, the whole data loaded first and then shuffled before get the first {toy_size} data")
39 |
40 | python_random.Random(seed).shuffle(data) # fixed
41 | data = data[:toy_size]
42 |
43 | return data
44 |
45 | def sort_and_generate_pr_w(pr_sql_i):
46 | pr_wc = []
47 | pr_wo = []
48 | pr_wv = []
49 | for b, pr_sql_i1 in enumerate(pr_sql_i):
50 | conds1 = pr_sql_i1["conds"]
51 | pr_wc1 = []
52 | pr_wo1 = []
53 | pr_wv1 = []
54 |
55 | # Generate
56 | for i_wn, conds11 in enumerate(conds1):
57 | pr_wc1.append( conds11[0])
58 | pr_wo1.append( conds11[1])
59 | pr_wv1.append( conds11[2])
60 |
61 | # sort based on pr_wc1
62 | idx = argsort(pr_wc1)
63 | pr_wc1 = array(pr_wc1)[idx].tolist()
64 | pr_wo1 = array(pr_wo1)[idx].tolist()
65 | pr_wv1 = array(pr_wv1)[idx].tolist()
66 |
67 | conds1_sorted = []
68 | for i, idx1 in enumerate(idx):
69 | conds1_sorted.append( conds1[idx1] )
70 |
71 |
72 | pr_wc.append(pr_wc1)
73 | pr_wo.append(pr_wo1)
74 | pr_wv.append(pr_wv1)
75 |
76 | pr_sql_i1['conds'] = conds1_sorted
77 |
78 | return pr_wc, pr_wo, pr_wv, pr_sql_i
79 |
80 | def process(data,tokenize):
81 | final_all = []
82 | badcase = 0
83 | for i, one_data in enumerate(data):
84 | nlu_t1 = one_data["question_tok"]
85 |
86 | # 1. 2nd tokenization using RoBERTa Tokenizer
87 | charindex2wordindex = {}
88 | total = 0
89 | tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens (here, CoreNLP).
90 | t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens.
91 | nlu_tt1 = [] # all_doc_tokens[ orig_to_tok_idx[i] ] returns first sub-token segement of i-th-1st-level-token
92 | for (ii, token) in enumerate(nlu_t1):
93 | t_to_tt_idx1.append(
94 | len(nlu_tt1)) # all_doc_tokens[ indicate the start position of original 'white-space' tokens.
95 | sub_tokens = tokenize.tokenize(token, is_pretokenized=True)
96 | for sub_token in sub_tokens:
97 | tt_to_t_idx1.append(ii)
98 | nlu_tt1.append(sub_token) # all_doc_tokens are further tokenized using RoBERTa tokenizer
99 |
100 | token_ = re_.sub('',token)
101 | for iii in range(len(token_)):
102 | charindex2wordindex[total+iii]=ii
103 | total += len(token_)
104 |
105 | one_final = one_data
106 | # one_table = table[one_data["table_id"]]
107 | final_question = [0] * len(nlu_tt1)
108 | one_final["bertindex_knowledge"] = final_question
109 | final_header = [0] * len(one_data["header"])
110 | one_final["header_knowledge"] = final_header
111 | for ii,h in enumerate(one_data["header"]):
112 | h = h.lower()
113 | hs = h.split("/")
114 | for h_ in hs:
115 | flag, start_, end_ = contains2(re_.sub('', h_), "".join(one_data["question_tok"]).lower())
116 | if flag == True:
117 | try:
118 | start = t_to_tt_idx1[charindex2wordindex[start_]]
119 | end = t_to_tt_idx1[charindex2wordindex[end_]]
120 | for iii in range(start,end):
121 | final_question[iii] = 4
122 | final_question[start] = 4
123 | final_question[end] = 4
124 | one_final["bertindex_knowledge"] = final_question
125 | except:
126 | # print("!!!!!")
127 | continue
128 |
129 | for ii,h in enumerate(one_data["header"]):
130 | h = h.lower()
131 | hs = h.split("/")
132 | for h_ in hs:
133 | flag, start_, end_ = contains2(re_.sub('', h_), "".join(one_data["question_tok"]).lower())
134 | if flag == True:
135 | try:
136 | final_header[ii] = 1
137 | break
138 | except:
139 | # print("!!!!")
140 | continue
141 |
142 | one_final["header_knowledge"] = final_header
143 |
144 | if "bertindex_knowledge" not in one_final and len(one_final["sql"]["conds"])>0:
145 | one_final["bertindex_knowledge"] = [0] * len(nlu_tt1)
146 | badcase+=1
147 |
148 | final_all.append([one_data["question_tok"],one_final["bertindex_knowledge"],one_final["header_knowledge"]])
149 | return final_all
150 |
151 |
152 | def contains2(small_str,big_str):
153 | if small_str in big_str:
154 | start = big_str.index(small_str)
155 | return True,start,start+len(small_str)-1
156 | else:
157 | return False,-1,-1
158 |
159 | def infer(nlu1,
160 | table_id, headers, types, tokenizer,
161 | model, model_roberta, roberta_config, max_seq_length, num_target_layers,
162 | beam_size=4):
163 |
164 | model.eval()
165 | model_roberta.eval()
166 |
167 | # Get inputs
168 | nlu = [nlu1]
169 |
170 | nlu_t1 = sent_split(nlu1)
171 | nlu_t = [nlu_t1]
172 | hds = [headers]
173 | hs_t = [[]]
174 |
175 | data = {}
176 | data['question_tok'] = nlu_t[0]
177 | data['table_id'] = table_id
178 | data['header'] = headers
179 | data = [data]
180 |
181 | tb = {}
182 | tb['id'] = table_id
183 | tb['header'] = headers
184 | tb['types'] = types
185 | tb = [tb]
186 |
187 | tk = tokenizer
188 |
189 | check = process(data, tk)
190 | knowledge = [check[0][1]]
191 | header_knowledge = [check[0][2]]
192 |
193 | wemb_n, wemb_h, l_n, l_hpu, l_hs, \
194 | nlu_tt, t_to_tt_idx, tt_to_t_idx \
195 | = get_wemb_roberta(roberta_config, model_roberta, tokenizer, nlu_t, hds, max_seq_length,
196 | num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
197 |
198 | prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
199 | l_hs, tb,
200 | nlu_t, nlu_tt,
201 | tt_to_t_idx, nlu,
202 | beam_size=beam_size,
203 | knowledge=knowledge,
204 | knowledge_header=header_knowledge)
205 | pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
206 |
207 |
208 | if len(pr_sql_i) != 1:
209 | raise EnvironmentError
210 | pr_sql_q1 = generate_sql_q(pr_sql_i, tb)
211 | pr_sql_q = [pr_sql_q1]
212 | '''
213 | print(f'START ============================================================= ')
214 | print(f'{hds}')
215 | print(f'nlu: {nlu}')
216 | print(f'pr_sql_i : {pr_sql_i}')
217 | print(f'pr_sql_q : {pr_sql_q}')
218 | print(f'---------------------------------------------------------------------')
219 | '''
220 | return pr_sql_q1
221 |
--------------------------------------------------------------------------------
/roberta_training.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | device = torch.device("cuda")
4 |
5 | def get_where_column(conds):
6 | """
7 | [ [where_column, where_operator, where_value],
8 | [where_column, where_operator, where_value], ...
9 | ]
10 | """
11 | where_column = []
12 | for cond in conds:
13 | where_column.append(cond[0])
14 | return where_column
15 |
16 |
17 | def get_where_operator(conds):
18 | """
19 | [ [where_column, where_operator, where_value],
20 | [where_column, where_operator, where_value], ...
21 | ]
22 | """
23 | where_operator = []
24 | for cond in conds:
25 | where_operator.append(cond[1])
26 | return where_operator
27 |
28 |
29 | def get_where_value(conds):
30 | """
31 | [ [where_column, where_operator, where_value],
32 | [where_column, where_operator, where_value], ...
33 | ]
34 | """
35 | where_value = []
36 | for cond in conds:
37 | where_value.append(cond[2])
38 | return where_value
39 |
40 |
41 | def get_ground_truth_values(canonical_sql_queries):
42 |
43 | ground_select_column = []
44 | ground_select_aggregate = []
45 | ground_where_number = []
46 | ground_where_column = []
47 | ground_where_operator = []
48 | ground_where_value = []
49 | for _, canonical_sql_query in enumerate(canonical_sql_queries):
50 | ground_select_column.append( canonical_sql_query["sel"] )
51 | ground_select_aggregate.append( canonical_sql_query["agg"])
52 |
53 | conds = canonical_sql_query['conds']
54 | if not canonical_sql_query["agg"] < 0:
55 | ground_where_number.append( len( conds ) )
56 | ground_where_column.append( get_where_column(conds) )
57 | ground_where_operator.append( get_where_operator(conds) )
58 | ground_where_value.append( get_where_value(conds) )
59 | else:
60 | raise EnvironmentError
61 | return ground_select_column, ground_select_aggregate, ground_where_number, ground_where_column,\
62 | ground_where_operator, ground_where_value
63 |
64 |
65 | def get_wemb_roberta(roberta_config, model_roberta, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=1, num_out_layers_h=1):
66 | '''
67 | wemb_n : word embedding of natural language question
68 | wemb_h : word embedding of header
69 | l_n : length of natural question
70 | l_hs : length of header
71 | nlu_tt : Natural language double tokenized
72 | t_to_tt_idx : map first level tokenization to second level tokenization
73 | tt_to_t_idx : map second level tokenization to first level tokenization
74 | '''
75 | # get contextual output of all tokens from RoBERTa
76 | all_encoder_layer, i_nlu, i_headers,\
77 | l_n, l_hpu, l_hs, \
78 | nlu_tt, t_to_tt_idx, tt_to_t_idx = get_roberta_output(model_roberta, tokenizer, nlu_t, hds, max_seq_length)
79 | # all_encoder_layer: RoBERTa outputs from all layers.
80 | # i_nlu: start and end indices of question in tokens
81 | # i_headers: start and end indices of headers
82 |
83 | # get the wemb
84 | wemb_n = get_wemb_n(i_nlu, l_n, roberta_config.hidden_size, roberta_config.num_hidden_layers, all_encoder_layer,
85 | num_out_layers_n)
86 |
87 | wemb_h = get_wemb_h(i_headers, l_hpu, l_hs, roberta_config.hidden_size, roberta_config.num_hidden_layers, all_encoder_layer,
88 | num_out_layers_h)
89 |
90 | return wemb_n, wemb_h, l_n, l_hpu, l_hs, \
91 | nlu_tt, t_to_tt_idx, tt_to_t_idx
92 |
93 | def get_roberta_output(model_roberta, tokenizer, nlu_t, headers, max_seq_length):
94 | """
95 | Here, input is toknized further by RoBERTa Tokenizer and fed into RoBERTa
96 | INPUT
97 | :param model_roberta:
98 | :param tokenizer: RoBERTa toknizer
99 | :param nlu: Question
100 | :param nlu_t: tokenized natural_language_utterance.
101 | :param headers: Headers of the table
102 | :param max_seq_length: max input token length
103 | OUTPUT
104 | tokens: RoBERTa input tokens
105 | nlu_tt: RoBERTa-tokenized input natural language questions
106 | orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token
107 | tok_to_orig_index: inverse map.
108 | """
109 |
110 | l_n = []
111 | l_hs = [] # The length of columns for each batch
112 |
113 | input_ids = []
114 | input_mask = []
115 |
116 | i_nlu = [] # index to retreive the position of contextual vector later.
117 | i_headers = []
118 |
119 | nlu_tt = []
120 |
121 | t_to_tt_idx = []
122 | tt_to_t_idx = []
123 | for b, nlu_t1 in enumerate(nlu_t):
124 |
125 | batch_headers = headers[b]
126 | l_hs.append(len(batch_headers))
127 |
128 |
129 | # 1. Tokenization using RoBERTa Tokenizer
130 | tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens
131 | t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens.
132 | nlu_tt1 = []
133 | for (i, token) in enumerate(nlu_t1):
134 | t_to_tt_idx1.append(
135 | len(nlu_tt1))
136 | sub_tokens = tokenizer.tokenize(token, is_pretokenized=True)
137 | for sub_token in sub_tokens:
138 | tt_to_t_idx1.append(i)
139 | nlu_tt1.append(sub_token)
140 | nlu_tt.append(nlu_tt1)
141 | tt_to_t_idx.append(tt_to_t_idx1)
142 | t_to_tt_idx.append(t_to_tt_idx1)
143 |
144 | l_n.append(len(nlu_tt1))
145 |
146 | # nlu col1 col2 ...col-n
147 | # 2. Generate RoBERTa inputs & indices.
148 | tokens, i_nlu1, i_batch_headers = generate_inputs(tokenizer, nlu_tt1, batch_headers)
149 | input_ids1 = tokenizer.convert_tokens_to_ids(tokens)
150 |
151 | # Input masks
152 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
153 | # tokens are attended to.
154 | input_mask1 = [1] * len(input_ids1)
155 |
156 | # 3. Zero-pad up to the sequence length.
157 | while len(input_ids1) < max_seq_length:
158 | input_ids1.append(0)
159 | input_mask1.append(0)
160 |
161 | assert len(input_ids1) == max_seq_length
162 | assert len(input_mask1) == max_seq_length
163 |
164 | input_ids.append(input_ids1)
165 | input_mask.append(input_mask1)
166 |
167 | i_nlu.append(i_nlu1)
168 | i_headers.append(i_batch_headers)
169 |
170 | # Convert to tensor
171 | all_input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
172 | all_input_mask = torch.tensor(input_mask, dtype=torch.long).to(device)
173 |
174 | # 4. Generate RoBERTa output.
175 | check, _, all_encoder_layer = model_roberta(input_ids=all_input_ids, attention_mask=all_input_mask, output_hidden_states=True)
176 | all_encoder_layer = list(all_encoder_layer)
177 |
178 | assert all((check == all_encoder_layer[-1]).tolist())
179 |
180 | # 5. generate l_hpu from i_headers
181 | l_hpu = gen_l_hpu(i_headers)
182 |
183 | return all_encoder_layer, i_nlu, i_headers, \
184 | l_n, l_hpu, l_hs, \
185 | nlu_tt, t_to_tt_idx, tt_to_t_idx
186 |
187 |
188 | def generate_inputs(tokenizer, nlu1_tok, hds1):
189 | tokens = []
190 |
191 | tokens.append("")
192 | i_st_nlu = len(tokens) # to use it later
193 |
194 | for token in nlu1_tok:
195 | tokens.append(token)
196 | i_ed_nlu = len(tokens)
197 | tokens.append("")
198 |
199 | i_headers = []
200 |
201 | for i, hds11 in enumerate(hds1):
202 | i_st_hd = len(tokens)
203 | sub_tok = tokenizer.tokenize(hds11)
204 | tokens += sub_tok
205 | i_ed_hd = len(tokens)
206 | i_headers.append((i_st_hd, i_ed_hd))
207 | if i < len(hds1)-1:
208 | tokens.append("")
209 | elif i == len(hds1)-1:
210 | tokens.append("")
211 | else:
212 | raise EnvironmentError
213 |
214 | i_nlu = (i_st_nlu, i_ed_nlu)
215 |
216 | return tokens, i_nlu, i_headers
217 |
218 | def gen_l_hpu(i_headers):
219 | """
220 | # Treat columns as if it is a batch of natural language utterance with batch-size = # of columns * # of batch_size
221 | i_headers = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)])
222 | """
223 | l_hpu = []
224 |
225 | for i_header in i_headers:
226 | for index_pair in i_header:
227 | l_hpu.append(index_pair[1] - index_pair[0])
228 |
229 | return l_hpu
230 |
231 | def get_wemb_n(i_nlu, l_n, hS, num_hidden_layers, all_encoder_layer, num_out_layers_n):
232 | """
233 | Get the representation of each tokens.
234 | """
235 | bS = len(l_n)
236 | l_n_max = max(l_n)
237 | wemb_n = torch.zeros([bS, l_n_max, hS * num_out_layers_n]).to(device)
238 | for b in range(bS):
239 |
240 | l_n1 = l_n[b]
241 | i_nlu1 = i_nlu[b]
242 | for i_noln in range(num_out_layers_n):
243 | i_layer = num_hidden_layers - 1 - i_noln
244 | st = i_noln * hS
245 | ed = (i_noln + 1) * hS
246 | wemb_n[b, 0:(i_nlu1[1] - i_nlu1[0]), st:ed] = all_encoder_layer[i_layer][b, i_nlu1[0]:i_nlu1[1], :]
247 |
248 | return wemb_n
249 |
250 | def get_wemb_h(i_headers, l_hpu, l_hs, hS, num_hidden_layers, all_encoder_layer, num_out_layers_h):
251 | """
252 | As if
253 | [ [table-1-col-1-tok1, t1-c1-t2, ...],
254 | [t1-c2-t1, t1-c2-t2, ...].
255 | ...
256 | [t2-c1-t1, ...,]
257 | ]
258 | """
259 | bS = len(l_hs)
260 | l_hpu_max = max(l_hpu)
261 | num_of_all_hds = sum(l_hs)
262 | wemb_h = torch.zeros([num_of_all_hds, l_hpu_max, hS * num_out_layers_h]).to(device)
263 | b_pu = -1
264 | for b, i_header in enumerate(i_headers):
265 | for b1, index_pair in enumerate(i_header):
266 | b_pu += 1
267 | for i_nolh in range(num_out_layers_h):
268 | i_layer = num_hidden_layers - 1 - i_nolh
269 | st = i_nolh * hS
270 | ed = (i_nolh + 1) * hS
271 | wemb_h[b_pu, 0:(index_pair[1] - index_pair[0]), st:ed] \
272 | = all_encoder_layer[i_layer][b, index_pair[0]:index_pair[1],:]
273 |
274 | return wemb_h
275 |
--------------------------------------------------------------------------------
/seq2sql_model_internal_functions.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import random as rd
5 | from copy import deepcopy
6 |
7 | from matplotlib.pylab import *
8 |
9 | import math
10 | import torch
11 | import torchvision.datasets as dsets
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | # import torch_xla
15 | # import torch_xla.core.xla_model as xm
16 |
17 | device = torch.device("cuda")
18 |
19 |
20 |
21 |
22 | def encode(lstm, wemb_l, l, return_hidden=False, hc0=None, last_only=False):
23 | """ [batch_size, max token length, dim_emb]
24 | """
25 | bS, mL, eS = wemb_l.shape
26 |
27 | # sort before packking
28 | l = array(l)
29 | perm_idx = argsort(-l)
30 | perm_idx_inv = generate_perm_inv(perm_idx)
31 |
32 | # pack sequence
33 |
34 | packed_wemb_l = nn.utils.rnn.pack_padded_sequence(wemb_l[perm_idx, :, :],
35 | l[perm_idx],
36 | batch_first=True)
37 |
38 | # Time to encode
39 | if hc0 is not None:
40 | hc0 = (hc0[0][:, perm_idx], hc0[1][:, perm_idx])
41 |
42 | # ipdb.set_trace()
43 | packed_wemb_l = packed_wemb_l.float() # I don't know why..
44 | packed_wenc, hc_out = lstm(packed_wemb_l, hc0)
45 | hout, cout = hc_out
46 |
47 | # unpack
48 | wenc, _l = nn.utils.rnn.pad_packed_sequence(packed_wenc, batch_first=True)
49 |
50 | if last_only:
51 | # Take only final outputs for each columns.
52 | wenc = wenc[tuple(range(bS)), l[perm_idx] - 1] # [batch_size, dim_emb]
53 | wenc.unsqueeze_(1) # [batch_size, 1, dim_emb]
54 |
55 | wenc = wenc[perm_idx_inv]
56 |
57 | if return_hidden:
58 | # hout.shape = [number_of_directoin * num_of_layer, seq_len(=batch size), dim * number_of_direction ] w/ batch_first.. w/o batch_first? I need to see.
59 | hout = hout[:, perm_idx_inv].to(device)
60 | cout = cout[:, perm_idx_inv].to(device) # Is this correct operation?
61 |
62 | return wenc, hout, cout
63 | else:
64 | return wenc
65 |
66 |
67 | def encode_hpu(lstm, wemb_hpu, l_hpu, l_hs):
68 | wenc_hpu, hout, cout = encode(lstm,
69 | wemb_hpu,
70 | l_hpu,
71 | return_hidden=True,
72 | hc0=None,
73 | last_only=True)
74 |
75 | wenc_hpu = wenc_hpu.squeeze(1)
76 | bS_hpu, mL_hpu, eS = wemb_hpu.shape
77 | hS = wenc_hpu.size(-1)
78 |
79 | wenc_hs = wenc_hpu.new_zeros(len(l_hs), max(l_hs), hS)
80 | wenc_hs = wenc_hs.to(device)
81 |
82 | # Re-pack according to batch.
83 | # ret = [B_NLq, max_len_headers_all, dim_lstm]
84 | st = 0
85 | for i, l_hs1 in enumerate(l_hs):
86 | wenc_hs[i, :l_hs1] = wenc_hpu[st:(st + l_hs1)]
87 | st += l_hs1
88 |
89 | return wenc_hs
90 |
91 |
92 | def generate_perm_inv(perm):
93 | # Definitly correct.
94 | perm_inv = zeros(len(perm), dtype=int) # Was an undefine int32 variable
95 | for i, p in enumerate(perm):
96 | perm_inv[int(p)] = i
97 |
98 | return perm_inv
99 |
100 |
101 | def pred_sc(s_sc):
102 | """
103 | return: [ pr_wc1_i, pr_wc2_i, ...]
104 | """
105 | # get g_num
106 | pr_sc = []
107 | for s_sc1 in s_sc:
108 | pr_sc.append(s_sc1.argmax().item())
109 |
110 | return pr_sc
111 |
112 |
113 | def pred_sc_beam(s_sc, beam_size):
114 | """
115 | return: [ pr_wc1_i, pr_wc2_i, ...]
116 | """
117 | # get g_num
118 | pr_sc_beam = []
119 |
120 | for s_sc1 in s_sc:
121 | val, idxes = s_sc1.topk(k=beam_size)
122 | pr_sc_beam.append(idxes.tolist())
123 |
124 | return pr_sc_beam
125 |
126 |
127 | def pred_sa(s_sa):
128 | """
129 | return: [ pr_wc1_i, pr_wc2_i, ...]
130 | """
131 | # get g_num
132 | pr_sa = []
133 | for s_sa1 in s_sa:
134 | pr_sa.append(s_sa1.argmax().item())
135 |
136 | return pr_sa
137 |
138 |
139 | def pred_wn(s_wn):
140 | """
141 | return: [ pr_wc1_i, pr_wc2_i, ...]
142 | """
143 | # get g_num
144 | pr_wn = []
145 | for s_wn1 in s_wn:
146 | pr_wn.append(s_wn1.argmax().item())
147 | # print(pr_wn, s_wn1)
148 | # if s_wn1.argmax().item() == 3:
149 | # input('')
150 |
151 | return pr_wn
152 |
153 |
154 | def pred_wc(wn, s_wc):
155 | """
156 | return: [ pr_wc1_i, pr_wc2_i, ...]
157 | ! Returned index is sorted!
158 | """
159 | # get g_num
160 | pr_wc = []
161 | for b, wn1 in enumerate(wn):
162 | s_wc1 = s_wc[b]
163 |
164 | pr_wc1 = argsort(-s_wc1.data.cpu().numpy())[:wn1]
165 | pr_wc1.sort()
166 |
167 | pr_wc.append(list(pr_wc1))
168 | return pr_wc
169 |
170 |
171 | def pred_wo(wn, s_wo):
172 | """
173 | return: [ pr_wc1_i, pr_wc2_i, ...]
174 | """
175 | # s_wo = [B, 4, n_op]
176 | pr_wo_a = s_wo.argmax(dim=2) # [B, 4]
177 | # get g_num
178 | pr_wo = []
179 | for b, pr_wo_a1 in enumerate(pr_wo_a):
180 | wn1 = wn[b]
181 | pr_wo.append(list(pr_wo_a1.data.cpu().numpy()[:wn1]))
182 |
183 | return pr_wo
184 |
185 |
186 | def topk_multi_dim(tensor, n_topk=1, batch_exist=True):
187 |
188 | if batch_exist:
189 | idxs = []
190 | for b, tensor1 in enumerate(tensor):
191 | idxs1 = []
192 | tensor1_1d = tensor1.reshape(-1)
193 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk)
194 | idxs_list = unravel_index(idxs_1d.cpu().numpy(), tensor1.shape)
195 | # (dim0, dim1, dim2, ...)
196 |
197 | # reconstruct
198 | for i_beam in range(n_topk):
199 | idxs11 = []
200 | for idxs_list1 in idxs_list:
201 | idxs11.append(idxs_list1[i_beam])
202 | idxs1.append(idxs11)
203 | idxs.append(idxs1)
204 |
205 | else:
206 | tensor1 = tensor
207 | idxs1 = []
208 | tensor1_1d = tensor1.reshape(-1)
209 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk)
210 | idxs_list = unravel_index(idxs_1d.numpy(), tensor1.shape)
211 | # (dim0, dim1, dim2, ...)
212 |
213 | # reconstruct
214 | for i_beam in range(n_topk):
215 | idxs11 = []
216 | for idxs_list1 in idxs_list:
217 | idxs11.append(idxs_list1[i_beam])
218 | idxs1.append(idxs11)
219 | idxs = idxs1
220 | return idxs
221 |
222 |
223 | def remap_sc_idx(idxs, pr_sc_beam):
224 | for b, idxs1 in enumerate(idxs):
225 | for i_beam, idxs11 in enumerate(idxs1):
226 | sc_beam_idx = idxs[b][i_beam][0]
227 | sc_idx = pr_sc_beam[b][sc_beam_idx]
228 | idxs[b][i_beam][0] = sc_idx
229 |
230 | return idxs
231 |
232 |
233 | def check_sc_sa_pairs(tb, pr_sc, pr_sa, ):
234 | """
235 | Check whether pr_sc, pr_sa are allowed pairs or not.
236 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
237 | """
238 | bS = len(pr_sc)
239 | check = [False] * bS
240 | for b, pr_sc1 in enumerate(pr_sc):
241 | pr_sa1 = pr_sa[b]
242 | hd_types1 = tb[b]['types']
243 | hd_types11 = hd_types1[pr_sc1]
244 | if hd_types11 == 'text':
245 | if pr_sa1 == 0 or pr_sa1 == 3: # ''
246 | check[b] = True
247 | else:
248 | check[b] = False
249 |
250 | elif hd_types11 == 'real':
251 | check[b] = True
252 | else:
253 | raise Exception("New TYPE!!")
254 |
255 | return check
256 |
257 |
258 | def pred_wvi_se_beam(max_wn, s_wv, beam_size):
259 | """
260 | s_wv: [B, 4, mL, 2]
261 | - predict best st-idx & ed-idx
262 | output:
263 | pr_wvi_beam = [B, max_wn, n_pairs, 2]. 2 means [st, ed].
264 | prob_wvi_beam = [B, max_wn, n_pairs]
265 | """
266 | bS = s_wv.shape[0]
267 |
268 | # [B, 4, mL, 2] -> [B, 4, mL, 1], [B, 4, mL, 1]
269 | s_wv_st, s_wv_ed = s_wv.split(1, dim=3)
270 |
271 | s_wv_st = s_wv_st.squeeze(3) # [B, 4, mL, 1] -> [B, 4, mL]
272 | s_wv_ed = s_wv_ed.squeeze(3)
273 |
274 | prob_wv_st = F.softmax(s_wv_st, dim=-1).detach().to('cpu').numpy()
275 | prob_wv_ed = F.softmax(s_wv_ed, dim=-1).detach().to('cpu').numpy()
276 |
277 | k_logit = int(ceil(sqrt(beam_size)))
278 | n_pairs = k_logit**2
279 | assert n_pairs >= beam_size
280 | values_st, idxs_st = s_wv_st.topk(k_logit) # [B, 4, mL] -> [B, 4, k_logit]
281 | values_ed, idxs_ed = s_wv_ed.topk(k_logit) # [B, 4, mL] -> [B, 4, k_logit]
282 |
283 | # idxs = [B, k_logit, 2]
284 | # Generate all possible combination of st, ed indices & prob
285 | pr_wvi_beam = [] # [B, max_wn, k_logit**2 [st, ed] paris]
286 | prob_wvi_beam = zeros([bS, max_wn, n_pairs])
287 | for b in range(bS):
288 | pr_wvi_beam1 = []
289 |
290 | idxs_st1 = idxs_st[b]
291 | idxs_ed1 = idxs_ed[b]
292 | for i_wn in range(max_wn):
293 | idxs_st11 = idxs_st1[i_wn]
294 | idxs_ed11 = idxs_ed1[i_wn]
295 |
296 | pr_wvi_beam11 = []
297 | pair_idx = -1
298 | for i_k in range(k_logit):
299 | for j_k in range(k_logit):
300 | pair_idx += 1
301 | st = idxs_st11[i_k].item()
302 | ed = idxs_ed11[j_k].item()
303 | pr_wvi_beam11.append([st, ed])
304 |
305 | p1 = prob_wv_st[b, i_wn, st]
306 | p2 = prob_wv_ed[b, i_wn, ed]
307 | prob_wvi_beam[b, i_wn, pair_idx] = p1*p2
308 | pr_wvi_beam1.append(pr_wvi_beam11)
309 | pr_wvi_beam.append(pr_wvi_beam1)
310 |
311 | # prob
312 |
313 | return pr_wvi_beam, prob_wvi_beam
314 |
315 |
316 | def convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_wp_t, wp_to_wh_index, nlu):
317 | """
318 | - Convert to the string in whilte-space-separated tokens
319 | - Add-hoc addition.
320 | """
321 | pr_wv_str_wp = [] # word-piece version
322 | pr_wv_str = []
323 | for b, pr_wvi1 in enumerate(pr_wvi):
324 | pr_wv_str_wp1 = []
325 | pr_wv_str1 = []
326 | wp_to_wh_index1 = wp_to_wh_index[b]
327 | nlu_wp_t1 = nlu_wp_t[b]
328 | nlu_t1 = nlu_t[b]
329 |
330 | for i_wn, pr_wvi11 in enumerate(pr_wvi1):
331 | st_idx, ed_idx = pr_wvi11
332 |
333 | # Ad-hoc modification of ed_idx to deal with wp-tokenization effect.
334 | # e.g.) to convert "butler cc (" ->"butler cc (ks)" (dev set 1st question).
335 | pr_wv_str_wp11 = nlu_wp_t1[st_idx:ed_idx+1]
336 | pr_wv_str_wp1.append(pr_wv_str_wp11)
337 |
338 | st_wh_idx = wp_to_wh_index1[st_idx]
339 | ed_wh_idx = wp_to_wh_index1[ed_idx]
340 | pr_wv_str11 = nlu_t1[st_wh_idx:ed_wh_idx+1]
341 |
342 | pr_wv_str1.append(pr_wv_str11)
343 |
344 | pr_wv_str_wp.append(pr_wv_str_wp1)
345 | pr_wv_str.append(pr_wv_str1)
346 |
347 | return pr_wv_str, pr_wv_str_wp
348 |
349 |
350 | def merge_wv_t1_eng(where_str_tokens, NLq):
351 | """
352 | Almost copied of SQLNet.
353 | The main purpose is pad blank line while combining tokens.
354 | """
355 | nlq = NLq.lower()
356 | where_str_tokens = [tok.lower() for tok in where_str_tokens]
357 | alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$'
358 | special = {'-LRB-': '(',
359 | '-RRB-': ')',
360 | '-LSB-': '[',
361 | '-RSB-': ']',
362 | '``': '"',
363 | '\'\'': '"',
364 | }
365 | # '--': '\u2013'} # this generate error for test 5661 case.
366 | ret = ''
367 | double_quote_appear = 0
368 | for raw_w_token in where_str_tokens:
369 | # if '' (empty string) of None, continue
370 | if not raw_w_token:
371 | continue
372 |
373 | # Change the special characters
374 | # maybe necessary for some case?
375 | w_token = special.get(raw_w_token, raw_w_token)
376 |
377 | # check the double quote
378 | if w_token == '"':
379 | double_quote_appear = 1 - double_quote_appear
380 |
381 | # Check whether ret is empty. ret is selected where condition.
382 | if len(ret) == 0:
383 | pass
384 | # Check blank character.
385 | elif len(ret) > 0 and ret + ' ' + w_token in nlq:
386 | # Pad ' ' if ret + ' ' is part of nlq.
387 | ret = ret + ' '
388 |
389 | elif len(ret) > 0 and ret + w_token in nlq:
390 | pass # already in good form. Later, ret + w_token will performed.
391 |
392 | # Below for unnatural question I guess. Is it likely to appear?
393 | elif w_token == '"':
394 | if double_quote_appear:
395 | ret = ret + ' ' # pad blank line between next token when " because in this case, it is of closing apperas
396 | # for the case of opening, no blank line.
397 |
398 | elif w_token[0] not in alphabet:
399 | pass # non alphabet one does not pad blank line.
400 |
401 | # when previous character is the special case.
402 | elif (ret[-1] not in ['(', '/', '\u2013', '#', '$', '&']) and (ret[-1] != '"' or not double_quote_appear):
403 | ret = ret + ' '
404 | ret = ret + w_token
405 |
406 | return ret.strip()
407 |
--------------------------------------------------------------------------------
/seq2sql_model_training_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from matplotlib.pylab import *
5 | from copy import deepcopy
6 | #import torch_xla
7 | #import torch_xla.core.xla_model as xm
8 |
9 | device = torch.device("cuda")
10 |
11 |
12 | def Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi):
13 | """
14 | :param s_wv: score [ B, n_conds, T, score]
15 | :param g_wn: [ B ]
16 | :param g_wvi: [B, conds, pnt], e.g. [[[0, 6, 7, 8, 15], [0, 1, 2, 3, 4, 15]], [[0, 1, 2, 3, 16], [0, 7, 8, 9, 16]]]
17 | :return:
18 | """
19 | loss = 0
20 | loss += F.cross_entropy(s_sc, torch.tensor(g_sc).to(device))
21 | loss += F.cross_entropy(s_sa, torch.tensor(g_sa).to(device))
22 | loss += F.cross_entropy(s_wn, torch.tensor(g_wn).to(device))
23 | loss += Loss_wc(s_wc, g_wc)
24 | loss += Loss_wo(s_wo, g_wn, g_wo)
25 | loss += Loss_wv_se(s_wv, g_wn, g_wvi)
26 |
27 | return loss
28 |
29 |
30 | def Loss_wc(s_wc, g_wc):
31 |
32 | # Construct index matrix
33 | bS, max_h_len = s_wc.shape
34 | im = torch.zeros([bS, max_h_len]).to(device)
35 | for b, g_wc1 in enumerate(g_wc):
36 | for g_wc11 in g_wc1:
37 | im[b, g_wc11] = 1.0
38 | # Construct prob.
39 | p = F.sigmoid(s_wc)
40 | loss = F.binary_cross_entropy(p, im)
41 |
42 | return loss
43 |
44 |
45 | def Loss_wo(s_wo, g_wn, g_wo):
46 |
47 | # Construct index matrix
48 | loss = 0
49 | for b, g_wn1 in enumerate(g_wn):
50 | if g_wn1 == 0:
51 | continue
52 | g_wo1 = g_wo[b]
53 | s_wo1 = s_wo[b]
54 | loss += F.cross_entropy(s_wo1[:g_wn1], torch.tensor(g_wo1).to(device))
55 |
56 | return loss
57 |
58 | def Loss_wv_se(s_wv, g_wn, g_wvi):
59 | """
60 | s_wv: [bS, 4, mL, 2], 4 stands for maximum # of condition, 2 tands for start & end logits.
61 | g_wvi: [ [1, 3, 2], [4,3] ] (when B=2, wn(b=1) = 3, wn(b=2) = 2).
62 | """
63 | loss = 0
64 | # g_wvi = torch.tensor(g_wvi).to(device)
65 | for b, g_wvi1 in enumerate(g_wvi):
66 | # for i_wn, g_wvi11 in enumerate(g_wvi1):
67 |
68 | g_wn1 = g_wn[b]
69 | if g_wn1 == 0:
70 | continue
71 | g_wvi1 = torch.tensor(g_wvi1).to(device)
72 | g_st1 = g_wvi1[:,0]
73 | g_ed1 = g_wvi1[:,1]
74 | # loss from the start position
75 | loss += F.cross_entropy(s_wv[b,:g_wn1,:,0], g_st1)
76 |
77 | # print("st_login: ", s_wv[b,:g_wn1,:,0], g_st1, loss)
78 | # loss from the end position
79 | loss += F.cross_entropy(s_wv[b,:g_wn1,:,1], g_ed1)
80 | # print("ed_login: ", s_wv[b,:g_wn1,:,1], g_ed1, loss)
81 |
82 | return loss
83 |
84 | def pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv):
85 | pr_sc = pred_sc(s_sc)
86 | pr_sa = pred_sa(s_sa)
87 | pr_wn = pred_wn(s_wn)
88 | pr_wc = pred_wc(pr_wn, s_wc)
89 | pr_wo = pred_wo(pr_wn, s_wo)
90 | pr_wvi = pred_wvi_se(pr_wn, s_wv)
91 |
92 | return pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi
93 |
94 | def pred_sc(s_sc):
95 | """
96 | return: [ pr_wc1_i, pr_wc2_i, ...]
97 | """
98 | # get g_num
99 | pr_sc = []
100 | for s_sc1 in s_sc:
101 | pr_sc.append(s_sc1.argmax().item())
102 |
103 | return pr_sc
104 |
105 | def pred_sc(s_sc):
106 | """
107 | return: [ pr_wc1_i, pr_wc2_i, ...]
108 | """
109 | # get g_num
110 | pr_sc = []
111 | for s_sc1 in s_sc:
112 | pr_sc.append(s_sc1.argmax().item())
113 |
114 | return pr_sc
115 |
116 | def pred_sa(s_sa):
117 | """
118 | return: [ pr_wc1_i, pr_wc2_i, ...]
119 | """
120 | # get g_num
121 | pr_sa = []
122 | for s_sa1 in s_sa:
123 | pr_sa.append(s_sa1.argmax().item())
124 |
125 | return pr_sa
126 |
127 | def pred_wn(s_wn):
128 | """
129 | return: [ pr_wc1_i, pr_wc2_i, ...]
130 | """
131 | # get g_num
132 | pr_wn = []
133 | for s_wn1 in s_wn:
134 | pr_wn.append(s_wn1.argmax().item())
135 | # print(pr_wn, s_wn1)
136 | # if s_wn1.argmax().item() == 3:
137 | # input('')
138 |
139 | return pr_wn
140 |
141 | def pred_wc(wn, s_wc):
142 | """
143 | return: [ pr_wc1_i, pr_wc2_i, ...]
144 | ! Returned index is sorted!
145 | """
146 | # get g_num
147 | pr_wc = []
148 | for b, wn1 in enumerate(wn):
149 | s_wc1 = s_wc[b]
150 |
151 | pr_wc1 = argsort(-s_wc1.data.cpu().numpy())[:wn1]
152 | pr_wc1.sort()
153 |
154 | pr_wc.append(list(pr_wc1))
155 | return pr_wc
156 |
157 | def pred_wo(wn, s_wo):
158 | """
159 | return: [ pr_wc1_i, pr_wc2_i, ...]
160 | """
161 | # s_wo = [B, 4, n_op]
162 | pr_wo_a = s_wo.argmax(dim=2) # [B, 4]
163 | # get g_num
164 | pr_wo = []
165 | for b, pr_wo_a1 in enumerate(pr_wo_a):
166 | wn1 = wn[b]
167 | pr_wo.append(list(pr_wo_a1.data.cpu().numpy()[:wn1]))
168 |
169 | return pr_wo
170 |
171 | def pred_wvi_se(wn, s_wv):
172 | """
173 | s_wv: [B, 4, mL, 2]
174 | - predict best st-idx & ed-idx
175 | """
176 |
177 | s_wv_st, s_wv_ed = s_wv.split(1, dim=3) # [B, 4, mL, 2] -> [B, 4, mL, 1], [B, 4, mL, 1]
178 |
179 | s_wv_st = s_wv_st.squeeze(3) # [B, 4, mL, 1] -> [B, 4, mL]
180 | s_wv_ed = s_wv_ed.squeeze(3)
181 |
182 | pr_wvi_st_idx = s_wv_st.argmax(dim=2) # [B, 4, mL] -> [B, 4, 1]
183 | pr_wvi_ed_idx = s_wv_ed.argmax(dim=2)
184 |
185 | pr_wvi = []
186 | for b, wn1 in enumerate(wn):
187 | pr_wvi1 = []
188 | for i_wn in range(wn1):
189 | pr_wvi_st_idx11 = pr_wvi_st_idx[b][i_wn]
190 | pr_wvi_ed_idx11 = pr_wvi_ed_idx[b][i_wn]
191 | pr_wvi1.append([pr_wvi_st_idx11.item(), pr_wvi_ed_idx11.item()])
192 | pr_wvi.append(pr_wvi1)
193 |
194 | return pr_wvi
195 |
196 | def convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_wp_t, wp_to_wh_index, nlu):
197 | """
198 | - Convert to the string in whilte-space-separated tokens
199 | - Add-hoc addition.
200 | """
201 | pr_wv_str_wp = [] # word-piece version
202 | pr_wv_str = []
203 | for b, pr_wvi1 in enumerate(pr_wvi):
204 | pr_wv_str_wp1 = []
205 | pr_wv_str1 = []
206 | wp_to_wh_index1 = wp_to_wh_index[b]
207 | nlu_wp_t1 = nlu_wp_t[b]
208 | nlu_t1 = nlu_t[b]
209 |
210 | for i_wn, pr_wvi11 in enumerate(pr_wvi1):
211 | st_idx, ed_idx = pr_wvi11
212 |
213 | # Ad-hoc modification of ed_idx to deal with wp-tokenization effect.
214 | # e.g.) to convert "butler cc (" ->"butler cc (ks)" (dev set 1st question).
215 | pr_wv_str_wp11 = nlu_wp_t1[st_idx:ed_idx+1]
216 | pr_wv_str_wp1.append(pr_wv_str_wp11)
217 |
218 | st_wh_idx = wp_to_wh_index1[st_idx]
219 | ed_wh_idx = wp_to_wh_index1[ed_idx]
220 | pr_wv_str11 = nlu_t1[st_wh_idx:ed_wh_idx+1]
221 |
222 | pr_wv_str1.append(pr_wv_str11)
223 |
224 | pr_wv_str_wp.append(pr_wv_str_wp1)
225 | pr_wv_str.append(pr_wv_str1)
226 |
227 | return pr_wv_str, pr_wv_str_wp
228 |
229 | def sort_pr_wc(pr_wc, g_wc):
230 | """
231 | Input: list
232 | pr_wc = [B, n_conds]
233 | g_wc = [B, n_conds]
234 | Return: list
235 | pr_wc_sorted = [B, n_conds]
236 | """
237 | pr_wc_sorted = []
238 | for b, pr_wc1 in enumerate(pr_wc):
239 | g_wc1 = g_wc[b]
240 | pr_wc1_sorted = []
241 |
242 | if set(g_wc1) == set(pr_wc1):
243 | pr_wc1_sorted = deepcopy(g_wc1)
244 | else:
245 | # no sorting when g_wc1 and pr_wc1 are different.
246 | pr_wc1_sorted = deepcopy(pr_wc1)
247 |
248 | pr_wc_sorted.append(pr_wc1_sorted)
249 | return pr_wc_sorted
250 |
251 | def generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu):
252 | pr_sql_i = []
253 | for b, nlu1 in enumerate(nlu):
254 | conds = []
255 | for i_wn in range(pr_wn[b]):
256 | conds1 = []
257 | conds1.append(pr_wc[b][i_wn])
258 | conds1.append(pr_wo[b][i_wn])
259 | merged_wv11 = merge_wv_t1_eng(pr_wv_str[b][i_wn], nlu[b])
260 | conds1.append(merged_wv11)
261 | conds.append(conds1)
262 |
263 | pr_sql_i1 = {'agg': pr_sa[b], 'sel': pr_sc[b], 'conds': conds}
264 | pr_sql_i.append(pr_sql_i1)
265 | return pr_sql_i
266 |
267 | def merge_wv_t1_eng(where_str_tokens, NLq):
268 | """
269 | Almost copied of SQLNet.
270 | The main purpose is pad blank line while combining tokens.
271 | """
272 | nlq = NLq.lower()
273 | where_str_tokens = [tok.lower() for tok in where_str_tokens]
274 | alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$'
275 | special = {'-LRB-': '(',
276 | '-RRB-': ')',
277 | '-LSB-': '[',
278 | '-RSB-': ']',
279 | '``': '"',
280 | '\'\'': '"',
281 | }
282 | # '--': '\u2013'} # this generate error for test 5661 case.
283 | ret = ''
284 | double_quote_appear = 0
285 | for raw_w_token in where_str_tokens:
286 | # if '' (empty string) of None, continue
287 | if not raw_w_token:
288 | continue
289 |
290 | # Change the special characters
291 | w_token = special.get(raw_w_token, raw_w_token) # maybe necessary for some case?
292 |
293 | # check the double quote
294 | if w_token == '"':
295 | double_quote_appear = 1 - double_quote_appear
296 |
297 | # Check whether ret is empty. ret is selected where condition.
298 | if len(ret) == 0:
299 | pass
300 | # Check blank character.
301 | elif len(ret) > 0 and ret + ' ' + w_token in nlq:
302 | # Pad ' ' if ret + ' ' is part of nlq.
303 | ret = ret + ' '
304 |
305 | elif len(ret) > 0 and ret + w_token in nlq:
306 | pass # already in good form. Later, ret + w_token will performed.
307 |
308 | # Below for unnatural question I guess. Is it likely to appear?
309 | elif w_token == '"':
310 | if double_quote_appear:
311 | ret = ret + ' ' # pad blank line between next token when " because in this case, it is of closing apperas
312 | # for the case of opening, no blank line.
313 |
314 | elif w_token[0] not in alphabet:
315 | pass # non alphabet one does not pad blank line.
316 |
317 | # when previous character is the special case.
318 | elif (ret[-1] not in ['(', '/', '\u2013', '#', '$', '&']) and (ret[-1] != '"' or not double_quote_appear):
319 | ret = ret + ' '
320 | ret = ret + w_token
321 |
322 | return ret.strip()
323 |
324 | def get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
325 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
326 | g_sql_i, pr_sql_i,
327 | mode):
328 | """ usalbe only when g_wc was used to find pr_wv
329 | """
330 | cnt_sc = get_cnt_sc_list(g_sc, pr_sc)
331 | cnt_sa = get_cnt_sc_list(g_sa, pr_sa)
332 | cnt_wn = get_cnt_sc_list(g_wn, pr_wn)
333 | cnt_wc = get_cnt_wc_list(g_wc, pr_wc)
334 | cnt_wo = get_cnt_wo_list(g_wn, g_wc, g_wo, pr_wc, pr_wo, mode)
335 | if pr_wvi:
336 | cnt_wvi = get_cnt_wvi_list(g_wn, g_wc, g_wvi, pr_wvi, mode)
337 | else:
338 | cnt_wvi = [0]*len(cnt_sc)
339 | cnt_wv = get_cnt_wv_list(g_wn, g_wc, g_sql_i, pr_sql_i, mode) # compare using wv-str which presented in original data.
340 |
341 |
342 | return cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wvi, cnt_wv
343 |
344 | def get_cnt_sc_list(g_sc, pr_sc):
345 | cnt_list = []
346 | for b, g_sc1 in enumerate(g_sc):
347 | pr_sc1 = pr_sc[b]
348 | if pr_sc1 == g_sc1:
349 | cnt_list.append(1)
350 | else:
351 | cnt_list.append(0)
352 |
353 | return cnt_list
354 |
355 | def get_cnt_wc_list(g_wc, pr_wc):
356 | cnt_list= []
357 | for b, g_wc1 in enumerate(g_wc):
358 |
359 | pr_wc1 = pr_wc[b]
360 | pr_wn1 = len(pr_wc1)
361 | g_wn1 = len(g_wc1)
362 |
363 | if pr_wn1 != g_wn1:
364 | cnt_list.append(0)
365 | continue
366 | else:
367 | wc1 = array(g_wc1)
368 | wc1.sort()
369 |
370 | if array_equal(pr_wc1, wc1):
371 | cnt_list.append(1)
372 | else:
373 | cnt_list.append(0)
374 |
375 | return cnt_list
376 |
377 | def get_cnt_wo_list(g_wn, g_wc, g_wo, pr_wc, pr_wo, mode):
378 | """ pr's are all sorted as pr_wc are sorted in increasing order (in column idx)
379 | However, g's are not sorted.
380 | Sort g's in increasing order (in column idx)
381 | """
382 | cnt_list=[]
383 | for b, g_wo1 in enumerate(g_wo):
384 | g_wc1 = g_wc[b]
385 | pr_wc1 = pr_wc[b]
386 | pr_wo1 = pr_wo[b]
387 | pr_wn1 = len(pr_wo1)
388 | g_wn1 = g_wn[b]
389 |
390 | if g_wn1 != pr_wn1:
391 | cnt_list.append(0)
392 | continue
393 | else:
394 | # Sort based wc sequence.
395 | if mode == 'test':
396 | idx = argsort(array(g_wc1))
397 |
398 | g_wo1_s = array(g_wo1)[idx]
399 | g_wo1_s = list(g_wo1_s)
400 | elif mode == 'train':
401 | # due to tearch forcing, no need to sort.
402 | g_wo1_s = g_wo1
403 | else:
404 | raise ValueError
405 |
406 | if type(pr_wo1) != list:
407 | raise TypeError
408 | if g_wo1_s == pr_wo1:
409 | cnt_list.append(1)
410 | else:
411 | cnt_list.append(0)
412 | return cnt_list
413 |
414 | def get_cnt_wvi_list(g_wn, g_wc, g_wvi, pr_wvi, mode):
415 | """ usalbe only when g_wc was used to find pr_wv
416 | """
417 | cnt_list =[]
418 | for b, g_wvi1 in enumerate(g_wvi):
419 | g_wc1 = g_wc[b]
420 | pr_wvi1 = pr_wvi[b]
421 | pr_wn1 = len(pr_wvi1)
422 | g_wn1 = g_wn[b]
423 |
424 | # Now sorting.
425 | # Sort based wc sequence.
426 | if mode == 'test':
427 | idx1 = argsort(array(g_wc1))
428 | elif mode == 'train':
429 | idx1 = list( range( g_wn1) )
430 | else:
431 | raise ValueError
432 |
433 | if g_wn1 != pr_wn1:
434 | cnt_list.append(0)
435 | continue
436 | else:
437 | flag = True
438 | for i_wn, idx11 in enumerate(idx1):
439 | g_wvi11 = g_wvi1[idx11]
440 | pr_wvi11 = pr_wvi1[i_wn]
441 | if g_wvi11 != pr_wvi11:
442 | flag = False
443 | # print(g_wv1, g_wv11)
444 | # print(pr_wv1, pr_wv11)
445 | # input('')
446 | break
447 | if flag:
448 | cnt_list.append(1)
449 | else:
450 | cnt_list.append(0)
451 |
452 | return cnt_list
453 |
454 | def get_cnt_wv_list(g_wn, g_wc, g_sql_i, pr_sql_i, mode):
455 | """ usalbe only when g_wc was used to find pr_wv
456 | """
457 | cnt_list =[]
458 | for b, g_wc1 in enumerate(g_wc):
459 | pr_wn1 = len(pr_sql_i[b]["conds"])
460 | g_wn1 = g_wn[b]
461 |
462 | # Now sorting.
463 | # Sort based wc sequence.
464 | if mode == 'test':
465 | idx1 = argsort(array(g_wc1))
466 | elif mode == 'train':
467 | idx1 = list( range( g_wn1) )
468 | else:
469 | raise ValueError
470 |
471 | if g_wn1 != pr_wn1:
472 | cnt_list.append(0)
473 | continue
474 | else:
475 | flag = True
476 | for i_wn, idx11 in enumerate(idx1):
477 | g_wvi_str11 = str(g_sql_i[b]["conds"][idx11][2]).lower()
478 | pr_wvi_str11 = str(pr_sql_i[b]["conds"][i_wn][2]).lower()
479 | # print(g_wvi_str11)
480 | # print(pr_wvi_str11)
481 | # print(g_wvi_str11==pr_wvi_str11)
482 | if g_wvi_str11 != pr_wvi_str11:
483 | flag = False
484 | # print(g_wv1, g_wv11)
485 | # print(pr_wv1, pr_wv11)
486 | # input('')
487 | break
488 | if flag:
489 | cnt_list.append(1)
490 | else:
491 | cnt_list.append(0)
492 |
493 | return cnt_list
494 |
495 | def get_cnt_lx_list(cnt_sc1, cnt_sa1, cnt_wn1, cnt_wc1, cnt_wo1, cnt_wv1):
496 | # all cnt are list here.
497 | cnt_list = []
498 | cnt_lx = 0
499 | for csc, csa, cwn, cwc, cwo, cwv in zip(cnt_sc1, cnt_sa1, cnt_wn1, cnt_wc1, cnt_wo1, cnt_wv1):
500 | if csc and csa and cwn and cwc and cwo and cwv:
501 | cnt_list.append(1)
502 | else:
503 | cnt_list.append(0)
504 |
505 | return cnt_list
506 |
507 | def get_cnt_x_list(engine, tb, g_sc, g_sa, g_sql_i, pr_sc, pr_sa, pr_sql_i):
508 | cnt_x1_list = []
509 | g_ans = []
510 | pr_ans = []
511 | for b in range(len(g_sc)):
512 | g_ans1 = engine.execute(tb[b]['id'], g_sc[b], g_sa[b], g_sql_i[b]['conds'])
513 | # print(f'cnt: {cnt}')
514 | # print(f"pr_sql_i: {pr_sql_i[b]['conds']}")
515 | try:
516 | pr_ans1 = engine.execute(tb[b]['id'], pr_sc[b], pr_sa[b], pr_sql_i[b]['conds'])
517 |
518 | if bool(pr_ans1): # not empty due to lack of the data from incorretly generated sql
519 | if g_ans1 == pr_ans1:
520 | cnt_x1 = 1
521 | else:
522 | cnt_x1 = 0
523 | else:
524 | cnt_x1 = 0
525 | except:
526 | # type error etc... Execution-guided decoding may be used here.
527 | pr_ans1 = None
528 | cnt_x1 = 0
529 | cnt_x1_list.append(cnt_x1)
530 | g_ans.append(g_ans1)
531 | pr_ans.append(pr_ans1)
532 |
533 | return cnt_x1_list, g_ans, pr_ans
534 |
--------------------------------------------------------------------------------
/dev_function.py:
--------------------------------------------------------------------------------
1 | from dbengine_sqlnet import DBEngine
2 |
3 | import os
4 | import seq2sql_model_training_functions
5 | import corenlp_local
6 | import load_data
7 | import roberta_training
8 | import infer_functions
9 | import torch
10 | from tqdm.notebook import tqdm
11 | import seq2sql_model_testing
12 |
13 |
14 | def train(seq2sql_model,roberta_model,model_optimizer,roberta_optimizer,roberta_tokenizer,roberta_config,path_wikisql,train_loader):
15 |
16 | roberta_model.train()
17 | seq2sql_model.train()
18 |
19 | results=[]
20 | average_loss = 0
21 |
22 | count_select_column = 0 # count the # of correct predictions of select column
23 | count_select_agg = 0 # of selectd aggregation
24 | count_where_number = 0 # of where number
25 | count_where_column = 0 # of where column
26 | count_where_operator = 0 # of where operator
27 | count_where_value = 0 # of where-value
28 | count_where_value_index = 0 # of where-value index (on question tokens)
29 | count_logical_form_acc = 0 # of logical form accuracy
30 | count_execution_acc = 0 # of execution accuracy
31 |
32 |
33 | # Engine for SQL querying.
34 | engine = DBEngine(os.path.join(path_wikisql, f"train.db"))
35 | count = 0 # count the # of examples
36 | for batch_index, batch in enumerate(tqdm(train_loader)):
37 | count += len(batch)
38 |
39 | # if batch_index > 2:
40 | # break
41 | # Get fields
42 |
43 | # nlu : natural language utterance
44 | # nlu_t: tokenized nlu
45 | # sql_i: canonical form of SQL query
46 | # sql_q: full SQL query text. Not used.
47 | # sql_t: tokenized SQL query
48 | # tb : table metadata. No row data needed
49 | # hs_t : tokenized headers. Not used.
50 | natural_lang_utterance, natural_lang_utterance_tokenized, sql_canonical, \
51 | _, _, table_metadata, _, headers = load_data.get_fields(batch)
52 |
53 |
54 | select_column_ground, select_agg_ground, where_number_ground, \
55 | where_column_ground, where_operator_ground, _ = roberta_training.get_ground_truth_values(sql_canonical)
56 | # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
57 |
58 |
59 | natural_lang_embeddings, header_embeddings, question_token_length, header_token_length, header_count, \
60 | natural_lang_double_tokenized, punkt_to_roberta_token_indices, roberta_to_punkt_token_indices \
61 | = roberta_training.get_wemb_roberta(roberta_config, roberta_model, roberta_tokenizer,
62 | natural_lang_utterance_tokenized, headers,max_seq_length= 222,
63 | num_out_layers_n=2, num_out_layers_h=2)
64 | # natural_lang_embeddings: natural language embedding
65 | # header_embeddings: header embedding
66 | # question_token_length: token lengths of each question
67 | # header_token_length: header token lengths
68 | # header_count: the number of columns (headers) of the tables.
69 |
70 | where_value_index_ground_corenlp = corenlp_local.get_g_wvi_corenlp(batch)
71 | try:
72 | #
73 | where_value_index_ground = corenlp_local.get_g_wvi_bert_from_g_wvi_corenlp(punkt_to_roberta_token_indices, where_value_index_ground_corenlp)
74 | except:
75 | # Exception happens when where-condition is not found in natural_lang_double_tokenized.
76 | # In this case, that train example is not used.
77 | # During test, that example considered as wrongly answered.
78 | # e.g. train: 32.
79 | continue
80 |
81 | knowledge = []
82 | for k in batch:
83 | if "bertindex_knowledge" in k:
84 | knowledge.append(k["bertindex_knowledge"])
85 | else:
86 | knowledge.append(max(question_token_length)*[0])
87 |
88 | knowledge_header = []
89 | for k in batch:
90 | if "header_knowledge" in k:
91 | knowledge_header.append(k["header_knowledge"])
92 | else:
93 | knowledge_header.append(max(header_count) * [0])
94 |
95 | # score
96 |
97 | select_column_score, select_agg_score, where_number_score, where_column_score,\
98 | where_operator_score, where_value_score = seq2sql_model(natural_lang_embeddings, question_token_length, header_embeddings,
99 | header_token_length, header_count,
100 | g_sc=select_column_ground, g_sa=select_agg_ground,
101 | g_wn=where_number_ground, g_wc=where_column_ground,
102 | g_wo=where_operator_ground, g_wvi=where_value_index_ground,
103 | knowledge = knowledge,
104 | knowledge_header = knowledge_header)
105 |
106 | # Calculate loss & step
107 | loss = seq2sql_model_training_functions.Loss_sw_se(select_column_score, select_agg_score, where_number_score,
108 | where_column_score, where_operator_score, where_value_score,
109 | select_column_ground, select_agg_ground,
110 | where_number_ground, where_column_ground,
111 | where_operator_ground, where_value_index_ground)
112 |
113 |
114 | model_optimizer.zero_grad()
115 | if roberta_optimizer:
116 | roberta_optimizer.zero_grad()
117 | loss.backward()
118 | model_optimizer.step()
119 | if roberta_optimizer:
120 | roberta_optimizer.step()
121 |
122 |
123 | # Prediction
124 | select_column_predict, select_agg_predict, where_number_predict, \
125 | where_column_predict, where_operator_predict, where_val_index_predict = seq2sql_model_training_functions.pred_sw_se(
126 | select_column_score, select_agg_score, where_number_score,
127 | where_column_score, where_operator_score, where_value_score)
128 | where_value_string_predict, _ = seq2sql_model_training_functions.convert_pr_wvi_to_string(
129 | where_val_index_predict,
130 | natural_lang_utterance_tokenized, natural_lang_double_tokenized,
131 | roberta_to_punkt_token_indices, natural_lang_utterance)
132 |
133 |
134 | # Sort where_column_predict:
135 | # Sort where_column_predict when training the model as where_operator_predict and where_val_index_predict are predicted using ground-truth where-column (g_wc)
136 | # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
137 | where_column_predict_sorted = seq2sql_model_training_functions.sort_pr_wc(where_column_predict, where_column_ground)
138 |
139 | sql_canonical_predict = seq2sql_model_training_functions.generate_sql_i(
140 | select_column_predict, select_agg_predict, where_number_predict,
141 | where_column_predict_sorted, where_operator_predict,
142 | where_value_string_predict, natural_lang_utterance)
143 |
144 | # Cacluate accuracy
145 | select_col_batchlist, select_agg_batchlist, where_number_batchlist, \
146 | where_column_batchlist, where_operator_batchlist, where_value_index_batchlist, \
147 | where_value_batchlist = seq2sql_model_training_functions.get_cnt_sw_list(
148 | select_column_ground, select_agg_ground,
149 | where_number_ground, where_column_ground,
150 | where_operator_ground, where_value_index_ground,
151 | select_column_predict, select_agg_predict, where_number_predict,
152 | where_column_predict, where_operator_predict, where_val_index_predict,
153 | sql_canonical, sql_canonical_predict,
154 | mode='train')
155 |
156 | logical_form_acc_batchlist = seq2sql_model_training_functions.get_cnt_lx_list(
157 | select_col_batchlist, select_agg_batchlist, where_number_batchlist,
158 | where_column_batchlist,where_operator_batchlist, where_value_batchlist)
159 | # lx stands for logical form accuracy
160 | # Execution accuracy test.
161 | execution_acc_batchlist, _, _ = seq2sql_model_training_functions.get_cnt_x_list(
162 | engine, table_metadata, select_column_ground, select_agg_ground,
163 | sql_canonical, select_column_predict, select_agg_predict, sql_canonical_predict)
164 | # statistics
165 | average_loss += loss.item()
166 |
167 | # count
168 | count_select_column += sum(select_col_batchlist)
169 | count_select_agg += sum(select_agg_batchlist)
170 | count_where_number += sum(where_number_batchlist)
171 | count_where_column += sum(where_column_batchlist)
172 | count_where_operator += sum(where_operator_batchlist)
173 | count_where_value_index += sum(where_value_index_batchlist)
174 | count_where_value += sum(where_value_batchlist)
175 | count_logical_form_acc += sum(logical_form_acc_batchlist)
176 | count_execution_acc += sum(execution_acc_batchlist)
177 |
178 | average_loss /= count
179 | select_column_acc = count_select_column / count
180 | select_agg_acc = count_select_agg / count
181 | where_number_acc = count_where_number / count
182 | where_column_acc = count_where_column / count
183 | where_operator_acc = count_where_operator / count
184 | where_value_index_acc = count_where_value_index / count
185 | where_value_acc = count_where_value / count
186 | logical_form_acc = count_logical_form_acc / count
187 | execution_acc = count_execution_acc / count
188 | accuracy = [average_loss, select_column_acc, select_agg_acc, where_number_acc, where_column_acc,
189 | where_operator_acc, where_value_index_acc, where_value_acc, logical_form_acc, execution_acc]
190 |
191 | return accuracy
192 |
193 |
194 | def test(seq2sql_model,roberta_model,model_optimizer,roberta_tokenizer,roberta_config,path_wikisql,test_loader,mode="dev"):
195 |
196 | roberta_model.eval()
197 | seq2sql_model.eval()
198 |
199 | count_batchlist=[]
200 | results=[]
201 |
202 |
203 | count_select_column = 0 # count the # of correct predictions of select column
204 | count_select_agg = 0 # of selectd aggregation
205 | count_where_number = 0 # of where number
206 | count_where_column = 0 # of where column
207 | count_where_operator = 0 # of where operator
208 | count_where_value = 0 # of where-value
209 | count_where_value_index = 0 # of where-value index (on question tokens)
210 | count_logical_form_acc = 0 # of logical form accuracy
211 | count_execution_acc = 0 # of execution accurac
212 |
213 |
214 | # Engine for SQL querying.
215 | engine = DBEngine(os.path.join(path_wikisql, mode+".db"))
216 |
217 | count = 0
218 | for batch_index, batch in enumerate(tqdm(test_loader)):
219 | count += len(batch)
220 |
221 | # if batch_index > 2:
222 | # break
223 | # Get fields
224 | natural_lang_utterance, natural_lang_utterance_tokenized, sql_canonical, \
225 | _, _, table_metadata, _, headers = load_data.get_fields(batch)
226 |
227 |
228 | select_column_ground, select_agg_ground, where_number_ground, \
229 | where_column_ground, where_operator_ground, _ = roberta_training.get_ground_truth_values(sql_canonical)
230 | # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
231 |
232 |
233 | natural_lang_embeddings, header_embeddings, question_token_length, header_token_length, header_count, \
234 | natural_lang_double_tokenized, punkt_to_roberta_token_indices, roberta_to_punkt_token_indices \
235 | = roberta_training.get_wemb_roberta(roberta_config, roberta_model, roberta_tokenizer,
236 | natural_lang_utterance_tokenized, headers,max_seq_length= 222,
237 | num_out_layers_n=2, num_out_layers_h=2)
238 | # natural_lang_embeddings: natural language embedding
239 | # header_embeddings: header embedding
240 | # question_token_length: token lengths of each question
241 | # header_token_length: header token lengths
242 | # header_count: the number of columns (headers) of the tables.
243 |
244 | where_value_index_ground_corenlp = corenlp_local.get_g_wvi_corenlp(batch)
245 | try:
246 | #
247 | where_value_index_ground = corenlp_local.get_g_wvi_bert_from_g_wvi_corenlp(punkt_to_roberta_token_indices, where_value_index_ground_corenlp)
248 | except:
249 | # Exception happens when where-condition is not found in nlu_tt.
250 | # In this case, that train example is not used.
251 | # During test, that example considered as wrongly answered.
252 | # e.g. train: 32.
253 | for b in range(len(natural_lang_utterance)):
254 | curr_results = {}
255 | curr_results["error"] = "Skip happened"
256 | curr_results["nlu"] = natural_lang_utterance[b]
257 | curr_results["table_id"] = table_metadata[b]["id"]
258 | results.append(curr_results)
259 | continue
260 |
261 |
262 | knowledge = []
263 | for k in batch:
264 | if "bertindex_knowledge" in k:
265 | knowledge.append(k["bertindex_knowledge"])
266 | else:
267 | knowledge.append(max(question_token_length)*[0])
268 |
269 | knowledge_header = []
270 | for k in batch:
271 | if "header_knowledge" in k:
272 | knowledge_header.append(k["header_knowledge"])
273 | else:
274 | knowledge_header.append(max(header_count) * [0])
275 |
276 |
277 |
278 | # score
279 | _, _, _, select_column_predict, select_agg_predict, where_number_predict, sql_predict = seq2sql_model.beam_forward(
280 | natural_lang_embeddings, question_token_length, header_embeddings,
281 | header_token_length, header_count, table_metadata,
282 | natural_lang_utterance_tokenized, natural_lang_double_tokenized,
283 | roberta_to_punkt_token_indices, natural_lang_utterance,
284 | beam_size=4, knowledge=knowledge, knowledge_header=knowledge_header)
285 |
286 | # sort and generate
287 | where_column_predict, where_operator_predict, _, sql_predict = infer_functions.sort_and_generate_pr_w(sql_predict)
288 |
289 | # Follosing variables are just for the consistency with no-EG case.
290 | where_value_index_predict = None # not used
291 |
292 | for b, sql_predict_instance in enumerate(sql_predict):
293 | curr_results = {}
294 | curr_results["query"] = sql_predict_instance
295 | curr_results["table_id"] = table_metadata[b]["id"]
296 | curr_results["nlu"] = natural_lang_utterance[b]
297 | results.append(curr_results)
298 |
299 | # Cacluate accuracy
300 | select_column_batchlist, select_agg_batchlist, where_number_batchlist, \
301 | where_column_batchlist, where_operator_batchlist, \
302 | where_value_index_batchlist, where_value_batchlist = seq2sql_model_training_functions.get_cnt_sw_list(
303 | select_column_ground, select_agg_ground, where_number_ground,
304 | where_column_ground, where_operator_ground, where_value_index_ground,
305 | select_column_predict, select_agg_predict, where_number_predict, where_column_predict,
306 | where_operator_predict, where_value_index_predict,
307 | sql_canonical, sql_predict,
308 | mode='test')
309 |
310 | logical_form_acc_batchlist = seq2sql_model_training_functions.get_cnt_lx_list(select_column_batchlist, select_agg_batchlist, where_number_batchlist, where_column_batchlist,
311 | where_operator_batchlist, where_value_batchlist)
312 | # lx stands for logical form accuracy
313 |
314 | # Execution accuracy test.
315 | execution_acc_batchlist, _, _ = seq2sql_model_training_functions.get_cnt_x_list(
316 | engine, table_metadata, select_column_ground, select_agg_ground, sql_canonical, select_column_predict, select_agg_predict, sql_predict)
317 |
318 | # statistics
319 | # ave_loss += loss.item()
320 |
321 | # count
322 | count_select_column += sum(select_column_batchlist)
323 | count_select_agg += sum(select_agg_batchlist)
324 | count_where_number += sum(where_number_batchlist)
325 | count_where_column += sum(where_column_batchlist)
326 | count_where_operator += sum(where_operator_batchlist)
327 | count_where_value_index += sum(where_value_index_batchlist)
328 | count_where_value += sum(where_value_batchlist)
329 | count_logical_form_acc += sum(logical_form_acc_batchlist)
330 | count_execution_acc += sum(execution_acc_batchlist)
331 |
332 | count_curr_batchlist = [select_column_batchlist, select_agg_batchlist, where_number_batchlist, where_column_batchlist, where_operator_batchlist, where_value_batchlist, logical_form_acc_batchlist,execution_acc_batchlist]
333 | count_batchlist.append(count_curr_batchlist)
334 |
335 | # ave_loss /= cnt
336 | select_column_acc = count_select_column / count
337 | select_agg_acc = count_select_agg / count
338 | where_number_acc = count_where_number / count
339 | where_column_acc = count_where_column / count
340 | where_operator_acc = count_where_operator / count
341 | where_value_index_acc = count_where_value_index / count
342 | where_value_acc = count_where_value / count
343 | logical_form_acc = count_logical_form_acc / count
344 | execution_acc = count_execution_acc / count
345 |
346 | accuracy = [None, select_column_acc, select_agg_acc, where_number_acc,
347 | where_column_acc, where_operator_acc, where_value_index_acc,
348 | where_value_acc, logical_form_acc, execution_acc]
349 |
350 | return accuracy, results, count_batchlist
351 |
--------------------------------------------------------------------------------
/seq2sql_model_classes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from copy import deepcopy
4 | from matplotlib.pylab import *
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from seq2sql_model_internal_functions import *
9 | #import torch_xla
10 | #import torch_xla.core.xla_model as xm
11 |
12 | device = torch.device("cuda")
13 |
14 |
15 | class Seq2SQL_v1(nn.Module):
16 | def __init__(self, iS, hS, lS, dr, n_cond_ops, n_agg_ops, old=False):
17 | super(Seq2SQL_v1, self).__init__()
18 | self.iS = iS
19 | self.hS = hS
20 | self.ls = lS
21 | self.dr = dr
22 |
23 | self.max_wn = 4
24 | self.n_cond_ops = n_cond_ops
25 | self.n_agg_ops = n_agg_ops
26 |
27 | self.scp = SCP(iS, hS, lS, dr)
28 | self.sap = SAP(iS, hS, lS, dr, n_agg_ops, old=old)
29 | self.wnp = WNP(iS, hS, lS, dr)
30 | self.wcp = WCP(iS, hS, lS, dr)
31 | self.wop = WOP(iS, hS, lS, dr, n_cond_ops)
32 | # start-end-search-discriminative model
33 | self.wvp = WVP_se(iS, hS, lS, dr, n_cond_ops, old=old)
34 |
35 | def forward(self, wemb_n, l_n, wemb_h, l_hpu, l_hs,
36 | g_sc=None, g_sa=None, g_wn=None, g_wc=None, g_wo=None, g_wvi=None,
37 | show_p_sc=False, show_p_sa=False,
38 | show_p_wn=False, show_p_wc=False, show_p_wo=False, show_p_wv=False,
39 | knowledge=None,
40 | knowledge_header=None):
41 | # sc
42 | s_sc = self.scp(wemb_n, l_n, wemb_h, l_hpu, l_hs, show_p_sc=show_p_sc,
43 | knowledge=knowledge, knowledge_header=knowledge_header)
44 |
45 | if g_sc:
46 | pr_sc = g_sc
47 | else:
48 | pr_sc = pred_sc(s_sc)
49 | # sa
50 | s_sa = self.sap(wemb_n, l_n, wemb_h, l_hpu, l_hs, pr_sc, show_p_sa=show_p_sa,
51 | knowledge=knowledge, knowledge_header=knowledge_header)
52 | if g_sa:
53 | # it's not necessary though.
54 | pr_sa = g_sa
55 | else:
56 | pr_sa = pred_sa(s_sa)
57 | # wn
58 | s_wn = self.wnp(wemb_n, l_n, wemb_h, l_hpu, l_hs, show_p_wn=show_p_wn,
59 | knowledge=knowledge, knowledge_header=knowledge_header)
60 |
61 | if g_wn:
62 | pr_wn = g_wn
63 | else:
64 | pr_wn = pred_wn(s_wn)
65 | # wc
66 | s_wc = self.wcp(wemb_n, l_n, wemb_h, l_hpu, l_hs, show_p_wc=show_p_wc, penalty=True, predict_select_column=pr_sc,
67 | knowledge=knowledge, knowledge_header=knowledge_header)
68 |
69 | if g_wc:
70 | pr_wc = g_wc
71 | else:
72 | pr_wc = pred_wc(pr_wn, s_wc)
73 | # for b, columns in enumerate(pr_wc):
74 | # for c in columns:
75 | # s_sc[b, c] = -1e+10
76 |
77 | # wo
78 | s_wo = self.wop(wemb_n, l_n, wemb_h, l_hpu, l_hs, wn=pr_wn, wc=pr_wc, show_p_wo=show_p_wo,
79 | knowledge=knowledge, knowledge_header=knowledge_header)
80 |
81 | if g_wo:
82 | pr_wo = g_wo
83 | else:
84 | pr_wo = pred_wo(pr_wn, s_wo)
85 | # wv
86 | s_wv = self.wvp(wemb_n, l_n, wemb_h, l_hpu, l_hs, wn=pr_wn, wc=pr_wc, wo=pr_wo, show_p_wv=show_p_wv,
87 | knowledge=knowledge, knowledge_header=knowledge_header)
88 | return s_sc, s_sa, s_wn, s_wc, s_wo, s_wv
89 |
90 | def beam_forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, tb,
91 | nlu_t, nlu_wp_t, wp_to_wh_index, nlu,
92 | beam_size=4,
93 | show_p_sc=False, show_p_sa=False,
94 | show_p_wn=False, show_p_wc=False, show_p_wo=False, show_p_wv=False,
95 | knowledge=None,
96 | knowledge_header=None):
97 | """
98 | Execution-guided beam decoding.
99 | """
100 | # s_sc = [batch_size, header_len]
101 | s_sc = self.scp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_sc=show_p_sc,
102 | knowledge=knowledge, knowledge_header=knowledge_header)
103 | prob_sc = F.softmax(s_sc, dim=-1)
104 | bS, mcL = s_sc.shape
105 |
106 | # minimum_hs_length = min(l_hs)
107 | # beam_size = minimum_hs_length if beam_size > minimum_hs_length else beam_size
108 |
109 | # sa
110 | # Construct all possible sc_sa_score
111 | prob_sc_sa = torch.zeros([bS, beam_size, self.n_agg_ops]).to(device)
112 | prob_sca = torch.zeros_like(prob_sc_sa).to(device)
113 |
114 | # get the top-k indices. pr_sc_beam = [B, beam_size]
115 | pr_sc_beam = pred_sc_beam(s_sc, beam_size)
116 |
117 | # calculate and predict s_sa.
118 | for i_beam in range(beam_size):
119 | pr_sc = list(array(pr_sc_beam)[:, i_beam]) # pr_sc = [batch_size]
120 | s_sa = self.sap(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, pr_sc, show_p_sa=show_p_sa,
121 | knowledge=knowledge, knowledge_header=knowledge_header)
122 | prob_sa = F.softmax(s_sa, dim=-1)
123 | prob_sc_sa[:, i_beam, :] = prob_sa
124 |
125 | prob_sc_selected = prob_sc[range(bS), pr_sc] # [B]
126 | prob_sca[:, i_beam, :] = (prob_sa.t() * prob_sc_selected).t()
127 | # [mcL, B] * [B] -> [mcL, B] (element-wise multiplication)
128 | # [mcL, B] -> [B, mcL]
129 |
130 | # Calculate the dimension of tensor
131 | # tot_dim = len(prob_sca.shape)
132 |
133 | # First flatten to 1-d
134 | idxs = topk_multi_dim(torch.tensor(prob_sca),
135 | n_topk=beam_size, batch_exist=True)
136 | # Now as sc_idx is already sorted, re-map them properly.
137 |
138 | # [sc_beam_idx, sa_idx] -> [sc_idx, sa_idx]
139 | idxs = remap_sc_idx(idxs, pr_sc_beam)
140 | idxs_arr = array(idxs)
141 | # [B, beam_size, remainig dim]
142 | # idxs[b][0] gives first probable [sc_idx, sa_idx] pairs.
143 | # idxs[b][1] gives of second.
144 |
145 | # Calculate prob_sca, a joint probability
146 | beam_idx_sca = [0] * bS
147 | beam_meet_the_final = [False] * bS
148 | while True:
149 | pr_sc = idxs_arr[range(bS), beam_idx_sca, 0]
150 | pr_sa = idxs_arr[range(bS), beam_idx_sca, 1]
151 |
152 | # map index properly
153 |
154 | check = check_sc_sa_pairs(tb, pr_sc, pr_sa)
155 |
156 | if sum(check) == bS:
157 | break
158 | else:
159 | for b, check1 in enumerate(check):
160 | if not check1: # wrong pair
161 | beam_idx_sca[b] += 1
162 | if beam_idx_sca[b] >= beam_size:
163 | beam_meet_the_final[b] = True
164 | beam_idx_sca[b] -= 1
165 | else:
166 | beam_meet_the_final[b] = True
167 |
168 | if sum(beam_meet_the_final) == bS:
169 | break
170 |
171 | # Now pr_sc, pr_sa are properly predicted.
172 | pr_sc_best = list(pr_sc)
173 | pr_sa_best = list(pr_sa)
174 |
175 | # Now, Where-clause beam search.
176 | s_wn = self.wnp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wn=show_p_wn,
177 | knowledge=knowledge, knowledge_header=knowledge_header)
178 | prob_wn = F.softmax(s_wn, dim=-1).detach().to('cpu').numpy()
179 |
180 | # Found "executable" most likely 4(=max_num_of_conditions) where-clauses.
181 | # wc
182 | s_wc = self.wcp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wc=show_p_wc, penalty=True,
183 | knowledge=knowledge, knowledge_header=knowledge_header)
184 | prob_wc = F.sigmoid(s_wc).detach().to('cpu').numpy()
185 | # pr_wc_sorted_by_prob = pred_wc_sorted_by_prob(s_wc)
186 |
187 | # get max_wn # of most probable columns & their prob.
188 | pr_wn_max = [self.max_wn]*bS
189 | # if some column do not have executable where-claouse, omit that column
190 | pr_wc_max = pred_wc(pr_wn_max, s_wc)
191 | prob_wc_max = zeros([bS, self.max_wn])
192 | for b, pr_wc_max1 in enumerate(pr_wc_max):
193 | prob_wc_max[b, :] = prob_wc[b, pr_wc_max1]
194 |
195 | # get most probable max_wn where-clouses
196 | # wo
197 | s_wo_max = self.wop(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn=pr_wn_max, wc=pr_wc_max, show_p_wo=show_p_wo,
198 | knowledge=knowledge, knowledge_header=knowledge_header)
199 | prob_wo_max = F.softmax(s_wo_max, dim=-1).detach().to('cpu').numpy()
200 | # [B, max_wn, n_cond_op]
201 |
202 | pr_wvi_beam_op_list = []
203 | prob_wvi_beam_op_list = []
204 | for i_op in range(self.n_cond_ops-1):
205 | pr_wo_temp = [[i_op]*self.max_wn]*bS
206 | # wv
207 | s_wv = self.wvp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn=pr_wn_max, wc=pr_wc_max, wo=pr_wo_temp, show_p_wv=show_p_wv,
208 | knowledge=knowledge, knowledge_header=knowledge_header)
209 | prob_wv = F.softmax(s_wv, dim=-2).detach().to('cpu').numpy()
210 |
211 | # prob_wv
212 | pr_wvi_beam, prob_wvi_beam = pred_wvi_se_beam(
213 | self.max_wn, s_wv, beam_size)
214 | pr_wvi_beam_op_list.append(pr_wvi_beam)
215 | prob_wvi_beam_op_list.append(prob_wvi_beam)
216 | # pr_wvi_beam = [B, max_wn, k_logit**2 [st, ed] paris]
217 |
218 | # pred_wv_beam
219 |
220 | # Calculate joint probability of where-clause
221 | # prob_w = [batch, wc, wo, wv] = [B, max_wn, n_cond_op, n_pairs]
222 | n_wv_beam_pairs = prob_wvi_beam.shape[2]
223 | prob_w = zeros([bS, self.max_wn, self.n_cond_ops-1, n_wv_beam_pairs])
224 | for b in range(bS):
225 | for i_wn in range(self.max_wn):
226 | for i_op in range(self.n_cond_ops-1): # do not use final one
227 | for i_wv_beam in range(n_wv_beam_pairs):
228 | # i_wc = pr_wc_max[b][i_wn] # already done
229 | p_wc = prob_wc_max[b, i_wn]
230 | p_wo = prob_wo_max[b, i_wn, i_op]
231 | p_wv = prob_wvi_beam_op_list[i_op][b, i_wn, i_wv_beam]
232 |
233 | prob_w[b, i_wn, i_op, i_wv_beam] = p_wc * p_wo * p_wv
234 |
235 | # Perform execution guided decoding
236 | conds_max = []
237 | prob_conds_max = []
238 | # while len(conds_max) < self.max_wn:
239 | idxs = topk_multi_dim(torch.tensor(
240 | prob_w), n_topk=beam_size, batch_exist=True)
241 | # idxs = [B, i_wc_beam, i_op, i_wv_pairs]
242 |
243 | # Construct conds1
244 | for b, idxs1 in enumerate(idxs):
245 | conds_max1 = []
246 | prob_conds_max1 = []
247 | for i_wn, idxs11 in enumerate(idxs1):
248 | i_wc = pr_wc_max[b][idxs11[0]]
249 | i_op = idxs11[1]
250 | wvi = pr_wvi_beam_op_list[i_op][b][idxs11[0]][idxs11[2]]
251 |
252 | # get wv_str
253 | temp_pr_wv_str, _ = convert_pr_wvi_to_string(
254 | [[wvi]], [nlu_t[b]], [nlu_wp_t[b]], [wp_to_wh_index[b]], [nlu[b]])
255 | merged_wv11 = merge_wv_t1_eng(temp_pr_wv_str[0][0], nlu[b])
256 | conds11 = [i_wc, i_op, merged_wv11]
257 |
258 | prob_conds11 = prob_w[b, idxs11[0], idxs11[1], idxs11[2]]
259 |
260 | # test execution
261 | # print(nlu[b])
262 | # print(tb[b]['id'], tb[b]['types'], pr_sc[b], pr_sa[b], [conds11])
263 | # pr_ans = engine.execute(
264 | # tb[b]['id'], pr_sc[b], pr_sa[b], [conds11])
265 | # if bool(pr_ans):
266 | # # pr_ans is not empty!
267 | conds_max1.append(conds11)
268 | prob_conds_max1.append(prob_conds11)
269 | conds_max.append(conds_max1)
270 | prob_conds_max.append(prob_conds_max1)
271 |
272 | # May need to do more exhuastive search?
273 | # i.e. up to.. getting all executable cases.
274 |
275 | # Calculate total probability to decide the number of where-clauses
276 | pr_sql_i = []
277 | prob_wn_w = []
278 | pr_wn_based_on_prob = []
279 |
280 | for b, prob_wn1 in enumerate(prob_wn):
281 | max_executable_wn1 = len(conds_max[b])
282 | prob_wn_w1 = []
283 | prob_wn_w1.append(prob_wn1[0]) # wn=0 case.
284 | for i_wn in range(max_executable_wn1):
285 | prob_wn_w11 = prob_wn1[i_wn+1] * prob_conds_max[b][i_wn]
286 | prob_wn_w1.append(prob_wn_w11)
287 | pr_wn_based_on_prob.append(argmax(prob_wn_w1))
288 | prob_wn_w.append(prob_wn_w1)
289 |
290 | pr_sql_i1 = {'agg': pr_sa_best[b], 'sel': pr_sc_best[b],
291 | 'conds': conds_max[b][:pr_wn_based_on_prob[b]]}
292 | pr_sql_i.append(pr_sql_i1)
293 | # s_wv = [B, max_wn, max_nlu_tokens, 2]
294 | return prob_sca, prob_w, prob_wn_w, pr_sc_best, pr_sa_best, pr_wn_based_on_prob, pr_sql_i
295 |
296 |
297 | class SCP(nn.Module):
298 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3):
299 | super(SCP, self).__init__()
300 | self.iS = iS
301 | self.hS = hS
302 | self.lS = lS
303 | self.dr = dr
304 |
305 | self.question_knowledge_dim = 5
306 | self.header_knowledge_dim = 3
307 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
308 | num_layers=lS, batch_first=True,
309 | dropout=dr, bidirectional=True)
310 |
311 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
312 | num_layers=lS, batch_first=True,
313 | dropout=dr, bidirectional=True)
314 |
315 | self.W_att = nn.Linear(
316 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim)
317 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS)
318 | self.W_hs = nn.Linear(hS+self.header_knowledge_dim, hS)
319 | self.sc_out = nn.Sequential(nn.Tanh(), nn.Linear(2 * hS, 1))
320 |
321 | self.softmax_dim1 = nn.Softmax(dim=1)
322 | self.softmax_dim2 = nn.Softmax(dim=2)
323 |
324 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_sc=False,
325 | knowledge=None,
326 | knowledge_header=None):
327 | # Encode
328 | mL_n = max(l_n)
329 | bS = len(l_hs)
330 | wenc_n = encode(self.enc_n, wemb_n, l_n,
331 | return_hidden=False,
332 | hc0=None,
333 | last_only=False) # [b, n, dim]
334 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge]
335 | knowledge = torch.tensor(knowledge).unsqueeze(-1)
336 |
337 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1,
338 | index=knowledge,
339 | value=1).to(device)
340 | wenc_n = torch.cat([wenc_n, feature], -1)
341 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim]
342 | knowledge_header = [k + (max(l_hs) - len(k)) * [0]
343 | for k in knowledge_header]
344 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1)
345 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1,
346 | index=knowledge_header,
347 | value=1).to(device)
348 | wenc_hs = torch.cat([wenc_hs, feature2], -1)
349 | bS = len(l_hs)
350 | mL_n = max(l_n)
351 |
352 | # [bS, mL_hs, 100] * [bS, 100, mL_n] -> [bS, mL_hs, mL_n]
353 | att_h = torch.bmm(wenc_hs, self.W_att(wenc_n).transpose(1, 2))
354 |
355 | # Penalty on blank parts
356 | for b, l_n1 in enumerate(l_n):
357 | if l_n1 < mL_n:
358 | att_h[b, :, l_n1:] = -10000000000
359 |
360 | p_n = self.softmax_dim2(att_h)
361 | if show_p_sc:
362 | # p = [b, hs, n]
363 | if p_n.shape[0] != 1:
364 | raise Exception("Batch size should be 1.")
365 | fig = figure(2001, figsize=(12, 3.5))
366 | # subplot(6,2,7)
367 | subplot2grid((7, 2), (3, 0), rowspan=2)
368 | cla()
369 | _color = 'rgbkcm'
370 | _symbol = '.......'
371 | for i_h in range(l_hs[0]):
372 | color_idx = i_h % len(_color)
373 | plot(p_n[0][i_h][:].data.numpy() - i_h, '--' +
374 | _symbol[color_idx]+_color[color_idx], ms=7)
375 |
376 | title('sc: p_n for each h')
377 | grid(True)
378 | fig.tight_layout()
379 | fig.canvas.draw()
380 | show()
381 |
382 | # p_n [ bS, mL_hs, mL_n] -> [ bS, mL_hs, mL_n, 1]
383 | # wenc_n [ bS, mL_n, 100] -> [ bS, 1, mL_n, 100]
384 | # -> [bS, mL_hs, mL_n, 100] -> [bS, mL_hs, 100]
385 | c_n = torch.mul(p_n.unsqueeze(3), wenc_n.unsqueeze(1)).sum(dim=2)
386 |
387 | vec = torch.cat([self.W_c(c_n), self.W_hs(wenc_hs)], dim=2)
388 | s_sc = self.sc_out(vec).squeeze(2) # [bS, mL_hs, 1] -> [bS, mL_hs]
389 |
390 | # Penalty
391 | mL_hs = max(l_hs)
392 | for b, l_hs1 in enumerate(l_hs):
393 | if l_hs1 < mL_hs:
394 | s_sc[b, l_hs1:] = -10000000000
395 |
396 | return s_sc
397 |
398 |
399 | class SAP(nn.Module):
400 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, n_agg_ops=-1, old=False):
401 | super(SAP, self).__init__()
402 | self.iS = iS
403 | self.hS = hS
404 | self.lS = lS
405 | self.dr = dr
406 |
407 | self.question_knowledge_dim = 5
408 | self.header_knowledge_dim = 3
409 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
410 | num_layers=lS, batch_first=True,
411 | dropout=dr, bidirectional=True)
412 |
413 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
414 | num_layers=lS, batch_first=True,
415 | dropout=dr, bidirectional=True)
416 |
417 | self.W_att = nn.Linear(
418 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim)
419 | self.sa_out = nn.Sequential(nn.Linear(hS + self.question_knowledge_dim, hS),
420 | nn.Tanh(),
421 | nn.Linear(hS, n_agg_ops)) # Fixed number of aggregation operator.
422 |
423 | self.softmax_dim1 = nn.Softmax(dim=1)
424 | self.softmax_dim2 = nn.Softmax(dim=2)
425 |
426 | if old:
427 | # for backwoard compatibility
428 | self.W_c = nn.Linear(hS, hS)
429 | self.W_hs = nn.Linear(hS, hS)
430 |
431 | # wemb_hpu [batch_size*header_num, max_header_len, hidden_dim]
432 | # l_hpu [batch_size*header_num]
433 | # l_hs [batch_size]
434 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, pr_sc, show_p_sa=False,
435 | knowledge=None,
436 | knowledge_header=None):
437 | # Encode
438 | mL_n = max(l_n)
439 | bS = len(l_hs)
440 | wenc_n = encode(self.enc_n, wemb_n, l_n,
441 | return_hidden=False,
442 | hc0=None,
443 | last_only=False) # [b, n, dim]
444 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge]
445 | knowledge = torch.tensor(knowledge).unsqueeze(-1)
446 |
447 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1,
448 | index=knowledge,
449 | value=1).to(device)
450 | wenc_n = torch.cat([wenc_n, feature], -1)
451 |
452 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim]
453 | knowledge_header = [k + (max(l_hs) - len(k)) * [0]
454 | for k in knowledge_header]
455 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1)
456 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1,
457 | index=knowledge_header,
458 | value=1).to(device)
459 | wenc_hs = torch.cat([wenc_hs, feature2], -1)
460 | bS = len(l_hs)
461 | mL_n = max(l_n)
462 |
463 | # list, so one sample for each batch.
464 | wenc_hs_ob = wenc_hs[list(range(bS)), pr_sc]
465 |
466 | # [bS, mL_n, 100] * [bS, 100, 1] -> [bS, mL_n]
467 | att = torch.bmm(self.W_att(wenc_n), wenc_hs_ob.unsqueeze(2)).squeeze(2)
468 |
469 | # Penalty on blank parts
470 | for b, l_n1 in enumerate(l_n):
471 | if l_n1 < mL_n:
472 | att[b, l_n1:] = -10000000000
473 | # [bS, mL_n]
474 | p = self.softmax_dim1(att)
475 |
476 | if show_p_sa:
477 | if p.shape[0] != 1:
478 | raise Exception("Batch size should be 1.")
479 | fig = figure(2001)
480 | subplot(7, 2, 3)
481 | cla()
482 | plot(p[0].data.numpy(), '--rs', ms=7)
483 | title('sa: nlu_weight')
484 | grid(True)
485 | fig.tight_layout()
486 | fig.canvas.draw()
487 | show()
488 |
489 | # [bS, mL_n, 100] * ( [bS, mL_n, 1] -> [bS, mL_n, 100])
490 | # -> [bS, mL_n, 100] -> [bS, 100]
491 | c_n = torch.mul(wenc_n, p.unsqueeze(2).expand_as(wenc_n)).sum(dim=1)
492 | s_sa = self.sa_out(c_n)
493 |
494 | return s_sa
495 |
496 |
497 | class WNP(nn.Module):
498 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, ):
499 | super(WNP, self).__init__()
500 | self.iS = iS
501 | self.hS = hS
502 | self.lS = lS
503 | self.dr = dr
504 |
505 | self.mL_w = 4 # max where condition number
506 | self.question_knowledge_dim = 5
507 | self.header_knowledge_dim = 3
508 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
509 | num_layers=lS, batch_first=True,
510 | dropout=dr, bidirectional=True)
511 |
512 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
513 | num_layers=lS, batch_first=True,
514 | dropout=dr, bidirectional=True)
515 |
516 | self.W_att_h = nn.Linear(hS + self.header_knowledge_dim, 1)
517 | self.W_hidden = nn.Linear(hS + self.header_knowledge_dim, lS * hS)
518 | self.W_cell = nn.Linear(hS + self.header_knowledge_dim, lS * hS)
519 |
520 | self.W_att_n = nn.Linear(hS + self.question_knowledge_dim, 1)
521 | self.wn_out = nn.Sequential(nn.Linear(hS + self.question_knowledge_dim, hS),
522 | nn.Tanh(),
523 | nn.Linear(hS, self.mL_w + 1)) # max number (4 + 1)
524 |
525 | self.softmax_dim1 = nn.Softmax(dim=1)
526 | self.softmax_dim2 = nn.Softmax(dim=2)
527 |
528 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wn=False,
529 | knowledge=None,
530 | knowledge_header=None):
531 | # Encode
532 | mL_n = max(l_n)
533 | bS = len(l_hs)
534 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu,
535 | l_hs) # [b, mL_hs, dim]
536 | knowledge_header = [k + (max(l_hs) - len(k)) * [0]
537 | for k in knowledge_header]
538 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1)
539 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1,
540 | index=knowledge_header,
541 | value=1).to(device)
542 | wenc_hs = torch.cat([wenc_hs, feature2], -1)
543 |
544 | bS = len(l_hs)
545 | mL_n = max(l_n)
546 | mL_hs = max(l_hs)
547 | # mL_h = max(l_hpu)
548 |
549 | # (self-attention?) column Embedding?
550 | # [B, mL_hs, 100] -> [B, mL_hs, 1] -> [B, mL_hs]
551 | att_h = self.W_att_h(wenc_hs).squeeze(2)
552 |
553 | # Penalty
554 | for b, l_hs1 in enumerate(l_hs):
555 | if l_hs1 < mL_hs:
556 | att_h[b, l_hs1:] = -10000000000
557 | p_h = self.softmax_dim1(att_h)
558 |
559 | if show_p_wn:
560 | if p_h.shape[0] != 1:
561 | raise Exception("Batch size should be 1.")
562 | fig = figure(2001)
563 | subplot(7, 2, 5)
564 | cla()
565 | plot(p_h[0].data.numpy(), '--rs', ms=7)
566 | title('wn: header_weight')
567 | grid(True)
568 | fig.canvas.draw()
569 | show()
570 | # input('Type Eenter to continue.')
571 |
572 | # [B, mL_hs, 100] * [ B, mL_hs, 1] -> [B, mL_hs, 100] -> [B, 100]
573 | c_hs = torch.mul(wenc_hs, p_h.unsqueeze(2)).sum(1)
574 |
575 | # [B, 100] --> [B, 2*100] Enlarge because there are two layers.
576 | hidden = self.W_hidden(c_hs) # [B, 4, 200/2]
577 | hidden = hidden.view(bS, self.lS * 2, int(
578 | self.hS / 2)) # [4, B, 100/2] # number_of_layer_layer * (bi-direction) # lstm input convention.
579 | hidden = hidden.transpose(0, 1).contiguous()
580 |
581 | cell = self.W_cell(c_hs) # [B, 4, 100/2]
582 | cell = cell.view(bS, self.lS * 2, int(self.hS / 2)) # [4, B, 100/2]
583 | cell = cell.transpose(0, 1).contiguous()
584 |
585 | wenc_n = encode(self.enc_n, wemb_n, l_n,
586 | return_hidden=False,
587 | hc0=(hidden, cell),
588 | last_only=False) # [b, n, dim]
589 |
590 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge]
591 | knowledge = torch.tensor(knowledge).unsqueeze(-1)
592 |
593 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1,
594 | index=knowledge,
595 | value=1).to(device)
596 | wenc_n = torch.cat([wenc_n, feature], -1)
597 |
598 | # [B, max_len, 100] -> [B, max_len, 1] -> [B, max_len]
599 | att_n = self.W_att_n(wenc_n).squeeze(2)
600 |
601 | # Penalty
602 | for b, l_n1 in enumerate(l_n):
603 | if l_n1 < mL_n:
604 | att_n[b, l_n1:] = -10000000000
605 | p_n = self.softmax_dim1(att_n)
606 |
607 | if show_p_wn:
608 | if p_n.shape[0] != 1:
609 | raise Exception("Batch size should be 1.")
610 | fig = figure(2001)
611 | subplot(7, 2, 6)
612 | cla()
613 | plot(p_n[0].data.numpy(), '--rs', ms=7)
614 | title('wn: nlu_weight')
615 | grid(True)
616 | fig.canvas.draw()
617 |
618 | show()
619 | # input('Type Enter to continue.')
620 |
621 | # [B, mL_n, 100] *([B, mL_n] -> [B, mL_n, 1] -> [B, mL_n, 100] ) -> [B, 100]
622 | c_n = torch.mul(wenc_n, p_n.unsqueeze(2).expand_as(wenc_n)).sum(dim=1)
623 | s_wn = self.wn_out(c_n)
624 |
625 | return s_wn
626 |
627 |
628 | class WCP(nn.Module):
629 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3):
630 | super(WCP, self).__init__()
631 | self.iS = iS
632 | self.hS = hS
633 | self.lS = lS
634 | self.dr = dr
635 | self.question_knowledge_dim = 5
636 | self.header_knowledge_dim = 3
637 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
638 | num_layers=lS, batch_first=True,
639 | dropout=dr, bidirectional=True)
640 |
641 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
642 | num_layers=lS, batch_first=True,
643 | dropout=dr, bidirectional=True)
644 |
645 | self.W_att = nn.Linear(
646 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim)
647 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS)
648 | self.W_hs = nn.Linear(hS + self.header_knowledge_dim, hS)
649 | self.W_out = nn.Sequential(
650 | nn.Tanh(), nn.Linear(2 * hS, 1)
651 | )
652 |
653 | self.softmax_dim1 = nn.Softmax(dim=1)
654 | self.softmax_dim2 = nn.Softmax(dim=2)
655 |
656 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wc, penalty=True, predict_select_column=None,
657 | knowledge=None,
658 | knowledge_header=None):
659 | # Encode
660 | mL_n = max(l_n)
661 | bS = len(l_hs)
662 | wenc_n = encode(self.enc_n, wemb_n, l_n,
663 | return_hidden=False,
664 | hc0=None,
665 | last_only=False) # [b, n, dim]
666 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge]
667 | knowledge = torch.tensor(knowledge).unsqueeze(-1)
668 |
669 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1,
670 | index=knowledge,
671 | value=1).to(device)
672 | wenc_n = torch.cat([wenc_n, feature], -1)
673 |
674 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim]
675 | knowledge_header = [k + (max(l_hs) - len(k)) * [0]
676 | for k in knowledge_header]
677 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1)
678 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1,
679 | index=knowledge_header,
680 | value=1).to(device)
681 | wenc_hs = torch.cat([wenc_hs, feature2], -1)
682 | # attention
683 | # wenc = [bS, mL, hS]
684 | # att = [bS, mL_hs, mL_n]
685 | # att[b, i_h, j_n] = p(j_n| i_h)
686 | att = torch.bmm(wenc_hs, self.W_att(wenc_n).transpose(1, 2))
687 |
688 | # penalty to blank part.
689 | mL_n = max(l_n)
690 | for b_n, l_n1 in enumerate(l_n):
691 | if l_n1 < mL_n:
692 | att[b_n, :, l_n1:] = -10000000000
693 |
694 | # for b, c in enumerate(predict_select_column):
695 | # att[b, c, :] = -10000000000
696 |
697 | # make p(j_n | i_h)
698 | p = self.softmax_dim2(att)
699 |
700 | if show_p_wc:
701 | # p = [b, hs, n]
702 | if p.shape[0] != 1:
703 | raise Exception("Batch size should be 1.")
704 | fig = figure(2001)
705 | # subplot(6,2,7)
706 | subplot2grid((7, 2), (3, 1), rowspan=2)
707 | cla()
708 | _color = 'rgbkcm'
709 | _symbol = '.......'
710 | for i_h in range(l_hs[0]):
711 | color_idx = i_h % len(_color)
712 | plot(p[0][i_h][:].data.numpy() - i_h, '--' +
713 | _symbol[color_idx]+_color[color_idx], ms=7)
714 |
715 | title('wc: p_n for each h')
716 | grid(True)
717 | fig.tight_layout()
718 | fig.canvas.draw()
719 | show()
720 | # max nlu context vectors
721 | # [bS, mL_hs, mL_n]*[bS, mL_hs, mL_n]
722 | wenc_n = wenc_n.unsqueeze(1) # [ b, n, dim] -> [b, 1, n, dim]
723 | p = p.unsqueeze(3) # [b, hs, n] -> [b, hs, n, 1]
724 | # -> [b, hs, dim], c_n for each header.
725 | c_n = torch.mul(wenc_n, p).sum(2)
726 |
727 | # bS = len(l_hs)
728 | # index = torch.tensor(predict_select_column).unsqueeze(-1)
729 | # feature = torch.zeros(bS, max(l_hs)).scatter_(dim=-1,
730 | # index=index,
731 | # value=1).to(device)
732 | # c_n = torch.cat([c_n, feature.unsqueeze(-1)],dim=-1)
733 |
734 | y = torch.cat([self.W_c(c_n), self.W_hs(wenc_hs)],
735 | dim=2) # [b, hs, 2*dim]
736 | score = self.W_out(y).squeeze(2) # [b, hs]
737 |
738 | if penalty:
739 | for b, l_hs1 in enumerate(l_hs):
740 | score[b, l_hs1:] = -1e+10
741 |
742 | # for b, c in enumerate(predict_select_column):
743 | # score[b, c] = -1e+10
744 |
745 | return score
746 |
747 |
748 | class WOP(nn.Module):
749 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, n_cond_ops=3):
750 | super(WOP, self).__init__()
751 | self.iS = iS
752 | self.hS = hS
753 | self.lS = lS
754 | self.dr = dr
755 | self.question_knowledge_dim = 0
756 | self.header_knowledge_dim = 0
757 | self.mL_w = 4 # max where condition number
758 |
759 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
760 | num_layers=lS, batch_first=True,
761 | dropout=dr, bidirectional=True)
762 |
763 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
764 | num_layers=lS, batch_first=True,
765 | dropout=dr, bidirectional=True)
766 |
767 | self.W_att = nn.Linear(
768 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim)
769 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS)
770 | self.W_hs = nn.Linear(hS + self.header_knowledge_dim, hS)
771 | self.wo_out = nn.Sequential(
772 | nn.Linear(2*hS, hS),
773 | nn.Tanh(),
774 | nn.Linear(hS, n_cond_ops)
775 | )
776 |
777 | self.softmax_dim1 = nn.Softmax(dim=1)
778 | self.softmax_dim2 = nn.Softmax(dim=2)
779 |
780 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn, wc, wenc_n=None, show_p_wo=False,
781 | knowledge=None,
782 | knowledge_header=None):
783 | # Encode
784 | mL_n = max(l_n)
785 | bS = len(l_hs)
786 | if not wenc_n:
787 | wenc_n = encode(self.enc_n, wemb_n, l_n,
788 | return_hidden=False,
789 | hc0=None,
790 | last_only=False) # [b, n, dim]
791 | if self.question_knowledge_dim != 0:
792 | knowledge = [k + (mL_n - len(k)) * [0] for k in knowledge]
793 | knowledge = torch.tensor(knowledge).unsqueeze(-1)
794 |
795 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1,
796 | index=knowledge,
797 | value=1).to(device)
798 | wenc_n = torch.cat([wenc_n, feature], -1)
799 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim]
800 | if self.header_knowledge_dim != 0:
801 | knowledge_header = [k + (max(l_hs) - len(k)) * [0]
802 | for k in knowledge_header]
803 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1)
804 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1,
805 | index=knowledge_header,
806 | value=1).to(device)
807 | wenc_hs = torch.cat([wenc_hs, feature2], -1)
808 | bS = len(l_hs)
809 | # wn
810 |
811 | wenc_hs_ob = [] # observed hs
812 | for b in range(bS):
813 | # [[...], [...]]
814 | # Pad list to maximum number of selections
815 | real = [wenc_hs[b, col] for col in wc[b]]
816 | # this padding could be wrong. Test with zero padding later.
817 | pad = (self.mL_w - wn[b]) * [wenc_hs[b, 0]]
818 | # It is not used in the loss function.
819 | wenc_hs_ob1 = torch.stack(real + pad)
820 | wenc_hs_ob.append(wenc_hs_ob1)
821 |
822 | # list to [B, 4, dim] tensor.
823 | wenc_hs_ob = torch.stack(wenc_hs_ob) # list to tensor.
824 | wenc_hs_ob = wenc_hs_ob.to(device)
825 |
826 | # [B, 1, mL_n, dim] * [B, 4, dim, 1]
827 | # -> [B, 4, mL_n, 1] -> [B, 4, mL_n]
828 | # multiplication bewteen NLq-tokens and selected column
829 | att = torch.matmul(self.W_att(wenc_n).unsqueeze(1),
830 | wenc_hs_ob.unsqueeze(3)
831 | ).squeeze(3)
832 | # Penalty for blank part.
833 | mL_n = max(l_n)
834 | for b, l_n1 in enumerate(l_n):
835 | if l_n1 < mL_n:
836 | att[b, :, l_n1:] = -10000000000
837 |
838 | p = self.softmax_dim2(att) # p( n| selected_col )
839 | if show_p_wo:
840 | # p = [b, hs, n]
841 | if p.shape[0] != 1:
842 | raise Exception("Batch size should be 1.")
843 | fig = figure(2001)
844 | # subplot(6,2,7)
845 | subplot2grid((7, 2), (5, 0), rowspan=2)
846 | cla()
847 | _color = 'rgbkcm'
848 | _symbol = '.......'
849 | for i_wn in range(self.mL_w):
850 | color_idx = i_wn % len(_color)
851 | plot(p[0][i_wn][:].data.numpy() - i_wn, '--' +
852 | _symbol[color_idx]+_color[color_idx], ms=7)
853 |
854 | title('wo: p_n for selected h')
855 | grid(True)
856 | fig.tight_layout()
857 | fig.canvas.draw()
858 | show()
859 |
860 | # [B, 1, mL_n, dim] * [B, 4, mL_n, 1]
861 | # --> [B, 4, mL_n, dim]
862 | # --> [B, 4, dim]
863 | c_n = torch.mul(wenc_n.unsqueeze(1), p.unsqueeze(3)).sum(dim=2)
864 |
865 | # [bS, 5-1, dim] -> [bS, 5-1, 3]
866 |
867 | vec = torch.cat([self.W_c(c_n), self.W_hs(wenc_hs_ob)], dim=2)
868 | s_wo = self.wo_out(vec)
869 | return s_wo
870 |
871 |
872 | class WVP_se(nn.Module):
873 | """
874 | Discriminative model
875 | Get start and end.
876 | Here, classifier for [ [투수], [팀1], [팀2], [연도], ...]
877 | Input: Encoded nlu & selected column.
878 | Algorithm: Encoded nlu & selected column. -> classifier -> mask scores -> ...
879 | """
880 |
881 | def __init__(self, iS=300, hS=100, lS=2, dr=0.3, n_cond_ops=4, old=False):
882 | super(WVP_se, self).__init__()
883 | self.iS = iS
884 | self.hS = hS
885 | self.lS = lS
886 | self.dr = dr
887 | self.n_cond_ops = n_cond_ops
888 | self.question_knowledge_dim = 5
889 | self.header_knowledge_dim = 3
890 | self.mL_w = 4 # max where condition number
891 |
892 | self.enc_h = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
893 | num_layers=lS, batch_first=True,
894 | dropout=dr, bidirectional=True)
895 |
896 | self.enc_n = nn.LSTM(input_size=iS, hidden_size=int(hS / 2),
897 | num_layers=lS, batch_first=True,
898 | dropout=dr, bidirectional=True)
899 |
900 | self.W_att = nn.Linear(
901 | hS + self.question_knowledge_dim, hS + self.header_knowledge_dim)
902 | self.W_c = nn.Linear(hS + self.question_knowledge_dim, hS)
903 | self.W_hs = nn.Linear(hS + self.header_knowledge_dim, hS)
904 | self.W_op = nn.Linear(n_cond_ops, hS)
905 |
906 | # self.W_n = nn.Linear(hS, hS)
907 | if old:
908 | self.wv_out = nn.Sequential(
909 | nn.Linear(4 * hS, 2)
910 | )
911 | else:
912 | self.wv_out = nn.Sequential(
913 | nn.Linear(4 * hS + self.question_knowledge_dim, hS),
914 | nn.Tanh(),
915 | nn.Linear(hS, 2)
916 | )
917 | # self.wv_out = nn.Sequential(
918 | # nn.Linear(3 * hS, hS),
919 | # nn.Tanh(),
920 | # nn.Linear(hS, self.gdkL)
921 | # )
922 |
923 | self.softmax_dim1 = nn.Softmax(dim=1)
924 | self.softmax_dim2 = nn.Softmax(dim=2)
925 |
926 | def forward(self, wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn, wc, wo, wenc_n=None, show_p_wv=False,
927 | knowledge=None,
928 | knowledge_header=None):
929 | mL_n = max(l_n)
930 | bS = len(l_hs)
931 | # Encode
932 | if not wenc_n:
933 | wenc_n, hout, cout = encode(self.enc_n, wemb_n, l_n,
934 | return_hidden=True,
935 | hc0=None,
936 | last_only=False) # [b, n, dim]
937 |
938 | knowledge = [k+(mL_n-len(k))*[0] for k in knowledge]
939 | knowledge = torch.tensor(knowledge).unsqueeze(-1)
940 |
941 | feature = torch.zeros(bS, mL_n, self.question_knowledge_dim).scatter_(dim=-1,
942 | index=knowledge,
943 | value=1).to(device)
944 | wenc_n = torch.cat([wenc_n, feature], -1)
945 |
946 | wenc_hs = encode_hpu(self.enc_h, wemb_hpu, l_hpu, l_hs) # [b, hs, dim]
947 |
948 | knowledge_header = [k + (max(l_hs) - len(k)) * [0]
949 | for k in knowledge_header]
950 | knowledge_header = torch.tensor(knowledge_header).unsqueeze(-1)
951 | feature2 = torch.zeros(bS, max(l_hs), self.header_knowledge_dim).scatter_(dim=-1,
952 | index=knowledge_header,
953 | value=1).to(device)
954 | wenc_hs = torch.cat([wenc_hs, feature2], -1)
955 |
956 | wenc_hs_ob = [] # observed hs
957 |
958 | for b in range(bS):
959 | # [[...], [...]]
960 | # Pad list to maximum number of selections
961 | real = [wenc_hs[b, col] for col in wc[b]]
962 | # this padding could be wrong. Test with zero padding later.
963 | pad = (self.mL_w - wn[b]) * [wenc_hs[b, 0]]
964 | # It is not used in the loss function.
965 | wenc_hs_ob1 = torch.stack(real + pad)
966 | wenc_hs_ob.append(wenc_hs_ob1)
967 |
968 | # list to [B, 4, dim] tensor.
969 | wenc_hs_ob = torch.stack(wenc_hs_ob) # list to tensor.
970 | wenc_hs_ob = wenc_hs_ob.to(device)
971 |
972 | # Column attention
973 | # [B, 1, mL_n, dim] * [B, 4, dim, 1]
974 | # -> [B, 4, mL_n, 1] -> [B, 4, mL_n]
975 | # multiplication bewteen NLq-tokens and 【selected】 column
976 | att = torch.matmul(self.W_att(wenc_n).unsqueeze(1),
977 | wenc_hs_ob.unsqueeze(3)
978 | ).squeeze(3)
979 | # Penalty for blank part.
980 |
981 | for b, l_n1 in enumerate(l_n):
982 | if l_n1 < mL_n:
983 | att[b, :, l_n1:] = -10000000000
984 |
985 | p = self.softmax_dim2(att) # p( n| selected_col )
986 |
987 | if show_p_wv:
988 | # p = [b, hs, n]
989 | if p.shape[0] != 1:
990 | raise Exception("Batch size should be 1.")
991 | fig = figure(2001)
992 | # subplot(6,2,7)
993 | subplot2grid((7, 2), (5, 1), rowspan=2)
994 | cla()
995 | _color = 'rgbkcm'
996 | _symbol = '.......'
997 | for i_wn in range(self.mL_w):
998 | color_idx = i_wn % len(_color)
999 | plot(p[0][i_wn][:].data.numpy() - i_wn, '--' +
1000 | _symbol[color_idx]+_color[color_idx], ms=7)
1001 |
1002 | title('wv: p_n for selected h')
1003 | grid(True)
1004 | fig.tight_layout()
1005 | fig.canvas.draw()
1006 | show()
1007 |
1008 | # [B, 1, mL_n, dim] * [B, 4, mL_n, 1]
1009 | # --> [B, 4, mL_n, dim]
1010 | # --> [B, 4, dim]
1011 | c_n = torch.mul(wenc_n.unsqueeze(1), p.unsqueeze(3)).sum(dim=2)
1012 |
1013 | # Select observed headers only.
1014 | # Also generate one_hot vector encoding info of the operator
1015 | # [B, 4, dim]
1016 | wenc_op = []
1017 | for b in range(bS):
1018 | # [[...], [...]]
1019 | # Pad list to maximum number of selections
1020 | wenc_op1 = torch.zeros(self.mL_w, self.n_cond_ops)
1021 | wo1 = wo[b]
1022 | idx_scatter = []
1023 | l_wo1 = len(wo1)
1024 | for i_wo11 in range(self.mL_w):
1025 | if i_wo11 < l_wo1:
1026 | wo11 = wo1[i_wo11]
1027 | idx_scatter.append([int(wo11)])
1028 | else:
1029 | idx_scatter.append([0]) # not used anyway
1030 |
1031 | wenc_op1 = wenc_op1.scatter(1, torch.tensor(idx_scatter), 1)
1032 |
1033 | wenc_op.append(wenc_op1)
1034 |
1035 | # list to [B, 4, dim] tensor.
1036 | wenc_op = torch.stack(wenc_op) # list to tensor.
1037 | wenc_op = wenc_op.to(device)
1038 |
1039 | # Now after concat, calculate logits for each token
1040 | # [bS, 5-1, 3*hS] = [bS, 4, 300]
1041 | vec = torch.cat([self.W_c(c_n), self.W_hs(
1042 | wenc_hs_ob), self.W_op(wenc_op)], dim=2)
1043 |
1044 | # Make extended vector based on encoded nl token containing column and operator information.
1045 | # wenc_n = [bS, mL, 100]
1046 | # vec2 = [bS, 4, mL, 400]
1047 | # [bS, 4, 1, 300] -> [bS, 4, mL, 300]
1048 | vec1e = vec.unsqueeze(2).expand(-1, -1, mL_n, -1)
1049 | # [bS, 1, mL, 100] -> [bS, 4, mL, 100]
1050 | wenc_ne = wenc_n.unsqueeze(1).expand(-1, 4, -1, -1)
1051 | vec2 = torch.cat([vec1e, wenc_ne], dim=3)
1052 |
1053 | # now make logits
1054 | s_wv = self.wv_out(vec2) # [bS, 4, mL, 400] -> [bS, 4, mL, 2]
1055 |
1056 | # penalty for spurious tokens
1057 | for b, l_n1 in enumerate(l_n):
1058 | if l_n1 < mL_n:
1059 | s_wv[b, :, l_n1:, :] = -10000000000
1060 | return s_wv
1061 |
--------------------------------------------------------------------------------