├── sqlova ├── __init__.py ├── model │ ├── __init__.py │ ├── nl2sql │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── wikisql_models.cpython-36.pyc │ └── __pycache__ │ │ └── __init__.cpython-36.pyc ├── utils │ ├── __init__.py │ ├── utils.py │ └── wikisql_formatter.py └── __pycache__ │ └── __init__.cpython-36.pyc ├── human_eval ├── sample.png └── README.md ├── run_make_table.sh ├── run_infer.sh ├── .gitignore ├── wikisql ├── __pycache__ │ └── __init__.cpython-36.pyc ├── lib │ ├── common.py │ ├── dbengine.py │ ├── table.py │ └── query.py ├── LICENSE_WikiSQL ├── evaluate.py └── annotate.py ├── run_train.sh ├── sqlnet ├── LICENSE └── dbengine.py ├── add_question.py ├── evaluate_ws.py ├── bert ├── convert_tf_checkpoint_to_pytorch.py ├── LICENSE_bert ├── tokenization.py └── README_bert.md ├── add_csv.py ├── NOTICE ├── predict.py ├── annotate_ws.py ├── README.md ├── LICENSE.md ├── train_decoder_layer.py └── train_shallow_layer.py /sqlova/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /sqlova/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sqlova/model/nl2sql/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sqlova/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /human_eval/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/sqlova/HEAD/human_eval/sample.png -------------------------------------------------------------------------------- /run_make_table.sh: -------------------------------------------------------------------------------- 1 | python3 add_csv.py ctable ftable1.csv 2 | python3 add_csv.py ctable ftable2.csv -------------------------------------------------------------------------------- /run_infer.sh: -------------------------------------------------------------------------------- 1 | python3 train.py --do_infer --infer_loop --trained --bert_type_abb uL --max_seq_length 222 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | sqlova/**/*.pyc 2 | wikisql/**/*.pyc 3 | sqlnet/__pycache__ 4 | bert/__pycache__ 5 | .idea 6 | *.swp 7 | -------------------------------------------------------------------------------- /sqlova/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/sqlova/HEAD/sqlova/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /wikisql/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/sqlova/HEAD/wikisql/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /sqlova/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/sqlova/HEAD/sqlova/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /sqlova/model/nl2sql/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/sqlova/HEAD/sqlova/model/nl2sql/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | python3 train.py --do_train --seed 1 --bS 16 --accumulate_gradients 2 --bert_type_abb uS --fine_tune --lr 0.001 --lr_bert 0.00001 --max_seq_length 222 2 | -------------------------------------------------------------------------------- /sqlova/model/nl2sql/__pycache__/wikisql_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/sqlova/HEAD/sqlova/model/nl2sql/__pycache__/wikisql_models.cpython-36.pyc -------------------------------------------------------------------------------- /wikisql/lib/common.py: -------------------------------------------------------------------------------- 1 | def count_lines(fname): 2 | with open(fname) as f: 3 | return sum(1 for line in f) 4 | 5 | 6 | def detokenize(tokens): 7 | ret = '' 8 | for g, a in zip(tokens['gloss'], tokens['after']): 9 | ret += g + a 10 | return ret.strip() 11 | -------------------------------------------------------------------------------- /human_eval/README.md: -------------------------------------------------------------------------------- 1 | # The (approximate) human performance of WikiSQL dataset 2 | 3 | We (approximately) measured the human performance of [WikiSQL dataset](https://arxiv.org/abs/1709.00103), using the crowdsourcing platform, Amazon Mechanical Turk. 4 | We used the randomly sampled 1,551 examples (\~10%) from the WikiSQL test set consisting of 15,878 examples in total, and 246 different qualified crowdworkers participated in the task. 5 | 6 | The sample task page for crowdworkers is shown below: 7 | ![example](sample.png "The example of the task page") 8 | 9 | The crowd task resembles the execution of SQL query to be comparable with the execution accuracy adapted in the [official WikiSQL leaderboard](https://github.com/salesforce/WikiSQL). 10 | The accuracy of crowdworkers on the randomly sampled test data is 88.3%, while the execution accuracy of our model (SQLova) over 1,551 samples are 86.8% (w/o EG) and 91.0% (w/ EG). 11 | 12 | Please note that the Ground Truth values in the [measurement result file](result.tsv) means the SQL execution results of ground truth SQL queries. The SQL execution was conducted by the script from the [official leaderboard](https://github.com/salesforce/WikiSQL). The retrieved Ground Truth values are a bit erroneous and thus the experts (us) manually checked the actual answers of given questions and checked correctness of all crowd answers. 13 | 14 | The details of human performance measurement can be found at [our paper](https://arxiv.org/pdf/1902.01069.pdf)'s section 5.2 Measuring Human Performance. 15 | -------------------------------------------------------------------------------- /wikisql/LICENSE_WikiSQL: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Salesforce Research 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /sqlnet/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Xiaojun Xu, Chang Liu and Dawn Song 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /wikisql/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from lib.dbengine import DBEngine 6 | from lib.query import Query 7 | from lib.common import count_lines 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = ArgumentParser() 12 | parser.add_argument('source_file', help='source file for the prediction') 13 | parser.add_argument('db_file', help='source database for the prediction') 14 | parser.add_argument('pred_file', help='predictions by the model') 15 | parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') 16 | args = parser.parse_args() 17 | 18 | engine = DBEngine(args.db_file) 19 | exact_match = [] 20 | with open(args.source_file) as fs, open(args.pred_file) as fp: 21 | grades = [] 22 | for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): 23 | eg = json.loads(ls) 24 | ep = json.loads(lp) 25 | qg = Query.from_dict(eg['sql'], ordered=args.ordered) 26 | gold = engine.execute_query(eg['table_id'], qg, lower=True) 27 | pred = ep.get('error', None) 28 | qp = None 29 | if not ep.get('error', None): 30 | try: 31 | qp = Query.from_dict(ep['query'], ordered=args.ordered) 32 | pred = engine.execute_query(eg['table_id'], qp, lower=True) 33 | except Exception as e: 34 | pred = repr(e) 35 | correct = pred == gold 36 | match = qp == qg 37 | grades.append(correct) 38 | exact_match.append(match) 39 | print(json.dumps({ 40 | 'ex_accuracy': sum(grades) / len(grades), 41 | 'lf_accuracy': sum(exact_match) / len(exact_match), 42 | }, indent=2)) 43 | -------------------------------------------------------------------------------- /add_question.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Add a line of json representing a question into .jsonl 4 | # Call as: 5 | # python add_question.py 6 | # 7 | # This utility is not intended for use during training. A dummy label is added to the 8 | # question to make it loadable by existing code. 9 | # 10 | # For example, suppose we downloaded this list of us state abbreviations: 11 | # https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/USstateAbbreviations.csv 12 | # Let's rename it as something short, say "abbrev.csv" 13 | # Now we can add it to a split called say "playground": 14 | # python add_csv.py playground abbrev.csv 15 | # And now we can add a question about it to the same split: 16 | # python add_question.py playground abbrev "what state has ansi digits of 11" 17 | # The next step would be to annotate the split: 18 | # python annotate_ws.py --din $PWD --dout $PWD --split playground 19 | # Then we're ready to run prediction on the split with predict.py 20 | 21 | import argparse, csv, json 22 | 23 | from sqlalchemy import Column, create_engine, Integer, MetaData, String, Table 24 | from sqlalchemy.exc import ArgumentError 25 | from sqlalchemy.ext.declarative import declarative_base 26 | from sqlalchemy.orm import create_session, mapper 27 | 28 | def question_to_json(table_id, question, json_file_name): 29 | record = { 30 | 'phase': 1, 31 | 'table_id': table_id, 32 | 'question': question, 33 | 'sql': {'sel': 0, 'conds': [], 'agg': 0} 34 | } 35 | with open(json_file_name, 'a+') as fout: 36 | json.dump(record, fout) 37 | fout.write('\n') 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('split') 42 | parser.add_argument('table_id') 43 | parser.add_argument('question', type=str, nargs='+') 44 | args = parser.parse_args() 45 | json_file_name = '{}.jsonl'.format(args.split) 46 | question_to_json(args.table_id, " ".join(args.question), json_file_name) 47 | print("Added question (with dummy label) to {}".format(json_file_name)) 48 | -------------------------------------------------------------------------------- /wikisql/lib/dbengine.py: -------------------------------------------------------------------------------- 1 | import records 2 | import re 3 | from babel.numbers import parse_decimal, NumberFormatError 4 | from wikisql.lib.query import Query 5 | 6 | # Jan 3, 2019. Wonseok modify the lib. path 7 | 8 | 9 | schema_re = re.compile(r'\((.+)\)') 10 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') 11 | 12 | 13 | class DBEngine: 14 | 15 | def __init__(self, fdb): 16 | self.db = records.Database('sqlite:///{}'.format(fdb)) 17 | 18 | def execute_query(self, table_id, query, *args, **kwargs): 19 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 20 | 21 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 22 | if not table_id.startswith('table'): 23 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 24 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql 25 | schema_str = schema_re.findall(table_info)[0] 26 | schema = {} 27 | for tup in schema_str.split(', '): 28 | c, t = tup.split() 29 | schema[c] = t 30 | select = 'col{}'.format(select_index) 31 | agg = Query.agg_ops[aggregation_index] 32 | if agg: 33 | select = '{}({})'.format(agg, select) 34 | where_clause = [] 35 | where_map = {} 36 | for col_index, op, val in conditions: 37 | if lower and isinstance(val, str): 38 | val = val.lower() 39 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 40 | try: 41 | val = float(parse_decimal(val)) 42 | except NumberFormatError as e: 43 | val = float(num_re.findall(val)[0]) 44 | where_clause.append('col{} {} :col{}'.format(col_index, Query.cond_ops[op], col_index)) 45 | where_map['col{}'.format(col_index)] = val 46 | where_str = '' 47 | if where_clause: 48 | where_str = 'WHERE ' + ' AND '.join(where_clause) 49 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 50 | out = self.db.query(query, **where_map) 51 | return [o.result for o in out] 52 | -------------------------------------------------------------------------------- /evaluate_ws.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from wikisql.lib.dbengine import DBEngine 6 | from wikisql.lib.query import Query 7 | from wikisql.lib.common import count_lines 8 | 9 | import os 10 | 11 | # Jan1 2019. Wonseok. Path info has added to original wikisql/evaluation.py 12 | # Only need to add "query" (essentially "sql" in original data) and "table_id" while constructing file. 13 | 14 | if __name__ == '__main__': 15 | 16 | # Hyper parameters 17 | mode = 'dev' 18 | ordered = False 19 | 20 | dset_name = 'wikisql_tok' 21 | saved_epoch = 'best' # 30-162 22 | 23 | # Set path 24 | path_h = '/home/wonseok' # change to your home folder 25 | path_wikisql_tok = os.path.join(path_h, 'data', 'wikisql_tok') 26 | path_save_analysis = '.' 27 | 28 | # Path for evaluation results. 29 | path_wikisql0 = os.path.join(path_h,'data/WikiSQL-1.1/data') 30 | path_source = os.path.join(path_wikisql0, f'{mode}.jsonl') 31 | path_db = os.path.join(path_wikisql0, f'{mode}.db') 32 | path_pred = os.path.join(path_save_analysis, f'results_{mode}.jsonl') 33 | 34 | 35 | # For the case when use "argument" 36 | parser = ArgumentParser() 37 | parser.add_argument('--source_file', help='source file for the prediction', default=path_source) 38 | parser.add_argument('--db_file', help='source database for the prediction', default=path_db) 39 | parser.add_argument('--pred_file', help='predictions by the model', default=path_pred) 40 | parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') 41 | args = parser.parse_args() 42 | args.ordered=ordered 43 | 44 | engine = DBEngine(args.db_file) 45 | exact_match = [] 46 | with open(args.source_file) as fs, open(args.pred_file) as fp: 47 | grades = [] 48 | for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): 49 | eg = json.loads(ls) 50 | ep = json.loads(lp) 51 | qg = Query.from_dict(eg['sql'], ordered=args.ordered) 52 | gold = engine.execute_query(eg['table_id'], qg, lower=True) 53 | pred = ep.get('error', None) 54 | qp = None 55 | if not ep.get('error', None): 56 | try: 57 | qp = Query.from_dict(ep['query'], ordered=args.ordered) 58 | pred = engine.execute_query(eg['table_id'], qp, lower=True) 59 | except Exception as e: 60 | pred = repr(e) 61 | correct = pred == gold 62 | match = qp == qg 63 | grades.append(correct) 64 | exact_match.append(match) 65 | 66 | print(json.dumps({ 67 | 'ex_accuracy': sum(grades) / len(grades), 68 | 'lf_accuracy': sum(exact_match) / len(exact_match), 69 | }, indent=2)) 70 | 71 | 72 | -------------------------------------------------------------------------------- /sqlova/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present NAVER Corp. 2 | # Apache License v2.0 3 | 4 | # Wonseok Hwang 5 | import os, json 6 | import random as python_random 7 | from matplotlib.pylab import * 8 | 9 | 10 | def generate_perm_inv(perm): 11 | # Definitly correct. 12 | perm_inv = zeros(len(perm), dtype=int32) 13 | for i, p in enumerate(perm): 14 | perm_inv[int(p)] = i 15 | 16 | return perm_inv 17 | 18 | 19 | def ensure_dir(my_path): 20 | """ Generate directory if not exists 21 | """ 22 | if not os.path.exists(my_path): 23 | os.makedirs(my_path) 24 | 25 | 26 | def topk_multi_dim(tensor, n_topk=1, batch_exist=True): 27 | 28 | if batch_exist: 29 | idxs = [] 30 | for b, tensor1 in enumerate(tensor): 31 | idxs1 = [] 32 | tensor1_1d = tensor1.reshape(-1) 33 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk) 34 | idxs_list = unravel_index(idxs_1d.cpu().numpy(), tensor1.shape) 35 | # (dim0, dim1, dim2, ...) 36 | 37 | # reconstruct 38 | for i_beam in range(n_topk): 39 | idxs11 = [] 40 | for idxs_list1 in idxs_list: 41 | idxs11.append(idxs_list1[i_beam]) 42 | idxs1.append(idxs11) 43 | idxs.append(idxs1) 44 | 45 | else: 46 | tensor1 = tensor 47 | idxs1 = [] 48 | tensor1_1d = tensor1.reshape(-1) 49 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk) 50 | idxs_list = unravel_index(idxs_1d.numpy(), tensor1.shape) 51 | # (dim0, dim1, dim2, ...) 52 | 53 | # reconstruct 54 | for i_beam in range(n_topk): 55 | idxs11 = [] 56 | for idxs_list1 in idxs_list: 57 | idxs11.append(idxs_list1[i_beam]) 58 | idxs1.append(idxs11) 59 | idxs = idxs1 60 | return idxs 61 | 62 | 63 | def json_default_type_checker(o): 64 | """ 65 | From https://stackoverflow.com/questions/11942364/typeerror-integer-is-not-json-serializable-when-serializing-json-in-python 66 | """ 67 | if isinstance(o, int64): return int(o) 68 | raise TypeError 69 | 70 | 71 | def load_jsonl(path_file, toy_data=False, toy_size=4, shuffle=False, seed=1): 72 | data = [] 73 | 74 | with open(path_file, "r", encoding="utf-8") as f: 75 | for idx, line in enumerate(f): 76 | if toy_data and idx >= toy_size and (not shuffle): 77 | break 78 | t1 = json.loads(line.strip()) 79 | data.append(t1) 80 | 81 | if shuffle and toy_data: 82 | # When shuffle required, get all the data, shuffle, and get the part of data. 83 | print( 84 | f"If the toy-data is used, the whole data loaded first and then shuffled before get the first {toy_size} data") 85 | 86 | python_random.Random(seed).shuffle(data) # fixed 87 | data = data[:toy_size] 88 | 89 | return data 90 | -------------------------------------------------------------------------------- /sqlova/utils/wikisql_formatter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present NAVER Corp. 2 | # Apache License v2.0 3 | 4 | 5 | # Wonseok Hwang 6 | # Convert the wikisql format to the suitable format for the BERT. 7 | import os, sys, json 8 | from matplotlib.pylab import * 9 | 10 | 11 | def get_squad_style_ans(nlu, sql): 12 | conds = sql['conds'] 13 | answers = [] 14 | for cond1 in conds: 15 | a1 = {} 16 | wv1 = cond1[2] 17 | a1['text'] = wv1 18 | a1['answer_start'] = nlu.lower().find(str(wv1).lower()) 19 | if a1['answer_start'] < 0 or a1['answer_start'] >= len(nlu): 20 | raise EnvironmentError 21 | answers.append(a1) 22 | 23 | return answers 24 | 25 | 26 | def get_qas(path_q, tid): 27 | qas = [] 28 | with open(path_q, 'r') as f_q: 29 | qnum = -1 30 | for j, q1 in enumerate(f_q): 31 | q1 = json.loads(q1) 32 | tid_q = q1['table_id'] 33 | 34 | if tid_q != tid: 35 | continue 36 | else: 37 | qnum += 1 38 | # print(tid_q, tid) 39 | qas1 = {} 40 | nlu = q1['question'] 41 | sql = q1['sql'] 42 | 43 | qas1['question'] = nlu 44 | qas1['id'] = f'{tid_q}-{qnum}' 45 | qas1['answers'] = get_squad_style_ans(nlu, sql) 46 | qas1['c_answers'] = sql 47 | 48 | qas.append(qas1) 49 | 50 | return qas 51 | 52 | 53 | def get_tbl_context(t1): 54 | context = '' 55 | 56 | header_tok = t1['header'] 57 | # Here Join scheme can be changed. 58 | header_joined = ' '.join(header_tok) 59 | context += header_joined 60 | 61 | return context 62 | 63 | def generate_wikisql_bert(path_wikisql, dset_type): 64 | path_q = os.path.join(path_wikisql, f'{dset_type}.jsonl') 65 | path_tbl = os.path.join(path_wikisql, f'{dset_type}.tables.jsonl') 66 | 67 | # Generate new json file 68 | with open(path_tbl, 'r') as f_tbl: 69 | wikisql = {'version': "v1.1"} 70 | data = [] 71 | data1 = {} 72 | paragraphs = [] # new tbls 73 | for i, t1 in enumerate(f_tbl): 74 | paragraphs1 = {} 75 | 76 | t1 = json.loads(t1) 77 | tid = t1['id'] 78 | qas = get_qas(path_q, tid) 79 | 80 | paragraphs1['qas'] = qas 81 | paragraphs1['tid'] = tid 82 | paragraphs1['context'] = get_tbl_context(t1) 83 | # paragraphs1['context_page_title'] = t1['page_title'] # not always present 84 | paragraphs1['context_headers'] = t1['header'] 85 | paragraphs1['context_headers_type'] = t1['types'] 86 | paragraphs1['context_contents'] = t1['rows'] 87 | 88 | paragraphs.append(paragraphs1) 89 | data1['paragraphs'] = paragraphs 90 | data1['title'] = 'wikisql' 91 | data.append(data1) 92 | wikisql['data'] = data 93 | 94 | # Save 95 | with open(os.path.join(path_wikisql, f'{dset_type}_bert.json'), 'w', encoding='utf-8') as fnew: 96 | json_str = json.dumps(wikisql, ensure_ascii=False) 97 | json_str += '\n' 98 | fnew.writelines(json_str) 99 | 100 | 101 | if __name__=='__main__': 102 | 103 | # 0. Load wikisql 104 | path_h = '/Users/wonseok' 105 | path_wikisql = os.path.join(path_h, 'data', 'WikiSQL-1.1', 'data') 106 | 107 | 108 | dset_type_list = ['dev', 'test', 'train'] 109 | 110 | for dset_type in dset_type_list: 111 | generate_wikisql_bert(path_wikisql, dset_type) 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if name[0] in ['redictions', 'eq_relationship']: 75 | print("Skipping") 76 | continue 77 | pointer = model 78 | for m_name in name: 79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 80 | l = re.split(r'_(\d+)', m_name) 81 | else: 82 | l = [m_name] 83 | if l[0] == 'kernel': 84 | pointer = getattr(pointer, 'weight') 85 | else: 86 | pointer = getattr(pointer, l[0]) 87 | if len(l) >= 2: 88 | num = int(l[1]) 89 | pointer = pointer[num] 90 | if m_name[-11:] == '_embeddings': 91 | pointer = getattr(pointer, 'weight') 92 | elif m_name == 'kernel': 93 | array = np.transpose(array) 94 | try: 95 | assert pointer.shape == array.shape 96 | except AssertionError as e: 97 | e.args += (pointer.shape, array.shape) 98 | raise 99 | pointer.data = torch.from_numpy(array) 100 | 101 | # Save pytorch-model 102 | torch.save(model.state_dict(), args.pytorch_dump_path) 103 | 104 | if __name__ == "__main__": 105 | convert() 106 | -------------------------------------------------------------------------------- /add_csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Add a CSV file as a table into .db and .tables.jsonl 4 | # Call as: 5 | # python add_csv.py 6 | # For a CSV file called data.csv, the table will be called table_data in the .db 7 | # file, and will be assigned the id 'data'. 8 | # All columns are treated as text - no attempt is made to sniff the type of value 9 | # stored in the column. 10 | 11 | import argparse, csv, json, os, re 12 | from sqlalchemy import Column, create_engine, MetaData, String, Table 13 | 14 | 15 | def get_table_name(table_id): 16 | return 'table_{}'.format(table_id) 17 | 18 | 19 | def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name, working_folder='.'): 20 | sqlite_file_name = os.path.join(working_folder, sqlite_file_name) 21 | csv_file_name = os.path.join(working_folder, csv_file_name) 22 | 23 | engine = create_engine('sqlite:///{}'.format(sqlite_file_name)) 24 | 25 | with open(csv_file_name) as f: 26 | metadata = MetaData(bind=engine) 27 | cf = csv.DictReader(f, delimiter=',') 28 | simple_name = dict([(name, 'col%d' % i) for i, name in enumerate(cf.fieldnames)]) 29 | table = Table(get_table_name(table_id), metadata, 30 | *(Column(simple_name[name], String()) 31 | for name in cf.fieldnames)) 32 | table.drop(checkfirst=True) 33 | table.create() 34 | for row in cf: 35 | row = dict((simple_name[name], val) for name, val in row.items()) 36 | table.insert().values(**row).execute() 37 | return engine 38 | 39 | 40 | def is_num(val): 41 | pattern = re.compile(r'[-+]?\d*\.\d+|\d+') 42 | if pattern.match(val): 43 | return True 44 | else: 45 | return False 46 | 47 | 48 | def get_types(rows): 49 | types = [] 50 | row1 = rows[0] 51 | types = [] 52 | for val in row1: 53 | if is_num(val): 54 | types.append('real') 55 | else: 56 | types.append('text') 57 | return types 58 | 59 | 60 | def get_refined_rows(rows, types): 61 | real_idx = [] 62 | for i, type in enumerate(types): 63 | if type == 'real': 64 | real_idx.append(i) 65 | 66 | if len(real_idx) == 0: 67 | rrs = rows 68 | else: 69 | rrs = [] 70 | for row in rows: 71 | rr = row 72 | for idx in real_idx: 73 | rr[idx] = float(row[idx]) 74 | rrs.append(rr) 75 | return rrs 76 | 77 | 78 | def csv_to_json(table_id, csv_file_name, json_file_name, working_folder='.'): 79 | csv_file_name = os.path.join(working_folder, csv_file_name) 80 | json_file_name = os.path.join(working_folder, json_file_name) 81 | with open(csv_file_name) as f: 82 | cf = csv.DictReader(f, delimiter=',') 83 | record = {} 84 | record['header'] = [(name or 'col{}'.format(i)) for i, name in enumerate(cf.fieldnames)] 85 | record['page_title'] = None 86 | record['id'] = table_id 87 | record['caption'] = None 88 | record['rows'] = [list(row.values()) for row in cf] 89 | record['name'] = get_table_name(table_id) 90 | 91 | # infer type based on first row 92 | 93 | record['types'] = get_types(rows=record['rows']) 94 | refined_rows = get_refined_rows(rows=record['rows'], types=record['types']) 95 | record['rows'] = refined_rows 96 | 97 | # save 98 | with open(json_file_name, 'a+') as fout: 99 | json.dump(record, fout) 100 | fout.write('\n') 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('split') 106 | parser.add_argument('file', metavar='file.csv') 107 | working_folder = './data_and_model' 108 | args = parser.parse_args() 109 | table_id = os.path.splitext(os.path.basename(args.file))[0] 110 | csv_to_sqlite(table_id, args.file, '{}.db'.format(args.split), working_folder) 111 | csv_to_json(table_id, args.file, '{}.tables.jsonl'.format(args.split), working_folder) 112 | print("Added table with id '{id}' (name '{name}') to {split}.db and {split}.tables.jsonl".format( 113 | id=table_id, name=get_table_name(table_id), split=args.split)) 114 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | SQLova 2 | Copyright 2019-present NAVER Corp. 3 | 4 | This project contains subcomponents with separate copyright notices and license terms. 5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 6 | 7 | ======================================================================= 8 | huggingface/pytorch-pretrained-BERT from https://github.com/huggingface/pytorch-pretrained-BERT 9 | ======================================================================= 10 | 11 | # Copyright 2018 The HugginFace Inc. team. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | 25 | ======================================================================= 26 | xiaojunxu/SQLNet from https://github.com/xiaojunxu/SQLNet 27 | ======================================================================= 28 | 29 | BSD 3-Clause License 30 | 31 | Copyright (c) 2017, Xiaojun Xu, Chang Liu and Dawn Song 32 | All rights reserved. 33 | 34 | Redistribution and use in source and binary forms, with or without 35 | modification, are permitted provided that the following conditions are met: 36 | 37 | * Redistributions of source code must retain the above copyright notice, this 38 | list of conditions and the following disclaimer. 39 | 40 | * Redistributions in binary form must reproduce the above copyright notice, 41 | this list of conditions and the following disclaimer in the documentation 42 | and/or other materials provided with the distribution. 43 | 44 | * Neither the name of the copyright holder nor the names of its 45 | contributors may be used to endorse or promote products derived from 46 | this software without specific prior written permission. 47 | 48 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 49 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 50 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 51 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 52 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 53 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 54 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 55 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 56 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 57 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 58 | 59 | ======================================================================= 60 | salesforce/WikiSQL from https://github.com/salesforce/WikiSQL 61 | ======================================================================= 62 | 63 | BSD 3-Clause License 64 | 65 | Copyright (c) 2017, Salesforce Research 66 | All rights reserved. 67 | 68 | Redistribution and use in source and binary forms, with or without 69 | modification, are permitted provided that the following conditions are met: 70 | 71 | * Redistributions of source code must retain the above copyright notice, this 72 | list of conditions and the following disclaimer. 73 | 74 | * Redistributions in binary form must reproduce the above copyright notice, 75 | this list of conditions and the following disclaimer in the documentation 76 | and/or other materials provided with the distribution. 77 | 78 | * Neither the name of the copyright holder nor the names of its 79 | contributors may be used to endorse or promote products derived from 80 | this software without specific prior written permission. 81 | 82 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 83 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 84 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 85 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 86 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 87 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 88 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 89 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 90 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 91 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 92 | -------------------------------------------------------------------------------- /wikisql/annotate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 3 | import os 4 | import records 5 | import ujson as json 6 | from stanza.nlp.corenlp import CoreNLPClient 7 | from tqdm import tqdm 8 | import copy 9 | from lib.common import count_lines, detokenize 10 | from lib.query import Query 11 | 12 | 13 | client = None 14 | 15 | 16 | def annotate(sentence, lower=True): 17 | global client 18 | if client is None: 19 | client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(',')) 20 | words, gloss, after = [], [], [] 21 | for s in client.annotate(sentence): 22 | for t in s: 23 | words.append(t.word) 24 | gloss.append(t.originalText) 25 | after.append(t.after) 26 | if lower: 27 | words = [w.lower() for w in words] 28 | return { 29 | 'gloss': gloss, 30 | 'words': words, 31 | 'after': after, 32 | } 33 | 34 | 35 | def annotate_example(example, table): 36 | ann = {'table_id': example['table_id']} 37 | ann['question'] = annotate(example['question']) 38 | ann['table'] = { 39 | 'header': [annotate(h) for h in table['header']], 40 | } 41 | ann['query'] = sql = copy.deepcopy(example['sql']) 42 | for c in ann['query']['conds']: 43 | c[-1] = annotate(str(c[-1])) 44 | 45 | q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format(Query.agg_ops[sql['agg']], table['header'][sql['sel']]) 46 | q2 = ['SYMCOL {} SYMOP {} SYMCOND {}'.format(table['header'][col], Query.cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds']] 47 | if q2: 48 | q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND' 49 | else: 50 | q2 = 'SYMEND' 51 | inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question} SYMEND'.format( 52 | syms=' '.join(['SYM' + s for s in Query.syms]), 53 | table=' '.join(['SYMCOL ' + s for s in table['header']]), 54 | question=example['question'], 55 | aggops=' '.join([s for s in Query.agg_ops]), 56 | condops=' '.join([s for s in Query.cond_ops]), 57 | ) 58 | ann['seq_input'] = annotate(inp) 59 | out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1 60 | ann['seq_output'] = annotate(out) 61 | ann['where_output'] = annotate(q2) 62 | assert 'symend' in ann['seq_output']['words'] 63 | assert 'symend' in ann['where_output']['words'] 64 | return ann 65 | 66 | 67 | def is_valid_example(e): 68 | if not all([h['words'] for h in e['table']['header']]): 69 | return False 70 | headers = [detokenize(h).lower() for h in e['table']['header']] 71 | if len(headers) != len(set(headers)): 72 | return False 73 | input_vocab = set(e['seq_input']['words']) 74 | for w in e['seq_output']['words']: 75 | if w not in input_vocab: 76 | print('query word "{}" is not in input vocabulary.\n{}'.format(w, e['seq_input']['words'])) 77 | return False 78 | input_vocab = set(e['question']['words']) 79 | for col, op, cond in e['query']['conds']: 80 | for w in cond['words']: 81 | if w not in input_vocab: 82 | print('cond word "{}" is not in input vocabulary.\n{}'.format(w, e['question']['words'])) 83 | return False 84 | return True 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 89 | parser.add_argument('--din', default='data', help='data directory') 90 | parser.add_argument('--dout', default='annotated', help='output directory') 91 | args = parser.parse_args() 92 | 93 | if not os.path.isdir(args.dout): 94 | os.makedirs(args.dout) 95 | 96 | for split in ['train', 'dev', 'test']: 97 | fsplit = os.path.join(args.din, split) + '.jsonl' 98 | ftable = os.path.join(args.din, split) + '.tables.jsonl' 99 | fout = os.path.join(args.dout, split) + '.jsonl' 100 | 101 | print('annotating {}'.format(fsplit)) 102 | with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: 103 | print('loading tables') 104 | tables = {} 105 | for line in tqdm(ft, total=count_lines(ftable)): 106 | d = json.loads(line) 107 | tables[d['id']] = d 108 | print('loading examples') 109 | n_written = 0 110 | for line in tqdm(fs, total=count_lines(fsplit)): 111 | d = json.loads(line) 112 | a = annotate_example(d, tables[d['table_id']]) 113 | if not is_valid_example(a): 114 | raise Exception(str(a)) 115 | 116 | gold = Query.from_tokenized_dict(a['query']) 117 | reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True) 118 | if gold.lower() != reconstruct.lower(): 119 | raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct)) 120 | fo.write(json.dumps(a) + '\n') 121 | n_written += 1 122 | print('wrote {} examples'.format(n_written)) 123 | -------------------------------------------------------------------------------- /sqlnet/dbengine.py: -------------------------------------------------------------------------------- 1 | # From original SQLNet code. 2 | # Wonseok modified. 20180607 3 | 4 | import records 5 | import re 6 | from babel.numbers import parse_decimal, NumberFormatError 7 | 8 | 9 | schema_re = re.compile(r'\((.+)\)') # group (.......) dfdf (.... )group 10 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') # ? zero or one time appear of preceding character, * zero or several time appear of preceding character. 11 | # Catch something like -34.34, .4543, 12 | # | is 'or' 13 | 14 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 15 | cond_ops = ['=', '>', '<', 'OP'] 16 | 17 | class DBEngine: 18 | 19 | def __init__(self, fdb): 20 | #fdb = 'data/test.db' 21 | self.db = records.Database('sqlite:///{}'.format(fdb)) 22 | 23 | def execute_query(self, table_id, query, *args, **kwargs): 24 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 25 | 26 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 27 | if not table_id.startswith('table'): 28 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 29 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 30 | schema_str = schema_re.findall(table_info)[0] 31 | schema = {} 32 | for tup in schema_str.split(', '): 33 | c, t = tup.split() 34 | schema[c] = t 35 | select = 'col{}'.format(select_index) 36 | agg = agg_ops[aggregation_index] 37 | if agg: 38 | select = '{}({})'.format(agg, select) 39 | where_clause = [] 40 | where_map = {} 41 | for col_index, op, val in conditions: 42 | if lower and (isinstance(val, str) or isinstance(val, str)): 43 | val = val.lower() 44 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 45 | try: 46 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val)) 47 | # val = float(parse_decimal(val)) # somehow it generates error. 48 | val = float(parse_decimal(val, locale='en_US')) 49 | # print('!!!!!!After: val', val) 50 | 51 | except NumberFormatError as e: 52 | try: 53 | val = float(num_re.findall(val)[0]) # need to understand and debug this part. 54 | except: 55 | # Although column is of number, selected one is not number. Do nothing in this case. 56 | pass 57 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 58 | where_map['col{}'.format(col_index)] = val 59 | where_str = '' 60 | if where_clause: 61 | where_str = 'WHERE ' + ' AND '.join(where_clause) 62 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 63 | #print query 64 | out = self.db.query(query, **where_map) 65 | 66 | 67 | return [o.result for o in out] 68 | def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True): 69 | if not table_id.startswith('table'): 70 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 71 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 72 | schema_str = schema_re.findall(table_info)[0] 73 | schema = {} 74 | for tup in schema_str.split(', '): 75 | c, t = tup.split() 76 | schema[c] = t 77 | select = 'col{}'.format(select_index) 78 | agg = agg_ops[aggregation_index] 79 | if agg: 80 | select = '{}({})'.format(agg, select) 81 | where_clause = [] 82 | where_map = {} 83 | for col_index, op, val in conditions: 84 | if lower and (isinstance(val, str) or isinstance(val, str)): 85 | val = val.lower() 86 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 87 | try: 88 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val)) 89 | # val = float(parse_decimal(val)) # somehow it generates error. 90 | val = float(parse_decimal(val, locale='en_US')) 91 | # print('!!!!!!After: val', val) 92 | 93 | except NumberFormatError as e: 94 | val = float(num_re.findall(val)[0]) 95 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 96 | where_map['col{}'.format(col_index)] = val 97 | where_str = '' 98 | if where_clause: 99 | where_str = 'WHERE ' + ' AND '.join(where_clause) 100 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 101 | #print query 102 | out = self.db.query(query, **where_map) 103 | 104 | 105 | return [o.result for o in out], query 106 | def show_table(self, table_id): 107 | if not table_id.startswith('table'): 108 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 109 | rows = self.db.query('select * from ' +table_id) 110 | print(rows.dataset) -------------------------------------------------------------------------------- /wikisql/lib/table.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tabulate import tabulate 3 | from lib.query import Query 4 | import random 5 | 6 | 7 | class Table: 8 | 9 | schema_re = re.compile('\((.+)\)') 10 | 11 | def __init__(self, table_id, header, types, rows, caption=None): 12 | self.table_id = table_id 13 | self.header = header 14 | self.types = types 15 | self.rows = rows 16 | self.caption = caption 17 | 18 | def __repr__(self): 19 | return 'Table: {id}\nCaption: {caption}\n{tabulate}'.format( 20 | id=self.table_id, 21 | caption=self.caption, 22 | tabulate=tabulate(self.rows, headers=self.header) 23 | ) 24 | 25 | @classmethod 26 | def get_schema(cls, db, table_id): 27 | table_infos = db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=cls.get_id(table_id)).all() 28 | if table_infos: 29 | return table_infos[0] 30 | else: 31 | return None 32 | 33 | @classmethod 34 | def get_id(cls, table_id): 35 | return 'table_{}'.format(table_id.replace('-', '_')) 36 | 37 | @classmethod 38 | def from_db(cls, db, table_id): 39 | table_info = cls.get_schema(db, table_id) 40 | if table_info: 41 | schema_str = cls.schema_re.findall(table_info)[0] = [0].sql 42 | header, types = [], [] 43 | for tup in schema_str.split(', '): 44 | c, t = tup.split() 45 | header.append(c) 46 | types.append(t) 47 | rows = [[getattr(r, h) for h in header] for r in db.query('SELECT * from {}'.format(cls.get_id(table_id)))] 48 | return cls(table_id, header, types, rows) 49 | else: 50 | return None 51 | 52 | @property 53 | def name(self): 54 | return self.get_id(self.table_id) 55 | 56 | def create_table(self, db, replace_existing=False, lower=True): 57 | exists = self.get_schema(db, self.table_id) 58 | if exists: 59 | if replace_existing: 60 | db.query('DROP TABLE {}'.format(self.name)) 61 | else: 62 | return 63 | type_str = ', '.join(['col{} {}'.format(i, t) for i, t in enumerate(self.types)]) 64 | db.query('CREATE TABLE {name} ({types})'.format(name=self.name, types=type_str)) 65 | for row in self.rows: 66 | value_str = ', '.join([':val{}'.format(j) for j, c in enumerate(row)]) 67 | value_dict = {'val{}'.format(j): c for j, c in enumerate(row)} 68 | if lower: 69 | value_dict = {k: v.lower() if isinstance(v, str) else v for k, v in value_dict.items()} 70 | db.query('INSERT INTO {name} VALUES ({values})'.format(name=self.name, values=value_str), **value_dict) 71 | 72 | def execute_query(self, db, query, lower=True): 73 | sel_str = 'col{}'.format(query.sel_index) if query.sel_index >= 0 else '*' 74 | agg_str = sel_str 75 | agg_op = Query.agg_ops[query.agg_index] 76 | if agg_op: 77 | agg_str = '{}({})'.format(agg_op, sel_str) 78 | where_str = ' AND '.join(['col{} {} :col{}'.format(i, Query.cond_ops[o], i) for i, o, v in query.conditions]) 79 | where_map = {'col{}'.format(i): v for i, o, v in query.conditions} 80 | if lower: 81 | where_map = {k: v.lower() if isinstance(v, str) else v for k, v in where_map.items()} 82 | if where_map: 83 | where_str = 'WHERE ' + where_str 84 | 85 | if query.sel_index >= 0: 86 | query_str = 'SELECT {agg_str} AS result FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 87 | return [r.result for r in db.query(query_str, **where_map)] 88 | else: 89 | query_str = 'SELECT {agg_str} FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 90 | return [[getattr(r, 'col{}'.format(i)) for i in range(len(self.header))] for r in db.query(query_str, **where_map)] 91 | 92 | def query_str(self, query): 93 | agg_str = self.header[query.sel_index] 94 | agg_op = Query.agg_ops[query.agg_index] 95 | if agg_op: 96 | agg_str = '{}({})'.format(agg_op, agg_str) 97 | where_str = ' AND '.join(['{} {} {}'.format(self.header[i], Query.cond_ops[o], v) for i, o, v in query.conditions]) 98 | return 'SELECT {} FROM {} WHERE {}'.format(agg_str, self.name, where_str) 99 | 100 | def generate_query(self, db, max_cond=4): 101 | max_cond = min(len(self.header), max_cond) 102 | # sample a select column 103 | sel_index = random.choice(list(range(len(self.header)))) 104 | # sample where conditions 105 | query = Query(-1, Query.agg_ops.index('')) 106 | results = self.execute_query(db, query) 107 | condition_options = list(range(len(self.header))) 108 | condition_options.remove(sel_index) 109 | for i in range(max_cond): 110 | if not results: 111 | break 112 | cond_index = random.choice(condition_options) 113 | if self.types[cond_index] == 'text': 114 | cond_op = Query.cond_ops.index('=') 115 | else: 116 | cond_op = random.choice(list(range(len(Query.cond_ops)))) 117 | cond_val = random.choice([r[cond_index] for r in results]) 118 | query.conditions.append((cond_index, cond_op, cond_val)) 119 | new_results = self.execute_query(db, query) 120 | if [r[sel_index] for r in new_results] != [r[sel_index] for r in results]: 121 | condition_options.remove(cond_index) 122 | results = new_results 123 | else: 124 | query.conditions.pop() 125 | # sample an aggregation operation 126 | if self.types[sel_index] == 'text': 127 | query.agg_index = Query.agg_ops.index('') 128 | else: 129 | query.agg_index = random.choice(list(range(len(Query.agg_ops)))) 130 | query.sel_index = sel_index 131 | results = self.execute_query(db, query) 132 | return query, results 133 | 134 | def generate_queries(self, db, n=1, max_tries=5, lower=True): 135 | qs = [] 136 | for i in range(n): 137 | n_tries = 0 138 | r = None 139 | while r is None and n_tries < max_tries: 140 | q, r = self.generate_query(db, max_cond=4) 141 | n_tries += 1 142 | if r: 143 | qs.append((q, r)) 144 | return qs 145 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Use existing model to predict sql from tables and questions. 4 | # 5 | # For example, you can get a pretrained model from https://github.com/naver/sqlova/releases: 6 | # https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt 7 | # https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_best.pt 8 | # 9 | # Make sure you also have the following support files (see README for where to get them): 10 | # - bert_config_uncased_*.json 11 | # - vocab_uncased_*.txt 12 | # 13 | # Finally, you need some data - some files called: 14 | # - .db 15 | # - .jsonl 16 | # - .tables.jsonl 17 | # - _tok.jsonl # derived using annotate_ws.py 18 | # You can play with the existing train/dev/test splits, or make your own with 19 | # the add_csv.py and add_question.py utilities. 20 | # 21 | # Once you have all that, you are ready to predict, using: 22 | # python predict.py \ 23 | # --bert_type_abb uL \ # need to match the architecture of the model you are using 24 | # --model_file /model_best.pt \ 25 | # --bert_model_file /model_bert_best.pt \ 26 | # --bert_path \ 27 | # --result_path \ 28 | # --data_path \ 29 | # --split 30 | # 31 | # Results will be in a file called results_.jsonl in the result_path. 32 | 33 | import argparse, os 34 | from sqlnet.dbengine import DBEngine 35 | from sqlova.utils.utils_wikisql import * 36 | from train import construct_hyper_param, get_models 37 | 38 | # This is a stripped down version of the test() method in train.py - identical, except: 39 | # - does not attempt to measure accuracy and indeed does not expect the data to be labelled. 40 | # - saves plain text sql queries. 41 | # 42 | def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, 43 | max_seq_length, 44 | num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, 45 | path_db=None, dset_name='test'): 46 | 47 | model.eval() 48 | model_bert.eval() 49 | 50 | engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) 51 | results = [] 52 | for iB, t in enumerate(data_loader): 53 | nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) 54 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) 55 | g_wvi_corenlp = get_g_wvi_corenlp(t) 56 | wemb_n, wemb_h, l_n, l_hpu, l_hs, \ 57 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 58 | = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, 59 | num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) 60 | if not EG: 61 | # No Execution guided decoding 62 | s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) 63 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) 64 | pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) 65 | pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) 66 | else: 67 | # Execution guided decoding 68 | prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, 69 | l_hs, engine, tb, 70 | nlu_t, nlu_tt, 71 | tt_to_t_idx, nlu, 72 | beam_size=beam_size) 73 | # sort and generate 74 | pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) 75 | # Following variables are just for consistency with no-EG case. 76 | pr_wvi = None # not used 77 | pr_wv_str=None 78 | pr_wv_str_wp=None 79 | 80 | pr_sql_q = generate_sql_q(pr_sql_i, tb) 81 | 82 | for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)): 83 | results1 = {} 84 | results1["query"] = pr_sql_i1 85 | results1["table_id"] = tb[b]["id"] 86 | results1["nlu"] = nlu[b] 87 | results1["sql"] = pr_sql_q1 88 | results.append(results1) 89 | 90 | return results 91 | 92 | ## Set up hyper parameters and paths 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--model_file", required=True, help='model file to use (e.g. model_best.pt)') 95 | parser.add_argument("--bert_model_file", required=True, help='bert model file to use (e.g. model_bert_best.pt)') 96 | parser.add_argument("--bert_path", required=True, help='path to bert files (bert_config*.json etc)') 97 | parser.add_argument("--data_path", required=True, help='path to *.jsonl and *.db files') 98 | parser.add_argument("--split", required=True, help='prefix of jsonl and db files (e.g. dev)') 99 | parser.add_argument("--result_path", required=True, help='directory in which to place results') 100 | args = construct_hyper_param(parser) 101 | 102 | BERT_PT_PATH = args.bert_path 103 | path_save_for_evaluation = args.result_path 104 | 105 | # Load pre-trained models 106 | path_model_bert = args.bert_model_file 107 | path_model = args.model_file 108 | args.no_pretraining = True # counterintuitive, but avoids loading unused models 109 | model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model) 110 | 111 | # Load data 112 | dev_data, dev_table = load_wikisql_data(args.data_path, mode=args.split, toy_model=args.toy_model, toy_size=args.toy_size, no_hs_tok=True) 113 | dev_loader = torch.utils.data.DataLoader( 114 | batch_size=args.bS, 115 | dataset=dev_data, 116 | shuffle=False, 117 | num_workers=1, 118 | collate_fn=lambda x: x # now dictionary values are not merged! 119 | ) 120 | 121 | # Run prediction 122 | with torch.no_grad(): 123 | results = predict(dev_loader, 124 | dev_table, 125 | model, 126 | model_bert, 127 | bert_config, 128 | tokenizer, 129 | args.max_seq_length, 130 | args.num_target_layers, 131 | detail=False, 132 | path_db=args.data_path, 133 | st_pos=0, 134 | dset_name=args.split, EG=args.EG) 135 | 136 | # Save results 137 | save_for_evaluation(path_save_for_evaluation, results, args.split) 138 | -------------------------------------------------------------------------------- /annotate_ws.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # docker run --name corenlp -d -p 9000:9000 vzhong/corenlp-server 3 | # Wonseok Hwang. Jan 6 2019, Comment added 4 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 5 | import os 6 | import records 7 | import ujson as json 8 | from stanza.nlp.corenlp import CoreNLPClient 9 | from tqdm import tqdm 10 | import copy 11 | from wikisql.lib.common import count_lines, detokenize 12 | from wikisql.lib.query import Query 13 | 14 | 15 | client = None 16 | 17 | 18 | def annotate(sentence, lower=True): 19 | global client 20 | if client is None: 21 | client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(',')) 22 | words, gloss, after = [], [], [] 23 | for s in client.annotate(sentence): 24 | for t in s: 25 | words.append(t.word) 26 | gloss.append(t.originalText) 27 | after.append(t.after) 28 | if lower: 29 | words = [w.lower() for w in words] 30 | return { 31 | 'gloss': gloss, 32 | 'words': words, 33 | 'after': after, 34 | } 35 | 36 | 37 | def annotate_example(example, table): 38 | ann = {'table_id': example['table_id']} 39 | ann['question'] = annotate(example['question']) 40 | ann['table'] = { 41 | 'header': [annotate(h) for h in table['header']], 42 | } 43 | ann['query'] = sql = copy.deepcopy(example['sql']) 44 | for c in ann['query']['conds']: 45 | c[-1] = annotate(str(c[-1])) 46 | 47 | q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format(Query.agg_ops[sql['agg']], table['header'][sql['sel']]) 48 | q2 = ['SYMCOL {} SYMOP {} SYMCOND {}'.format(table['header'][col], Query.cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds']] 49 | if q2: 50 | q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND' 51 | else: 52 | q2 = 'SYMEND' 53 | inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question} SYMEND'.format( 54 | syms=' '.join(['SYM' + s for s in Query.syms]), 55 | table=' '.join(['SYMCOL ' + s for s in table['header']]), 56 | question=example['question'], 57 | aggops=' '.join([s for s in Query.agg_ops]), 58 | condops=' '.join([s for s in Query.cond_ops]), 59 | ) 60 | ann['seq_input'] = annotate(inp) 61 | out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1 62 | ann['seq_output'] = annotate(out) 63 | ann['where_output'] = annotate(q2) 64 | assert 'symend' in ann['seq_output']['words'] 65 | assert 'symend' in ann['where_output']['words'] 66 | return ann 67 | 68 | def find_sub_list(sl, l): 69 | # from stack overflow. 70 | results = [] 71 | sll = len(sl) 72 | for ind in (i for i, e in enumerate(l) if e == sl[0]): 73 | if l[ind:ind + sll] == sl: 74 | results.append((ind, ind + sll - 1)) 75 | 76 | return results 77 | 78 | def check_wv_tok_in_nlu_tok(wv_tok1, nlu_t1): 79 | """ 80 | Jan.2019: Wonseok 81 | Generate SQuAD style start and end index of wv in nlu. Index is for of after WordPiece tokenization. 82 | 83 | Assumption: where_str always presents in the nlu. 84 | 85 | return: 86 | st_idx of where-value string token in nlu under CoreNLP tokenization scheme. 87 | """ 88 | g_wvi1_corenlp = [] 89 | nlu_t1_low = [tok.lower() for tok in nlu_t1] 90 | for i_wn, wv_tok11 in enumerate(wv_tok1): 91 | wv_tok11_low = [tok.lower() for tok in wv_tok11] 92 | results = find_sub_list(wv_tok11_low, nlu_t1_low) 93 | st_idx, ed_idx = results[0] 94 | 95 | g_wvi1_corenlp.append( [st_idx, ed_idx] ) 96 | 97 | return g_wvi1_corenlp 98 | 99 | 100 | def annotate_example_ws(example, table): 101 | """ 102 | Jan. 2019: Wonseok 103 | Annotate only the information that will be used in our model. 104 | """ 105 | ann = {'table_id': example['table_id'],'phase': example['phase']} 106 | _nlu_ann = annotate(example['question']) 107 | ann['question'] = example['question'] 108 | ann['question_tok'] = _nlu_ann['gloss'] 109 | # ann['table'] = { 110 | # 'header': [annotate(h) for h in table['header']], 111 | # } 112 | ann['sql'] = example['sql'] 113 | ann['query'] = sql = copy.deepcopy(example['sql']) 114 | 115 | conds1 = ann['sql']['conds'] 116 | wv_ann1 = [] 117 | for conds11 in conds1: 118 | _wv_ann1 = annotate(str(conds11[2])) 119 | wv_ann11 = _wv_ann1['gloss'] 120 | wv_ann1.append( wv_ann11 ) 121 | 122 | # Check whether wv_ann exsits inside question_tok 123 | 124 | try: 125 | wvi1_corenlp = check_wv_tok_in_nlu_tok(wv_ann1, ann['question_tok']) 126 | ann['wvi_corenlp'] = wvi1_corenlp 127 | except: 128 | ann['wvi_corenlp'] = None 129 | ann['tok_error'] = 'SQuAD style st, ed are not found under CoreNLP.' 130 | 131 | return ann 132 | 133 | 134 | def is_valid_example(e): 135 | if not all([h['words'] for h in e['table']['header']]): 136 | return False 137 | headers = [detokenize(h).lower() for h in e['table']['header']] 138 | if len(headers) != len(set(headers)): 139 | return False 140 | input_vocab = set(e['seq_input']['words']) 141 | for w in e['seq_output']['words']: 142 | if w not in input_vocab: 143 | print('query word "{}" is not in input vocabulary.\n{}'.format(w, e['seq_input']['words'])) 144 | return False 145 | input_vocab = set(e['question']['words']) 146 | for col, op, cond in e['query']['conds']: 147 | for w in cond['words']: 148 | if w not in input_vocab: 149 | print('cond word "{}" is not in input vocabulary.\n{}'.format(w, e['question']['words'])) 150 | return False 151 | return True 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 156 | parser.add_argument('--din', default='/Users/wonseok/data/WikiSQL-1.1/data', help='data directory') 157 | parser.add_argument('--dout', default='/Users/wonseok/data/wikisql_tok', help='output directory') 158 | parser.add_argument('--split', default='train,dev,test', help='comma=separated list of splits to process') 159 | args = parser.parse_args() 160 | 161 | answer_toy = not True 162 | toy_size = 10 163 | 164 | if not os.path.isdir(args.dout): 165 | os.makedirs(args.dout) 166 | 167 | # for split in ['train', 'dev', 'test']: 168 | for split in args.split.split(','): 169 | fsplit = os.path.join(args.din, split) + '.jsonl' 170 | ftable = os.path.join(args.din, split) + '.tables.jsonl' 171 | fout = os.path.join(args.dout, split) + '_tok.jsonl' 172 | 173 | print('annotating {}'.format(fsplit)) 174 | with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: 175 | print('loading tables') 176 | 177 | # ws: Construct table dict with table_id as a key. 178 | tables = {} 179 | for line in tqdm(ft, total=count_lines(ftable)): 180 | d = json.loads(line) 181 | tables[d['id']] = d 182 | print('loading examples') 183 | n_written = 0 184 | cnt = -1 185 | for line in tqdm(fs, total=count_lines(fsplit)): 186 | cnt += 1 187 | d = json.loads(line) 188 | # a = annotate_example(d, tables[d['table_id']]) 189 | a = annotate_example_ws(d, tables[d['table_id']]) 190 | fo.write(json.dumps(a) + '\n') 191 | n_written += 1 192 | 193 | if answer_toy: 194 | if cnt > toy_size: 195 | break 196 | print('wrote {} examples'.format(n_written)) 197 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SQLova 2 | - SQLova is a neural semantic parser translating natural language utterance to SQL query. The name is originated from the name of our department: **S**earch & **QLova** ([Search & Clova](https://clova.ai/ko/research/publications.html)). 3 | 4 | ### Authors 5 | - [Wonseok Hwang](mailto:wonseok.hwang@navercorp.com), [Jinyeong Yim](mailto:jinyeong.yim@navercorp.com), [Seunghyun Park](mailto:seung.park@navercorp.com), and [Minjoon Seo](https://seominjoon.github.io). 6 | - Affiliation: Clova AI Research, NAVER Corp., Seongnam, Korea. 7 | - The updated version of manuscript is available from [arXiv](https://arxiv.org/abs/1902.01069). 8 | - The manuscript is significantly re-written to improve readability. 9 | - The detailed description of the model and human evaluation process have added. 10 | - To be presented at [KR2ML Workshop at NeurIPS 2019](https://kr2ml.github.io/2019/#about). 11 | - [The old version](https://ssl.pstatic.net/static/clova/service/clova_ai/research/publications/SQLova.pdf). 12 | 13 | ### Abstract 14 | - We present the new state-of-the-art semantic parsing model that translates a natural language (NL) utterance into a SQL query. 15 | - The model is evaluated on [WikiSQL](https://github.com/salesforce/WikiSQL), a semantic parsing dataset consisting of 80,654 (NL, SQL) pairs over 24,241 tables from Wikipedia. 16 | - We achieve **83.6%** logical form accuracy and **89.6%** execution accuracy on WikiSQL test set. 17 | 18 | ### The model in a nutshell 19 | - [BERT](https://arxiv.org/abs/1810.04805) based table- and context-aware word-embedding. 20 | - The sequence-to-SQL model leveraging recent works ([Seq2SQL](https://arxiv.org/abs/1709.00103), [SQLNet](https://arxiv.org/abs/1711.04436)). 21 | - [Execution-guided decoding](https://arxiv.org/abs/1807.03100) is applied in SQLova-EG. 22 | 23 | ### Results (Updated at Jan 12, 2019) 24 | | **Model** | Dev
logical form
accuracy | Dev
execution
accuracy | Test
logical form
accuracy | Test
execution
accuracy | 25 | | ----------- | ------------------------------------- | -------------------------------- | -------------------------------------- | ----------------------------------- | 26 | | SQLova | 81.6 (**+5.5**)^ | 87.2 (**+3.2**)^ | 80.7 (**+5.3**)^ | 86.2 (**+2.5**)^ | 27 | | SQLova-EG | 84.2 (**+8.2**)* | 90.2 (**+3.0**)* | 83.6(**+8.2**)* | 89.6 (**+2.5**)* | 28 | 29 | - ^: Compared to current [SOTA](https://github.com/salesforce/WikiSQL) models that do not use execution guided decoding. 30 | - *: Compared to current [SOTA](https://github.com/salesforce/WikiSQL). 31 | - The order of where conditions is ignored in measuring logical form accuracy in our model. 32 | 33 | 34 | 35 | ### Source code 36 | #### Requirements 37 | - `python3.6` or higher. 38 | - `PyTorch 0.4.0` or higher. 39 | - `CUDA 9.0` 40 | - Python libraries: `babel, matplotlib, defusedxml, tqdm` 41 | - Example 42 | - Install [minicoda](https://conda.io/miniconda.html) 43 | - `conda install pytorch torchvision -c pytorch` 44 | - `conda install -c conda-forge records==0.5.2` 45 | - `conda install babel` 46 | - `conda install matplotlib` 47 | - `conda install defusedxml` 48 | - `conda install tqdm` 49 | - The code has been tested on Tesla M40 GPU running on Ubuntu 16.04.4 LTS. 50 | 51 | #### Running code 52 | - Type `python3 train.py --seed 1 --bS 16 --accumulate_gradients 2 --bert_type_abb uS --fine_tune --lr 0.001 --lr_bert 0.00001 --max_seq_leng 222` on terminal. 53 | - `--seed 1`: Set the seed of random generator. The accuracies changes by few percent depending on `seed`. 54 | - `--bS 16`: Set the batch size by 16. 55 | - `--accumulate_gradients 2`: Make the effective batch size be `16 * 2 = 32`. 56 | - `--bert_type_abb uS`: Uncased-Base BERT model is used. Use `uL` to use Uncased-Large BERT. 57 | - `--fine_tune`: Train BERT. Without this, only the sequence-to-SQL module is trained. 58 | - `--lr 0.001`: Set the learning rate of the sequence-to-SQL module as 0.001. 59 | - `--lr_bert 0.00001`: Set the learning rate of BERT module as 0.00001. 60 | - `--max_seq_leng 222`: Set the maximum number of input token lengths of BERT. 61 | - The model should show ~79% logical accuracy (lx) on dev set after ~12 hrs (~10 epochs). Higher accuracy can be obtained with longer training, by selecting different seed, by using Uncased Large BERT model, or by using execution guided decoding. 62 | - Add `--EG` argument while running `train.py` to use execution guided decoding. 63 | - Whenever higher logical form accuracy calculated on the dev set, following three files are saved on current folder: 64 | - `model_best.pt`: the checkpoint of the the sequence-to-SQL module. 65 | - `model_bert_best.pt`: the checkpoint of the BERT module. 66 | - `results_dev.jsonl`: json file for official evaluation. 67 | - `Shallow-Layer` and `Decoder-Layer` models can be trained similarly (`train_shallow_layer.py`, `train_decoder_layer.py`). 68 | 69 | #### Evaluation on WikiSQL DEV set 70 | - To calculate logical form and execution accuracies on `dev` set using official evaluation script, 71 | - Download original [WikiSQL dataset](https://github.com/salesforce/WikiSQL). 72 | - tar xvf data.tar.bz2 73 | - Move them under `$HOME/data/WikiSQL-1.1/data` 74 | - Set path on `evaluation_ws.py`. This is the file where the path information has added on original `evaluation.py` script. Or you can use original [`evaluation.py`](https://github.com/salesforce/WikiSQL) by setting the path to the files by yourself. 75 | - Type `python3 evaluation_ws.py` on terminal. 76 | 77 | #### Evaluation on WikiSQL TEST set 78 | - Uncomment line 550-557 of `train.py` to load `test_loader` and `test_table`. 79 | - One `test(...)` function, use `test_loader` and `test_table` instead of `dev_loader` and `dev_table`. 80 | - Save the output of `test(...)` with `save_for_evaluation(...)` function. 81 | - Evaluate with `evaluatoin_ws.py` as before. 82 | 83 | #### Load pre-trained SQLova parameters. 84 | - Pretrained SQLova model parameters are uploaded in [release](https://github.com/naver/sqlova/releases). To start from this, uncomment line 562-565 and set paths. 85 | 86 | 87 | #### Code base 88 | - Pretrained BERT models were downloaded from [official repository](https://github.com/google-research/bert). 89 | - BERT code is from [huggingface-pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT). 90 | - The sequence-to-SQL model is started from the source code of [SQLNet](https://github.com/xiaojunxu/SQLNet) and significantly re-written while maintaining the basic column-attention and sequence-to-set structure of the SQLNet. 91 | 92 | #### Data 93 | - The data is annotated by using `annotate_ws.py` which is based on [`annotate.py`](https://github.com/salesforce/WikiSQL) from WikiSQL repository. The tokens of natural language guery, and the start and end indices of where-conditions on natural language tokens are annotated. 94 | - Pre-trained BERT parameters can be downloaded from BERT [official repository](https://github.com/google-research/bert) and can be coverted to `pt`file using following script. You need install both pytorch and tensorflow and change `BERT_BASE_DIR` to your data directory. 95 | 96 | ```sh 97 | cd sqlova 98 | export BERT_BASE_DIR=data/uncased_L-12_H-768_A-12 99 | python bert/convert_tf_checkpoint_to_pytorch.py \ 100 | --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ 101 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 102 | --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin 103 | ``` 104 | 105 | - `bert/convert_tf_checkpoint_to_pytorch.py` is from the previous version of [huggingface-pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT), and current version of `pytorch-pretrained-BERT` is not compatible with the bert model used in this repo due to the difference in variable names (in LayerNorm). See [this](https://github.com/naver/sqlova/issues/1) for the detail. 106 | - For the convenience, the annotated WikiSQL data and the PyTorch-converted pre-trained BERT parameters are available at [here](https://drive.google.com/file/d/1iJvsf38f16el58H4NPINQ7uzal5-V4v4/view?usp=sharing). 107 | 108 | ### License 109 | ``` 110 | Copyright 2019-present NAVER Corp. 111 | 112 | Licensed under the Apache License, Version 2.0 (the "License"); 113 | you may not use this file except in compliance with the License. 114 | You may obtain a copy of the License at 115 | 116 | http://www.apache.org/licenses/LICENSE-2.0 117 | 118 | Unless required by applicable law or agreed to in writing, software 119 | distributed under the License is distributed on an "AS IS" BASIS, 120 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121 | See the License for the specific language governing permissions and 122 | limitations under the License. 123 | ``` 124 | -------------------------------------------------------------------------------- /wikisql/lib/query.py: -------------------------------------------------------------------------------- 1 | from wikisql.lib.common import detokenize 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | import re 5 | 6 | 7 | re_whitespace = re.compile(r'\s+', flags=re.UNICODE) 8 | 9 | 10 | class Query: 11 | 12 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 13 | cond_ops = ['=', '>', '<', 'OP'] 14 | syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'] 15 | 16 | def __init__(self, sel_index, agg_index, conditions=tuple(), ordered=False): 17 | self.sel_index = sel_index 18 | self.agg_index = agg_index 19 | self.conditions = list(conditions) 20 | self.ordered = ordered 21 | 22 | def __eq__(self, other): 23 | if isinstance(other, self.__class__): 24 | indices = self.sel_index == other.sel_index and self.agg_index == other.agg_index 25 | if other.ordered: 26 | conds = [(col, op, str(cond).lower()) for col, op, cond in self.conditions] == [(col, op, str(cond).lower()) for col, op, cond in other.conditions] 27 | else: 28 | conds = set([(col, op, str(cond).lower()) for col, op, cond in self.conditions]) == set([(col, op, str(cond).lower()) for col, op, cond in other.conditions]) 29 | 30 | return indices and conds 31 | return NotImplemented 32 | 33 | def __ne__(self, other): 34 | if isinstance(other, self.__class__): 35 | return not self.__eq__(other) 36 | return NotImplemented 37 | 38 | def __hash__(self): 39 | return hash(tuple(sorted(self.__dict__.items()))) 40 | 41 | def __repr__(self): 42 | rep = 'SELECT {agg} {sel} FROM table'.format( 43 | agg=self.agg_ops[self.agg_index], 44 | sel='col{}'.format(self.sel_index), 45 | ) 46 | if self.conditions: 47 | rep += ' WHERE ' + ' AND '.join(['{} {} {}'.format('col{}'.format(i), self.cond_ops[o], v) for i, o, v in self.conditions]) 48 | return rep 49 | 50 | def to_dict(self): 51 | return {'sel': self.sel_index, 'agg': self.agg_index, 'conds': self.conditions} 52 | 53 | def lower(self): 54 | conds = [] 55 | for col, op, cond in self.conditions: 56 | conds.append([col, op, cond.lower()]) 57 | return self.__class__(self.sel_index, self.agg_index, conds) 58 | 59 | @classmethod 60 | def from_dict(cls, d, ordered=False): 61 | return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds'], ordered=ordered) 62 | 63 | @classmethod 64 | def from_tokenized_dict(cls, d): 65 | conds = [] 66 | for col, op, val in d['conds']: 67 | conds.append([col, op, detokenize(val)]) 68 | return cls(d['sel'], d['agg'], conds) 69 | 70 | @classmethod 71 | def from_generated_dict(cls, d): 72 | conds = [] 73 | for col, op, val in d['conds']: 74 | end = len(val['words']) 75 | conds.append([col, op, detokenize(val)]) 76 | return cls(d['sel'], d['agg'], conds) 77 | 78 | @classmethod 79 | def from_sequence(cls, sequence, table, lowercase=True): 80 | sequence = deepcopy(sequence) 81 | if 'symend' in sequence['words']: 82 | end = sequence['words'].index('symend') 83 | for k, v in sequence.items(): 84 | sequence[k] = v[:end] 85 | terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] 86 | headers = [detokenize(h) for h in table['header']] 87 | 88 | # lowercase everything and truncate sequence 89 | if lowercase: 90 | headers = [h.lower() for h in headers] 91 | for i, t in enumerate(terms): 92 | for k, v in t.items(): 93 | t[k] = v.lower() 94 | headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] 95 | 96 | # get select 97 | if 'symselect' != terms.pop(0)['word']: 98 | raise Exception('Missing symselect operator') 99 | 100 | # get aggregation 101 | if 'symagg' != terms.pop(0)['word']: 102 | raise Exception('Missing symagg operator') 103 | agg_op = terms.pop(0)['word'] 104 | 105 | if agg_op == 'symcol': 106 | agg_op = '' 107 | else: 108 | if 'symcol' != terms.pop(0)['word']: 109 | raise Exception('Missing aggregation column') 110 | try: 111 | agg_op = cls.agg_ops.index(agg_op.upper()) 112 | except Exception as e: 113 | raise Exception('Invalid agg op {}'.format(agg_op)) 114 | 115 | def find_column(name): 116 | return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) 117 | 118 | def flatten(tokens): 119 | ret = {'words': [], 'after': [], 'gloss': []} 120 | for t in tokens: 121 | ret['words'].append(t['word']) 122 | ret['after'].append(t['after']) 123 | ret['gloss'].append(t['gloss']) 124 | return ret 125 | where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere'] 126 | where_index = where_index[0] if where_index else len(terms) 127 | flat = flatten(terms[:where_index]) 128 | try: 129 | agg_col = find_column(detokenize(flat)) 130 | except Exception as e: 131 | raise Exception('Cannot find aggregation column {}'.format(flat['words'])) 132 | where_terms = terms[where_index+1:] 133 | 134 | # get conditions 135 | conditions = [] 136 | while where_terms: 137 | t = where_terms.pop(0) 138 | flat = flatten(where_terms) 139 | if t['word'] != 'symcol': 140 | raise Exception('Missing conditional column {}'.format(flat['words'])) 141 | try: 142 | op_index = flat['words'].index('symop') 143 | col_tokens = flatten(where_terms[:op_index]) 144 | except Exception as e: 145 | raise Exception('Missing conditional operator {}'.format(flat['words'])) 146 | cond_op = where_terms[op_index+1]['word'] 147 | try: 148 | cond_op = cls.cond_ops.index(cond_op.upper()) 149 | except Exception as e: 150 | raise Exception('Invalid cond op {}'.format(cond_op)) 151 | try: 152 | cond_col = find_column(detokenize(col_tokens)) 153 | except Exception as e: 154 | raise Exception('Cannot find conditional column {}'.format(col_tokens['words'])) 155 | try: 156 | val_index = flat['words'].index('symcond') 157 | except Exception as e: 158 | raise Exception('Cannot find conditional value {}'.format(flat['words'])) 159 | 160 | where_terms = where_terms[val_index+1:] 161 | flat = flatten(where_terms) 162 | val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms) 163 | cond_val = detokenize(flatten(where_terms[:val_end_index])) 164 | conditions.append([cond_col, cond_op, cond_val]) 165 | where_terms = where_terms[val_end_index+1:] 166 | q = cls(agg_col, agg_op, conditions) 167 | return q 168 | 169 | @classmethod 170 | def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True): 171 | sequence = deepcopy(sequence) 172 | if 'symend' in sequence['words']: 173 | end = sequence['words'].index('symend') 174 | for k, v in sequence.items(): 175 | sequence[k] = v[:end] 176 | terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] 177 | headers = [detokenize(h) for h in table['header']] 178 | 179 | # lowercase everything and truncate sequence 180 | if lowercase: 181 | headers = [h.lower() for h in headers] 182 | for i, t in enumerate(terms): 183 | for k, v in t.items(): 184 | t[k] = v.lower() 185 | headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] 186 | 187 | def find_column(name): 188 | return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) 189 | 190 | def flatten(tokens): 191 | ret = {'words': [], 'after': [], 'gloss': []} 192 | for t in tokens: 193 | ret['words'].append(t['word']) 194 | ret['after'].append(t['after']) 195 | ret['gloss'].append(t['gloss']) 196 | return ret 197 | where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere'] 198 | where_index = where_index[0] if where_index else len(terms) 199 | where_terms = terms[where_index+1:] 200 | 201 | # get conditions 202 | conditions = [] 203 | while where_terms: 204 | t = where_terms.pop(0) 205 | flat = flatten(where_terms) 206 | if t['word'] != 'symcol': 207 | raise Exception('Missing conditional column {}'.format(flat['words'])) 208 | try: 209 | op_index = flat['words'].index('symop') 210 | col_tokens = flatten(where_terms[:op_index]) 211 | except Exception as e: 212 | raise Exception('Missing conditional operator {}'.format(flat['words'])) 213 | cond_op = where_terms[op_index+1]['word'] 214 | try: 215 | cond_op = cls.cond_ops.index(cond_op.upper()) 216 | except Exception as e: 217 | raise Exception('Invalid cond op {}'.format(cond_op)) 218 | try: 219 | cond_col = find_column(detokenize(col_tokens)) 220 | except Exception as e: 221 | raise Exception('Cannot find conditional column {}'.format(col_tokens['words'])) 222 | try: 223 | val_index = flat['words'].index('symcond') 224 | except Exception as e: 225 | raise Exception('Cannot find conditional value {}'.format(flat['words'])) 226 | 227 | where_terms = where_terms[val_index+1:] 228 | flat = flatten(where_terms) 229 | val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms) 230 | cond_val = detokenize(flatten(where_terms[:val_end_index])) 231 | conditions.append([cond_col, cond_op, cond_val]) 232 | where_terms = where_terms[val_end_index+1:] 233 | q = cls(agg_col, agg_op, conditions) 234 | return q 235 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /bert/LICENSE_bert: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with open(vocab_file, "r", encoding="utf-8") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | ids.append(vocab[token]) 89 | return ids 90 | 91 | 92 | def whitespace_tokenize(text): 93 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 94 | text = text.strip() 95 | if not text: 96 | return [] 97 | tokens = text.split() 98 | return tokens 99 | 100 | 101 | class FullTokenizer(object): 102 | """Runs end-to-end tokenziation.""" 103 | 104 | def __init__(self, vocab_file, do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 108 | 109 | def tokenize(self, text): 110 | split_tokens = [] 111 | for token in self.basic_tokenizer.tokenize(text): 112 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 113 | split_tokens.append(sub_token) 114 | 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | return convert_tokens_to_ids(self.vocab, tokens) 119 | 120 | 121 | class BasicTokenizer(object): 122 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 123 | 124 | def __init__(self, do_lower_case=True): 125 | """Constructs a BasicTokenizer. 126 | 127 | Args: 128 | do_lower_case: Whether to lower case the input. 129 | """ 130 | self.do_lower_case = do_lower_case 131 | 132 | def tokenize(self, text): 133 | """Tokenizes a piece of text.""" 134 | text = convert_to_unicode(text) 135 | text = self._clean_text(text) 136 | # This was added on November 1st, 2018 for the multilingual and Chinese 137 | # models. This is also applied to the English models now, but it doesn't 138 | # matter since the English models were not trained on any Chinese data 139 | # and generally don't have any Chinese data in them (there are Chinese 140 | # characters in the vocabulary because Wikipedia does have some Chinese 141 | # words in the English Wikipedia.). 142 | text = self._tokenize_chinese_chars(text) 143 | orig_tokens = whitespace_tokenize(text) 144 | split_tokens = [] 145 | for token in orig_tokens: 146 | if self.do_lower_case: 147 | token = token.lower() 148 | token = self._run_strip_accents(token) 149 | split_tokens.extend(self._run_split_on_punc(token)) 150 | 151 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 152 | return output_tokens 153 | 154 | def _run_strip_accents(self, text): 155 | """Strips accents from a piece of text.""" 156 | text = unicodedata.normalize("NFD", text) 157 | output = [] 158 | for char in text: 159 | cat = unicodedata.category(char) 160 | if cat == "Mn": 161 | continue 162 | output.append(char) 163 | return "".join(output) 164 | 165 | def _run_split_on_punc(self, text): 166 | """Splits punctuation on a piece of text.""" 167 | chars = list(text) 168 | i = 0 169 | start_new_word = True 170 | output = [] 171 | while i < len(chars): 172 | char = chars[i] 173 | if _is_punctuation(char): 174 | output.append([char]) 175 | start_new_word = True 176 | else: 177 | if start_new_word: 178 | output.append([]) 179 | start_new_word = False 180 | output[-1].append(char) 181 | i += 1 182 | 183 | return ["".join(x) for x in output] 184 | 185 | def _tokenize_chinese_chars(self, text): 186 | """Adds whitespace around any CJK character.""" 187 | output = [] 188 | for char in text: 189 | cp = ord(char) 190 | if self._is_chinese_char(cp): 191 | output.append(" ") 192 | output.append(char) 193 | output.append(" ") 194 | else: 195 | output.append(char) 196 | return "".join(output) 197 | 198 | def _is_chinese_char(self, cp): 199 | """Checks whether CP is the codepoint of a CJK character.""" 200 | # This defines a "chinese character" as anything in the CJK Unicode block: 201 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 202 | # 203 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 204 | # despite its name. The modern Korean Hangul alphabet is a different block, 205 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 206 | # space-separated words, so they are not treated specially and handled 207 | # like the all of the other languages. 208 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 209 | (cp >= 0x3400 and cp <= 0x4DBF) or # 210 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 211 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 212 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 213 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 214 | (cp >= 0xF900 and cp <= 0xFAFF) or # 215 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 216 | return True 217 | 218 | return False 219 | 220 | def _clean_text(self, text): 221 | """Performs invalid character removal and whitespace cleanup on text.""" 222 | output = [] 223 | for char in text: 224 | cp = ord(char) 225 | if cp == 0 or cp == 0xfffd or _is_control(char): 226 | continue 227 | if _is_whitespace(char): 228 | output.append(" ") 229 | else: 230 | output.append(char) 231 | return "".join(output) 232 | 233 | 234 | class WordpieceTokenizer(object): 235 | """Runs WordPiece tokenization.""" 236 | 237 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 238 | self.vocab = vocab 239 | self.unk_token = unk_token 240 | self.max_input_chars_per_word = max_input_chars_per_word 241 | 242 | def tokenize(self, text): 243 | """Tokenizes a piece of text into its word pieces. 244 | 245 | This uses a greedy longest-match-first algorithm to perform tokenization 246 | using the given vocabulary. 247 | 248 | For example: 249 | input = "unaffable" 250 | output = ["un", "##aff", "##able"] 251 | 252 | Args: 253 | text: A single token or whitespace separated tokens. This should have 254 | already been passed through `BasicTokenizer. 255 | 256 | Returns: 257 | A list of wordpiece tokens. 258 | """ 259 | 260 | text = convert_to_unicode(text) 261 | 262 | output_tokens = [] 263 | for token in whitespace_tokenize(text): 264 | chars = list(token) 265 | if len(chars) > self.max_input_chars_per_word: 266 | output_tokens.append(self.unk_token) 267 | continue 268 | 269 | is_bad = False 270 | start = 0 271 | sub_tokens = [] 272 | while start < len(chars): 273 | end = len(chars) 274 | cur_substr = None 275 | while start < end: 276 | substr = "".join(chars[start:end]) 277 | if start > 0: 278 | substr = "##" + substr 279 | if substr in self.vocab: 280 | cur_substr = substr 281 | break 282 | end -= 1 283 | if cur_substr is None: 284 | is_bad = True 285 | break 286 | sub_tokens.append(cur_substr) 287 | start = end 288 | 289 | if is_bad: 290 | output_tokens.append(self.unk_token) 291 | else: 292 | output_tokens.extend(sub_tokens) 293 | return output_tokens 294 | 295 | 296 | def _is_whitespace(char): 297 | """Checks whether `chars` is a whitespace character.""" 298 | # \t, \n, and \r are technically contorl characters but we treat them 299 | # as whitespace since they are generally considered as such. 300 | if char == " " or char == "\t" or char == "\n" or char == "\r": 301 | return True 302 | cat = unicodedata.category(char) 303 | if cat == "Zs": 304 | return True 305 | return False 306 | 307 | 308 | def _is_control(char): 309 | """Checks whether `chars` is a control character.""" 310 | # These are technically control characters but we count them as whitespace 311 | # characters. 312 | if char == "\t" or char == "\n" or char == "\r": 313 | return False 314 | cat = unicodedata.category(char) 315 | if cat.startswith("C"): 316 | return True 317 | return False 318 | 319 | 320 | def _is_punctuation(char): 321 | """Checks whether `chars` is a punctuation character.""" 322 | cp = ord(char) 323 | # We treat all non-letter/number ASCII as punctuation. 324 | # Characters such as "^", "$", and "`" are not in the Unicode 325 | # Punctuation class but we treat them as punctuation anyways, for 326 | # consistency. 327 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 328 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 329 | return True 330 | cat = unicodedata.category(char) 331 | if cat.startswith("P"): 332 | return True 333 | return False 334 | -------------------------------------------------------------------------------- /bert/README_bert.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Google AI's BERT model with a script to load Google's pre-trained models 2 | 3 | ## Forked for wikisql application 4 | 5 | ## NSML 6 | 7 | ### SQuAD1.1 finetuning 8 | 9 | ``` 10 | nsml run -d squad_bert -g 4 -e run_squad.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu" 11 | 12 | ``` 13 | 14 | ### SQuAD2.0 finetuning 15 | 16 | ``` 17 | nsml run -d squad_bert -g 4 -e run_squad2.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu" 18 | ``` 19 | 20 | ### Evaluation 21 | 22 | 1. Download prediction file from NSML session 23 | 24 | ``` 25 | nsml download -f /app/squad_base [NSML_ID]/squad_bert/[SESSION] . 26 | ``` 27 | 28 | 2. Run official evaluation file 29 | 30 | ``` 31 | python3 evaluate-v1.1.py [dev.json] [predictions.json] 32 | 33 | python3 evaluate-v2.0.py [dev.json] [predictions.json] -n [na_probs.json] 34 | ``` 35 | 36 | ## Introduction 37 | 38 | This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 39 | 40 | This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below). 41 | 42 | The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated). 43 | 44 | ## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models)) 45 | 46 | You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. 47 | 48 | This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`). 49 | 50 | You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. 51 | 52 | To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. 53 | 54 | Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: 55 | 56 | ```shell 57 | export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 58 | 59 | python convert_tf_checkpoint_to_pytorch.py \ 60 | --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ 61 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 62 | --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin 63 | ``` 64 | 65 | You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). 66 | 67 | ## PyTorch models for BERT 68 | 69 | We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py): 70 | 71 | - `BertModel` - the basic BERT Transformer model 72 | - `BertForSequenceClassification` - the BERT model with a sequence classification head on top 73 | - `BertForQuestionAnswering` - the BERT model with a token classification head on top 74 | 75 | Here are some details on each class. 76 | 77 | ### 1. `BertModel` 78 | 79 | `BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large). 80 | 81 | The inputs and output are **identical to the TensorFlow model inputs and outputs**. 82 | 83 | We detail them here. This model takes as inputs: 84 | 85 | - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and 86 | - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 87 | - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. 88 | 89 | This model outputs a tuple composed of: 90 | 91 | - `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and 92 | - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 93 | 94 | An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input. 95 | 96 | ### 2. `BertForSequenceClassification` 97 | 98 | `BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`. 99 | 100 | The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper). 101 | 102 | An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task. 103 | 104 | ### 3. `BertForQuestionAnswering` 105 | 106 | `BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states. 107 | 108 | The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper). 109 | 110 | An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task. 111 | 112 | ## Installation, requirements, test 113 | 114 | This code was tested on Python 3.5+. The requirements are: 115 | 116 | - PyTorch (>= 0.4.1) 117 | - tqdm 118 | 119 | To install the dependencies: 120 | 121 | ````bash 122 | pip install -r ./requirements.txt 123 | ```` 124 | 125 | A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`). 126 | 127 | You can run the tests with the command: 128 | ```bash 129 | python -m pytest -sv tests/ 130 | ``` 131 | 132 | ## Training on large batches: gradient accumulation, multi-GPU and distributed training 133 | 134 | BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32). 135 | 136 | To help with fine-tuning these models, we have included four techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: optimize on CPU, gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month. 137 | 138 | Here is how to use these techniques in our scripts: 139 | 140 | - **Optimize on CPU**: The Adam optimizer comprise 2 moving average of all the weights of the model which means that if you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal when using a large model like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU to free more room on the GPU(s). As the most computational intensive operation is the backward pass, this usually doesn't increase the computation time by a lot. This is the only way to fine-tune `BERT-large` in a reasonable time on GPU(s) (see below). Activate this option with `--optimize_on_cpu` on the `run_squad.py` script. 141 | - **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps. 142 | - **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs. 143 | - **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details): 144 | 145 | ```bash 146 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script) 147 | ``` 148 | 149 | Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`. 150 | 151 | ## TPU support and pretraining scripts 152 | 153 | TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)). 154 | 155 | We will add TPU support when this next release is published. 156 | 157 | The original TensorFlow code further comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py). 158 | 159 | Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts. 160 | 161 | ## Comparing the PyTorch model and the TensorFlow model predictions 162 | 163 | We also include [two Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. 164 | 165 | - The first NoteBook ([Comparing TF and PT models.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models. 166 | 167 | - The second NoteBook ([Comparing TF and PT models SQuAD predictions.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models%20SQuAD%20predictions.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models. 168 | 169 | Please follow the instructions given in the notebooks to run and modify them. They can also be nice example on how to use the models in a simpler way than the full fine-tuning scripts we provide. 170 | 171 | ## Fine-tuning with BERT: running the examples 172 | 173 | We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD. 174 | 175 | Before running these examples you should download the 176 | [GLUE data](https://gluebenchmark.com/tasks) by running 177 | [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) 178 | and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base` 179 | checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section. 180 | 181 | This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase 182 | Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80. 183 | 184 | ```shell 185 | export GLUE_DIR=/path/to/glue 186 | 187 | python run_classifier.py \ 188 | --task_name MRPC \ 189 | --do_train \ 190 | --do_eval \ 191 | --do_lower_case \ 192 | --data_dir $GLUE_DIR/MRPC/ \ 193 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 194 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 195 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 196 | --max_seq_length 128 \ 197 | --train_batch_size 32 \ 198 | --learning_rate 2e-5 \ 199 | --num_train_epochs 3.0 \ 200 | --output_dir /tmp/mrpc_output/ 201 | ``` 202 | 203 | Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%. 204 | 205 | The second example fine-tunes `BERT-Base` on the SQuAD question answering task. 206 | 207 | The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory. 208 | 209 | * [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json) 210 | * [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json) 211 | * [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py) 212 | 213 | ```shell 214 | export SQUAD_DIR=/path/to/SQUAD 215 | 216 | python run_squad.py \ 217 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 218 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 219 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 220 | --do_train \ 221 | --do_predict \ 222 | --do_lower_case 223 | --train_file $SQUAD_DIR/train-v1.1.json \ 224 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 225 | --train_batch_size 12 \ 226 | --learning_rate 3e-5 \ 227 | --num_train_epochs 2.0 \ 228 | --max_seq_length 384 \ 229 | --doc_stride 128 \ 230 | --output_dir ../debug_squad/ 231 | ``` 232 | 233 | Training with the previous hyper-parameters gave us the following results: 234 | ```bash 235 | {"f1": 88.52381567990474, "exact_match": 81.22043519394512} 236 | ``` 237 | 238 | # Fine-tuning BERT-large on GPUs 239 | 240 | The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation. 241 | 242 | For example, fine-tuning BERT-large on SQuAD can be done on a server with 4 k-80 (these are pretty old now) in 18 hours. Our results are similar to the TensorFlow implementation results (actually slightly higher): 243 | ```bash 244 | {"exact_match": 84.56953642384106, "f1": 91.04028647786927} 245 | ``` 246 | To get these results that we used a combination of: 247 | - multi-GPU training (automatically activated on a multi-GPU server), 248 | - 2 steps of gradient accumulation and 249 | - perform the optimization step on CPU to store Adam's averages in RAM. 250 | 251 | Here are the full list of hyper-parameters we used for this run: 252 | ```bash 253 | python ./run_squad.py --vocab_file $BERT_LARGE_DIR/vocab.txt --bert_config_file $BERT_LARGE_DIR/bert_config.json --init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin --do_lower_case --do_train --do_predict --train_file $SQUAD_TRAIN --predict_file $SQUAD_EVAL --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir $OUTPUT_DIR/bert_large_bsz_24 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu 254 | ``` 255 | 256 | -------------------------------------------------------------------------------- /train_decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Wonseok Hwang 2 | # Sep30, 2018 3 | import os, sys, argparse, re, json 4 | import random as python_random 5 | 6 | from matplotlib.pylab import * 7 | import torch.nn as nn 8 | import torch 9 | import torch.nn.functional as F 10 | # import torchvision.datasets as dsets 11 | 12 | # BERT 13 | import bert.tokenization as tokenization 14 | from bert.modeling import BertConfig, BertModel 15 | 16 | from sqlova.utils.utils_wikisql import * 17 | from sqlova.model.nl2sql.wikisql_models import * 18 | from sqlnet.dbengine import DBEngine 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | def construct_hyper_param(parser): 24 | parser.add_argument('--tepoch', default=200, type=int) 25 | parser.add_argument("--bS", default=4, type=int, 26 | help="Batch size") 27 | parser.add_argument("--accumulate_gradients", default=8, type=int, 28 | help="The number of accumulation of backpropagation to effectivly increase the batch size.") 29 | parser.add_argument('--fine_tune', 30 | default=True, 31 | action='store_true', 32 | help="If present, BERT is trained.") 33 | 34 | parser.add_argument("--model_type", default='FT_s2s_1', type=str, 35 | help="Type of model.") 36 | 37 | parser.add_argument('--aug', 38 | default=False, 39 | action='store_true', 40 | help="If present, aug.train.jsonl is used.") 41 | 42 | # 1.2 BERT Parameters 43 | parser.add_argument("--vocab_file", 44 | default='vocab.txt', type=str, 45 | help="The vocabulary file that the BERT model was trained on.") 46 | parser.add_argument("--max_seq_length", 47 | default=270, type=int, # Set based on maximum length of input tokens. 48 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 49 | "longer than this will be truncated, and sequences shorter than this will be padded.") 50 | parser.add_argument("--num_target_layers", 51 | default=1, type=int, 52 | help="The Number of final layers of BERT to be used in downstream task.") 53 | parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.') 54 | parser.add_argument('--seed', 55 | type=int, 56 | default=42, 57 | help="random seed for initialization") 58 | parser.add_argument('--no_pretraining', action='store_true', help='Use BERT pretrained model') 59 | parser.add_argument("--bert_type_abb", default='uS', type=str, 60 | help="Type of BERT model to load. e.g.) uS, uL, cS, cL, and mcS") 61 | parser.add_argument("--col_pool_type", default='start_tok', type=str, 62 | help="Which col-token shall be used? start_tok, end_tok, or avg are possible choices.") 63 | 64 | # 1.3 Seq-to-SQL module parameters 65 | parser.add_argument('--lS', default=2, type=int, help="The number of LSTM layers.") 66 | parser.add_argument('--dr', default=0.3, type=float, help="Dropout rate.") 67 | parser.add_argument('--lr', default=1e-3, type=float, help="Learning rate.") 68 | parser.add_argument("--hS", default=100, type=int, help="The dimension of hidden vector in the seq-to-SQL module.") 69 | 70 | 71 | # 1.4 Execution-guided decoding beam-size. It is used only in test.py 72 | parser.add_argument('--EG', 73 | default=False, 74 | action='store_true', 75 | help="If present, Execution guided decoding is used in test.") 76 | 77 | parser.add_argument('--beam_only', 78 | default=False, 79 | action='store_true', 80 | help="If present, no Execution guided while doing beam-searching.") 81 | 82 | parser.add_argument('--beam_size', 83 | type=int, 84 | default=4, 85 | help="The size of beam for smart decoding") 86 | 87 | # 1.5 S2S model 88 | 89 | parser.add_argument('--sql_vocab_type', 90 | type=int, 91 | default=0, 92 | help="Sql-vocab type") 93 | 94 | # 1.5 Arguments only for test.py 95 | parser.add_argument('--sn', default=42, type=int, help="The targetting session number.") 96 | parser.add_argument("--target_epoch", default='best', type=str, 97 | help="Targer epoch (the save name from nsml).") 98 | 99 | parser.add_argument("--tag", default='', type=str, 100 | help="Tag of saved files. e.g.) '', 'FT1', 'FT1_aug', 'no_pretraining', 'no_tuning',..") 101 | 102 | args = parser.parse_args() 103 | assert args.sql_vocab_type == 0 # type 0 is better than type 1 slightly.. although there seems to be some statistical fluctuation. 104 | 105 | map_bert_type_abb = {'uS': 'uncased_L-12_H-768_A-12', 106 | 'uL': 'uncased_L-24_H-1024_A-16', 107 | 'cS': 'cased_L-12_H-768_A-12', 108 | 'cL': 'cased_L-24_H-1024_A-16', 109 | 'mcS': 'multi_cased_L-12_H-768_A-12'} 110 | args.bert_type = map_bert_type_abb[args.bert_type_abb] 111 | print(f"BERT-type: {args.bert_type}") 112 | 113 | sql_vocab_list = [ 114 | ( 115 | "none", "max", "min", "count", "sum", "average", 116 | "select", "where", "and", 117 | "equal", "greater than", "less than", 118 | "start", "end" 119 | ), 120 | 121 | ( 122 | "sql none", "sql max", "sql min", "sql count", "sql sum", "sql average", 123 | "sql select", "sql where", "sql and", 124 | "sql equal", "sql greater than", "sql less than", 125 | "sql start", "sql end" 126 | ) 127 | ] 128 | args.sql_vocab = sql_vocab_list[args.sql_vocab_type] 129 | 130 | 131 | # 132 | # Decide whether to use lower_case. 133 | if args.bert_type_abb == 'cS' or args.bert_type_abb == 'cL' or args.bert_type_abb == 'mcS': 134 | args.do_lower_case = False 135 | else: 136 | args.do_lower_case = True 137 | 138 | # args.toy_model = not torch.cuda.is_available() 139 | args.toy_model = False 140 | args.toy_size = 32 141 | 142 | if args.model_type == 'FT_s2s_1': 143 | assert args.num_target_layers == 1 144 | assert args.fine_tune == True 145 | 146 | # Seeds for random number generation. 147 | seed(args.seed) 148 | python_random.seed(args.seed) 149 | np.random.seed(args.seed) 150 | torch.manual_seed(args.seed) 151 | if torch.cuda.is_available(): 152 | torch.cuda.manual_seed_all(args.seed) 153 | 154 | 155 | 156 | return args 157 | 158 | 159 | def get_bert(BERT_PT_PATH, bert_type, do_lower_case, no_pretraining): 160 | 161 | 162 | bert_config_file = os.path.join(BERT_PT_PATH, f'bert_config_{bert_type}.json') 163 | vocab_file = os.path.join(BERT_PT_PATH, f'vocab_{bert_type}.txt') 164 | init_checkpoint = os.path.join(BERT_PT_PATH, f'pytorch_model_{bert_type}.bin') 165 | 166 | 167 | 168 | bert_config = BertConfig.from_json_file(bert_config_file) 169 | tokenizer = tokenization.FullTokenizer( 170 | vocab_file=vocab_file, do_lower_case=do_lower_case) 171 | bert_config.print_status() 172 | 173 | model_bert = BertModel(bert_config) 174 | if no_pretraining: 175 | pass 176 | else: 177 | model_bert.load_state_dict(torch.load(init_checkpoint, map_location='cpu')) 178 | print("Load pre-trained parameters.") 179 | model_bert.to(device) 180 | 181 | return model_bert, tokenizer, bert_config 182 | 183 | def get_opt(model, model_bert, model_type): 184 | # if model_type == 'FT_Scalar_1': 185 | # # Model itself does not have trainable parameters. Thus, 186 | # opt_bert = torch.optim.Adam(list(filter(lambda p: p.requires_grad, model.parameters())) \ 187 | # # + list(model_bert.parameters()), 188 | # + list(filter(lambda p: p.requires_grad, model_bert.parameters())), 189 | # lr=args.lr, weight_decay=0) 190 | # opt = opt_bert # for consistency in interface 191 | if model_type == 'FT_s2s_1': 192 | opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 193 | lr=args.lr, weight_decay=0) 194 | 195 | opt_bert = torch.optim.Adam(filter(lambda p: p.requires_grad, model_bert.parameters()), 196 | lr=args.lr_bert, weight_decay=0) 197 | # opt = opt_bert 198 | else: 199 | raise NotImplementedError 200 | # opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 201 | # lr=args.lr, weight_decay=0) 202 | # 203 | # opt_bert = torch.optim.Adam(filter(lambda p: p.requires_grad, model_bert.parameters()), 204 | # lr=args.lr_bert, weight_decay=0) 205 | 206 | return opt, opt_bert 207 | 208 | def get_models(args, BERT_PT_PATH, trained=False, path_model_bert=None, path_model=None): 209 | # some constants 210 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 211 | cond_ops = ['=', '>', '<', 'OP'] # do not know why 'OP' required. Hence, 212 | 213 | print(f"Batch_size = {args.bS * args.accumulate_gradients}") 214 | print(f"BERT parameters:") 215 | print(f"learning rate: {args.lr_bert}") 216 | print(f"Fine-tune BERT: {args.fine_tune}") 217 | 218 | # Get BERT 219 | model_bert, tokenizer, bert_config = get_bert(BERT_PT_PATH, args.bert_type, args.do_lower_case, 220 | args.no_pretraining) 221 | args.iS = bert_config.hidden_size * args.num_target_layers # Seq-to-SQL input vector dimenstion 222 | 223 | # Get Seq-to-SQL 224 | 225 | n_cond_ops = len(cond_ops) 226 | n_agg_ops = len(agg_ops) 227 | print(f"Seq-to-SQL: the number of final BERT layers to be used: {args.num_target_layers}") 228 | print(f"Seq-to-SQL: the size of hidden dimension = {args.hS}") 229 | print(f"Seq-to-SQL: LSTM encoding layer size = {args.lS}") 230 | print(f"Seq-to-SQL: dropout rate = {args.dr}") 231 | print(f"Seq-to-SQL: learning rate = {args.lr}") 232 | model = FT_s2s_1(args.iS, args.hS, args.lS, args.dr, args.max_seq_length, n_cond_ops, n_agg_ops) 233 | model = model.to(device) 234 | 235 | if trained: 236 | assert path_model_bert != None 237 | assert path_model != None 238 | 239 | if torch.cuda.is_available(): 240 | res = torch.load(path_model_bert) 241 | else: 242 | res = torch.load(path_model_bert, map_location='cpu') 243 | model_bert.load_state_dict(res['model_bert']) 244 | model_bert.to(device) 245 | 246 | if torch.cuda.is_available(): 247 | res = torch.load(path_model) 248 | else: 249 | res = torch.load(path_model, map_location='cpu') 250 | 251 | model.load_state_dict(res['model']) 252 | 253 | return model, model_bert, tokenizer, bert_config 254 | 255 | def get_data(path_wikisql, args): 256 | train_data, train_table, dev_data, dev_table, _, _ = load_wikisql(path_wikisql, args.toy_model, args.toy_size, 257 | no_w2i=True, no_hs_tok=True, 258 | aug=args.aug) 259 | train_loader, dev_loader = get_loader_wikisql(train_data, dev_data, args.bS, shuffle_train=True) 260 | 261 | return train_data, train_table, dev_data, dev_table, train_loader, dev_loader 262 | 263 | 264 | def train(train_loader, train_table, model, model_bert, opt, tokenizer,sql_vocab, 265 | max_seq_length, accumulate_gradients=1, check_grad=False, 266 | st_pos=0, opt_bert=None, path_db=None, dset_name='train', col_pool_type='start_tok', aug=False): 267 | model.train() 268 | model_bert.train() 269 | 270 | ave_loss = 0 271 | cnt = 0 # count the # of examples 272 | cnt_x = 0 273 | cnt_lx = 0 # of logical form acc 274 | 275 | # Engine for SQL querying. 276 | engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) 277 | 278 | for iB, t in enumerate(train_loader): 279 | cnt += len(t) 280 | 281 | if cnt < st_pos: 282 | continue 283 | # Get fields 284 | nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, train_table, no_hs_t=True, no_sql_t=True) 285 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) 286 | # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. 287 | g_wvi_corenlp = get_g_wvi_corenlp(t) 288 | 289 | 290 | # g_wvi_corenlp = get_g_wvi_corenlp(t) 291 | all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, i_sql_vocab, \ 292 | l_n, l_hpu, l_hs, l_input, \ 293 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 294 | = get_bert_output_s2s(model_bert, tokenizer, nlu_t, hds, sql_vocab, max_seq_length) 295 | 296 | try: 297 | # 298 | g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) 299 | except: 300 | # Exception happens when where-condition is not found in nlu_tt. 301 | # In this case, that train example is not used. 302 | # During test, that example considered as wrongly answered. 303 | # e.g. train: 32. 304 | continue 305 | 306 | 307 | # Generate g_pnt_idx 308 | g_pnt_idxs = gen_g_pnt_idx(g_wvi, sql_i, i_hds, i_sql_vocab, col_pool_type=col_pool_type) 309 | pnt_start_tok = i_sql_vocab[0][-2][0] 310 | pnt_end_tok = i_sql_vocab[0][-1][0] 311 | # check 312 | # print(array(tokens[0])[g_pnt_idxs[0]]) 313 | wenc_s2s = all_encoder_layer[-1] 314 | 315 | # wemb_h = [B, max_header_number, hS] 316 | cls_vec = pooled_output 317 | 318 | score = model(wenc_s2s, l_input, cls_vec, pnt_start_tok, g_pnt_idxs=g_pnt_idxs) 319 | 320 | 321 | # Calculate loss & step 322 | loss = Loss_s2s(score, g_pnt_idxs) 323 | 324 | # Calculate gradient 325 | if iB % accumulate_gradients == 0: # mode 326 | # at start, perform zero_grad 327 | opt.zero_grad() 328 | opt_bert.zero_grad() 329 | loss.backward() 330 | if accumulate_gradients == 1: 331 | opt.step() 332 | opt_bert.step() 333 | elif iB % accumulate_gradients == (accumulate_gradients-1): 334 | # at the final, take step with accumulated graident 335 | loss.backward() 336 | opt.step() 337 | opt_bert.step() 338 | else: 339 | # at intermediate stage, just accumulates the gradients 340 | loss.backward() 341 | 342 | if check_grad: 343 | named_parameters = model.named_parameters() 344 | 345 | mu_list, sig_list = get_mean_grad(named_parameters) 346 | 347 | grad_abs_mean_mean = mean(mu_list) 348 | grad_abs_mean_sig = std(mu_list) 349 | grad_abs_sig_mean = mean(sig_list) 350 | else: 351 | grad_abs_mean_mean = 1 352 | grad_abs_mean_sig = 1 353 | grad_abs_sig_mean = 1 354 | 355 | # Prediction 356 | pr_pnt_idxs = pred_pnt_idxs(score, pnt_start_tok, pnt_end_tok) 357 | # generate pr_sql_q 358 | # pr_sql_q_rough = generate_sql_q_s2s(pr_pnt_idxs, tokens, tb) 359 | # g_sql_q_rough = generate_sql_q_s2s(g_pnt_idxs, tokens, tb) 360 | 361 | g_i_vg_list, g_i_vg_sub_list = gen_i_vg_from_pnt_idxs(g_pnt_idxs, i_sql_vocab, i_nlu, i_hds) 362 | 363 | g_sql_q_s2s, g_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, g_pnt_idxs, g_i_vg_list, 364 | g_i_vg_sub_list) 365 | 366 | pr_i_vg_list, pr_i_vg_sub_list = gen_i_vg_from_pnt_idxs(pr_pnt_idxs, i_sql_vocab, i_nlu, i_hds) 367 | 368 | pr_sql_q_s2s, pr_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, 369 | pr_pnt_idxs, pr_i_vg_list, pr_i_vg_sub_list) 370 | 371 | g_sql_q = generate_sql_q(sql_i, tb) 372 | 373 | try: 374 | pr_sql_q = generate_sql_q(pr_sql_i, tb) 375 | # gen pr_sc, pr_sa 376 | pr_sc = [] 377 | pr_sa = [] 378 | for pr_sql_i1 in pr_sql_i: 379 | pr_sc.append(pr_sql_i1["sel"]) 380 | pr_sa.append(pr_sql_i1["agg"]) 381 | except: 382 | bS = len(sql_i) 383 | pr_sql_q = ['NA'] * bS 384 | pr_sc = ['NA'] * bS 385 | pr_sa = ['NA'] * bS 386 | 387 | 388 | # Cacluate accuracy 389 | cnt_lx1_list = get_cnt_lx_list_s2s(g_pnt_idxs, pr_pnt_idxs) 390 | 391 | if not aug: 392 | cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) 393 | else: 394 | cnt_x1_list = [0] * len(t) 395 | g_ans = ['N/A (data augmented'] * len(t) 396 | pr_ans = ['N/A (data augmented'] * len(t) 397 | 398 | 399 | # statistics 400 | ave_loss += loss.item() 401 | 402 | # count 403 | cnt_lx += sum(cnt_lx1_list) 404 | cnt_x += sum(cnt_x1_list) 405 | 406 | ave_loss /= cnt 407 | acc_lx = cnt_lx / cnt 408 | acc_x = cnt_x / cnt 409 | 410 | acc = [ave_loss, acc_lx, acc_x] 411 | aux_out = [grad_abs_mean_mean, grad_abs_mean_sig, grad_abs_sig_mean] 412 | 413 | return acc, aux_out 414 | 415 | def report_detail(hds, nlu, 416 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, 417 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, 418 | cnt_list, current_cnt): 419 | cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x = current_cnt 420 | 421 | print(f'cnt = {cnt} / {cnt_tot} ===============================') 422 | 423 | print(f'headers: {hds}') 424 | print(f'nlu: {nlu}') 425 | 426 | # print(f's_sc: {s_sc[0]}') 427 | # print(f's_sa: {s_sa[0]}') 428 | # print(f's_wn: {s_wn[0]}') 429 | # print(f's_wc: {s_wc[0]}') 430 | # print(f's_wo: {s_wo[0]}') 431 | # print(f's_wv: {s_wv[0][0]}') 432 | print(f'===============================') 433 | print(f'g_sc : {g_sc}') 434 | print(f'pr_sc: {pr_sc}') 435 | print(f'g_sa : {g_sa}') 436 | print(f'pr_sa: {pr_sa}') 437 | print(f'g_wn : {g_wn}') 438 | print(f'pr_wn: {pr_wn}') 439 | print(f'g_wc : {g_wc}') 440 | print(f'pr_wc: {pr_wc}') 441 | print(f'g_wo : {g_wo}') 442 | print(f'pr_wo: {pr_wo}') 443 | print(f'g_wv : {g_wv}') 444 | # print(f'pr_wvi: {pr_wvi}') 445 | print('g_wv_str:', g_wv_str) 446 | print('p_wv_str:', pr_wv_str) 447 | print(f'g_sql_q: {g_sql_q}') 448 | print(f'pr_sql_q: {pr_sql_q}') 449 | print(f'g_ans: {g_ans}') 450 | print(f'pr_ans: {pr_ans}') 451 | print(f'--------------------------------') 452 | 453 | print(cnt_list) 454 | 455 | print(f'acc_lx = {cnt_lx/cnt:.3f}, acc_x = {cnt_x/cnt:.3f}\n', 456 | f'acc_sc = {cnt_sc/cnt:.3f}, acc_sa = {cnt_sa/cnt:.3f}, acc_wn = {cnt_wn/cnt:.3f}\n', 457 | f'acc_wc = {cnt_wc/cnt:.3f}, acc_wo = {cnt_wo/cnt:.3f}, acc_wv = {cnt_wv/cnt:.3f}') 458 | print(f'===============================') 459 | 460 | def test(data_loader, data_table, model, model_bert, tokenizer, sql_vocab, 461 | max_seq_length, 462 | detail=False, st_pos=0, cnt_tot=1, EG=False, beam_only=True, beam_size=4, 463 | path_db=None, dset_name='test', col_pool_type='start_tok', aug=False, 464 | ): 465 | model.eval() 466 | model_bert.eval() 467 | 468 | ave_loss = 0 469 | cnt = 0 470 | cnt_lx = 0 471 | cnt_x = 0 472 | results = [] 473 | cnt_list = [] 474 | 475 | engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) 476 | 477 | for iB, t in enumerate(data_loader): 478 | 479 | cnt += len(t) 480 | if cnt < st_pos: 481 | continue 482 | # Get fields 483 | nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) 484 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) 485 | g_wvi_corenlp = get_g_wvi_corenlp(t) 486 | 487 | all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, i_sql_vocab, \ 488 | l_n, l_hpu, l_hs, l_input, \ 489 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 490 | = get_bert_output_s2s(model_bert, tokenizer, nlu_t, hds, sql_vocab, max_seq_length) 491 | try: 492 | # 493 | g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) 494 | except: 495 | # Exception happens when where-condition is not found in nlu_tt. 496 | # In this case, that train example is not used. 497 | # During test, that example considered as wrongly answered. 498 | # e.g. train: 32. 499 | for b in range(len(nlu)): 500 | results1 = {} 501 | results1["error"] = "Skip happened" 502 | results1["nlu"] = nlu[b] 503 | results1["table_id"] = tb[b]["id"] 504 | results.append(results1) 505 | continue 506 | 507 | # Generate g_pnt_idx 508 | g_pnt_idxs = gen_g_pnt_idx(g_wvi, sql_i, i_hds, i_sql_vocab, col_pool_type=col_pool_type) 509 | pnt_start_tok = i_sql_vocab[0][-2][0] 510 | pnt_end_tok = i_sql_vocab[0][-1][0] 511 | # check 512 | # print(array(tokens[0])[g_pnt_idxs[0]]) 513 | wenc_s2s = all_encoder_layer[-1] 514 | 515 | # wemb_h = [B, max_header_number, hS] 516 | cls_vec = pooled_output 517 | 518 | if not EG: 519 | score = model(wenc_s2s, l_input, cls_vec, pnt_start_tok,) 520 | loss = Loss_s2s(score, g_pnt_idxs) 521 | 522 | pr_pnt_idxs = pred_pnt_idxs(score, pnt_start_tok, pnt_end_tok) 523 | else: 524 | # EG 525 | pr_pnt_idxs, p_list, pnt_list_beam = model.EG_forward(wenc_s2s, l_input, cls_vec, 526 | pnt_start_tok, pnt_end_tok, 527 | i_sql_vocab, i_nlu, i_hds, # for EG 528 | tokens, nlu, nlu_t, hds, tt_to_t_idx, # for EG 529 | tb, engine, 530 | beam_size, beam_only=beam_only) 531 | if beam_only: 532 | loss = torch.tensor([0]) 533 | else: 534 | # print('EG on!') 535 | loss = torch.tensor([1]) 536 | 537 | 538 | g_i_vg_list, g_i_vg_sub_list = gen_i_vg_from_pnt_idxs(g_pnt_idxs, i_sql_vocab, i_nlu, i_hds) 539 | g_sql_q_s2s, g_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, 540 | g_pnt_idxs, g_i_vg_list, 541 | g_i_vg_sub_list) 542 | 543 | pr_i_vg_list, pr_i_vg_sub_list = gen_i_vg_from_pnt_idxs(pr_pnt_idxs, i_sql_vocab, i_nlu, i_hds) 544 | pr_sql_q_s2s, pr_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, 545 | pr_pnt_idxs, pr_i_vg_list, pr_i_vg_sub_list) 546 | 547 | g_sql_q = generate_sql_q(sql_i, tb) 548 | 549 | try: 550 | pr_sql_q = generate_sql_q(pr_sql_i, tb) 551 | # gen pr_sc, pr_sa 552 | pr_sc = [] 553 | pr_sa = [] 554 | for pr_sql_i1 in pr_sql_i: 555 | pr_sc.append(pr_sql_i1["sel"]) 556 | pr_sa.append(pr_sql_i1["agg"]) 557 | except: 558 | bS = len(sql_i) 559 | pr_sql_q = ['NA'] * bS 560 | pr_sc = ['NA'] * bS 561 | pr_sa = ['NA'] * bS 562 | 563 | for b, pr_sql_i1 in enumerate(pr_sql_i): 564 | results1 = {} 565 | results1["query"] = pr_sql_i1 566 | results1["table_id"] = tb[b]["id"] 567 | results1["nlu"] = nlu[b] 568 | results.append(results1) 569 | 570 | # Cacluate accuracy 571 | cnt_lx1_list = get_cnt_lx_list_s2s(g_pnt_idxs, pr_pnt_idxs) 572 | 573 | if not aug: 574 | cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) 575 | else: 576 | cnt_x1_list = [0] * len(t) 577 | g_ans = ['N/A (data augmented'] * len(t) 578 | pr_ans = ['N/A (data augmented'] * len(t) 579 | 580 | # statistics 581 | ave_loss += loss.item() 582 | 583 | # count 584 | cnt_lx += sum(cnt_lx1_list) 585 | cnt_x += sum(cnt_x1_list) 586 | # report 587 | if detail: 588 | print(f"Ground T : {g_pnt_idxs}") 589 | print(f"Prediction: {pr_pnt_idxs}") 590 | print(f"Ground T : {g_sql_q}") 591 | print(f"Prediction: {pr_sql_q}") 592 | 593 | 594 | ave_loss /= cnt 595 | 596 | acc_lx = cnt_lx / cnt 597 | acc_x = cnt_x / cnt 598 | 599 | acc = [ave_loss, acc_lx, acc_x] 600 | return acc, results 601 | 602 | 603 | def print_result(epoch, acc, dname): 604 | ave_loss, acc_lx, acc_x = acc 605 | 606 | print(f'{dname} results ------------') 607 | print( 608 | f" Epoch: {epoch}, ave loss: {ave_loss}, acc_lx: {acc_lx:.3f}, acc_x: {acc_x:.3f}" 609 | ) 610 | 611 | if __name__ == '__main__': 612 | 613 | ## 1. Hyper parameters 614 | parser = argparse.ArgumentParser() 615 | args = construct_hyper_param(parser) 616 | 617 | ## 2. Paths 618 | path_h = '/home/wonseok' 619 | path_wikisql = os.path.join(path_h, 'data', 'wikisql_tok') 620 | BERT_PT_PATH = path_wikisql 621 | 622 | path_save_for_evaluation = './' 623 | 624 | ## 3. Load data 625 | train_data, train_table, dev_data, dev_table, train_loader, dev_loader = get_data(path_wikisql, args) 626 | 627 | ## 4. Build & Load models 628 | model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH) 629 | 630 | 631 | ## 5. Get optimizers 632 | opt, opt_bert = get_opt(model, model_bert, args.model_type) 633 | 634 | ## 6. Train 635 | acc_lx_t_best = -1 636 | epoch_best = -1 637 | for epoch in range(args.tepoch): 638 | # train 639 | acc_train, aux_out_train = train(train_loader, 640 | train_table, 641 | model, 642 | model_bert, 643 | opt, 644 | tokenizer, 645 | args.sql_vocab, 646 | args.max_seq_length, 647 | args.accumulate_gradients, 648 | opt_bert=opt_bert, 649 | st_pos=0, 650 | path_db=path_wikisql, 651 | dset_name='train', 652 | col_pool_type=args.col_pool_type, 653 | aug=args.aug) 654 | 655 | # check DEV 656 | with torch.no_grad(): 657 | acc_dev, results_dev = test(dev_loader, 658 | dev_table, 659 | model, 660 | model_bert, 661 | tokenizer, 662 | args.sql_vocab, 663 | args.max_seq_length, 664 | detail=False, 665 | path_db=path_wikisql, 666 | st_pos=0, 667 | dset_name='dev', EG=args.EG, 668 | col_pool_type=args.col_pool_type, 669 | aug=args.aug) 670 | 671 | 672 | print_result(epoch, acc_train, 'train') 673 | print_result(epoch, acc_dev, 'dev') 674 | 675 | # save results for the offical evaluation 676 | save_for_evaluation(path_save_for_evaluation, results_dev, 'dev') 677 | 678 | # save best model 679 | # Based on Dev Set logical accuracy lx 680 | acc_lx_t = acc_dev[-2] 681 | if acc_lx_t > acc_lx_t_best: 682 | acc_lx_t_best = acc_lx_t 683 | epoch_best = epoch 684 | # save best model 685 | state = {'model': model.state_dict()} 686 | torch.save(state, os.path.join('.', 'model_best.pt')) 687 | 688 | state = {'model_bert': model_bert.state_dict()} 689 | torch.save(state, os.path.join('.', 'model_bert_best.pt')) 690 | 691 | print(f" Best Dev lx acc: {acc_lx_t_best} at epoch: {epoch_best}") 692 | -------------------------------------------------------------------------------- /train_shallow_layer.py: -------------------------------------------------------------------------------- 1 | # Wonseok Hwang 2 | # Sep30, 2018 3 | import os, sys, argparse, re, json 4 | import random as python_random 5 | 6 | from matplotlib.pylab import * 7 | import torch.nn as nn 8 | import torch 9 | import torch.nn.functional as F 10 | # import torchvision.datasets as dsets 11 | 12 | # BERT 13 | import bert.tokenization as tokenization 14 | from bert.modeling import BertConfig, BertModel 15 | 16 | from sqlova.utils.utils_wikisql import * 17 | from sqlova.model.nl2sql.wikisql_models import * 18 | from sqlnet.dbengine import DBEngine 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | def construct_hyper_param(parser): 24 | parser.add_argument('--tepoch', default=200, type=int) 25 | parser.add_argument("--bS", default=16, type=int, 26 | help="Batch size") 27 | parser.add_argument("--accumulate_gradients", default=2, type=int, 28 | help="The number of accumulation of backpropagation to effectivly increase the batch size.") 29 | parser.add_argument('--fine_tune', 30 | default=True, 31 | action='store_true', 32 | help="If present, BERT is trained.") 33 | 34 | parser.add_argument("--model_type", default='FT_Scalar_1', type=str, 35 | help="Type of model.") 36 | 37 | parser.add_argument('--aug', 38 | default=False, 39 | action='store_true', 40 | help="If present, aug.train.jsonl is used.") 41 | 42 | # 1.2 BERT Parameters 43 | parser.add_argument("--vocab_file", 44 | default='vocab.txt', type=str, 45 | help="The vocabulary file that the BERT model was trained on.") 46 | parser.add_argument("--max_seq_length", 47 | default=222, type=int, # Set based on maximum length of input tokens. 48 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 49 | "longer than this will be truncated, and sequences shorter than this will be padded.") 50 | parser.add_argument("--num_target_layers", 51 | default=1, type=int, 52 | help="The Number of final layers of BERT to be used in downstream task.") 53 | parser.add_argument('--lr_bert', default=6e-6, type=float, help='BERT model learning rate.') 54 | parser.add_argument('--seed', 55 | type=int, 56 | default=42, 57 | help="random seed for initialization") 58 | parser.add_argument('--no_pretraining', action='store_true', help='Use BERT pretrained model') 59 | parser.add_argument("--bert_type_abb", default='uS', type=str, 60 | help="Type of BERT model to load. e.g.) uS, uL, cS, cL, and mcS") 61 | parser.add_argument("--col_pool_type", default='start_tok', type=str, 62 | help="Which col-token shall be used? start_tok, end_tok, or avg are possible choices.") 63 | 64 | # 1.3 Seq-to-SQL module parameters 65 | parser.add_argument('--lS', default=2, type=int, help="The number of LSTM layers.") 66 | parser.add_argument('--dr', default=0.3, type=float, help="Dropout rate.") 67 | parser.add_argument('--lr', default=1e-5, type=float, help="Learning rate.") 68 | parser.add_argument("--hS", default=100, type=int, help="The dimension of hidden vector in the seq-to-SQL module.") 69 | 70 | 71 | # 1.4 Execution-guided decoding beam-size. It is used only in test.py 72 | parser.add_argument('--EG', 73 | default=False, 74 | action='store_true', 75 | help="If present, Execution guided decoding is used in test.") 76 | parser.add_argument('--beam_size', 77 | type=int, 78 | default=4, 79 | help="The size of beam for smart decoding") 80 | 81 | 82 | # 1.5 Arguments only for test.py 83 | parser.add_argument('--sn', default=42, type=int, help="The targetting session number.") 84 | parser.add_argument("--target_epoch", default='best', type=str, 85 | help="Targer epoch (the save name from nsml).") 86 | 87 | parser.add_argument("--tag", default='', type=str, 88 | help="Tag of saved files. e.g.) '', 'FT1', 'FT1_aug', 'no_pretraining', 'no_tuning',..") 89 | 90 | args = parser.parse_args() 91 | 92 | map_bert_type_abb = {'uS': 'uncased_L-12_H-768_A-12', 93 | 'uL': 'uncased_L-24_H-1024_A-16', 94 | 'cS': 'cased_L-12_H-768_A-12', 95 | 'cL': 'cased_L-24_H-1024_A-16', 96 | 'mcS': 'multi_cased_L-12_H-768_A-12'} 97 | args.bert_type = map_bert_type_abb[args.bert_type_abb] 98 | print(f"BERT-type: {args.bert_type}") 99 | 100 | # 101 | # Decide whether to use lower_case. 102 | if args.bert_type_abb == 'cS' or args.bert_type_abb == 'cL' or args.bert_type_abb == 'mcS': 103 | args.do_lower_case = False 104 | else: 105 | args.do_lower_case = True 106 | 107 | # args.toy_model = not torch.cuda.is_available() 108 | args.toy_model = not True 109 | args.toy_size = 32 110 | if args.model_type == 'FT_Scalar_1': 111 | assert args.num_target_layers == 1 112 | assert args.fine_tune == True 113 | 114 | 115 | # Seeds for random number generation. 116 | seed(args.seed) 117 | python_random.seed(args.seed) 118 | np.random.seed(args.seed) 119 | torch.manual_seed(args.seed) 120 | if torch.cuda.is_available(): 121 | torch.cuda.manual_seed_all(args.seed) 122 | 123 | return args 124 | 125 | 126 | def get_bert(BERT_PT_PATH, bert_type, do_lower_case, no_pretraining): 127 | 128 | 129 | bert_config_file = os.path.join(BERT_PT_PATH, f'bert_config_{bert_type}.json') 130 | vocab_file = os.path.join(BERT_PT_PATH, f'vocab_{bert_type}.txt') 131 | init_checkpoint = os.path.join(BERT_PT_PATH, f'pytorch_model_{bert_type}.bin') 132 | 133 | 134 | 135 | bert_config = BertConfig.from_json_file(bert_config_file) 136 | tokenizer = tokenization.FullTokenizer( 137 | vocab_file=vocab_file, do_lower_case=do_lower_case) 138 | bert_config.print_status() 139 | 140 | model_bert = BertModel(bert_config) 141 | if no_pretraining: 142 | pass 143 | else: 144 | model_bert.load_state_dict(torch.load(init_checkpoint, map_location='cpu')) 145 | print("Load pre-trained parameters.") 146 | model_bert.to(device) 147 | 148 | return model_bert, tokenizer, bert_config 149 | 150 | def get_opt(model, model_bert, model_type): 151 | if model_type == 'FT_Scalar_1': 152 | # Model itself does not have trainable parameters. Thus, 153 | opt_bert = torch.optim.Adam(list(filter(lambda p: p.requires_grad, model.parameters())) \ 154 | # + list(model_bert.parameters()), 155 | + list(filter(lambda p: p.requires_grad, model_bert.parameters())), 156 | lr=args.lr_bert, weight_decay=0) 157 | opt = opt_bert # for consistency in interface 158 | else: 159 | raise NotImplementedError 160 | # opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 161 | # lr=args.lr, weight_decay=0) 162 | # 163 | # opt_bert = torch.optim.Adam(filter(lambda p: p.requires_grad, model_bert.parameters()), 164 | # lr=args.lr_bert, weight_decay=0) 165 | 166 | return opt, opt_bert 167 | 168 | def get_models(args, BERT_PT_PATH, trained=False, path_model_bert=None, path_model=None): 169 | # some constants 170 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 171 | cond_ops = ['=', '>', '<', 'OP'] # do not know why 'OP' required. Hence, 172 | 173 | print(f"Batch_size = {args.bS * args.accumulate_gradients}") 174 | print(f"BERT parameters:") 175 | print(f"learning rate: {args.lr_bert}") 176 | print(f"Fine-tune BERT: {args.fine_tune}") 177 | 178 | # Get BERT 179 | model_bert, tokenizer, bert_config = get_bert(BERT_PT_PATH, args.bert_type, args.do_lower_case, 180 | args.no_pretraining) 181 | args.iS = bert_config.hidden_size * args.num_target_layers # Seq-to-SQL input vector dimenstion 182 | 183 | # Get Seq-to-SQL 184 | 185 | n_cond_ops = len(cond_ops) 186 | n_agg_ops = len(agg_ops) 187 | print(f"Seq-to-SQL: the number of final BERT layers to be used: {args.num_target_layers}") 188 | print(f"Seq-to-SQL: the size of hidden dimension = {args.hS}") 189 | print(f"Seq-to-SQL: LSTM encoding layer size = {args.lS}") 190 | print(f"Seq-to-SQL: dropout rate = {args.dr}") 191 | print(f"Seq-to-SQL: learning rate = {args.lr}") 192 | model = FT_Scalar_1(args.iS, args.hS, args.lS, args.dr, n_cond_ops, n_agg_ops) 193 | model = model.to(device) 194 | 195 | if trained: 196 | assert path_model_bert != None 197 | assert path_model != None 198 | 199 | if torch.cuda.is_available(): 200 | res = torch.load(path_model_bert) 201 | else: 202 | res = torch.load(path_model_bert, map_location='cpu') 203 | model_bert.load_state_dict(res['model_bert']) 204 | model_bert.to(device) 205 | 206 | if torch.cuda.is_available(): 207 | res = torch.load(path_model) 208 | else: 209 | res = torch.load(path_model, map_location='cpu') 210 | 211 | model.load_state_dict(res['model']) 212 | 213 | return model, model_bert, tokenizer, bert_config 214 | 215 | def get_data(path_wikisql, args): 216 | train_data, train_table, dev_data, dev_table, _, _ = load_wikisql(path_wikisql, args.toy_model, args.toy_size, 217 | no_w2i=True, no_hs_tok=True, 218 | aug=args.aug) 219 | train_loader, dev_loader = get_loader_wikisql(train_data, dev_data, args.bS, shuffle_train=True) 220 | 221 | return train_data, train_table, dev_data, dev_table, train_loader, dev_loader 222 | 223 | 224 | def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, 225 | max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=False, 226 | st_pos=0, opt_bert=None, path_db=None, dset_name='train', col_pool_type='start_tok', aug=False): 227 | model.train() 228 | model_bert.train() 229 | 230 | ave_loss = 0 231 | cnt = 0 # count the # of examples 232 | cnt_sc = 0 # count the # of correct predictions of select column 233 | cnt_sa = 0 # of selectd aggregation 234 | cnt_wn = 0 # of where number 235 | cnt_wc = 0 # of where column 236 | cnt_wo = 0 # of where operator 237 | cnt_wv = 0 # of where-value 238 | cnt_wvi = 0 # of where-value index (on question tokens) 239 | cnt_lx = 0 # of logical form acc 240 | cnt_x = 0 # of execution acc 241 | 242 | # Engine for SQL querying. 243 | engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) 244 | 245 | for iB, t in enumerate(train_loader): 246 | cnt += len(t) 247 | 248 | if cnt < st_pos: 249 | continue 250 | # Get fields 251 | nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, train_table, no_hs_t=True, no_sql_t=True) 252 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) 253 | # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. 254 | g_wvi_corenlp = get_g_wvi_corenlp(t) 255 | 256 | 257 | all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ 258 | l_n, l_hpu, l_hs, \ 259 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 260 | = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length) 261 | 262 | try: 263 | # 264 | g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) 265 | except: 266 | # Exception happens when where-condition is not found in nlu_tt. 267 | # In this case, that train example is not used. 268 | # During test, that example considered as wrongly answered. 269 | # e.g. train: 32. 270 | continue 271 | 272 | wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, 273 | bert_config.num_hidden_layers, all_encoder_layer, 1) 274 | wemb_h = get_wemb_h_FT_Scalar_1(i_hds, l_hs, bert_config.hidden_size, all_encoder_layer, 275 | col_pool_type=col_pool_type) 276 | # wemb_h = [B, max_header_number, hS] 277 | cls_vec = pooled_output 278 | 279 | # model specific part 280 | # get g_wvi (it is idex for word-piece tok) 281 | # score 282 | s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hs, cls_vec, 283 | g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_wo=g_wo, g_wvi=g_wvi) 284 | 285 | # Calculate loss & step 286 | loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) 287 | 288 | # Calculate gradient 289 | if iB % accumulate_gradients == 0: # mode 290 | # at start, perform zero_grad 291 | opt.zero_grad() 292 | if opt_bert: 293 | opt_bert.zero_grad() 294 | loss.backward() 295 | if accumulate_gradients == 1: 296 | opt.step() 297 | if opt_bert: 298 | opt_bert.step() 299 | elif iB % accumulate_gradients == (accumulate_gradients-1): 300 | # at the final, take step with accumulated graident 301 | loss.backward() 302 | opt.step() 303 | if opt_bert: 304 | opt_bert.step() 305 | else: 306 | # at intermediate stage, just accumulates the gradients 307 | loss.backward() 308 | 309 | if check_grad: 310 | named_parameters = model.named_parameters() 311 | 312 | mu_list, sig_list = get_mean_grad(named_parameters) 313 | 314 | grad_abs_mean_mean = mean(mu_list) 315 | grad_abs_mean_sig = std(mu_list) 316 | grad_abs_sig_mean = mean(sig_list) 317 | else: 318 | grad_abs_mean_mean = 1 319 | grad_abs_mean_sig = 1 320 | grad_abs_sig_mean = 1 321 | 322 | # Prediction 323 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) 324 | pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) 325 | pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) 326 | 327 | 328 | # Cacluate accuracy 329 | cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ 330 | cnt_wc1_list, cnt_wo1_list, \ 331 | cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, 332 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, 333 | sql_i, pr_sql_i, 334 | mode='train') 335 | 336 | cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, 337 | cnt_wo1_list, cnt_wv1_list) 338 | # lx stands for logical form accuracy 339 | 340 | # Execution accuracy test. 341 | if not aug: 342 | cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) 343 | else: 344 | cnt_x1_list = [0] * len(t) 345 | g_ans = ['N/A (data augmented'] * len(t) 346 | pr_ans = ['N/A (data augmented'] * len(t) 347 | # statistics 348 | ave_loss += loss.item() 349 | 350 | # count 351 | cnt_sc += sum(cnt_sc1_list) 352 | cnt_sa += sum(cnt_sa1_list) 353 | cnt_wn += sum(cnt_wn1_list) 354 | cnt_wc += sum(cnt_wc1_list) 355 | cnt_wo += sum(cnt_wo1_list) 356 | cnt_wvi += sum(cnt_wvi1_list) 357 | cnt_wv += sum(cnt_wv1_list) 358 | cnt_lx += sum(cnt_lx1_list) 359 | cnt_x += sum(cnt_x1_list) 360 | 361 | ave_loss /= cnt 362 | acc_sc = cnt_sc / cnt 363 | acc_sa = cnt_sa / cnt 364 | acc_wn = cnt_wn / cnt 365 | acc_wc = cnt_wc / cnt 366 | acc_wo = cnt_wo / cnt 367 | acc_wvi = cnt_wv / cnt 368 | acc_wv = cnt_wv / cnt 369 | acc_lx = cnt_lx / cnt 370 | acc_x = cnt_x / cnt 371 | 372 | acc = [ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x] 373 | aux_out = [grad_abs_mean_mean, grad_abs_mean_sig, grad_abs_sig_mean] 374 | 375 | return acc, aux_out 376 | 377 | def report_detail(hds, nlu, 378 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, 379 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, 380 | cnt_list, current_cnt): 381 | cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x = current_cnt 382 | 383 | print(f'cnt = {cnt} / {cnt_tot} ===============================') 384 | 385 | print(f'headers: {hds}') 386 | print(f'nlu: {nlu}') 387 | 388 | # print(f's_sc: {s_sc[0]}') 389 | # print(f's_sa: {s_sa[0]}') 390 | # print(f's_wn: {s_wn[0]}') 391 | # print(f's_wc: {s_wc[0]}') 392 | # print(f's_wo: {s_wo[0]}') 393 | # print(f's_wv: {s_wv[0][0]}') 394 | print(f'===============================') 395 | print(f'g_sc : {g_sc}') 396 | print(f'pr_sc: {pr_sc}') 397 | print(f'g_sa : {g_sa}') 398 | print(f'pr_sa: {pr_sa}') 399 | print(f'g_wn : {g_wn}') 400 | print(f'pr_wn: {pr_wn}') 401 | print(f'g_wc : {g_wc}') 402 | print(f'pr_wc: {pr_wc}') 403 | print(f'g_wo : {g_wo}') 404 | print(f'pr_wo: {pr_wo}') 405 | print(f'g_wv : {g_wv}') 406 | # print(f'pr_wvi: {pr_wvi}') 407 | print('g_wv_str:', g_wv_str) 408 | print('p_wv_str:', pr_wv_str) 409 | print(f'g_sql_q: {g_sql_q}') 410 | print(f'pr_sql_q: {pr_sql_q}') 411 | print(f'g_ans: {g_ans}') 412 | print(f'pr_ans: {pr_ans}') 413 | print(f'--------------------------------') 414 | 415 | print(cnt_list) 416 | 417 | print(f'acc_lx = {cnt_lx/cnt:.3f}, acc_x = {cnt_x/cnt:.3f}\n', 418 | f'acc_sc = {cnt_sc/cnt:.3f}, acc_sa = {cnt_sa/cnt:.3f}, acc_wn = {cnt_wn/cnt:.3f}\n', 419 | f'acc_wc = {cnt_wc/cnt:.3f}, acc_wo = {cnt_wo/cnt:.3f}, acc_wv = {cnt_wv/cnt:.3f}') 420 | print(f'===============================') 421 | 422 | def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, 423 | max_seq_length, 424 | num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, 425 | path_db=None, dset_name='test', col_pool_type='start_tok', aug=False): 426 | model.eval() 427 | model_bert.eval() 428 | 429 | ave_loss = 0 430 | cnt = 0 431 | cnt_sc = 0 432 | cnt_sa = 0 433 | cnt_wn = 0 434 | cnt_wc = 0 435 | cnt_wo = 0 436 | cnt_wv = 0 437 | cnt_wvi = 0 438 | cnt_lx = 0 439 | cnt_x = 0 440 | 441 | cnt_list = [] 442 | p_list = [] # List of prediction probabilities. 443 | data_list = [] # Miscellanerous data. Save it for later convenience of analysis. 444 | 445 | engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) 446 | results = [] 447 | for iB, t in enumerate(data_loader): 448 | 449 | cnt += len(t) 450 | if cnt < st_pos: 451 | continue 452 | # Get fields 453 | nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) 454 | 455 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) 456 | g_wvi_corenlp = get_g_wvi_corenlp(t) 457 | 458 | all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ 459 | l_n, l_hpu, l_hs, \ 460 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 461 | = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length) 462 | 463 | try: 464 | g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) 465 | g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) 466 | 467 | except: 468 | # Exception happens when where-condition is not found in nlu_tt. 469 | # In this case, that train example is not used. 470 | # During test, that example considered as wrongly answered. 471 | for b in range(len(nlu)): 472 | results1 = {} 473 | results1["error"] = "Skip happened" 474 | results1["nlu"] = nlu[b] 475 | results1["table_id"] = tb[b]["id"] 476 | results.append(results1) 477 | continue 478 | 479 | # model specific part 480 | # score 481 | wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, 482 | bert_config.num_hidden_layers, all_encoder_layer, 1) 483 | wemb_h = get_wemb_h_FT_Scalar_1(i_hds, l_hs, bert_config.hidden_size, all_encoder_layer, 484 | col_pool_type=col_pool_type) 485 | # wemb_h = [B, max_header_number, hS] 486 | cls_vec = pooled_output 487 | # No Execution guided decoding 488 | if not EG: 489 | 490 | s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hs, cls_vec) 491 | 492 | # get loss & step 493 | loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) 494 | 495 | # prediction 496 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv,) 497 | pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) 498 | # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) 499 | pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) 500 | 501 | # calculate probability 502 | p_tot, p_select, p_where, p_sc, p_sa, p_wn, p_wc, p_wo, p_wvi \ 503 | = cal_prob(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, 504 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi) 505 | 506 | else: 507 | # Execution guided decoding 508 | pr_sc_best, pr_sa_best, pr_wn_based_on_prob, pr_wvi_best, \ 509 | pr_sql_i, p_tot, p_select, p_where, p_sc_best, p_sa_best, \ 510 | p_wn_best, p_wc_best, p_wo_best, p_wvi_best\ 511 | = model.forward_EG(wemb_n, l_n, wemb_h, l_hs, cls_vec, engine, tb, 512 | nlu_t, nlu_tt, tt_to_t_idx, nlu, 513 | beam_size=beam_size) 514 | 515 | 516 | pr_sc = pr_sc_best 517 | pr_sa = pr_sa_best 518 | pr_wn = pr_wn_based_on_prob 519 | 520 | p_sc = p_sc_best 521 | p_sa = p_sa_best 522 | p_wn = p_wn_best 523 | 524 | # sort and generate: prob-based-sort (descending) -> wc-idx-based-sort (ascending) 525 | pr_wc, pr_wo, pr_wv_str, pr_wvi, pr_sql_i, \ 526 | p_wc, p_wo, p_wvi = sort_and_generate_pr_w(pr_sql_i, pr_wvi_best, p_wc_best, p_wo_best, p_wvi_best) 527 | 528 | # Follosing variables are just for the consistency with no-EG case. 529 | pr_wv_str_wp=None 530 | loss = torch.tensor([0]) 531 | 532 | p_list_batch = [p_tot, p_select, p_where, p_sc, p_sa, p_wn, p_wc, p_wo, p_wvi ] 533 | p_list.append(p_list_batch) 534 | 535 | g_sql_q = generate_sql_q(sql_i, tb) 536 | pr_sql_q = generate_sql_q(pr_sql_i, tb) 537 | 538 | # Saving for the official evaluation later. 539 | for b, pr_sql_i1 in enumerate(pr_sql_i): 540 | results1 = {} 541 | results1["query"] = pr_sql_i1 542 | results1["table_id"] = tb[b]["id"] 543 | results1["nlu"] = nlu[b] 544 | results.append(results1) 545 | 546 | cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ 547 | cnt_wc1_list, cnt_wo1_list, \ 548 | cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa,g_wn, g_wc,g_wo, g_wvi, 549 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, 550 | sql_i, pr_sql_i, 551 | mode='test') 552 | 553 | cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, 554 | cnt_wo1_list, cnt_wv1_list) 555 | 556 | # Execution accura y test 557 | cnt_x1_list = [] 558 | # lx stands for logical form accuracy 559 | 560 | # Execution accuracy test. 561 | if not aug: 562 | cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) 563 | else: 564 | cnt_x1_list = [0] * len(t) 565 | g_ans = ['N/A (data augmented'] * len(t) 566 | pr_ans = ['N/A (data augmented'] * len(t) 567 | # stat 568 | ave_loss += loss.item() 569 | 570 | # count 571 | cnt_sc += sum(cnt_sc1_list) 572 | cnt_sa += sum(cnt_sa1_list) 573 | cnt_wn += sum(cnt_wn1_list) 574 | cnt_wc += sum(cnt_wc1_list) 575 | cnt_wo += sum(cnt_wo1_list) 576 | cnt_wv += sum(cnt_wv1_list) 577 | cnt_wvi += sum(cnt_wvi1_list) 578 | cnt_lx += sum(cnt_lx1_list) 579 | cnt_x += sum(cnt_x1_list) 580 | 581 | current_cnt = [cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x] 582 | cnt_list_batch = [cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, 583 | cnt_x1_list] 584 | cnt_list.append(cnt_list_batch) 585 | # report 586 | if detail: 587 | report_detail(hds, nlu, 588 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, 589 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, 590 | cnt_list_batch, current_cnt) 591 | data_batch = [] 592 | for b, nlu1 in enumerate(nlu): 593 | data1 = [nlu[b], nlu_t[b], sql_i[b], g_sql_q[b], g_ans[b], 594 | pr_sql_i[b], pr_sql_q[b], pr_ans[b], tb[b]] 595 | data_batch.append(data1) 596 | 597 | data_list.append(data_batch) 598 | 599 | ave_loss /= cnt 600 | acc_sc = cnt_sc / cnt 601 | acc_sa = cnt_sa / cnt 602 | acc_wn = cnt_wn / cnt 603 | acc_wc = cnt_wc / cnt 604 | acc_wo = cnt_wo / cnt 605 | acc_wvi = cnt_wvi / cnt 606 | acc_wv = cnt_wv / cnt 607 | acc_lx = cnt_lx / cnt 608 | acc_x = cnt_x / cnt 609 | 610 | acc = [ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x] 611 | return acc, results, cnt_list, p_list, data_list 612 | 613 | 614 | def print_result(epoch, acc, dname): 615 | ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x = acc 616 | 617 | print(f'{dname} results ------------') 618 | print( 619 | f" Epoch: {epoch}, ave loss: {ave_loss}, acc_sc: {acc_sc:.3f}, acc_sa: {acc_sa:.3f}, acc_wn: {acc_wn:.3f}, \ 620 | acc_wc: {acc_wc:.3f}, acc_wo: {acc_wo:.3f}, acc_wvi: {acc_wvi:.3f}, acc_wv: {acc_wv:.3f}, acc_lx: {acc_lx:.3f}, acc_x: {acc_x:.3f}" 621 | ) 622 | 623 | if __name__ == '__main__': 624 | 625 | ## 1. Hyper parameters 626 | parser = argparse.ArgumentParser() 627 | args = construct_hyper_param(parser) 628 | 629 | ## 2. Paths 630 | path_h = '/home/wonseok' 631 | path_wikisql = os.path.join(path_h, 'data', 'wikisql_tok') 632 | BERT_PT_PATH = path_wikisql 633 | 634 | path_save_for_evaluation = './' 635 | 636 | ## 3. Load data 637 | train_data, train_table, dev_data, dev_table, train_loader, dev_loader = get_data(path_wikisql, args) 638 | 639 | ## 4. Build & Load models 640 | model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH) 641 | 642 | # nsml binding 643 | 644 | ## 5. Get optimizers 645 | opt, opt_bert = get_opt(model, model_bert, args.model_type) 646 | 647 | ## 6. Train 648 | acc_lx_t_best = -1 649 | epoch_best = -1 650 | for epoch in range(args.tepoch): 651 | # train 652 | acc_train, aux_out_train = train(train_loader, 653 | train_table, 654 | model, 655 | model_bert, 656 | opt, 657 | bert_config, 658 | tokenizer, 659 | args.max_seq_length, 660 | args.num_target_layers, 661 | args.accumulate_gradients, 662 | opt_bert=opt_bert, 663 | st_pos=0, 664 | path_db=path_wikisql, 665 | dset_name='train', 666 | col_pool_type=args.col_pool_type, 667 | aug=args.aug) 668 | 669 | # check DEV 670 | with torch.no_grad(): 671 | acc_dev, results_dev, cnt_list_dev, p_list_dev, data_list_dev = test(dev_loader, 672 | dev_table, 673 | model, 674 | model_bert, 675 | bert_config, 676 | tokenizer, 677 | args.max_seq_length, 678 | args.num_target_layers, 679 | detail=False, 680 | path_db=path_wikisql, 681 | st_pos=0, 682 | dset_name='dev', EG=args.EG, 683 | col_pool_type=args.col_pool_type, 684 | beam_size=args.beam_size, 685 | aug=args.aug) 686 | 687 | 688 | print_result(epoch, acc_train, 'train') 689 | print_result(epoch, acc_dev, 'dev') 690 | 691 | # save results for the offical evaluation 692 | save_for_evaluation(path_save_for_evaluation, results_dev, 'dev') 693 | 694 | # save best model 695 | # Based on Dev Set logical accuracy lx 696 | acc_lx_t = acc_dev[-2] 697 | if acc_lx_t > acc_lx_t_best: 698 | acc_lx_t_best = acc_lx_t 699 | epoch_best = epoch 700 | # save best model 701 | state = {'model': model.state_dict()} 702 | torch.save(state, os.path.join('.', 'model_best.pt')) 703 | 704 | state = {'model_bert': model_bert.state_dict()} 705 | torch.save(state, os.path.join('.', 'model_bert_best.pt')) 706 | 707 | print(f" Best Dev lx acc: {acc_lx_t_best} at epoch: {epoch_best}") 708 | --------------------------------------------------------------------------------