├── .gitignore ├── README.md ├── asdls ├── action_info.py ├── asdl.py ├── asdl_ast.py ├── decode_hypothesis.py ├── hypothesis.py ├── sql │ ├── grammar │ │ ├── sql_asdl_v0.txt │ │ ├── sql_asdl_v1.txt │ │ └── sql_asdl_v2.txt │ ├── parser │ │ ├── parser_base.py │ │ ├── parser_v0.py │ │ ├── parser_v1.py │ │ └── parser_v2.py │ ├── sql_transition_system.py │ └── unparser │ │ ├── unparser_base.py │ │ ├── unparser_v0.py │ │ ├── unparser_v1.py │ │ └── unparser_v2.py └── transition_system.py ├── evaluation.py ├── model ├── decoder │ ├── onlstm.py │ └── sql_parser.py ├── encoder │ ├── functions.py │ ├── graph_encoder.py │ ├── graph_input.py │ ├── graph_output.py │ └── rgatsql.py ├── model_constructor.py └── model_utils.py ├── preprocess ├── common_utils.py ├── graph_utils.py ├── process_dataset.py └── process_graphs.py ├── process_sql.py ├── scripts └── text2sql.py └── utils ├── args.py ├── batch.py ├── constants.py ├── evaluator.py ├── example.py ├── graph_example.py ├── hyperparams.py ├── initialization.py ├── logger.py ├── optimization.py ├── vocab.py └── word2vec.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | **/__pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ISESL-SQL 2 | 3 | The source code of paper "Semantic Enhanced Text-to-SQL Parsing via Iteratively Learning Schema Linking Graph" published at KDD 2022. 4 | 5 | ### download data 6 | 7 | spider: https://yale-lily.github.io/spider 8 | 9 | put the data into the data/ directory 10 | 11 | ### preprocess data 12 | 13 | ``` 14 | python3 -u preprocess/process_dataset.py --dataset_path data/train.json --raw_table_path data/tables.json --table_path data/tables.bin --output_path 'data/train.bin' --skip_large --semantic_graph 15 | python3 -u preprocess/process_dataset.py --dataset_path data/dev.json --table_path data/tables.bin --output_path 'data/dev.bin' --skip_large --semantic_graph 16 | python3 -u preprocess/process_graphs.py --dataset_path 'data/train.bin' --table_path data/tables.bin --output_path data/train.rgatsql.bin 17 | python3 -u preprocess/process_graphs.py --dataset_path 'data/dev.bin' --table_path data/tables.bin --output_path data/dev.rgatsql.bin 18 | ``` 19 | 20 | ### train model 21 | 22 | ``` 23 | CUDA_VISIBLE_DEVICES=0 python scripts/text2sql.py --task lgesql_large --seed 999 --device 0 --plm google/electra-large-discriminator --gnn_hidden_size 512 --dropout 0.2 --attn_drop 0.0 --att_vec_size 512 --model rgatsql --output_model without_pruning --score_function affine --relation_share_heads --subword_aggregation attentive-pooling --schema_aggregation head+tail --gnn_num_layers 8 --num_heads 8 --lstm onlstm --chunk_size 8 --drop_connect 0.2 --lstm_hidden_size 512 --lstm_num_layers 1 --action_embed_size 128 --field_embed_size 64 --type_embed_size 64 --no_context_feeding --batch_size 35 --grad_accumulate 5 --lr 1e-4 --l2 0.1 --warmup_ratio 0.1 --lr_schedule linear --eval_after_epoch 120 --smoothing 0.15 --layerwise_decay 0.8 --max_epoch 200 --max_norm 5 --beam_size 5 --logdir logdir/run --train_path train --dev_path dev --training --optimize_graph --schema_loss 24 | ``` 25 | 26 | ### Reference 27 | 28 | ``` 29 | @inproceedings{liu2022semantic, 30 | title={Semantic Enhanced Text-to-SQL Parsing via Iteratively Learning Schema Linking Graph}, 31 | author={Liu, Aiwei and Hu, Xuming and Lin, Li and Wen, Lijie}, 32 | booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 33 | pages={1021--1030}, 34 | year={2022} 35 | } 36 | ``` 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /asdls/action_info.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from asdls.hypothesis import Hypothesis 3 | from asdls.transition_system import ApplyRuleAction, GenTokenAction 4 | from asdls.sql.sql_transition_system import SelectColumnAction, SelectTableAction 5 | 6 | class ActionInfo(object): 7 | """sufficient statistics for making a prediction of an action at a time step""" 8 | 9 | def __init__(self, action=None): 10 | self.t = 0 11 | self.parent_t = -1 12 | self.action = action 13 | self.frontier_prod = None 14 | self.frontier_field = None 15 | 16 | # for GenToken actions only 17 | self.copy_from_src = False 18 | self.src_token_position = -1 19 | 20 | def __repr__(self, verbose=False): 21 | repr_str = '%s (t=%d, p_t=%d, frontier_field=%s)' % (repr(self.action), 22 | self.t, 23 | self.parent_t, 24 | self.frontier_field.__repr__(True) if self.frontier_field else 'None') 25 | 26 | if verbose: 27 | verbose_repr = 'action_prob=%.4f, ' % self.action_prob 28 | if isinstance(self.action, GenTokenAction): 29 | verbose_repr += 'in_vocab=%s, ' \ 30 | 'gen_copy_switch=%s, ' \ 31 | 'p(gen)=%s, p(copy)=%s, ' \ 32 | 'has_copy=%s, copy_pos=%s' % (self.in_vocab, 33 | self.gen_copy_switch, 34 | self.gen_token_prob, self.copy_token_prob, 35 | self.copy_from_src, self.src_token_position) 36 | 37 | repr_str += '\n' + verbose_repr 38 | 39 | return repr_str 40 | 41 | 42 | def get_action_infos(src_query: list = None, tgt_actions: list = [], force_copy=False): 43 | action_infos = [] 44 | hyp = Hypothesis() 45 | for t, action in enumerate(tgt_actions): 46 | action_info = ActionInfo(action) 47 | action_info.t = t 48 | if hyp.frontier_node: 49 | action_info.parent_t = hyp.frontier_node.created_time 50 | action_info.frontier_prod = hyp.frontier_node.production 51 | action_info.frontier_field = hyp.frontier_field.field 52 | 53 | if isinstance(action, SelectColumnAction) or isinstance(action, SelectTableAction): 54 | pass 55 | elif isinstance(action, GenTokenAction): # GenToken 56 | try: 57 | tok_src_idx = src_query.index(str(action.token)) 58 | action_info.copy_from_src = True 59 | action_info.src_token_position = tok_src_idx 60 | except ValueError: 61 | if force_copy: raise ValueError('cannot copy primitive token %s from source' % action.token) 62 | 63 | hyp.apply_action(action) 64 | action_infos.append(action_info) 65 | 66 | return action_infos 67 | -------------------------------------------------------------------------------- /asdls/asdl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from collections import OrderedDict, Counter 3 | from itertools import chain 4 | import re, os 5 | 6 | def remove_comment(text): 7 | text = re.sub(re.compile("#.*"), "", text) 8 | text = '\n'.join(filter(lambda x: x, text.split('\n'))) 9 | return text 10 | 11 | class ASDLGrammar(object): 12 | """ 13 | Collection of types, constructors and productions 14 | """ 15 | def __init__(self, productions, file_path): 16 | # productions are indexed by their head types 17 | file_name = os.path.basename(file_path) 18 | grammar_name = file_name[:file_name.index('.txt')] if '.txt' in file_name else file_name 19 | self._grammar_name = grammar_name 20 | self._productions = OrderedDict() 21 | self._constructor_production_map = dict() 22 | for prod in productions: 23 | if prod.type not in self._productions: 24 | self._productions[prod.type] = list() 25 | self._productions[prod.type].append(prod) 26 | self._constructor_production_map[prod.constructor.name] = prod 27 | 28 | self.root_type = productions[0].type 29 | # number of constructors 30 | self.size = sum(len(head) for head in self._productions.values()) 31 | 32 | # get entities to their ids map 33 | self.prod2id = {prod: i for i, prod in enumerate(self.productions)} 34 | self.type2id = {type: i for i, type in enumerate(self.types)} 35 | self.field2id = {field: i for i, field in enumerate(self.fields)} 36 | 37 | self.id2prod = {i: prod for i, prod in enumerate(self.productions)} 38 | self.id2type = {i: type for i, type in enumerate(self.types)} 39 | self.id2field = {i: field for i, field in enumerate(self.fields)} 40 | 41 | def __len__(self): 42 | return self.size 43 | 44 | @property 45 | def productions(self): 46 | return sorted(chain.from_iterable(self._productions.values()), key=lambda x: repr(x)) 47 | 48 | def __getitem__(self, datum): 49 | if isinstance(datum, str): 50 | return self._productions[ASDLType(datum)] 51 | elif isinstance(datum, ASDLType): 52 | return self._productions[datum] 53 | 54 | def get_prod_by_ctr_name(self, name): 55 | return self._constructor_production_map[name] 56 | 57 | @property 58 | def types(self): 59 | if not hasattr(self, '_types'): 60 | all_types = set() 61 | for prod in self.productions: 62 | all_types.add(prod.type) 63 | all_types.update(map(lambda x: x.type, prod.constructor.fields)) 64 | 65 | self._types = sorted(all_types, key=lambda x: x.name) 66 | 67 | return self._types 68 | 69 | @property 70 | def fields(self): 71 | if not hasattr(self, '_fields'): 72 | all_fields = set() 73 | for prod in self.productions: 74 | all_fields.update(prod.constructor.fields) 75 | 76 | self._fields = sorted(all_fields, key=lambda x: (x.name, x.type.name, x.cardinality)) 77 | 78 | return self._fields 79 | 80 | @property 81 | def primitive_types(self): 82 | return filter(lambda x: isinstance(x, ASDLPrimitiveType), self.types) 83 | 84 | @property 85 | def composite_types(self): 86 | return filter(lambda x: isinstance(x, ASDLCompositeType), self.types) 87 | 88 | def is_composite_type(self, asdl_type): 89 | return asdl_type in self.composite_types 90 | 91 | def is_primitive_type(self, asdl_type): 92 | return asdl_type in self.primitive_types 93 | 94 | @staticmethod 95 | def from_filepath(file_path): 96 | def _parse_field_from_text(_text): 97 | d = _text.strip().split(' ') 98 | name = d[1].strip() 99 | type_str = d[0].strip() 100 | cardinality = 'single' 101 | if type_str[-1] == '*': 102 | type_str = type_str[:-1] 103 | cardinality = 'multiple' 104 | elif type_str[-1] == '?': 105 | type_str = type_str[:-1] 106 | cardinality = 'optional' 107 | 108 | if type_str in primitive_type_names: 109 | return Field(name, ASDLPrimitiveType(type_str), cardinality=cardinality) 110 | else: 111 | return Field(name, ASDLCompositeType(type_str), cardinality=cardinality) 112 | 113 | def _parse_constructor_from_text(_text): 114 | _text = _text.strip() 115 | fields = None 116 | if '(' in _text: 117 | name = _text[:_text.find('(')] 118 | field_blocks = _text[_text.find('(') + 1:_text.find(')')].split(',') 119 | fields = map(_parse_field_from_text, field_blocks) 120 | else: 121 | name = _text 122 | 123 | if name == '': name = None 124 | 125 | return ASDLConstructor(name, fields) 126 | 127 | with open(file_path, 'r') as inf: 128 | text = inf.read() 129 | lines = remove_comment(text).split('\n') 130 | lines = list(map(lambda l: l.strip(), lines)) 131 | lines = list(filter(lambda l: l, lines)) 132 | line_no = 0 133 | 134 | # first line is always the primitive types 135 | primitive_type_names = list(map(lambda x: x.strip(), lines[line_no].split(','))) 136 | line_no += 1 137 | 138 | all_productions = list() 139 | 140 | while True: 141 | type_block = lines[line_no] 142 | type_name = type_block[:type_block.find('=')].strip() 143 | constructors_blocks = type_block[type_block.find('=') + 1:].split('|') 144 | i = line_no + 1 145 | while i < len(lines) and lines[i].strip().startswith('|'): 146 | t = lines[i].strip() 147 | cont_constructors_blocks = t[1:].split('|') 148 | constructors_blocks.extend(cont_constructors_blocks) 149 | 150 | i += 1 151 | 152 | constructors_blocks = filter(lambda x: x and x.strip(), constructors_blocks) 153 | 154 | # parse type name 155 | new_type = ASDLPrimitiveType(type_name) if type_name in primitive_type_names else ASDLCompositeType(type_name) 156 | constructors = map(_parse_constructor_from_text, constructors_blocks) 157 | 158 | productions = list(map(lambda c: ASDLProduction(new_type, c), constructors)) 159 | all_productions.extend(productions) 160 | 161 | line_no = i 162 | if line_no == len(lines): 163 | break 164 | 165 | grammar = ASDLGrammar(all_productions, file_path) 166 | 167 | return grammar 168 | 169 | 170 | class ASDLProduction(object): 171 | def __init__(self, type, constructor): 172 | self.type = type 173 | self.constructor = constructor 174 | 175 | @property 176 | def fields(self): 177 | return self.constructor.fields 178 | 179 | def __getitem__(self, field_name): 180 | return self.constructor[field_name] 181 | 182 | def __hash__(self): 183 | h = hash(self.type) ^ hash(self.constructor) 184 | 185 | return h 186 | 187 | def __eq__(self, other): 188 | return isinstance(other, ASDLProduction) and \ 189 | self.type == other.type and \ 190 | self.constructor == other.constructor 191 | 192 | def __ne__(self, other): 193 | return not self.__eq__(other) 194 | 195 | def __repr__(self): 196 | return '%s -> %s' % (self.type.__repr__(plain=True), self.constructor.__repr__(plain=True)) 197 | 198 | 199 | class ASDLConstructor(object): 200 | def __init__(self, name, fields=None): 201 | self.name = name 202 | self.fields = [] 203 | if fields: 204 | self.fields = list(fields) 205 | 206 | def __getitem__(self, field_name): 207 | for field in self.fields: 208 | if field.name == field_name: return field 209 | 210 | raise KeyError 211 | 212 | def __hash__(self): 213 | h = hash(self.name) 214 | for field in self.fields: 215 | h ^= hash(field) 216 | 217 | return h 218 | 219 | def __eq__(self, other): 220 | return isinstance(other, ASDLConstructor) and \ 221 | self.name == other.name and \ 222 | self.fields == other.fields 223 | 224 | def __ne__(self, other): 225 | return not self.__eq__(other) 226 | 227 | def __repr__(self, plain=False): 228 | plain_repr = '%s(%s)' % (self.name, 229 | ', '.join(f.__repr__(plain=True) for f in self.fields)) 230 | if plain: return plain_repr 231 | else: return 'Constructor(%s)' % plain_repr 232 | 233 | 234 | class Field(object): 235 | def __init__(self, name, type, cardinality): 236 | self.name = name 237 | self.type = type 238 | 239 | assert cardinality in ['single', 'optional', 'multiple'] 240 | self.cardinality = cardinality 241 | 242 | def __hash__(self): 243 | h = hash(self.name) ^ hash(self.type) 244 | h ^= hash(self.cardinality) 245 | 246 | return h 247 | 248 | def __eq__(self, other): 249 | return isinstance(other, Field) and \ 250 | self.name == other.name and \ 251 | self.type == other.type and \ 252 | self.cardinality == other.cardinality 253 | 254 | def __ne__(self, other): 255 | return not self.__eq__(other) 256 | 257 | def __repr__(self, plain=False): 258 | plain_repr = '%s%s %s' % (self.type.__repr__(plain=True), 259 | Field.get_cardinality_repr(self.cardinality), 260 | self.name) 261 | if plain: return plain_repr 262 | else: return 'Field(%s)' % plain_repr 263 | 264 | @staticmethod 265 | def get_cardinality_repr(cardinality): 266 | return '' if cardinality == 'single' else '?' if cardinality == 'optional' else '*' 267 | 268 | 269 | class ASDLType(object): 270 | def __init__(self, type_name): 271 | self.name = type_name 272 | 273 | def __hash__(self): 274 | return hash(self.name) 275 | 276 | def __eq__(self, other): 277 | return isinstance(other, ASDLType) and self.name == other.name 278 | 279 | def __ne__(self, other): 280 | return not self.__eq__(other) 281 | 282 | def __repr__(self, plain=False): 283 | plain_repr = self.name 284 | if plain: return plain_repr 285 | else: return '%s(%s)' % (self.__class__.__name__, plain_repr) 286 | 287 | 288 | class ASDLCompositeType(ASDLType): 289 | pass 290 | 291 | 292 | class ASDLPrimitiveType(ASDLType): 293 | pass 294 | 295 | 296 | if __name__ == '__main__': 297 | asdl_desc = """ 298 | var, ent, num, var_type 299 | 300 | expr = Variable(var variable) 301 | | Entity(ent entity) 302 | | Number(num number) 303 | | Apply(pred predicate, expr* arguments) 304 | | Argmax(var variable, expr domain, expr body) 305 | | Argmin(var variable, expr domain, expr body) 306 | | Count(var variable, expr body) 307 | | Exists(var variable, expr body) 308 | | Lambda(var variable, var_type type, expr body) 309 | | Max(var variable, expr body) 310 | | Min(var variable, expr body) 311 | | Sum(var variable, expr domain, expr body) 312 | | The(var variable, expr body) 313 | | Not(expr argument) 314 | | And(expr* arguments) 315 | | Or(expr* arguments) 316 | | Compare(cmp_op op, expr left, expr right) 317 | 318 | cmp_op = GreaterThan | Equal | LessThan 319 | """ 320 | 321 | grammar = ASDLGrammar.from_text(asdl_desc) 322 | print(ASDLCompositeType('1') == ASDLPrimitiveType('1')) 323 | 324 | -------------------------------------------------------------------------------- /asdls/asdl_ast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from io import StringIO 3 | from asdls.asdl import * 4 | 5 | class AbstractSyntaxTree(object): 6 | def __init__(self, production, realized_fields=None): 7 | self.production = production 8 | 9 | # a child is essentially a *realized_field* 10 | self.fields = [] 11 | 12 | # record its parent field to which it's attached 13 | self.parent_field = None 14 | 15 | # used in decoding, record the time step when this node was created 16 | self.created_time = 0 17 | 18 | if realized_fields: 19 | assert len(realized_fields) == len(self.production.fields) 20 | 21 | for field in realized_fields: 22 | self.add_child(field) 23 | else: 24 | for field in self.production.fields: 25 | self.add_child(RealizedField(field)) 26 | 27 | def add_child(self, realized_field): 28 | # if isinstance(realized_field.value, AbstractSyntaxTree): 29 | # realized_field.value.parent = self 30 | self.fields.append(realized_field) 31 | realized_field.parent_node = self 32 | 33 | def __getitem__(self, field_name): 34 | for field in self.fields: 35 | if field.name == field_name: return field 36 | raise KeyError 37 | 38 | def sanity_check(self): 39 | if len(self.production.fields) != len(self.fields): 40 | raise ValueError('filed number must match') 41 | for field, realized_field in zip(self.production.fields, self.fields): 42 | assert field == realized_field.field 43 | for child in self.fields: 44 | for child_val in child.as_value_list: 45 | if isinstance(child_val, AbstractSyntaxTree): 46 | child_val.sanity_check() 47 | 48 | def copy(self): 49 | new_tree = AbstractSyntaxTree(self.production) 50 | new_tree.created_time = self.created_time 51 | for i, old_field in enumerate(self.fields): 52 | new_field = new_tree.fields[i] 53 | new_field._not_single_cardinality_finished = old_field._not_single_cardinality_finished 54 | if isinstance(old_field.type, ASDLCompositeType): 55 | for value in old_field.as_value_list: 56 | new_field.add_value(value.copy()) 57 | else: 58 | for value in old_field.as_value_list: 59 | new_field.add_value(value) 60 | 61 | return new_tree 62 | 63 | def to_string(self, sb=None): 64 | is_root = False 65 | if sb is None: 66 | is_root = True 67 | sb = StringIO() 68 | 69 | sb.write('(') 70 | sb.write(self.production.constructor.name) 71 | 72 | for field in self.fields: 73 | sb.write(' ') 74 | sb.write('(') 75 | sb.write(field.type.name) 76 | sb.write(Field.get_cardinality_repr(field.cardinality)) 77 | sb.write('-') 78 | sb.write(field.name) 79 | 80 | if field.value is not None: 81 | for val_node in field.as_value_list: 82 | sb.write(' ') 83 | if isinstance(field.type, ASDLCompositeType): 84 | val_node.to_string(sb) 85 | else: 86 | sb.write(str(val_node).replace(' ', '-SPACE-')) 87 | 88 | sb.write(')') # of field 89 | 90 | sb.write(')') # of node 91 | 92 | if is_root: 93 | return sb.getvalue() 94 | 95 | def __hash__(self): 96 | code = hash(self.production) 97 | for field in self.fields: 98 | code = code + 37 * hash(field) 99 | 100 | return code 101 | 102 | def __eq__(self, other): 103 | if not isinstance(other, self.__class__): 104 | return False 105 | 106 | if self.created_time != other.created_time: 107 | return False 108 | 109 | if self.production != other.production: 110 | return False 111 | 112 | if len(self.fields) != len(other.fields): 113 | return False 114 | 115 | for i in range(len(self.fields)): 116 | if self.fields[i] != other.fields[i]: return False 117 | 118 | return True 119 | 120 | def __ne__(self, other): 121 | return not self.__eq__(other) 122 | 123 | def __repr__(self): 124 | return repr(self.production) 125 | 126 | @property 127 | def size(self): 128 | node_num = 1 129 | for field in self.fields: 130 | for val in field.as_value_list: 131 | if isinstance(val, AbstractSyntaxTree): 132 | node_num += val.size 133 | else: node_num += 1 134 | 135 | return node_num 136 | 137 | 138 | class RealizedField(Field): 139 | """wrapper of field realized with values""" 140 | def __init__(self, field, value=None, parent=None): 141 | super(RealizedField, self).__init__(field.name, field.type, field.cardinality) 142 | 143 | # record its parent AST node 144 | self.parent_node = None 145 | 146 | # FIXME: hack, return the field as a property 147 | self.field = field 148 | 149 | # initialize value to correct type 150 | if self.cardinality == 'multiple': 151 | self.value = [] 152 | if value is not None: 153 | for child_node in value: 154 | self.add_value(child_node) 155 | else: 156 | self.value = None 157 | # note the value could be 0! 158 | if value is not None: self.add_value(value) 159 | 160 | # properties only used in decoding, record if the field is finished generating 161 | # when card in [optional, multiple] 162 | self._not_single_cardinality_finished = False 163 | 164 | def add_value(self, value): 165 | if isinstance(value, AbstractSyntaxTree): 166 | value.parent_field = self 167 | 168 | if self.cardinality == 'multiple': 169 | self.value.append(value) 170 | else: 171 | self.value = value 172 | 173 | @property 174 | def as_value_list(self): 175 | """get value as an iterable""" 176 | if self.cardinality == 'multiple': return self.value 177 | elif self.value is not None: return [self.value] 178 | else: return [] 179 | 180 | @property 181 | def finished(self): 182 | if self.cardinality == 'single': 183 | if self.value is None: return False 184 | else: return True 185 | elif self.cardinality == 'optional' and self.value is not None: 186 | return True 187 | else: 188 | if self._not_single_cardinality_finished: return True 189 | else: return False 190 | 191 | def set_finish(self): 192 | # assert self.cardinality in ('optional', 'multiple') 193 | self._not_single_cardinality_finished = True 194 | 195 | def __eq__(self, other): 196 | if super(RealizedField, self).__eq__(other): 197 | if type(other) == Field: return True # FIXME: hack, Field and RealizedField can compare! 198 | if self.value == other.value: return True 199 | else: return False 200 | else: return False 201 | -------------------------------------------------------------------------------- /asdls/decode_hypothesis.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from asdls.asdl import * 4 | from asdls.hypothesis import Hypothesis 5 | from asdls.transition_system import * 6 | 7 | 8 | class DecodeHypothesis(Hypothesis): 9 | def __init__(self): 10 | super(DecodeHypothesis, self).__init__() 11 | 12 | self.action_infos = [] 13 | self.code = None 14 | 15 | def clone_and_apply_action_info(self, action_info): 16 | action = action_info.action 17 | 18 | new_hyp = self.clone_and_apply_action(action) 19 | new_hyp.action_infos.append(action_info) 20 | 21 | return new_hyp 22 | 23 | def copy(self): 24 | new_hyp = DecodeHypothesis() 25 | if self.tree: 26 | new_hyp.tree = self.tree.copy() 27 | 28 | new_hyp.actions = list(self.actions) 29 | new_hyp.action_infos = list(self.action_infos) 30 | new_hyp.score = self.score 31 | new_hyp._value_buffer = list(self._value_buffer) 32 | new_hyp.t = self.t 33 | new_hyp.code = self.code 34 | 35 | new_hyp.update_frontier_info() 36 | 37 | return new_hyp 38 | -------------------------------------------------------------------------------- /asdls/hypothesis.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from asdls.asdl import * 4 | from asdls.asdl_ast import AbstractSyntaxTree 5 | from asdls.transition_system import * 6 | 7 | class Hypothesis(object): 8 | def __init__(self): 9 | self.tree = None 10 | self.actions = [] 11 | self.score = 0. 12 | self.frontier_node = None 13 | self.frontier_field = None 14 | self._value_buffer = [] 15 | 16 | # record the current time step 17 | self.t = 0 18 | 19 | def apply_action(self, action): 20 | if self.tree is None: # the first action 21 | assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \ 22 | 'at the beginning of decoding' 23 | 24 | self.tree = AbstractSyntaxTree(action.production) 25 | self.update_frontier_info() 26 | elif self.frontier_node: 27 | if isinstance(self.frontier_field.type, ASDLCompositeType): 28 | if isinstance(action, ApplyRuleAction): 29 | field_value = AbstractSyntaxTree(action.production) 30 | field_value.created_time = self.t 31 | self.frontier_field.add_value(field_value) 32 | self.update_frontier_info() 33 | elif isinstance(action, ReduceAction): 34 | assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 35 | 'applied on field with multiple ' \ 36 | 'cardinality' 37 | self.frontier_field.set_finish() 38 | self.update_frontier_info() 39 | else: 40 | raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field)) 41 | else: # fill in a primitive field 42 | if isinstance(action, GenTokenAction): 43 | # only field of type string requires termination signal 44 | end_primitive = False 45 | if self.frontier_field.type.name == 'string': 46 | if action.is_stop_signal(): 47 | self.frontier_field.add_value(' '.join(self._value_buffer)) 48 | self._value_buffer = [] 49 | 50 | end_primitive = True 51 | else: 52 | self._value_buffer.append(action.token) 53 | else: 54 | self.frontier_field.add_value(action.token) 55 | end_primitive = True 56 | 57 | if end_primitive and self.frontier_field.cardinality in ('single', 'optional'): 58 | self.frontier_field.set_finish() 59 | self.update_frontier_info() 60 | 61 | elif isinstance(action, ReduceAction): 62 | assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 63 | 'applied on field with multiple ' \ 64 | 'cardinality' 65 | self.frontier_field.set_finish() 66 | self.update_frontier_info() 67 | else: 68 | raise ValueError('Can only invoke GenToken or Reduce actions on primitive fields') 69 | 70 | self.t += 1 71 | self.actions.append(action) 72 | 73 | def update_frontier_info(self): 74 | def _find_frontier_node_and_field(tree_node): 75 | # return None if each field of this ast node is realized else unfinished ast node, unrealized field 76 | if tree_node: 77 | for field in tree_node.fields: 78 | # if it's an intermediate node, check its children 79 | if isinstance(field.type, ASDLCompositeType) and field.value: 80 | if field.cardinality in ('single', 'optional'): iter_values = [field.value] 81 | else: iter_values = field.value 82 | 83 | for child_node in iter_values: 84 | result = _find_frontier_node_and_field(child_node) 85 | if result: return result 86 | 87 | # now all its possible children are checked 88 | if not field.finished: 89 | return tree_node, field 90 | 91 | return None 92 | else: return None 93 | 94 | frontier_info = _find_frontier_node_and_field(self.tree) 95 | if frontier_info: 96 | self.frontier_node, self.frontier_field = frontier_info 97 | else: 98 | self.frontier_node, self.frontier_field = None, None 99 | 100 | def clone_and_apply_action(self, action): 101 | new_hyp = self.copy() 102 | new_hyp.apply_action(action) 103 | 104 | return new_hyp 105 | 106 | def copy(self): 107 | new_hyp = Hypothesis() 108 | if self.tree: 109 | new_hyp.tree = self.tree.copy() 110 | 111 | new_hyp.actions = list(self.actions) 112 | new_hyp.score = self.score 113 | new_hyp._value_buffer = list(self._value_buffer) 114 | new_hyp.t = self.t 115 | 116 | new_hyp.update_frontier_info() 117 | 118 | return new_hyp 119 | 120 | @property 121 | def completed(self): 122 | return self.tree and self.frontier_field is None 123 | -------------------------------------------------------------------------------- /asdls/sql/grammar/sql_asdl_v0.txt: -------------------------------------------------------------------------------- 1 | # Assumptions: 2 | # 1. sql is correct 3 | # 2. only table name has alias 4 | # 3. only one intersect/union/except 5 | 6 | # val: value(float/string)/sql(dict)/col_unit(tuple) 7 | # col_unit: (agg_id, col_id, isDistinct(bool)) 8 | # val_unit: (unit_op, col_unit1, col_unit2) 9 | # table_unit: (table_type, tab_id/sql) 10 | # cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2) 11 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 12 | # sql { 13 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 14 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 15 | # 'where': condition 16 | # 'groupBy': [col_unit1, col_unit2, ...] 17 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 18 | # 'having': condition 19 | # 'limit': None/integer 20 | # 'intersect': None/sql 21 | # 'except': None/sql 22 | # 'union': None/sql 23 | # } 24 | 25 | # CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 26 | # JOIN_KEYWORDS = ('join', 'on', 'as') 27 | # CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 28 | # UNIT_OPS = ('none', '-', '+', "*", '/') 29 | # AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 30 | # TABLE_TYPE = ('sql', 'table_unit') 31 | # COND_OPS = ('and', 'or') 32 | # SQL_OPS = ('intersect', 'union', 'except') 33 | # ORDER_OPS = ('desc', 'asc') 34 | 35 | ########################################################################## 36 | 37 | # 1. eliminate ? by enumerating different number of items 38 | # 2. from conds, distinct and value generation are also not considered 39 | # 3. for select items, we use val_unit instead of (agg, val_unit) 40 | # 4. for orderby items, we use col_unit instead of val_unit 41 | # 5. for groupby items, we use col_unit instead of col_id 42 | # 6. predict from clause first, to obtain an overview of the entire query graph 43 | 44 | col_id, tab_id 45 | 46 | sql = Intersect(sql_unit sql_unit, sql_unit sql_unit) 47 | | Union(sql_unit sql_unit, sql_unit sql_unit) 48 | | Except(sql_unit sql_unit, sql_unit sql_unit) 49 | | Single(sql_unit sql_unit) 50 | 51 | sql_unit = Complete(from from_clause, val_unit* select_clause, cond where_clause, group_by group_by_clause, order_by order_by_clause) 52 | | NoWhere(from from_clause, val_unit* select_clause, group_by group_by_clause, order_by order_by_clause) 53 | | NoGroupBy(from from_clause, val_unit* select_clause, cond where_clause, order_by order_by_clause) 54 | | NoOrderBy(from from_clause, val_unit* select_clause, cond where_clause, group_by group_by_clause) 55 | | OnlyWhere(from from_clause, val_unit* select_clause, cond where_clause) 56 | | OnlyGroupBy(from from_clause, val_unit* select_clause, group_by group_by_clause) 57 | | OnlyOrderBy(from from_clause, val_unit* select_clause, order_by order_by_clause) 58 | | Simple(from from_clause, val_unit* select_clause) 59 | 60 | from = FromTable(tab_id* tab_id_list) 61 | | FromSQL(sql from_sql) 62 | 63 | group_by = Having(col_unit* col_unit_list, cond having_clause) 64 | | NoHaving(col_unit* col_unit_list) 65 | 66 | order_by = Asc(col_unit* col_unit_list) 67 | | Desc(col_unit* col_unit_list) 68 | | AscLimit(col_unit* col_unit_list) 69 | | DescLimit(col_unit* col_unit_list) 70 | 71 | cond = And(cond left, cond right) 72 | | Or(cond left, cond right) 73 | | Between(val_unit val_unit) 74 | | Eq(val_unit val_unit) 75 | | Gt(val_unit val_unit) 76 | | Lt(val_unit val_unit) 77 | | Ge(val_unit val_unit) 78 | | Le(val_unit val_unit) 79 | | Neq(val_unit val_unit) 80 | | Like(val_unit val_unit) 81 | | NotLike(val_unit val_unit) 82 | | BetweenSQL(val_unit val_unit, sql cond_sql) 83 | | EqSQL(val_unit val_unit, sql cond_sql) 84 | | GtSQL(val_unit val_unit, sql cond_sql) 85 | | LtSQL(val_unit val_unit, sql cond_sql) 86 | | GeSQL(val_unit val_unit, sql cond_sql) 87 | | LeSQL(val_unit val_unit, sql cond_sql) 88 | | NeqSQL(val_unit val_unit, sql cond_sql) 89 | | InSQL(val_unit val_unit, sql cond_sql) 90 | | NotInSQL(val_unit val_unit, sql cond_sql) 91 | 92 | val_unit = Unary(col_unit col_unit) 93 | | Minus(col_unit col_unit, col_unit col_unit) 94 | | Plus(col_unit col_unit, col_unit col_unit) 95 | | Times(col_unit col_unit, col_unit col_unit) 96 | | Divide(col_unit col_unit, col_unit col_unit) 97 | 98 | col_unit = None(col_id col_id) 99 | | Max(col_id col_id) 100 | | Min(col_id col_id) 101 | | Count(col_id col_id) 102 | | Sum(col_id col_id) 103 | | Avg(col_id col_id) -------------------------------------------------------------------------------- /asdls/sql/grammar/sql_asdl_v1.txt: -------------------------------------------------------------------------------- 1 | # Assumptions: 2 | # 1. sql is correct 3 | # 2. only table name has alias 4 | # 3. only one intersect/union/except 5 | 6 | # val: value(float/string)/sql(dict)/col_unit(tuple) 7 | # col_unit: (agg_id, col_id, isDistinct(bool)) 8 | # val_unit: (unit_op, col_unit1, col_unit2) 9 | # table_unit: (table_type, tab_id/sql) 10 | # cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2) 11 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 12 | # sql { 13 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 14 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 15 | # 'where': condition 16 | # 'groupBy': [col_unit1, col_unit2, ...] 17 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 18 | # 'having': condition 19 | # 'limit': None/integer 20 | # 'intersect': None/sql 21 | # 'except': None/sql 22 | # 'union': None/sql 23 | # } 24 | 25 | # CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 26 | # JOIN_KEYWORDS = ('join', 'on', 'as') 27 | # CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 28 | # UNIT_OPS = ('none', '-', '+', "*", '/') 29 | # AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 30 | # TABLE_TYPE = ('sql', 'table_unit') 31 | # COND_OPS = ('and', 'or') 32 | # SQL_OPS = ('intersect', 'union', 'except') 33 | # ORDER_OPS = ('desc', 'asc') 34 | 35 | ########################################################################## 36 | 37 | # 1. eliminate * by enumerating different number of items 38 | # 2. from conds, distinct and value generation are also not considered 39 | # 3. for select items, we use val_unit instead of (agg, val_unit) 40 | # 4. for orderby items, we use col_unit instead of val_unit 41 | # 5. for groupby items, we use col_unit instead of col_id 42 | # 6. predict from clause first, to obtain an overview of the entire query graph 43 | 44 | col_id, tab_id 45 | 46 | sql = Intersect(sql_unit sql_unit, sql_unit sql_unit) 47 | | Union(sql_unit sql_unit, sql_unit sql_unit) 48 | | Except(sql_unit sql_unit, sql_unit sql_unit) 49 | | Single(sql_unit sql_unit) 50 | 51 | sql_unit = SQL(from from_clause, select select_clause, cond? where_clause, group_by? group_by_clause, order_by? order_by_clause) 52 | 53 | select = SelectOne(val_unit val_unit) 54 | | SelectTwo(val_unit val_unit, val_unit val_unit) 55 | | SelectThree(val_unit val_unit, val_unit val_unit, val_unit val_unit) 56 | | SelectFour(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit) 57 | | SelectFive(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit) 58 | 59 | from = FromOneTable(tab_id tab_id) 60 | | FromTwoTable(tab_id tab_id, tab_id tab_id) 61 | | FromThreeTable(tab_id tab_id, tab_id tab_id, tab_id tab_id) 62 | | FromFourTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id) 63 | | FromFiveTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id) 64 | | FromSixTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id) 65 | | FromSQL(sql from_sql) 66 | 67 | group_by = GroupByOne(col_unit col_unit, cond? having_clause) 68 | | GroupByTwo(col_unit col_unit, col_unit col_unit, cond? having_clause) 69 | 70 | order_by = OneAsc(col_unit col_unit) 71 | | OneDesc(col_unit col_unit) 72 | | OneAscLimit(col_unit col_unit) 73 | | OneDescLimit(col_unit col_unit) 74 | | TwoAsc(col_unit col_unit, col_unit col_unit) 75 | | TwoDesc(col_unit col_unit, col_unit col_unit) 76 | | TwoAscLimit(col_unit col_unit, col_unit col_unit) 77 | | TwoDescLimit(col_unit col_unit, col_unit col_unit) 78 | 79 | cond = And(cond left, cond right) 80 | | Or(cond left, cond right) 81 | | Between(val_unit val_unit) 82 | | Eq(val_unit val_unit) 83 | | Gt(val_unit val_unit) 84 | | Lt(val_unit val_unit) 85 | | Ge(val_unit val_unit) 86 | | Le(val_unit val_unit) 87 | | Neq(val_unit val_unit) 88 | | Like(val_unit val_unit) 89 | | NotLike(val_unit val_unit) 90 | | BetweenSQL(val_unit val_unit, sql cond_sql) 91 | | EqSQL(val_unit val_unit, sql cond_sql) 92 | | GtSQL(val_unit val_unit, sql cond_sql) 93 | | LtSQL(val_unit val_unit, sql cond_sql) 94 | | GeSQL(val_unit val_unit, sql cond_sql) 95 | | LeSQL(val_unit val_unit, sql cond_sql) 96 | | NeqSQL(val_unit val_unit, sql cond_sql) 97 | | InSQL(val_unit val_unit, sql cond_sql) 98 | | NotInSQL(val_unit val_unit, sql cond_sql) 99 | 100 | val_unit = Unary(col_unit col_unit) 101 | | Minus(col_unit col_unit, col_unit col_unit) 102 | | Plus(col_unit col_unit, col_unit col_unit) 103 | | Times(col_unit col_unit, col_unit col_unit) 104 | | Divide(col_unit col_unit, col_unit col_unit) 105 | 106 | col_unit = None(col_id col_id) 107 | | Max(col_id col_id) 108 | | Min(col_id col_id) 109 | | Count(col_id col_id) 110 | | Sum(col_id col_id) 111 | | Avg(col_id col_id) 112 | -------------------------------------------------------------------------------- /asdls/sql/grammar/sql_asdl_v2.txt: -------------------------------------------------------------------------------- 1 | # Assumptions: 2 | # 1. sql is correct 3 | # 2. only table name has alias 4 | # 3. only one intersect/union/except 5 | 6 | # val: value(float/string)/sql(dict)/col_unit(tuple) 7 | # col_unit: (agg_id, col_id, isDistinct(bool)) 8 | # val_unit: (unit_op, col_unit1, col_unit2) 9 | # table_unit: (table_type, tab_id/sql) 10 | # cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2) 11 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 12 | # sql { 13 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 14 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 15 | # 'where': condition 16 | # 'groupBy': [col_unit1, col_unit2, ...] 17 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 18 | # 'having': condition 19 | # 'limit': None/integer 20 | # 'intersect': None/sql 21 | # 'except': None/sql 22 | # 'union': None/sql 23 | # } 24 | 25 | # CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 26 | # JOIN_KEYWORDS = ('join', 'on', 'as') 27 | # CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 28 | # UNIT_OPS = ('none', '-', '+', "*", '/') 29 | # AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 30 | # TABLE_TYPE = ('sql', 'table_unit') 31 | # COND_OPS = ('and', 'or') 32 | # SQL_OPS = ('intersect', 'union', 'except') 33 | # ORDER_OPS = ('desc', 'asc') 34 | 35 | ########################################################################## 36 | 37 | # 1. eliminate both ? and * cardinality by enumerating different number of items 38 | # 2. from conds, distinct and value generation are also not considered 39 | # 3. for select items, we use val_unit instead of (agg, val_unit) 40 | # 4. for orderby items, we use col_unit instead of val_unit 41 | # 5. for groupby items, we use col_unit instead of col_id 42 | # 6. predict from clause first, to obtain an overview of the entire query graph 43 | 44 | col_id, tab_id 45 | 46 | sql = Intersect(sql_unit sql_unit, sql_unit sql_unit) 47 | | Union(sql_unit sql_unit, sql_unit sql_unit) 48 | | Except(sql_unit sql_unit, sql_unit sql_unit) 49 | | Single(sql_unit sql_unit) 50 | 51 | sql_unit = Complete(from from_clause, select select_clause, cond where_clause, group_by group_by_clause, order_by order_by_clause) 52 | | NoWhere(from from_clause, select select_clause, group_by group_by_clause, order_by order_by_clause) 53 | | NoGroupBy(from from_clause, select select_clause, cond where_clause, order_by order_by_clause) 54 | | NoOrderBy(from from_clause, select select_clause, cond where_clause, group_by group_by_clause) 55 | | OnlyWhere(from from_clause, select select_clause, cond where_clause) 56 | | OnlyGroupBy(from from_clause, select select_clause, group_by group_by_clause) 57 | | OnlyOrderBy(from from_clause, select select_clause, order_by order_by_clause) 58 | | Simple(from from_clause, select select_clause) 59 | 60 | select = SelectOne(val_unit val_unit) 61 | | SelectTwo(val_unit val_unit, val_unit val_unit) 62 | | SelectThree(val_unit val_unit, val_unit val_unit, val_unit val_unit) 63 | | SelectFour(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit) 64 | | SelectFive(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit) 65 | 66 | from = FromOneTable(tab_id tab_id) 67 | | FromTwoTable(tab_id tab_id, tab_id tab_id) 68 | | FromThreeTable(tab_id tab_id, tab_id tab_id, tab_id tab_id) 69 | | FromFourTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id) 70 | | FromFiveTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id) 71 | | FromSixTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id) 72 | | FromSQL(sql from_sql) 73 | 74 | group_by = OneNoHaving(col_unit col_unit) 75 | | TwoNoHaving(col_unit col_unit, col_unit col_unit) 76 | | OneHaving(col_unit col_unit, cond having_clause) 77 | | TwoHaving(col_unit col_unit, col_unit col_unit, cond having_clause) 78 | 79 | order_by = OneAsc(col_unit col_unit) 80 | | OneDesc(col_unit col_unit) 81 | | OneAscLimit(col_unit col_unit) 82 | | OneDescLimit(col_unit col_unit) 83 | | TwoAsc(col_unit col_unit, col_unit col_unit) 84 | | TwoDesc(col_unit col_unit, col_unit col_unit) 85 | | TwoAscLimit(col_unit col_unit, col_unit col_unit) 86 | | TwoDescLimit(col_unit col_unit, col_unit col_unit) 87 | 88 | cond = And(cond left, cond right) 89 | | Or(cond left, cond right) 90 | | Between(val_unit val_unit) 91 | | Eq(val_unit val_unit) 92 | | Gt(val_unit val_unit) 93 | | Lt(val_unit val_unit) 94 | | Ge(val_unit val_unit) 95 | | Le(val_unit val_unit) 96 | | Neq(val_unit val_unit) 97 | | Like(val_unit val_unit) 98 | | NotLike(val_unit val_unit) 99 | | BetweenSQL(val_unit val_unit, sql cond_sql) 100 | | EqSQL(val_unit val_unit, sql cond_sql) 101 | | GtSQL(val_unit val_unit, sql cond_sql) 102 | | LtSQL(val_unit val_unit, sql cond_sql) 103 | | GeSQL(val_unit val_unit, sql cond_sql) 104 | | LeSQL(val_unit val_unit, sql cond_sql) 105 | | NeqSQL(val_unit val_unit, sql cond_sql) 106 | | InSQL(val_unit val_unit, sql cond_sql) 107 | | NotInSQL(val_unit val_unit, sql cond_sql) 108 | 109 | val_unit = Unary(col_unit col_unit) 110 | | Minus(col_unit col_unit, col_unit col_unit) 111 | | Plus(col_unit col_unit, col_unit col_unit) 112 | | Times(col_unit col_unit, col_unit col_unit) 113 | | Divide(col_unit col_unit, col_unit col_unit) 114 | 115 | col_unit = None(col_id col_id) 116 | | Max(col_id col_id) 117 | | Min(col_id col_id) 118 | | Count(col_id col_id) 119 | | Sum(col_id col_id) 120 | | Avg(col_id col_id) 121 | -------------------------------------------------------------------------------- /asdls/sql/parser/parser_base.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | 3 | from asdls.asdl import ASDLGrammar 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class Parser(): 7 | """ Parse a sql dict into AbstractSyntaxTree object according to specified grammar rules 8 | Some common methods are implemented in this parent class. 9 | """ 10 | def __init__(self, grammar: ASDLGrammar): 11 | super(Parser, self).__init__() 12 | self.grammar = grammar 13 | 14 | @classmethod 15 | def from_grammar(cls, grammar: ASDLGrammar): 16 | grammar_name = grammar._grammar_name 17 | if 'v0' in grammar_name: 18 | from asdls.sql.parser.parser_v0 import ParserV0 19 | return ParserV0(grammar) 20 | elif 'v1' in grammar_name: 21 | from asdls.sql.parser.parser_v1 import ParserV1 22 | return ParserV1(grammar) 23 | elif 'v2' in grammar_name: 24 | from asdls.sql.parser.parser_v2 import ParserV2 25 | return ParserV2(grammar) 26 | else: 27 | raise ValueError('Not recognized grammar name %s' % (grammar_name)) 28 | 29 | def parse(self, sql_json: dict): 30 | """ sql_json is exactly the 'sql' field of each data sample 31 | return AbstractSyntaxTree of sql 32 | """ 33 | try: 34 | ast_node = self.parse_sql(sql_json) 35 | return ast_node 36 | except Exception as e: 37 | print('Something Error happened while parsing:', e) 38 | # if fail to parse, just return select * from table(id=0) 39 | error_sql = { 40 | "select": [False, [(0, [0, [0, 0, False], None])]], 41 | "from": {'table_units': [('table_unit', 0)], 'conds': []}, 42 | "where": [], "groupBy": [], "orderBy": [], "having": [], "limit": None, 43 | "intersect": [], "union": [], "except": [] 44 | } 45 | ast_node = self.parse_sql(error_sql) 46 | return ast_node 47 | 48 | def parse_sql(self, sql: dict): 49 | """ Determine whether sql has intersect/union/except, 50 | at most one in the current dict 51 | """ 52 | for choice in ['intersect', 'union', 'except']: 53 | if sql[choice]: 54 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(choice.title())) 55 | nested_sql = sql[choice] 56 | sql_field1, sql_field2 = ast_node.fields 57 | sql_field1.add_value(self.parse_sql_unit(sql)) 58 | sql_field2.add_value(self.parse_sql_unit(nested_sql)) 59 | return ast_node 60 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Single')) 61 | ast_node.fields[0].add_value(self.parse_sql_unit(sql)) 62 | return ast_node 63 | 64 | def parse_sql_unit(self, sql: dict): 65 | """ Parse a single sql unit, determine the existence of different clauses 66 | """ 67 | sql_ctr = ['Complete', 'NoWhere', 'NoGroupBy', 'NoOrderBy', 'OnlyWhere', 'OnlyGroupBy', 'OnlyOrderBy', 'Simple'] 68 | where_field, groupby_field, orderby_field = [None] * 3 69 | if sql['where'] and sql['groupBy'] and sql['orderBy']: 70 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[0])) 71 | from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields 72 | elif sql['groupBy'] and sql['orderBy']: 73 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[1])) 74 | from_field, select_field, groupby_field, orderby_field = ast_node.fields 75 | elif sql['where'] and sql['orderBy']: 76 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[2])) 77 | from_field, select_field, where_field, orderby_field = ast_node.fields 78 | elif sql['where'] and sql['groupBy']: 79 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[3])) 80 | from_field, select_field, where_field, groupby_field = ast_node.fields 81 | elif sql['where']: 82 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[4])) 83 | from_field, select_field, where_field = ast_node.fields 84 | elif sql['groupBy']: 85 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[5])) 86 | from_field, select_field, groupby_field = ast_node.fields 87 | elif sql['orderBy']: 88 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[6])) 89 | from_field, select_field, orderby_field = ast_node.fields 90 | else: 91 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[7])) 92 | from_field, select_field = ast_node.fields 93 | self.parse_from(sql['from'], from_field) 94 | self.parse_select(sql['select'], select_field) 95 | if sql['where']: 96 | self.parse_where(sql['where'], where_field) 97 | if sql['groupBy']: # if having clause is not empty, groupBy must exist 98 | self.parse_groupby(sql['groupBy'], sql['having'], groupby_field) 99 | if sql['orderBy']: # if limit is not None, orderBY is not empty 100 | self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field) 101 | return ast_node 102 | 103 | def parse_select(self, select_clause: list, select_field: RealizedField): 104 | raise NotImplementedError 105 | 106 | def parse_from(self, from_clause: dict, from_field: RealizedField): 107 | raise NotImplementedError 108 | 109 | def parse_where(self, where_clause: list, where_field: RealizedField): 110 | where_field.add_value(self.parse_conds(where_clause)) 111 | 112 | def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField): 113 | raise NotImplementedError 114 | 115 | def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField): 116 | raise NotImplementedError 117 | 118 | def parse_conds(self, conds: list): 119 | assert len(conds) > 0 120 | and_or = (len(conds) - 1) // 2 121 | root_node, left_field, right_field = [None] * 3 122 | for i in reversed(range(and_or)): 123 | and_or_idx = 2 * i + 1 124 | conj = conds[and_or_idx] 125 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(conj.title())) 126 | if root_node is None: 127 | root_node = ast_node 128 | if left_field is not None: 129 | left_field.add_value(ast_node) 130 | left_field, right_field = ast_node.fields 131 | right_field.add_value(self.parse_cond(conds[2 * (i + 1)])) 132 | if left_field is None: 133 | root_node = self.parse_cond(conds[0]) 134 | else: 135 | left_field.add_value(self.parse_cond(conds[0])) 136 | return root_node 137 | 138 | def parse_cond(self, cond: list): 139 | not_op, cmp_op, val_unit, val1, val2 = cond 140 | not_op = '^' if not_op else '' 141 | sql_val = 'sql' if type(val1) == dict else '' 142 | op_list = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 143 | cmp_op = not_op + op_list[cmp_op] + sql_val 144 | op_dict = { 145 | 'between': 'Between', '=': 'Eq', '>': 'Gt', '<': 'Lt', '>=': 'Ge', '<=': 'Le', '!=': 'Neq', 146 | 'insql': 'InSQL', 'like': 'Like', '^insql': 'NotInSQL', '^like': 'NotLike', 'betweensql': 'BetweenSQL', '=sql': 'EqSQL', 147 | '>sql': 'GtSQL', '=sql': 'GeSQL', '<=sql': 'LeSQL', '!=sql': 'NeqSQL' 148 | } 149 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(op_dict[cmp_op])) 150 | val_unit_field = ast_node.fields[0] 151 | val_unit_field.add_value(self.parse_val_unit(val_unit)) 152 | if len(ast_node.fields) == 2: 153 | val_field = ast_node.fields[1] 154 | val_field.add_value(self.parse_sql(val1)) 155 | return ast_node 156 | 157 | def parse_val_unit(self, val_unit: list): 158 | unit_op, col_unit1, col_unit2 = val_unit 159 | unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide'] 160 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(unit_op_list[unit_op])) 161 | if unit_op == 0: 162 | ast_node.fields[0].add_value(self.parse_col_unit(col_unit1)) 163 | else: 164 | # ast_node.fields[0].add_value(int(col_unit1[1])) 165 | # ast_node.fields[1].add_value(int(col_unit2[1])) 166 | ast_node.fields[0].add_value(self.parse_col_unit(col_unit1)) 167 | ast_node.fields[1].add_value(self.parse_col_unit(col_unit2)) 168 | return ast_node 169 | 170 | def parse_col_unit(self, col_unit: list): 171 | agg_op, col_id, distinct_flag = col_unit 172 | agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg'] 173 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(agg_op_list[agg_op])) 174 | ast_node.fields[0].add_value(int(col_id)) 175 | return ast_node 176 | -------------------------------------------------------------------------------- /asdls/sql/parser/parser_v0.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from asdls.sql.parser.parser_base import Parser 3 | from asdls.asdl import ASDLGrammar 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class ParserV0(Parser): 7 | """ In this version, we eliminate all cardinality ? and restrict that * must have at least one item 8 | """ 9 | def parse_select(self, select_clause: list, select_field: RealizedField): 10 | """ 11 | ignore cases agg(col_id1 op col_id2) and agg(col_id1) op agg(col_id2) 12 | """ 13 | select_clause = select_clause[1] # list of (agg, val_unit) 14 | unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide'] 15 | agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg'] 16 | for agg, val_unit in select_clause: 17 | if agg != 0: # agg col_id 18 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary')) 19 | col_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(agg_op_list[agg])) 20 | col_node.fields[0].add_value(int(val_unit[1][1])) 21 | ast_node.fields[0].add_value(col_node) 22 | else: # binary_op col_id1 col_id2 23 | ast_node = self.parse_val_unit(val_unit) 24 | select_field.add_value(ast_node) 25 | 26 | def parse_from(self, from_clause: dict, from_field: RealizedField): 27 | """ 28 | Ignore from conditions, since it is not evaluated in evaluation script 29 | """ 30 | table_units = from_clause['table_units'] 31 | t = table_units[0][0] 32 | if t == 'table_unit': 33 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromTable')) 34 | tables_field = ast_node.fields[0] 35 | for _, v in table_units: 36 | tables_field.add_value(int(v)) 37 | else: 38 | assert t == 'sql' 39 | v = table_units[0][1] 40 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL')) 41 | ast_node.fields[0].add_value(self.parse_sql(v)) 42 | from_field.add_value(ast_node) 43 | 44 | def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField): 45 | col_ids = [] 46 | for col_unit in groupby_clause: 47 | col_ids.append(col_unit[1]) # agg is None and isDistinct False 48 | if having_clause: 49 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Having')) 50 | col_units_field, having_fields = ast_node.fields 51 | having_fields.add_value(self.parse_conds(having_clause)) 52 | else: 53 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('NoHaving')) 54 | col_units_field = ast_node.fields[0] 55 | for col_unit in groupby_clause: 56 | col_units_field.add_value(self.parse_col_unit(col_unit)) 57 | groupby_field.add_value(ast_node) 58 | 59 | def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField): 60 | if limit is None: 61 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Asc')) if orderby_clause[0] == 'asc' \ 62 | else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Desc')) 63 | else: 64 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('AscLimit')) if orderby_clause[0] == 'asc' \ 65 | else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('DescLimit')) 66 | col_units_field = ast_node.fields[0] 67 | for val_unit in orderby_clause[1]: 68 | col_units_field.add_value(self.parse_col_unit(val_unit[1])) 69 | orderby_field.add_value(ast_node) -------------------------------------------------------------------------------- /asdls/sql/parser/parser_v1.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from asdls.sql.parser.parser_base import Parser 3 | from asdls.asdl import ASDLGrammar 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class ParserV1(Parser): 7 | """ In this version, we eliminate all cardinality * and use ? 8 | """ 9 | def parse_sql_unit(self, sql: dict): 10 | """ Parse a single sql unit, determine the existence of different clauses 11 | """ 12 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('SQL')) 13 | from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields 14 | self.parse_from(sql['from'], from_field) 15 | self.parse_select(sql['select'], select_field) 16 | if sql['where']: 17 | self.parse_where(sql['where'], where_field) 18 | if sql['groupBy']: # if having clause is not empty, groupBy must exist 19 | self.parse_groupby(sql['groupBy'], sql['having'], groupby_field) 20 | if sql['orderBy']: # if limit is not None, orderBY is not empty 21 | self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field) 22 | return ast_node 23 | 24 | def parse_select(self, select_clause: list, select_field: RealizedField): 25 | select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag 26 | select_num = min(5, len(select_clause)) 27 | select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive'] 28 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1])) 29 | for i, (agg, val_unit) in enumerate(select_clause): 30 | if i >= 5: break 31 | if agg != 0: # MAX/MIN/COUNT/SUM/AVG 32 | val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary')) 33 | col_unit = [agg] + val_unit[1][1:] 34 | val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit)) 35 | else: 36 | val_unit_ast = self.parse_val_unit(val_unit) 37 | ast_node.fields[i].add_value(val_unit_ast) 38 | select_field.add_value(ast_node) 39 | 40 | def parse_from(self, from_clause: dict, from_field: RealizedField): 41 | """ Ignore from conditions, since it is not evaluated in evaluation script 42 | """ 43 | table_units = from_clause['table_units'] 44 | t = table_units[0][0] 45 | if t == 'table_unit': 46 | table_num = min(6, len(table_units)) 47 | table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable'] 48 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1])) 49 | for i, (_, tab_id) in enumerate(table_units): 50 | if i >= 6: break 51 | ast_node.fields[i].add_value(int(tab_id)) 52 | else: 53 | assert t == 'sql' 54 | v = table_units[0][1] 55 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL')) 56 | ast_node.fields[0].add_value(self.parse_sql(v)) 57 | from_field.add_value(ast_node) 58 | 59 | def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField): 60 | groupby_ctr = ['GroupByOne', 'GroupByTwo'] 61 | groupby_num = min(2, len(groupby_clause)) - 1 62 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num])) 63 | if having_clause: 64 | having_field = ast_node.fields[-1] 65 | having_field.add_value(self.parse_conds(having_clause)) 66 | for i, col_unit in enumerate(groupby_clause): 67 | if i >= 2: break 68 | # ast_node.fields[i].add_value(int(col_unit[1])) 69 | ast_node.fields[i].add_value(self.parse_col_unit(col_unit)) 70 | groupby_field.add_value(ast_node) 71 | 72 | def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField): 73 | orderby_num = min(2, len(orderby_clause[1])) 74 | num_str = 'One' if orderby_num == 1 else 'Two' 75 | order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc' 76 | limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit 77 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str)) 78 | for i, val_unit in enumerate(orderby_clause[1]): 79 | if i >= 2: break 80 | col_unit = val_unit[1] 81 | ast_node.fields[i].add_value(self.parse_col_unit(col_unit)) 82 | # ast_node.fields[i].add_value(self.parse_val_unit(val_unit)) 83 | orderby_field.add_value(ast_node) -------------------------------------------------------------------------------- /asdls/sql/parser/parser_v2.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from asdls.sql.parser.parser_base import Parser 3 | from asdls.asdl import ASDLGrammar 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class ParserV2(Parser): 7 | """ In this version, we remove all cardinality ? or * 8 | by enumerating all different lengths of item list, such as SelectOne, SelectTwo 9 | """ 10 | def parse_select(self, select_clause: list, select_field: RealizedField): 11 | select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag 12 | select_num = min(5, len(select_clause)) 13 | select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive'] 14 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1])) 15 | for i, (agg, val_unit) in enumerate(select_clause): 16 | if i >= 5: break 17 | if agg != 0: # MAX/MIN/COUNT/SUM/AVG 18 | val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary')) 19 | col_unit = [agg] + val_unit[1][1:] 20 | val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit)) 21 | else: 22 | val_unit_ast = self.parse_val_unit(val_unit) 23 | ast_node.fields[i].add_value(val_unit_ast) 24 | select_field.add_value(ast_node) 25 | 26 | def parse_from(self, from_clause: dict, from_field: RealizedField): 27 | """ Ignore from conditions, since it is not evaluated in evaluation script 28 | """ 29 | table_units = from_clause['table_units'] 30 | t = table_units[0][0] 31 | if t == 'table_unit': 32 | table_num = min(6, len(table_units)) 33 | table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable'] 34 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1])) 35 | for i, (_, tab_id) in enumerate(table_units): 36 | if i >= 6: break 37 | ast_node.fields[i].add_value(int(tab_id)) 38 | else: 39 | assert t == 'sql' 40 | v = table_units[0][1] 41 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL')) 42 | ast_node.fields[0].add_value(self.parse_sql(v)) 43 | from_field.add_value(ast_node) 44 | 45 | def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField): 46 | groupby_ctr = ['OneNoHaving', 'TwoNoHaving', 'OneHaving', 'TwoHaving'] 47 | groupby_num = min(2, len(groupby_clause)) 48 | if having_clause: 49 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num + 1])) 50 | having_field = ast_node.fields[-1] 51 | having_field.add_value(self.parse_conds(having_clause)) 52 | else: 53 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num - 1])) 54 | for i, col_unit in enumerate(groupby_clause): 55 | if i >= 2: break 56 | # ast_node.fields[i].add_value(int(col_unit[1])) 57 | ast_node.fields[i].add_value(self.parse_col_unit(col_unit)) 58 | groupby_field.add_value(ast_node) 59 | 60 | def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField): 61 | orderby_num = min(2, len(orderby_clause[1])) 62 | num_str = 'One' if orderby_num == 1 else 'Two' 63 | order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc' 64 | limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit 65 | ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str)) 66 | for i, val_unit in enumerate(orderby_clause[1]): 67 | if i >= 2: break 68 | col_unit = val_unit[1] 69 | ast_node.fields[i].add_value(self.parse_col_unit(col_unit)) 70 | # ast_node.fields[i].add_value(self.parse_val_unit(val_unit)) 71 | orderby_field.add_value(ast_node) -------------------------------------------------------------------------------- /asdls/sql/sql_transition_system.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json, os, sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 4 | from asdls.sql.parser.parser_base import Parser 5 | from asdls.sql.unparser.unparser_base import UnParser 6 | from asdls.asdl import ASDLGrammar 7 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 8 | from asdls.transition_system import GenTokenAction, TransitionSystem, ApplyRuleAction, ReduceAction 9 | 10 | class SelectColumnAction(GenTokenAction): 11 | def __init__(self, column_id): 12 | super(SelectColumnAction, self).__init__(column_id) 13 | 14 | @property 15 | def column_id(self): 16 | return self.token 17 | 18 | def __repr__(self): 19 | return 'SelectColumnAction[id=%s]' % self.column_id 20 | 21 | class SelectTableAction(GenTokenAction): 22 | def __init__(self, table_id): 23 | super(SelectTableAction, self).__init__(table_id) 24 | 25 | @property 26 | def table_id(self): 27 | return self.token 28 | 29 | def __repr__(self): 30 | return 'SelectTableAction[id=%s]' % self.table_id 31 | 32 | class SQLTransitionSystem(TransitionSystem): 33 | 34 | def __init__(self, grammar): 35 | self.grammar = grammar 36 | self.parser = Parser.from_grammar(self.grammar) 37 | self.unparser = UnParser.from_grammar(self.grammar) 38 | 39 | def ast_to_surface_code(self, asdl_ast, table, *args, **kargs): 40 | return self.unparser.unparse(asdl_ast, table, *args, **kargs) 41 | 42 | def compare_ast(self, hyp_ast, ref_ast): 43 | raise NotImplementedError 44 | 45 | def tokenize_code(self, code, mode): 46 | raise NotImplementedError 47 | 48 | def surface_code_to_ast(self, code): 49 | return self.parser.parse(code) 50 | 51 | def get_valid_continuation_types(self, hyp): 52 | if hyp.tree: 53 | if self.grammar.is_composite_type(hyp.frontier_field.type): 54 | if hyp.frontier_field.cardinality == 'single': 55 | return ApplyRuleAction, 56 | elif hyp.frontier_field.cardinality == 'multiple': 57 | if len(hyp.frontier_field.value) == 0: 58 | return ApplyRuleAction, 59 | else: 60 | return ApplyRuleAction, ReduceAction 61 | else: 62 | return ApplyRuleAction, ReduceAction 63 | elif hyp.frontier_field.type.name == 'col_id': 64 | if hyp.frontier_field.cardinality == 'single': 65 | return SelectColumnAction, 66 | elif hyp.frontier_field.cardinality == 'multiple': 67 | if len(hyp.frontier_field.value) == 0: 68 | return SelectColumnAction, 69 | else: 70 | return SelectColumnAction, ReduceAction 71 | else: # optional, not used 72 | return SelectColumnAction, ReduceAction 73 | elif hyp.frontier_field.type.name == 'tab_id': 74 | if hyp.frontier_field.cardinality == 'single': 75 | return SelectTableAction, 76 | elif hyp.frontier_field.cardinality == 'multiple': 77 | if len(hyp.frontier_field.value) == 0: 78 | return SelectTableAction, 79 | else: 80 | return SelectTableAction, ReduceAction 81 | else: # optional, not used 82 | return SelectTableAction, ReduceAction 83 | else: # not used now 84 | return GenTokenAction, 85 | else: 86 | return ApplyRuleAction, 87 | 88 | def get_primitive_field_actions(self, realized_field): 89 | if realized_field.type.name == 'col_id': 90 | if realized_field.cardinality == 'multiple': 91 | action_list = [] 92 | for idx in realized_field.value: 93 | action_list.append(SelectColumnAction(int(idx))) 94 | return action_list 95 | elif realized_field.value is not None: 96 | return [SelectColumnAction(int(realized_field.value))] 97 | else: 98 | return [] 99 | elif realized_field.type.name == 'tab_id': 100 | if realized_field.cardinality == 'multiple': 101 | action_list = [] 102 | for idx in realized_field.value: 103 | action_list.append(SelectTableAction(int(idx))) 104 | return action_list 105 | elif realized_field.value is not None: 106 | return [SelectTableAction(int(realized_field.value))] 107 | else: 108 | return [] 109 | else: 110 | raise ValueError('unknown primitive field type') 111 | 112 | if __name__ == '__main__': 113 | 114 | try: 115 | from evaluation import evaluate, build_foreign_key_map_from_json 116 | except Exception: 117 | print('Cannot find evaluator ...') 118 | grammar = ASDLGrammar.from_filepath('asdls/sql/grammar/sql_asdl_v2.txt') 119 | print('Total number of productions:', len(grammar)) 120 | for each in grammar.productions: 121 | print(each) 122 | print('Total number of types:', len(grammar.types)) 123 | for each in grammar.types: 124 | print(each) 125 | print('Total number of fields:', len(grammar.fields)) 126 | for each in grammar.fields: 127 | print(each) 128 | 129 | spider_trans = SQLTransitionSystem(grammar) 130 | kmaps = build_foreign_key_map_from_json('data/tables.json') 131 | dbs_list = json.load(open('data/tables.json', 'r')) 132 | dbs = {} 133 | for each in dbs_list: 134 | dbs[each['db_id']] = each 135 | 136 | train = json.load(open('data/train.json', 'r')) 137 | train_db = [ex['db_id'] for ex in train] 138 | train = [ex['sql'] for ex in train] 139 | dev = json.load(open('data/dev.json', 'r')) 140 | dev_db = [ex['db_id'] for ex in dev] 141 | dev = [ex['sql'] for ex in dev] 142 | 143 | recovered_sqls = [] 144 | for idx in range(len(train)): 145 | sql_ast = spider_trans.surface_code_to_ast(train[idx]) 146 | sql_ast.sanity_check() 147 | # print(spider_trans.get_actions(sql_ast)) 148 | recovered_sql = spider_trans.ast_to_surface_code(sql_ast, dbs[train_db[idx]]) 149 | # print(recovered_sql) 150 | recovered_sqls.append(recovered_sql) 151 | 152 | with open('data/train_pred.sql', 'w') as of: 153 | for each in recovered_sqls: 154 | of.write(each + '\n') 155 | with open('data/eval_train.log', 'w') as of: 156 | old_print = sys.stdout 157 | sys.stdout = of 158 | evaluate('data/train_gold.sql', 'data/train_pred.sql', 'data/database', 'match', kmaps) 159 | sys.stdout = old_print 160 | 161 | recovered_sqls = [] 162 | for idx in range(len(dev)): 163 | sql_ast = spider_trans.surface_code_to_ast(dev[idx]) 164 | sql_ast.sanity_check() 165 | recovered_sql = spider_trans.ast_to_surface_code(sql_ast, dbs[dev_db[idx]]) 166 | recovered_sqls.append(recovered_sql) 167 | with open('data/dev_pred.sql', 'w') as of: 168 | for each in recovered_sqls: 169 | of.write(each + '\n') 170 | with open('data/eval_dev.log', 'w') as of: 171 | old_print = sys.stdout 172 | sys.stdout = of 173 | evaluate('data/dev_gold.sql', 'data/dev_pred.sql', 'data/database', 'match', kmaps) 174 | sys.stdout = old_print 175 | -------------------------------------------------------------------------------- /asdls/sql/unparser/unparser_base.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | 3 | from asdls.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class UnParser(): 7 | 8 | def __init__(self, grammar: ASDLGrammar): 9 | """ ASDLGrammar """ 10 | super(UnParser, self).__init__() 11 | self.grammar = grammar 12 | 13 | @classmethod 14 | def from_grammar(cls, grammar: ASDLGrammar): 15 | grammar_name = grammar._grammar_name 16 | if 'v0' in grammar_name: 17 | from asdls.sql.unparser.unparser_v0 import UnParserV0 18 | return UnParserV0(grammar) 19 | elif 'v1' in grammar_name: 20 | from asdls.sql.unparser.unparser_v1 import UnParserV1 21 | return UnParserV1(grammar) 22 | elif 'v2' in grammar_name: 23 | from asdls.sql.unparser.unparser_v2 import UnParserV2 24 | return UnParserV2(grammar) 25 | else: 26 | raise ValueError('Not recognized grammar name %s' % (grammar_name)) 27 | 28 | def unparse(self, sql_ast: AbstractSyntaxTree, db: dict, *args, **kargs): 29 | try: 30 | sql = self.unparse_sql(sql_ast, db, *args, **kargs) 31 | sql = ' '.join([i for i in sql.split(' ') if i != '']) 32 | return sql 33 | except Exception as e: 34 | print('Something Error happened while unparsing:', e) 35 | return 'SELECT * FROM %s' % (db['table_names_original'][0]) 36 | 37 | def unparse_sql(self, sql_ast: AbstractSyntaxTree, db: dict, *args, **kargs): 38 | prod_name = sql_ast.production.constructor.name 39 | if prod_name == 'Intersect': 40 | return '%s INTERSECT %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs)) 41 | elif prod_name == 'Union': 42 | return '%s UNION %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs)) 43 | elif prod_name == 'Except': 44 | return '%s EXCEPT %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs)) 45 | else: 46 | return self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs) 47 | 48 | def unparse_sql_unit(self, sql_field: RealizedField, db: dict, *args, **kargs): 49 | sql_ast = sql_field.value 50 | prod_name = sql_ast.production.constructor.name 51 | from_str = self.unparse_from(sql_ast.fields[0], db, *args, **kargs) 52 | select_str = self.unparse_select(sql_ast.fields[1], db, *args, **kargs) 53 | if prod_name == 'Complete': 54 | return 'SELECT %s FROM %s WHERE %s GROUP BY %s ORDER BY %s' % ( 55 | select_str, from_str, 56 | self.unparse_where(sql_ast.fields[2], db, *args, **kargs), 57 | self.unparse_groupby(sql_ast.fields[3], db, *args, **kargs), 58 | self.unparse_orderby(sql_ast.fields[4], db, *args, **kargs) 59 | ) 60 | elif prod_name == 'NoWhere': 61 | return 'SELECT %s FROM %s GROUP BY %s ORDER BY %s' % ( 62 | select_str, from_str, 63 | self.unparse_groupby(sql_ast.fields[2], db, *args, **kargs), 64 | self.unparse_orderby(sql_ast.fields[3], db, *args, **kargs), 65 | ) 66 | elif prod_name == 'NoGroupBy': 67 | return 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( 68 | select_str, from_str, 69 | self.unparse_where(sql_ast.fields[2], db, *args, **kargs), 70 | self.unparse_orderby(sql_ast.fields[3], db, *args, **kargs), 71 | ) 72 | elif prod_name == 'NoOrderBy': 73 | return 'SELECT %s FROM %s WHERE %s GROUP BY %s' % ( 74 | select_str, from_str, 75 | self.unparse_where(sql_ast.fields[2], db, *args, **kargs), 76 | self.unparse_groupby(sql_ast.fields[3], db, *args, **kargs), 77 | ) 78 | elif prod_name == 'OnlyWhere': 79 | return 'SELECT %s FROM %s WHERE %s' % ( 80 | select_str, from_str, 81 | self.unparse_where(sql_ast.fields[2], db, *args, **kargs) 82 | ) 83 | elif prod_name == 'OnlyGroupBy': 84 | return 'SELECT %s FROM %s GROUP BY %s' % ( 85 | select_str, from_str, 86 | self.unparse_groupby(sql_ast.fields[2], db, *args, **kargs) 87 | ) 88 | elif prod_name == 'OnlyOrderBy': 89 | return 'SELECT %s FROM %s ORDER BY %s' % ( 90 | select_str, from_str, 91 | self.unparse_orderby(sql_ast.fields[2], db, *args, **kargs) 92 | ) 93 | else: 94 | return 'SELECT %s FROM %s' % (select_str, from_str) 95 | 96 | def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs): 97 | raise NotImplementedError 98 | 99 | def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs): 100 | raise NotImplementedError 101 | 102 | def unparse_where(self, where_field: RealizedField, db: dict, *args, **kargs): 103 | return self.unparse_conds(where_field.value, db, *args, **kargs) 104 | 105 | def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs): 106 | raise NotImplementedError 107 | 108 | def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs): 109 | raise NotImplementedError 110 | 111 | def unparse_conds(self, conds_ast: AbstractSyntaxTree, db: dict, *args, **kargs): 112 | ctr_name = conds_ast.production.constructor.name 113 | if ctr_name in ['And', 'Or']: 114 | left_cond, right_cond = conds_ast.fields 115 | return self.unparse_conds(left_cond.value, db, *args, **kargs) + ' ' + ctr_name.upper() + ' ' + \ 116 | self.unparse_conds(right_cond.value, db, *args, **kargs) 117 | else: 118 | return self.unparse_cond(conds_ast, db, *args, **kargs) 119 | 120 | def unparse_cond(self, cond_ast: AbstractSyntaxTree, db: dict, *args, **kargs): 121 | ctr_name = cond_ast.production.constructor.name 122 | val_unit_str = self.unparse_val_unit(cond_ast.fields[0].value, db, *args, **kargs) 123 | val_str = '( ' + self.unparse_sql(cond_ast.fields[1].value, db, *args, **kargs) + ' )' if len(cond_ast.fields) == 2 else '"value"' 124 | if ctr_name.startswith('Between'): 125 | return val_unit_str + ' BETWEEN ' + val_str + ' AND "value"' 126 | else: 127 | op_dict = { 128 | 'Between': ' BETWEEN ', 'Eq': ' = ', 'Gt': ' > ', 'Lt': ' < ', 'Ge': ' >= ', 'Le': ' <= ', 'Neq': ' != ', 129 | 'In': ' IN ', 'Like': ' LIKE ', 'NotIn': ' NOT IN ', 'NotLike': ' NOT LIKE ' 130 | } 131 | ctr_name = ctr_name if 'SQL' not in ctr_name else ctr_name[:ctr_name.index('SQL')] 132 | op = op_dict[ctr_name] 133 | return op.join([val_unit_str, val_str]) 134 | 135 | def unparse_val_unit(self, val_unit_ast: AbstractSyntaxTree, db: dict, *args, **kargs): 136 | unit_op = val_unit_ast.production.constructor.name 137 | if unit_op == 'Unary': 138 | return self.unparse_col_unit(val_unit_ast.fields[0].value, db, *args, **kargs) 139 | else: 140 | binary = {'Minus': ' - ', 'Plus': ' + ', 'Times': ' * ', 'Divide': ' / '} 141 | op = binary[unit_op] 142 | return op.join([self.unparse_col_unit(val_unit_ast.fields[0].value, db, *args, **kargs), 143 | self.unparse_col_unit(val_unit_ast.fields[1].value, db, *args, **kargs)]) 144 | # col_id1, col_id2 = int(val_unit_ast.fields[0].value), int(val_unit_ast.fields[1].value) 145 | # tab_id1, col_name1 = db['column_names_original'][col_id1] 146 | # if col_id1 != 0: 147 | # tab_name1 = db['table_names_original'][tab_id1] 148 | # col_name1 = tab_name1 + '.' + col_name1 149 | # tab_id2, col_name2 = db['column_names_original'][col_id2] 150 | # if col_id2 != 0: 151 | # tab_name2 = db['table_names_original'][tab_id2] 152 | # col_name2 = tab_name2 + '.' + col_name2 153 | # return op.join([col_name1, col_name2]) 154 | 155 | def unparse_col_unit(self, col_unit_ast: AbstractSyntaxTree, db: dict, *args, **kargs): 156 | agg = col_unit_ast.production.constructor.name 157 | col_id = int(col_unit_ast.fields[0].value) 158 | tab_id, col_name = db['column_names_original'][col_id] 159 | if col_id != 0: 160 | tab_name = db['table_names_original'][tab_id] 161 | col_name = tab_name + '.' + col_name 162 | if agg == 'None': 163 | return col_name 164 | else: # Max/Min/Count/Sum/Avg 165 | return agg.upper() + '(' + col_name + ')' 166 | -------------------------------------------------------------------------------- /asdls/sql/unparser/unparser_v0.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from asdls.sql.unparser.unparser_base import UnParser 3 | from asdls.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class UnParserV0(UnParser): 7 | 8 | def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs): 9 | select_list = select_field.value 10 | select_items = [] 11 | for val_unit_ast in select_list: 12 | val_unit_str = self.unparse_val_unit(val_unit_ast, db, *args, **kargs) 13 | select_items.append(val_unit_str) 14 | return ' , '.join(select_items) 15 | 16 | def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs): 17 | from_ast = from_field.value 18 | ctr_name = from_ast.production.constructor.name 19 | if ctr_name == 'FromTable': 20 | tab_ids = from_ast.fields[0].value 21 | if len(tab_ids) == 1: 22 | return db['table_names_original'][tab_ids[0]] 23 | else: 24 | tab_names = [db['table_names_original'][i] for i in tab_ids] 25 | return ' JOIN '.join(tab_names) 26 | else: 27 | sql_ast = from_ast.fields[0].value 28 | return '( ' + self.unparse_sql(sql_ast, db, *args, **kargs) + ' )' 29 | 30 | def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs): 31 | groupby_ast = groupby_field.value 32 | ctr_name = groupby_ast.production.constructor.name 33 | groupby_str = [] 34 | for col_unit_ast in groupby_ast.fields[0].value: 35 | groupby_str.append(self.unparse_col_unit(col_unit_ast, db, *args, **kargs)) 36 | groupby_str = ' , '.join(groupby_str) 37 | if ctr_name == 'Having': 38 | having = groupby_ast.fields[1].value 39 | having_str = self.unparse_conds(having, db, *args, **kargs) 40 | return groupby_str + ' HAVING ' + having_str 41 | else: 42 | return groupby_str 43 | 44 | def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs): 45 | orderby_ast = orderby_field.value 46 | ctr_name = orderby_ast.production.constructor.name.lower() 47 | val_unit_str = [] 48 | for val_unit_ast in orderby_ast.fields[0].value: 49 | val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs)) 50 | # val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs)) 51 | val_unit_str = ' , '.join(val_unit_str) 52 | if 'asc' in ctr_name and 'limit' in ctr_name: 53 | return '%s ASC LIMIT 1' % (val_unit_str) 54 | elif 'asc' in ctr_name: 55 | return '%s ASC' % (val_unit_str) 56 | elif 'desc' in ctr_name and 'limit' in ctr_name: 57 | return '%s DESC LIMIT 1' % (val_unit_str) 58 | else: 59 | return '%s DESC' % (val_unit_str) 60 | -------------------------------------------------------------------------------- /asdls/sql/unparser/unparser_v1.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from asdls.sql.unparser.unparser_base import UnParser 3 | from asdls.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class UnParserV1(UnParser): 7 | 8 | def unparse_sql_unit(self, sql_field: RealizedField, db: dict, *args, **kargs): 9 | sql_ast = sql_field.value 10 | from_field, select_field, where_field, groupby_field, orderby_field = sql_ast.fields 11 | from_str = 'FROM ' + self.unparse_from(from_field, db, *args, **kargs) 12 | select_str = 'SELECT ' + self.unparse_select(select_field, db, *args, **kargs) 13 | where_str, groupby_str, orderby_str = '', '', '' 14 | if where_field.value is not None: 15 | where_str = 'WHERE ' + self.unparse_where(where_field, db, *args, **kargs) 16 | if groupby_field.value is not None: 17 | groupby_str = 'GROUP BY ' + self.unparse_groupby(groupby_field, db, *args, **kargs) 18 | if orderby_field.value is not None: 19 | orderby_str = 'ORDER BY ' + self.unparse_orderby(orderby_field, db, *args, **kargs) 20 | return ' '.join([select_str, from_str, where_str, groupby_str, orderby_str]) 21 | 22 | def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs): 23 | select_ast = select_field.value 24 | select_list = select_ast.fields 25 | select_items = [] 26 | for val_unit_field in select_list: 27 | val_unit_str = self.unparse_val_unit(val_unit_field.value, db, *args, **kargs) 28 | select_items.append(val_unit_str) 29 | return ' , '.join(select_items) 30 | 31 | def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs): 32 | from_ast = from_field.value 33 | ctr_name = from_ast.production.constructor.name 34 | if 'Table' in ctr_name: 35 | tab_names = [] 36 | for tab_field in from_ast.fields: 37 | tab_name = db['table_names_original'][int(tab_field.value)] 38 | tab_names.append(tab_name) 39 | return ' JOIN '.join(tab_names) 40 | else: 41 | return '( ' + self.unparse_sql(from_ast.fields[0].value, db, *args, **kargs) + ' )' 42 | 43 | def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs): 44 | groupby_ast = groupby_field.value 45 | ctr_name = groupby_ast.production.constructor.name 46 | groupby_str = [] 47 | num = len(groupby_ast.fields) - 1 48 | for col_id_field in groupby_ast.fields[:num]: 49 | # col_id = int(col_id_field.value) 50 | # tab_id, col_name = db['column_names_original'][col_id] 51 | # if col_id != 0: 52 | # tab_name = db['table_names_original'][tab_id] 53 | # col_name = tab_name + '.' + col_name 54 | col_name = self.unparse_col_unit(col_id_field.value, db, *args, **kargs) 55 | groupby_str.append(col_name) 56 | groupby_str = ' , '.join(groupby_str) 57 | if groupby_ast.fields[-1].value is not None: 58 | having = groupby_ast.fields[-1].value 59 | having_str = self.unparse_conds(having, db, *args, **kargs) 60 | return groupby_str + ' HAVING ' + having_str 61 | else: 62 | return groupby_str 63 | 64 | def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs): 65 | orderby_ast = orderby_field.value 66 | ctr_name = orderby_ast.production.constructor.name.lower() 67 | val_unit_str = [] 68 | for val_unit_field in orderby_ast.fields: 69 | val_unit_ast = val_unit_field.value 70 | val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs)) 71 | # val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs)) 72 | val_unit_str = ' , '.join(val_unit_str) 73 | if 'asc' in ctr_name and 'limit' in ctr_name: 74 | return '%s ASC LIMIT 1' % (val_unit_str) 75 | elif 'asc' in ctr_name: 76 | return '%s ASC' % (val_unit_str) 77 | elif 'desc' in ctr_name and 'limit' in ctr_name: 78 | return '%s DESC LIMIT 1' % (val_unit_str) 79 | else: 80 | return '%s DESC' % (val_unit_str) -------------------------------------------------------------------------------- /asdls/sql/unparser/unparser_v2.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from asdls.sql.unparser.unparser_base import UnParser 3 | from asdls.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction 4 | from asdls.asdl_ast import RealizedField, AbstractSyntaxTree 5 | 6 | class UnParserV2(UnParser): 7 | 8 | def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs): 9 | select_ast = select_field.value 10 | select_list = select_ast.fields 11 | select_items = [] 12 | for val_unit_field in select_list: 13 | val_unit_str = self.unparse_val_unit(val_unit_field.value, db, *args, **kargs) 14 | select_items.append(val_unit_str) 15 | return ' , '.join(select_items) 16 | 17 | def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs): 18 | from_ast = from_field.value 19 | ctr_name = from_ast.production.constructor.name 20 | if 'Table' in ctr_name: 21 | tab_names = [] 22 | for tab_field in from_ast.fields: 23 | tab_name = db['table_names_original'][int(tab_field.value)] 24 | tab_names.append(tab_name) 25 | return ' JOIN '.join(tab_names) 26 | else: 27 | return '( ' + self.unparse_sql(from_ast.fields[0].value, db, *args, **kargs) + ' )' 28 | 29 | def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs): 30 | groupby_ast = groupby_field.value 31 | ctr_name = groupby_ast.production.constructor.name 32 | groupby_str = [] 33 | num = len(groupby_ast.fields) if 'NoHaving' in ctr_name else len(groupby_ast.fields) - 1 34 | for col_id_field in groupby_ast.fields[:num]: 35 | # col_id = int(col_id_field.value) 36 | # tab_id, col_name = db['column_names_original'][col_id] 37 | # if col_id != 0: 38 | # tab_name = db['table_names_original'][tab_id] 39 | # col_name = tab_name + '.' + col_name 40 | col_name = self.unparse_col_unit(col_id_field.value, db, *args, **kargs) 41 | groupby_str.append(col_name) 42 | groupby_str = ' , '.join(groupby_str) 43 | if 'NoHaving' in ctr_name: 44 | return groupby_str 45 | else: 46 | having = groupby_ast.fields[-1].value 47 | having_str = self.unparse_conds(having, db, *args, **kargs) 48 | return groupby_str + ' HAVING ' + having_str 49 | 50 | def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs): 51 | orderby_ast = orderby_field.value 52 | ctr_name = orderby_ast.production.constructor.name.lower() 53 | val_unit_str = [] 54 | for val_unit_field in orderby_ast.fields: 55 | val_unit_ast = val_unit_field.value 56 | val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs)) 57 | # val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs)) 58 | val_unit_str = ' , '.join(val_unit_str) 59 | if 'asc' in ctr_name and 'limit' in ctr_name: 60 | return '%s ASC LIMIT 1' % (val_unit_str) 61 | elif 'asc' in ctr_name: 62 | return '%s ASC' % (val_unit_str) 63 | elif 'desc' in ctr_name and 'limit' in ctr_name: 64 | return '%s DESC LIMIT 1' % (val_unit_str) 65 | else: 66 | return '%s DESC' % (val_unit_str) -------------------------------------------------------------------------------- /asdls/transition_system.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | class Action(object): 5 | pass 6 | 7 | 8 | class ApplyRuleAction(Action): 9 | def __init__(self, production): 10 | self.production = production 11 | 12 | def __hash__(self): 13 | return hash(self.production) 14 | 15 | def __eq__(self, other): 16 | return isinstance(other, ApplyRuleAction) and self.production == other.production 17 | 18 | def __ne__(self, other): 19 | return not self.__eq__(other) 20 | 21 | def __repr__(self): 22 | return 'ApplyRule[%s]' % self.production.__repr__() 23 | 24 | 25 | class GenTokenAction(Action): 26 | def __init__(self, token): 27 | self.token = token 28 | 29 | def is_stop_signal(self): 30 | return self.token == '' 31 | 32 | def __repr__(self): 33 | return 'GenToken[%s]' % self.token 34 | 35 | 36 | class ReduceAction(Action): 37 | def __repr__(self): 38 | return 'Reduce' 39 | 40 | 41 | class TransitionSystem(object): 42 | def __init__(self, grammar): 43 | self.grammar = grammar 44 | 45 | def get_actions(self, asdl_ast): 46 | """ 47 | generate action sequence given the ASDL Syntax Tree 48 | """ 49 | 50 | actions = [] 51 | 52 | parent_action = ApplyRuleAction(asdl_ast.production) 53 | actions.append(parent_action) 54 | 55 | for field in asdl_ast.fields: 56 | # is a composite field 57 | if self.grammar.is_composite_type(field.type): 58 | if field.cardinality == 'single': 59 | field_actions = self.get_actions(field.value) 60 | else: 61 | field_actions = [] 62 | 63 | if field.value is not None: 64 | if field.cardinality == 'multiple': 65 | for val in field.value: 66 | cur_child_actions = self.get_actions(val) 67 | field_actions.extend(cur_child_actions) 68 | elif field.cardinality == 'optional': 69 | field_actions = self.get_actions(field.value) 70 | 71 | # if an optional field is filled, then do not need Reduce action 72 | if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions: 73 | field_actions.append(ReduceAction()) 74 | else: # is a primitive field 75 | field_actions = self.get_primitive_field_actions(field) 76 | 77 | # if an optional field is filled, then do not need Reduce action 78 | if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions: 79 | # reduce action 80 | field_actions.append(ReduceAction()) 81 | 82 | actions.extend(field_actions) 83 | 84 | return actions 85 | 86 | def tokenize_code(self, code, mode): 87 | raise NotImplementedError 88 | 89 | def compare_ast(self, hyp_ast, ref_ast): 90 | raise NotImplementedError 91 | 92 | def ast_to_surface_code(self, asdl_ast): 93 | raise NotImplementedError 94 | 95 | def surface_code_to_ast(self, code): 96 | raise NotImplementedError 97 | 98 | def get_primitive_field_actions(self, realized_field): 99 | raise NotImplementedError 100 | 101 | def get_valid_continuation_types(self, hyp): 102 | if hyp.tree: 103 | if self.grammar.is_composite_type(hyp.frontier_field.type): 104 | if hyp.frontier_field.cardinality == 'single': 105 | return ApplyRuleAction, 106 | else: # optional, multiple 107 | return ApplyRuleAction, ReduceAction 108 | else: 109 | if hyp.frontier_field.cardinality == 'single': 110 | return GenTokenAction, 111 | elif hyp.frontier_field.cardinality == 'optional': 112 | if hyp._value_buffer: 113 | return GenTokenAction, 114 | else: 115 | return GenTokenAction, ReduceAction 116 | else: 117 | return GenTokenAction, ReduceAction 118 | else: 119 | return ApplyRuleAction, 120 | 121 | def get_valid_continuating_productions(self, hyp): 122 | if hyp.tree: 123 | if self.grammar.is_composite_type(hyp.frontier_field.type): 124 | return self.grammar[hyp.frontier_field.type] 125 | else: 126 | raise ValueError 127 | else: 128 | return self.grammar[self.grammar.root_type] 129 | 130 | @staticmethod 131 | def get_class_by_lang(lang): 132 | if lang == 'sql': 133 | from asdls.sql.sql_transition_system import SQLTransitionSystem 134 | else: 135 | raise ValueError('unknown language %s' % lang) 136 | return SQLTransitionSystem 137 | -------------------------------------------------------------------------------- /model/decoder/onlstm.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | """ ONLSTM and traditional LSTM with locked dropout """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def cumsoftmax(x, dim=-1): 8 | return torch.cumsum(F.softmax(x, dim=dim), dim=dim) 9 | 10 | class LinearDropConnect(nn.Linear): 11 | """ Used in recurrent connection dropout """ 12 | def __init__(self, in_features, out_features, bias=True, dropconnect=0.): 13 | super(LinearDropConnect, self).__init__(in_features=in_features, out_features=out_features, bias=bias) 14 | self.dropconnect = dropconnect 15 | 16 | def sample_mask(self): 17 | if self.dropconnect == 0.: 18 | self._weight = self.weight.clone() 19 | else: 20 | mask = self.weight.new_zeros(self.weight.size(), dtype=torch.bool) 21 | mask.bernoulli_(self.dropconnect) 22 | self._weight = self.weight.masked_fill(mask, 0.) 23 | 24 | def forward(self, inputs, sample_mask=False): 25 | if self.training: 26 | if sample_mask: 27 | self.sample_mask() 28 | return F.linear(inputs, self._weight, self.bias) # apply the same mask to weight matrix in linear module 29 | else: 30 | return F.linear(inputs, self.weight * (1 - self.dropconnect), self.bias) 31 | 32 | class LockedDropout(nn.Module): 33 | """ Used in dropout between layers """ 34 | def __init__(self, hidden_size, num_layers=1, dropout=0.2): 35 | super(LockedDropout, self).__init__() 36 | self.hidden_size = hidden_size 37 | self.num_layers = num_layers 38 | self.dropout = dropout 39 | 40 | def sample_masks(self, x): 41 | self.masks = [] 42 | for _ in range(self.num_layers - 1): 43 | mask = x.new_zeros(x.size(0), 1, self.hidden_size).bernoulli_(1 - self.dropout) 44 | mask = mask.div_(1 - self.dropout) 45 | mask.requires_grad = False 46 | self.masks.append(mask) 47 | 48 | def forward(self, x, layer=0): 49 | """ x: bsize x seqlen x hidden_size """ 50 | if (not self.training) or self.dropout == 0. or layer == self.num_layers - 1: # output hidden states, no dropout 51 | return x 52 | mask = self.masks[layer] 53 | mask = mask.expand_as(x) 54 | return mask * x 55 | 56 | class RecurrentNeuralNetwork(nn.Module): 57 | def init_hiddens(self, x): 58 | return x.new_zeros(self.num_layers, x.size(0), self.hidden_size), \ 59 | x.new_zeros(self.num_layers, x.size(0), self.hidden_size) 60 | 61 | def forward(self, inputs, hiddens=None, start=False, layerwise=False): 62 | """ 63 | @args: 64 | start: whether sampling locked masks for recurrent connections and between layers 65 | layerwise: whether return a list, results of intermediate layer outputs 66 | @return: 67 | outputs: bsize x seqlen x hidden_size 68 | final_hiddens: hT and cT, each of size: num_layers x bsize x hidden_size 69 | """ 70 | assert inputs.dim() == 3 71 | if hiddens is None: 72 | hiddens = self.init_hiddens(inputs) 73 | bsize, seqlen, _ = list(inputs.size()) 74 | prev_state = list(hiddens) # each of size: num_layers, bsize, hidden_size 75 | prev_layer = inputs # size: bsize, seqlen, input_size 76 | each_layer_outputs, final_h, final_c = [], [], [] 77 | 78 | if self.training and start: 79 | for c in self.cells: 80 | c.sample_masks() 81 | self.locked_dropout.sample_masks(inputs) 82 | 83 | for l in range(len(self.cells)): 84 | curr_layer = [None] * seqlen 85 | curr_inputs = self.cells[l].ih(prev_layer) 86 | next_h, next_c = prev_state[0][l], prev_state[1][l] 87 | for t in range(seqlen): 88 | hidden, cell = self.cells[l](None, (next_h, next_c), transformed_inputs=curr_inputs[:, t]) 89 | next_h, next_c = hidden, cell # overwritten every timestep 90 | curr_layer[t] = hidden 91 | 92 | prev_layer = torch.stack(curr_layer, dim=1) # bsize x seqlen x hidden_size 93 | each_layer_outputs.append(prev_layer) 94 | final_h.append(next_h) 95 | final_c.append(next_c) 96 | prev_layer = self.locked_dropout(prev_layer, layer=l) 97 | 98 | outputs, final_hiddens = prev_layer, (torch.stack(final_h, dim=0), torch.stack(final_c, dim=0)) 99 | if layerwise: 100 | return outputs, final_hiddens, each_layer_outputs 101 | else: 102 | return outputs, final_hiddens 103 | 104 | class LSTMCell(nn.Module): 105 | 106 | def __init__(self, input_size, hidden_size, bias=True, dropconnect=0.): 107 | super(LSTMCell, self).__init__() 108 | self.input_size = input_size 109 | self.hidden_size = hidden_size 110 | self.ih = nn.Linear(input_size, hidden_size * 4, bias=bias) 111 | self.hh = LinearDropConnect(hidden_size, hidden_size * 4, bias=bias, dropconnect=dropconnect) 112 | self.drop_weight_modules = [self.hh] 113 | 114 | def sample_masks(self): 115 | for m in self.drop_weight_modules: 116 | m.sample_mask() 117 | 118 | def forward(self, inputs, hiddens, transformed_inputs=None): 119 | """ 120 | inputs: bsize x input_size 121 | hiddens: tuple of h0 (bsize x hidden_size) and c0 (bsize x hidden_size) 122 | transformed_inputs: short cut for inputs, save time if seq len is already provied in training 123 | return tuple of h1 (bsize x hidden_size) and c1 (bsize x hidden_size) 124 | """ 125 | if transformed_inputs is None: 126 | transformed_inputs = self.ih(inputs) 127 | h0, c0 = hiddens 128 | gates = transformed_inputs + self.hh(h0) 129 | ingate, forgetgate, outgate, cell = gates.contiguous().\ 130 | view(-1, 4, self.hidden_size).chunk(4, 1) 131 | forgetgate = torch.sigmoid(forgetgate.squeeze(1)) 132 | ingate = torch.sigmoid(ingate.squeeze(1)) 133 | cell = torch.tanh(cell.squeeze(1)) 134 | outgate = torch.sigmoid(outgate.squeeze(1)) 135 | c1 = forgetgate * c0 + ingate * cell 136 | h1 = outgate * torch.tanh(c1) 137 | return h1, c1 138 | 139 | class LSTM(RecurrentNeuralNetwork): 140 | 141 | def __init__(self, input_size, hidden_size, num_layers=1, chunk_num=1, bias=True, dropout=0., dropconnect=0.): 142 | super(LSTM, self).__init__() 143 | self.input_size = input_size 144 | self.hidden_size = hidden_size 145 | self.num_layers = num_layers 146 | self.cells = nn.ModuleList( 147 | [LSTMCell(input_size, hidden_size, bias, dropconnect)] + 148 | [LSTMCell(hidden_size, hidden_size, bias, dropconnect) for i in range(num_layers - 1)] 149 | ) 150 | self.locked_dropout = LockedDropout(hidden_size, num_layers, dropout) # dropout rate between layers 151 | 152 | class ONLSTMCell(nn.Module): 153 | 154 | def __init__(self, input_size, hidden_size, chunk_num=8, bias=True, dropconnect=0.2): 155 | super(ONLSTMCell, self).__init__() 156 | self.input_size = input_size 157 | self.hidden_size = hidden_size 158 | self.chunk_num = chunk_num # chunk_num should be divided by hidden_size 159 | if self.hidden_size % self.chunk_num != 0: 160 | raise ValueError('[Error]: chunk number must be divided by hidden size in ONLSTM Cell') 161 | self.chunk_size = int(hidden_size / chunk_num) 162 | 163 | self.ih = nn.Linear(input_size, self.chunk_size * 2 + hidden_size * 4, bias=bias) 164 | self.hh = LinearDropConnect(hidden_size, self.chunk_size * 2 + hidden_size * 4, bias=bias, dropconnect=dropconnect) 165 | self.drop_weight_modules = [self.hh] 166 | 167 | def sample_masks(self): 168 | for m in self.drop_weight_modules: 169 | m.sample_mask() 170 | 171 | def forward(self, inputs, hiddens, transformed_inputs=None): 172 | """ 173 | inputs: bsize x input_size 174 | hiddens: tuple of h0 (bsize x hidden_size) and c0 (bsize x hidden_size) 175 | transformed_inputs: short cut for inputs, save time if seq len is already provied in training 176 | return tuple of h1 (bsize x hidden_size) and c1 (bsize x hidden_size) 177 | """ 178 | if transformed_inputs is None: 179 | transformed_inputs = self.ih(inputs) 180 | h0, c0 = hiddens 181 | gates = transformed_inputs + self.hh(h0) 182 | cingate, cforgetgate = gates[:, :self.chunk_size * 2].chunk(2, 1) 183 | ingate, forgetgate, outgate, cell = gates[:, self.chunk_size * 2:].contiguous().\ 184 | view(-1, self.chunk_size * 4, self.chunk_num).chunk(4, 1) 185 | 186 | cingate = 1. - cumsoftmax(cingate) 187 | cforgetgate = cumsoftmax(cforgetgate) 188 | cingate = cingate[:, :, None] 189 | cforgetgate = cforgetgate[:, :, None] 190 | 191 | forgetgate = torch.sigmoid(forgetgate) 192 | ingate = torch.sigmoid(ingate) 193 | cell = torch.tanh(cell) 194 | outgate = torch.sigmoid(outgate) 195 | 196 | overlap = cforgetgate * cingate 197 | forgetgate = forgetgate * overlap + (cforgetgate - overlap) 198 | ingate = ingate * overlap + (cingate - overlap) 199 | c0 = c0.contiguous().view(-1, self.chunk_size, self.chunk_num) 200 | c1 = forgetgate * c0 + ingate * cell 201 | h1 = outgate * torch.tanh(c1) 202 | return h1.contiguous().view(-1, self.hidden_size), c1.contiguous().view(-1, self.hidden_size) 203 | 204 | class ONLSTM(RecurrentNeuralNetwork): 205 | 206 | def __init__(self, input_size, hidden_size, num_layers=1, chunk_num=8, bias=True, dropout=0., dropconnect=0.): 207 | super(ONLSTM, self).__init__() 208 | self.input_size = input_size 209 | self.hidden_size = hidden_size 210 | self.num_layers = num_layers 211 | self.cells = nn.ModuleList( 212 | [ONLSTMCell(input_size, hidden_size, chunk_num, bias, dropconnect)] + 213 | [ONLSTMCell(hidden_size, hidden_size, chunk_num, bias, dropconnect) for i in range(num_layers - 1)] 214 | ) 215 | self.locked_dropout = LockedDropout(hidden_size, num_layers, dropout) # dropout rate between layers 216 | -------------------------------------------------------------------------------- /model/encoder/functions.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import dgl, math, torch 3 | 4 | def src_dot_dst(src_field, dst_field, out_field): 5 | def func(edges): 6 | return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)} 7 | 8 | return func 9 | 10 | def src_sum_edge_mul_dst(src_field, dst_field, e_field, out_field): 11 | def func(edges): 12 | return {out_field: ((edges.src[src_field] + edges.data[e_field]) * edges.dst[dst_field]).sum(-1, keepdim=True)} 13 | 14 | return func 15 | 16 | def src_sum_edge_mul_dst2(src_field, dst_field, e_field, e_field2, weight_field, out_field): 17 | def func(edges): 18 | return {out_field: ((edges.src[src_field] + edges.data[e_field] + edges.data[weight_field]*edges.data[e_field2]) * edges.dst[dst_field]).sum(-1, keepdim=True)} 19 | # return {out_field: ( (edges.src[src_field] + edges.data[weight_field] * edges.data[e_field2]) * edges.dst[dst_field]).sum(-1, keepdim=True)} 20 | 21 | return func 22 | 23 | def scaled_exp(field, scale_constant): 24 | def func(edges): 25 | # clamp for softmax numerical stability 26 | return {field: torch.exp((edges.data[field] / scale_constant).clamp(-10, 10))} 27 | 28 | return func 29 | 30 | def src_sum_edge_mul_edge(src_field, e_field1, e_field2, out_field): 31 | def func(edges): 32 | return {out_field: (edges.src[src_field] + edges.data[e_field1]) * edges.data[e_field2]} 33 | 34 | return func 35 | 36 | 37 | def src_sum_edge_mul_edge2(src_field, e_field1, e_field2, weight_field, score_field2, out_field): 38 | def func(edges): 39 | return {out_field: (edges.src[src_field] + edges.data[e_field1] + edges.data[weight_field]*edges.data[e_field2]) * edges.data[score_field2]} 40 | # return {out_field: (edges.src[src_field] + edges.data[weight_field] * edges.data[e_field2]) * edges.data[score_field2]} 41 | 42 | return func 43 | 44 | def div_by_z(in_field, norm_field, out_field): 45 | def func(nodes): 46 | return {out_field: nodes.data[in_field] / nodes.data[norm_field]} 47 | 48 | return func 49 | -------------------------------------------------------------------------------- /model/encoder/graph_encoder.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | from model.encoder.rgatsql import RGATSQL 5 | from model.encoder.graph_output import * 6 | from model.model_utils import Registrable 7 | 8 | 9 | class Text2SQLEncoder(nn.Module): 10 | 11 | def __init__(self, args): 12 | super(Text2SQLEncoder, self).__init__() 13 | self.hidden_layer = RGATSQL(args) 14 | self.output_layer = Registrable.by_name(args.output_model)(args) 15 | def forward(self, batch, x): 16 | outputs = self.hidden_layer(x, batch) 17 | return self.output_layer(outputs, batch) 18 | -------------------------------------------------------------------------------- /model/encoder/graph_output.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import dgl.function as fn 7 | from model.model_utils import Registrable 8 | from model.encoder.functions import scaled_exp, div_by_z, src_dot_dst 9 | 10 | class ScoreFunction(nn.Module): 11 | 12 | def __init__(self, hidden_size, mlp=1, method='biaffine'): 13 | super(ScoreFunction, self).__init__() 14 | assert method in ['dot', 'bilinear', 'affine', 'biaffine'] 15 | self.mlp = int(mlp) 16 | self.hidden_size = hidden_size // self.mlp 17 | if self.mlp > 1: # use mlp to perform dim reduction 18 | self.mlp_q = nn.Sequential(nn.Linear(hidden_size, self.hidden_size), nn.Tanh()) 19 | self.mlp_s = nn.Sequential(nn.Linear(hidden_size, self.hidden_size), nn.Tanh()) 20 | self.method = method 21 | if self.method == 'bilinear': 22 | self.W = nn.Linear(self.hidden_size, self.hidden_size) 23 | elif self.method == 'affine': 24 | self.affine = nn.Linear(self.hidden_size * 2, 1) 25 | elif self.method == 'biaffine': 26 | self.W = nn.Linear(self.hidden_size, self.hidden_size) 27 | self.affine = nn.Linear(self.hidden_size * 2, 1) 28 | 29 | def forward(self, context, node): 30 | """ 31 | @args: 32 | context(torch.FloatTensor): num_nodes x hidden_size 33 | node(torch.FloatTensor): num_nodes x hidden_size 34 | @return: 35 | scores(torch.FloatTensor): num_nodes 36 | """ 37 | if self.mlp > 1: 38 | context, node = self.mlp_q(context), self.mlp_s(node) 39 | if self.method == 'dot': 40 | scores = (context * node).sum(dim=-1) 41 | elif self.method == 'bilinear': 42 | scores = (context * self.W(node)).sum(dim=-1) 43 | elif self.method == 'affine': 44 | scores = self.affine(torch.cat([context, node], dim=-1)).squeeze(-1) 45 | elif self.method == 'biaffine': 46 | scores = (context * self.W(node)).sum(dim=-1) 47 | scores += self.affine(torch.cat([context, node], dim=-1)).squeeze(-1) 48 | else: 49 | raise ValueError('[Error]: Unrecognized score function method %s!' % (self.method)) 50 | return scores 51 | 52 | @Registrable.register('without_pruning') 53 | class GraphOutputLayer(nn.Module): 54 | 55 | def __init__(self, args): 56 | super(GraphOutputLayer, self).__init__() 57 | self.hidden_size = args.gnn_hidden_size 58 | 59 | def forward(self, inputs, batch): 60 | """ Re-scatter data format: 61 | inputs: sum(q_len + t_len + c_len) x hidden_size 62 | outputs: bsize x (max_q_len + max_t_len + max_c_len) x hidden_size 63 | """ 64 | outputs = inputs.new_zeros(len(batch), batch.mask.size(1), self.hidden_size) 65 | outputs = outputs.masked_scatter_(batch.mask.unsqueeze(-1), inputs) 66 | if self.training: 67 | return outputs, batch.mask, torch.tensor(0., dtype=torch.float).to(outputs.device) 68 | else: 69 | return outputs, batch.mask 70 | 71 | @Registrable.register('with_pruning') 72 | class GraphOutputLayerWithPruning(nn.Module): 73 | 74 | def __init__(self, args): 75 | super(GraphOutputLayerWithPruning, self).__init__() 76 | self.hidden_size = args.gnn_hidden_size 77 | self.graph_pruning = GraphPruning(self.hidden_size, args.num_heads, args.dropout, args.score_function) 78 | 79 | def forward(self, inputs, batch): 80 | outputs = inputs.new_zeros(len(batch), batch.mask.size(1), self.hidden_size) 81 | outputs = outputs.masked_scatter_(batch.mask.unsqueeze(-1), inputs) 82 | 83 | if self.training: 84 | g = batch.graph 85 | question = inputs.masked_select(g.question_mask.unsqueeze(-1)).view(-1, self.hidden_size) 86 | schema = inputs.masked_select(g.schema_mask.unsqueeze(-1)).view(-1, self.hidden_size) 87 | loss = self.graph_pruning(question, schema, g.gp, g.node_label) 88 | return outputs, batch.mask, loss 89 | else: 90 | return outputs, batch.mask 91 | 92 | class GraphPruning(nn.Module): 93 | 94 | def __init__(self, hidden_size, num_heads=8, feat_drop=0.2, score_function='affine'): 95 | super(GraphPruning, self).__init__() 96 | self.hidden_size = hidden_size 97 | self.node_mha = DGLMHA(hidden_size, hidden_size, num_heads, feat_drop) 98 | self.node_score_function = ScoreFunction(self.hidden_size, mlp=2, method=score_function) 99 | self.loss_function = nn.BCEWithLogitsLoss(reduction='sum') 100 | 101 | def forward(self, question, schema, graph, node_label): 102 | node_context = self.node_mha(question, schema, graph) 103 | node_score = self.node_score_function(node_context, schema) 104 | loss = self.loss_function(node_score, node_label) 105 | return loss 106 | 107 | class DGLMHA(nn.Module): 108 | """ Multi-head attention implemented with DGL lib 109 | """ 110 | def __init__(self, hidden_size, output_size, num_heads=8, feat_drop=0.2): 111 | super(DGLMHA, self).__init__() 112 | self.hidden_size = hidden_size 113 | self.output_size = output_size 114 | self.num_heads = num_heads 115 | self.d_k = self.hidden_size // self.num_heads 116 | self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.output_size, self.hidden_size),\ 117 | nn.Linear(self.hidden_size, self.hidden_size, bias=False), nn.Linear(self.hidden_size, self.hidden_size, bias=False) 118 | self.affine_o = nn.Linear(self.hidden_size, self.output_size) 119 | self.feat_dropout = nn.Dropout(p=feat_drop) 120 | 121 | def forward(self, context, node, g): 122 | q, k, v = self.affine_q(self.feat_dropout(node)), self.affine_k(self.feat_dropout(context)), self.affine_v(self.feat_dropout(context)) 123 | with g.local_scope(): 124 | g.nodes['schema'].data['q'] = q.view(-1, self.num_heads, self.d_k) 125 | g.nodes['question'].data['k'] = k.view(-1, self.num_heads, self.d_k) 126 | g.nodes['question'].data['v'] = v.view(-1, self.num_heads, self.d_k) 127 | out_x = self.propagate_attention(g) 128 | return self.affine_o(out_x.view(-1, self.hidden_size)) 129 | 130 | def propagate_attention(self, g): 131 | # Compute attention score 132 | g.apply_edges(src_dot_dst('k', 'q', 'score')) 133 | g.apply_edges(scaled_exp('score', math.sqrt(self.d_k))) 134 | # Update node state 135 | g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv')) 136 | g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o')) 137 | out_x = g.nodes['schema'].data['o'] 138 | return out_x -------------------------------------------------------------------------------- /model/encoder/rgatsql.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import copy, math 3 | import torch, dgl 4 | import dgl.function as fn 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from model.model_utils import Registrable, FFN 8 | from model.encoder.functions import * 9 | 10 | class RGATSQL(nn.Module): 11 | 12 | def __init__(self, args): 13 | super(RGATSQL, self).__init__() 14 | self.num_layers = args.gnn_num_layers 15 | self.relation_num = args.relation_num 16 | self.relation_share_layers, self.relation_share_heads = args.relation_share_layers, args.relation_share_heads 17 | edim = args.gnn_hidden_size // args.num_heads if self.relation_share_heads else args.gnn_hidden_size 18 | if self.relation_share_layers: 19 | self.relation_embed = nn.Embedding(args.relation_num, edim) 20 | else: 21 | self.relation_embed = nn.ModuleList([nn.Embedding(args.relation_num, edim) for _ in range(self.num_layers)]) 22 | gnn_layer = RGATLayer 23 | self.gnn_layers = nn.ModuleList([gnn_layer(args.gnn_hidden_size, edim, num_heads=args.num_heads, feat_drop=args.dropout, optimize_graph=args.optimize_graph) 24 | for _ in range(self.num_layers)]) 25 | self.optimize_graph = args.optimize_graph 26 | 27 | def forward(self, x, batch): 28 | if self.optimize_graph: 29 | global_edges2 = batch.graph.global_edges2 30 | 31 | graph = batch.graph.global_g 32 | edges = batch.graph.global_edges 33 | # edges, mask = batch.graph.global_edges, batch.graph.local_mask 34 | if self.relation_share_layers: 35 | lgx = self.relation_embed(edges) 36 | lgx = lgx 37 | for i in range(self.num_layers): 38 | lgx = lgx if self.relation_share_layers else self.relation_embed[i](edges) 39 | if self.optimize_graph: 40 | lgx2 = self.relation_embed[i](global_edges2) 41 | x, lgx = self.gnn_layers[i](x, lgx, graph, lgx2) 42 | else: 43 | x, lgx = self.gnn_layers[i](x, lgx, graph) 44 | return x 45 | 46 | 47 | class RGATLayer(nn.Module): 48 | 49 | def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2, optimize_graph=False): 50 | super(RGATLayer, self).__init__() 51 | self.ndim, self.edim = ndim, edim 52 | self.num_heads = num_heads 53 | dim = max([ndim, edim]) 54 | self.d_k = dim // self.num_heads 55 | self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.ndim, dim),\ 56 | nn.Linear(self.ndim, dim, bias=False), nn.Linear(self.ndim, dim, bias=False) 57 | self.affine_o = nn.Linear(dim, self.ndim) 58 | self.layernorm = nn.LayerNorm(self.ndim) 59 | self.feat_dropout = nn.Dropout(p=feat_drop) 60 | self.ffn = FFN(self.ndim) 61 | self.optimize_graph = optimize_graph 62 | 63 | def forward(self, x, lgx, g, lgx2=None): 64 | """ @Params: 65 | x: node feats, num_nodes x ndim 66 | lgx: edge feats, num_edges x edim 67 | g: dgl.graph 68 | """ 69 | # pre-mapping q/k/v affine 70 | q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x)) 71 | e = lgx.view(-1, self.num_heads, self.d_k) if lgx.size(-1) == q.size(-1) else \ 72 | lgx.unsqueeze(1).expand(-1, self.num_heads, -1) 73 | if self.optimize_graph: 74 | e2 = lgx2.view(-1, self.num_heads, self.d_k) if lgx2.size(-1) == q.size(-1) else \ 75 | lgx2.unsqueeze(1).expand(-1, self.num_heads, -1) 76 | with g.local_scope(): 77 | g.ndata['q'], g.ndata['k'] = q.view(-1, self.num_heads, self.d_k), k.view(-1, self.num_heads, self.d_k) 78 | g.ndata['v'] = v.view(-1, self.num_heads, self.d_k) 79 | g.edata['e'] = e 80 | if self.optimize_graph: 81 | g.edata['e2'] = e2 82 | g.edata['weight'] = g.edata['weight'].unsqueeze(1).expand(-1, self.num_heads) 83 | g.edata['weight'] = g.edata['weight'].unsqueeze(2).expand(-1, -1, self.d_k) 84 | 85 | out_x = self.propagate_attention(g) 86 | 87 | out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k))) 88 | # out_x = x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)) 89 | out_x = self.ffn(out_x) 90 | # import ipdb; ipdb.set_trace() 91 | return out_x, lgx 92 | 93 | def propagate_attention(self, g): 94 | # Compute attention score 95 | 96 | if self.optimize_graph: 97 | g.apply_edges(src_sum_edge_mul_dst2('k', 'q', 'e', 'e2', 'weight', 'score')) 98 | else: 99 | g.apply_edges(src_sum_edge_mul_dst('k', 'q', 'e', 'score')) 100 | 101 | g.apply_edges(scaled_exp('score', math.sqrt(self.d_k))) 102 | # Update node state 103 | if self.optimize_graph: 104 | g.update_all(src_sum_edge_mul_edge2('v', 'e', 'e2', 'weight', 'score', 'v'), fn.sum('v', 'wv')) 105 | else: 106 | g.update_all(src_sum_edge_mul_edge('v', 'e', 'score', 'v'), fn.sum('v', 'wv')) 107 | g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o')) 108 | out_x = g.ndata['o'] 109 | return out_x 110 | -------------------------------------------------------------------------------- /model/model_constructor.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | from model.model_utils import Registrable, PoolingFunction 5 | from model.encoder.graph_encoder import Text2SQLEncoder 6 | from model.decoder.sql_parser import SqlParser 7 | from utils.example import Example 8 | from model.encoder.graph_input import * 9 | from utils.constants import RELATIONS, RELATIONS_INDEX 10 | import copy 11 | from asdls.sql.sql_transition_system import SelectColumnAction, SelectTableAction 12 | 13 | 14 | class Text2SQL(nn.Module): 15 | def __init__(self, args, transition_system): 16 | super(Text2SQL, self).__init__() 17 | lazy_load = args.lazy_load if hasattr(args, 'lazy_load') else False 18 | self.input_layer = GraphInputLayerPLM(args, args.plm, args.gnn_hidden_size, dropout=args.dropout, 19 | subword_aggregation=args.subword_aggregation, 20 | schema_aggregation=args.schema_aggregation, lazy_load=lazy_load) 21 | self.encoder = Text2SQLEncoder(args) 22 | self.encoder2decoder = PoolingFunction(args.gnn_hidden_size, args.lstm_hidden_size, method='attentive-pooling') 23 | self.decoder = SqlParser(args, transition_system) 24 | self.args = args 25 | 26 | def forward(self, batch, cur_dataset, train=False, use_standard=False): 27 | 28 | graph_list = [ex.graph for ex in cur_dataset] 29 | new_graph_list = copy.deepcopy(graph_list) 30 | 31 | outputs, loss_token, edges, weight_m = self.input_layer(batch, use_standard) 32 | 33 | if self.args.optimize_graph: 34 | for n_b in range(len(graph_list)): 35 | new_graph_list[n_b].global_g = new_graph_list[n_b].global_g.to("cuda:0") 36 | new_graph_list[n_b].global_g.edata['weight'] = weight_m[n_b].reshape(-1) 37 | 38 | if self.args.semantic: 39 | new_graph_list = self.add_edge(new_graph_list, edges, cur_dataset) 40 | 41 | batch.graph = Example.graph_factory.batch_graphs(new_graph_list, "cuda:0", train=train) 42 | encodings, mask, gp_loss = self.encoder(batch, outputs) 43 | 44 | h0 = self.encoder2decoder(encodings, mask=mask) 45 | loss = self.decoder.score(encodings, mask, h0, batch) 46 | if self.args.token_task or self.args.schema_loss: 47 | loss = loss + loss_token 48 | return loss 49 | 50 | def get_exact_match_column_showed(self, cur_dataset, batch): 51 | 52 | column_part_match = RELATIONS.index('question-column-partialmatch') 53 | column_exact_match = RELATIONS.index('question-column-exactmatch') 54 | table_part_match = RELATIONS.index('question-table-partialmatch') 55 | table_exact_match = RELATIONS.index('question-table-exactmatch') 56 | column_value_match = RELATIONS.index('question-column-valuematch') 57 | columns = [] 58 | tables = [] 59 | for _index in range(len(cur_dataset)): 60 | data = cur_dataset[_index] 61 | keys = list(data.graph.local_edge_map.keys()) 62 | cur_columns = [] 63 | cur_tables = [] 64 | for index in range(len(data.graph.local_edges)): 65 | edge = data.graph.local_edges[index] 66 | if edge == column_exact_match: 67 | column = keys[index][1] - len(batch.questions[_index]) - len(batch.table_names[_index]) 68 | if column not in cur_columns: 69 | cur_columns.append(column) 70 | if edge == table_exact_match: 71 | table = keys[index][1] - len(batch.questions[_index]) 72 | if table not in cur_tables: 73 | cur_tables.append(table) 74 | columns.append(cur_columns) 75 | tables.append(cur_tables) 76 | return columns, tables 77 | 78 | def check_edge(self, cur_dataset): 79 | for _index in range(len(cur_dataset)): 80 | data = cur_dataset[_index] 81 | for index in range(len(data.global_edges)): 82 | if data.global_edges[index] == 21: 83 | print('already semantic') 84 | import ipdb; ipdb.set_trace() 85 | 86 | def add_edge(self, graph_data, edges, cur_dataset): 87 | 88 | question_column_semanticmatch = RELATIONS.index('question-column-semanticmatch') 89 | column_question_semanticmatch = RELATIONS.index('column-question-semanticmatch') 90 | question_table_semanticmatch = RELATIONS.index('question-table-semanticmatch') 91 | table_question_semanticmatch = RELATIONS.index('table-question-semanticmatch') 92 | 93 | 94 | column_edges, table_edges = edges 95 | for _index in range(len(graph_data)): 96 | data = graph_data[_index] 97 | column_dict = cur_dataset[_index].column_dict 98 | table_dict = cur_dataset[_index].table_dict 99 | for edge in column_edges[_index]: 100 | column, question = edge 101 | if column in column_dict: 102 | question = column_dict[column] 103 | index = data.global_edge_map[(column, question)] 104 | data.global_edges[index] = column_question_semanticmatch 105 | index = data.global_edge_map[(question, column)] 106 | data.global_edges[index] = question_column_semanticmatch 107 | if edge in data.local_edge_map: 108 | index = data.local_edge_map[(column, question)] 109 | data.local_edges[index] = column_question_semanticmatch 110 | index = data.local_edge_map[(question, column)] 111 | data.local_edges[index] = question_column_semanticmatch 112 | 113 | for edge in table_edges[_index]: 114 | table, question = edge 115 | if table in table_dict: 116 | question = table_dict[table] 117 | index = data.global_edge_map[(table, question)] 118 | data.global_edges[index] = table_question_semanticmatch 119 | index = data.global_edge_map[(question, table)] 120 | data.global_edges[index] = question_table_semanticmatch 121 | if edge in data.local_edge_map: 122 | index = data.local_edge_map[(table, question)] 123 | data.local_edges[index] = table_question_semanticmatch 124 | index = data.local_edge_map[(question, table)] 125 | data.local_edges[index] = question_table_semanticmatch 126 | 127 | return graph_data 128 | 129 | def parse(self, batch, beam_size, cur_dataset, use_standard=False): 130 | """ This function is used for decoding, which returns a batch of [DecodeHypothesis()] * beam_size 131 | """ 132 | graph_list = [ex.graph for ex in cur_dataset] 133 | new_graph_list = copy.deepcopy(graph_list) 134 | outputs, loss_token, edges, weight_m = self.input_layer(batch, use_standard, train=False) 135 | 136 | if self.args.optimize_graph: 137 | for n_b in range(len(graph_list)): 138 | new_graph_list[n_b].global_g = new_graph_list[n_b].global_g.to("cuda:0") 139 | new_graph_list[n_b].global_g.edata['weight'] = weight_m[n_b].reshape(-1) 140 | 141 | if self.args.semantic: 142 | new_graph_list = self.add_edge(new_graph_list, edges, cur_dataset) 143 | batch.graph = Example.graph_factory.batch_graphs(new_graph_list, "cuda:0", train=False) 144 | encodings, mask = self.encoder(batch, outputs) 145 | h0 = self.encoder2decoder(encodings, mask=mask) 146 | hyps = [] 147 | table_in_sql, column_in_sql = [], [] 148 | for i in range(len(batch)): 149 | """ 150 | table_mappings and column_mappings are used to map original database ids to local ids, 151 | while reverse_mappings perform the opposite function, mapping local ids to database ids 152 | """ 153 | table_in_sql_batch, column_in_sql_batch = [], [] 154 | hyps_item = self.decoder.parse(encodings[i:i+1], mask[i:i+1], h0[i:i+1], batch, beam_size) 155 | for action in hyps_item[0].actions: 156 | if isinstance(action, SelectColumnAction): 157 | if action.token not in column_in_sql_batch: 158 | column_in_sql_batch.append(action.token) 159 | if isinstance(action, SelectTableAction): 160 | if action.token not in table_in_sql_batch: 161 | table_in_sql_batch.append(action.token) 162 | hyps.append(hyps_item) 163 | table_in_sql.append(table_in_sql_batch) 164 | column_in_sql.append(column_in_sql_batch) 165 | 166 | return hyps, (batch.table_used, batch.column_used), (table_in_sql, column_in_sql) 167 | 168 | def pad_embedding_grad_zero(self, index=None): 169 | """ 170 | For glove.42B.300d word vectors, gradients for symbol is always 0; 171 | Most words (starting from index) in the word vocab are also fixed except most frequent words 172 | """ 173 | self.input_layer.pad_embedding_grad_zero(index) 174 | -------------------------------------------------------------------------------- /model/model_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import copy, math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.utils.rnn as rnn_utils 6 | 7 | def clones(module, N): 8 | "Produce N identical layers." 9 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 10 | 11 | def lens2mask(lens): 12 | bsize = lens.numel() 13 | max_len = lens.max() 14 | masks = torch.arange(0, max_len).type_as(lens).to(lens.device).repeat(bsize, 1).lt(lens.unsqueeze(1)) 15 | masks.requires_grad = False 16 | return masks 17 | 18 | def mask2matrix(mask): 19 | col_mask, row_mask = mask.unsqueeze(-1), mask.unsqueeze(-2) 20 | return col_mask & row_mask 21 | 22 | def tile(x, count, dim=0): 23 | """ 24 | Tiles x on dimension dim count times. 25 | E.g. [1, 2, 3], count=2 ==> [1, 1, 2, 2, 3, 3] 26 | [[1, 2], [3, 4]], count=3, dim=1 ==> [[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]] 27 | Different from torch.repeat 28 | """ 29 | if x is None: 30 | return x 31 | elif type(x) in [list, tuple]: 32 | return type(x)([tile(each, count, dim) for each in x]) 33 | else: 34 | perm = list(range(len(x.size()))) 35 | if dim != 0: 36 | perm[0], perm[dim] = perm[dim], perm[0] 37 | x = x.permute(perm).contiguous() 38 | out_size = list(x.size()) 39 | out_size[0] *= count 40 | batch = x.size(0) 41 | x = x.contiguous().view(batch, -1) \ 42 | .transpose(0, 1) \ 43 | .repeat(count, 1) \ 44 | .transpose(0, 1) \ 45 | .contiguous() \ 46 | .view(*out_size) 47 | if dim != 0: 48 | x = x.permute(perm).contiguous() 49 | return x 50 | 51 | def rnn_wrapper(encoder, inputs, lens, cell='lstm'): 52 | """ 53 | @args: 54 | encoder(nn.Module): rnn series bidirectional encoder, batch_first=True 55 | inputs(torch.FloatTensor): rnn inputs, [bsize x max_seq_len x in_dim] 56 | lens(torch.LongTensor): seq len for each sample, allow length=0, padding with 0-vector, [bsize] 57 | @return: 58 | out(torch.FloatTensor): output of encoder, bsize x max_seq_len x hidden_dim*2 59 | hidden_states([tuple of ]torch.FloatTensor): final hidden states, num_layers*2 x bsize x hidden_dim 60 | """ 61 | # rerank according to lens and remove empty inputs 62 | sorted_lens, sort_key = torch.sort(lens, descending=True) 63 | nonzero_num, total_num = torch.sum(sorted_lens > 0).item(), sorted_lens.size(0) 64 | sort_key = sort_key[:nonzero_num] 65 | sorted_inputs = torch.index_select(inputs, dim=0, index=sort_key) 66 | # forward non empty inputs 67 | packed_inputs = rnn_utils.pack_padded_sequence(sorted_inputs, sorted_lens[:nonzero_num].tolist(), batch_first=True) 68 | packed_out, sorted_h = encoder(packed_inputs) # bsize x srclen x dim 69 | sorted_out, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True) 70 | if cell.upper() == 'LSTM': 71 | sorted_h, sorted_c = sorted_h 72 | # rerank according to sort_key 73 | out_shape = list(sorted_out.size()) 74 | out_shape[0] = total_num 75 | out = sorted_out.new_zeros(*out_shape).scatter_(0, sort_key.unsqueeze(-1).unsqueeze(-1).repeat(1, *out_shape[1:]), sorted_out) 76 | h_shape = list(sorted_h.size()) 77 | h_shape[1] = total_num 78 | h = sorted_h.new_zeros(*h_shape).scatter_(1, sort_key.unsqueeze(0).unsqueeze(-1).repeat(h_shape[0], 1, h_shape[-1]), sorted_h) 79 | if cell.upper() == 'LSTM': 80 | c = sorted_c.new_zeros(*h_shape).scatter_(1, sort_key.unsqueeze(0).unsqueeze(-1).repeat(h_shape[0], 1, h_shape[-1]), sorted_c) 81 | return out, (h.contiguous(), c.contiguous()) 82 | return out, h.contiguous() 83 | 84 | class MultiHeadAttention(nn.Module): 85 | 86 | def __init__(self, hidden_size, q_size, kv_size, output_size, num_heads=8, bias=True, feat_drop=0.2, attn_drop=0.0): 87 | super(MultiHeadAttention, self).__init__() 88 | self.num_heads = int(num_heads) 89 | self.hidden_size = hidden_size 90 | assert self.hidden_size % self.num_heads == 0, 'Head num %d must be divided by hidden size %d' % (num_heads, hidden_size) 91 | self.d_k = self.hidden_size // self.num_heads 92 | self.feat_drop = nn.Dropout(p=feat_drop) 93 | self.attn_drop = nn.Dropout(p=attn_drop) 94 | self.W_q = nn.Linear(q_size, self.hidden_size, bias=bias) 95 | self.W_k = nn.Linear(kv_size, self.hidden_size, bias=False) 96 | self.W_v = nn.Linear(kv_size, self.hidden_size, bias=False) 97 | self.W_o = nn.Linear(self.hidden_size, output_size, bias=bias) 98 | 99 | def forward(self, hiddens, query_hiddens, mask=None): 100 | ''' @params: 101 | hiddens : encoded sequence representations, bsize x seqlen x hidden_size 102 | query_hiddens : bsize [x tgtlen ]x hidden_size 103 | mask : length mask for hiddens, ByteTensor, bsize x seqlen 104 | @return: 105 | context : bsize x[ tgtlen x] hidden_size 106 | ''' 107 | remove_flag = False 108 | if query_hiddens.dim() == 2: 109 | query_hiddens, remove_flag = query_hiddens.unsqueeze(1), True 110 | Q, K, V = self.W_q(self.feat_drop(query_hiddens)), self.W_k(self.feat_drop(hiddens)), self.W_v(self.feat_drop(hiddens)) 111 | Q, K, V = Q.reshape(-1, Q.size(1), 1, self.num_heads, self.d_k), K.reshape(-1, 1, K.size(1), self.num_heads, self.d_k), V.reshape(-1, 1, V.size(1), self.num_heads, self.d_k) 112 | e = (Q * K).sum(-1) / math.sqrt(self.d_k) # bsize x tgtlen x seqlen x num_heads 113 | if mask is not None: 114 | e = e + ((1 - mask.float()) * (-1e20)).unsqueeze(1).unsqueeze(-1) 115 | a = torch.softmax(e, dim=2) 116 | concat = (a.unsqueeze(-1) * V).sum(dim=2).reshape(-1, query_hiddens.size(1), self.hidden_size) 117 | context = self.W_o(concat) 118 | if remove_flag: 119 | return context.squeeze(dim=1), a.mean(dim=-1).squeeze(dim=1) 120 | else: 121 | return context, a.mean(dim=-1) 122 | 123 | class PoolingFunction(nn.Module): 124 | """ Map a sequence of hidden_size dim vectors into one fixed size vector with dimension output_size """ 125 | def __init__(self, hidden_size=256, output_size=256, bias=True, method='attentive-pooling'): 126 | super(PoolingFunction, self).__init__() 127 | assert method in ['mean-pooling', 'max-pooling', 'attentive-pooling'] 128 | self.method = method 129 | if self.method == 'attentive-pooling': 130 | self.attn = nn.Sequential( 131 | nn.Linear(hidden_size, hidden_size, bias=bias), 132 | nn.Tanh(), 133 | nn.Linear(hidden_size, 1, bias=bias) 134 | ) 135 | self.mapping_function = nn.Sequential(nn.Linear(hidden_size, output_size, bias=bias), nn.Tanh()) \ 136 | if hidden_size != output_size else lambda x: x 137 | 138 | def forward(self, inputs, mask=None): 139 | """ @args: 140 | inputs(torch.FloatTensor): features, batch_size x seq_len x hidden_size 141 | mask(torch.BoolTensor): mask for inputs, batch_size x seq_len 142 | @return: 143 | outputs(torch.FloatTensor): aggregate seq_len dim for inputs, batch_size x output_size 144 | """ 145 | if self.method == 'max-pooling': 146 | outputs = inputs.masked_fill(~ mask.unsqueeze(-1), -1e8) 147 | outputs = outputs.max(dim=1)[0] 148 | elif self.method == 'mean-pooling': 149 | mask_float = mask.float().unsqueeze(-1) 150 | outputs = (inputs * mask_float).sum(dim=1) / mask_float.sum(dim=1) 151 | elif self.method == 'attentive-pooling': 152 | e = self.attn(inputs).squeeze(-1) 153 | e = e + (1 - mask.float()) * (-1e20) 154 | a = torch.softmax(e, dim=1).unsqueeze(1) 155 | outputs = torch.bmm(a, inputs).squeeze(1) 156 | else: 157 | raise ValueError('[Error]: Unrecognized pooling method %s !' % (self.method)) 158 | outputs = self.mapping_function(outputs) 159 | return outputs 160 | 161 | class FFN(nn.Module): 162 | 163 | def __init__(self, input_size): 164 | super(FFN, self).__init__() 165 | self.input_size = input_size 166 | self.feedforward = nn.Sequential( 167 | nn.Linear(self.input_size, self.input_size * 4), 168 | nn.ReLU(inplace=True), 169 | nn.Linear(self.input_size * 4, self.input_size) 170 | ) 171 | self.layernorm = nn.LayerNorm(self.input_size) 172 | 173 | def forward(self, inputs): 174 | return self.layernorm(inputs + self.feedforward(inputs)) 175 | 176 | class Registrable(object): 177 | """ 178 | A class that collects all registered components, 179 | adapted from `common.registrable.Registrable` from AllenNLP 180 | """ 181 | registered_components = dict() 182 | 183 | @staticmethod 184 | def register(name): 185 | def register_class(cls): 186 | if name in Registrable.registered_components: 187 | raise RuntimeError('class %s already registered' % name) 188 | 189 | Registrable.registered_components[name] = cls 190 | return cls 191 | 192 | return register_class 193 | 194 | @staticmethod 195 | def by_name(name): 196 | return Registrable.registered_components[name] 197 | 198 | class cached_property(object): 199 | """ A property that is only computed once per instance and then replaces 200 | itself with an ordinary attribute. Deleting the attribute resets the 201 | property. 202 | 203 | Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 204 | """ 205 | 206 | def __init__(self, func): 207 | self.__doc__ = getattr(func, '__doc__') 208 | self.func = func 209 | 210 | def __get__(self, obj, cls): 211 | if obj is None: 212 | return self 213 | value = obj.__dict__[self.func.__name__] = self.func(obj) 214 | return value 215 | -------------------------------------------------------------------------------- /preprocess/graph_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import math, dgl, torch 3 | import numpy as np 4 | from utils.constants import MAX_RELATIVE_DIST 5 | from utils.graph_example import GraphExample 6 | import time 7 | 8 | # mapping special column * as an ordinary column 9 | special_column_mapping_dict = { 10 | 'question-*-generic': 'question-column-nomatch', 11 | '*-question-generic': 'column-question-nomatch', 12 | 'table-*-generic': 'table-column-has', 13 | '*-table-generic': 'column-table-has', 14 | '*-column-generic': 'column-column-generic', 15 | 'column-*-generic': 'column-column-generic', 16 | '*-*-identity': 'column-column-identity' 17 | } 18 | # nonlocal_relations = [ 19 | # 'question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic', 20 | # 'table-table-fk', 'table-table-fkr', 'table-table-fkb', 'column-column-sametable', 21 | # '*-column-generic', 'column-*-generic', '*-*-identity', '*-table-generic', 22 | # 'question-question-identity', 'table-table-identity', 'column-column-identity'] + [ 23 | # 'question-question-dist' + str(i) for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1) if i not in [-1, 0, 1] 24 | # ] 25 | class GraphProcessor(): 26 | 27 | def process_rgatsql(self, ex: dict, db: dict, relation: list, relation_semantic: list): 28 | graph = GraphExample() 29 | 30 | num_nodes = int(math.sqrt(len(relation))) 31 | 32 | # local_edges = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r)) 33 | # for idx, r in enumerate(relation) if r not in nonlocal_relations] 34 | # local_edges2 = [(idx // num_nodes, idx % num_nodes, 35 | # (special_column_mapping_dict[r] if r in special_column_mapping_dict else r)) 36 | # for idx, r in enumerate(relation) if r not in nonlocal_relations] 37 | # local_edges_mask = [True if r not in nonlocal_relations else False 38 | # for idx, r in enumerate(relation)] 39 | # graph.local_edges_mask = local_edges_mask 40 | 41 | global_edges = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r)) for idx, r in enumerate(relation)] 42 | global_edges2 = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r)) for idx, r in enumerate(relation_semantic)] 43 | 44 | src_ids, dst_ids = list(map(lambda r: r[0], global_edges)), list(map(lambda r: r[1], global_edges)) 45 | graph.global_g = dgl.graph((src_ids, dst_ids), num_nodes=num_nodes, idtype=torch.int32) 46 | graph.global_edges = global_edges 47 | graph.global_edges2 = global_edges2 48 | # src_ids, dst_ids = list(map(lambda r: r[0], local_edges)), list(map(lambda r: r[1], local_edges)) 49 | # graph.local_g = dgl.graph((src_ids, dst_ids), num_nodes=num_nodes, idtype=torch.int32) 50 | # graph.local_edges = local_edges 51 | # graph.local_edges2 = local_edges2 52 | 53 | # graph pruning for nodes 54 | q_num = len(ex['processed_question_toks']) 55 | s_num = num_nodes - q_num 56 | graph.question_mask = [1] * q_num + [0] * s_num 57 | graph.schema_mask = [0] * q_num + [1] * s_num 58 | graph.gp = dgl.heterograph({ 59 | ('question', 'to', 'schema'): (list(range(q_num)) * s_num, 60 | [i for i in range(s_num) for _ in range(q_num)]) 61 | }, num_nodes_dict={'question': q_num, 'schema': s_num}, idtype=torch.int32 62 | ) 63 | t_num = len(db['processed_table_toks']) 64 | def check_node(i): 65 | if i < t_num and i in ex['used_tables']: 66 | return 1.0 67 | elif i >= t_num and i - t_num in ex['used_columns']: 68 | return 1.0 69 | else: return 0.0 70 | graph.node_label = list(map(check_node, range(s_num))) 71 | graph.schema_weight = ex['schema_weight'] 72 | ex['graph'] = graph 73 | 74 | return ex 75 | 76 | def process_graph_utils(self, ex: dict, db: dict): 77 | """ Example should be preprocessed by self.pipeline 78 | """ 79 | q = np.array(ex['relations'], dtype=' 100: continue 39 | if verbose: 40 | print('*************** Processing %d-th sample **************' % (idx)) 41 | entry = process_example(processor, entry, tables[entry['db_id']], trans, verbose=verbose) 42 | processed_dataset.append(entry) 43 | print('In total, process %d samples , skip %d extremely large databases.' % (len(processed_dataset), len(dataset) - len(processed_dataset))) 44 | if output_path is not None: 45 | # serialize preprocessed dataset 46 | pickle.dump(processed_dataset, open(output_path, 'wb')) 47 | return processed_dataset 48 | 49 | if __name__ == '__main__': 50 | 51 | arg_parser = argparse.ArgumentParser() 52 | arg_parser.add_argument('--db_dir', type=str, default='data/database') 53 | arg_parser.add_argument('--dataset_path', type=str, required=True, help='dataset path') 54 | arg_parser.add_argument('--raw_table_path', type=str, help='raw tables path') 55 | arg_parser.add_argument('--table_path', type=str, required=True, help='processed table path') 56 | arg_parser.add_argument('--output_path', type=str, required=True, help='output preprocessed dataset') 57 | arg_parser.add_argument('--skip_large', action='store_true', help='whether skip large databases') 58 | arg_parser.add_argument('--verbose', action='store_true', help='whether print processing information') 59 | arg_parser.add_argument('--toy', action='store_true', help='whether generate small dataset') 60 | arg_parser.add_argument('--semantic_graph', action='store_true', help='generate semantic edge') 61 | arg_parser.add_argument('--semantic_feature_size', type=int, default = 1024 ,required=False) 62 | arg_parser.add_argument('--semantic_batch_size', type=int, default = 8 ,required=False) 63 | arg_parser.add_argument('--semantic_threshold', type=int, default = 0.7 ,required=False) 64 | arg_parser.add_argument('--semantic_pretrain_model', type=str, default = 'google/electra-large-discriminator', required=False) 65 | args = arg_parser.parse_args() 66 | 67 | processor = Preprocessor(args, db_dir=args.db_dir, db_content=True) 68 | # loading database and dataset 69 | if args.raw_table_path: 70 | # need to preprocess database items 71 | tables_list = json.load(open(args.raw_table_path, 'r')) 72 | print('Firstly, preprocess the original databases ...') 73 | start_time = time.time() 74 | tables = process_tables(processor, tables_list, args.table_path, args.verbose) 75 | print('Databases preprocessing costs %.4fs .' % (time.time() - start_time)) 76 | else: 77 | tables = pickle.load(open(args.table_path, 'rb')) 78 | dataset = json.load(open(args.dataset_path, 'r')) 79 | if args.toy: 80 | dataset = dataset[:10] 81 | start_time = time.time() 82 | dataset = process_dataset(processor, dataset, tables, args.output_path, args.skip_large, verbose=args.verbose) 83 | print('Dataset preprocessing costs %.4fs .' % (time.time() - start_time)) 84 | -------------------------------------------------------------------------------- /preprocess/process_graphs.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import os, json, pickle, argparse, sys, time 3 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 4 | from preprocess.graph_utils import GraphProcessor 5 | import time 6 | 7 | 8 | def process_dataset_graph(processor, dataset, tables, output_path=None, skip_large=False): 9 | processed_dataset = [] 10 | for idx, entry in enumerate(dataset): 11 | db = tables[entry['db_id']] 12 | if skip_large and len(db['column_names']) > 100: 13 | continue 14 | if (idx + 1) % 500 == 0: 15 | print('Processing the %d-th example ...' % (idx + 1)) 16 | entry = processor.process_graph_utils(entry, db) 17 | processed_dataset.append(entry) 18 | 19 | print('In total, process %d samples, skip %d samples .' % (len(processed_dataset), len(dataset) - len(processed_dataset))) 20 | if output_path is not None: 21 | pickle.dump(processed_dataset, open(output_path, 'wb')) 22 | return processed_dataset 23 | 24 | 25 | if __name__ == '__main__': 26 | 27 | arg_parser = argparse.ArgumentParser() 28 | arg_parser.add_argument('--dataset_path', type=str, required=True, help='dataset path') 29 | arg_parser.add_argument('--table_path', type=str, required=True, help='processed table path') 30 | arg_parser.add_argument('--output_path', type=str, required=True, help='output preprocessed dataset') 31 | args = arg_parser.parse_args() 32 | 33 | processor = GraphProcessor() 34 | # loading database and dataset 35 | tables = pickle.load(open(args.table_path, 'rb')) 36 | dataset = pickle.load(open(args.dataset_path, 'rb')) 37 | start_time = time.time() 38 | dataset = process_dataset_graph(processor, dataset, tables, args.output_path) 39 | print('Dataset preprocessing costs %.4fs .' % (time.time() - start_time)) 40 | -------------------------------------------------------------------------------- /scripts/text2sql.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import sys, os, time, json, gc, datetime, csv 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | from utils.args import init_args 5 | from utils.hyperparams import hyperparam_path 6 | from utils.initialization import * 7 | from utils.example import Example 8 | from utils.batch import Batch 9 | from utils.optimization import set_optimizer 10 | from model.model_utils import Registrable 11 | from model.model_constructor import Text2SQL 12 | from tensorboardX import SummaryWriter 13 | from utils.constants import RELATIONS, RELATIONS_INDEX 14 | from utils.logger import Logger 15 | from utils.evaluator import acc_token 16 | from tqdm import tqdm 17 | 18 | 19 | def decode(dataset, output_path, acc_type='sql', use_checker=False, use_standard=False): 20 | assert acc_type in ['beam', 'ast', 'sql'] 21 | model.eval() 22 | all_hyps = [] 23 | table_predicts, column_predicts, table_labels, column_labels, table_sqls, column_sqls = [], [], [], [], [], [] 24 | _correct, _total, _total1 = 0, 0, 0 25 | with torch.no_grad(): 26 | for i in range(0, len(dataset), args.batch_size): 27 | current_batch = Batch.from_example_list(dataset[i: i + args.batch_size], device, train=False) 28 | hyps, labels, sql_schemas = model.parse(current_batch, args.beam_size, dataset[i: i + args.batch_size], use_standard=use_standard) 29 | all_hyps.extend(hyps) 30 | acc = evaluator.acc(all_hyps, dataset, output_path, acc_type=acc_type, etype='match', use_checker=use_checker) 31 | return acc 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | args = init_args(sys.argv[1:]) 37 | set_random_seed(args.seed) 38 | logger = Logger(args.logdir, vars(args)) 39 | device = set_torch_device(args.device) 40 | logger.log("Initialization finished ...") 41 | logger.log("Random seed is set to %d" % (args.seed)) 42 | 43 | # load dataset and vocabulary 44 | start_time = time.time() 45 | 46 | Example.configuration(plm=args.plm, method=args.model) 47 | dev_dk_dataset = Example.load_dataset(args.dev_dk_path) if args.dev_dk_path != "" else None 48 | dev_syn_dataset = Example.load_dataset(args.dev_syn_path) if args.dev_syn_path != "" else None 49 | train_dataset, dev_dataset = Example.load_dataset(args.train_path), Example.load_dataset(args.dev_path) 50 | 51 | logger.log("Load dataset and database finished, cost %.4fs ..." % (time.time() - start_time)) 52 | logger.log("Dataset size: train -> %d ; dev -> %d" % (len(train_dataset), len(dev_dataset))) 53 | sql_trans, evaluator = Example.trans, Example.evaluator 54 | args.word_vocab, args.relation_num = len(Example.word_vocab), len(Example.relation_vocab) 55 | 56 | # model init, set optimizer 57 | model = Text2SQL(args, sql_trans).to(device) 58 | 59 | if args.read_model_path and args.read_model_path != 'none': 60 | check_point = torch.load(open(os.path.join(args.read_model_path, 'model.bin'), 'rb'), map_location=device) 61 | model.load_state_dict(check_point['model']) 62 | logger.log("Load saved model from path: %s" % (args.read_model_path)) 63 | else: 64 | json.dump(vars(args), open(os.path.join(args.logdir, 'params.json'), 'w'), indent=4) 65 | if args.plm is None: 66 | ratio = Example.word2vec.load_embeddings(model.encoder.input_layer.word_embed, Example.word_vocab, device=device) 67 | logger.log("Init model and word embedding layer with a coverage %.2f" % (ratio)) 68 | 69 | if args.training: 70 | num_training_steps = ((len(train_dataset) + args.batch_size - 1) // args.batch_size) * args.max_epoch 71 | num_warmup_steps = int(num_training_steps * args.warmup_ratio) 72 | logger.log('Total training steps: %d;\t Warmup steps: %d' % (num_training_steps, num_warmup_steps)) 73 | 74 | optimizer, scheduler = set_optimizer(model, args, num_warmup_steps, num_training_steps) 75 | start_epoch, nsamples, best_result = 0, len(train_dataset), {'dev_acc': 0.} 76 | train_index, step_size = np.arange(nsamples), args.batch_size // args.grad_accumulate 77 | 78 | logger.log('Start training ......') 79 | 80 | for i in range(start_epoch, args.max_epoch): 81 | start_time = time.time() 82 | epoch_loss, epoch_gp_loss, count = 0, 0, 0 83 | np.random.shuffle(train_index) 84 | model.train() 85 | table_predicts, column_predicts, table_labels, column_labels = [], [], [], [] 86 | for j in range(0, nsamples, step_size): 87 | count += 1 88 | cur_dataset = [train_dataset[k] for k in train_index[j: j + step_size]] 89 | current_batch = Batch.from_example_list(cur_dataset, device, train=True, smoothing=args.smoothing) 90 | loss = model(current_batch, cur_dataset, train=True, use_standard=args.use_standard) # see utils/batch.py for batch elements 91 | 92 | epoch_loss += loss.item() 93 | loss.backward() 94 | if j % 1000 == 0: 95 | logger.write_metrics({'step': j, 'loss': loss.item()},'step') 96 | 97 | if count == args.grad_accumulate or j + step_size >= nsamples: 98 | count = 0 99 | model.pad_embedding_grad_zero() 100 | optimizer.step() 101 | scheduler.step() 102 | optimizer.zero_grad() 103 | 104 | if i < args.eval_after_epoch: 105 | continue 106 | 107 | start_time = time.time() 108 | dev_acc = decode(dev_dataset, os.path.join(args.logdir, 'dev.iter' + str(i)), acc_type='sql') 109 | logger.log('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f' % (i, time.time() - start_time, dev_acc)) 110 | 111 | if args.dev_syn_path != "": 112 | start_time = time.time() 113 | dev_syn_acc = decode(dev_syn_dataset, os.path.join(args.logdir, 'dev_syn.iter' + str(i)), acc_type='sql') 114 | logger.log('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f' % (i, time.time() - start_time, dev_syn_acc)) 115 | 116 | if args.dev_dk_path != "": 117 | start_time = time.time() 118 | dev_dk_acc = decode(dev_dk_dataset, os.path.join(args.logdir, 'dev_dk.iter' + str(i)), acc_type='sql') 119 | logger.log('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f' % (i, time.time() - start_time, dev_dk_acc)) 120 | 121 | if dev_acc > best_result['dev_acc']: 122 | best_result['dev_acc'], best_result['iter'] = dev_acc, i 123 | torch.save({ 124 | 'epoch': i, 'model': model.state_dict(), 125 | 'optim': optimizer.state_dict(), 126 | 'scheduler': scheduler.state_dict() 127 | }, open(os.path.join(args.logdir, 'model.bin'), 'wb')) 128 | logger.log('NEW BEST MODEL: \tEpoch: %d\tDev acc: %.4f' % (i, dev_acc)) 129 | 130 | logger.log('FINAL BEST RESULT: \tEpoch: %d\tDev acc: %.4f' % (best_result['iter'], best_result['dev_acc'])) 131 | 132 | if args.testing: 133 | start_time = time.time() 134 | dev_acc = decode(dev_dataset, output_path=os.path.join(args.logdir, 'dev.eval'), acc_type='sql') 135 | print(dev_acc) 136 | dev_acc_checker = decode(dev_dataset, output_path=os.path.join(args.logdir, 'dev.eval.checker'), acc_type='sql', use_checker=True, use_standard=False) 137 | print(dev_acc_checker) 138 | logger.log("Evaluation costs %.2fs ; Dev dataset exact match/checker/beam acc is %.4f/%.4f ." % (time.time() - start_time, dev_acc, dev_acc_checker)) 139 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import argparse 3 | import sys 4 | 5 | 6 | def init_args(params=sys.argv[1:]): 7 | arg_parser = argparse.ArgumentParser() 8 | arg_parser = add_argument_base(arg_parser) 9 | arg_parser = add_argument_encoder(arg_parser) 10 | arg_parser = add_argument_decoder(arg_parser) 11 | opt = arg_parser.parse_args(params) 12 | # if opt.model == 'rgatsql' and opt.local_and_nonlocal == 'msde': 13 | # opt.local_and_nonlocal = 'global' 14 | # if opt.model == 'lgesql' and opt.local_and_nonlocal == 'global': 15 | # opt.local_and_nonlocal = 'msde' 16 | return opt 17 | 18 | 19 | def add_argument_base(arg_parser): 20 | #### General configuration #### 21 | arg_parser.add_argument('--task', default='text2sql', help='task name') 22 | arg_parser.add_argument('--seed', default=999, type=int, help='Random seed') 23 | arg_parser.add_argument('--device', type=int, default=0, help='Use which device: -1 -> cpu ; the index of gpu o.w.') 24 | 25 | arg_parser.add_argument('--training', action='store_true', help='training or evaluation mode') 26 | arg_parser.add_argument('--testing', action='store_true', help='training or evaluation mode') 27 | arg_parser.add_argument('--read_model_path', type=str, help='read pretrained model path') 28 | #arg_parser.add_argument('--read_part_model_path', type=str, help='read pretrained model path') 29 | #### Training Hyperparams #### 30 | arg_parser.add_argument('--batch_size', default=20, type=int, help='Batch size') 31 | arg_parser.add_argument('--grad_accumulate', default=1, type=int, help='accumulate grad and update once every x steps') 32 | arg_parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 33 | arg_parser.add_argument('--layerwise_decay', type=float, default=1.0, help='layerwise decay rate for lr, used for PLM') 34 | arg_parser.add_argument('--l2', type=float, default=1e-4, help='weight decay coefficient') 35 | arg_parser.add_argument('--warmup_ratio', type=float, default=0.1, help='warmup steps proportion') 36 | arg_parser.add_argument('--lr_schedule', default='linear', choices=['constant', 'linear', 'ratsql', 'cosine'], help='lr scheduler') 37 | arg_parser.add_argument('--eval_after_epoch', default=40, type=int, help='Start to evaluate after x epoch') 38 | arg_parser.add_argument('--load_optimizer', action='store_true', default=False, help='Whether to load optimizer state') 39 | 40 | arg_parser.add_argument('--semantic', action='store_true', default=False, help='Whether to use semantic graph') 41 | arg_parser.add_argument('--use_standard', action='store_true', default=False) 42 | arg_parser.add_argument('--token_task', action='store_true', default=False, help='Whether to use token classification task loss') 43 | arg_parser.add_argument('--pruning_edge', action='store_true', default=False, help='Whether to pruning edge in standard schema') 44 | arg_parser.add_argument('--noise_edge', action='store_true', default=False, help='Whether to add noise edge in standard schema') 45 | arg_parser.add_argument('--schema_loss', action='store_true', default=False, help='Whether to calculate schema loss') 46 | arg_parser.add_argument('--filter_edge', action='store_true', default=False, help='whether to filter edge') 47 | 48 | arg_parser.add_argument('--optimize_graph', action='store_true', default=False, help='optimize graph during GNN') 49 | arg_parser.add_argument('--get_info_from_gold', action='store_true', default=False, help='get dict from gold align data') 50 | arg_parser.add_argument('--not_use_weight_transform', action='store_true', default=False, help='use weight metric to transform representation') 51 | arg_parser.add_argument('--random_question_edge', action='store_true', default=False, help='whether to use random question edge') 52 | arg_parser.add_argument('--filter_gold_edge', action='store_true', default=False, help='whether to set weight of edge to one') 53 | 54 | arg_parser.add_argument('--max_epoch', type=int, default=100, help='terminate after maximum epochs') 55 | arg_parser.add_argument('--max_norm', default=5., type=float, help='clip gradients') 56 | arg_parser.add_argument('--dynamic_rate', type=float, default=0.8) 57 | 58 | arg_parser.add_argument('--logdir', default="", type=str, help='location to save and log') 59 | 60 | arg_parser.add_argument('--train_path', default="train", type=str, help='location of the training data') 61 | arg_parser.add_argument('--dev_path', default="dev", type=str, help='location of the development data') 62 | arg_parser.add_argument('--dev_dk_path', default="", type=str, help='location of the development data') 63 | arg_parser.add_argument('--dev_syn_path', default="", type=str, help='location of the development data') 64 | return arg_parser 65 | 66 | def add_argument_encoder(arg_parser): 67 | # Encoder Hyperparams 68 | arg_parser.add_argument('--model', choices=['rgatsql', 'lgesql'], default='lgesql', help='which text2sql model to use') 69 | #arg_parser.add_argument('--local_and_nonlocal', choices=['mmc', 'msde', 'local', 'global'], default='mmc', 70 | # help='how to integrate local and non-local relations: mmc -> multi-head multi-view concatenation ; msde -> mixed static and dynamic embeddings') 71 | arg_parser.add_argument('--output_model', choices=['without_pruning', 'with_pruning'], default='without_pruning', help='whether add graph pruning') 72 | arg_parser.add_argument('--plm', type=str, help='pretrained model name') 73 | arg_parser.add_argument('--subword_aggregation', choices=['mean-pooling', 'max-pooling', 'attentive-pooling'], default='attentive-pooling', help='aggregate subword feats from PLM') 74 | arg_parser.add_argument('--schema_aggregation', choices=['mean-pooling', 'max-pooling', 'attentive-pooling', 'head+tail'], default='head+tail', help='aggregate schema words feats') 75 | arg_parser.add_argument('--dropout', type=float, default=0.2, help='feature dropout rate') 76 | arg_parser.add_argument('--attn_drop', type=float, default=0., help='dropout rate of attention weights') 77 | arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings, only used in glove.42B.300d') 78 | arg_parser.add_argument('--gnn_num_layers', default=8, type=int, help='num of GNN layers in encoder') 79 | arg_parser.add_argument('--gnn_hidden_size', default=256, type=int, help='size of GNN layers hidden states') 80 | arg_parser.add_argument('--num_heads', default=8, type=int, help='num of heads in multihead attn') 81 | arg_parser.add_argument('--relation_share_layers', action='store_true') 82 | arg_parser.add_argument('--relation_share_heads', action='store_true') 83 | arg_parser.add_argument('--score_function', choices=['affine', 'bilinear', 'biaffine', 'dot'], default='affine', help='graph pruning score function') 84 | arg_parser.add_argument('--smoothing', type=float, default=0.15, help='label smoothing factor for graph pruning') 85 | return arg_parser 86 | 87 | def add_argument_decoder(arg_parser): 88 | # Decoder Hyperparams 89 | arg_parser.add_argument('--lstm', choices=['lstm', 'onlstm'], default='onlstm', help='Type of LSTM used, ONLSTM or traditional LSTM') 90 | arg_parser.add_argument('--chunk_size', default=8, type=int, help='parameter of ONLSTM') 91 | arg_parser.add_argument('--att_vec_size', default=512, type=int, help='size of attentional vector') 92 | arg_parser.add_argument('--sep_cxt', action='store_true', help='when calculating context vectors, use seperate cxt for question and schema') 93 | arg_parser.add_argument('--drop_connect', type=float, default=0.2, help='recurrent connection dropout rate in decoder lstm') 94 | arg_parser.add_argument('--lstm_num_layers', type=int, default=1, help='num_layers of decoder') 95 | arg_parser.add_argument('--lstm_hidden_size', default=512, type=int, help='Size of LSTM hidden states') 96 | arg_parser.add_argument('--action_embed_size', default=128, type=int, help='Size of ApplyRule/GenToken action embeddings') 97 | arg_parser.add_argument('--field_embed_size', default=64, type=int, help='Embedding size of ASDL fields') 98 | arg_parser.add_argument('--type_embed_size', default=64, type=int, help='Embeddings ASDL types') 99 | arg_parser.add_argument('--no_context_feeding', action='store_true', default=False, 100 | help='Do not use embedding of context vectors') 101 | arg_parser.add_argument('--no_parent_production_embed', default=False, action='store_true', 102 | help='Do not use embedding of parent ASDL production to update decoder LSTM state') 103 | arg_parser.add_argument('--no_parent_field_embed', default=False, action='store_true', 104 | help='Do not use embedding of parent field to update decoder LSTM state') 105 | arg_parser.add_argument('--no_parent_field_type_embed', default=False, action='store_true', 106 | help='Do not use embedding of the ASDL type of parent field to update decoder LSTM state') 107 | arg_parser.add_argument('--no_parent_state', default=False, action='store_true', 108 | help='Do not use the parent hidden state to update decoder LSTM state') 109 | arg_parser.add_argument('--beam_size', default=5, type=int, help='Beam size for beam search') 110 | arg_parser.add_argument('--decode_max_step', default=100, type=int, help='Maximum number of time steps used in decoding') 111 | return arg_parser 112 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | PAD = '[PAD]' 2 | BOS = '[CLS]' 3 | EOS = '[SEP]' 4 | UNK = '[UNK]' 5 | 6 | GRAMMAR_FILEPATH = 'asdls/sql/grammar/sql_asdl_v2.txt' 7 | SCHEMA_TYPES = ['table', 'others', 'text', 'time', 'number', 'boolean'] 8 | MAX_RELATIVE_DIST = 2 9 | # relations: type_1-type_2-rel_name, r represents reverse edge, b represents bidirectional edge 10 | # RELATIONS = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity' for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1)] + \ 11 | # ['table-table-identity', 'table-table-fk', 'table-table-fkr', 'table-table-fkb'] + \ 12 | # ['column-column-identity', 'column-column-sametable', 'column-column-fk', 'column-column-fkr'] + \ 13 | # ['table-column-pk', 'column-table-pk', 'table-column-has', 'column-table-has'] + \ 14 | # ['question-column-exactmatch', 'question-column-partialmatch', 'question-column-nomatch', 'question-column-valuematch', 15 | # 'column-question-exactmatch', 'column-question-partialmatch', 'column-question-nomatch', 'column-question-valuematch'] + \ 16 | # ['question-table-exactmatch', 'question-table-partialmatch', 'question-table-nomatch', 17 | # 'table-question-exactmatch', 'table-question-partialmatch', 'table-question-nomatch'] + \ 18 | # ['question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic'] + \ 19 | # ['*-*-identity', '*-question-generic', 'question-*-generic', '*-table-generic', 'table-*-generic', '*-column-generic', 'column-*-generic'] 20 | 21 | RELATIONS = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity' for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1)] + \ 22 | ['table-table-identity', 'table-table-fk', 'table-table-fkr', 'table-table-fkb'] + \ 23 | ['column-column-identity', 'column-column-sametable', 'column-column-fk', 'column-column-fkr'] + \ 24 | ['table-column-pk', 'column-table-pk', 'table-column-has', 'column-table-has'] + \ 25 | ['question-column-exactmatch', 'question-column-partialmatch','question-column-partialsemanticmatch', 'question-column-nomatch','question-column-semanticmatch', 'question-column-valuematch', 26 | 'column-question-exactmatch', 'column-question-partialmatch','column-question-partialsemanticmatch', 'column-question-nomatch','column-question-semanticmatch', 'column-question-valuematch'] + \ 27 | ['question-table-exactmatch', 'question-table-partialmatch','question-table-partialsemanticmatch', 'question-table-nomatch','question-table-semanticmatch', 28 | 'table-question-exactmatch', 'table-question-partialmatch', 'table-question-partialsemanticmatch','table-question-semanticmatch','table-question-nomatch'] + \ 29 | ['question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic'] + \ 30 | ['*-*-identity', '*-question-generic', 'question-*-generic', '*-table-generic', 'table-*-generic', '*-column-generic', 'column-*-generic'] 31 | 32 | RELATIONS_INDEX = {} 33 | index = 0 34 | for item in RELATIONS: 35 | RELATIONS_INDEX[item] = index 36 | index += 1 37 | # import ipdb; ipdb.set_trace() 38 | -------------------------------------------------------------------------------- /utils/example.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import os, pickle, json 3 | import torch, random 4 | import numpy as np 5 | from asdls.asdl import ASDLGrammar 6 | from asdls.transition_system import TransitionSystem 7 | from utils.constants import UNK, GRAMMAR_FILEPATH, SCHEMA_TYPES, RELATIONS 8 | from utils.graph_example import GraphFactory 9 | from utils.vocab import Vocab 10 | from utils.word2vec import Word2vecUtils 11 | from transformers import AutoTokenizer 12 | from utils.evaluator import Evaluator 13 | from itertools import chain 14 | 15 | class Example(): 16 | 17 | @classmethod 18 | def configuration(cls, plm=None, method='lgesql', table_path='data/new_tables.json', tables='data/tables.bin', db_dir='data/database'): 19 | cls.plm, cls.method = plm, method 20 | cls.grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH) 21 | cls.trans = TransitionSystem.get_class_by_lang('sql')(cls.grammar) 22 | cls.tables = pickle.load(open(tables, 'rb')) if type(tables) == str else tables 23 | cls.evaluator = Evaluator(cls.trans, table_path, db_dir) 24 | cls.tokenizer = AutoTokenizer.from_pretrained(plm) 25 | cls.word_vocab = cls.tokenizer.get_vocab() 26 | cls.relation_vocab = Vocab(padding=False, unk=False, boundary=False, iterable=RELATIONS, default=None) 27 | cls.graph_factory = GraphFactory(cls.method, cls.relation_vocab) 28 | 29 | @classmethod 30 | def load_dataset(cls, choice, debug=False): 31 | # assert choice in ['train', 'dev'] 32 | print("loader ", choice) 33 | fp = os.path.join('data', choice + '.' + cls.method + '.bin') 34 | datasets = pickle.load(open(fp, 'rb')) 35 | # question_lens = [len(ex['processed_question_toks']) for ex in datasets] 36 | # print('Max/Min/Avg question length in %s dataset is: %d/%d/%.2f' % (choice, max(question_lens), min(question_lens), float(sum(question_lens))/len(question_lens))) 37 | # action_lens = [len(ex['actions']) for ex in datasets] 38 | # print('Max/Min/Avg action length in %s dataset is: %d/%d/%.2f' % (choice, max(action_lens), min(action_lens), float(sum(action_lens))/len(action_lens))) 39 | examples, outliers = [], 0 40 | for ex in datasets: 41 | if ex['db_id'] == 'new_concert_singer': 42 | ex['db_id'] = 'concert_singer' 43 | if choice == 'train' and len(cls.tables[ex['db_id']]['column_names']) > 100: 44 | outliers += 1 45 | continue 46 | examples.append(cls(ex, cls.tables[ex['db_id']])) 47 | if debug and len(examples) >= 100: 48 | return examples 49 | if choice == 'train': 50 | print("Skip %d extremely large samples in training dataset ..." % (outliers)) 51 | return examples 52 | 53 | def __init__(self, ex: dict, db: dict): 54 | super(Example, self).__init__() 55 | self.ex = ex 56 | self.db = db 57 | 58 | """ Mapping word to corresponding index """ 59 | t = Example.tokenizer 60 | self.question = [q.lower() for q in ex['raw_question_toks']] 61 | self.question_id = [t.cls_token_id] # map token to id 62 | self.question_mask_plm = [] # remove SEP token in our case 63 | self.question_subword_len = [] # subword len for each word, exclude SEP token 64 | self.question_total_len = 0 65 | for w in self.question: 66 | toks = t.convert_tokens_to_ids(t.tokenize(w)) 67 | self.question_id.extend(toks) 68 | self.question_subword_len.append(len(toks)) 69 | # self.question_total_len += len(w) 70 | self.question_total_len = len(self.question) 71 | self.question_mask_plm = [0] + [1] * (len(self.question_id) - 1) + [0] 72 | self.question_id.append(t.sep_token_id) 73 | 74 | self.table = [['table'] + t.lower().split() for t in db['table_names']] 75 | ### 增加 label 设置 76 | self.schema_labels = [] 77 | self.table_labels = [] 78 | self.table_begin_end = [] 79 | for table_index in range(len(self.table)): 80 | label = 1 if table_index in ex['used_tables'] else 0 81 | self.schema_labels.append(label) 82 | begin = len(self.table_labels) + 1 83 | self.table_labels.extend([label for _ in range(len(self.table[table_index]))]) 84 | end = len(self.table_labels) -1 85 | self.table_begin_end.append((begin, end)) 86 | 87 | self.table_id, self.table_mask_plm, self.table_subword_len = [], [], [] 88 | self.table_word_len = [] 89 | self.table_total_len = 0 90 | for s in self.table: 91 | l = 0 92 | for w in s: 93 | toks = t.convert_tokens_to_ids(t.tokenize(w)) 94 | self.table_id.extend(toks) 95 | self.table_subword_len.append(len(toks)) 96 | l += len(toks) 97 | self.table_word_len.append(l) 98 | self.table_total_len += len(s) 99 | self.table_mask_plm = [1] * len(self.table_id) 100 | 101 | self.column = [[db['column_types'][idx].lower()] + c.lower().split() for idx, (_, c) in enumerate(db['column_names'])] 102 | ### 增加label设置 103 | self.column_labels = [] 104 | self.column_begin_end = [] 105 | for column_index in range(len(self.column)): 106 | if column_index in ex['used_columns']: 107 | label = 1 108 | else: 109 | label = 0 110 | self.schema_labels.append(label) 111 | begin = len(self.column_labels) + 1 112 | self.column_labels.extend([label for _ in range(len(self.column[column_index]))]) 113 | end = len(self.column_labels) - 1 114 | self.column_begin_end.append((begin, end)) 115 | 116 | self.column_id, self.column_mask_plm, self.column_subword_len = [], [], [] 117 | self.column_word_len = [] 118 | self.column_total_len = 0 119 | for s in self.column: 120 | l = 0 121 | for w in s: 122 | toks = t.convert_tokens_to_ids(t.tokenize(w)) 123 | self.column_id.extend(toks) 124 | self.column_subword_len.append(len(toks)) 125 | l += len(toks) 126 | self.column_word_len.append(l) 127 | self.column_total_len += len(s) 128 | self.column_mask_plm = [1] * len(self.column_id) + [0] 129 | self.column_id.append(t.sep_token_id) 130 | 131 | self.input_id = self.question_id + self.table_id + self.column_id 132 | self.segment_id = [0] * len(self.question_id) + [1] * (len(self.table_id) + len(self.column_id)) \ 133 | if Example.plm != 'grappa_large_jnt' and not Example.plm.startswith('roberta') \ 134 | else [0] * (len(self.question_id) + len(self.table_id) + len(self.column_id)) 135 | 136 | self.question_mask_plm = self.question_mask_plm + [0] * (len(self.table_id) + len(self.column_id)) 137 | self.table_mask_plm = [0] * len(self.question_id) + self.table_mask_plm + [0] * len(self.column_id) 138 | self.column_mask_plm = [0] * (len(self.question_id) + len(self.table_id)) + self.column_mask_plm 139 | 140 | self.graph = Example.graph_factory.graph_construction(ex, db) 141 | 142 | # outputs 143 | if 'schema_weight' in ex['graph'].__dict__.keys(): 144 | self.schema_weight = ex['graph'].schema_weight 145 | self.query = ' '.join(ex['query'].split('\t')) 146 | self.ast = ex['ast'] 147 | self.tgt_action = ex['actions'] 148 | self.used_tables, self.used_columns = ex['used_tables'], ex['used_columns'] 149 | self.column_dict, self.table_dict = {}, {} 150 | 151 | 152 | def get_position_ids(ex, shuffle=True): 153 | # cluster columns with their corresponding table and randomly shuffle tables and columns 154 | # [CLS] q1 q2 ... [SEP] * t1 c1 c2 c3 t2 c4 c5 ... [SEP] 155 | db, table_word_len, column_word_len = ex.db, ex.table_word_len, ex.column_word_len 156 | table_num, column_num = len(db['table_names']), len(db['column_names']) 157 | question_position_id = list(range(len(ex.question_id))) 158 | start = len(question_position_id) 159 | table_position_id, column_position_id = [None] * table_num, [None] * column_num 160 | column_position_id[0] = list(range(start, start + column_word_len[0])) 161 | start += column_word_len[0] # special symbol * first 162 | table_idxs = list(range(table_num)) 163 | if shuffle: 164 | random.shuffle(table_idxs) 165 | for idx in table_idxs: 166 | col_idxs = db['table2columns'][idx] 167 | table_position_id[idx] = list(range(start, start + table_word_len[idx])) 168 | start += table_word_len[idx] 169 | if shuffle: 170 | random.shuffle(col_idxs) 171 | for col_id in col_idxs: 172 | column_position_id[col_id] = list(range(start, start + column_word_len[col_id])) 173 | start += column_word_len[col_id] 174 | position_id = question_position_id + list(chain.from_iterable(table_position_id)) + \ 175 | list(chain.from_iterable(column_position_id)) + [start] 176 | assert len(position_id) == len(ex.input_id) 177 | return position_id 178 | -------------------------------------------------------------------------------- /utils/graph_example.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import numpy as np 3 | import dgl, torch, math 4 | 5 | class GraphExample(): 6 | 7 | pass 8 | 9 | class BatchedGraph(): 10 | 11 | pass 12 | 13 | class GraphFactory(): 14 | 15 | def __init__(self, method='rgatsql', relation_vocab=None): 16 | super(GraphFactory, self).__init__() 17 | self.method = eval('self.' + method) 18 | self.batch_method = eval('self.batch_' + method) 19 | self.relation_vocab = relation_vocab 20 | 21 | def graph_construction(self, ex: dict, db: dict): 22 | return self.method(ex, db) 23 | 24 | def rgatsql(self, ex, db): 25 | graph = GraphExample() 26 | 27 | # local_edges = ex['graph'].local_edges 28 | # rel_ids = list(map(lambda r: self.relation_vocab[r[2]], local_edges)) 29 | # graph.local_edges = torch.tensor(rel_ids, dtype=torch.long) 30 | 31 | global_edges = ex['graph'].global_edges 32 | rel_ids = list(map(lambda r: self.relation_vocab[r[2]], global_edges)) 33 | graph.global_edges = torch.tensor(rel_ids, dtype=torch.long) 34 | 35 | # if 'local_edges2' in ex['graph'].__dict__.keys(): 36 | # local_edges2 = ex['graph'].local_edges2 37 | # rel_ids = list(map(lambda r: self.relation_vocab[r[2]], local_edges2)) 38 | # graph.local_edges2 = torch.tensor(rel_ids, dtype=torch.long) 39 | 40 | if 'global_edges2' in ex['graph'].__dict__.keys(): 41 | global_edges2 = ex['graph'].global_edges2 42 | rel_ids = list(map(lambda r: self.relation_vocab[r[2]], global_edges2)) 43 | graph.global_edges2 = torch.tensor(rel_ids, dtype=torch.long) 44 | 45 | # local_edge_map = {} 46 | # for index in range(len(local_edges)): 47 | # edge = local_edges[index] 48 | # local_edge_map[(edge[0], edge[1])] = index 49 | global_edge_map = {} 50 | for index in range(len(global_edges)): 51 | edge = global_edges[index] 52 | global_edge_map[(edge[0], edge[1])] = index 53 | graph.global_edge_map = global_edge_map 54 | # graph.local_edge_map = local_edge_map 55 | 56 | graph.global_g = ex['graph'].global_g 57 | # graph.local_g, graph.global_g = ex['graph'].local_g, ex['graph'].global_g 58 | graph.gp = ex['graph'].gp 59 | graph.question_mask = torch.tensor(ex['graph'].question_mask, dtype=torch.bool) 60 | graph.schema_mask = torch.tensor(ex['graph'].schema_mask, dtype=torch.bool) 61 | graph.node_label = torch.tensor(ex['graph'].node_label, dtype=torch.float) 62 | # extract local relations (used in msde), global_edges = local_edges + nonlocal_edges 63 | global_enum = graph.global_edges.size(0) 64 | # local_enum, global_enum = graph.local_edges.size(0), graph.global_edges.size(0) 65 | # graph.local_mask = torch.tensor([1] * local_enum + [0] * (global_enum - local_enum), dtype=torch.bool) 66 | # if 'local_edges_mask' in ex['graph'].__dict__.keys(): 67 | # graph.local_edges_mask = ex['graph'].local_edges_mask 68 | return graph 69 | 70 | def batch_graphs(self, ex_list, device, train=True, **kwargs): 71 | """ Batch graphs in example list """ 72 | return self.batch_method(ex_list, device, train=train, **kwargs) 73 | 74 | def batch_rgatsql(self, ex_list, device, train=True, **kwargs): 75 | # method = kwargs.pop('local_and_nonlocal', 'global') 76 | graph_list = ex_list 77 | bg = BatchedGraph() 78 | # bg.local_g = dgl.batch([ex.local_g for ex in graph_list]).to(device) 79 | # bg.local_edges = torch.cat([ex.local_edges for ex in graph_list], dim=0).to(device) 80 | # bg.local_mask = torch.cat([ex.local_mask for ex in ex_list], dim=0).to(device) 81 | bg.global_g = dgl.batch([ex.global_g for ex in graph_list]).to(device) 82 | bg.global_edges = torch.cat([ex.global_edges for ex in graph_list], dim=0).to(device) 83 | 84 | if 'global_edges2' in graph_list[0].__dict__.keys(): 85 | bg.global_edges2 = torch.cat([ex.global_edges2 for ex in graph_list], dim=0).to(device) 86 | 87 | if train: 88 | bg.question_mask = torch.cat([ex.question_mask for ex in graph_list], dim=0).to(device) 89 | bg.schema_mask = torch.cat([ex.schema_mask for ex in graph_list], dim=0).to(device) 90 | smoothing = kwargs.pop('smoothing', 0.0) 91 | node_label = torch.cat([ex.node_label for ex in graph_list], dim=0) 92 | node_label = node_label.masked_fill_(~ node_label.bool(), 2 * smoothing) - smoothing 93 | bg.node_label = node_label.to(device) 94 | bg.gp = dgl.batch([ex.gp for ex in graph_list]).to(device) 95 | return bg 96 | -------------------------------------------------------------------------------- /utils/hyperparams.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import sys, os 3 | 4 | EXP_PATH = 'exp' 5 | 6 | def hyperparam_path(args): 7 | if args.read_model_path and args.testing: 8 | return args.read_model_path 9 | exp_path = hyperparam_path_text2sql(args) 10 | if not os.path.exists(exp_path): 11 | os.makedirs(exp_path) 12 | return exp_path 13 | 14 | def hyperparam_path_text2sql(args): 15 | task = 'task_%s__model_%s_view_%s' % (args.task, args.model, 'global') 16 | task += '' if 'without' in args.output_model else '_gp_%s' % (args.smoothing) 17 | # encoder params 18 | exp_path = 'emb_%s' % (args.embed_size) if args.plm is None else 'plm_%s' % (args.plm) 19 | exp_path += '__gnn_%s_x_%s' % (args.gnn_hidden_size, args.gnn_num_layers) 20 | exp_path += '__share' if args.relation_share_layers else '' 21 | exp_path += '__head_%s' % (args.num_heads) 22 | exp_path += '__share' if args.relation_share_heads else '' 23 | exp_path += '__dp_%s' % (args.dropout) 24 | exp_path += '__dpa_%s' % (args.attn_drop) 25 | exp_path += '__dpc_%s' % (args.drop_connect) 26 | # decoder params 27 | # exp_path += '__cell_%s_%s_x_%s' % (args.lstm, args.lstm_hidden_size, args.lstm_num_layers) 28 | # exp_path += '_chunk_%s' % (args.chunk_size) if args.lstm == 'onlstm' else '' 29 | # exp_path += '_no' if args.no_parent_state else '' 30 | # exp_path += '__attvec_%s' % (args.att_vec_size) 31 | # exp_path += '__sepcxt' if args.sep_cxt else '__jointcxt' 32 | # exp_path += '_no' if args.no_context_feeding else '' 33 | # exp_path += '__ae_%s' % (args.action_embed_size) 34 | # exp_path += '_no' if args.no_parent_production_embed else '' 35 | # exp_path += '__fe_%s' % ('no' if args.no_parent_field_embed else args.field_embed_size) 36 | # exp_path += '__te_%s' % ('no' if args.no_parent_field_type_embed else args.type_embed_size) 37 | # training params 38 | exp_path += '__bs_%s' % (args.batch_size) 39 | exp_path += '__lr_%s' % (args.lr) if args.plm is None else '__lr_%s_ld_%s' % (args.lr, args.layerwise_decay) 40 | exp_path += '__l2_%s' % (args.l2) 41 | exp_path += '__wp_%s' % (args.warmup_ratio) 42 | exp_path += '__sd_%s' % (args.lr_schedule) 43 | exp_path += '__me_%s' % (args.max_epoch) 44 | exp_path += '__mn_%s' % (args.max_norm) 45 | exp_path += '__bm_%s' % (args.beam_size) 46 | exp_path += '__seed_%s' % (args.seed) 47 | exp_path = os.path.join(EXP_PATH, task, exp_path) 48 | return exp_path 49 | -------------------------------------------------------------------------------- /utils/initialization.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | """ Utility functions include: 3 | 1. set output logging path 4 | 2. set random seed for all libs 5 | 3. select torch.device 6 | """ 7 | 8 | import sys, os, logging 9 | import random, torch, dgl 10 | import numpy as np 11 | 12 | def set_logger(exp_path, testing=False): 13 | logFormatter = logging.Formatter('%(asctime)s - %(message)s') #('%(asctime)s - %(levelname)s - %(message)s') 14 | logger = logging.getLogger('mylogger') 15 | logger.setLevel(logging.DEBUG) 16 | if testing: 17 | fileHandler = logging.FileHandler('%s/log_test.txt' % (exp_path), mode='w') 18 | else: 19 | fileHandler = logging.FileHandler('%s/log_train.txt' % (exp_path), mode='w') 20 | fileHandler.setFormatter(logFormatter) 21 | logger.addHandler(fileHandler) 22 | consoleHandler = logging.StreamHandler(sys.stdout) 23 | consoleHandler.setFormatter(logFormatter) 24 | logger.addHandler(consoleHandler) 25 | return logger 26 | 27 | def set_random_seed(random_seed=999): 28 | random.seed(random_seed) 29 | torch.manual_seed(random_seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed(random_seed) 32 | np.random.seed(random_seed) 33 | dgl.random.seed(random_seed) 34 | 35 | def set_torch_device(deviceId): 36 | if deviceId < 0: 37 | device = torch.device("cpu") 38 | else: 39 | assert torch.cuda.device_count() >= deviceId + 1 40 | device = torch.device("cuda:%d" % (deviceId)) 41 | ## These two sentences are used to ensure reproducibility with cudnnbacken 42 | # torch.backends.cudnn.deterministic = True 43 | # torch.backends.cudnn.benchmark = False 44 | torch.backends.cudnn.enabled = False 45 | return device 46 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorboardX import SummaryWriter 3 | import json 4 | import csv 5 | import sys 6 | import datetime 7 | 8 | class Logger: 9 | def __init__(self, log_path=None, config=None, reopen_to_flush=False): 10 | self.log_file = None 11 | self.reopen_to_flush = reopen_to_flush 12 | self.log_path = log_path 13 | if log_path is not None: 14 | os.makedirs(os.path.dirname(os.path.join(log_path, 'log.txt')), exist_ok=True) 15 | self.log_file = open(os.path.join(log_path, 'log.txt'), 'a+') 16 | with open(os.path.join(log_path, 'command.sh'), 'w') as f: 17 | f.write("CUDA_VISIBLE_DEVICES=1 python "+" ".join(sys.argv) + '\n') 18 | f.write('tensorboard --logdir {} --bind_all'.format(os.path.join(log_path, 'tensorboard'))) 19 | if config is not None: 20 | with open(os.path.join(log_path, 'log_config.json'), 'w') as f: 21 | json.dump(config, f, indent=4) 22 | self.summary_writer = SummaryWriter(os.path.join(log_path, 'tensorboard')) 23 | self.csv_writers = {} 24 | self.csv_files = {} 25 | 26 | def log(self, msg): 27 | formatted = f'[{datetime.datetime.now().replace(microsecond=0).isoformat()}] {msg}' 28 | print(formatted) 29 | if self.log_file: 30 | self.log_file.write(formatted + '\n') 31 | if self.reopen_to_flush: 32 | log_path = self.log_file.name 33 | self.log_file.close() 34 | self.log_file = open(log_path, 'a+') 35 | else: 36 | self.log_file.flush() 37 | 38 | def write_metrics(self, metrics, name): 39 | metrics = self.regular_metrics(metrics) 40 | for key in metrics: 41 | self.log("{}: {}".format(key, metrics[key])) 42 | 43 | if name not in self.csv_writers: 44 | csv_file = open(os.path.join(self.log_path, 'metrics_'+name+'.csv'), 'w') 45 | csv_writer = csv.DictWriter(csv_file, fieldnames=metrics.keys()) 46 | csv_writer.writeheader() 47 | self.csv_writers[name] = csv_writer 48 | self.csv_files[name] = csv_file 49 | 50 | self.csv_writers[name].writerow(metrics) 51 | self.csv_files[name].flush() 52 | 53 | def regular_metrics(self, metrics): 54 | for key in metrics: 55 | if isinstance(metrics[key],float): 56 | metrics[key] = format(metrics[key], '.4f') 57 | return metrics -------------------------------------------------------------------------------- /utils/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import re, math 19 | import torch 20 | from torch.optim import Optimizer 21 | from torch.optim.lr_scheduler import LambdaLR 22 | from torch.nn.utils import clip_grad_norm_ 23 | from collections import defaultdict 24 | import torch.nn.init as init 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def glorot_init(params): 30 | for p in params: 31 | if len(p.data.size()) > 1: 32 | init.xavier_normal_(p.data) 33 | 34 | 35 | def set_optimizer(model, args, num_warmup_steps, num_training_steps, last_epoch=-1): 36 | plm = hasattr(model.input_layer, 'plm_model') 37 | if plm and args.layerwise_decay <= 0.: # fix plm params 38 | for n, p in model.named_parameters(): 39 | if 'plm_model' in n: 40 | p.requires_grad = False 41 | params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] 42 | no_decay = ['bias', 'LayerNorm.weight'] 43 | # import ipdb; ipdb.set_trace() 44 | # plm_params = [ 45 | # p for n, p in model.named_parameters() if 'plm_model' in n 46 | # ] 47 | other_params = [ 48 | p for n, p in model.named_parameters() if 'plm_model' not in n 49 | ] 50 | LayerNorm_params = [n for n, p in model.named_parameters() if 'plm_model' not in n and 'norm' in n] 51 | # import ipdb; ipdb.set_trace() 52 | if plm and 0. < args.layerwise_decay <= 0.5: # seperate lr for plm 53 | grouped_params = [ 54 | {'params': list(set([p for n, p in params if 'plm_model' in n and not any(nd in n for nd in no_decay)])), 'lr': args.layerwise_decay * args.lr, 'weight_decay': args.l2}, 55 | {'params': list(set([p for n, p in params if 'plm_model' in n and any(nd in n for nd in no_decay)])), 'lr': args.layerwise_decay * args.lr, 'weight_decay': 0.0}, 56 | {'params': list(set([p for n, p in params if 'plm_model' not in n and not any(nd in n for nd in no_decay)])), 'weight_decay': args.l2}, 57 | {'params': list(set([p for n, p in params if 'plm_model' not in n and any(nd in n for nd in no_decay)])), 'weight_decay': 0.0}, 58 | ] 59 | print('Use seperate lr %f for pretrained model ...' % (args.lr * args.layerwise_decay)) 60 | elif plm and 0.5 < args.layerwise_decay < 1.: # lr decay layerwise for plm 61 | pattern = r'encoder\.layer\.(.*?)\.' 62 | num_layers = int(model.input_layer.plm_model.config.num_hidden_layers) 63 | groups = {"decay": defaultdict(list), "no_decay": defaultdict(list)} # record grouped params 64 | for n, p in params: 65 | res = re.search(pattern, n) if 'plm_model' in n else None 66 | depth = int(res.group(1)) if res is not None else 0 if 'plm_model' in n else num_layers 67 | if any(nd in n for nd in no_decay): 68 | groups["no_decay"][int(depth)].append(p) 69 | else: 70 | groups["decay"][int(depth)].append(p) 71 | grouped_params = [] 72 | for d in groups["decay"]: 73 | lr = args.lr * (args.layerwise_decay ** (num_layers - d)) 74 | grouped_params.append({'params': list(set(groups["decay"][d])), 'lr': lr, 'weight_decay': args.l2}) 75 | for d in groups["no_decay"]: 76 | lr = args.lr * (args.layerwise_decay ** (num_layers - d)) 77 | grouped_params.append({'params': list(set(groups["no_decay"][d])), 'lr': lr, 'weight_decay': 0.0}) 78 | print('Use layerwise decay (rate %f) lr %f for pretrained model ...' % (args.layerwise_decay, args.lr)) 79 | else: # the same lr for plm and other modules 80 | grouped_params = [ 81 | {'params': list(set([p for n, p in params if not any(nd in n for nd in no_decay)])), 'weight_decay': args.l2}, 82 | {'params': list(set([p for n, p in params if any(nd in n for nd in no_decay)])), 'weight_decay': 0.0}, 83 | ] 84 | print('Use the same lr %f for all parameters ...' % (args.lr)) 85 | # glorot_init(other_params) 86 | optimizer = AdamW(grouped_params, lr=args.lr, max_grad_norm=args.max_norm) 87 | schedule_func = schedule_dict[args.lr_schedule] 88 | scheduler = schedule_func(optimizer, num_warmup_steps, num_training_steps, last_epoch=last_epoch) 89 | return optimizer, scheduler 90 | 91 | def get_ratsql_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 92 | """ Create a schedule with a learning rate that decreases according to the formular 93 | in RATSQL model 94 | """ 95 | def lr_lambda(current_step): 96 | if current_step < num_warmup_steps: 97 | return float(current_step) / float(max(1.0, num_warmup_steps)) 98 | return max(0.0, math.sqrt((num_training_steps - current_step) / float(num_training_steps - num_warmup_steps))) 99 | return LambdaLR(optimizer, lr_lambda, last_epoch) 100 | 101 | def get_constant_schedule(optimizer, *args, last_epoch=-1): 102 | """ Create a schedule with a constant learning rate. 103 | """ 104 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 105 | 106 | 107 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 108 | """ Create a schedule with a constant learning rate preceded by a warmup 109 | period during which the learning rate increases linearly between 0 and 1. 110 | """ 111 | 112 | def lr_lambda(current_step): 113 | if current_step < num_warmup_steps: 114 | return float(current_step) / float(max(1.0, num_warmup_steps)) 115 | return 1.0 116 | 117 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 118 | 119 | 120 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 121 | """ Create a schedule with a learning rate that decreases linearly after 122 | linearly increasing during a warmup period. 123 | """ 124 | 125 | def lr_lambda(current_step): 126 | if current_step < num_warmup_steps: 127 | return float(current_step) / float(max(1, num_warmup_steps)) 128 | return max( 129 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 130 | ) 131 | 132 | return LambdaLR(optimizer, lr_lambda, last_epoch) 133 | 134 | 135 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1): 136 | """ Create a schedule with a learning rate that decreases following the 137 | values of the cosine function between 0 and `pi * cycles` after a warmup 138 | period during which it increases linearly between 0 and 1. 139 | """ 140 | 141 | def lr_lambda(current_step): 142 | if current_step < num_warmup_steps: 143 | return float(current_step) / float(max(1, num_warmup_steps)) 144 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 145 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 146 | 147 | return LambdaLR(optimizer, lr_lambda, last_epoch) 148 | 149 | 150 | def get_cosine_with_hard_restarts_schedule_with_warmup( 151 | optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1 152 | ): 153 | """ Create a schedule with a learning rate that decreases following the 154 | values of the cosine function with several hard restarts, after a warmup 155 | period during which it increases linearly between 0 and 1. 156 | """ 157 | 158 | def lr_lambda(current_step): 159 | if current_step < num_warmup_steps: 160 | return float(current_step) / float(max(1, num_warmup_steps)) 161 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 162 | if progress >= 1.0: 163 | return 0.0 164 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 165 | 166 | return LambdaLR(optimizer, lr_lambda, last_epoch) 167 | 168 | schedule_dict = { 169 | "constant": get_constant_schedule, 170 | "linear": get_linear_schedule_with_warmup, 171 | "ratsql": get_ratsql_schedule_with_warmup, 172 | "cosine": get_cosine_schedule_with_warmup, 173 | } 174 | 175 | class AdamW(Optimizer): 176 | """ Implements Adam algorithm with weight decay fix. 177 | 178 | Parameters: 179 | lr (float): learning rate. Default 1e-3. 180 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 181 | eps (float): Adams epsilon. Default: 1e-6 182 | weight_decay (float): Weight decay. Default: 0.0 183 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 184 | """ 185 | 186 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, max_grad_norm=-1, correct_bias=True): 187 | if lr < 0.0: 188 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 189 | if not 0.0 <= betas[0] < 1.0: 190 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 191 | if not 0.0 <= betas[1] < 1.0: 192 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 193 | if not 0.0 <= eps: 194 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 195 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, max_grad_norm=max_grad_norm, correct_bias=correct_bias) 196 | super().__init__(params, defaults) 197 | 198 | def step(self, closure=None): 199 | """Performs a single optimization step. 200 | 201 | Arguments: 202 | closure (callable, optional): A closure that reevaluates the model 203 | and returns the loss. 204 | """ 205 | loss = None 206 | if closure is not None: 207 | loss = closure() 208 | 209 | for group in self.param_groups: 210 | for p in group["params"]: 211 | if p.grad is None: 212 | continue 213 | grad = p.grad.data 214 | if grad.is_sparse: 215 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 216 | 217 | state = self.state[p] 218 | 219 | # State initialization 220 | if len(state) == 0: 221 | state["step"] = 0 222 | # Exponential moving average of gradient values 223 | state["exp_avg"] = torch.zeros_like(p.data) 224 | # Exponential moving average of squared gradient values 225 | state["exp_avg_sq"] = torch.zeros_like(p.data) 226 | 227 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 228 | beta1, beta2 = group["betas"] 229 | 230 | state["step"] += 1 231 | 232 | # Add grad clipping 233 | if group['max_grad_norm'] > 0: 234 | clip_grad_norm_(p, group['max_grad_norm']) 235 | 236 | # Decay the first and second moment running average coefficient 237 | # In-place operations to update the averages at the same time 238 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 239 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 240 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 241 | 242 | step_size = group["lr"] 243 | if group["correct_bias"]: # No bias correction for Bert 244 | bias_correction1 = 1.0 - beta1 ** state["step"] 245 | bias_correction2 = 1.0 - beta2 ** state["step"] 246 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 247 | 248 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 249 | 250 | # Just adding the square of the weights to the loss function is *not* 251 | # the correct way of using L2 regularization/weight decay with Adam, 252 | # since that will interact with the m and v parameters in strange ways. 253 | # 254 | # Instead we want to decay the weights in a manner that doesn't interact 255 | # with the m/v parameters. This is equivalent to adding the square 256 | # of the weights to the loss with plain (non-momentum) SGD. 257 | # Add weight decay at the end (fixed version) 258 | if group["weight_decay"] > 0.0: 259 | p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) 260 | 261 | return loss 262 | -------------------------------------------------------------------------------- /utils/vocab.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from utils.constants import PAD, UNK, BOS, EOS 3 | 4 | class Vocab(): 5 | 6 | def __init__(self, padding=False, unk=False, boundary=False, min_freq=1, 7 | filepath=None, iterable=None, default=UNK, specials=[]): 8 | super(Vocab, self).__init__() 9 | self.word2id = dict() 10 | self.id2word = dict() 11 | self.default = default # if default is None, ensure that no oov words 12 | if padding: 13 | idx = len(self.word2id) 14 | self.word2id[PAD], self.id2word[idx] = idx, PAD 15 | if unk: 16 | idx = len(self.word2id) 17 | self.word2id[UNK], self.id2word[idx] = idx, UNK 18 | if boundary: 19 | idx = len(self.word2id) 20 | self.word2id[BOS], self.id2word[idx] = idx, BOS 21 | self.word2id[EOS], self.id2word[idx + 1] = idx + 1, EOS 22 | for w in specials: 23 | if w not in self.word2id: 24 | idx = len(self.word2id) 25 | self.word2id[w], self.id2word[idx] = idx, w 26 | if filepath is not None: 27 | self.from_filepath(filepath, min_freq=min_freq) 28 | elif iterable is not None: 29 | self.from_iterable(iterable) 30 | assert (self.default is None) or (self.default in self.word2id) 31 | 32 | def from_filepath(self, filepath, min_freq=1): 33 | with open(filepath, 'r', encoding='utf-8') as inf: 34 | for line in inf: 35 | line = line.strip() 36 | if line == '': continue 37 | line = line.split('\t') # ignore count or frequency 38 | if len(line) == 1: 39 | word, freq = line[0], min_freq 40 | else: 41 | assert len(line) == 2 42 | word, freq = line 43 | word = word.lower() 44 | if word not in self.word2id and int(freq) >= min_freq: 45 | idx = len(self.word2id) 46 | self.word2id[word] = idx 47 | self.id2word[idx] = word 48 | 49 | def from_iterable(self, iterable): 50 | for item in iterable: 51 | if item not in self.word2id: 52 | idx = len(self.word2id) 53 | self.word2id[item] = idx 54 | self.id2word[idx] = item 55 | 56 | def __len__(self): 57 | return len(self.word2id) 58 | 59 | @property 60 | def vocab_size(self): 61 | return len(self.word2id) 62 | 63 | def __getitem__(self, key): 64 | """ If self.default is None, it means we do not allow out of vocabulary token; 65 | If self.default is not None, we get the idx of self.default if key does not exist. 66 | """ 67 | if self.default is None: 68 | return self.word2id[key] 69 | else: 70 | return self.word2id.get(key, self.word2id[self.default]) 71 | -------------------------------------------------------------------------------- /utils/word2vec.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | 3 | from embeddings import GloveEmbedding 4 | import numpy as np 5 | from utils.constants import PAD 6 | import torch, random 7 | 8 | class Word2vecUtils(): 9 | 10 | def __init__(self): 11 | super(Word2vecUtils, self).__init__() 12 | self.word_embed = GloveEmbedding('common_crawl_48', d_emb=300) 13 | self.initializer = lambda: np.random.normal(size=300).tolist() 14 | 15 | def load_embeddings(self, module, vocab, device='cpu'): 16 | """ Initialize the embedding with glove and char embedding 17 | """ 18 | emb_size = module.weight.data.size(-1) 19 | assert emb_size == 300, 'Embedding size is not 300, cannot be initialized by GLOVE' 20 | outliers = 0 21 | for word in vocab.word2id: 22 | if word == PAD: # PAD symbol is always 0-vector 23 | module.weight.data[vocab[PAD]] = torch.zeros(emb_size, dtype=torch.float, device=device) 24 | continue 25 | word_emb = self.word_embed.emb(word, default='none') 26 | if word_emb[0] is None: # oov 27 | word_emb = self.initializer() 28 | outliers += 1 29 | module.weight.data[vocab[word]] = torch.tensor(word_emb, dtype=torch.float, device=device) 30 | return 1 - outliers / float(len(vocab)) 31 | 32 | def emb(self, word): 33 | word_emb = self.word_embed.emb(word, default='none') 34 | if word_emb[0] is None: 35 | return None 36 | else: 37 | return word_emb 38 | --------------------------------------------------------------------------------