├── .gitignore ├── README.md ├── requirements.txt └── zero-shot-text-to-SQL ├── annotate.py ├── evaluate.py ├── job_config.sh ├── lib ├── .DS_Store ├── __init__.py ├── common.py ├── dbengine.py ├── query.py └── table.py ├── opts.py ├── preprocess.py ├── run.sh ├── table ├── .DS_Store ├── Beam.py ├── IO.py ├── Loss.py ├── ModelConstructor.py ├── Models.py ├── Optim.py ├── ParseResult.py ├── Trainer.py ├── Translator.py ├── Utils.py ├── __init__.py └── modules │ ├── .DS_Store │ ├── Embeddings.py │ ├── Gate.py │ ├── GlobalAttention.py │ ├── LockedDropout.py │ ├── StackedRNN.py │ ├── UtilClass.py │ ├── WeightDrop.py │ ├── WeightNorm.py │ ├── __init__.py │ └── cross_entropy_smooth.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # environment 9 | venv 10 | 11 | # checkpoints 12 | checkpoints 13 | saves 14 | 15 | # mac decoration file 16 | .DS_Store 17 | 18 | # pycharm files 19 | .idea 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # DotEnv configuration 73 | .env 74 | 75 | # Database 76 | *.db 77 | *.rdb 78 | 79 | # Pycharm 80 | .idea 81 | 82 | # VS Code 83 | .vscode/ 84 | 85 | # Spyder 86 | .spyproject/ 87 | 88 | # Jupyter NB Checkpoints 89 | .ipynb_checkpoints/ 90 | 91 | # exclude data from source control by default 92 | 93 | # Mac OS-specific storage files 94 | .DS_Store 95 | 96 | # vim 97 | *.swp 98 | *.swo 99 | 100 | # Mypy cache 101 | .mypy_cache/ 102 | 103 | nl_table_data/ 104 | 105 | # data/result/ 106 | 107 | data/datasets 108 | 109 | data/result/all_negative_df.tsv 110 | 111 | data/interim/full_partial_inputs/intermediate/* 112 | 113 | !data/interim/full_partial_inputs/intermediate/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-shot Text-to-SQL Learning with Auxiliary Task 2 | Code for [Zero-shot Text-to-SQL Learning with Auxiliary Task](https://arxiv.org/pdf/1908.11052.pdf) 3 | 4 | ## Usage 5 | 6 | ### Conda Environments 7 | Please use Python 3.6 and Pytorch 1.3. Other Python dependency is in requirement.txt. Install Python dependency with: 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ### Download Data 13 | [Data](https://drive.google.com/file/d/1UQmL-F5tGUqAit35ybto7kk-3emkqtgE/view?usp=sharing) can be found from google drive. Please download them and extract them into root path. 14 | 15 | ### Generate our respilted WikiSQL data 16 | ``` 17 | cd data_model/wikisql 18 | python make_zs.py 19 | python make_fs.py 20 | 21 | ``` 22 | 23 | ### Run the model on original WikiSQL and our split 24 | 25 | ``` 26 | cd zero-shot-text-to-SQL 27 | ./run.sh GPU_ID 28 | ``` 29 | 30 | ## Acknowledgement 31 | - This implementation is based on [coarse2fine](https://github.com/donglixp/coarse2fine). 32 | - The preprocessing and evaluation code used for WikiSQL is from [salesforce/WikiSQL](https://github.com/salesforce/WikiSQL). -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.0 2 | torchtext==0.2.0 3 | nltk==3.2.5 4 | tensorboard_logger==0.0.4 5 | records==0.5.2 6 | more_itertools==3.2.0 7 | six==1.11.0 8 | Babel==2.5.1 9 | tabulate==0.8.1 10 | tqdm==4.19.8 11 | forked_path==0.2.3 12 | scikit_learn==0.19.1 13 | stanza==0.3 14 | json_lines==0.5.0 15 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/annotate.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 2 | import os 3 | import records 4 | # import ujson as json 5 | import json 6 | from stanza.nlp.corenlp import CoreNLPClient 7 | from tqdm import tqdm 8 | import copy 9 | from lib.common import count_lines, detokenize 10 | from lib.query import Query, agg_ops, cond_ops 11 | 12 | 13 | client = None 14 | 15 | 16 | def annotate(sentence, lower=True): 17 | global client 18 | if client is None: 19 | client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(',')) 20 | words, gloss, after = [], [], [] 21 | for s in client.annotate(sentence): 22 | for t in s: 23 | words.append(t.word) 24 | gloss.append(t.originalText) 25 | after.append(t.after) 26 | if lower: 27 | words = [w.lower() for w in words] 28 | return { 29 | 'gloss': gloss, 30 | 'words': words, 31 | 'after': after, 32 | } 33 | 34 | 35 | def annotate_example(example, table): 36 | ann = {'table_id': example['table_id']} 37 | ann['question'] = annotate(example['question']) 38 | ann['table'] = { 39 | 'header': [annotate(h) for h in table['header']], 40 | } 41 | ann['query'] = sql = copy.deepcopy(example['sql']) 42 | for c in ann['query']['conds']: 43 | c[-1] = annotate(str(c[-1])) 44 | 45 | q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format( 46 | agg_ops[sql['agg']], table['header'][sql['sel']]) 47 | q2 = ['SYMCOL {} SYMOP {} SYMCOND {}'.format( 48 | table['header'][col], cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds']] 49 | if q2: 50 | q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND' 51 | else: 52 | q2 = 'SYMEND' 53 | inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question} SYMEND'.format( 54 | syms=' '.join(['SYM' + s for s in Query.syms]), 55 | table=' '.join(['SYMCOL ' + s for s in table['header']]), 56 | question=example['question'], 57 | aggops=' '.join([s for s in agg_ops]), 58 | condops=' '.join([s for s in cond_ops]), 59 | ) 60 | ann['seq_input'] = annotate(inp) 61 | out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1 62 | ann['seq_output'] = annotate(out) 63 | ann['where_output'] = annotate(q2) 64 | assert 'symend' in ann['seq_output']['words'] 65 | assert 'symend' in ann['where_output']['words'] 66 | return ann 67 | 68 | 69 | def is_valid_example(e): 70 | if not all([h['words'] for h in e['table']['header']]): 71 | return False 72 | headers = [detokenize(h).lower() for h in e['table']['header']] 73 | if len(headers) != len(set(headers)): 74 | return False 75 | input_vocab = set(e['seq_input']['words']) 76 | for w in e['seq_output']['words']: 77 | if w not in input_vocab: 78 | print('query word "{}" is not in input vocabulary.\n{}'.format( 79 | w, e['seq_input']['words'])) 80 | return False 81 | input_vocab = set(e['question']['words']) 82 | for col, op, cond in e['query']['conds']: 83 | for w in cond['words']: 84 | if w not in input_vocab: 85 | print('cond word "{}" is not in input vocabulary.\n{}'.format( 86 | w, e['question']['words'])) 87 | return False 88 | return True 89 | 90 | 91 | if __name__ == '__main__': 92 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 93 | parser.add_argument( 94 | '--din', default='', help='data directory') 95 | parser.add_argument( 96 | '--dout', default='', help='output directory') 97 | args = parser.parse_args() 98 | 99 | if not os.path.isdir(args.dout): 100 | os.makedirs(args.dout) 101 | 102 | for split in ['train', 'dev', 'test']: 103 | fsplit = os.path.join(args.din, split) + '.jsonl' 104 | ftable = os.path.join(args.din, split) + '.tables.jsonl' 105 | fout = os.path.join(args.dout, split) + '.jsonl' 106 | 107 | print('annotating {}'.format(fsplit)) 108 | with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: 109 | print('loading tables') 110 | tables = {} 111 | for line in tqdm(ft, total=count_lines(ftable)): 112 | d = json.loads(line) 113 | tables[d['id']] = d 114 | print('loading examples') 115 | n_written = 0 116 | for line in tqdm(fs, total=count_lines(fsplit)): 117 | d = json.loads(line) 118 | a = annotate_example(d, tables[d['table_id']]) 119 | if not is_valid_example(a): 120 | raise Exception(str(a)) 121 | 122 | gold = Query.from_tokenized_dict(a['query']) 123 | reconstruct = Query.from_sequence( 124 | a['seq_output'], a['table'], lowercase=True) 125 | if gold.lower() != reconstruct.lower(): 126 | raise Exception( 127 | 'Expected:\n{}\nGot:\n{}'.format(gold, reconstruct)) 128 | a['id'] = '{}-{}'.format(split, n_written) 129 | fo.write(json.dumps(a) + '\n') 130 | n_written += 1 131 | print('wrote {} examples'.format(n_written)) 132 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from builtins import bytes 3 | import os 4 | import argparse 5 | import math 6 | import codecs 7 | import torch 8 | import sys 9 | import table 10 | import table.IO 11 | import opts 12 | import random 13 | from itertools import takewhile, count 14 | try: 15 | from itertools import zip_longest 16 | except ImportError: 17 | from itertools import izip_longest as zip_longest 18 | import glob 19 | import json 20 | from tqdm import tqdm 21 | from lib.dbengine import DBEngine 22 | from lib.query import Query 23 | 24 | parser = argparse.ArgumentParser(description='evaluate.py') 25 | opts.translate_opts(parser) 26 | opt = parser.parse_args() 27 | torch.cuda.set_device(opt.gpu) 28 | #annotated 29 | 30 | if opt.split == 'finaltest': 31 | split = 'test' 32 | else: 33 | split = opt.split 34 | if opt.unseen_table!='full': 35 | opt.anno = os.path.join( 36 | opt.data_path, 'annotated_ent_'+opt.unseen_table+'/{}.jsonl'.format(split)) 37 | # source 38 | opt.source_file = os.path.join( 39 | opt.data_path, 'data_'+opt.unseen_table+'/{}.jsonl'.format(split)) 40 | # DB 41 | opt.db_file = os.path.join(opt.data_path, 'data/{}.db'.format(split)) 42 | else: 43 | opt.anno = os.path.join( 44 | opt.data_path, 'annotated_ent/{}.jsonl'.format(split)) 45 | #source 46 | opt.source_file = os.path.join( 47 | opt.data_path, 'data/{}.jsonl'.format(split)) 48 | #DB 49 | opt.db_file = os.path.join(opt.data_path, 'data/{}.db'.format(split)) 50 | 51 | #pretrained embedding 52 | opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding') 53 | 54 | 55 | def main(): 56 | dummy_parser = argparse.ArgumentParser(description='train.py') 57 | opts.model_opts(dummy_parser) 58 | opts.train_opts(dummy_parser) 59 | dummy_opt = dummy_parser.parse_known_args([])[0] 60 | 61 | engine = DBEngine(opt.db_file) 62 | 63 | with codecs.open(opt.source_file, "r", "utf-8") as corpus_file: 64 | sql_list = [json.loads(line)['sql'] for line in corpus_file] 65 | 66 | js_list = table.IO.read_anno_json(opt.anno) 67 | 68 | prev_best = (None, None) 69 | print(opt.split, opt.model_path) 70 | 71 | num_models=0 72 | 73 | f_out=open('Two-stream-' +opt.unseen_table+'-out-case','w') 74 | 75 | for fn_model in glob.glob(opt.model_path): 76 | num_models += 1 77 | sys.stdout.flush() 78 | print(fn_model) 79 | print(opt.anno) 80 | opt.model = fn_model 81 | 82 | translator = table.Translator(opt, dummy_opt.__dict__) 83 | data = table.IO.TableDataset(js_list, translator.fields, None, False) 84 | #torch.save(data, open( 'data.pt', 'wb')) 85 | test_data = table.IO.OrderedIterator( 86 | dataset=data, device=opt.gpu, batch_size=opt.batch_size, train=False, sort=True, sort_within_batch=False) 87 | 88 | # inference 89 | r_list = [] 90 | for batch in test_data: 91 | r_list += translator.translate(batch) 92 | r_list.sort(key=lambda x: x.idx) 93 | assert len(r_list) == len(js_list), 'len(r_list) != len(js_list): {} != {}'.format( 94 | len(r_list), len(js_list)) 95 | # evaluation 96 | error_cases = [] 97 | for pred, gold, sql_gold in zip(r_list, js_list, sql_list): 98 | error_cases.append(pred.eval(opt.split, gold, sql_gold, engine)) 99 | # error_cases.append(pred.eval(opt.split, gold, sql_gold)) 100 | print('Results:') 101 | for metric_name in ('all', 'exe', 'agg', 'sel', 'where', 'col', 'span', 'lay','BIO','BIO_col'): 102 | c_correct = sum((x.correct[metric_name] for x in r_list)) 103 | print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct, 104 | len(r_list), c_correct / len(r_list))) 105 | if metric_name=='all': 106 | all_acc=c_correct 107 | if metric_name=='exe': 108 | exe_acc=c_correct 109 | if prev_best[0] is None or all_acc+exe_acc >prev_best[1]+prev_best[2]: 110 | prev_best = (fn_model, all_acc, exe_acc) 111 | 112 | # random.shuffle(error_cases) 113 | for error_case in error_cases: 114 | if len(error_case) == 0: 115 | continue 116 | json.dump(error_case,f_out) 117 | f_out.write('\n') 118 | # print('table_id:\t', error_case['table_id']) 119 | # print('question_id:\t',error_case['question_id']) 120 | # print('question:\t', error_case['question']) 121 | # print('table_head:\t', error_case['table_head']) 122 | # print('table_content:\t', error_case['table_content']) 123 | # print() 124 | 125 | # print(error_case['BIO']) 126 | # print(error_case['BIO_col']) 127 | # print() 128 | 129 | # print('gold:','agg:',error_case['gold']['agg'],'sel:',error_case['predict']['sel']) 130 | # for i in range(len(error_case['gold']['conds'])): 131 | # print(error_case['gold']['conds'][i]) 132 | 133 | # print('predict:','agg:',error_case['predict']['agg'],'sel:',error_case['predict']['sel']) 134 | # for i in range(len(error_case['predict']['conds'])): 135 | # print(error_case['predict']['conds'][i]) 136 | # print('\n\n') 137 | 138 | 139 | print(prev_best) 140 | if (opt.split == 'dev') and (prev_best[0] is not None) and num_models!=1: 141 | if opt.unseen_table=='full': 142 | with codecs.open(os.path.join(opt.save_path, 'dev_best.txt'), 'w', encoding='utf-8') as f_out: 143 | f_out.write('{}\n'.format(prev_best[0])) 144 | else: 145 | with codecs.open(os.path.join(opt.save_path, 'dev_best_'+opt.unseen_table+'.txt'), 'w', encoding='utf-8') as f_out: 146 | f_out.write('{}\n'.format(prev_best[0])) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/job_config.sh: -------------------------------------------------------------------------------- 1 | model_config='run.zero-shot-text-to-SQL' 2 | DATANAME='wikisql' 3 | GPU_ID=$1 4 | 5 | DATA_DIR=../data_model/$DATANAME #data, annotated_data, embedding fold are in DATA_DIR 6 | SAVE_PATH=$DATA_DIR/$model_config #processed data and model are in SAVE_PATH 7 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/lib/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JD-AI-Research-Silicon-Valley/auxiliary-task-for-text-to-sql/9c0ff806cabab9e06b1b7fd0fac557bae79ff610/zero-shot-text-to-SQL/lib/.DS_Store -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JD-AI-Research-Silicon-Valley/auxiliary-task-for-text-to-sql/9c0ff806cabab9e06b1b7fd0fac557bae79ff610/zero-shot-text-to-SQL/lib/__init__.py -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/lib/common.py: -------------------------------------------------------------------------------- 1 | def count_lines(fname): 2 | with open(fname) as f: 3 | return sum(1 for line in f) 4 | 5 | 6 | def detokenize(tokens): 7 | ret = '' 8 | for g, a in zip(tokens['gloss'], tokens['after']): 9 | ret += g + a 10 | return ret.strip() 11 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/lib/dbengine.py: -------------------------------------------------------------------------------- 1 | import records 2 | import re 3 | from babel.numbers import parse_decimal, NumberFormatError 4 | from lib.query import Query 5 | 6 | 7 | schema_re = re.compile(r'\((.+)\)') 8 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') 9 | 10 | 11 | class DBEngine: 12 | 13 | def __init__(self, fdb): 14 | self.db = records.Database('sqlite:///{}'.format(fdb)) 15 | 16 | def execute_query(self, table_id, query, *args, **kwargs): 17 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 18 | 19 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 20 | if not table_id.startswith('table'): 21 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 22 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql 23 | schema_str = schema_re.findall(table_info)[0] 24 | schema = {} 25 | for tup in schema_str.split(', '): 26 | c, t = tup.split() 27 | schema[c] = t 28 | select = 'col{}'.format(select_index) 29 | agg = Query.agg_ops[aggregation_index] 30 | if agg: 31 | select = '{}({})'.format(agg, select) 32 | where_clause = [] 33 | where_map = {} 34 | for col_index, op, val in conditions: 35 | if lower and isinstance(val, str): 36 | val = val.lower() 37 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 38 | try: 39 | val = float(parse_decimal(val)) 40 | except NumberFormatError as e: 41 | val = float(num_re.findall(val)[0]) 42 | where_clause.append('col{} {} :col{}'.format(col_index, Query.cond_ops[op], col_index)) 43 | where_map['col{}'.format(col_index)] = val 44 | where_str = '' 45 | if where_clause: 46 | where_str = 'WHERE ' + ' AND '.join(where_clause) 47 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 48 | out = self.db.query(query, **where_map) 49 | return [o.result for o in out] 50 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/lib/query.py: -------------------------------------------------------------------------------- 1 | from lib.common import detokenize 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | import re 5 | 6 | 7 | re_whitespace = re.compile(r'\s+', flags=re.UNICODE) 8 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 9 | cond_ops = ['=', '>', '<', 'OP'] 10 | 11 | 12 | class Query: 13 | 14 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 15 | cond_ops = ['=', '>', '<', 'OP'] 16 | syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'] 17 | 18 | def __init__(self, sel_index, agg_index, conditions=tuple()): 19 | self.sel_index = sel_index 20 | self.agg_index = agg_index 21 | self.conditions = list(conditions) 22 | 23 | def __eq__(self, other): 24 | if isinstance(other, self.__class__): 25 | indices = self.sel_index == other.sel_index and self.agg_index == other.agg_index 26 | conds = [(col, op, cond.lower() if isinstance(cond, str) else cond) for col, op, cond in self.conditions] == [(col, op, cond.lower() if isinstance(cond, str) else cond) for col, op, cond in other.conditions] 27 | return indices and conds 28 | return NotImplemented 29 | 30 | def __ne__(self, other): 31 | if isinstance(other, self.__class__): 32 | return not self.__eq__(other) 33 | return NotImplemented 34 | 35 | def __hash__(self): 36 | return hash(tuple(sorted(self.__dict__.items()))) 37 | 38 | def __repr__(self): 39 | rep = 'SELECT {agg} {sel} FROM table'.format( 40 | agg=self.agg_ops[self.agg_index], 41 | sel='col{}'.format(self.sel_index), 42 | ) 43 | if self.conditions: 44 | rep += ' WHERE ' + ' AND '.join(['{} {} {}'.format('col{}'.format(i), self.cond_ops[o], v) for i, o, v in self.conditions]) 45 | return rep 46 | 47 | def to_dict(self): 48 | return {'sel': self.sel_index, 'agg': self.agg_index, 'conds': self.conditions} 49 | 50 | def lower(self): 51 | conds = [] 52 | for col, op, cond in self.conditions: 53 | conds.append([col, op, cond.lower()]) 54 | return self.__class__(self.sel_index, self.agg_index, conds) 55 | 56 | @classmethod 57 | def from_dict(cls, d): 58 | return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds']) 59 | 60 | @classmethod 61 | def from_tokenized_dict(cls, d): 62 | conds = [] 63 | for col, op, val in d['conds']: 64 | conds.append([col, op, detokenize(val)]) 65 | return cls(d['sel'], d['agg'], conds) 66 | 67 | @classmethod 68 | def from_generated_dict(cls, d): 69 | conds = [] 70 | for col, op, val in d['conds']: 71 | end = len(val['words']) 72 | conds.append([col, op, detokenize(val)]) 73 | return cls(d['sel'], d['agg'], conds) 74 | 75 | @classmethod 76 | def from_sequence(cls, sequence, table, lowercase=True): 77 | sequence = deepcopy(sequence) 78 | if 'symend' in sequence['words']: 79 | end = sequence['words'].index('symend') 80 | for k, v in sequence.items(): 81 | sequence[k] = v[:end] 82 | terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] 83 | headers = [detokenize(h) for h in table['header']] 84 | 85 | # lowercase everything and truncate sequence 86 | if lowercase: 87 | headers = [h.lower() for h in headers] 88 | for i, t in enumerate(terms): 89 | for k, v in t.items(): 90 | t[k] = v.lower() 91 | headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] 92 | 93 | # get select 94 | if 'symselect' != terms.pop(0)['word']: 95 | raise Exception('Missing symselect operator') 96 | 97 | # get aggregation 98 | if 'symagg' != terms.pop(0)['word']: 99 | raise Exception('Missing symagg operator') 100 | agg_op = terms.pop(0)['word'] 101 | 102 | if agg_op == 'symcol': 103 | agg_op = '' 104 | else: 105 | if 'symcol' != terms.pop(0)['word']: 106 | raise Exception('Missing aggregation column') 107 | try: 108 | agg_op = cls.agg_ops.index(agg_op.upper()) 109 | except Exception as e: 110 | raise Exception('Invalid agg op {}'.format(agg_op)) 111 | 112 | def find_column(name): 113 | return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) 114 | 115 | def flatten(tokens): 116 | ret = {'words': [], 'after': [], 'gloss': []} 117 | for t in tokens: 118 | ret['words'].append(t['word']) 119 | ret['after'].append(t['after']) 120 | ret['gloss'].append(t['gloss']) 121 | return ret 122 | where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere'] 123 | where_index = where_index[0] if where_index else len(terms) 124 | flat = flatten(terms[:where_index]) 125 | try: 126 | agg_col = find_column(detokenize(flat)) 127 | except Exception as e: 128 | raise Exception('Cannot find aggregation column {}'.format(flat['words'])) 129 | where_terms = terms[where_index+1:] 130 | 131 | # get conditions 132 | conditions = [] 133 | while where_terms: 134 | t = where_terms.pop(0) 135 | flat = flatten(where_terms) 136 | if t['word'] != 'symcol': 137 | raise Exception('Missing conditional column {}'.format(flat['words'])) 138 | try: 139 | op_index = flat['words'].index('symop') 140 | col_tokens = flatten(where_terms[:op_index]) 141 | except Exception as e: 142 | raise Exception('Missing conditional operator {}'.format(flat['words'])) 143 | cond_op = where_terms[op_index+1]['word'] 144 | try: 145 | cond_op = cls.cond_ops.index(cond_op.upper()) 146 | except Exception as e: 147 | raise Exception('Invalid cond op {}'.format(cond_op)) 148 | try: 149 | cond_col = find_column(detokenize(col_tokens)) 150 | except Exception as e: 151 | raise Exception('Cannot find conditional column {}'.format(col_tokens['words'])) 152 | try: 153 | val_index = flat['words'].index('symcond') 154 | except Exception as e: 155 | raise Exception('Cannot find conditional value {}'.format(flat['words'])) 156 | 157 | where_terms = where_terms[val_index+1:] 158 | flat = flatten(where_terms) 159 | val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms) 160 | cond_val = detokenize(flatten(where_terms[:val_end_index])) 161 | conditions.append([cond_col, cond_op, cond_val]) 162 | where_terms = where_terms[val_end_index+1:] 163 | q = cls(agg_col, agg_op, conditions) 164 | return q 165 | 166 | @classmethod 167 | def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True): 168 | sequence = deepcopy(sequence) 169 | if 'symend' in sequence['words']: 170 | end = sequence['words'].index('symend') 171 | for k, v in sequence.items(): 172 | sequence[k] = v[:end] 173 | terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] 174 | headers = [detokenize(h) for h in table['header']] 175 | 176 | # lowercase everything and truncate sequence 177 | if lowercase: 178 | headers = [h.lower() for h in headers] 179 | for i, t in enumerate(terms): 180 | for k, v in t.items(): 181 | t[k] = v.lower() 182 | headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] 183 | 184 | def find_column(name): 185 | return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) 186 | 187 | def flatten(tokens): 188 | ret = {'words': [], 'after': [], 'gloss': []} 189 | for t in tokens: 190 | ret['words'].append(t['word']) 191 | ret['after'].append(t['after']) 192 | ret['gloss'].append(t['gloss']) 193 | return ret 194 | where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere'] 195 | where_index = where_index[0] if where_index else len(terms) 196 | where_terms = terms[where_index+1:] 197 | 198 | # get conditions 199 | conditions = [] 200 | while where_terms: 201 | t = where_terms.pop(0) 202 | flat = flatten(where_terms) 203 | if t['word'] != 'symcol': 204 | raise Exception('Missing conditional column {}'.format(flat['words'])) 205 | try: 206 | op_index = flat['words'].index('symop') 207 | col_tokens = flatten(where_terms[:op_index]) 208 | except Exception as e: 209 | raise Exception('Missing conditional operator {}'.format(flat['words'])) 210 | cond_op = where_terms[op_index+1]['word'] 211 | try: 212 | cond_op = cls.cond_ops.index(cond_op.upper()) 213 | except Exception as e: 214 | raise Exception('Invalid cond op {}'.format(cond_op)) 215 | try: 216 | cond_col = find_column(detokenize(col_tokens)) 217 | except Exception as e: 218 | raise Exception('Cannot find conditional column {}'.format(col_tokens['words'])) 219 | try: 220 | val_index = flat['words'].index('symcond') 221 | except Exception as e: 222 | raise Exception('Cannot find conditional value {}'.format(flat['words'])) 223 | 224 | where_terms = where_terms[val_index+1:] 225 | flat = flatten(where_terms) 226 | val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms) 227 | cond_val = detokenize(flatten(where_terms[:val_end_index])) 228 | conditions.append([cond_col, cond_op, cond_val]) 229 | where_terms = where_terms[val_end_index+1:] 230 | q = cls(agg_col, agg_op, conditions) 231 | return q 232 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/lib/table.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tabulate import tabulate 3 | from lib.query import Query 4 | import random 5 | 6 | 7 | class Table: 8 | 9 | schema_re = re.compile('\((.+)\)') 10 | 11 | def __init__(self, table_id, header, types, rows, caption=None): 12 | self.table_id = table_id 13 | self.header = header 14 | self.types = types 15 | self.rows = rows 16 | self.caption = caption 17 | 18 | def __repr__(self): 19 | return 'Table: {id}\nCaption: {caption}\n{tabulate}'.format( 20 | id=self.table_id, 21 | caption=self.caption, 22 | tabulate=tabulate(self.rows, headers=self.header) 23 | ) 24 | 25 | @classmethod 26 | def get_schema(cls, db, table_id): 27 | table_infos = db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=cls.get_id(table_id)).all() 28 | if table_infos: 29 | return table_infos[0] 30 | else: 31 | return None 32 | 33 | @classmethod 34 | def get_id(cls, table_id): 35 | return 'table_{}'.format(table_id.replace('-', '_')) 36 | 37 | @classmethod 38 | def from_db(cls, db, table_id): 39 | table_info = cls.get_schema(db, table_id) 40 | if table_info: 41 | schema_str = cls.schema_re.findall(table_info)[0] = [0].sql 42 | header, types = [], [] 43 | for tup in schema_str.split(', '): 44 | c, t = tup.split() 45 | header.append(c) 46 | types.append(t) 47 | rows = [[getattr(r, h) for h in header] for r in db.query('SELECT * from {}'.format(cls.get_id(table_id)))] 48 | return cls(table_id, header, types, rows) 49 | else: 50 | return None 51 | 52 | @property 53 | def name(self): 54 | return self.get_id(self.table_id) 55 | 56 | def create_table(self, db, replace_existing=False, lower=True): 57 | exists = self.get_schema(db, self.table_id) 58 | if exists: 59 | if replace_existing: 60 | db.query('DROP TABLE {}'.format(self.name)) 61 | else: 62 | return 63 | type_str = ', '.join(['col{} {}'.format(i, t) for i, t in enumerate(self.types)]) 64 | db.query('CREATE TABLE {name} ({types})'.format(name=self.name, types=type_str)) 65 | for row in self.rows: 66 | value_str = ', '.join([':val{}'.format(j) for j, c in enumerate(row)]) 67 | value_dict = {'val{}'.format(j): c for j, c in enumerate(row)} 68 | if lower: 69 | value_dict = {k: v.lower() if isinstance(v, str) else v for k, v in value_dict.items()} 70 | db.query('INSERT INTO {name} VALUES ({values})'.format(name=self.name, values=value_str), **value_dict) 71 | 72 | def execute_query(self, db, query, lower=True): 73 | sel_str = 'col{}'.format(query.sel_index) if query.sel_index >= 0 else '*' 74 | agg_str = sel_str 75 | agg_op = Query.agg_ops[query.agg_index] 76 | if agg_op: 77 | agg_str = '{}({})'.format(agg_op, sel_str) 78 | where_str = ' AND '.join(['col{} {} :col{}'.format(i, Query.cond_ops[o], i) for i, o, v in query.conditions]) 79 | where_map = {'col{}'.format(i): v for i, o, v in query.conditions} 80 | if lower: 81 | where_map = {k: v.lower() if isinstance(v, str) else v for k, v in where_map.items()} 82 | if where_map: 83 | where_str = 'WHERE ' + where_str 84 | 85 | if query.sel_index >= 0: 86 | query_str = 'SELECT {agg_str} AS result FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 87 | return [r.result for r in db.query(query_str, **where_map)] 88 | else: 89 | query_str = 'SELECT {agg_str} FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 90 | return [[getattr(r, 'col{}'.format(i)) for i in range(len(self.header))] for r in db.query(query_str, **where_map)] 91 | 92 | def query_str(self, query): 93 | agg_str = self.header[query.sel_index] 94 | agg_op = Query.agg_ops[query.agg_index] 95 | if agg_op: 96 | agg_str = '{}({})'.format(agg_op, agg_str) 97 | where_str = ' AND '.join(['{} {} {}'.format(self.header[i], Query.cond_ops[o], v) for i, o, v in query.conditions]) 98 | return 'SELECT {} FROM {} WHERE {}'.format(agg_str, self.name, where_str) 99 | 100 | def generate_query(self, db, max_cond=4): 101 | max_cond = min(len(self.header), max_cond) 102 | # sample a select column 103 | sel_index = random.choice(list(range(len(self.header)))) 104 | # sample where conditions 105 | query = Query(-1, Query.agg_ops.index('')) 106 | results = self.execute_query(db, query) 107 | condition_options = list(range(len(self.header))) 108 | condition_options.remove(sel_index) 109 | for i in range(max_cond): 110 | if not results: 111 | break 112 | cond_index = random.choice(condition_options) 113 | if self.types[cond_index] == 'text': 114 | cond_op = Query.cond_ops.index('=') 115 | else: 116 | cond_op = random.choice(list(range(len(Query.cond_ops)))) 117 | cond_val = random.choice([r[cond_index] for r in results]) 118 | query.conditions.append((cond_index, cond_op, cond_val)) 119 | new_results = self.execute_query(db, query) 120 | if [r[sel_index] for r in new_results] != [r[sel_index] for r in results]: 121 | condition_options.remove(cond_index) 122 | results = new_results 123 | else: 124 | query.conditions.pop() 125 | # sample an aggregation operation 126 | if self.types[sel_index] == 'text': 127 | query.agg_index = Query.agg_ops.index('') 128 | else: 129 | query.agg_index = random.choice(list(range(len(Query.agg_ops)))) 130 | query.sel_index = sel_index 131 | results = self.execute_query(db, query) 132 | return query, results 133 | 134 | def generate_queries(self, db, n=1, max_tries=5, lower=True): 135 | qs = [] 136 | for i in range(n): 137 | n_tries = 0 138 | r = None 139 | while r is None and n_tries < max_tries: 140 | q, r = self.generate_query(db, max_cond=4) 141 | n_tries += 1 142 | if r: 143 | qs.append((q, r)) 144 | return qs 145 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/opts.py: -------------------------------------------------------------------------------- 1 | def model_opts(parser): 2 | """ 3 | These options are passed to the construction of the model. 4 | Be careful with these as they will be used during translation. 5 | """ 6 | # Model options 7 | # Embedding Options 8 | parser.add_argument('-use_type', type=str, default="nouse", help="To use type") 9 | 10 | parser.add_argument('-word_vec_size', type=int, default=300, 11 | help='Word embedding for both.') 12 | parser.add_argument('-ent_vec_size', type=int, default=10, 13 | help='POS embedding size.') 14 | parser.add_argument('-type_vec_size', type=int, default=20, 15 | help='POS embedding size.') 16 | 17 | parser.add_argument('-elmo_path',type=str, default='/media/DataStor/shuaichen/data_model/wikisql/annotated_ent/elmp_embeddings') 18 | 19 | # RNN Options 20 | parser.add_argument('-encoder_type', type=str, default='brnn', 21 | choices=['rnn', 'brnn'], 22 | help="""Type of encoder layer to use.""") 23 | parser.add_argument('-decoder_type', type=str, default='rnn', 24 | choices=['rnn', 'transformer', 'cnn'], 25 | help='Type of decoder layer to use.') 26 | 27 | parser.add_argument('-layers', type=int, default=1, 28 | help='Number of layers in enc/dec.') 29 | parser.add_argument('-enc_layers', type=int, default=1, 30 | help='Number of layers in the encoder') 31 | parser.add_argument('-dec_layers', type=int, default=1, 32 | help='Number of layers in the decoder') 33 | 34 | parser.add_argument('-rnn_size', type=int, default=250, 35 | help='Size of LSTM hidden states') 36 | parser.add_argument('-score_size', type=int, default=64, 37 | help='Size of hidden layer in scorer') 38 | 39 | parser.add_argument('-rnn_type', type=str, default='LSTM', 40 | choices=['LSTM', 'GRU'], 41 | help="""The gate type to use in the RNNs""") 42 | parser.add_argument('-brnn_merge', default='concat', 43 | choices=['concat', 'sum'], 44 | help="Merge action for the bidir hidden states") 45 | 46 | # Table encoding options 47 | parser.add_argument('-split_type', default='incell', 48 | choices=['incell', 'outcell'], 49 | help="whether encode column split token |") 50 | parser.add_argument('-merge_type', default='cat', 51 | choices=['sub', 'cat', 'mlp'], 52 | help="compute span vector for table column: mlp>cat>sub") 53 | 54 | # Decoder options 55 | parser.add_argument('-layout_encode', default='rnn', 56 | choices=['none', 'rnn'], 57 | help="Layout encoding method.") 58 | parser.add_argument('-cond_op_vec_size', type=int, default=150, 59 | help='Layout embedding size.') 60 | 61 | # Attention options 62 | parser.add_argument('-global_attention', type=str, default='general', 63 | choices=['dot', 'general', 'mlp'], 64 | help="""The attention type to use: 65 | dotprot or general (Luong) or MLP (Bahdanau)""") 66 | parser.add_argument('-attn_hidden', type=int, default=64, 67 | help="if attn_hidden > 0, then attention score = f(Ue) B f(Ud)") 68 | parser.add_argument('-co_attention', action="store_true", 69 | help="if attn_hidden > 0, then attention score = f(Ue) B f(Ud)") 70 | 71 | 72 | def preprocess_opts(parser): 73 | # Dictionary Options 74 | parser.add_argument('-src_vocab_size', type=int, default=100000, 75 | help="Size of the source vocabulary") 76 | parser.add_argument('-src_words_min_frequency', type=int, default=0) 77 | 78 | # Truncation options 79 | parser.add_argument('-src_seq_length', type=int, default=50, 80 | help="Maximum source sequence length") 81 | parser.add_argument('-src_seq_length_trunc', type=int, default=0, 82 | help="Truncate source sequence length.") 83 | parser.add_argument('-tgt_seq_length', type=int, default=50, 84 | help="Maximum target sequence length to keep.") 85 | parser.add_argument('-tgt_seq_length_trunc', type=int, default=0, 86 | help="Truncate target sequence length.") 87 | 88 | # Data processing options 89 | parser.add_argument('-shuffle', type=int, default=1, 90 | help="Shuffle data") 91 | parser.add_argument('-lower', action='store_true', help='lowercase data') 92 | 93 | parser.add_argument('-span_exact_match', action="store_true", 94 | help='Must have exact match for cond span in WHERE clause') 95 | 96 | 97 | def train_opts(parser): 98 | # Model loading/saving options 99 | parser.add_argument('-data', default='', 100 | help="""Path prefix to the "train.pt" and 101 | "valid.pt" file path from preprocess.py""") 102 | parser.add_argument('-embd', default='', 103 | help="""Path prefix to embedding folder""") 104 | parser.add_argument('-save_dir', default='', 105 | help="Model save dir") 106 | parser.add_argument('-train_from', default='', type=str, 107 | help="""If training from a checkpoint then this is the 108 | path to the pretrained model's state_dict.""") 109 | # GPU 110 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int, 111 | help="Use CUDA on the listed devices.") 112 | parser.add_argument('-seed', type=int, default=123, 113 | help="""Random seed used for the experiments 114 | reproducibility.""") 115 | 116 | # Init options 117 | parser.add_argument('-start_epoch', type=int, default=1, 118 | help='The epoch from which to start') 119 | parser.add_argument('-param_init', type=float, default=0.08, 120 | help="""Parameters are initialized over uniform distribution 121 | with support (-param_init, param_init). 122 | Use 0 to not use initialization""") 123 | 124 | parser.add_argument('-fix_word_vecs', action='store_true', 125 | help="Fix word embeddings on the encoder side.") 126 | parser.add_argument('-update_word_vecs_after', type=int, default=10,#10 127 | help='When fix_word_vecs=True, only update word vectors after update_word_vecs_after epochs.') 128 | parser.add_argument('-agg_sample_rate', type=float, default=0.5, 129 | help='Randomly skip agg loss, because this loss term tends to be overfitting.') 130 | 131 | # Optimization options 132 | parser.add_argument('-batch_size', type=int, default=200, 133 | help='Maximum batch size') 134 | parser.add_argument('-max_generator_batches', type=int, default=32, 135 | help="""Maximum batches of words in a sequence to run 136 | the generator on in parallel. Higher is faster, but 137 | uses more memory.""") 138 | parser.add_argument('-epochs', type=int, default=60, 139 | help='Number of training epochs') 140 | parser.add_argument('-optim', default='adam',#'rmsprop', 141 | choices=['sgd', 'adagrad', 142 | 'adadelta', 'adam', 'rmsprop'], 143 | help="""Optimization method.""") 144 | parser.add_argument('-max_grad_norm', type=float, default=5, 145 | help="""If the norm of the gradient vector exceeds this, 146 | renormalize it to have the norm equal to 147 | max_grad_norm""") 148 | parser.add_argument('-dropout', type=float, default=0.5, 149 | help="Dropout rate.") 150 | parser.add_argument('-lock_dropout', action='store_true', 151 | help="Use the same dropout mask for RNNs.") 152 | parser.add_argument('-weight_dropout', type=float, default=0, 153 | help=">0: Weight dropout probability; applied in LSTM stacks.") 154 | parser.add_argument('-smooth_eps', type=float, default=0, 155 | help="Label smoothing") 156 | # learning rate 157 | parser.add_argument('-learning_rate', type=float, default=0.001,#0.002, 158 | help="""Starting learning rate.""") 159 | parser.add_argument('-alpha', type=float, default=0.95, 160 | help="Optimization hyperparameter") 161 | parser.add_argument('-learning_rate_decay', type=float, default=0.98, 162 | help="""If update_learning_rate, decay learning rate by this much if (i) perplexity does not decrease on the validation set or (ii) epoch has gone past start_decay_at""") 163 | parser.add_argument('-start_decay_at', type=int, default=8, 164 | help="""Start decaying every epoch after and including this epoch""") 165 | parser.add_argument('-start_checkpoint_at', type=int, default=30, 166 | help="""Start checkpointing every epoch after and including this epoch""") 167 | parser.add_argument('-decay_method', type=str, default="", 168 | choices=['noam'], help="Use a custom decay rate.") 169 | parser.add_argument('-warmup_steps', type=int, default=4000, 170 | help="""Number of warmup steps for custom decay.""") 171 | 172 | parser.add_argument('-report_every', type=int, default=50, 173 | help="Print stats at this interval.") 174 | parser.add_argument('-exp', type=str, default="", 175 | help="Name of the experiment for logging.") 176 | 177 | 178 | def translate_opts(parser): 179 | parser.add_argument('-model_path', required=True, 180 | help='Path to model .pt file') 181 | parser.add_argument('-data_path', default='', 182 | help='Path to data') 183 | parser.add_argument('-save_path', default='', 184 | help='Path to processed data and devbest') 185 | 186 | parser.add_argument('-unseen_table', type=str, default="full", help="To test model on unseen tables only") 187 | parser.add_argument('-split', default="dev", 188 | help="Path to the evaluation annotated data") 189 | parser.add_argument('-output', default='pred.txt', 190 | help="""Path to output the predictions (each line will be the decoded sequence""") 191 | parser.add_argument('-batch_size', type=int, default=30, 192 | help='Batch size') 193 | parser.add_argument('-gpu', type=int, default=0, 194 | help="Device to run on") 195 | parser.add_argument('-gold_layout', action='store_true', 196 | help="Given the golden layout sequences for evaluation.") 197 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import argparse 5 | import codecs 6 | import torch 7 | 8 | import table 9 | import table.IO 10 | import opts 11 | from table.Utils import set_seed 12 | 13 | parser = argparse.ArgumentParser(description='preprocess.py') 14 | 15 | 16 | # **Preprocess Options** 17 | parser.add_argument('-config', help="Read options from this file") 18 | 19 | parser.add_argument('-train_anno', default="train.jsonl", 20 | help="Path to the training annotated data") 21 | parser.add_argument('-valid_anno', default="dev.jsonl", 22 | help="Path to the validation annotated data") 23 | parser.add_argument('-test_anno', default="test.jsonl", 24 | help="Path to the test annotated data") 25 | 26 | parser.add_argument('-save_data', default="", 27 | help="Output file for the prepared data") 28 | 29 | parser.add_argument('-src_vocab', 30 | help="Path to an existing source vocabulary") 31 | parser.add_argument('-tgt_vocab', 32 | help="Path to an existing target vocabulary") 33 | parser.add_argument('-seed', type=int, default=123, 34 | help="Random seed") 35 | parser.add_argument('-report_every', type=int, default=100000, 36 | help="Report status every this many sentences") 37 | 38 | opts.preprocess_opts(parser) 39 | 40 | opt = parser.parse_args() 41 | set_seed(opt.seed) 42 | 43 | 44 | def main(): 45 | print('Preparing training ...') 46 | fields = table.IO.TableDataset.get_fields() 47 | print("Building Training...") 48 | train = table.IO.TableDataset(opt.train_anno, fields, opt, True) 49 | 50 | print("Building Valid...") 51 | valid = table.IO.TableDataset(opt.valid_anno, fields, opt, True) 52 | 53 | print("Building Test...") 54 | test = table.IO.TableDataset(opt.test_anno, fields, opt, False) 55 | 56 | print("Building Vocab...") 57 | table.IO.TableDataset.build_vocab(train, valid, test, opt) 58 | 59 | print("Saving train/valid/fields") 60 | # Can't save fields, so remove/reconstruct at training time. 61 | if os.path.exists(opt.save_data) is False: 62 | os.mkdir(opt.save_data) 63 | torch.save(table.IO.TableDataset.save_vocab(fields), 64 | open(os.path.join(opt.save_data, 'vocab.pt'), 'wb')) 65 | train.fields = [] 66 | valid.fields = [] 67 | torch.save(valid, open(os.path.join(opt.save_data, 'valid.pt'), 'wb')) 68 | torch.save(train, open(os.path.join(opt.save_data, 'train.pt'), 'wb')) 69 | #torch.save(test, open(os.path.join(opt.save_data, 'test.pt'), 'wb')) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/run.sh: -------------------------------------------------------------------------------- 1 | . ./job_config.sh 2 | 3 | 4 | python preprocess.py -train_anno "$DATA_DIR/annotated_ent/train.jsonl" -valid_anno "$DATA_DIR/annotated_ent/dev.jsonl" -test_anno "$DATA_DIR/annotated_ent/test.jsonl" -save_data "$SAVE_PATH" 5 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 train.py -start_checkpoint_at 25 -split_type "incell" -epochs 45 -global_attention "general" -fix_word_vecs -dropout 0.5 -score_size 64 -attn_hidden 64 -rnn_size 250 -co_attention -embd "$DATA_DIR" -data "$SAVE_PATH" -save_dir "$SAVE_PATH" >out-train 6 | 7 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 evaluate.py -unseen_table 'full' -split dev -data_path "$DATA_DIR" -save_path "$SAVE_PATH" -model_path "$SAVE_PATH/m_*.pt" >out-dev 8 | MODEL_PATH=$(head -n1 $SAVE_PATH/dev_best.txt) 9 | echo $MODEL_PATH 10 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 evaluate.py -unseen_table 'full' -split finaltest -data_path "$DATA_DIR" -save_path "$SAVE_PATH" -model_path "$MODEL_PATH" >out-test 11 | 12 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 evaluate.py -unseen_table 'zs' -split finaltest -data_path "$DATA_DIR" -save_path "$SAVE_PATH" -model_path "$MODEL_PATH" >out-test-zs 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JD-AI-Research-Silicon-Valley/auxiliary-task-for-text-to-sql/9c0ff806cabab9e06b1b7fd0fac557bae79ff610/zero-shot-text-to-SQL/table/.DS_Store -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import table 4 | 5 | """ 6 | Class for managing the internals of the beam search process. 7 | 8 | Takes care of beams, back pointers, and scores. 9 | """ 10 | 11 | 12 | class Beam(object): 13 | def __init__(self, size, n_best=1, cuda=False, vocab=None, 14 | global_scorer=None): 15 | 16 | self.size = size 17 | self.tt = torch.cuda if cuda else torch 18 | 19 | # The score for each translation on the beam. 20 | self.scores = self.tt.FloatTensor(size).zero_() 21 | self.allScores = [] 22 | 23 | # The backpointers at each time-step. 24 | self.prevKs = [] 25 | 26 | # The outputs at each time-step. 27 | self.nextYs = [self.tt.LongTensor(size) 28 | .fill_(vocab.stoi[table.IO.PAD_WORD])] 29 | self.nextYs[0][0] = vocab.stoi[table.IO.BOS_WORD] 30 | self.vocab = vocab 31 | 32 | # Has EOS topped the beam yet. 33 | self._eos = self.vocab.stoi[table.IO.EOS_WORD] 34 | self.eosTop = False 35 | 36 | # The attentions (matrix) for each time. 37 | self.attn = [] 38 | 39 | # Time and k pair for finished. 40 | self.finished = [] 41 | self.n_best = n_best 42 | 43 | # Information for global scoring. 44 | self.globalScorer = global_scorer 45 | self.globalState = {} 46 | 47 | def getCurrentState(self): 48 | "Get the outputs for the current timestep." 49 | return self.nextYs[-1] 50 | 51 | def getCurrentOrigin(self): 52 | "Get the backpointers for the current timestep." 53 | return self.prevKs[-1] 54 | 55 | def advance(self, wordLk, attnOut): 56 | """ 57 | Given prob over words for every last beam `wordLk` and attention 58 | `attnOut`: Compute and update the beam search. 59 | 60 | Parameters: 61 | 62 | * `wordLk`- probs of advancing from the last step (K x words) 63 | * `attnOut`- attention at the last step 64 | 65 | Returns: True if beam search is complete. 66 | """ 67 | numWords = wordLk.size(1) 68 | 69 | # Sum the previous scores. 70 | if len(self.prevKs) > 0: 71 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 72 | 73 | # Don't let EOS have children. 74 | for i in range(self.nextYs[-1].size(0)): 75 | if self.nextYs[-1][i] == self._eos: 76 | beamLk[i] = -1e20 77 | else: 78 | beamLk = wordLk[0] 79 | flatBeamLk = beamLk.view(-1) 80 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 81 | 82 | self.allScores.append(self.scores) 83 | self.scores = bestScores 84 | 85 | # bestScoresId is flattened beam x word array, so calculate which 86 | # word and beam each score came from 87 | prevK = bestScoresId / numWords 88 | self.prevKs.append(prevK) 89 | self.nextYs.append((bestScoresId - prevK * numWords)) 90 | self.attn.append(attnOut.index_select(0, prevK)) 91 | 92 | if self.globalScorer is not None: 93 | self.globalScorer.updateGlobalState(self) 94 | 95 | for i in range(self.nextYs[-1].size(0)): 96 | if self.nextYs[-1][i] == self._eos: 97 | s = self.scores[i] 98 | if self.globalScorer is not None: 99 | globalScores = self.globalScorer.score(self, self.scores) 100 | s = globalScores[i] 101 | self.finished.append((s, len(self.nextYs) - 1, i)) 102 | 103 | # End condition is when top-of-beam is EOS and no global score. 104 | if self.nextYs[-1][0] == self.vocab.stoi[table.IO.EOS_WORD]: 105 | # self.allScores.append(self.scores) 106 | self.eosTop = True 107 | 108 | def done(self): 109 | return self.eosTop and len(self.finished) >= self.n_best 110 | 111 | def sortFinished(self, minimum=None): 112 | if minimum is not None: 113 | i = 0 114 | # Add from beam until we have minimum outputs. 115 | while len(self.finished) < minimum: 116 | s = self.scores[i] 117 | if self.globalScorer is not None: 118 | globalScores = self.globalScorer.score(self, self.scores) 119 | s = globalScores[i] 120 | self.finished.append((s, len(self.nextYs) - 1, i)) 121 | 122 | self.finished.sort(key=lambda a: -a[0]) 123 | scores = [sc for sc, _, _ in self.finished] 124 | ks = [(t, k) for _, t, k in self.finished] 125 | return scores, ks 126 | 127 | def getHyp(self, timestep, k): 128 | """ 129 | Walk back to construct the full hypothesis. 130 | """ 131 | hyp, attn = [], [] 132 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 133 | hyp.append(self.nextYs[j+1][k]) 134 | attn.append(self.attn[j][k]) 135 | k = self.prevKs[j][k] 136 | return hyp[::-1], torch.stack(attn[::-1]) 137 | 138 | 139 | class GNMTGlobalScorer(object): 140 | """ 141 | Google NMT ranking score from Wu et al. 142 | """ 143 | def __init__(self, alpha, beta): 144 | self.alpha = alpha 145 | self.beta = beta 146 | 147 | def score(self, beam, logprobs): 148 | "Additional term add to log probability" 149 | cov = beam.globalState["coverage"] 150 | pen = self.beta * torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 151 | l_term = (((5 + len(beam.nextYs)) ** self.alpha) / 152 | ((5 + 1) ** self.alpha)) 153 | return (logprobs / l_term) + pen 154 | 155 | def updateGlobalState(self, beam): 156 | "Keeps the coverage vector as sum of attens" 157 | if len(beam.prevKs) == 1: 158 | beam.globalState["coverage"] = beam.attn[-1] 159 | else: 160 | beam.globalState["coverage"] = beam.globalState["coverage"] \ 161 | .index_select(0, beam.prevKs[-1]).add(beam.attn[-1]) 162 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/IO.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import codecs 4 | import json 5 | import random as rnd 6 | import numpy as np 7 | from collections import Counter, defaultdict 8 | from itertools import chain, count 9 | from six import string_types 10 | 11 | import torch 12 | import torchtext.data 13 | import torchtext.vocab 14 | import h5py 15 | 16 | UNK_WORD = '' 17 | UNK = 0 18 | PAD_WORD = '' 19 | PAD = 1 20 | BOS_WORD = '' 21 | EOS_WORD = '' 22 | SPLIT_WORD = '<|>' 23 | special_token_list = [UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD, SPLIT_WORD] 24 | zero_vector = [] 25 | for _ in range(1024): 26 | zero_vector.append(0.0) 27 | 28 | def __getstate__(self): 29 | return dict(self.__dict__, stoi=dict(self.stoi)) 30 | 31 | 32 | def __setstate__(self, state): 33 | self.__dict__.update(state) 34 | self.stoi = defaultdict(lambda: 0, self.stoi) 35 | 36 | 37 | torchtext.vocab.Vocab.__getstate__ = __getstate__ 38 | torchtext.vocab.Vocab.__setstate__ = __setstate__ 39 | 40 | 41 | def merge_vocabs(vocabs, vocab_size=None): 42 | """ 43 | Merge individual vocabularies (assumed to be generated from disjoint 44 | documents) into a larger vocabulary. 45 | 46 | Args: 47 | vocabs: `torchtext.vocab.Vocab` vocabularies to be merged 48 | vocab_size: `int` the final vocabulary size. `None` for no limit. 49 | Return: 50 | `torchtext.vocab.Vocab` 51 | """ 52 | merged = sum([vocab.freqs for vocab in vocabs], Counter()) 53 | return torchtext.vocab.Vocab(merged, 54 | specials=list(special_token_list), 55 | max_size=vocab_size) 56 | 57 | 58 | def join_dicts(*args): 59 | """ 60 | args: dictionaries with disjoint keys 61 | returns: a single dictionary that has the union of these keys 62 | """ 63 | return dict(chain(*[d.items() for d in args])) 64 | 65 | 66 | class OrderedIterator(torchtext.data.Iterator): 67 | def create_batches(self): 68 | if self.train: 69 | self.batches = torchtext.data.pool( 70 | self.data(), self.batch_size, 71 | self.sort_key, self.batch_size_fn, 72 | random_shuffler=self.random_shuffler) 73 | else: 74 | self.batches = [] 75 | _=0 76 | for b in torchtext.data.batch(self.data(), self.batch_size, 77 | self.batch_size_fn): 78 | self.batches.append(sorted(b, key=self.sort_key)) 79 | _+=1 80 | print(_,self.batch_size,_*self.batch_size) 81 | 82 | 83 | #"conds": [ [5, 0, {"words": ["butler", "cc", "-lrb-", "ks", "-rrb-"], "after": [" ", " ", "", "", ""], "gloss": ["Butler", "CC", "(", "KS", ")"]}] ] 84 | 85 | def find(q, c): 86 | ans=[] 87 | for st in range(0, len(q) - len(c) + 1): 88 | if q[st:st + len(c)] == c: 89 | ans.append(st) 90 | return st 91 | return 0 92 | 93 | 94 | def read_anno_json(anno_path): 95 | with codecs.open(anno_path, "r", "utf-8") as corpus_file: 96 | js_list = [json.loads(line) for line in corpus_file] 97 | for js in js_list: 98 | #cond_list = list(enumerate(js['query']['conds'])) 99 | # sort by (op, orginal index) 100 | # cond_list.sort(key=lambda x: (x[1][1], x[0])) 101 | #cond_list.sort(key=lambda x: x[1][1]) 102 | #js['query']['conds'] = [x[1] for x in cond_list] 103 | 104 | cond_list = js['query']['conds'] 105 | pos=[] 106 | for i in range(len(cond_list)): 107 | pos.append(find(js['question']['words'],cond_list[i][2]['words'])) 108 | #print(pos) 109 | S=list(zip(cond_list,pos)) 110 | S.sort(key=lambda x: x[1]) 111 | js['query']['conds'] = [x[0] for x in S] 112 | return js_list 113 | 114 | 115 | class TableDataset(torchtext.data.Dataset): 116 | """Defines a dataset for machine translation.""" 117 | 118 | @staticmethod 119 | def sort_key(ex): 120 | "Sort in reverse size order" 121 | return -len(ex.src) 122 | 123 | def __init__(self, anno, fields, opt, filter_ex, **kwargs): 124 | """ 125 | Create a TranslationDataset given paths and fields. 126 | 127 | anno: location of annotated data / js_list 128 | filter_ex: False - keep all the examples for evaluation (should not have filtered examples); True - filter examples with unmatched spans; 129 | """ 130 | if isinstance(anno, string_types): 131 | js_list = read_anno_json(anno) 132 | else: 133 | js_list = anno 134 | 135 | 136 | 137 | src_data = self._read_annotated_file( #return a generator 138 | opt, js_list, 'question', filter_ex) 139 | src_examples = self._construct_examples(src_data, 'src') #return a generator of dict ('src':data) 140 | 141 | #elmo_data = self._read_annotated_file( # return a generator 142 | # opt, js_list, 'elmo', filter_ex) 143 | #elmo_examples = self._construct_examples(elmo_data, 'elmo') # return a generator of dict ('src':data) 144 | 145 | 146 | ent_data = self._read_annotated_file(opt, js_list, 'ent', filter_ex) 147 | ent_examples = self._construct_examples(ent_data, 'ent') 148 | 149 | type_data = self._read_annotated_file(opt, js_list, 'type', filter_ex) 150 | type_examples = self._construct_examples(type_data, 'type') 151 | 152 | agg_data = self._read_annotated_file(opt, js_list, 'agg', filter_ex) 153 | agg_examples = self._construct_examples(agg_data, 'agg') 154 | 155 | sel_data = self._read_annotated_file(opt, js_list, 'sel', filter_ex) 156 | sel_examples = self._construct_examples(sel_data, 'sel') 157 | 158 | tbl_data = self._read_annotated_file(opt, js_list, 'tbl', filter_ex) 159 | tbl_examples = self._construct_examples(tbl_data, 'tbl') 160 | 161 | tbl_split_data = self._read_annotated_file( 162 | opt, js_list, 'tbl_split', filter_ex) 163 | tbl_split_examples = self._construct_examples( 164 | tbl_split_data, 'tbl_split') 165 | 166 | tbl_mask_data = self._read_annotated_file( 167 | opt, js_list, 'tbl_mask', filter_ex) 168 | tbl_mask_examples = self._construct_examples( 169 | tbl_mask_data, 'tbl_mask') 170 | 171 | lay_data = self._read_annotated_file(opt, js_list, 'lay', filter_ex) 172 | lay_examples = self._construct_examples(lay_data, 'lay') 173 | 174 | cond_op_data = self._read_annotated_file( 175 | opt, js_list, 'cond_op', filter_ex) 176 | cond_op_examples = self._construct_examples(cond_op_data, 'cond_op') 177 | 178 | cond_col_data = list( 179 | self._read_annotated_file(opt, js_list, 'cond_col', filter_ex)) 180 | cond_col_examples = self._construct_examples(cond_col_data, 'cond_col') 181 | cond_col_loss_examples = self._construct_examples( 182 | cond_col_data, 'cond_col_loss') 183 | 184 | def _map_to_sublist_index(d_list, idx): 185 | return [([it[idx] for it in d] if (d is not None) else None) for d in d_list] 186 | span_data = list(self._read_annotated_file( 187 | opt, js_list, 'cond_span', filter_ex)) 188 | 189 | ### 190 | BIO_label_data = list(self._read_annotated_file( 191 | opt, js_list, 'BIO_label', filter_ex)) 192 | 193 | BIO_label_examples = self._construct_examples( BIO_label_data, 'BIO_label') 194 | BIO_label_loss_examples = self._construct_examples(BIO_label_data, 'BIO_label_loss') 195 | 196 | ### 197 | BIO_op_label_data = list(self._read_annotated_file( 198 | opt, js_list, 'BIO_op_label', filter_ex)) 199 | 200 | BIO_op_label_examples = self._construct_examples(BIO_op_label_data, 'BIO_op_label') 201 | BIO_op_label_loss_examples = self._construct_examples(BIO_op_label_data, 'BIO_op_label_loss') 202 | 203 | ### 204 | BIO_column_label_data = list(self._read_annotated_file( 205 | opt, js_list, 'BIO_column_label', filter_ex)) 206 | BIO_column_label_examples = self._construct_examples(BIO_column_label_data, 'BIO_column_label') 207 | BIO_column_label_loss_examples = self._construct_examples(BIO_column_label_data, 'BIO_column_label_loss') 208 | 209 | #print(len(list(BIO_label_loss_examples))) 210 | 211 | 212 | span_l_examples = self._construct_examples( 213 | _map_to_sublist_index(span_data, 0), 'cond_span_l') 214 | span_r_examples = self._construct_examples( 215 | _map_to_sublist_index(span_data, 1), 'cond_span_r') 216 | span_l_loss_examples = self._construct_examples( 217 | _map_to_sublist_index(span_data, 0), 'cond_span_l_loss') 218 | span_r_loss_examples = self._construct_examples( 219 | _map_to_sublist_index(span_data, 1), 'cond_span_r_loss') 220 | 221 | 222 | 223 | 224 | 225 | # examples: one for each src line or (src, tgt) line pair. 226 | examples = [join_dicts(*it) for it in zip(src_examples, ent_examples, type_examples, agg_examples, sel_examples, lay_examples, tbl_examples, tbl_split_examples, tbl_mask_examples, 227 | cond_op_examples, cond_col_examples, span_l_examples, span_r_examples, cond_col_loss_examples, span_l_loss_examples, span_r_loss_examples, 228 | BIO_label_examples, BIO_label_loss_examples,BIO_column_label_examples,BIO_column_label_loss_examples,BIO_op_label_examples,BIO_op_label_loss_examples)] 229 | # the examples should not contain None 230 | len_before_filter = len(examples) 231 | print(len_before_filter) 232 | examples = list(filter(lambda x: all( 233 | (v is not None for k, v in x.items())), examples)) 234 | len_after_filter = len(examples) 235 | print(len_after_filter) 236 | num_filter = len_before_filter - len_after_filter 237 | # if num_filter > 0: 238 | # print('Filter #examples (with None): {} / {} = {:.2%}'.format(num_filter, 239 | # len_before_filter, num_filter / len_before_filter)) 240 | 241 | len_lay_list = [] 242 | len_tgt_list = [] 243 | for ex in examples: 244 | has_agg = 0 if int(ex['agg']) == 0 else 1 245 | if len(ex['cond_op']) == 0: 246 | len_lay_list.append(0) 247 | len_tgt_list.append(1 + has_agg + 1) 248 | else: 249 | len_lay = len(ex['cond_op']) * 2 250 | len_lay_list.append(len_lay) 251 | len_tgt_list.append( 252 | 1 + has_agg + 1 + len_lay + len(ex['cond_op']) * 2) 253 | 254 | # Peek at the first to see which fields are used. 255 | ex = examples[0] 256 | keys = ex.keys() 257 | fields = [(k, fields[k]) 258 | for k in (list(keys) + ["indices"])] 259 | 260 | def construct_final(examples): 261 | for i, ex in enumerate(examples): 262 | yield torchtext.data.Example.fromlist( 263 | [ex[k] for k in keys] + [i], 264 | fields) 265 | 266 | def filter_pred(example): 267 | return True 268 | 269 | super(TableDataset, self).__init__( 270 | construct_final(examples), fields, filter_pred) 271 | 272 | 273 | def _read_annotated_file(self, opt, js_list, field, filter_ex): 274 | """ 275 | path: location of a src or tgt file 276 | truncate: maximum sequence length (0 for unlimited) 277 | """ 278 | if field in ('sel', 'agg'): 279 | lines = (line['query'][field] for line in js_list) 280 | elif field in ('ent',): 281 | lines = (line['question']['ent'] for line in js_list) 282 | elif field in ('type',): 283 | #lines = (line['question']['word_type'] for line in js_list) 284 | 285 | def filt_type(line): 286 | return [ w if w =='column' else 'others' for w in line['question']['word_type']] 287 | 288 | lines = (filt_type(line) for line in js_list) 289 | elif field in ('tbl',): 290 | def _tbl(line): 291 | tk_list = [SPLIT_WORD] 292 | tk_split = '\t' + SPLIT_WORD + '\t' 293 | tk_list.extend(tk_split.join( 294 | ['\t'.join(col['words']) for col in line['table']['header']]).strip().split('\t')) 295 | tk_list.append(SPLIT_WORD) 296 | return tk_list 297 | lines = (_tbl(line) for line in js_list) 298 | elif field in ('tbl_split',): 299 | def _cum_length_for_split(line): 300 | len_list = [len(col['words']) 301 | for col in line['table']['header']] 302 | r = [0] 303 | for i in range(len(len_list)): 304 | r.append(r[-1] + len_list[i] + 1) 305 | return r 306 | lines = (_cum_length_for_split(line) for line in js_list) 307 | elif field in ('tbl_mask',): 308 | lines = ([0 for col in line['table']['header']] 309 | for line in js_list) 310 | elif field in ('lay',): 311 | def _lay(where_list): 312 | return ' '.join([str(op) for col, op, cond in where_list]) 313 | lines = (_lay(line['query']['conds']) 314 | for line in js_list) 315 | elif field in ('cond_op',): 316 | lines = ([str(op) for col, op, cond in line['query']['conds']] 317 | for line in js_list) 318 | elif field in ('cond_col',): 319 | lines = ([col for col, op, cond in line['query']['conds']] 320 | for line in js_list) 321 | elif field in ('cond_span'): 322 | def _find_span(q_list, where_list): 323 | r_list = [] 324 | for col, op, cond in where_list: 325 | tk_list = cond['words'] 326 | # find exact match first 327 | if len(tk_list) <= len(q_list): 328 | match_list = [] 329 | for st in range(0, len(q_list) - len(tk_list) + 1): 330 | if q_list[st:st + len(tk_list)] == tk_list: 331 | match_list.append((st, st + len(tk_list) - 1)) 332 | if len(match_list) > 0: 333 | r_list.append(rnd.choice(match_list)) #multi match then random choose one. 334 | continue 335 | elif (opt is not None) and opt.span_exact_match: 336 | return None 337 | else: 338 | # do not have exact match, then fuzzy match (w/o considering order) 339 | for len_span in range(len(tk_list), len(tk_list) + 2): 340 | for st in range(0, len(q_list) - len_span + 1): 341 | if set(tk_list) <= set(q_list[st:st + len_span]): 342 | match_list.append( 343 | (st, st + len_span - 1)) 344 | if len(match_list) > 0: 345 | # match spans that are as short as possible 346 | break 347 | if len(match_list) > 0: 348 | r_list.append(rnd.choice(match_list)) 349 | else: 350 | return None 351 | else: 352 | return None 353 | return r_list 354 | 355 | def _span(q_list, where_list, filter_ex): 356 | r_list = _find_span(q_list, where_list) 357 | if (not filter_ex) and (r_list is None): 358 | r_list = [] 359 | for col, op, cond in where_list: 360 | r_list.append((0, 0)) 361 | return r_list 362 | lines = (_span(line['question']['words'], line['query'] 363 | ['conds'], filter_ex) for line in js_list) 364 | print('span',type(lines)) 365 | elif field in('BIO_label','BIO_column_label','BIO_op_label'): 366 | 367 | def _find_span(q_list, where_list): 368 | r_list = [] 369 | for col, op, cond in where_list: 370 | tk_list = cond['words'] 371 | # find exact match first 372 | if len(tk_list) <= len(q_list): 373 | match_list = [] 374 | for st in range(0, len(q_list) - len(tk_list) + 1): 375 | if q_list[st:st + len(tk_list)] == tk_list: 376 | match_list.append((st, st + len(tk_list) - 1,col,op)) 377 | if len(match_list) > 0: 378 | r_list.append(rnd.choice(match_list)) #multi match then random choose one. 379 | continue 380 | elif (opt is not None) and opt.span_exact_match: 381 | return None 382 | else: 383 | # do not have exact match, then fuzzy match (w/o considering order) 384 | for len_span in range(len(tk_list), len(tk_list) + 2): 385 | for st in range(0, len(q_list) - len_span + 1): 386 | if set(tk_list) <= set(q_list[st:st + len_span]): 387 | match_list.append( 388 | (st, st + len_span - 1,col,op)) 389 | if len(match_list) > 0: 390 | # match spans that are as short as possible 391 | break 392 | if len(match_list) > 0: 393 | r_list.append(rnd.choice(match_list)) 394 | else: 395 | return None 396 | else: 397 | return None 398 | return r_list 399 | 400 | def _span(q_list, where_list, filter_ex): 401 | r_list = _find_span(q_list, where_list) 402 | if (not filter_ex) and (r_list is None): 403 | r_list = [] 404 | for col, op, cond in where_list: 405 | r_list.append((0, 0,-1,-1)) 406 | return r_list 407 | 408 | def _col_showin_q(question, cols): # q: a list of word. col: a list of col, each col is a list of words 409 | def _find_col(c, cols): 410 | for i in range(len(cols)): 411 | if c == cols[i]: 412 | return i 413 | return -1 414 | 415 | mapping = [] 416 | for i in range(len(question)): 417 | mapping.append(-1) 418 | 419 | for span_len in reversed(range(10)): 420 | for i in range(len(question) - span_len): 421 | ff = 1 422 | for j in range(i, i + span_len + 1): 423 | if mapping[j] != -1: # belong to a col 424 | ff = 0 425 | break 426 | if ff == 0: 427 | continue 428 | 429 | tag = _find_col(question[i:i + span_len + 1], cols) 430 | if tag != -1: 431 | for j in range(i, i + span_len + 1): 432 | mapping[j] = tag 433 | return mapping 434 | 435 | lines_1 = [line['question']['words'] for line in js_list] 436 | 437 | lines_2 = (_span(line['question']['words'], line['query']['conds'], filter_ex) for line in js_list) 438 | 439 | lines_3 = [line['question']['word_type'] for line in js_list] 440 | lines_4 = [_col_showin_q(line['question']['words'], [x['words'] for x in line['table']['header']]) for line 441 | in js_list] 442 | if field in ('BIO_label',): 443 | total_token = 0 444 | BI_token = 0 445 | lines=[] 446 | for qq,span,col_indicator,col_map in zip(lines_1,lines_2,lines_3,lines_4): 447 | q=qq[:] 448 | #print(q) 449 | for i in range(len(q)): 450 | q[i]=2 451 | total_token+=1 452 | for i in range(len(col_map)): 453 | if col_map[i] != -1: 454 | q[i]=3 455 | # for i in range(len(col_indicator)): 456 | # if col_indicator[i]=='column': 457 | # q[i]=3 458 | if span is not None and len(span) > 0: 459 | for sp in span: 460 | BI_token+=1 461 | #print(sp) 462 | q[sp[0]]=0 463 | for i in range(sp[0]+1,sp[1]+1): 464 | q[i]=1 465 | #yield q 466 | lines.append(q) 467 | print(total_token, BI_token, 1.0 * BI_token / total_token) 468 | elif field in ('BIO_column_label'):#BIO_column_label 469 | lines = [] 470 | for qq, span in zip(lines_1, lines_2): 471 | q = qq[:] 472 | for i in range(len(q)): 473 | q[i]=-1 474 | if span is not None and len(span) > 0: 475 | for sp in span: 476 | for i in range(sp[0],sp[1]+1): 477 | q[i]=sp[2] #cond_col 478 | lines.append(q) 479 | else: 480 | lines = [] 481 | for qq, span in zip(lines_1, lines_2): 482 | q = qq[:] 483 | for i in range(len(q)): 484 | q[i] = -1 485 | if span is not None and len(span) > 0: 486 | for sp in span: 487 | for i in range(sp[0], sp[1] + 1): 488 | q[i] = sp[3] #cond_op 489 | lines.append(q) 490 | lines = (line for line in lines) 491 | 492 | #print('BIO', type(lines)) 493 | elif field in ('cond_mask',): 494 | lines = ([0 for col, op, cond in line['query']['conds']] 495 | for line in js_list) 496 | elif field in ('elmo',): 497 | lines = (line[field]['words'] for line in js_list) 498 | else: 499 | lines = (line[field]['words'] for line in js_list) 500 | #print(field) 501 | #print(type(js_list)) 502 | #print(type(lines)) 503 | for line in lines: 504 | yield line 505 | 506 | def _construct_examples(self, lines, side): 507 | for words in lines: 508 | example_dict = {side: words} 509 | yield example_dict 510 | 511 | def __getstate__(self): 512 | return self.__dict__ 513 | 514 | def __setstate__(self, d): 515 | self.__dict__.update(d) 516 | 517 | def __reduce_ex__(self, proto): 518 | "This is a hack. Something is broken with torch pickle." 519 | return super(TableDataset, self).__reduce_ex__() 520 | 521 | @staticmethod 522 | def load_fields(vocab): 523 | vocab = dict(vocab) 524 | fields = TableDataset.get_fields() 525 | for k, v in vocab.items(): 526 | # Hack. Can't pickle defaultdict :( 527 | v.stoi = defaultdict(lambda: 0, v.stoi) 528 | fields[k].vocab = v 529 | return fields 530 | 531 | @staticmethod 532 | def save_vocab(fields): 533 | vocab = [] 534 | for k, f in fields.items(): 535 | if 'vocab' in f.__dict__: 536 | f.vocab.stoi = dict(f.vocab.stoi) 537 | vocab.append((k, f.vocab)) 538 | return vocab 539 | 540 | @staticmethod 541 | def get_fields(): 542 | fields = {} 543 | fields["src"] = torchtext.data.Field( 544 | pad_token=PAD_WORD, include_lengths=True)#, eos_token=EOS_WORD) 545 | 546 | 547 | 548 | fields["ent"] = torchtext.data.Field( 549 | pad_token=PAD_WORD, include_lengths=False)#, eos_token=EOS_WORD) 550 | fields["type"] = torchtext.data.Field( 551 | pad_token=PAD_WORD, include_lengths=False) # , eos_token=EOS_WORD) 552 | fields["agg"] = torchtext.data.Field( 553 | sequential=False, use_vocab=False, batch_first=True) 554 | fields["sel"] = torchtext.data.Field( 555 | sequential=False, use_vocab=False, batch_first=True) 556 | fields["tbl"] = torchtext.data.Field( 557 | pad_token=PAD_WORD, include_lengths=True) 558 | fields["tbl_split"] = torchtext.data.Field( 559 | use_vocab=False, pad_token=0) 560 | fields["tbl_mask"] = torchtext.data.Field( 561 | use_vocab=False, tensor_type=torch.ByteTensor, batch_first=True, pad_token=1) 562 | fields["lay"] = torchtext.data.Field( 563 | sequential=False, batch_first=True) 564 | fields["cond_op"] = torchtext.data.Field( 565 | include_lengths=True, pad_token=PAD_WORD) 566 | fields["cond_col"] = torchtext.data.Field( 567 | use_vocab=False, include_lengths=False, pad_token=0) 568 | fields["cond_span_l"] = torchtext.data.Field( 569 | use_vocab=False, include_lengths=False, pad_token=0) 570 | fields["cond_span_r"] = torchtext.data.Field( 571 | use_vocab=False, include_lengths=False, pad_token=0) 572 | fields["cond_col_loss"] = torchtext.data.Field( 573 | use_vocab=False, include_lengths=False, pad_token=-1) 574 | fields["cond_span_l_loss"] = torchtext.data.Field( 575 | use_vocab=False, include_lengths=False, pad_token=-1) 576 | fields["cond_span_r_loss"] = torchtext.data.Field( 577 | use_vocab=False, include_lengths=False, pad_token=-1) 578 | fields["indices"] = torchtext.data.Field( 579 | use_vocab=False, sequential=False) 580 | 581 | fields["BIO_label"] = torchtext.data.Field( 582 | use_vocab=False, include_lengths=False, pad_token=-1) 583 | fields["BIO_label_loss"] = torchtext.data.Field( 584 | use_vocab=False, include_lengths=False, pad_token=-1) 585 | fields["BIO_column_label"] = torchtext.data.Field( 586 | use_vocab=False, include_lengths=False, pad_token=-1) 587 | fields["BIO_column_label_loss"] = torchtext.data.Field( 588 | use_vocab=False, include_lengths=False, pad_token=-1) 589 | 590 | fields["BIO_op_label"] = torchtext.data.Field( 591 | use_vocab=False, include_lengths=False, pad_token=-1) 592 | fields["BIO_op_label_loss"] = torchtext.data.Field( 593 | use_vocab=False, include_lengths=False, pad_token=-1) 594 | 595 | return fields 596 | 597 | @staticmethod 598 | def build_vocab(train, dev, test, opt): 599 | fields = train.fields 600 | 601 | merge_list = [] 602 | merge_name_list = ['tbl','src']#('src', 'tbl') 603 | print(1) 604 | for split in (dev, test, train,): 605 | 606 | for merge_name_it in merge_name_list: 607 | fields[merge_name_it].build_vocab(split, max_size=opt.src_vocab_size, min_freq=0) 608 | merge_list.append(fields[merge_name_it].vocab) 609 | print(merge_name_it, len(fields[merge_name_it].vocab.stoi)) 610 | print(2) 611 | # build vocabulary only based on the training set 612 | fields["ent"].build_vocab( 613 | train, max_size=opt.src_vocab_size, min_freq=0) 614 | fields["type"].build_vocab( 615 | train, max_size=opt.src_vocab_size, min_freq=0) 616 | fields["lay"].build_vocab( 617 | train, max_size=opt.src_vocab_size, min_freq=0) 618 | fields["cond_op"].build_vocab( 619 | train, max_size=opt.src_vocab_size, min_freq=0) 620 | print(3) 621 | # need to know all the words to filter the pretrained word embeddings 622 | merged_vocab = merge_vocabs(merge_list, vocab_size=opt.src_vocab_size) 623 | total=0 624 | for x in merge_list: 625 | total+=len(x.stoi) 626 | print(total,len(merged_vocab.stoi)) 627 | for merge_name_it in merge_name_list: 628 | fields[merge_name_it].vocab = merged_vocab 629 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file handles the details of the loss function during training. 3 | 4 | This includes: LossComputeBase and the standard NMTLossCompute, and 5 | sharded loss compute stuff. 6 | """ 7 | from __future__ import division 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | import random as rnd 12 | 13 | import table 14 | from table.modules.cross_entropy_smooth import CrossEntropyLossSmooth 15 | 16 | 17 | class TableLossCompute(nn.Module): 18 | def __init__(self, agg_sample_rate, smooth_eps=0): 19 | super(TableLossCompute, self).__init__() 20 | self.criterion = {} 21 | nll = nn.NLLLoss(size_average=False, ignore_index=-1) 22 | nll_col = nn.NLLLoss(size_average=False, ignore_index=0) 23 | if smooth_eps > 0: 24 | for loss_name in ('sel', 'cond_col', 'cond_span_l', 'cond_span_r','BIO_label', 'BIO_column_label','BIO_op_label'): 25 | self.criterion[loss_name] = nll 26 | for loss_name in ('agg', 'lay'): 27 | self.criterion[loss_name] = CrossEntropyLossSmooth( 28 | size_average=False, ignore_index=-1, smooth_eps=smooth_eps) 29 | else: 30 | for loss_name in ('agg', 'sel', 'lay', 'cond_col', 'cond_span_l', 'cond_span_r','BIO_label', 'BIO_column_label','BIO_op_label'): 31 | self.criterion[loss_name] = nll 32 | #self.criterion['BIO_column_label'] = nll_col 33 | self.agg_sample_rate = agg_sample_rate 34 | 35 | def compute_loss(self, pred, gold): 36 | # sum up the loss functions 37 | loss_list = [] 38 | for loss_name in ('agg', 'sel', 'lay'): 39 | loss = self.criterion[loss_name](pred[loss_name], gold[loss_name]) 40 | if (loss_name != 'agg') or (rnd.random() < self.agg_sample_rate): 41 | loss_list.append(loss) 42 | for loss_name in ('cond_col', 'cond_span_l', 'cond_span_r'): 43 | losses=[] 44 | for p, g in zip(pred[loss_name], gold[loss_name]): 45 | loss = self.criterion[loss_name](p, g) 46 | loss_list.append(loss) 47 | #losses.append(loss) 48 | #loss_list.append(1.0*sum(losses)/len(losses)) 49 | 50 | #return sum(loss_list) 51 | #BIO loss begin 52 | #loss_list=[] 53 | for loss_name in ('BIO_label','BIO_op_label'): 54 | losses=[] 55 | losses.append(loss) 56 | for p,g in zip(pred[loss_name], gold[loss_name]): 57 | loss=self.criterion[loss_name](p,g) 58 | losses.append(loss) 59 | loss_list.append(3.0* sum(losses) / len(losses)) #3 60 | 61 | for loss_name in ('BIO_column_label',): 62 | losses=[] 63 | losses.append(loss) 64 | for p,g in zip(pred[loss_name], gold[loss_name]): 65 | loss=self.criterion[loss_name](p,g) 66 | losses.append(loss) 67 | loss_list.append(3.0* sum(losses) / len(losses)) # 7 68 | # BIO loss end 69 | 70 | 71 | return sum(loss_list) -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/ModelConstructor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is for models creation, which consults options 3 | and creates each encoder and decoder accordingly. 4 | """ 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch 8 | import table 9 | import table.Models 10 | import table.modules 11 | from table.IO import TableDataset 12 | from table.Models import ParserModel, RNNEncoder, CondDecoder, TableRNNEncoder, MatchScorer, CondMatchScorer, CoAttention,CoAttention_wornn, AttnPool, Softattention 13 | 14 | import torchtext.vocab 15 | from table.modules.Embeddings import PartUpdateEmbedding 16 | 17 | import sys 18 | sys.path.append("..") 19 | from lib.query import agg_ops, cond_ops 20 | 21 | 22 | def make_word_embeddings(opt, word_dict, fields): 23 | word_padding_idx = word_dict.stoi[table.IO.PAD_WORD] 24 | num_word = len(word_dict) 25 | emb_word = nn.Embedding(num_word, opt.word_vec_size, 26 | padding_idx=word_padding_idx) 27 | 28 | if len(opt.pre_word_vecs) > 0: 29 | vectors = torchtext.vocab.GloVe( 30 | name="840B", cache=opt.pre_word_vecs, dim=str(opt.word_vec_size)) 31 | 32 | fields["src"].vocab.load_vectors(vectors) 33 | emb_word.weight.data.copy_(fields["src"].vocab.vectors) 34 | 35 | if opt.fix_word_vecs: 36 | # is 0 37 | num_special = len(table.IO.special_token_list) 38 | # zero vectors in the fixed embedding (emb_word) 39 | emb_word.weight.data[:num_special].zero_() 40 | emb_special = nn.Embedding( 41 | num_special, opt.word_vec_size, padding_idx=word_padding_idx) 42 | emb = PartUpdateEmbedding(num_special, emb_special, emb_word) 43 | return emb 44 | else: 45 | return emb_word 46 | 47 | 48 | def make_embeddings(word_dict, vec_size): 49 | word_padding_idx = word_dict.stoi[table.IO.PAD_WORD] 50 | num_word = len(word_dict) 51 | w_embeddings = nn.Embedding( 52 | num_word, vec_size, padding_idx=word_padding_idx) 53 | return w_embeddings 54 | 55 | 56 | def make_encoder(opt, embeddings, ent_embedding=None, type_embedding=None): 57 | # "rnn" or "brnn" 58 | return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers, opt.rnn_size, opt.dropout, opt.lock_dropout, opt.weight_dropout, embeddings, ent_embedding,type_embedding) 59 | 60 | 61 | def make_table_encoder(opt, embeddings): 62 | # "rnn" or "brnn" 63 | return TableRNNEncoder(make_encoder(opt, embeddings), opt.split_type, opt.merge_type) 64 | 65 | 66 | def make_cond_decoder(opt): 67 | input_size = opt.rnn_size 68 | return CondDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, input_size, opt.rnn_size, opt.global_attention, opt.attn_hidden, opt.dropout, opt.lock_dropout, opt.weight_dropout) 69 | 70 | 71 | def make_co_attention(opt): 72 | if opt.co_attention: 73 | return CoAttention(opt.rnn_type, opt.brnn, opt.enc_layers, opt.rnn_size, opt.dropout, opt.weight_dropout, opt.global_attention, opt.attn_hidden) 74 | return None 75 | 76 | def make_co_attention_wornn(opt): 77 | if opt.co_attention: 78 | return CoAttention_wornn(opt.rnn_type, opt.brnn, opt.enc_layers, opt.rnn_size, opt.dropout, opt.weight_dropout, opt.global_attention, opt.attn_hidden) 79 | return None 80 | 81 | 82 | def make_self_attention(opt): 83 | return AttnPool(opt.rnn_size, opt.attn_hidden) 84 | 85 | def make_base_model(model_opt, fields, checkpoint=None): 86 | """ 87 | Args: 88 | model_opt: the option loaded from checkpoint. 89 | fields: `Field` objects for the model. 90 | gpu(bool): whether to use gpu. 91 | checkpoint: the model gnerated by train phase, or a resumed snapshot 92 | model from a stopped training. 93 | Returns: 94 | the NMTModel. 95 | """ 96 | 97 | # embedding 98 | w_embeddings = make_word_embeddings(model_opt, fields["src"].vocab, fields) 99 | 100 | #elmo_embeddings = make_elmo_embeddings(model_opt, fields['src'], ) 101 | 102 | if model_opt.ent_vec_size > 0: 103 | ent_embedding = make_embeddings( 104 | fields["ent"].vocab, model_opt.ent_vec_size) 105 | else: 106 | ent_embedding = None 107 | 108 | if model_opt.type_vec_size > 0: 109 | type_embedding = make_embeddings( 110 | fields["type"].vocab, model_opt.type_vec_size) 111 | else: 112 | type_embedding = None 113 | 114 | # Make question encoder. 115 | if model_opt.use_type =='use': 116 | q_encoder = make_encoder(model_opt, w_embeddings, ent_embedding, type_embedding) 117 | else: 118 | q_encoder = make_encoder(model_opt, w_embeddings, ent_embedding) 119 | 120 | # Make table encoder. 121 | tbl_encoder = make_table_encoder(model_opt, w_embeddings) 122 | 123 | co_attention = make_co_attention(model_opt) 124 | op_selfattention= make_co_attention_wornn(model_opt) 125 | 126 | agg_self_attention = make_self_attention(model_opt) 127 | lay_self_attention = make_self_attention(model_opt) 128 | 129 | BIO_classifier = nn.Sequential( 130 | nn.Dropout(model_opt.dropout), 131 | nn.Linear(model_opt.rnn_size, 4) #B,I,O and col 132 | ) 133 | BIO_op_classifier = nn.Sequential( 134 | nn.Dropout(model_opt.dropout), 135 | nn.Linear(model_opt.rnn_size, 3) 136 | ) 137 | 138 | agg_classifier = nn.Sequential( 139 | nn.Dropout(model_opt.dropout), 140 | nn.Linear(model_opt.rnn_size, len(agg_ops)), 141 | nn.LogSoftmax()) 142 | sel_match = MatchScorer(2 * model_opt.rnn_size, 143 | model_opt.score_size, model_opt.dropout) 144 | lay_classifier = nn.Sequential( 145 | nn.Dropout(model_opt.dropout), 146 | nn.Linear(model_opt.rnn_size, len(fields['lay'].vocab)), 147 | nn.LogSoftmax()) 148 | 149 | decode_softattention=Softattention(model_opt.rnn_size,model_opt.attn_hidden) 150 | # embedding 151 | # layout encoding 152 | if model_opt.layout_encode == 'rnn': 153 | cond_embedding = make_embeddings( 154 | fields["cond_op"].vocab, model_opt.cond_op_vec_size) 155 | lay_encoder = make_encoder(model_opt, cond_embedding) 156 | else: 157 | cond_embedding = make_embeddings( 158 | fields["cond_op"].vocab, model_opt.rnn_size) 159 | lay_encoder = None 160 | 161 | # Make cond models. 162 | cond_decoder = make_cond_decoder(model_opt) 163 | cond_col_match = CondMatchScorer( 164 | MatchScorer(2 * model_opt.rnn_size, model_opt.score_size, model_opt.dropout)) 165 | 166 | label_col_match = CondMatchScorer( 167 | MatchScorer(2 * model_opt.rnn_size, model_opt.score_size, model_opt.dropout,log_softmax=False)) 168 | 169 | cond_span_l_match = CondMatchScorer( 170 | MatchScorer(2 * model_opt.rnn_size, model_opt.score_size, model_opt.dropout)) 171 | cond_span_r_match = CondMatchScorer( 172 | MatchScorer(3 * model_opt.rnn_size, model_opt.score_size, model_opt.dropout)) 173 | 174 | # Make ParserModel 175 | pad_word_index = fields["src"].vocab.stoi[table.IO.PAD_WORD] 176 | model = ParserModel(q_encoder, tbl_encoder, co_attention, agg_classifier, sel_match, lay_classifier, cond_embedding, 177 | lay_encoder, cond_decoder, cond_col_match, cond_span_l_match, cond_span_r_match, model_opt, pad_word_index, agg_self_attention,lay_self_attention,decode_softattention,BIO_classifier,BIO_op_classifier,label_col_match,op_selfattention) 178 | 179 | if checkpoint is not None: 180 | print('Loading model') 181 | model.load_state_dict(checkpoint['model']) 182 | 183 | if torch.cuda.is_available(): 184 | model.cuda() 185 | 186 | return model 187 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Models.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | import torch.nn.functional as F 8 | import table 9 | from table.Utils import aeq, sort_for_pack,sequence_mask 10 | import random 11 | import os 12 | 13 | def _build_rnn(rnn_type, input_size, hidden_size, num_layers, dropout, weight_dropout, bidirectional=False): 14 | dr = 0 if weight_dropout > 0 else dropout 15 | rnn = getattr(nn, rnn_type)(input_size, hidden_size, 16 | num_layers=num_layers, dropout=dr, bidirectional=bidirectional) 17 | if weight_dropout > 0: 18 | param_list = ['weight_hh_l0'] 19 | if bidirectional: 20 | param_list += [it + '_reverse' for it in param_list] 21 | rnn = table.modules.WeightDrop(rnn, param_list, dropout=weight_dropout) 22 | return rnn 23 | 24 | 25 | class RNNEncoder(nn.Module): 26 | """ The standard RNN encoder. """ 27 | 28 | def __init__(self, rnn_type, bidirectional, num_layers, 29 | hidden_size, dropout, lock_dropout, weight_dropout, embeddings, ent_embedding, type_embedding=None): 30 | super(RNNEncoder, self).__init__() 31 | 32 | num_directions = 2 if bidirectional else 1 33 | self.hidden_size = hidden_size 34 | self.embeddings = embeddings 35 | self.ent_embedding = ent_embedding 36 | self.type_embedding = type_embedding 37 | self.no_pack_padded_seq = False 38 | if lock_dropout: 39 | self.word_dropout = table.modules.LockedDropout(dropout) 40 | else: 41 | self.word_dropout = nn.Dropout(dropout) 42 | 43 | # Use pytorch version when available. 44 | input_size = embeddings.embedding_dim 45 | if ent_embedding is not None: 46 | input_size += ent_embedding.embedding_dim 47 | if type_embedding is not None: 48 | input_size += type_embedding.embedding_dim 49 | 50 | self.rnn = _build_rnn(rnn_type, input_size, 51 | hidden_size // num_directions, num_layers, dropout, weight_dropout, bidirectional) 52 | 53 | 54 | 55 | 56 | def forward(self, input, lengths=None, hidden=None, ent=None, type=None): 57 | emb = self.embeddings(input) 58 | if self.ent_embedding is not None: 59 | emb_ent = self.ent_embedding(ent) 60 | emb = torch.cat((emb, emb_ent), 2) 61 | if self.type_embedding is not None: 62 | emb_type = self.type_embedding(type) 63 | emb = torch.cat((emb, emb_type), 2) 64 | 65 | if self.word_dropout is not None: 66 | emb = self.word_dropout(emb) 67 | # s_len, batch, emb_dim = emb.size() 68 | 69 | packed_emb = emb 70 | need_pack = (lengths is not None) and (not self.no_pack_padded_seq) 71 | if need_pack: 72 | # Lengths data is wrapped inside a Variable. 73 | if not isinstance(lengths, list): 74 | lengths = lengths.view(-1).tolist() 75 | packed_emb = pack(emb, lengths) 76 | 77 | outputs, hidden_t = self.rnn(packed_emb, hidden) 78 | 79 | if need_pack: 80 | outputs = unpack(outputs)[0] 81 | 82 | return hidden_t, outputs 83 | 84 | 85 | def encode_unsorted_batch(encoder, tbl, tbl_len): 86 | # sort for pack() 87 | idx_sorted, tbl_len_sorted, idx_map_back = sort_for_pack(tbl_len) 88 | if torch.cuda.is_available(): 89 | tbl_sorted = tbl.index_select(1, Variable( 90 | torch.LongTensor(idx_sorted).cuda(), requires_grad=False)) 91 | else: 92 | tbl_sorted = tbl.index_select(1, Variable( 93 | torch.LongTensor(idx_sorted), requires_grad=False)) 94 | # tbl_context: (seq_len, batch, hidden_size * num_directions) 95 | __, tbl_context = encoder(tbl_sorted, tbl_len_sorted) 96 | # recover the sort for pack() 97 | if torch.cuda.is_available(): 98 | v_idx_map_back = Variable(torch.LongTensor( 99 | idx_map_back).cuda(), requires_grad=False) 100 | else: 101 | v_idx_map_back = Variable(torch.LongTensor( 102 | idx_map_back), requires_grad=False) 103 | tbl_context = tbl_context.index_select(1, v_idx_map_back) 104 | return tbl_context 105 | 106 | 107 | class TableRNNEncoder(nn.Module): 108 | def __init__(self, encoder, split_type='incell', merge_type='cat'): 109 | super(TableRNNEncoder, self).__init__() 110 | self.split_type = split_type 111 | self.merge_type = merge_type 112 | self.hidden_size = encoder.hidden_size 113 | self.encoder = encoder 114 | if self.merge_type == 'mlp': 115 | self.merge = nn.Sequential( 116 | nn.Linear(2 * self.hidden_size, self.hidden_size), 117 | nn.Tanh()) 118 | 119 | def forward(self, tbl, tbl_len, tbl_split): 120 | """ 121 | Encode table headers. 122 | :param tbl: header token list 123 | :param tbl_len: length of token list (num_table_header, batch) 124 | :param tbl_split: table header boundary list 125 | """ 126 | tbl_context = encode_unsorted_batch(self.encoder, tbl, tbl_len) 127 | # --> (num_table_header, batch, hidden_size * num_directions) 128 | if self.split_type == 'outcell': 129 | batch_index = torch.LongTensor(range(tbl_split.data.size(1))).unsqueeze_( 130 | 0).expand_as(tbl_split.data) 131 | if torch.cuda.is_available(): 132 | batch_index=batch_index.cuda() 133 | enc_split = tbl_context[tbl_split.data, batch_index, :] 134 | enc_left, enc_right = enc_split[:-1], enc_split[1:] 135 | elif self.split_type == 'incell': 136 | batch_index = torch.LongTensor(range(tbl_split.data.size(1))).unsqueeze_( 137 | 0).expand(tbl_split.data.size(0) - 1, tbl_split.data.size(1)) 138 | if torch.cuda.is_available(): 139 | batch_index=batch_index.cuda() 140 | split_left = (tbl_split.data[:-1] + 141 | 1).clamp(0, tbl_context.size(0) - 1) 142 | enc_left = tbl_context[split_left, batch_index, :] 143 | split_right = (tbl_split.data[1:] - 144 | 1).clamp(0, tbl_context.size(0) - 1) 145 | enc_right = tbl_context[split_right, batch_index, :] 146 | 147 | if self.merge_type == 'sub': 148 | return (enc_right - enc_left) 149 | elif self.merge_type == 'cat': 150 | # take half vector for each direction 151 | half_hidden_size = self.hidden_size // 2 152 | return torch.cat([enc_right[:, :, :half_hidden_size], enc_left[:, :, half_hidden_size:]], 2) 153 | elif self.merge_type == 'mlp': 154 | return self.merge(torch.cat([enc_right, enc_left], 2)) 155 | 156 | 157 | class MatchScorer(nn.Module): 158 | def __init__(self, input_size, score_size, dropout,log_softmax=True): 159 | super(MatchScorer, self).__init__() 160 | self.score_layer = nn.Sequential( 161 | nn.Dropout(dropout), 162 | nn.Linear(input_size, score_size), 163 | nn.Tanh(), 164 | nn.Linear(score_size, 1)) 165 | self.log_sm = nn.LogSoftmax(dim=-1) 166 | self.log_softmax=log_softmax 167 | 168 | def forward(self, q_enc, tbl_enc, tbl_mask, previous_distribution=None,select=False): 169 | """ 170 | Match and return table column score. 171 | :param q_enc: question encoding vectors (batch, rnn_size) 172 | :param tbl_enc: header encoding vectors (num_table_header, batch, rnn_size) 173 | :param tbl_num: length of token list 174 | """ 175 | q_enc_expand = q_enc.unsqueeze(0).expand( 176 | tbl_enc.size(0), tbl_enc.size(1), q_enc.size(1)) 177 | # (batch, num_table_header, input_size) 178 | feat = torch.cat((q_enc_expand, tbl_enc), 2).transpose(0, 1) 179 | # (batch, num_table_header) 180 | 181 | #score = self.score_layer(feat).squeeze(2) 182 | #print(q_enc_expand.size(),tbl_enc.size()) 183 | #print((q_enc_expand*tbl_enc).size()) 184 | if select: 185 | score=torch.sum(q_enc_expand*tbl_enc ,dim=2).transpose(0,1) 186 | else: 187 | score = self.score_layer(feat).squeeze(2) 188 | 189 | # mask scores 190 | score_mask = score.masked_fill(tbl_mask.type(torch.bool), -1e20)#-float('inf')) 191 | 192 | # normalize 193 | if self.log_softmax==False: 194 | return score_mask 195 | return self.log_sm(score_mask) 196 | 197 | 198 | class CondMatchScorer(nn.Module): 199 | def __init__(self, sel_match): 200 | super(CondMatchScorer, self).__init__() 201 | self.sel_match = sel_match 202 | 203 | def forward(self, cond_context_filter, tbl_enc, tbl_mask, previous_distribution=None,emb_span_l=None): 204 | """ 205 | Match and return table column score for cond decoder. 206 | :param cond_context: cond decoder's context vectors (num_cond*3, batch, rnn_size) 207 | :param tbl_enc: header encoding vectors (num_table_header, batch, rnn_size) 208 | :param tbl_num: length of token list 209 | """ 210 | # -> (num_cond, batch, rnn_size) 211 | if emb_span_l is not None: 212 | # -> (num_cond, batch, 2*rnn_size) 213 | cond_context_filter = torch.cat( 214 | (cond_context_filter, emb_span_l), 2) 215 | r_list = [] 216 | for cond_context_one in cond_context_filter: 217 | # -> (batch, num_table_header) 218 | r_list.append(self.sel_match(cond_context_one, tbl_enc, tbl_mask,previous_distribution=previous_distribution)) 219 | # (num_cond, batch, num_table_header) 220 | return torch.stack(r_list, 0) 221 | 222 | 223 | class CondDecoder(nn.Module): 224 | def __init__(self, rnn_type, bidirectional_encoder, num_layers, input_size, hidden_size, attn_type, attn_hidden, dropout, lock_dropout, weight_dropout): 225 | super(CondDecoder, self).__init__() 226 | 227 | # Basic attributes. 228 | self.decoder_type = 'rnn' 229 | self.bidirectional_encoder = bidirectional_encoder 230 | self.num_layers = num_layers 231 | self.input_size = input_size 232 | self.hidden_size = hidden_size 233 | if lock_dropout: 234 | self.word_dropout = table.modules.LockedDropout(dropout) 235 | else: 236 | self.word_dropout = nn.Dropout(dropout) 237 | 238 | # Build the RNN. 239 | self.rnn = _build_rnn(rnn_type, input_size, 240 | hidden_size, num_layers, dropout, weight_dropout) 241 | 242 | # Set up the standard attention. 243 | self.attn = table.modules.GlobalAttention( 244 | hidden_size, True, attn_type=attn_type, attn_hidden=attn_hidden) 245 | 246 | def forward(self, emb, context, state): 247 | """ 248 | Forward through the decoder. 249 | Args: 250 | input (LongTensor): a sequence of input tokens tensors 251 | of size (len x batch x nfeats). 252 | context (FloatTensor): output(tensor sequence) from the encoder 253 | RNN of size (src_len x batch x hidden_size). 254 | state (FloatTensor): hidden state from the encoder RNN for 255 | initializing the decoder. 256 | Returns: 257 | outputs (FloatTensor): a Tensor sequence of output from the decoder 258 | of shape (len x batch x hidden_size). 259 | state (FloatTensor): final hidden state from the decoder. 260 | attns (dict of (str, FloatTensor)): a dictionary of different 261 | type of attention Tensor from the decoder 262 | of shape (src_len x batch). 263 | """ 264 | # Args Check 265 | assert isinstance(state, RNNDecoderState) 266 | # END Args Check 267 | 268 | # Run the forward pass of the RNN. 269 | hidden, outputs, attns = self._run_forward_pass(emb, context, state) 270 | 271 | # Update the state with the result. 272 | state.update_state(hidden) 273 | 274 | # Concatenates sequence of tensors along a new dimension. 275 | #print(outputs.size()) 276 | #outputs = torch.stack(outputs) 277 | #print(outputs.size()) 278 | #for k in attns: 279 | # print(attns[k].size()) 280 | # attns[k] = torch.stack(attns[k]) 281 | # print(attns[k].size()) 282 | 283 | return outputs, state, attns 284 | 285 | def _fix_enc_hidden(self, h): 286 | """ 287 | The encoder hidden is (layers*directions) x batch x dim. 288 | We need to convert it to layers x batch x (directions*dim). 289 | """ 290 | if self.bidirectional_encoder: 291 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2) 292 | return h 293 | 294 | def init_decoder_state(self, context, enc_hidden): 295 | return RNNDecoderState(context, self.hidden_size, tuple([self._fix_enc_hidden(enc_hidden[i]) for i in range(len(enc_hidden))])) 296 | 297 | def _run_forward_pass(self, emb, context, state): 298 | """ 299 | Private helper for running the specific RNN forward pass. 300 | Must be overriden by all subclasses. 301 | Args: 302 | input (LongTensor): a sequence of input tokens tensors 303 | of size (len x batch x nfeats). 304 | context (FloatTensor): output(tensor sequence) from the encoder 305 | RNN of size (src_len x batch x hidden_size). 306 | state (FloatTensor): hidden state from the encoder RNN for 307 | initializing the decoder. 308 | Returns: 309 | hidden (Variable): final hidden state from the decoder. 310 | outputs ([FloatTensor]): an array of output of every time 311 | step from the decoder. 312 | attns (dict of (str, [FloatTensor]): a dictionary of different 313 | type of attention Tensor array of every time 314 | step from the decoder. 315 | """ 316 | 317 | # Initialize local and return variables. 318 | outputs = [] 319 | attns = {"std": []} 320 | 321 | if self.word_dropout is not None: 322 | emb = self.word_dropout(emb) 323 | 324 | # Run the forward pass of the RNN. 325 | rnn_output, hidden = self.rnn(emb, state.hidden) 326 | 327 | # Calculate the attention. 328 | attn_outputs, attn_scores = self.attn( 329 | rnn_output.transpose(0, 1).contiguous(), # (output_len, batch, d) 330 | context.transpose(0, 1) # (contxt_len, batch, d) 331 | ) 332 | attns["std"] = attn_scores 333 | 334 | outputs = attn_outputs # (input_len, batch, d) 335 | 336 | # Return result. 337 | return hidden, outputs, attns 338 | 339 | 340 | class DecoderState(object): 341 | """ 342 | DecoderState is a base class for models, used during translation 343 | for storing translation states. 344 | """ 345 | 346 | def detach(self): 347 | """ 348 | Detaches all Variables from the graph 349 | that created it, making it a leaf. 350 | """ 351 | for h in self._all: 352 | if h is not None: 353 | h.detach_() 354 | 355 | def beam_update(self, idx, positions, beam_size): 356 | """ Update when beam advances. """ 357 | for e in self._all: 358 | a, br, d = e.size() 359 | sentStates = e.view(a, beam_size, br // beam_size, d)[:, :, idx] 360 | sentStates.data.copy_( 361 | sentStates.data.index_select(1, positions)) 362 | 363 | 364 | class RNNDecoderState(DecoderState): 365 | def __init__(self, context, hidden_size, rnnstate): 366 | """ 367 | Args: 368 | context (FloatTensor): output from the encoder of size 369 | len x batch x rnn_size. 370 | hidden_size (int): the size of hidden layer of the decoder. 371 | rnnstate (Variable): final hidden state from the encoder. 372 | transformed to shape: layers x batch x (directions*dim). 373 | """ 374 | if not isinstance(rnnstate, tuple): 375 | self.hidden = (rnnstate,) 376 | else: 377 | self.hidden = rnnstate 378 | 379 | @property 380 | def _all(self): 381 | return self.hidden 382 | 383 | def update_state(self, rnnstate): 384 | if not isinstance(rnnstate, tuple): 385 | self.hidden = (rnnstate,) 386 | else: 387 | self.hidden = rnnstate 388 | 389 | def repeat_beam_size_times(self, beam_size): 390 | """ Repeat beam_size times along batch dimension. """ 391 | v_list = [Variable(e.data.repeat(1, beam_size, 1), volatile=True) 392 | for e in self._all] 393 | self.hidden = tuple(v_list) 394 | 395 | 396 | class CoAttention(nn.Module): 397 | def __init__(self, rnn_type, bidirectional, num_layers, hidden_size, dropout, weight_dropout, attn_type, attn_hidden): 398 | super(CoAttention, self).__init__() 399 | 400 | num_directions = 2 if bidirectional else 1 401 | self.hidden_size = hidden_size 402 | self.no_pack_padded_seq = False 403 | 404 | self.rnn = _build_rnn(rnn_type, 2 * hidden_size, hidden_size // 405 | num_directions, num_layers, dropout, weight_dropout, bidirectional) 406 | self.attn = table.modules.GlobalAttention( 407 | hidden_size, False, attn_type=attn_type, attn_hidden=attn_hidden) 408 | 409 | def forward(self, q_all, lengths, tbl_enc, tbl_mask): 410 | self.attn.applyMask(tbl_mask.data.unsqueeze(0)) 411 | # attention 412 | emb, _ = self.attn( 413 | q_all.transpose(0, 1).contiguous(), # (output_len, batch, d) 414 | tbl_enc.transpose(0, 1) # (contxt_len, batch, d) 415 | ) 416 | 417 | # feed to rnn 418 | if not isinstance(lengths, list): 419 | lengths = lengths.view(-1).tolist() 420 | packed_emb = pack(emb, lengths) 421 | 422 | outputs, hidden_t = self.rnn(packed_emb, None) 423 | 424 | outputs = unpack(outputs)[0] 425 | 426 | return hidden_t, outputs 427 | 428 | 429 | class CoAttention_wornn(nn.Module): 430 | def __init__(self, rnn_type, bidirectional, num_layers, hidden_size, dropout, weight_dropout, attn_type, attn_hidden): 431 | super(CoAttention_wornn, self).__init__() 432 | 433 | num_directions = 2 if bidirectional else 1 434 | self.hidden_size = hidden_size 435 | self.no_pack_padded_seq = False 436 | 437 | self.rnn = _build_rnn(rnn_type, 2 * hidden_size, hidden_size // 438 | num_directions, num_layers, dropout, weight_dropout, bidirectional) 439 | self.attn = table.modules.GlobalAttention( 440 | hidden_size, False, attn_type=attn_type, attn_hidden=attn_hidden) 441 | 442 | def forward(self, q_all, lengths, tbl_enc, tbl_mask): 443 | self.attn.applyMask(tbl_mask.data.unsqueeze(0)) 444 | # attention 445 | emb, _ = self.attn( 446 | q_all.transpose(0, 1).contiguous(), # (output_len, batch, d) 447 | tbl_enc.transpose(0, 1).contiguous() # (contxt_len, batch, d) 448 | ) 449 | 450 | 451 | return emb 452 | 453 | class AttnPool(nn.Module): 454 | def __init__(self, hidden_size, attn_size): 455 | super(AttnPool, self).__init__() 456 | self.hidden_size = hidden_size 457 | self.attn_size=attn_size 458 | self.softmax=nn.Softmax(dim=-1) 459 | self.m=nn.Sequential( 460 | nn.Linear(self.hidden_size, self.attn_size, bias=True), 461 | nn.Tanh(), 462 | nn.Linear(self.attn_size, 1, bias=False),#len,bsz,1 463 | ) 464 | 465 | def forward(self, q_all, lengths): 466 | #self.attn.applyMask(tbl_mask.data.unsqueeze(0)) 467 | # attention 468 | 469 | q_all=q_all.cuda() 470 | 471 | q_att = self.m(q_all) 472 | q_att = torch.squeeze(q_att,dim=2)# len,bsz 473 | #print(q_all.size(),q_att.size()) 474 | q_att=torch.transpose(q_att,0,1) 475 | #print(q_all.size(), q_att.size()) 476 | q_att=torch.transpose(self.softmax(q_att),0,1) 477 | q_att=torch.unsqueeze(q_att,dim=2) 478 | #print(q_all.size(), q_att.size()) 479 | #print(q_att) 480 | #print(q_att.size()) 481 | q_all = torch.sum(q_att * q_all, dim=0) 482 | 483 | return q_all 484 | 485 | class Softattention(nn.Module): 486 | def __init__(self,hidden_size, attn_size): 487 | super(Softattention, self).__init__() 488 | self.hidden_size = hidden_size 489 | self.attn_size = attn_size 490 | self.m=nn.Sequential( 491 | nn.Linear(self.hidden_size*2, self.attn_size, bias=True), 492 | nn.Tanh(), 493 | nn.Linear(self.attn_size, 1, bias=False),#len,bsz,1 494 | ) 495 | self.MLP=nn.Sequential( 496 | nn.Linear(self.hidden_size * 2,self.hidden_size, bias=True), 497 | nn.Tanh() 498 | ) 499 | def forward(self, hidden, q_all, lengths): 500 | # hidden (#cond,bsz, #hidden_sz) , q_all (len, bsz, #hidden_sz) 501 | #->hidden (#cond, len, bsz, #hidden_sz) , q_all (#cond, len, bsz, #hidden_sz) 502 | #print(hidden.size(), q_all.size()) 503 | new_hidden=hidden.unsqueeze(1).repeat(1,q_all.size(0),1,1) 504 | new_q_all= q_all.unsqueeze(0).repeat(hidden.size(0),1,1,1) 505 | #print(new_hidden.size(),new_q_all.size()) 506 | x=torch.cat( [new_hidden, new_q_all], dim=-1) 507 | # print(x.size()) 508 | x = self.m(x) 509 | # print(x.size()) 510 | 511 | # print(x.size()) 512 | #print mask.size() 513 | mask=sequence_mask(lengths).eq(0).type(torch.bool) #len,bsz 514 | mask=mask.transpose(0,1) 515 | mask=mask.unsqueeze(0).repeat(hidden.size(0),1,1).unsqueeze(3) #cond,len,bsz 516 | #print(x.size(),mask.size()) 517 | x.masked_fill_(mask,-1e10) 518 | 519 | x = F.softmax(x, dim=1) 520 | 521 | return self.MLP( torch.cat([hidden,torch.sum(x *new_q_all,dim=1)],dim=-1)) 522 | 523 | 524 | class tbl2q_attention(nn.Module): 525 | def __init__(self,hidden_size, attn_size): 526 | super(tbl2q_attention, self).__init__() 527 | self.hidden_size = hidden_size 528 | self.attn_size = attn_size 529 | self.m=nn.Sequential( 530 | nn.Linear(self.hidden_size*2, self.attn_size, bias=True), 531 | nn.Tanh(), 532 | nn.Linear(self.attn_size, 1, bias=False),#len,bsz,1 533 | ) 534 | self.MLP=nn.Sequential( 535 | nn.Linear(self.hidden_size * 2,self.hidden_size, bias=True), 536 | nn.Tanh() 537 | ) 538 | def forward(self, hidden, q_all, lengths): 539 | # hidden (#cond,bsz, #hidden_sz) , q_all (len, bsz, #hidden_sz) 540 | #->hidden (#cond, len, bsz, #hidden_sz) , q_all (#cond, len, bsz, #hidden_sz) 541 | #print(hidden.size(), q_all.size()) 542 | new_hidden=hidden.unsqueeze(1).repeat(1,q_all.size(0),1,1) 543 | new_q_all= q_all.unsqueeze(0).repeat(hidden.size(0),1,1,1) 544 | #print(new_hidden.size(),new_q_all.size()) 545 | x=torch.cat( [new_hidden, new_q_all], dim=-1) 546 | # print(x.size()) 547 | x = self.m(x) 548 | # print(x.size()) 549 | 550 | # print(x.size()) 551 | #print mask.size() 552 | mask=sequence_mask(lengths).eq(0).type(torch.bool) #len,bsz 553 | mask=mask.transpose(0,1) 554 | mask=mask.unsqueeze(0).repeat(hidden.size(0),1,1).unsqueeze(3) #cond,len,bsz 555 | #print(x.size(),mask.size()) 556 | x.masked_fill_(mask,-1e10) 557 | 558 | x = F.softmax(x, dim=1) 559 | 560 | #print(self.MLP( torch.cat([hidden,torch.sum(x *new_q_all,dim=1)],dim=-1)).size()) 561 | return self.MLP( torch.cat([hidden,torch.sum(x *new_q_all,dim=1)],dim=-1)) 562 | 563 | 564 | class ParserModel(nn.Module): 565 | def __init__(self, q_encoder, tbl_encoder, co_attention, agg_classifier, sel_match, lay_classifier, cond_embedding, lay_encoder, cond_decoder, 566 | cond_col_match, cond_span_l_match, cond_span_r_match, model_opt, pad_word_index, agg_self_attention,lay_self_attention,decode_softattention, BIO_classifier,BIO_op_classifier,label_col_match,op_selfattention): 567 | super(ParserModel, self).__init__() 568 | self.q_encoder = q_encoder 569 | self.tbl_encoder = tbl_encoder 570 | self.agg_classifier = agg_classifier 571 | self.sel_match = sel_match 572 | self.lay_classifier = lay_classifier 573 | self.cond_embedding = cond_embedding 574 | self.lay_encoder = lay_encoder 575 | self.cond_decoder = cond_decoder 576 | self.opt = model_opt 577 | self.span_merge = nn.Sequential( 578 | nn.Linear(2 * model_opt.rnn_size, model_opt.rnn_size), 579 | nn.Tanh()) 580 | self.cond_col_match = cond_col_match 581 | self.cond_span_l_match = cond_span_l_match 582 | self.cond_span_r_match = cond_span_r_match 583 | self.pad_word_index = pad_word_index 584 | self.co_attention = co_attention 585 | self.agg_self_attention = agg_self_attention 586 | self.lay_self_attention=lay_self_attention 587 | self.decode_softattention=decode_softattention 588 | self.BIO_classifier=BIO_classifier 589 | self.BIO_op_classifier=BIO_op_classifier 590 | self.label_col_match=label_col_match 591 | self.op_selfattention=op_selfattention 592 | self.tbl2q_att=tbl2q_attention(model_opt.rnn_size,model_opt.attn_hidden) 593 | vocab_dic = dict(torch.load(os.path.join(self.opt.data, 'vocab.pt'))) 594 | 595 | self.id2words = vocab_dic['src'].itos 596 | 597 | self.FC=nn.Sequential( 598 | nn.Dropout(model_opt.dropout), 599 | nn.Linear(model_opt.rnn_size*2, model_opt.rnn_size), 600 | ) 601 | 602 | def enc(self, q, q_len, ent, type, tbl, tbl_len, tbl_split, tbl_mask): 603 | q_enc, q_all = self.q_encoder(q, lengths=q_len, ent=ent, type=type) 604 | tbl_enc = self.tbl_encoder(tbl, tbl_len, tbl_split) 605 | 606 | tbl_enc_new=self.tbl2q_att(tbl_enc,q_all,q_len) 607 | 608 | #print(tbl_mask.size(), tbl_enc.size()) 609 | 610 | if self.co_attention is not None: 611 | q_enc, q_all = self.co_attention(q_all, q_len, tbl_enc, tbl_mask) 612 | 613 | # (num_layers * num_directions, batch, hidden_size) 614 | q_ht, q_ct = q_enc 615 | batch_size = q_ht.size(1) 616 | q_ht = q_ht[-1] if not self.opt.brnn else q_ht[-2:].transpose( 617 | 0, 1).contiguous().view(batch_size, -1) 618 | #q_ht=[q_1;q_l] 619 | 620 | 621 | return q_enc, q_all, tbl_enc_new, q_ht, batch_size 622 | 623 | def select3(self, cond_context, start_index): 624 | return cond_context[start_index:cond_context.size( 625 | 0):3] 626 | 627 | def forward(self, q, q_len, ent, type, tbl, tbl_len, tbl_split, tbl_mask, cond_op, cond_op_len, cond_col, cond_span_l, cond_span_r, lay): 628 | fff = 0 629 | if random.random() < 0.000: 630 | fff = 1 631 | # encoding 632 | q_enc, q_all, tbl_enc, q_ht, batch_size = self.enc( 633 | q, q_len, ent, type, tbl, tbl_len, tbl_split, tbl_mask) 634 | 635 | BIO_op_out = self.BIO_op_classifier(q_all) 636 | tsp_q = BIO_op_out.size(0) 637 | bsz = BIO_op_out.size(1) 638 | BIO_op_out = BIO_op_out.view(-1, BIO_op_out.size(2)) 639 | BIO_op_out = F.log_softmax(BIO_op_out,dim=-1) 640 | BIO_op_out_sf = torch.exp(BIO_op_out) 641 | BIO_op_out = BIO_op_out.view(tsp_q, bsz, -1) 642 | BIO_op_out_sf = BIO_op_out_sf.view(tsp_q, bsz, -1) 643 | 644 | 645 | BIO_out = self.BIO_classifier(q_all) 646 | BIO_out = BIO_out.view(-1, BIO_out.size(2)) 647 | BIO_out = F.log_softmax(BIO_out,dim=-1) 648 | BIO_out_sf = torch.exp(BIO_out) 649 | BIO_out = BIO_out.view(tsp_q, bsz, -1) 650 | BIO_out_sf = BIO_out_sf.view(tsp_q, bsz, -1) 651 | 652 | BIO_col_out = self.label_col_match(q_all, tbl_enc, tbl_mask) 653 | BIO_col_out = BIO_col_out.view(-1, BIO_col_out.size(2)) 654 | BIO_col_out = F.log_softmax(BIO_col_out,dim=-1) 655 | BIO_col_out_sf = torch.exp(BIO_col_out) 656 | BIO_col_out = BIO_col_out.view(tsp_q, bsz, -1) 657 | BIO_col_out_sf = BIO_col_out_sf.view(tsp_q, bsz, -1) 658 | if fff == 1: 659 | print([self.id2words[w] for w in q.transpose(0,1)[0].data ]) 660 | print([self.id2words[w] for w in tbl.transpose(0, 1)[0].data ]) 661 | print(BIO_out_sf.transpose(0, 1)[0]) 662 | print(BIO_op_out_sf.transpose(0, 1)[0]) 663 | print(BIO_col_out_sf.transpose(0, 1)[0]) 664 | 665 | 666 | 667 | 668 | #q_all=self.FC(torch.cat([q_all,BIO_out],dim=2)) 669 | # (1) decoding 670 | q_self_encode=self.agg_self_attention(q_all,q_len) 671 | agg_out = self.agg_classifier(q_self_encode) #select op 672 | #agg_out = self.agg_classifier(q_ht) 673 | sel_out = self.sel_match(q_self_encode, tbl_enc, tbl_mask,select=True) # select column 674 | #sel_out = self.sel_match(q_ht, tbl_enc, tbl_mask) #select column 675 | q_self_encode_layout=self.lay_self_attention(q_all,q_len) 676 | lay_out = self.lay_classifier(q_self_encode_layout) #layout predict 677 | 678 | # (2) decoding 679 | # emb_op 680 | if self.opt.layout_encode == 'rnn': 681 | emb_op = encode_unsorted_batch( 682 | self.lay_encoder, cond_op, cond_op_len.clamp(min=1)) 683 | else: 684 | emb_op = self.cond_embedding(cond_op) 685 | # emb_col 686 | batch_index = torch.LongTensor(range(batch_size)).unsqueeze_( 687 | 0).expand(cond_col.size(0), cond_col.size(1)) 688 | if torch.cuda.is_available(): 689 | batch_index = batch_index.cuda() 690 | emb_col = tbl_enc[cond_col.data, batch_index, :] 691 | # emb_span_l/r: (num_cond, batch, hidden_size) 692 | #print(cond_span_l.size()) 693 | #print(q_all.size()) 694 | #print (batch_index) 695 | emb_span_l = q_all[cond_span_l.data, batch_index, :] 696 | #print(emb_span_l.size()) 697 | 698 | emb_span_r = q_all[cond_span_r.data, batch_index, :] 699 | emb_span = self.span_merge(torch.cat([emb_span_l, emb_span_r], 2)) 700 | 701 | # stack embeddings 702 | # (seq_len*3, batch, hidden_size) 703 | emb = torch.stack([emb_op, emb_col, emb_span], 704 | 1).view(-1, batch_size, emb_op.size(2)) 705 | 706 | # cond decoder 707 | self.cond_decoder.attn.applyMaskBySeqBatch(q) 708 | q_state = self.cond_decoder.init_decoder_state(q_all, q_enc) 709 | cond_context, _, _ = self.cond_decoder(emb, q_all, q_state)# input in decode, encode hidden, initial decode hidden 710 | 711 | 712 | # cond col 713 | cond_context_0 = self.select3(cond_context, 0)#pick one of each 3 714 | #print(cond_context_0.size(),q_self_encode.size()) 715 | 716 | #print(cond_context_0.size()) 717 | # cond_context_0=self.decode_softattention(cond_context_0,q_all,q_len) 718 | #print(cond_context_0.size()) 719 | # print('\n') 720 | 721 | cond_col_out = self.cond_col_match(cond_context_0, tbl_enc, tbl_mask) 722 | 723 | 724 | # cond span 725 | q_mask = Variable(q.data.eq(self.pad_word_index).transpose( 726 | 0, 1), requires_grad=False) 727 | cond_context_1 = self.select3(cond_context, 1) 728 | cond_span_l_out = self.cond_span_l_match( 729 | cond_context_1, q_all, q_mask) 730 | cond_span_r_out = self.cond_span_r_match( 731 | cond_context_1, q_all, q_mask, emb_span_l=emb_span_l) 732 | 733 | return agg_out, sel_out, lay_out, cond_col_out, cond_span_l_out, cond_span_r_out,BIO_out,BIO_col_out,BIO_op_out 734 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm_ 3 | 4 | 5 | class Optim(object): 6 | 7 | def set_parameters(self, params): 8 | self.params = [p for p in params if p.requires_grad] 9 | if self.method == 'sgd': 10 | self.optimizer = optim.SGD(self.params, lr=self.lr) 11 | elif self.method == 'rmsprop': 12 | self.optimizer = optim.RMSprop( 13 | self.params, lr=self.lr, alpha=self.alpha) 14 | elif self.method == 'adam': 15 | self.optimizer = optim.Adam(self.params, lr=self.lr, 16 | betas=self.betas, eps=1e-9) 17 | else: 18 | raise RuntimeError("Invalid optim method: " + self.method) 19 | 20 | def __init__(self, method, lr, alpha, max_grad_norm, 21 | lr_decay=1, start_decay_at=None, 22 | beta1=0.9, beta2=0.98, 23 | opt=None): 24 | self.last_metric = None 25 | self.lr = lr 26 | self.alpha = alpha 27 | self.max_grad_norm = max_grad_norm 28 | self.method = method 29 | self.lr_decay = lr_decay 30 | self.start_decay_at = start_decay_at 31 | self.start_decay = False 32 | self._step = 0 33 | self.betas = [beta1, beta2] 34 | self.opt = opt 35 | 36 | def _setRate(self, lr): 37 | self.lr = lr 38 | self.optimizer.param_groups[0]['lr'] = self.lr 39 | 40 | def step(self): 41 | "Compute gradients norm." 42 | self._step += 1 43 | 44 | # Decay method used in tensor2tensor. 45 | if self.opt.__dict__.get("decay_method", "") == "noam": 46 | self._setRate( 47 | self.opt.learning_rate * 48 | (self.opt.rnn_size ** (-0.5) * 49 | min(self._step ** (-0.5), 50 | self._step * self.opt.warmup_steps**(-1.5)))) 51 | 52 | if self.max_grad_norm: 53 | clip_grad_norm_(self.params, self.max_grad_norm) 54 | self.optimizer.step() 55 | 56 | def updateLearningRate(self, metric, epoch): 57 | """ 58 | Decay learning rate if val perf does not improve 59 | or we hit the start_decay_at limit. 60 | """ 61 | 62 | if (self.start_decay_at is not None) and (epoch >= self.start_decay_at): 63 | self.start_decay = True 64 | if (self.last_metric is not None) and (metric is not None) and (metric > self.last_metric): 65 | self.start_decay = True 66 | 67 | if self.start_decay: 68 | self.lr = self.lr * self.lr_decay 69 | print("Decaying learning rate to %g" % self.lr) 70 | 71 | self.last_metric = metric 72 | self.optimizer.param_groups[0]['lr'] = self.lr 73 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/ParseResult.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import sys 4 | sys.path.append("..") 5 | from lib.dbengine import DBEngine 6 | from lib.query import Query 7 | 8 | 9 | class ParseResult(object): 10 | def __init__(self, idx, agg, sel, cond,BIO,BIO_col): 11 | self.idx = idx 12 | self.agg = agg 13 | self.sel = sel 14 | self.cond = cond 15 | self.BIO=BIO 16 | self.BIO_col=BIO_col 17 | self.correct = defaultdict(lambda: 0) 18 | 19 | def eval(self,split, gold, sql_gold, engine=None): 20 | if self.agg == gold['query']['agg']: 21 | self.correct['agg'] = 1 22 | 23 | if self.sel == gold['query']['sel']: 24 | self.correct['sel'] = 1 25 | 26 | 27 | 28 | #gold['BIO_label'].data, gold['BIO_column_label'].data 29 | 30 | op_list_pred = [op for col, op, span in self.cond] 31 | op_list_gold = [op for col, op, span in gold['query']['conds']] 32 | 33 | col_list_pred = [col for col, op, span in self.cond] 34 | col_list_gold = [col for col, op, span in gold['query']['conds']] 35 | 36 | q = gold['question']['words'] 37 | span_list_pred = [' '.join(q[span[0]:span[1] + 1]) 38 | for col, op, span in self.cond] 39 | span_list_gold = [' '.join(span['words']) 40 | for col, op, span in gold['query']['conds']] 41 | 42 | where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred)) 43 | where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold)) 44 | where_pred.sort() 45 | where_gold.sort() 46 | if where_pred == where_gold and (len(col_list_pred) == len(col_list_gold)) and (len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)): 47 | self.correct['where'] = 1 48 | 49 | if (len(col_list_pred) == len(col_list_gold)) and ([it[0] for it in where_pred] == [it[0] for it in where_gold]): 50 | self.correct['col'] = 1 51 | 52 | if (len(op_list_pred) == len(op_list_gold)) and ([it[1] for it in where_pred] == [it[1] for it in where_gold]): 53 | self.correct['lay'] = 1 54 | 55 | if (len(span_list_pred) == len(span_list_gold)) and ([it[2] for it in where_pred] == [it[2] for it in where_gold]): 56 | self.correct['span'] = 1 57 | 58 | if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))): 59 | self.correct['all'] = 1 60 | 61 | # execution 62 | table_id = gold['table_id'] 63 | ans_gold = '0' 64 | ans_pred = '1' 65 | if engine is not None: 66 | ans_gold = engine.execute_query( 67 | table_id, Query.from_dict(sql_gold), lower=True) 68 | 69 | try: 70 | sql_pred = {'agg':self.agg, 'sel':self.sel, 'conds': self.recover_cond_to_gloss(gold)} 71 | ans_pred = engine.execute_query( 72 | table_id, Query.from_dict(sql_pred), lower=True) 73 | except Exception as e: 74 | ans_pred = repr(e) 75 | else: 76 | ans_gold='0' 77 | ans_pred='1' 78 | if set(ans_gold) == set(ans_pred): 79 | self.correct['exe'] = 1 80 | 81 | error_case = {} 82 | if split == 'finaltest': 83 | #if self.correct['where'] != 1: 84 | if True: 85 | error_case['sel'] = self.correct['sel'] 86 | error_case['where'] = self.correct['where'] 87 | error_case['all'] = self.correct['all'] 88 | error_case['table_id'] = gold['table_id'] 89 | error_case['question_id']=gold['id'] 90 | error_case['question'] = gold['question']['words'] 91 | error_case['table_head'] = [head['words'] for head in gold['table']['header']] 92 | #error_case['table_content'] = gold['tbl_content'] 93 | 94 | # for i in range(len(sql_gold['conds'])): 95 | # sql_gold['conds'][i][0] = ( 96 | # sql_gold['conds'][i][0], gold['table']['header'][sql_gold['conds'][i][0]]['words'], 97 | # gold['tbl_content'][sql_gold['conds'][i][0]]) 98 | error_case['gold'] = sql_gold 99 | 100 | error_case['predict'] = {'agg': self.agg.item(), 'sel': self.sel.item(), 101 | 'conds': self.print_recover_cond_to_gloss(gold)} 102 | # error_case['exe result']=self.correct['exe'] 103 | 104 | error_case['BIO']=[(x.item(), y) for x, y in zip(list(self.BIO), list(gold['question']['words']))] 105 | error_case['BIO_col']=self.BIO_col.tolist() 106 | 107 | return error_case 108 | 109 | 110 | def recover_cond_to_gloss(self, gold): 111 | r_list = [] 112 | for col, op, span in self.cond: 113 | tk_list = [] 114 | for i in range(span[0], span[1] + 1): 115 | tk_list.append(gold['question']['gloss'][i]) 116 | tk_list.append(gold['question']['after'][i]) 117 | r_list.append([col, op, ''.join(tk_list).strip()]) 118 | return r_list 119 | def print_recover_cond_to_gloss(self, gold): 120 | r_list = [] 121 | for col, op, span in self.cond: 122 | tk_list = [] 123 | for i in range(span[0], span[1] + 1): 124 | tk_list.append(gold['question']['gloss'][i]) 125 | tk_list.append(gold['question']['after'][i]) 126 | #r_list.append([(col.item(), gold['table']['header'][col.item()]['words'], gold['tbl_content'][col.item()]), op,''.join(tk_list).strip()]) 127 | r_list.append([col.item(), op,''.join(tk_list).strip()]) 128 | return r_list -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the loadable seq2seq trainer library that is 3 | in charge of training details, loss compute, and statistics. 4 | """ 5 | from __future__ import division 6 | import os 7 | import time 8 | import sys 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | 13 | import table 14 | import table.modules 15 | from table.Utils import argmax 16 | import random 17 | 18 | class Statistics(object): 19 | def __init__(self, loss, eval_result): 20 | self.loss = loss 21 | self.eval_result = eval_result 22 | self.start_time = time.time() 23 | 24 | def update(self, stat): 25 | self.loss += stat.loss 26 | for k, v in stat.eval_result.items(): 27 | if k in self.eval_result: 28 | v0 = self.eval_result[k][0] + v[0] 29 | v1 = self.eval_result[k][1] + v[1] 30 | self.eval_result[k] = (v0, v1) 31 | else: 32 | self.eval_result[k] = (v[0], v[1]) 33 | 34 | def accuracy(self, return_str=False): 35 | d = sorted([(k, v) 36 | for k, v in self.eval_result.items()], key=lambda x: x[0]) 37 | #print(d) 38 | if return_str: 39 | return '; '.join((('{}: {:.2%}'.format(k, 1.0*v[0] / v[1])) for k, v in d)) 40 | else: 41 | return dict([(k, 100.0 * v[0] / v[1]) for k, v in d]) 42 | 43 | def elapsed_time(self): 44 | return time.time() - self.start_time 45 | 46 | def output(self, epoch, batch, n_batches, start): 47 | print(("Epoch %2d, %5d/%5d; %s; %.0f s elapsed") % 48 | (epoch, batch, n_batches, self.accuracy(True), time.time() - start)) 49 | sys.stdout.flush() 50 | 51 | def log(self, split, logger, lr, step): 52 | pass 53 | 54 | 55 | def count_accuracy(scores, target, mask=None, row=False): 56 | pred = argmax(scores) 57 | if mask is None: 58 | m_correct = pred.eq(target) 59 | num_all = m_correct.numel() 60 | elif row: 61 | m_correct = pred.eq(target).masked_fill_( 62 | mask.type(torch.bool), 1).prod(0, keepdim=False) 63 | 64 | #print('m_correct_row', m_correct.type()) 65 | num_all = m_correct.numel() 66 | else: 67 | non_mask = mask.ne(1).type(torch.bool) 68 | m_correct = pred.eq(target).masked_select(non_mask) 69 | num_all = non_mask.sum().item() 70 | 71 | m_correct = m_correct.type(torch.LongTensor) 72 | if torch.cuda.is_available(): 73 | m_correct=m_correct.cuda() 74 | return (m_correct, num_all) 75 | 76 | 77 | 78 | def count_condition_value_F1(scores1,golden_scores1, target_ls1,target_rs1): 79 | preds = argmax(scores1) 80 | 81 | preds = preds.transpose(0,1) 82 | target_ls=target_ls1.transpose(0,1) 83 | target_rs=target_rs1.transpose(0,1) 84 | golden_scores=golden_scores1.transpose(0,1) 85 | 86 | # print(type(preds),preds.size(),target_ls.size()) 87 | 88 | total_p=0 89 | total_r=0 90 | matched = 0 91 | exact_matched=0 92 | for sample_id in range(preds.size(0)): 93 | pred=preds[sample_id] 94 | #print(pred) 95 | golden_score=golden_scores[sample_id] 96 | #print(g) 97 | target_l=target_ls[sample_id] 98 | #print(target_l) 99 | target_r=target_rs[sample_id] 100 | #print(target_r) 101 | exact_matched+=1 102 | for i in range(pred.size(0)): 103 | if pred[i]!=golden_score[i] and golden_score[i]!=-1: 104 | exact_matched-=1 105 | break 106 | cond_span_lr = [] 107 | l=0 108 | r=0 109 | for i in range(pred.size(0)): 110 | if pred[i]==0: 111 | if l!=0: 112 | cond_span_lr.append((l,r)) 113 | l=i 114 | r=i 115 | elif pred[i]==1: 116 | r=i 117 | else: 118 | if l!=0: 119 | cond_span_lr.append((l,r)) 120 | l=0 121 | r=0 122 | if l != 0: 123 | cond_span_lr.append((l, r)) 124 | 125 | for l,r in cond_span_lr: 126 | for i in range(target_l.size(0)): 127 | if l==target_l[i] and r==target_r[i]: 128 | matched+=1 129 | for i in range(target_l.size(0)): 130 | if target_l[i]!=-1: 131 | total_r+=1 132 | total_p+=len(cond_span_lr) 133 | 134 | #print(matched,total_p,total_r) 135 | #if random.random()<0.01: 136 | # print(preds[:10]) 137 | # print(golden_scores[:10]) 138 | recall=1.0*matched/(total_r+1e-10) 139 | precision=1.0*matched/(total_p+1e-10) 140 | return (exact_matched,preds.size(0)),(precision,1),(recall,1),(recall*precision*2,(recall+precision+1e-10)) 141 | #for span in target: 142 | # if span in cond_span_lr: 143 | # recall+=1 144 | #for span in cond_span_lr: 145 | # if span in target: 146 | # precision+=1 147 | #precision=len(cond_span_lr)>0 ? 1.0*precision/len(cond_span_lr) : 0 148 | #recall = len(target) > 0 ? 1.0 * recall / len(stargetpan): 0 149 | #return precision*recall/(precision+recall)/2.0 150 | 151 | 152 | 153 | def count_condition_value_EM_column_op(scores1,scores_col1,scores_op1, golden_scores1, golden_scores_col1,golden_scores_op1, target_ls1, target_rs1, target_cols1): 154 | preds = argmax(scores1) 155 | preds_col = argmax(scores_col1) 156 | preds_op = argmax(scores_op1) 157 | 158 | preds = preds.transpose(0,1) 159 | preds_col = preds_col.transpose(0, 1) 160 | preds_op = preds_op.transpose(0, 1) 161 | 162 | target_ls = target_ls1.transpose(0,1) 163 | target_rs = target_rs1.transpose(0,1) 164 | golden_scores = golden_scores1.transpose(0,1) 165 | golden_scores_col = golden_scores_col1.transpose(0, 1) 166 | golden_scores_op = golden_scores_op1.transpose(0, 1) 167 | target_cols = target_cols1.transpose(0,1) 168 | 169 | total_p=0 170 | total_r=0 171 | 172 | exact_matched=0 173 | exact_matched_op=0 174 | exact_matched_col=0 175 | for sample_id in range(preds.size(0)): 176 | pred=preds[sample_id] 177 | pred_col=preds_col[sample_id] 178 | pred_op = preds_op[sample_id] 179 | golden_score=golden_scores[sample_id] 180 | golden_score_col = golden_scores_col[sample_id] 181 | golden_score_op = golden_scores_op[sample_id] 182 | 183 | exact_matched += 1 184 | exact_matched_op += 1 185 | exact_matched_col += 1 186 | BIO_not_match = False 187 | for i in range(pred.size(0)): 188 | if pred[i] != golden_score[i] and golden_score[i] != -1: 189 | exact_matched-=1 190 | exact_matched_op-=1 191 | exact_matched_col-=1 192 | BIO_not_match = True 193 | break 194 | 195 | if BIO_not_match == False: 196 | col_not_match = False 197 | for i in range(pred.size(0)): 198 | if pred[i]==0: 199 | column_cnt = [] 200 | for j in range(torch.max(pred_col) + 2): 201 | column_cnt.append(0) 202 | column_cnt[pred_col[i]] = 1 203 | for j in range(i+1,pred.size(0)): 204 | if pred[j]!=1: 205 | break 206 | column_cnt[pred_col[j]] += 1 207 | max_cnt=0 208 | argmax1=pred_col[i] 209 | for j in range(torch.max(pred_col)+2): 210 | if column_cnt[j]>max_cnt: 211 | max_cnt=column_cnt[j] 212 | argmax1=j 213 | if argmax1!=golden_score_col[i]: 214 | exact_matched_col-=1 215 | col_not_match=True 216 | break 217 | 218 | op_not_match = False 219 | for i in range(pred.size(0)): 220 | if pred[i]==0: 221 | op_cnt = [] 222 | for j in range(3): 223 | op_cnt.append(0) 224 | op_cnt[pred_op[i]] = 1 225 | for j in range(i+1,pred.size(0)): 226 | if pred[j]!=1: 227 | break 228 | op_cnt[pred_op[j]] += 1 229 | max_cnt=0 230 | argmax1=pred_op[i] 231 | for j in range(3): 232 | if op_cnt[j]>max_cnt: 233 | max_cnt=op_cnt[j] 234 | argmax1=j 235 | if argmax1!=golden_score_op[i]: 236 | exact_matched_op-=1 237 | op_not_match=True 238 | break 239 | if op_not_match or col_not_match: 240 | exact_matched-=1 241 | 242 | return (exact_matched,preds.size(0)),(exact_matched_col,preds.size(0)),(exact_matched_op,preds.size(0)),0 243 | 244 | 245 | 246 | def count_condition_value_F1_column(scores1,scores_col1, golden_scores1, golden_scores_col1, target_ls1, target_rs1,gold_cols): 247 | preds = argmax(scores1) 248 | pred_cols = argmax(scores_col1) 249 | 250 | preds = preds.transpose(0, 1) 251 | target_ls = target_ls1.transpose(0, 1) 252 | target_rs = target_rs1.transpose(0, 1) 253 | 254 | golden_scores = golden_scores1.transpose(0, 1) 255 | pred_cols = pred_cols.transpose(0,1) 256 | golden_scores_col1 = golden_scores_col1.transpose(0,1) 257 | gold_cols = gold_cols.transpose(0, 1) 258 | # print(type(preds),preds.size(),target_ls.size()) 259 | 260 | total_p = 0 261 | total_r = 0 262 | matched = 0 263 | exact_matched = 0 264 | for sample_id in range(preds.size(0)): 265 | pred = preds[sample_id] 266 | # print(pred) 267 | golden_score = golden_scores[sample_id] 268 | # print(g) 269 | target_l = target_ls[sample_id] 270 | # print(target_l) 271 | target_r = target_rs[sample_id] 272 | # print(target_r) 273 | 274 | pred_col=pred_cols[sample_id] 275 | golden_score_col1 = golden_scores_col1[sample_id] 276 | gold_col = gold_cols[sample_id] 277 | 278 | exact_matched += 1 279 | for i in range(pred.size(0)): 280 | if pred[i] != golden_score[i] and golden_score[i] != -1: 281 | exact_matched -= 1 282 | break 283 | cond_span_lr = [] 284 | l = 0 285 | r = 0 286 | for i in range(pred.size(0)): 287 | if pred[i] == 0: 288 | if l != 0: 289 | cond_span_lr.append((l, r)) 290 | l = i 291 | r = i 292 | elif pred[i] == 1: 293 | r = i 294 | else: 295 | if l != 0: 296 | cond_span_lr.append((l, r)) 297 | l = 0 298 | r = 0 299 | if l != 0: 300 | cond_span_lr.append((l, r)) 301 | 302 | for l, r in cond_span_lr: 303 | for i in range(target_l.size(0)): 304 | #print(l,r,pred_col.size(),gold_col.size(),target_l.size(0)) 305 | if l == target_l[i] and r == target_r[i] and pred_col[l]==gold_col[i]: 306 | matched += 1 307 | for i in range(target_l.size(0)): 308 | if target_l[i] != -1: 309 | total_r += 1 310 | total_p += len(cond_span_lr) 311 | 312 | # print(matched,total_p,total_r) 313 | # if random.random()<0.01: 314 | # print(preds[:10]) 315 | # print(golden_scores[:10]) 316 | recall = 1.0 * matched / (total_r + 1e-10) 317 | precision = 1.0 * matched / (total_p + 1e-10) 318 | return (exact_matched, preds.size(0)), (precision, 1), (recall, 1), (recall * precision * 2, (recall + precision + 1e-10)) 319 | 320 | 321 | def count_where_accuracy(score_span_l, score_span_r, score_col, gold_span_ls, gold_span_rs, gold_cols): 322 | pred_span_ls = argmax(score_span_l) 323 | pred_span_rs = argmax(score_span_r) 324 | pred_cols = argmax(score_col) 325 | pred_span_ls =pred_span_ls.transpose(0,1) 326 | pred_span_rs = pred_span_rs.transpose(0, 1) 327 | pred_cols = pred_cols.transpose(0,1) 328 | gold_span_ls = gold_span_ls.transpose(0, 1) 329 | gold_span_rs = gold_span_rs.transpose(0, 1) 330 | gold_cols = gold_cols.transpose(0, 1) 331 | exact_matched = 0 332 | for sample_id in range(pred_span_ls.size(0)): 333 | pred_span_l = pred_span_ls[sample_id] 334 | pred_span_r = pred_span_rs[sample_id] 335 | pred_col = pred_cols[sample_id] 336 | 337 | gold_span_l = gold_span_ls[sample_id] 338 | gold_span_r = gold_span_rs[sample_id] 339 | gold_col = gold_cols[sample_id] 340 | 341 | exact_matched+=1 342 | for i in range(pred_span_l.size(0)): 343 | if gold_col[i]!=-1: 344 | if gold_col[i]!=pred_col[i] or gold_span_l[i]!=pred_span_l[i] or gold_span_r[i]!= pred_span_r[i]: 345 | exact_matched-=1 346 | break 347 | return exact_matched,pred_span_ls.size(0) 348 | 349 | def aggregate_accuracy(r_dict, metric_name_list): 350 | m_list = [] 351 | for metric_name in metric_name_list: 352 | m_list.append(r_dict[metric_name][0]) 353 | #print(r_dict[metric_name][0].size(),r_dict[metric_name][0].type()) 354 | #print (len(m_list),m_list[0].type(),m_list[0].size()) 355 | agg= torch.stack(m_list, dim=0) 356 | agg = agg.prod(0, keepdim=False) 357 | return (agg.sum().item(), agg.numel()) 358 | 359 | 360 | class Trainer(object): 361 | def __init__(self, model, train_iter, valid_iter, 362 | train_loss, valid_loss, optim): 363 | """ 364 | Args: 365 | model: the seq2seq model. 366 | train_iter: the train data iterator. 367 | valid_iter: the validate data iterator. 368 | train_loss: the train side LossCompute object for computing loss. 369 | valid_loss: the valid side LossCompute object for computing loss. 370 | optim: the optimizer responsible for lr update. 371 | """ 372 | # Basic attributes. 373 | self.model = model 374 | self.train_iter = train_iter 375 | self.valid_iter = valid_iter 376 | self.train_loss = train_loss 377 | self.valid_loss = valid_loss 378 | self.optim = optim 379 | 380 | # Set model in training mode. 381 | self.model.train() 382 | 383 | def forward(self, batch, criterion): 384 | # 1. F-prop. 385 | q, q_len = batch.src 386 | #print(q) 387 | tbl, tbl_len = batch.tbl 388 | cond_op, cond_op_len = batch.cond_op 389 | agg_out, sel_out, lay_out, cond_col_out, cond_span_l_out, cond_span_r_out, BIO_out, BIO_column_out, BIO_op_out = self.model( 390 | q, q_len, batch.ent,batch.type, tbl, tbl_len, batch.tbl_split, batch.tbl_mask, cond_op, cond_op_len, batch.cond_col, batch.cond_span_l, batch.cond_span_r, batch.lay) 391 | 392 | # 2. Compute loss. 393 | pred = {'agg': agg_out, 'sel': sel_out, 'lay': lay_out, 'cond_col': cond_col_out, 394 | 'cond_span_l': cond_span_l_out, 'cond_span_r': cond_span_r_out, 'BIO_label': BIO_out, 395 | 'BIO_column_label': BIO_column_out, 'BIO_op_label':BIO_op_out} 396 | #print(lay_out) 397 | gold = {'agg': batch.agg, 'sel': batch.sel, 'lay': batch.lay, 'cond_col': batch.cond_col_loss, 398 | 'cond_span_l': batch.cond_span_l_loss, 'cond_span_r': batch.cond_span_r_loss,'BIO_label': batch.BIO_label_loss, 'BIO_column_label': batch.BIO_column_label_loss, 399 | 'BIO_op_label':batch.BIO_op_label_loss} 400 | #print(batch.lay) 401 | loss = criterion.compute_loss(pred, gold) 402 | 403 | # 3. Get the batch statistics. 404 | r_dict = {} 405 | #print('1',argmax(pred['agg'].data)) 406 | #print('2',gold['agg'].data) 407 | 408 | for metric_name in ('agg', 'sel', 'lay'): 409 | r_dict[metric_name] = count_accuracy( 410 | pred[metric_name].data, gold[metric_name].data) 411 | for metric_name in ('cond_col', 'cond_span_l', 'cond_span_r'): 412 | #r_dict[metric_name + '-token'] = count_accuracy( 413 | # pred[metric_name].data, gold[metric_name].data, mask=gold[metric_name].data.eq(-1), row=False) 414 | r_dict[metric_name] = count_accuracy( 415 | pred[metric_name].data, gold[metric_name].data, mask=gold[metric_name].data.eq(-1), row=True) 416 | 417 | for metric_name in ('BIO_label','BIO_column_label','BIO_op_label'): 418 | r_dict[metric_name] = count_accuracy( 419 | pred[metric_name].data, gold[metric_name].data, mask=gold[metric_name].data.eq(-1), row=False) 420 | 421 | #print('3', r_dict['agg'][0]) 422 | st = dict([(k, (int(v[0].sum()), v[1])) for k, v in r_dict.items()]) 423 | #print('4', st['agg']) 424 | prf=count_condition_value_F1(pred['BIO_label'].data, gold['BIO_label'].data,gold['cond_span_l'].data,gold['cond_span_r'].data) 425 | st['BIO_label-EM'] = prf[0] 426 | #st['BIO_label-P'] = prf[1] 427 | #st['BIO_label-R'] = prf[2] 428 | st['BIO_label-F1']=prf[3] 429 | 430 | 431 | prf = count_condition_value_EM_column_op(pred['BIO_label'].data, pred['BIO_column_label'].data, pred['BIO_op_label'].data, 432 | gold['BIO_label'].data, gold['BIO_column_label'].data,gold['BIO_op_label'].data, 433 | gold['cond_span_l'].data, 434 | gold['cond_span_r'].data, gold['cond_col'].data) 435 | 436 | st['ALL-label-EM'] = prf[0] 437 | st['BIO_coloum_label-EM'] = prf[1] 438 | st['BIO_op_label-EM'] = prf[2] 439 | 440 | prf = count_condition_value_F1_column(pred['BIO_label'].data, pred['BIO_column_label'].data, 441 | gold['BIO_label'].data, gold['BIO_column_label'].data, 442 | gold['cond_span_l'].data, 443 | gold['cond_span_r'].data,gold['cond_col'].data) 444 | st['BIO_coloum_label-P'] = prf[1] 445 | st['BIO_coloum_label-R'] = prf[2] 446 | st['BIO_coloum_label-F1'] = prf[3] 447 | 448 | 449 | #st['where'] = aggregate_accuracy( 450 | # r_dict, ('lay', 'cond_col', 'cond_span_l', 'cond_span_r')) 451 | st['where'] = count_where_accuracy( 452 | pred['cond_span_l'].data, pred['cond_span_r'].data, pred['cond_col'].data, 453 | gold['cond_span_l'].data, gold['cond_span_r'].data, gold['cond_col'].data 454 | ) 455 | 456 | 457 | st['all'] = aggregate_accuracy( 458 | r_dict, ('agg', 'sel', 'lay', 'cond_col', 'cond_span_l', 'cond_span_r')) 459 | #st['all'] = aggregate_accuracy( 460 | # r_dict, ('agg', 'sel', 'where')) 461 | #batch_stats = Statistics(loss.data[0], st) 462 | batch_stats = Statistics(loss.data.item(), st) 463 | 464 | return loss, batch_stats 465 | 466 | def train(self, epoch, report_func=None): 467 | """ Called for each epoch to train. """ 468 | total_stats = Statistics(0, {}) 469 | report_stats = Statistics(0, {}) 470 | #print(type(self.train_iter)) 471 | for i, batch in enumerate(self.train_iter): 472 | self.model.zero_grad() 473 | #print(batch) 474 | loss, batch_stats = self.forward(batch, self.train_loss) 475 | 476 | # Update the parameters and statistics. 477 | loss.backward() 478 | self.optim.step() 479 | total_stats.update(batch_stats) 480 | report_stats.update(batch_stats) 481 | 482 | if report_func is not None: 483 | report_stats = report_func( 484 | epoch, i, len(self.train_iter), 485 | total_stats.start_time, self.optim.lr, report_stats) 486 | 487 | return total_stats 488 | 489 | def validate(self): 490 | """ Called for each epoch to validate. """ 491 | # Set model in validating mode. 492 | self.model.eval() 493 | 494 | stats = Statistics(0, {}) 495 | for batch in self.valid_iter: 496 | loss, batch_stats = self.forward(batch, self.valid_loss) 497 | 498 | # Update statistics. 499 | stats.update(batch_stats) 500 | 501 | # Set model back to training mode. 502 | self.model.train() 503 | 504 | return stats 505 | 506 | def epoch_step(self, eval_metric, epoch): 507 | """ Called for each epoch to update learning rate. """ 508 | return self.optim.updateLearningRate(eval_metric, epoch) 509 | 510 | def drop_checkpoint(self, opt, epoch, fields, valid_stats): 511 | """ Called conditionally each epoch to save a snapshot. """ 512 | 513 | model_state_dict = self.model.state_dict() 514 | model_state_dict = {k: v for k, v in model_state_dict.items() 515 | if 'generator' not in k} 516 | checkpoint = { 517 | 'model': model_state_dict, 518 | 'vocab': table.IO.TableDataset.save_vocab(fields), 519 | 'opt': opt, 520 | 'epoch': epoch, 521 | 'optim': self.optim 522 | } 523 | eval_result = valid_stats.accuracy() 524 | torch.save(checkpoint, os.path.join( 525 | opt.save_path, 'm_%d.pt' % (epoch))) 526 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | import table 5 | import table.IO 6 | import table.ModelConstructor 7 | import table.Models 8 | import table.modules 9 | from table.Utils import add_pad, argmax 10 | from table.ParseResult import ParseResult 11 | import torch.nn.functional as F 12 | def v_eval(a): 13 | return Variable(a, volatile=True) 14 | 15 | 16 | def cpu_vector(v): 17 | return v.clone().view(-1).cpu() 18 | 19 | 20 | class Translator(object): 21 | def __init__(self, opt, dummy_opt): 22 | # Add in default model arguments, possibly added since training. 23 | self.opt = opt 24 | checkpoint = torch.load(opt.model, 25 | map_location=lambda storage, loc: storage) 26 | self.fields = table.IO.TableDataset.load_fields(checkpoint['vocab']) 27 | 28 | model_opt = checkpoint['opt'] 29 | model_opt.pre_word_vecs = opt.pre_word_vecs 30 | for arg in dummy_opt: 31 | if arg not in model_opt: 32 | model_opt.__dict__[arg] = dummy_opt[arg] 33 | 34 | self.model = table.ModelConstructor.make_base_model( 35 | model_opt, self.fields, checkpoint) 36 | self.model.eval() 37 | 38 | def translate(self, batch): 39 | q, q_len = batch.src 40 | tbl, tbl_len = batch.tbl 41 | ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask 42 | 43 | # encoding 44 | 45 | q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc( 46 | q, q_len, ent, batch.type, tbl, tbl_len, tbl_split, tbl_mask) #query, query length, table, table length, table split, table mask 47 | 48 | BIO_op_out = self.model.BIO_op_classifier(q_all) 49 | tsp_q = BIO_op_out.size(0) 50 | bsz = BIO_op_out.size(1) 51 | BIO_op_out = BIO_op_out.view(-1, BIO_op_out.size(2)) 52 | BIO_op_out = F.log_softmax(BIO_op_out,dim=-1) 53 | BIO_op_out_sf = torch.exp(BIO_op_out) 54 | BIO_op_out = BIO_op_out.view(tsp_q, bsz, -1) 55 | BIO_op_out_sf = BIO_op_out_sf.view(tsp_q, bsz, -1) 56 | 57 | # if fff == 1: 58 | # print(BIO_op_out_sf.transpose(0,1)[0]) 59 | # print(BIO_op_out.transpose(0, 1)[0]) 60 | 61 | BIO_out = self.model.BIO_classifier(q_all) 62 | BIO_out = BIO_out.view(-1, BIO_out.size(2)) 63 | BIO_out = F.log_softmax(BIO_out,dim=-1) 64 | BIO_out_sf = torch.exp(BIO_out) 65 | BIO_out = BIO_out.view(tsp_q, bsz, -1) 66 | BIO_out_sf = BIO_out_sf.view(tsp_q, bsz, -1) 67 | # if fff == 1: 68 | # print(BIO_out_sf.transpose(0,1)[0]) 69 | # print(BIO_out.transpose(0, 1)[0]) 70 | 71 | BIO_col_out = self.model.label_col_match(q_all, tbl_enc, tbl_mask) 72 | # if fff == 1: 73 | # print(BIO_col_out.size()) 74 | # print(BIO_col_out.transpose(0, 1)[0]) 75 | BIO_col_out = BIO_col_out.view(-1, BIO_col_out.size(2)) 76 | BIO_col_out = F.log_softmax(BIO_col_out,dim=-1) 77 | BIO_col_out_sf = torch.exp(BIO_col_out) 78 | BIO_col_out = BIO_col_out.view(tsp_q, bsz, -1) 79 | BIO_col_out_sf = BIO_col_out_sf.view(tsp_q, bsz, -1) 80 | 81 | BIO_pred = argmax(BIO_out_sf.data).transpose(0, 1) 82 | BIO_col_pred = argmax(BIO_col_out_sf.data).transpose(0, 1) 83 | for i in range(BIO_pred.size(0)): 84 | for j in range(BIO_pred.size(1)): 85 | if BIO_pred[i][j] == 2: 86 | BIO_col_pred[i][j] = -1 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | # (1) decoding 96 | q_self_encode = self.model.agg_self_attention(q_all, q_len)#q_ht 97 | q_self_encode_layout = self.model.lay_self_attention(q_all, q_len)#q_ht 98 | agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_self_encode).data)) 99 | sel_out = self.model.sel_match(q_self_encode, tbl_enc, tbl_mask,select=True) # select column 100 | sel_pred = cpu_vector(argmax(self.model.sel_match( 101 | q_self_encode, tbl_enc, tbl_mask,select=True).data)) 102 | lay_pred = argmax(self.model.lay_classifier(q_self_encode_layout).data) 103 | # get layout op tokens 104 | op_batch_list = [] 105 | op_idx_batch_list = [] 106 | if self.opt.gold_layout: 107 | lay_pred = batch.lay.data 108 | cond_op, cond_op_len = batch.cond_op 109 | cond_op_len_list = cond_op_len.view(-1).tolist() 110 | for i, len_it in enumerate(cond_op_len_list): 111 | if len_it == 0: 112 | op_idx_batch_list.append([]) 113 | op_batch_list.append([]) 114 | else: 115 | idx_list = cond_op.data[0:len_it, i].contiguous().view(-1).tolist() 116 | op_idx_batch_list.append([int(self.fields['cond_op'].vocab.itos[it]) for it in idx_list]) 117 | op_batch_list.append(idx_list) 118 | else: 119 | lay_batch_list = lay_pred.view(-1).tolist() 120 | for lay_it in lay_batch_list: 121 | tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ') 122 | if (len(tk_list) == 0) or (tk_list[0] == ''): 123 | op_idx_batch_list.append([]) 124 | op_batch_list.append([]) 125 | else: 126 | op_idx_batch_list.append([int(op_str) for op_str in tk_list]) 127 | op_batch_list.append( 128 | [self.fields['cond_op'].vocab.stoi[op_str] for op_str in tk_list]) 129 | # -> (num_cond, batch) 130 | cond_op = v_eval(add_pad( 131 | op_batch_list, self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t()) 132 | cond_op_len = torch.LongTensor([len(it) for it in op_batch_list]) 133 | # emb_op -> (num_cond, batch, emb_size) 134 | if self.model.opt.layout_encode == 'rnn': 135 | emb_op = table.Models.encode_unsorted_batch( 136 | self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1)) 137 | else: 138 | emb_op = self.model.cond_embedding(cond_op) 139 | 140 | # (2) decoding 141 | self.model.cond_decoder.attn.applyMaskBySeqBatch(q) 142 | cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc) 143 | cond_col_list, cond_span_l_list, cond_span_r_list = [], [], [] 144 | for emb_op_t in emb_op: 145 | emb_op_t = emb_op_t.unsqueeze(0) 146 | cond_context, cond_state, _ = self.model.cond_decoder( 147 | emb_op_t, q_all, cond_state) 148 | #print(cond_context.size()) 149 | #cond_context = self.model.decode_softattention(cond_context, q_all, q_len) 150 | #print(cond_context.size()) 151 | 152 | # cond col -> (1, batch) 153 | cond_col = argmax(self.model.cond_col_match( 154 | cond_context, tbl_enc, tbl_mask).data) 155 | cond_col_list.append(cpu_vector(cond_col)) 156 | # emb_col 157 | batch_index = torch.LongTensor(range(batch_size)).unsqueeze_(0).cuda().expand( 158 | cond_col.size(0), cond_col.size(1)) 159 | emb_col = tbl_enc[cond_col, batch_index, :] 160 | cond_context, cond_state, _ = self.model.cond_decoder( 161 | emb_col, q_all, cond_state) 162 | 163 | 164 | # cond span 165 | q_mask = v_eval( 166 | q.data.eq(self.model.pad_word_index).transpose(0, 1)) 167 | cond_span_l = argmax(self.model.cond_span_l_match( 168 | cond_context, q_all, q_mask).data) 169 | cond_span_l_list.append(cpu_vector(cond_span_l)) 170 | # emb_span_l: (1, batch, hidden_size) 171 | emb_span_l = q_all[cond_span_l, batch_index, :] 172 | cond_span_r = argmax(self.model.cond_span_r_match( 173 | cond_context, q_all, q_mask, emb_span_l=emb_span_l).data) 174 | cond_span_r_list.append(cpu_vector(cond_span_r)) 175 | # emb_span_r: (1, batch, hidden_size) 176 | emb_span_r = q_all[cond_span_r, batch_index, :] 177 | 178 | emb_span = self.model.span_merge( 179 | torch.cat([emb_span_l, emb_span_r], 2)) 180 | 181 | # mask = torch.zeros([cond_col.size(0), q_all.size(0), q_all.size(1)]) # (num_cond,tsp,bsz) 182 | # for j in range(q_all.size(1)): 183 | # for i in range(cond_col.size(0)): 184 | # for k in range(cond_span_l[i][j], cond_span_r[i][j] + 1): 185 | # mask[i][k][j] = 1 186 | 187 | # mask = mask.unsqueeze_(3) # .expand(cond_col.size(0),q_all.size(0),q_all.size(1),q_all.size(2)) 188 | 189 | 190 | # emb_span = Variable(mask.cuda()) * torch.unsqueeze(q_all, 0) # .expand_as(mask) #(num_cond,tsp,bsz,hidden) 191 | # emb_span = torch.mean(emb_span, dim=1) # (num_cond,bsz,hidden) mean pooling 192 | 193 | cond_context, cond_state, _ = self.model.cond_decoder( 194 | emb_span, q_all, cond_state) 195 | 196 | # (3) recover output 197 | indices = cpu_vector(batch.indices.data) 198 | r_list = [] 199 | for b in range(batch_size): 200 | idx = indices[b] 201 | agg = agg_pred[b] 202 | sel = sel_pred[b] 203 | BIO = BIO_pred[b] 204 | BIO_col = BIO_col_pred[b] 205 | cond = [] 206 | for i in range(len(op_batch_list[b])): 207 | col = cond_col_list[i][b] 208 | op = op_idx_batch_list[b][i] 209 | span_l = cond_span_l_list[i][b] 210 | span_r = cond_span_r_list[i][b] 211 | cond.append((col, op, (span_l, span_r))) 212 | r_list.append(ParseResult(idx, agg, sel, cond,BIO, BIO_col)) 213 | 214 | return r_list 215 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/Utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | 8 | def aeq(*args): 9 | """ 10 | Assert all arguments have the same value 11 | """ 12 | arguments = (arg for arg in args) 13 | first = next(arguments) 14 | assert all(arg == first for arg in arguments), \ 15 | "Not all arguments have the same value: " + str(args) 16 | 17 | 18 | def set_seed(seed): 19 | """Sets random seed everywhere.""" 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | 25 | 26 | def sort_for_pack(input_len): 27 | idx_sorted, input_len_sorted = zip( 28 | *sorted(list(enumerate(input_len)), key=lambda x: x[1], reverse=True)) 29 | idx_sorted, input_len_sorted = list(idx_sorted), list(input_len_sorted) 30 | idx_map_back = list(map(lambda x: x[0], sorted( 31 | list(enumerate(idx_sorted)), key=lambda x: x[1]))) 32 | return idx_sorted, input_len_sorted, idx_map_back 33 | 34 | 35 | def argmax(scores): 36 | return scores.max(scores.dim() - 1)[1] 37 | 38 | 39 | def add_pad(b_list, pad_index, return_tensor=True): 40 | max_len = max((len(b) for b in b_list)) 41 | r_list = [] 42 | for b in b_list: 43 | r_list.append(b + [pad_index] * (max_len - len(b))) 44 | if return_tensor: 45 | return torch.LongTensor(r_list).cuda() 46 | else: 47 | return r_list 48 | 49 | 50 | def sequence_mask(sequence_length, max_len=None): 51 | if max_len is None: 52 | max_len = sequence_length.data.max() 53 | batch_size = sequence_length.size(0) 54 | seq_range = torch.range(0, max_len - 1).long() 55 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 56 | 57 | if sequence_length.is_cuda: 58 | seq_range_expand = seq_range_expand.cuda() 59 | seq_length_expand = (sequence_length.unsqueeze(1) 60 | .expand_as(seq_range_expand)) 61 | return seq_range_expand < seq_length_expand -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/__init__.py: -------------------------------------------------------------------------------- 1 | import table.IO 2 | import table.Models 3 | import table.Loss 4 | import table.ParseResult 5 | from table.Trainer import Trainer, Statistics 6 | from table.Translator import Translator 7 | from table.Optim import Optim 8 | from table.Beam import Beam, GNMTGlobalScorer 9 | 10 | 11 | # # For flake8 compatibility 12 | # __all__ = [table.Loss, table.IO, table.Models, Trainer, Translator, 13 | # Optim, Beam, Statistics, GNMTGlobalScorer] 14 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JD-AI-Research-Silicon-Valley/auxiliary-task-for-text-to-sql/9c0ff806cabab9e06b1b7fd0fac557bae79ff610/zero-shot-text-to-SQL/table/modules/.DS_Store -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/Embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class PartUpdateEmbedding(nn.Module): 7 | def __init__(self, update_index, emb_update, emb_fixed): 8 | super(PartUpdateEmbedding, self).__init__() 9 | self.update_index = update_index 10 | self.emb_update = emb_update 11 | self.emb_fixed = emb_fixed 12 | self.should_update = True 13 | self.embedding_dim = emb_update.embedding_dim 14 | 15 | def set_update(self, should_update): 16 | self.should_update = should_update 17 | 18 | def forward(self, inp): 19 | assert(inp.dim() == 2) 20 | r_update = self.emb_update(inp.clamp(0, self.update_index - 1)) 21 | r_fixed = self.emb_fixed(inp) 22 | mask = Variable(inp.data.lt(self.update_index).float().unsqueeze( 23 | 2).expand_as(r_update), requires_grad=False) 24 | r_update = r_update.mul(mask) 25 | r_fixed = r_fixed.mul(1 - mask) 26 | if self.should_update: 27 | return r_update + r_fixed 28 | else: 29 | return r_update + Variable(r_fixed.data, requires_grad=False) 30 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/Gate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context gate is a decoder module that takes as input the previous word 3 | embedding, the current decoder state and the attention state, and produces a 4 | gate. 5 | The gate can be used to select the input from the target side context 6 | (decoder state), from the source context (attention state) or both. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def ContextGateFactory(type, embeddings_size, decoder_size, 13 | attention_size, output_size): 14 | """Returns the correct ContextGate class""" 15 | 16 | gate_types = {'source': SourceContextGate, 17 | 'target': TargetContextGate, 18 | 'both': BothContextGate} 19 | 20 | assert type in gate_types, "Not valid ContextGate type: {0}".format(type) 21 | return gate_types[type](embeddings_size, decoder_size, attention_size, 22 | output_size) 23 | 24 | 25 | class ContextGate(nn.Module): 26 | """Implement up to the computation of the gate""" 27 | 28 | def __init__(self, embeddings_size, decoder_size, 29 | attention_size, output_size): 30 | super(ContextGate, self).__init__() 31 | input_size = embeddings_size + decoder_size + attention_size 32 | self.gate = nn.Linear(input_size, output_size, bias=True) 33 | self.sig = nn.Sigmoid() 34 | self.source_proj = nn.Linear(attention_size, output_size) 35 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 36 | output_size) 37 | 38 | def forward(self, prev_emb, dec_state, attn_state): 39 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 40 | z = self.sig(self.gate(input_tensor)) 41 | proj_source = self.source_proj(attn_state) 42 | proj_target = self.target_proj( 43 | torch.cat((prev_emb, dec_state), dim=1)) 44 | return z, proj_source, proj_target 45 | 46 | 47 | class SourceContextGate(nn.Module): 48 | """Apply the context gate only to the source context""" 49 | 50 | def __init__(self, embeddings_size, decoder_size, 51 | attention_size, output_size): 52 | super(SourceContextGate, self).__init__() 53 | self.context_gate = ContextGate(embeddings_size, decoder_size, 54 | attention_size, output_size) 55 | self.tanh = nn.Tanh() 56 | 57 | def forward(self, prev_emb, dec_state, attn_state): 58 | z, source, target = self.context_gate( 59 | prev_emb, dec_state, attn_state) 60 | return self.tanh(target + z * source) 61 | 62 | 63 | class TargetContextGate(nn.Module): 64 | """Apply the context gate only to the target context""" 65 | 66 | def __init__(self, embeddings_size, decoder_size, 67 | attention_size, output_size): 68 | super(TargetContextGate, self).__init__() 69 | self.context_gate = ContextGate(embeddings_size, decoder_size, 70 | attention_size, output_size) 71 | self.tanh = nn.Tanh() 72 | 73 | def forward(self, prev_emb, dec_state, attn_state): 74 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 75 | return self.tanh(z * target + source) 76 | 77 | 78 | class BothContextGate(nn.Module): 79 | """Apply the context gate to both contexts""" 80 | 81 | def __init__(self, embeddings_size, decoder_size, 82 | attention_size, output_size): 83 | super(BothContextGate, self).__init__() 84 | self.context_gate = ContextGate(embeddings_size, decoder_size, 85 | attention_size, output_size) 86 | self.tanh = nn.Tanh() 87 | 88 | def forward(self, prev_emb, dec_state, attn_state): 89 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 90 | return self.tanh((1. - z) * target + z * source) 91 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from table.modules.UtilClass import BottleLinear 5 | from table.Utils import aeq 6 | import table.IO 7 | 8 | class GlobalAttention(nn.Module): 9 | """ 10 | Luong Attention. 11 | 12 | Global attention takes a matrix and a query vector. It 13 | then computes a parameterized convex combination of the matrix 14 | based on the input query. 15 | 16 | 17 | H_1 H_2 H_3 ... H_n 18 | q q q q 19 | | | | | 20 | \ | | / 21 | ..... 22 | \ | / 23 | a 24 | 25 | Constructs a unit mapping. 26 | $$(H_1 + H_n, q) => (a)$$ 27 | Where H is of `batch x n x dim` and q is of `batch x dim`. 28 | 29 | Luong Attention (dot, general): 30 | The full function is 31 | $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$. 32 | 33 | * dot: $$score(h_t,{\overline{h}}_s) = h_t^T{\overline{h}}_s$$ 34 | * general: $$score(h_t,{\overline{h}}_s) = h_t^T W_a {\overline{h}}_s$$ 35 | 36 | Bahdanau Attention (mlp): 37 | $$c = \sum_{j=1}^{SeqLength}\a_jh_j$$. 38 | The Alignment-function $$a$$ computes an alignment as: 39 | $$a_j = softmax(v_a^T \tanh(W_a q + U_a h_j) )$$. 40 | 41 | """ 42 | 43 | def __init__(self, dim, is_transform_out, attn_type="dot", attn_hidden=0): 44 | super(GlobalAttention, self).__init__() 45 | 46 | self.dim = dim 47 | self.attn_type = attn_type 48 | self.attn_hidden = attn_hidden 49 | assert (self.attn_type in ["dot", "general", "mlp"]), ( 50 | "Please select a valid attention type.") 51 | 52 | if attn_hidden > 0: 53 | self.transform_in = nn.Sequential( 54 | nn.Linear(dim, attn_hidden), 55 | nn.ELU(0.1)) 56 | 57 | if self.attn_type == "general": 58 | d = attn_hidden if attn_hidden > 0 else dim 59 | self.linear_in = nn.Linear(d, d, bias=False) 60 | # initialization 61 | # self.linear_in.weight.data.add_(torch.eye(d)) 62 | elif self.attn_type == "mlp": 63 | self.linear_context = BottleLinear(dim, dim, bias=False) 64 | self.linear_query = nn.Linear(dim, dim, bias=True) 65 | self.v = BottleLinear(dim, 1, bias=False) 66 | # mlp wants it with bias 67 | out_bias = self.attn_type == "mlp" 68 | if is_transform_out: 69 | self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) 70 | else: 71 | self.linear_out = None 72 | 73 | self.sm = nn.Softmax(dim=-1) 74 | self.tanh = nn.Tanh() 75 | self.mask = None 76 | 77 | def applyMask(self, mask): 78 | self.mask = mask 79 | 80 | def applyMaskBySeqBatch(self, q): 81 | self.applyMask(q.data.eq(table.IO.PAD).t().contiguous().unsqueeze(0).type(torch.bool)) 82 | 83 | def score(self, h_t, h_s): 84 | """ 85 | h_t (FloatTensor): batch x tgt_len x dim 86 | h_s (FloatTensor): batch x src_len x dim 87 | returns scores (FloatTensor): batch x tgt_len x src_len: 88 | raw attention scores for each src index 89 | """ 90 | 91 | # Check input sizes 92 | src_batch, src_len, src_dim = h_s.size() 93 | tgt_batch, tgt_len, tgt_dim = h_t.size() 94 | aeq(src_batch, tgt_batch) 95 | aeq(src_dim, tgt_dim) 96 | aeq(self.dim, src_dim) 97 | 98 | if self.attn_type in ["general", "dot"]: 99 | if self.attn_hidden > 0: 100 | h_t = self.transform_in(h_t) 101 | h_s = self.transform_in(h_s) 102 | if self.attn_type == "general": 103 | h_t = self.linear_in(h_t) 104 | h_s_ = h_s.transpose(1, 2) 105 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 106 | return torch.bmm(h_t, h_s_) 107 | else: 108 | dim = self.dim 109 | wq = self.linear_query(h_t.view(-1, dim)) 110 | wq = wq.view(tgt_batch, tgt_len, 1, dim) 111 | wq = wq.expand(tgt_batch, tgt_len, src_len, dim) 112 | 113 | uh = self.linear_context(h_s.contiguous().view(-1, dim)) 114 | uh = uh.view(src_batch, 1, src_len, dim) 115 | uh = uh.expand(src_batch, tgt_len, src_len, dim) 116 | 117 | # (batch, t_len, s_len, d) 118 | wquh = self.tanh(wq + uh) 119 | 120 | return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) 121 | 122 | def forward(self, input, context): 123 | """ 124 | input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output. 125 | context (FloatTensor): batch x src_len x dim: src hidden states 126 | """ 127 | 128 | # one step input 129 | if input.dim() == 2: 130 | one_step = True 131 | input = input.unsqueeze(1) 132 | else: 133 | one_step = False 134 | 135 | batch, sourceL, dim = context.size() 136 | batch_, targetL, dim_ = input.size() 137 | aeq(batch, batch_) 138 | aeq(dim, dim_) 139 | aeq(self.dim, dim) 140 | 141 | if self.mask is not None: 142 | beam_, batch_, sourceL_ = self.mask.size() 143 | aeq(batch, batch_ * beam_) 144 | aeq(sourceL, sourceL_) 145 | 146 | # compute attention scores, as in Luong et al. 147 | align = self.score(input, context) 148 | 149 | if self.mask is not None: 150 | mask_ = self.mask.view(batch, 1, sourceL).type(torch.bool) # make it broardcastable 151 | align.data.masked_fill_(mask_, -float('inf')) 152 | 153 | # Softmax to normalize attention weights 154 | align_vectors = self.sm(align.view(batch * targetL, sourceL)) 155 | align_vectors = align_vectors.view(batch, targetL, sourceL) 156 | 157 | # each context vector c_t is the weighted average 158 | # over all the source hidden states 159 | c = torch.bmm(align_vectors, context) 160 | 161 | # concatenate 162 | concat_c = torch.cat([c, input], 2) 163 | if self.linear_out is None: 164 | attn_h = concat_c 165 | else: 166 | attn_h = self.linear_out(concat_c) 167 | if self.attn_type in ["general", "dot"]: 168 | attn_h = self.tanh(attn_h) 169 | 170 | if one_step: 171 | attn_h = attn_h.squeeze(1) 172 | align_vectors = align_vectors.squeeze(1) 173 | 174 | # Check output sizes 175 | batch_, dim_ = attn_h.size() 176 | aeq(batch, batch_) 177 | # aeq(dim, dim_) 178 | batch_, sourceL_ = align_vectors.size() 179 | aeq(batch, batch_) 180 | aeq(sourceL, sourceL_) 181 | else: 182 | attn_h = attn_h.transpose(0, 1).contiguous() 183 | align_vectors = align_vectors.transpose(0, 1).contiguous() 184 | 185 | # Check output sizes 186 | targetL_, batch_, dim_ = attn_h.size() 187 | aeq(targetL, targetL_) 188 | aeq(batch, batch_) 189 | # aeq(dim, dim_) 190 | targetL_, batch_, sourceL_ = align_vectors.size() 191 | aeq(targetL, targetL_) 192 | aeq(batch, batch_) 193 | aeq(sourceL, sourceL_) 194 | 195 | return attn_h, align_vectors 196 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/LockedDropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class LockedDropout(nn.Module): 7 | def __init__(self, dropout_rate): 8 | super(LockedDropout, self).__init__() 9 | self.dropout_rate = dropout_rate 10 | 11 | def forward(self, x): 12 | if not self.training or not self.dropout_rate: 13 | return x 14 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) 15 | mask = Variable(m, requires_grad=False) / (1 - self.dropout_rate) 16 | mask = mask.expand_as(x) 17 | return mask * x 18 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/StackedRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class StackedLSTM(nn.Module): 6 | """ 7 | Our own implementation of stacked LSTM. 8 | Needed for the decoder, because we do input feeding. 9 | """ 10 | def __init__(self, num_layers, input_size, rnn_size, dropout): 11 | super(StackedLSTM, self).__init__() 12 | self.dropout = nn.Dropout(dropout) 13 | self.num_layers = num_layers 14 | self.layers = nn.ModuleList() 15 | 16 | for i in range(num_layers): 17 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 18 | input_size = rnn_size 19 | 20 | def forward(self, input, hidden): 21 | h_0, c_0 = hidden 22 | h_1, c_1 = [], [] 23 | for i, layer in enumerate(self.layers): 24 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 25 | input = h_1_i 26 | if i + 1 != self.num_layers: 27 | input = self.dropout(input) 28 | h_1 += [h_1_i] 29 | c_1 += [c_1_i] 30 | 31 | h_1 = torch.stack(h_1) 32 | c_1 = torch.stack(c_1) 33 | 34 | return input, (h_1, c_1) 35 | 36 | 37 | class StackedGRU(nn.Module): 38 | 39 | def __init__(self, num_layers, input_size, rnn_size, dropout): 40 | super(StackedGRU, self).__init__() 41 | self.dropout = nn.Dropout(dropout) 42 | self.num_layers = num_layers 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(num_layers): 46 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 47 | input_size = rnn_size 48 | 49 | def forward(self, input, hidden): 50 | h_1 = [] 51 | for i, layer in enumerate(self.layers): 52 | h_1_i = layer(input, hidden[0][i]) 53 | input = h_1_i 54 | if i + 1 != self.num_layers: 55 | input = self.dropout(input) 56 | h_1 += [h_1_i] 57 | 58 | h_1 = torch.stack(h_1) 59 | return input, (h_1,) 60 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/UtilClass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Bottle(nn.Module): 6 | def forward(self, input): 7 | if len(input.size()) <= 2: 8 | return super(Bottle, self).forward(input) 9 | size = input.size()[:2] 10 | out = super(Bottle, self).forward(input.view(size[0]*size[1], -1)) 11 | return out.contiguous().view(size[0], size[1], -1) 12 | 13 | 14 | class Bottle2(nn.Module): 15 | def forward(self, input): 16 | if len(input.size()) <= 3: 17 | return super(Bottle2, self).forward(input) 18 | size = input.size() 19 | out = super(Bottle2, self).forward(input.view(size[0]*size[1], 20 | size[2], size[3])) 21 | return out.contiguous().view(size[0], size[1], size[2], size[3]) 22 | 23 | 24 | class LayerNorm(nn.Module): 25 | ''' Layer normalization module ''' 26 | 27 | def __init__(self, d_hid, eps=1e-3): 28 | super(LayerNorm, self).__init__() 29 | 30 | self.eps = eps 31 | self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) 32 | self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) 33 | 34 | def forward(self, z): 35 | if z.size(1) == 1: 36 | return z 37 | mu = torch.mean(z, dim=1) 38 | sigma = torch.std(z, dim=1) 39 | # HACK. PyTorch is changing behavior 40 | if mu.dim() == 1: 41 | mu = mu.unsqueeze(1) 42 | sigma = sigma.unsqueeze(1) 43 | ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) 44 | ln_out = ln_out.mul(self.a_2.expand_as(ln_out)) \ 45 | + self.b_2.expand_as(ln_out) 46 | return ln_out 47 | 48 | 49 | class BottleLinear(Bottle, nn.Linear): 50 | pass 51 | 52 | 53 | class BottleLayerNorm(Bottle, LayerNorm): 54 | pass 55 | 56 | 57 | class BottleSoftmax(Bottle, nn.Softmax): 58 | pass 59 | 60 | 61 | class Elementwise(nn.ModuleList): 62 | """ 63 | A simple network container. 64 | Parameters are a list of modules. 65 | Inputs are a 3d Variable whose last dimension is the same length 66 | as the list. 67 | Outputs are the result of applying modules to inputs elementwise. 68 | An optional merge parameter allows the outputs to be reduced to a 69 | single Variable. 70 | """ 71 | 72 | def __init__(self, merge=None, *args): 73 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 74 | self.merge = merge 75 | super(Elementwise, self).__init__(*args) 76 | 77 | def forward(self, input): 78 | inputs = [feat.squeeze(2) for feat in input.split(1, dim=2)] 79 | assert len(self) == len(inputs) 80 | outputs = [f(x) for f, x in zip(self, inputs)] 81 | if self.merge == 'first': 82 | return outputs[0] 83 | elif self.merge == 'concat' or self.merge == 'mlp': 84 | return torch.cat(outputs, 2) 85 | elif self.merge == 'sum': 86 | return sum(outputs) 87 | else: 88 | return outputs 89 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/WeightDrop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from functools import wraps 4 | 5 | 6 | class WeightDrop(torch.nn.Module): 7 | def __init__(self, module, weights, dropout=0, variational=False): 8 | super(WeightDrop, self).__init__() 9 | self.module = module 10 | self.weights = weights 11 | self.dropout = dropout 12 | self.variational = variational 13 | self._setup() 14 | 15 | def widget_demagnetizer_y2k_edition(*args, **kwargs): 16 | # We need to replace flatten_parameters with a nothing function 17 | # It must be a function rather than a lambda as otherwise pickling explodes 18 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! 19 | return 20 | 21 | def _setup(self): 22 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN 23 | if issubclass(type(self.module), torch.nn.RNNBase): 24 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition 25 | 26 | for name_w in self.weights: 27 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) 28 | w = getattr(self.module, name_w) 29 | del self.module._parameters[name_w] 30 | self.module.register_parameter(name_w + '_raw', Parameter(w.data)) 31 | 32 | def _setweights(self): 33 | for name_w in self.weights: 34 | raw_w = getattr(self.module, name_w + '_raw') 35 | w = None 36 | if self.variational: 37 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) 38 | if raw_w.is_cuda: 39 | mask = mask.cuda() 40 | mask = torch.nn.functional.dropout( 41 | mask, p=self.dropout, training=True) 42 | w = mask.expand_as(raw_w) * raw_w 43 | else: 44 | w = torch.nn.functional.dropout( 45 | raw_w, p=self.dropout, training=self.training) 46 | setattr(self.module, name_w, w) 47 | 48 | def forward(self, *args): 49 | self._setweights() 50 | return self.module.forward(*args) 51 | 52 | 53 | if __name__ == '__main__': 54 | import torch 55 | 56 | # Input is (seq, batch, input) 57 | x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() 58 | h0 = None 59 | 60 | ### 61 | 62 | print('Testing WeightDrop') 63 | print('=-=-=-=-=-=-=-=-=-=') 64 | 65 | ### 66 | 67 | print('Testing WeightDrop with Linear') 68 | 69 | lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) 70 | lin.cuda() 71 | run1 = [x.sum() for x in lin(x).data] 72 | run2 = [x.sum() for x in lin(x).data] 73 | 74 | print('All items should be different') 75 | print('Run 1:', run1) 76 | print('Run 2:', run2) 77 | 78 | assert run1[0] != run2[0] 79 | assert run1[1] != run2[1] 80 | 81 | print('---') 82 | 83 | ### 84 | 85 | print('Testing WeightDrop with LSTM') 86 | 87 | wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) 88 | wdrnn.cuda() 89 | 90 | run1 = [x.sum() for x in wdrnn(x, h0)[0].data] 91 | run2 = [x.sum() for x in wdrnn(x, h0)[0].data] 92 | 93 | print('First timesteps should be equal, all others should differ') 94 | print('Run 1:', run1) 95 | print('Run 2:', run2) 96 | 97 | # First time step, not influenced by hidden to hidden weights, should be equal 98 | assert run1[0] == run2[0] 99 | # Second step should not 100 | assert run1[1] != run2[1] 101 | 102 | print('---') 103 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/WeightNorm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Weight Normalization: A Simple Reparameterization 3 | to Accelerate Training of Deep Neural Networks" 4 | As a reparameterization method, weight normalization is same 5 | as BatchNormalization, but it doesn't depend on minibatch. 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import Parameter 11 | from torch.autograd import Variable 12 | 13 | 14 | def get_var_maybe_avg(namespace, var_name, training, polyak_decay): 15 | # utility for retrieving polyak averaged params 16 | # Update average 17 | v = getattr(namespace, var_name) 18 | v_avg = getattr(namespace, var_name + '_avg') 19 | v_avg -= (1 - polyak_decay) * (v_avg - v.data) 20 | 21 | if training: 22 | return v 23 | else: 24 | return Variable(v_avg) 25 | 26 | 27 | def get_vars_maybe_avg(namespace, var_names, training, polyak_decay): 28 | # utility for retrieving polyak averaged params 29 | vars = [] 30 | for vn in var_names: 31 | vars.append(get_var_maybe_avg( 32 | namespace, vn, training, polyak_decay)) 33 | return vars 34 | 35 | 36 | class WeightNormLinear(nn.Linear): 37 | def __init__(self, in_features, out_features, 38 | init_scale=1., polyak_decay=0.9995): 39 | super(WeightNormLinear, self).__init__( 40 | in_features, out_features, bias=True) 41 | 42 | self.V = self.weight 43 | self.g = Parameter(torch.Tensor(out_features)) 44 | self.b = self.bias 45 | 46 | self.register_buffer( 47 | 'V_avg', torch.zeros(out_features, in_features)) 48 | self.register_buffer('g_avg', torch.zeros(out_features)) 49 | self.register_buffer('b_avg', torch.zeros(out_features)) 50 | 51 | self.init_scale = init_scale 52 | self.polyak_decay = polyak_decay 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | return 57 | 58 | def forward(self, x, init=False): 59 | if init is True: 60 | # out_features * in_features 61 | self.V.data.copy_(torch.randn(self.V.data.size()).type_as( 62 | self.V.data) * 0.05) 63 | # norm is out_features * 1 64 | V_norm = self.V.data / \ 65 | self.V.data.norm(2, 1).expand_as(self.V.data) 66 | # batch_size * out_features 67 | x_init = F.linear(x, Variable(V_norm)).data 68 | # out_features 69 | m_init, v_init = x_init.mean(0).squeeze( 70 | 0), x_init.var(0).squeeze(0) 71 | # out_features 72 | scale_init = self.init_scale / \ 73 | torch.sqrt(v_init + 1e-10) 74 | self.g.data.copy_(scale_init) 75 | self.b.data.copy_(-m_init * scale_init) 76 | x_init = scale_init.view(1, -1).expand_as(x_init) \ 77 | * (x_init - m_init.view(1, -1).expand_as(x_init)) 78 | self.V_avg.copy_(self.V.data) 79 | self.g_avg.copy_(self.g.data) 80 | self.b_avg.copy_(self.b.data) 81 | return Variable(x_init) 82 | else: 83 | V, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], 84 | self.training, 85 | polyak_decay=self.polyak_decay) 86 | # batch_size * out_features 87 | x = F.linear(x, V) 88 | scalar = g / torch.norm(V, 2, 1).squeeze(1) 89 | x = scalar.view(1, -1).expand_as(x) * x + \ 90 | b.view(1, -1).expand_as(x) 91 | return x 92 | 93 | 94 | class WeightNormConv2d(nn.Conv2d): 95 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 96 | padding=0, dilation=1, groups=1, init_scale=1., 97 | polyak_decay=0.9995): 98 | super(WeightNormConv2d, self).__init__(in_channels, out_channels, 99 | kernel_size, stride, padding, 100 | dilation, groups) 101 | 102 | self.V = self.weight 103 | self.g = Parameter(torch.Tensor(out_channels)) 104 | self.b = self.bias 105 | 106 | self.register_buffer('V_avg', torch.zeros(self.V.size())) 107 | self.register_buffer('g_avg', torch.zeros(out_channels)) 108 | self.register_buffer('b_avg', torch.zeros(out_channels)) 109 | 110 | self.init_scale = init_scale 111 | self.polyak_decay = polyak_decay 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | return 116 | 117 | def forward(self, x, init=False): 118 | if init is True: 119 | # out_channels, in_channels // groups, * kernel_size 120 | self.V.data.copy_(torch.randn(self.V.data.size() 121 | ).type_as(self.V.data) * 0.05) 122 | V_norm = self.V.data / self.V.data.view(self.out_channels, -1)\ 123 | .norm(2, 1).view(self.out_channels, *( 124 | [1] * (len(self.kernel_size) + 1))).expand_as(self.V.data) 125 | x_init = F.conv2d(x, Variable(V_norm), None, self.stride, 126 | self.padding, self.dilation, self.groups).data 127 | t_x_init = x_init.transpose(0, 1).contiguous().view( 128 | self.out_channels, -1) 129 | m_init, v_init = t_x_init.mean(1).squeeze( 130 | 1), t_x_init.var(1).squeeze(1) 131 | # out_features 132 | scale_init = self.init_scale / \ 133 | torch.sqrt(v_init + 1e-10) 134 | self.g.data.copy_(scale_init) 135 | self.b.data.copy_(-m_init * scale_init) 136 | scale_init_shape = scale_init.view( 137 | 1, self.out_channels, *([1] * (len(x_init.size()) - 2))) 138 | m_init_shape = m_init.view( 139 | 1, self.out_channels, *([1] * (len(x_init.size()) - 2))) 140 | x_init = scale_init_shape.expand_as( 141 | x_init) * (x_init - m_init_shape.expand_as(x_init)) 142 | self.V_avg.copy_(self.V.data) 143 | self.g_avg.copy_(self.g.data) 144 | self.b_avg.copy_(self.b.data) 145 | return Variable(x_init) 146 | else: 147 | V, g, b = get_vars_maybe_avg( 148 | self, ['V', 'g', 'b'], self.training, 149 | polyak_decay=self.polyak_decay) 150 | 151 | scalar = torch.norm(V.view(self.out_channels, -1), 2, 1) 152 | if len(scalar.size()) == 2: 153 | scalar = g / scalar.squeeze(1) 154 | else: 155 | scalar = g / scalar 156 | 157 | W = scalar.view(self.out_channels, * 158 | ([1] * (len(V.size()) - 1))).expand_as(V) * V 159 | 160 | x = F.conv2d(x, W, b, self.stride, 161 | self.padding, self.dilation, self.groups) 162 | return x 163 | 164 | 165 | class WeightNormConvTranspose2d(nn.ConvTranspose2d): 166 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 167 | padding=0, output_padding=0, groups=1, init_scale=1., 168 | polyak_decay=0.9995): 169 | super(WeightNormConvTranspose2d, self).__init__( 170 | in_channels, out_channels, 171 | kernel_size, stride, 172 | padding, output_padding, 173 | groups) 174 | # in_channels, out_channels, *kernel_size 175 | self.V = self.weight 176 | self.g = Parameter(torch.Tensor(out_channels)) 177 | self.b = self.bias 178 | 179 | self.register_buffer('V_avg', torch.zeros(self.V.size())) 180 | self.register_buffer('g_avg', torch.zeros(out_channels)) 181 | self.register_buffer('b_avg', torch.zeros(out_channels)) 182 | 183 | self.init_scale = init_scale 184 | self.polyak_decay = polyak_decay 185 | self.reset_parameters() 186 | 187 | def reset_parameters(self): 188 | return 189 | 190 | def forward(self, x, init=False): 191 | if init is True: 192 | # in_channels, out_channels, *kernel_size 193 | self.V.data.copy_(torch.randn(self.V.data.size()).type_as( 194 | self.V.data) * 0.05) 195 | V_norm = self.V.data / self.V.data.transpose(0, 1).contiguous() \ 196 | .view(self.out_channels, -1).norm(2, 1).view( 197 | self.in_channels, self.out_channels, 198 | *([1] * len(self.kernel_size))).expand_as(self.V.data) 199 | x_init = F.conv_transpose2d( 200 | x, Variable(V_norm), None, self.stride, 201 | self.padding, self.output_padding, self.groups).data 202 | # self.out_channels, 1 203 | t_x_init = x_init.tranpose(0, 1).contiguous().view( 204 | self.out_channels, -1) 205 | # out_features 206 | m_init, v_init = t_x_init.mean(1).squeeze( 207 | 1), t_x_init.var(1).squeeze(1) 208 | # out_features 209 | scale_init = self.init_scale / \ 210 | torch.sqrt(v_init + 1e-10) 211 | self.g.data.copy_(scale_init) 212 | self.b.data.copy_(-m_init * scale_init) 213 | scale_init_shape = scale_init.view( 214 | 1, self.out_channels, *([1] * (len(x_init.size()) - 2))) 215 | m_init_shape = m_init.view( 216 | 1, self.out_channels, *([1] * (len(x_init.size()) - 2))) 217 | 218 | x_init = scale_init_shape.expand_as(x_init)\ 219 | * (x_init - m_init_shape.expand_as(x_init)) 220 | self.V_avg.copy_(self.V.data) 221 | self.g_avg.copy_(self.g.data) 222 | self.b_avg.copy_(self.b.data) 223 | return Variable(x_init) 224 | else: 225 | V, g, b = get_vars_maybe_avg( 226 | self, ['V', 'g', 'b'], self.training, 227 | polyak_decay=self.polyak_decay) 228 | scalar = g / \ 229 | torch.norm(V.transpose(0, 1).contiguous().view( 230 | self.out_channels, -1), 2, 1).squeeze(1) 231 | W = scalar.view(self.in_channels, self.out_channels, 232 | *([1] * (len(V.size()) - 2))).expand_as(V) * V 233 | 234 | x = F.conv_transpose2d(x, W, b, self.stride, 235 | self.padding, self.output_padding, 236 | self.groups) 237 | return x 238 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from table.modules.UtilClass import LayerNorm, Bottle, BottleLinear, \ 2 | BottleLayerNorm, BottleSoftmax, Elementwise 3 | from table.modules.Gate import ContextGateFactory 4 | from table.modules.GlobalAttention import GlobalAttention 5 | from table.modules.StackedRNN import StackedLSTM, StackedGRU 6 | from table.modules.LockedDropout import LockedDropout 7 | from table.modules.WeightDrop import WeightDrop 8 | 9 | # # For flake8 compatibility. 10 | # __all__ = [GlobalAttention, ImageEncoder, CopyGenerator, MultiHeadedAttention, 11 | # LayerNorm, Bottle, BottleLinear, BottleLayerNorm, BottleSoftmax, 12 | # TransformerEncoder, TransformerDecoder, Elementwise, 13 | # MatrixTree, WeightNormConv2d, ConvMultiStepAttention, 14 | # CNNEncoder, CNNDecoder, StackedLSTM, StackedGRU, ContextGateFactory, 15 | # CopyGeneratorLossCompute] 16 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/table/modules/cross_entropy_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | def onehot(indexes, N=None, ignore_index=None): 8 | """ 9 | Creates a one-representation of indexes with N possible entries 10 | if N is not specified, it will suit the maximum index appearing. 11 | indexes is a long-tensor of indexes 12 | ignore_index will be zero in onehot representation 13 | """ 14 | return_variable = False 15 | if isinstance(indexes, Variable): 16 | return_variable = True 17 | indexes = indexes.data 18 | if ignore_index is not None: 19 | mask_idx = indexes.eq(ignore_index) 20 | if N is None: 21 | N = indexes.max() + 1 22 | sz = list(indexes.size()) 23 | output = indexes.new().byte().resize_(*sz, N).zero_() 24 | # ignore_index could be < 0 25 | output.scatter_(-1, indexes.clone().masked_fill_(mask_idx, 0).unsqueeze(-1), 1) 26 | if ignore_index is not None: 27 | output.masked_fill_(mask_idx.unsqueeze(-1), 0) 28 | if return_variable: 29 | output = Variable(output, requires_grad=False) 30 | 31 | return output 32 | 33 | 34 | def _is_long(x): 35 | if hasattr(x, 'data'): 36 | x = x.data 37 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor) 38 | 39 | 40 | def cross_entropy(logits, target, weight=None, size_average=True, 41 | ignore_index=None, smooth_eps=None, smooth_dist=None): 42 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567""" 43 | if smooth_eps is not None and smooth_eps > 0: 44 | num_classes = logits.size(-1) 45 | mask_idx = None 46 | if _is_long(target): 47 | if ignore_index is not None: 48 | mask_idx = target.eq(ignore_index) 49 | target = onehot(target, num_classes, ignore_index).type_as(logits) 50 | if smooth_dist is None: 51 | target = (1 - smooth_eps) * target + \ 52 | smooth_eps / num_classes 53 | else: 54 | target = torch.lerp( 55 | target, smooth_dist.unsqueeze(0), smooth_eps) 56 | if mask_idx is not None: 57 | target.masked_fill_(mask_idx.unsqueeze(1), 0) 58 | if weight is not None: 59 | target = target * weight.unsqueeze(0) 60 | ce = -(logits * target).sum(1) 61 | if size_average: 62 | ce = ce.mean() 63 | else: 64 | ce = ce.sum() 65 | return ce 66 | 67 | 68 | class CrossEntropyLossSmooth(nn.CrossEntropyLoss): 69 | """CrossEntropyLossSmooth - with ability to recieve distrbution as targets, and optional label smoothing""" 70 | 71 | def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True, 72 | smooth_eps=None, smooth_dist=None): 73 | super(CrossEntropyLossSmooth, self).__init__( 74 | weight, size_average=size_average, ignore_index=ignore_index) 75 | self.smooth_eps = smooth_eps 76 | self.smooth_dist = smooth_dist 77 | 78 | def forward(self, input, target): 79 | return cross_entropy(input, target, self.weight, self.size_average, 80 | self.ignore_index, self.smooth_eps, self.smooth_dist) 81 | 82 | 83 | if __name__ == '__main__': 84 | print(onehot(torch.LongTensor([3,2,-1,0,1]),5, -1)) 85 | -------------------------------------------------------------------------------- /zero-shot-text-to-SQL/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | from torch import cuda 10 | 11 | import table 12 | import table.Models 13 | import table.ModelConstructor 14 | import table.modules 15 | from table.Utils import set_seed 16 | import opts 17 | from tensorboard_logger import Logger 18 | #from path import Path 19 | 20 | 21 | use_cuda = torch.cuda.is_available() 22 | 23 | parser = argparse.ArgumentParser(description='train.py') 24 | 25 | # opts.py 26 | opts.model_opts(parser) 27 | opts.train_opts(parser) 28 | 29 | opt = parser.parse_args() 30 | opt.save_path=opt.save_dir 31 | print(opt.save_dir) 32 | 33 | if opt.layers != -1: 34 | opt.enc_layers = opt.layers 35 | opt.dec_layers = opt.layers 36 | 37 | opt.brnn = (opt.encoder_type == "brnn") 38 | opt.pre_word_vecs = os.path.join(opt.embd, 'embedding') 39 | 40 | print(vars(opt)) 41 | 42 | json.dump(opt.__dict__, open(os.path.join( 43 | opt.save_path, 'opt.json'), 'w'), sort_keys=True, indent=2) 44 | if torch.cuda.is_available(): 45 | cuda.set_device(opt.gpuid[0]) 46 | set_seed(opt.seed) 47 | 48 | # Set up the logging server. 49 | # logger = Logger(os.path.join(opt.save_path, 'tb')) 50 | 51 | 52 | def report_func(epoch, batch, num_batches, 53 | start_time, lr, report_stats): 54 | """ 55 | This is the user-defined batch-level traing progress 56 | report function. 57 | 58 | Args: 59 | epoch(int): current epoch count. 60 | batch(int): current batch count. 61 | num_batches(int): total number of batches. 62 | start_time(float): last report time. 63 | lr(float): current learning rate. 64 | report_stats(Statistics): old Statistics instance. 65 | Returns: 66 | report_stats(Statistics): updated Statistics instance. 67 | """ 68 | if batch % opt.report_every == -1 % opt.report_every: 69 | report_stats.output(epoch, batch + 1, num_batches, start_time) 70 | report_stats = table.Statistics(0, {}) 71 | 72 | return report_stats 73 | 74 | 75 | def train_model(model, train_data, valid_data, fields, optim): 76 | train_iter = table.IO.OrderedIterator( 77 | dataset=train_data, batch_size=opt.batch_size, device=opt.gpuid[0], repeat=False) 78 | valid_iter = table.IO.OrderedIterator( 79 | dataset=valid_data, batch_size=opt.batch_size, device=opt.gpuid[0], train=False, sort=True, sort_within_batch=False) 80 | 81 | train_loss = table.Loss.TableLossCompute(opt.agg_sample_rate, smooth_eps=model.opt.smooth_eps) 82 | if torch.cuda.is_available(): 83 | train_loss=train_loss.cuda() 84 | valid_loss = table.Loss.TableLossCompute(opt.agg_sample_rate, smooth_eps=model.opt.smooth_eps) 85 | if torch.cuda.is_available(): 86 | valid_loss=train_loss.cuda() 87 | 88 | trainer = table.Trainer(model, train_iter, valid_iter, 89 | train_loss, valid_loss, optim) 90 | 91 | for epoch in range(opt.start_epoch, opt.epochs + 1): 92 | print('') 93 | 94 | if opt.fix_word_vecs: 95 | if (epoch >= opt.update_word_vecs_after): 96 | model.q_encoder.embeddings.set_update(True) 97 | else: 98 | model.q_encoder.embeddings.set_update(False) 99 | 100 | # 1. Train for one epoch on the training set. 101 | train_stats = trainer.train(epoch, report_func) 102 | print('Train accuracy: %s' % train_stats.accuracy(True)) 103 | 104 | # 2. Validate on the validation set. 105 | valid_stats = trainer.validate() 106 | print('Validation accuracy: %s' % valid_stats.accuracy(True)) 107 | 108 | # 3. Log to remote server. 109 | # train_stats.log("train", logger, optim.lr, epoch) 110 | # valid_stats.log("valid", logger, optim.lr, epoch) 111 | 112 | # 4. Update the learning rate 113 | trainer.epoch_step(None, epoch) 114 | 115 | # 5. Drop a checkpoint if needed. 116 | if epoch >= opt.start_checkpoint_at: 117 | trainer.drop_checkpoint(opt, epoch, fields, valid_stats) 118 | 119 | 120 | def load_fields(train, valid, checkpoint): 121 | fields = table.IO.TableDataset.load_fields( 122 | torch.load(os.path.join(opt.data, 'vocab.pt'))) 123 | fields = dict([(k, f) for (k, f) in fields.items() 124 | if k in train.examples[0].__dict__]) 125 | train.fields = fields 126 | valid.fields = fields 127 | 128 | if opt.train_from: 129 | print('Loading vocab from checkpoint at %s.' % opt.train_from) 130 | fields = table.IO.TableDataset.load_fields(checkpoint['vocab']) 131 | 132 | return fields 133 | 134 | 135 | def build_model(model_opt, fields, checkpoint): 136 | print('Building model...') 137 | model = table.ModelConstructor.make_base_model( 138 | model_opt, fields, checkpoint) 139 | print(model) 140 | 141 | return model 142 | 143 | 144 | def build_optim(model, checkpoint): 145 | if opt.train_from: 146 | print('Loading optimizer from checkpoint.') 147 | optim = checkpoint['optim'] 148 | optim.optimizer.load_state_dict( 149 | checkpoint['optim'].optimizer.state_dict()) 150 | else: 151 | # what members of opt does Optim need? 152 | optim = table.Optim( 153 | opt.optim, opt.learning_rate, opt.alpha, opt.max_grad_norm, 154 | lr_decay=opt.learning_rate_decay, 155 | start_decay_at=opt.start_decay_at, 156 | opt=opt 157 | ) 158 | 159 | optim.set_parameters(model.parameters()) 160 | 161 | return optim 162 | 163 | 164 | def main(): 165 | # Load train and validate data. 166 | print("Loading train and validate data from '%s'" % opt.data) 167 | train = torch.load(os.path.join(opt.data, 'train.pt')) 168 | valid = torch.load(os.path.join(opt.data, 'valid.pt')) 169 | #test = torch.load(os.path.join(opt.data, 'test.pt')) 170 | print(' * number of training sentences: %d' % len(train)) 171 | print(' * maximum batch size: %d' % opt.batch_size) 172 | 173 | # Load checkpoint if we resume from a previous training. 174 | if opt.train_from: 175 | print('Loading checkpoint from %s' % opt.train_from) 176 | checkpoint = torch.load( 177 | opt.train_from, map_location=lambda storage, loc: storage) 178 | model_opt = checkpoint['opt'] 179 | # I don't like reassigning attributes of opt: it's not clear 180 | opt.start_epoch = checkpoint['epoch'] + 1 181 | else: 182 | checkpoint = None 183 | model_opt = opt 184 | 185 | # Load fields generated from preprocess phase. 186 | fields = load_fields(train, valid, checkpoint) 187 | 188 | # Build model. 189 | model = build_model(model_opt, fields, checkpoint) 190 | 191 | # Build optimizer. 192 | optim = build_optim(model, checkpoint) 193 | 194 | # Do training. 195 | train_model(model, train, valid,fields, optim) 196 | 197 | 198 | if __name__ == "__main__": 199 | main() 200 | --------------------------------------------------------------------------------