├── LICENSE ├── README.md ├── data_util ├── anonymization.py ├── atis_batch.py ├── atis_data.py ├── atis_vocab.py ├── dataset_split.py ├── entities.py ├── interaction.py ├── snippets.py ├── sql_util.py ├── tokenizers.py ├── util.py ├── utterance.py └── vocabulary.py ├── eval_scripts ├── evaluation.py ├── evaluation_sqa.py ├── metric_averages.py ├── process_sql.py └── process_sql.pyc ├── logger.py ├── model ├── attention.py ├── bert │ ├── LICENSE_bert │ ├── README_bert.md │ ├── convert_tf_checkpoint_to_pytorch.py │ ├── data │ │ └── annotated_wikisql_and_PyTorch_bert_param │ │ │ ├── bert_config_uncased_L-12_H-768_A-12.json │ │ │ └── vocab_uncased_L-12_H-768_A-12.txt │ ├── modeling.py │ ├── notebooks │ │ ├── Comparing TF and PT models SQuAD predictions.ipynb │ │ └── Comparing TF and PT models.ipynb │ └── tokenization.py ├── decoder.py ├── embedder.py ├── encoder.py ├── model.py ├── schema_interaction_model.py ├── token_predictor.py ├── torch_utils.py └── utils_bert.py ├── model_util.py ├── parse_args.py ├── postprocess_eval.py ├── preprocess.py ├── requirements.txt ├── run.py ├── run_cosql.sh ├── run_sparc.sh ├── test_cosql.sh └── test_sparc.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IGSQL 2 | 3 | This is the implementation of EMNLP 2020 paper [IGSQL: Database Schema Interaction Graph Based Neural Model for Context-Dependent Text-to-SQL Generation](https://www.aclweb.org/anthology/2020.emnlp-main.560.pdf) 4 | 5 | ### Dependency 6 | 7 | folder `data/ data/database model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin` can be found [here](https://github.com/ryanzhumich/editsql) 8 | 9 | ### How to run 10 | 11 | use `./run_sparc.sh` or `./run_cosql.sh` to train and evaluate on the best model. 12 | 13 | use `./test_sparc.sh` or `./test_cosql.sh` (change the save_file path) to evaluate the specific model. 14 | 15 | # Acknowledgement 16 | 17 | We adapt the code from [editsql](https://github.com/ryanzhumich/editsql). Thanks to the open-source project. 18 | -------------------------------------------------------------------------------- /data_util/anonymization.py: -------------------------------------------------------------------------------- 1 | """Code for identifying and anonymizing entities in NL and SQL.""" 2 | 3 | import copy 4 | import json 5 | from . import util 6 | 7 | ENTITY_NAME = "ENTITY" 8 | CONSTANT_NAME = "CONSTANT" 9 | TIME_NAME = "TIME" 10 | SEPARATOR = "#" 11 | 12 | 13 | def timeval(string): 14 | """Returns the numeric version of a time. 15 | 16 | Inputs: 17 | string (str): String representing a time. 18 | 19 | Returns: 20 | String representing the absolute time. 21 | """ 22 | if string.endswith("am") or string.endswith( 23 | "pm") and string[:-2].isdigit(): 24 | numval = int(string[:-2]) 25 | if len(string) == 3 or len(string) == 4: 26 | numval *= 100 27 | if string.endswith("pm"): 28 | numval += 1200 29 | return str(numval) 30 | return "" 31 | 32 | 33 | def is_time(string): 34 | """Returns whether a string represents a time. 35 | 36 | Inputs: 37 | string (str): String to check. 38 | 39 | Returns: 40 | Whether the string represents a time. 41 | """ 42 | if string.endswith("am") or string.endswith("pm"): 43 | if string[:-2].isdigit(): 44 | return True 45 | 46 | return False 47 | 48 | 49 | def deanonymize(sequence, ent_dict, key): 50 | """Deanonymizes a sequence. 51 | 52 | Inputs: 53 | sequence (list of str): List of tokens to deanonymize. 54 | ent_dict (dict str->(dict str->str)): Maps from tokens to the entity dictionary. 55 | key (str): The key to use, in this case either natural language or SQL. 56 | 57 | Returns: 58 | Deanonymized sequence of tokens. 59 | """ 60 | new_sequence = [] 61 | for token in sequence: 62 | if token in ent_dict: 63 | new_sequence.extend(ent_dict[token][key]) 64 | else: 65 | new_sequence.append(token) 66 | 67 | return new_sequence 68 | 69 | 70 | class Anonymizer: 71 | """Anonymization class for keeping track of entities in this domain and 72 | scripts for anonymizing/deanonymizing. 73 | 74 | Members: 75 | anonymization_map (list of dict (str->str)): Containing entities from 76 | the anonymization file. 77 | entity_types (list of str): All entities in the anonymization file. 78 | keys (set of str): Possible keys (types of text handled); in this case it should be 79 | one for natural language and another for SQL. 80 | entity_set (set of str): entity_types as a set. 81 | """ 82 | def __init__(self, filename): 83 | self.anonymization_map = [] 84 | self.entity_types = [] 85 | self.keys = set() 86 | 87 | pairs = [json.loads(line) for line in open(filename).readlines()] 88 | for pair in pairs: 89 | for key in pair: 90 | if key != "type": 91 | self.keys.add(key) 92 | self.anonymization_map.append(pair) 93 | if pair["type"] not in self.entity_types: 94 | self.entity_types.append(pair["type"]) 95 | 96 | self.entity_types.append(ENTITY_NAME) 97 | self.entity_types.append(CONSTANT_NAME) 98 | self.entity_types.append(TIME_NAME) 99 | 100 | self.entity_set = set(self.entity_types) 101 | 102 | def get_entity_type_from_token(self, token): 103 | """Gets the type of an entity given an anonymized token. 104 | 105 | Inputs: 106 | token (str): The entity token. 107 | 108 | Returns: 109 | str, representing the type of the entity. 110 | """ 111 | # these are in the pattern NAME:#, so just strip the thing after the 112 | # colon 113 | colon_loc = token.index(SEPARATOR) 114 | entity_type = token[:colon_loc] 115 | assert entity_type in self.entity_set 116 | 117 | return entity_type 118 | 119 | def is_anon_tok(self, token): 120 | """Returns whether a token is an anonymized token or not. 121 | 122 | Input: 123 | token (str): The token to check. 124 | 125 | Returns: 126 | bool, whether the token is an anonymized token. 127 | """ 128 | return token.split(SEPARATOR)[0] in self.entity_set 129 | 130 | def get_anon_id(self, token): 131 | """Gets the entity index (unique ID) for a token. 132 | 133 | Input: 134 | token (str): The token to get the index from. 135 | 136 | Returns: 137 | int, the token ID if it is an anonymized token; otherwise -1. 138 | """ 139 | if self.is_anon_tok(token): 140 | return self.entity_types.index(token.split(SEPARATOR)[0]) 141 | else: 142 | return -1 143 | 144 | def anonymize(self, 145 | sequence, 146 | tok_to_entity_dict, 147 | key, 148 | add_new_anon_toks=False): 149 | """Anonymizes a sequence. 150 | 151 | Inputs: 152 | sequence (list of str): Sequence to anonymize. 153 | tok_to_entity_dict (dict): Existing dictionary mapping from anonymized 154 | tokens to entities. 155 | key (str): Which kind of text this is (natural language or SQL) 156 | add_new_anon_toks (bool): Whether to add new entities to tok_to_entity_dict. 157 | 158 | Returns: 159 | list of str, the anonymized sequence. 160 | """ 161 | # Sort the token-tok-entity dict by the length of the modality. 162 | sorted_dict = sorted(tok_to_entity_dict.items(), 163 | key=lambda k: len(k[1][key]))[::-1] 164 | 165 | anonymized_sequence = copy.deepcopy(sequence) 166 | 167 | if add_new_anon_toks: 168 | type_counts = {} 169 | for entity_type in self.entity_types: 170 | type_counts[entity_type] = 0 171 | for token in tok_to_entity_dict: 172 | entity_type = self.get_entity_type_from_token(token) 173 | type_counts[entity_type] += 1 174 | 175 | # First find occurrences of things in the anonymization dictionary. 176 | for token, modalities in sorted_dict: 177 | our_modality = modalities[key] 178 | 179 | # Check if this key's version of the anonymized thing is in our 180 | # sequence. 181 | while util.subsequence(our_modality, anonymized_sequence): 182 | found = False 183 | for startidx in range( 184 | len(anonymized_sequence) - len(our_modality) + 1): 185 | if anonymized_sequence[startidx:startidx + 186 | len(our_modality)] == our_modality: 187 | anonymized_sequence = anonymized_sequence[:startidx] + [ 188 | token] + anonymized_sequence[startidx + len(our_modality):] 189 | found = True 190 | break 191 | assert found, "Thought " \ 192 | + str(our_modality) + " was in [" \ 193 | + str(anonymized_sequence) + "] but could not find it" 194 | 195 | # Now add new keys if they are present. 196 | if add_new_anon_toks: 197 | 198 | # For every span in the sequence, check whether it is in the anon map 199 | # for this modality 200 | sorted_anon_map = sorted(self.anonymization_map, 201 | key=lambda k: len(k[key]))[::-1] 202 | 203 | for pair in sorted_anon_map: 204 | our_modality = pair[key] 205 | 206 | token_type = pair["type"] 207 | new_token = token_type + SEPARATOR + \ 208 | str(type_counts[token_type]) 209 | 210 | while util.subsequence(our_modality, anonymized_sequence): 211 | found = False 212 | for startidx in range( 213 | len(anonymized_sequence) - len(our_modality) + 1): 214 | if anonymized_sequence[startidx:startidx + \ 215 | len(our_modality)] == our_modality: 216 | if new_token not in tok_to_entity_dict: 217 | type_counts[token_type] += 1 218 | tok_to_entity_dict[new_token] = pair 219 | 220 | anonymized_sequence = anonymized_sequence[:startidx] + [ 221 | new_token] + anonymized_sequence[startidx + len(our_modality):] 222 | found = True 223 | break 224 | assert found, "Thought " \ 225 | + str(our_modality) + " was in [" \ 226 | + str(anonymized_sequence) + "] but could not find it" 227 | 228 | # Also replace integers with constants 229 | for index, token in enumerate(anonymized_sequence): 230 | if token.isdigit() or is_time(token): 231 | if token.isdigit(): 232 | entity_type = CONSTANT_NAME 233 | value = new_token 234 | if is_time(token): 235 | entity_type = TIME_NAME 236 | value = timeval(token) 237 | 238 | # First try to find the constant in the entity dictionary already, 239 | # and get the name if it's found. 240 | new_token = "" 241 | new_dict = {} 242 | found = False 243 | for entity, value in tok_to_entity_dict.items(): 244 | if value[key][0] == token: 245 | new_token = entity 246 | new_dict = value 247 | found = True 248 | break 249 | 250 | if not found: 251 | new_token = entity_type + SEPARATOR + \ 252 | str(type_counts[entity_type]) 253 | new_dict = {} 254 | for tempkey in self.keys: 255 | new_dict[tempkey] = [token] 256 | 257 | tok_to_entity_dict[new_token] = new_dict 258 | type_counts[entity_type] += 1 259 | 260 | anonymized_sequence[index] = new_token 261 | 262 | return anonymized_sequence 263 | -------------------------------------------------------------------------------- /data_util/atis_batch.py: -------------------------------------------------------------------------------- 1 | # TODO: review this entire file and make it much simpler. 2 | 3 | import copy 4 | from . import snippets as snip 5 | from . import sql_util 6 | from . import vocabulary as vocab 7 | 8 | 9 | class UtteranceItem(): 10 | def __init__(self, interaction, index): 11 | self.interaction = interaction 12 | self.utterance_index = index 13 | 14 | def __str__(self): 15 | return str(self.interaction.utterances[self.utterance_index]) 16 | 17 | def histories(self, maximum): 18 | if maximum > 0: 19 | history_seqs = [] 20 | for utterance in self.interaction.utterances[:self.utterance_index]: 21 | history_seqs.append(utterance.input_seq_to_use) 22 | 23 | if len(history_seqs) > maximum: 24 | history_seqs = history_seqs[-maximum:] 25 | 26 | return history_seqs 27 | return [] 28 | 29 | def input_sequence(self): 30 | return self.interaction.utterances[self.utterance_index].input_seq_to_use 31 | 32 | def previous_query(self): 33 | if self.utterance_index == 0: 34 | return [] 35 | return self.interaction.utterances[self.utterance_index - 36 | 1].anonymized_gold_query 37 | 38 | def anonymized_gold_query(self): 39 | return self.interaction.utterances[self.utterance_index].anonymized_gold_query 40 | 41 | def snippets(self): 42 | return self.interaction.utterances[self.utterance_index].available_snippets 43 | 44 | def original_gold_query(self): 45 | return self.interaction.utterances[self.utterance_index].original_gold_query 46 | 47 | def contained_entities(self): 48 | return self.interaction.utterances[self.utterance_index].contained_entities 49 | 50 | def original_gold_queries(self): 51 | return [ 52 | q[0] for q in self.interaction.utterances[self.utterance_index].all_gold_queries] 53 | 54 | def gold_tables(self): 55 | return [ 56 | q[1] for q in self.interaction.utterances[self.utterance_index].all_gold_queries] 57 | 58 | def gold_query(self): 59 | return self.interaction.utterances[self.utterance_index].gold_query_to_use + [ 60 | vocab.EOS_TOK] 61 | 62 | def gold_edit_sequence(self): 63 | return self.interaction.utterances[self.utterance_index].gold_edit_sequence 64 | 65 | def gold_table(self): 66 | return self.interaction.utterances[self.utterance_index].gold_sql_results 67 | 68 | def all_snippets(self): 69 | return self.interaction.snippets 70 | 71 | def within_limits(self, 72 | max_input_length=float('inf'), 73 | max_output_length=float('inf')): 74 | return self.interaction.utterances[self.utterance_index].length_valid( 75 | max_input_length, max_output_length) 76 | 77 | def expand_snippets(self, sequence): 78 | # Remove the EOS 79 | if sequence[-1] == vocab.EOS_TOK: 80 | sequence = sequence[:-1] 81 | 82 | # First remove the snippets 83 | no_snippets_sequence = self.interaction.expand_snippets(sequence) 84 | no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence) 85 | return no_snippets_sequence 86 | 87 | def flatten_sequence(self, sequence): 88 | # Remove the EOS 89 | if sequence[-1] == vocab.EOS_TOK: 90 | sequence = sequence[:-1] 91 | 92 | # First remove the snippets 93 | no_snippets_sequence = self.interaction.expand_snippets(sequence) 94 | 95 | # Deanonymize 96 | deanon_sequence = self.interaction.deanonymize( 97 | no_snippets_sequence, "sql") 98 | return deanon_sequence 99 | 100 | 101 | class UtteranceBatch(): 102 | def __init__(self, items): 103 | self.items = items 104 | 105 | def __len__(self): 106 | return len(self.items) 107 | 108 | def start(self): 109 | self.index = 0 110 | 111 | def next(self): 112 | item = self.items[self.index] 113 | self.index += 1 114 | return item 115 | 116 | def done(self): 117 | return self.index >= len(self.items) 118 | 119 | class PredUtteranceItem(): 120 | def __init__(self, 121 | input_sequence, 122 | interaction_item, 123 | previous_query, 124 | index, 125 | available_snippets): 126 | self.input_seq_to_use = input_sequence 127 | self.interaction_item = interaction_item 128 | self.index = index 129 | self.available_snippets = available_snippets 130 | self.prev_pred_query = previous_query 131 | 132 | def input_sequence(self): 133 | return self.input_seq_to_use 134 | 135 | def histories(self, maximum): 136 | if maximum == 0: 137 | return histories 138 | histories = [] 139 | for utterance in self.interaction_item.processed_utterances[:self.index]: 140 | histories.append(utterance.input_sequence()) 141 | if len(histories) > maximum: 142 | histories = histories[-maximum:] 143 | return histories 144 | 145 | def snippets(self): 146 | return self.available_snippets 147 | 148 | def previous_query(self): 149 | return self.prev_pred_query 150 | 151 | def flatten_sequence(self, sequence): 152 | return self.interaction_item.flatten_sequence(sequence) 153 | 154 | def remove_snippets(self, sequence): 155 | return sql_util.fix_parentheses( 156 | self.interaction_item.expand_snippets(sequence)) 157 | 158 | def set_predicted_query(self, query): 159 | self.anonymized_pred_query = query 160 | 161 | # Mocks an Interaction item, but allows for the parameters to be updated during 162 | # the process 163 | 164 | 165 | class InteractionItem(): 166 | def __init__(self, 167 | interaction, 168 | max_input_length=float('inf'), 169 | max_output_length=float('inf'), 170 | nl_to_sql_dict={}, 171 | maximum_length=float('inf')): 172 | if maximum_length != float('inf'): 173 | self.interaction = copy.deepcopy(interaction) 174 | self.interaction.utterances = self.interaction.utterances[:maximum_length] 175 | else: 176 | self.interaction = interaction 177 | self.processed_utterances = [] 178 | self.snippet_bank = [] 179 | self.identifier = self.interaction.identifier 180 | 181 | self.max_input_length = max_input_length 182 | self.max_output_length = max_output_length 183 | 184 | self.nl_to_sql_dict = nl_to_sql_dict 185 | 186 | self.index = 0 187 | 188 | def __len__(self): 189 | return len(self.interaction) 190 | 191 | def __str__(self): 192 | s = "Utterances, gold queries, and predictions:\n" 193 | for i, utterance in enumerate(self.interaction.utterances): 194 | s += " ".join(utterance.input_seq_to_use) + "\n" 195 | pred_utterance = self.processed_utterances[i] 196 | s += " ".join(pred_utterance.gold_query()) + "\n" 197 | s += " ".join(pred_utterance.anonymized_query()) + "\n" 198 | s += "\n" 199 | s += "Snippets:\n" 200 | for snippet in self.snippet_bank: 201 | s += str(snippet) + "\n" 202 | 203 | return s 204 | 205 | def start_interaction(self): 206 | assert len(self.snippet_bank) == 0 207 | assert len(self.processed_utterances) == 0 208 | assert self.index == 0 209 | 210 | def next_utterance(self): 211 | utterance = self.interaction.utterances[self.index] 212 | self.index += 1 213 | 214 | available_snippets = self.available_snippets(snippet_keep_age=1) 215 | 216 | return PredUtteranceItem(utterance.input_seq_to_use, 217 | self, 218 | self.processed_utterances[-1].anonymized_pred_query if len(self.processed_utterances) > 0 else [], 219 | self.index - 1, 220 | available_snippets) 221 | 222 | def done(self): 223 | return len(self.processed_utterances) == len(self.interaction) 224 | 225 | def finish(self): 226 | self.snippet_bank = [] 227 | self.processed_utterances = [] 228 | self.index = 0 229 | 230 | def utterance_within_limits(self, utterance_item): 231 | return utterance_item.within_limits(self.max_input_length, 232 | self.max_output_length) 233 | 234 | def available_snippets(self, snippet_keep_age): 235 | return [ 236 | snippet for snippet in self.snippet_bank if snippet.index <= snippet_keep_age] 237 | 238 | def gold_utterances(self): 239 | utterances = [] 240 | for i, utterance in enumerate(self.interaction.utterances): 241 | utterances.append(UtteranceItem(self.interaction, i)) 242 | return utterances 243 | 244 | def get_schema(self): 245 | return self.interaction.schema 246 | 247 | def add_utterance( 248 | self, 249 | utterance, 250 | predicted_sequence, 251 | snippets=None, 252 | previous_snippets=[], 253 | simple=False): 254 | if not snippets: 255 | self.add_snippets( 256 | predicted_sequence, 257 | previous_snippets=previous_snippets, simple=simple) 258 | else: 259 | for snippet in snippets: 260 | snippet.assign_id(len(self.snippet_bank)) 261 | self.snippet_bank.append(snippet) 262 | 263 | for snippet in self.snippet_bank: 264 | snippet.increase_age() 265 | self.processed_utterances.append(utterance) 266 | 267 | def add_snippets(self, sequence, previous_snippets=[], simple=False): 268 | if sequence: 269 | if simple: 270 | snippets = sql_util.get_subtrees_simple( 271 | sequence, oldsnippets=previous_snippets) 272 | else: 273 | snippets = sql_util.get_subtrees( 274 | sequence, oldsnippets=previous_snippets) 275 | for snippet in snippets: 276 | snippet.assign_id(len(self.snippet_bank)) 277 | self.snippet_bank.append(snippet) 278 | 279 | for snippet in self.snippet_bank: 280 | snippet.increase_age() 281 | 282 | def expand_snippets(self, sequence): 283 | return sql_util.fix_parentheses( 284 | snip.expand_snippets( 285 | sequence, self.snippet_bank)) 286 | 287 | def remove_snippets(self, sequence): 288 | if sequence[-1] == vocab.EOS_TOK: 289 | sequence = sequence[:-1] 290 | 291 | no_snippets_sequence = self.expand_snippets(sequence) 292 | no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence) 293 | return no_snippets_sequence 294 | 295 | def flatten_sequence(self, sequence, gold_snippets=False): 296 | if sequence[-1] == vocab.EOS_TOK: 297 | sequence = sequence[:-1] 298 | 299 | if gold_snippets: 300 | no_snippets_sequence = self.interaction.expand_snippets(sequence) 301 | else: 302 | no_snippets_sequence = self.expand_snippets(sequence) 303 | no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence) 304 | 305 | deanon_sequence = self.interaction.deanonymize( 306 | no_snippets_sequence, "sql") 307 | return deanon_sequence 308 | 309 | def gold_query(self, index): 310 | return self.interaction.utterances[index].gold_query_to_use + [ 311 | vocab.EOS_TOK] 312 | 313 | def original_gold_query(self, index): 314 | return self.interaction.utterances[index].original_gold_query 315 | 316 | def gold_table(self, index): 317 | return self.interaction.utterances[index].gold_sql_results 318 | 319 | 320 | class InteractionBatch(): 321 | def __init__(self, items): 322 | self.items = items 323 | 324 | def __len__(self): 325 | return len(self.items) 326 | 327 | def start(self): 328 | self.timestep = 0 329 | self.current_interactions = [] 330 | 331 | def get_next_utterance_batch(self, snippet_keep_age, use_gold=False): 332 | items = [] 333 | self.current_interactions = [] 334 | for interaction in self.items: 335 | if self.timestep < len(interaction): 336 | utterance_item = interaction.original_utterances( 337 | snippet_keep_age, use_gold)[self.timestep] 338 | self.current_interactions.append(interaction) 339 | items.append(utterance_item) 340 | 341 | self.timestep += 1 342 | return UtteranceBatch(items) 343 | 344 | def done(self): 345 | finished = True 346 | for interaction in self.items: 347 | if self.timestep < len(interaction): 348 | finished = False 349 | return finished 350 | return finished 351 | -------------------------------------------------------------------------------- /data_util/atis_vocab.py: -------------------------------------------------------------------------------- 1 | """Gets and stores vocabulary for the ATIS data.""" 2 | 3 | from . import snippets 4 | from .vocabulary import Vocabulary, UNK_TOK, DEL_TOK, EOS_TOK 5 | 6 | INPUT_FN_TYPES = [UNK_TOK, DEL_TOK, EOS_TOK] 7 | OUTPUT_FN_TYPES = [UNK_TOK, EOS_TOK] 8 | 9 | MIN_INPUT_OCCUR = 1 10 | MIN_OUTPUT_OCCUR = 1 11 | 12 | class ATISVocabulary(): 13 | """ Stores the vocabulary for the ATIS data. 14 | 15 | Attributes: 16 | raw_vocab (Vocabulary): Vocabulary object. 17 | tokens (set of str): Set of all of the strings in the vocabulary. 18 | inorder_tokens (list of str): List of all tokens, with a strict and 19 | unchanging order. 20 | """ 21 | def __init__(self, 22 | token_sequences, 23 | filename, 24 | params, 25 | is_input='input', 26 | min_occur=1, 27 | anonymizer=None, 28 | skip=None): 29 | 30 | if is_input=='input': 31 | functional_types = INPUT_FN_TYPES 32 | elif is_input=='output': 33 | functional_types = OUTPUT_FN_TYPES 34 | elif is_input=='schema': 35 | functional_types = [UNK_TOK] 36 | else: 37 | functional_types = [] 38 | 39 | self.raw_vocab = Vocabulary( 40 | token_sequences, 41 | filename, 42 | functional_types=functional_types, 43 | min_occur=min_occur, 44 | ignore_fn=lambda x: snippets.is_snippet(x) or ( 45 | anonymizer and anonymizer.is_anon_tok(x)) or (skip and x in skip) ) 46 | self.tokens = set(self.raw_vocab.token_to_id.keys()) 47 | self.inorder_tokens = self.raw_vocab.id_to_token 48 | 49 | assert len(self.inorder_tokens) == len(self.raw_vocab) 50 | 51 | def __len__(self): 52 | return len(self.raw_vocab) 53 | 54 | def token_to_id(self, token): 55 | """ Maps from a token to a unique ID. 56 | 57 | Inputs: 58 | token (str): The token to look up. 59 | 60 | Returns: 61 | int, uniquely identifying the token. 62 | """ 63 | return self.raw_vocab.token_to_id[token] 64 | 65 | def id_to_token(self, identifier): 66 | """ Maps from a unique integer to an identifier. 67 | 68 | Inputs: 69 | identifier (int): The unique ID. 70 | 71 | Returns: 72 | string, representing the token. 73 | """ 74 | return self.raw_vocab.id_to_token[identifier] 75 | -------------------------------------------------------------------------------- /data_util/dataset_split.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for loading and processing ATIS data. 2 | """ 3 | import os 4 | import pickle 5 | 6 | class DatasetSplit: 7 | """Stores a split of the ATIS dataset. 8 | 9 | Attributes: 10 | examples (list of Interaction): Stores the examples in the split. 11 | """ 12 | def __init__(self, processed_filename, raw_filename, load_function): 13 | if os.path.exists(processed_filename): 14 | print("Loading preprocessed data from " + processed_filename) 15 | with open(processed_filename, 'rb') as infile: 16 | self.examples = pickle.load(infile) 17 | else: 18 | print( 19 | "Loading raw data from " + 20 | raw_filename + 21 | " and writing to " + 22 | processed_filename) 23 | 24 | infile = open(raw_filename, 'rb') 25 | examples_from_file = pickle.load(infile) 26 | assert isinstance(examples_from_file, list), raw_filename + \ 27 | " does not contain a list of examples" 28 | infile.close() 29 | 30 | self.examples = [] 31 | for example in examples_from_file: 32 | obj, keep = load_function(example) 33 | 34 | if keep: 35 | self.examples.append(obj) 36 | 37 | 38 | print("Loaded " + str(len(self.examples)) + " examples") 39 | outfile = open(processed_filename, 'wb') 40 | pickle.dump(self.examples, outfile) 41 | outfile.close() 42 | 43 | def get_ex_properties(self, function): 44 | """ Applies some function to the examples in the dataset. 45 | 46 | Inputs: 47 | function: (lambda Interaction -> T): Function to apply to all 48 | examples. 49 | 50 | Returns 51 | list of the return value of the function 52 | """ 53 | elems = [] 54 | for example in self.examples: 55 | elems.append(function(example)) 56 | return elems 57 | -------------------------------------------------------------------------------- /data_util/entities.py: -------------------------------------------------------------------------------- 1 | """ Classes for keeping track of the entities in a natural language string. """ 2 | import json 3 | 4 | 5 | class NLtoSQLDict: 6 | """ 7 | Entity dict file should contain, on each line, a JSON dictionary with 8 | "input" and "output" keys specifying the string for the input and output 9 | pairs. The idea is that the existence of the key in an input sequence 10 | likely corresponds to the existence of the value in the output sequence. 11 | 12 | The entity_dict should map keys (input strings) to a list of values (output 13 | strings) where this property holds. This allows keys to map to multiple 14 | output strings (e.g. for times). 15 | """ 16 | def __init__(self, entity_dict_filename): 17 | self.entity_dict = {} 18 | 19 | pairs = [json.loads(line) 20 | for line in open(entity_dict_filename).readlines()] 21 | for pair in pairs: 22 | input_seq = pair["input"] 23 | output_seq = pair["output"] 24 | if input_seq not in self.entity_dict: 25 | self.entity_dict[input_seq] = [] 26 | self.entity_dict[input_seq].append(output_seq) 27 | 28 | def get_sql_entities(self, tokenized_nl_string): 29 | """ 30 | Gets the output-side entities which correspond to the input entities in 31 | the input sequence. 32 | Inputs: 33 | tokenized_input_string: list of tokens in the input string. 34 | Outputs: 35 | set of output strings. 36 | """ 37 | assert len(tokenized_nl_string) > 0 38 | flat_input_string = " ".join(tokenized_nl_string) 39 | entities = [] 40 | 41 | # See if any input strings are in our input sequence, and add the 42 | # corresponding output strings if so. 43 | for entry, values in self.entity_dict.items(): 44 | in_middle = " " + entry + " " in flat_input_string 45 | 46 | leftspace = " " + entry 47 | at_end = leftspace in flat_input_string and flat_input_string.endswith( 48 | leftspace) 49 | 50 | rightspace = entry + " " 51 | at_beginning = rightspace in flat_input_string and flat_input_string.startswith( 52 | rightspace) 53 | if in_middle or at_end or at_beginning: 54 | for out_string in values: 55 | entities.append(out_string) 56 | 57 | # Also add any integers in the input string (these aren't in the entity) 58 | # dict. 59 | for token in tokenized_nl_string: 60 | if token.isnumeric(): 61 | entities.append(token) 62 | 63 | return entities 64 | -------------------------------------------------------------------------------- /data_util/interaction.py: -------------------------------------------------------------------------------- 1 | """ Contains the class for an interaction in ATIS. """ 2 | 3 | from . import anonymization as anon 4 | from . import sql_util 5 | from .snippets import expand_snippets 6 | from .utterance import Utterance, OUTPUT_KEY, ANON_INPUT_KEY 7 | 8 | import torch 9 | 10 | class Schema: 11 | def __init__(self, table_schema, simple=False): 12 | if simple: 13 | self.helper1(table_schema) 14 | else: 15 | self.helper2(table_schema) 16 | 17 | def helper1(self, table_schema): 18 | self.table_schema = table_schema 19 | column_names = table_schema['column_names'] 20 | column_names_original = table_schema['column_names_original'] 21 | table_names = table_schema['table_names'] 22 | table_names_original = table_schema['table_names_original'] 23 | assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original) 24 | 25 | column_keep_index = [] 26 | 27 | self.column_names_surface_form = [] 28 | self.column_names_surface_form_to_id = {} 29 | for i, (table_id, column_name) in enumerate(column_names_original): 30 | column_name_surface_form = column_name 31 | column_name_surface_form = column_name_surface_form.lower() 32 | if column_name_surface_form not in self.column_names_surface_form_to_id: 33 | self.column_names_surface_form.append(column_name_surface_form) 34 | self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1 35 | column_keep_index.append(i) 36 | 37 | column_keep_index_2 = [] 38 | for i, table_name in enumerate(table_names_original): 39 | column_name_surface_form = table_name.lower() 40 | if column_name_surface_form not in self.column_names_surface_form_to_id: 41 | self.column_names_surface_form.append(column_name_surface_form) 42 | self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1 43 | column_keep_index_2.append(i) 44 | 45 | self.column_names_embedder_input = [] 46 | self.column_names_embedder_input_to_id = {} 47 | for i, (table_id, column_name) in enumerate(column_names): 48 | column_name_embedder_input = column_name 49 | if i in column_keep_index: 50 | self.column_names_embedder_input.append(column_name_embedder_input) 51 | self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1 52 | 53 | for i, table_name in enumerate(table_names): 54 | column_name_embedder_input = table_name 55 | if i in column_keep_index_2: 56 | self.column_names_embedder_input.append(column_name_embedder_input) 57 | self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1 58 | 59 | max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items()) 60 | max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items()) 61 | assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1 62 | 63 | self.num_col = len(self.column_names_surface_form) 64 | 65 | def helper2(self, table_schema): 66 | self.table_schema = table_schema 67 | column_names = table_schema['column_names'] 68 | column_names_original = table_schema['column_names_original'] 69 | table_names = table_schema['table_names'] 70 | table_names_original = table_schema['table_names_original'] 71 | assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original) 72 | 73 | column_keep_index = [] 74 | 75 | self.column_names_surface_form = [] 76 | self.column_names_surface_form_to_id = {} 77 | for i, (table_id, column_name) in enumerate(column_names_original): 78 | if table_id >= 0: 79 | table_name = table_names_original[table_id] 80 | column_name_surface_form = '{}.{}'.format(table_name,column_name) 81 | else: 82 | column_name_surface_form = column_name 83 | column_name_surface_form = column_name_surface_form.lower() 84 | if column_name_surface_form not in self.column_names_surface_form_to_id: 85 | self.column_names_surface_form.append(column_name_surface_form) 86 | self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1 87 | column_keep_index.append(i) 88 | 89 | start_i = len(self.column_names_surface_form_to_id) 90 | for i, table_name in enumerate(table_names_original): 91 | column_name_surface_form = '{}.*'.format(table_name.lower()) 92 | self.column_names_surface_form.append(column_name_surface_form) 93 | self.column_names_surface_form_to_id[column_name_surface_form] = i + start_i 94 | 95 | self.column_names_embedder_input = [] 96 | self.column_names_embedder_input_to_id = {} 97 | for i, (table_id, column_name) in enumerate(column_names): 98 | if table_id >= 0: 99 | table_name = table_names[table_id] 100 | column_name_embedder_input = table_name + ' . ' + column_name 101 | else: 102 | column_name_embedder_input = column_name 103 | if i in column_keep_index: 104 | self.column_names_embedder_input.append(column_name_embedder_input) 105 | self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1 106 | 107 | start_i = len(self.column_names_embedder_input_to_id) 108 | for i, table_name in enumerate(table_names): 109 | column_name_embedder_input = table_name + ' . *' 110 | self.column_names_embedder_input.append(column_name_embedder_input) 111 | self.column_names_embedder_input_to_id[column_name_embedder_input] = i + start_i 112 | 113 | assert len(self.column_names_surface_form) == len(self.column_names_surface_form_to_id) == len(self.column_names_embedder_input) == len(self.column_names_embedder_input_to_id) 114 | 115 | max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items()) 116 | max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items()) 117 | assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1 118 | 119 | self.num_col = len(self.column_names_surface_form) 120 | 121 | def __len__(self): 122 | return self.num_col 123 | 124 | def in_vocabulary(self, column_name, surface_form=False): 125 | if surface_form: 126 | return column_name in self.column_names_surface_form_to_id 127 | else: 128 | return column_name in self.column_names_embedder_input_to_id 129 | 130 | def column_name_embedder_bow(self, column_name, surface_form=False, column_name_token_embedder=None): 131 | assert self.in_vocabulary(column_name, surface_form) 132 | if surface_form: 133 | column_name_id = self.column_names_surface_form_to_id[column_name] 134 | column_name_embedder_input = self.column_names_embedder_input[column_name_id] 135 | else: 136 | column_name_embedder_input = column_name 137 | 138 | column_name_embeddings = [column_name_token_embedder(token) for token in column_name_embedder_input.split()] 139 | column_name_embeddings = torch.stack(column_name_embeddings, dim=0) 140 | return torch.mean(column_name_embeddings, dim=0) 141 | 142 | def set_column_name_embeddings(self, column_name_embeddings): 143 | self.column_name_embeddings = column_name_embeddings 144 | assert len(self.column_name_embeddings) == self.num_col 145 | 146 | def column_name_embedder(self, column_name, surface_form=False): 147 | assert self.in_vocabulary(column_name, surface_form) 148 | if surface_form: 149 | column_name_id = self.column_names_surface_form_to_id[column_name] 150 | else: 151 | column_name_id = self.column_names_embedder_input_to_id[column_name] 152 | 153 | return self.column_name_embeddings[column_name_id] 154 | 155 | class Interaction: 156 | """ ATIS interaction class. 157 | 158 | Attributes: 159 | utterances (list of Utterance): The utterances in the interaction. 160 | snippets (list of Snippet): The snippets that appear through the interaction. 161 | anon_tok_to_ent: 162 | identifier (str): Unique identifier for the interaction in the dataset. 163 | """ 164 | def __init__(self, 165 | utterances, 166 | schema, 167 | snippets, 168 | anon_tok_to_ent, 169 | identifier, 170 | params): 171 | self.utterances = utterances 172 | self.schema = schema 173 | self.snippets = snippets 174 | self.anon_tok_to_ent = anon_tok_to_ent 175 | self.identifier = identifier 176 | 177 | # Ensure that each utterance's input and output sequences, when remapped 178 | # without anonymization or snippets, are the same as the original 179 | # version. 180 | for i, utterance in enumerate(self.utterances): 181 | deanon_input = self.deanonymize(utterance.input_seq_to_use, 182 | ANON_INPUT_KEY) 183 | assert deanon_input == utterance.original_input_seq, "Anonymized sequence [" \ 184 | + " ".join(utterance.input_seq_to_use) + "] is not the same as [" \ 185 | + " ".join(utterance.original_input_seq) + "] when deanonymized (is [" \ 186 | + " ".join(deanon_input) + "] instead)" 187 | desnippet_gold = self.expand_snippets(utterance.gold_query_to_use) 188 | deanon_gold = self.deanonymize(desnippet_gold, OUTPUT_KEY) 189 | assert deanon_gold == utterance.original_gold_query, \ 190 | "Anonymized and/or snippet'd query " \ 191 | + " ".join(utterance.gold_query_to_use) + " is not the same as " \ 192 | + " ".join(utterance.original_gold_query) 193 | 194 | def __str__(self): 195 | string = "Utterances:\n" 196 | for utterance in self.utterances: 197 | string += str(utterance) + "\n" 198 | string += "Anonymization dictionary:\n" 199 | for ent_tok, deanon in self.anon_tok_to_ent.items(): 200 | string += ent_tok + "\t" + str(deanon) + "\n" 201 | 202 | return string 203 | 204 | def __len__(self): 205 | return len(self.utterances) 206 | 207 | def deanonymize(self, sequence, key): 208 | """ Deanonymizes a predicted query or an input utterance. 209 | 210 | Inputs: 211 | sequence (list of str): The sequence to deanonymize. 212 | key (str): The key in the anonymization table, e.g. NL or SQL. 213 | """ 214 | return anon.deanonymize(sequence, self.anon_tok_to_ent, key) 215 | 216 | def expand_snippets(self, sequence): 217 | """ Expands snippets for a sequence. 218 | 219 | Inputs: 220 | sequence (list of str): A SQL query. 221 | 222 | """ 223 | return expand_snippets(sequence, self.snippets) 224 | 225 | def input_seqs(self): 226 | in_seqs = [] 227 | for utterance in self.utterances: 228 | in_seqs.append(utterance.input_seq_to_use) 229 | return in_seqs 230 | 231 | def output_seqs(self): 232 | out_seqs = [] 233 | for utterance in self.utterances: 234 | out_seqs.append(utterance.gold_query_to_use) 235 | return out_seqs 236 | 237 | def load_function(parameters, 238 | nl_to_sql_dict, 239 | anonymizer, 240 | database_schema=None): 241 | def fn(interaction_example): 242 | keep = False 243 | 244 | raw_utterances = interaction_example["interaction"] 245 | 246 | if "database_id" in interaction_example: 247 | database_id = interaction_example["database_id"] 248 | interaction_id = interaction_example["interaction_id"] 249 | identifier = database_id + '/' + str(interaction_id) 250 | else: 251 | identifier = interaction_example["id"] 252 | 253 | schema = None 254 | if database_schema: 255 | if 'removefrom' not in parameters.data_directory: 256 | schema = Schema(database_schema[database_id], simple=True) 257 | else: 258 | schema = Schema(database_schema[database_id]) 259 | 260 | snippet_bank = [] 261 | 262 | utterance_examples = [] 263 | 264 | anon_tok_to_ent = {} 265 | 266 | for utterance in raw_utterances: 267 | available_snippets = [ 268 | snippet for snippet in snippet_bank if snippet.index <= 1] 269 | 270 | proc_utterance = Utterance(utterance, 271 | available_snippets, 272 | nl_to_sql_dict, 273 | parameters, 274 | anon_tok_to_ent, 275 | anonymizer) 276 | keep_utterance = proc_utterance.keep 277 | 278 | if schema: 279 | assert keep_utterance 280 | 281 | if keep_utterance: 282 | keep = True 283 | utterance_examples.append(proc_utterance) 284 | 285 | # Update the snippet bank, and age each snippet in it. 286 | if parameters.use_snippets: 287 | if 'atis' in parameters.data_directory: 288 | snippets = sql_util.get_subtrees( 289 | proc_utterance.anonymized_gold_query, 290 | proc_utterance.available_snippets) 291 | else: 292 | snippets = sql_util.get_subtrees_simple( 293 | proc_utterance.anonymized_gold_query, 294 | proc_utterance.available_snippets) 295 | 296 | for snippet in snippets: 297 | snippet.assign_id(len(snippet_bank)) 298 | snippet_bank.append(snippet) 299 | 300 | for snippet in snippet_bank: 301 | snippet.increase_age() 302 | 303 | interaction = Interaction(utterance_examples, 304 | schema, 305 | snippet_bank, 306 | anon_tok_to_ent, 307 | identifier, 308 | parameters) 309 | 310 | return interaction, keep 311 | 312 | return fn 313 | -------------------------------------------------------------------------------- /data_util/snippets.py: -------------------------------------------------------------------------------- 1 | """ Contains the Snippet class and methods for handling snippets. 2 | 3 | Attributes: 4 | SNIPPET_PREFIX: string prefix for snippets. 5 | """ 6 | 7 | SNIPPET_PREFIX = "SNIPPET_" 8 | 9 | 10 | def is_snippet(token): 11 | """ Determines whether a token is a snippet or not. 12 | 13 | Inputs: 14 | token (str): The token to check. 15 | 16 | Returns: 17 | bool, indicating whether it's a snippet. 18 | """ 19 | return token.startswith(SNIPPET_PREFIX) 20 | 21 | def expand_snippets(sequence, snippets): 22 | """ Given a sequence and a list of snippets, expand the snippets in the sequence. 23 | 24 | Inputs: 25 | sequence (list of str): Query containing snippet references. 26 | snippets (list of Snippet): List of available snippets. 27 | 28 | return list of str representing the expanded sequence 29 | """ 30 | snippet_id_to_snippet = {} 31 | for snippet in snippets: 32 | assert snippet.name not in snippet_id_to_snippet 33 | snippet_id_to_snippet[snippet.name] = snippet 34 | expanded_seq = [] 35 | for token in sequence: 36 | if token in snippet_id_to_snippet: 37 | expanded_seq.extend(snippet_id_to_snippet[token].sequence) 38 | else: 39 | assert not is_snippet(token) 40 | expanded_seq.append(token) 41 | 42 | return expanded_seq 43 | 44 | def snippet_index(token): 45 | """ Returns the index of a snippet. 46 | 47 | Inputs: 48 | token (str): The snippet to check. 49 | 50 | Returns: 51 | integer, the index of the snippet. 52 | """ 53 | assert is_snippet(token) 54 | return int(token.split("_")[-1]) 55 | 56 | 57 | class Snippet(): 58 | """ Contains a snippet. """ 59 | def __init__(self, 60 | sequence, 61 | startpos, 62 | sql, 63 | age=0): 64 | self.sequence = sequence 65 | self.startpos = startpos 66 | self.sql = sql 67 | 68 | # TODO: age vs. index? 69 | self.age = age 70 | self.index = 0 71 | 72 | self.name = "" 73 | self.embedding = None 74 | 75 | self.endpos = self.startpos + len(self.sequence) 76 | assert self.endpos < len(self.sql), "End position of snippet is " + str( 77 | self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")" 78 | assert self.sequence == self.sql[self.startpos:self.endpos], \ 79 | "Value of snippet (" + " ".join(self.sequence) + ") " \ 80 | "is not the same as SQL at the same positions (" \ 81 | + " ".join(self.sql[self.startpos:self.endpos]) + ")" 82 | 83 | def __str__(self): 84 | return self.name + "\t" + \ 85 | str(self.age) + "\t" + " ".join(self.sequence) 86 | 87 | def __len__(self): 88 | return len(self.sequence) 89 | 90 | def increase_age(self): 91 | """ Ages a snippet by one. """ 92 | self.index += 1 93 | 94 | def assign_id(self, number): 95 | """ Assigns the name of the snippet to be the prefix + the number. """ 96 | self.name = SNIPPET_PREFIX + str(number) 97 | 98 | def set_embedding(self, embedding): 99 | """ Sets the embedding of the snippet. 100 | 101 | Inputs: 102 | embedding (dy.Expression) 103 | 104 | """ 105 | self.embedding = embedding 106 | -------------------------------------------------------------------------------- /data_util/sql_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pymysql 3 | import random 4 | import signal 5 | import sqlparse 6 | from . import util 7 | 8 | from .snippets import Snippet 9 | from sqlparse import tokens as token_types 10 | from sqlparse import sql as sql_types 11 | 12 | interesting_selects = ["DISTINCT", "MAX", "MIN", "count"] 13 | ignored_subtrees = [["1", "=", "1"]] 14 | 15 | # strip_whitespace_front 16 | # Strips whitespace and punctuation from the front of a SQL token list. 17 | # 18 | # Inputs: 19 | # token_list: the token list. 20 | # 21 | # Outputs: 22 | # new token list. 23 | 24 | 25 | def strip_whitespace_front(token_list): 26 | new_token_list = [] 27 | found_valid = False 28 | 29 | for token in token_list: 30 | if not (token.is_whitespace or token.ttype == 31 | token_types.Punctuation) or found_valid: 32 | found_valid = True 33 | new_token_list.append(token) 34 | 35 | return new_token_list 36 | 37 | # strip_whitespace 38 | # Strips whitespace from a token list. 39 | # 40 | # Inputs: 41 | # token_list: the token list. 42 | # 43 | # Outputs: 44 | # new token list with no whitespace/punctuation surrounding. 45 | 46 | 47 | def strip_whitespace(token_list): 48 | subtokens = strip_whitespace_front(token_list) 49 | subtokens = strip_whitespace_front(subtokens[::-1])[::-1] 50 | return subtokens 51 | 52 | # token_list_to_seq 53 | # Converts a Token list to a sequence of strings, stripping out surrounding 54 | # punctuation and all whitespace. 55 | # 56 | # Inputs: 57 | # token_list: the list of tokens. 58 | # 59 | # Outputs: 60 | # list of strings 61 | 62 | 63 | def token_list_to_seq(token_list): 64 | subtokens = strip_whitespace(token_list) 65 | 66 | seq = [] 67 | flat = sqlparse.sql.TokenList(subtokens).flatten() 68 | for i, token in enumerate(flat): 69 | strip_token = str(token).strip() 70 | if len(strip_token) > 0: 71 | seq.append(strip_token) 72 | if len(seq) > 0: 73 | if seq[0] == "(" and seq[-1] == ")": 74 | seq = seq[1:-1] 75 | 76 | return seq 77 | 78 | # TODO: clean this up 79 | # find_subtrees 80 | # Finds subtrees for a subsequence of SQL. 81 | # 82 | # Inputs: 83 | # sequence: sequence of SQL tokens. 84 | # current_subtrees: current list of subtrees. 85 | # 86 | # Optional inputs: 87 | # where_parent: whether the parent of the current sequence was a where clause 88 | # keep_conj_subtrees: whether to look for a conjunction in this sequence and 89 | # keep its arguments 90 | 91 | 92 | def find_subtrees(sequence, 93 | current_subtrees, 94 | where_parent=False, 95 | keep_conj_subtrees=False): 96 | # If the parent of the subsequence was a WHERE clause, keep everything in the 97 | # sequence except for the beginning WHERE and any surrounding parentheses. 98 | if where_parent: 99 | # Strip out the beginning WHERE, and any punctuation or whitespace at the 100 | # beginning or end of the token list. 101 | seq = token_list_to_seq(sequence.tokens[1:]) 102 | if len(seq) > 0 and seq not in current_subtrees: 103 | current_subtrees.append(seq) 104 | 105 | # If the current sequence has subtokens, i.e. if it's a node that can be 106 | # expanded, check for a conjunction in its subtrees, and expand its subtrees. 107 | # Also check for any SELECT statements and keep track of what follows. 108 | if sequence.is_group: 109 | if keep_conj_subtrees: 110 | subtokens = strip_whitespace(sequence.tokens) 111 | 112 | # Check if there is a conjunction in the subsequence. If so, keep the 113 | # children. Also make sure you don't split where AND is used within a 114 | # child -- the subtokens sequence won't treat those ANDs differently (a 115 | # bit hacky but it works) 116 | has_and = False 117 | for i, token in enumerate(subtokens): 118 | if token.value == "OR" or token.value == "AND": 119 | has_and = True 120 | break 121 | 122 | if has_and: 123 | and_subtrees = [] 124 | current_subtree = [] 125 | for i, token in enumerate(subtokens): 126 | if token.value == "OR" or (token.value == "AND" and i - 4 >= 0 and i - 4 < len( 127 | subtokens) and subtokens[i - 4].value != "BETWEEN"): 128 | and_subtrees.append(current_subtree) 129 | current_subtree = [] 130 | else: 131 | current_subtree.append(token) 132 | and_subtrees.append(current_subtree) 133 | 134 | for subtree in and_subtrees: 135 | seq = token_list_to_seq(subtree) 136 | if len(seq) > 0 and seq[0] == "WHERE": 137 | seq = seq[1:] 138 | if seq not in current_subtrees: 139 | current_subtrees.append(seq) 140 | 141 | in_select = False 142 | select_toks = [] 143 | for i, token in enumerate(sequence.tokens): 144 | # Mark whether this current token is a WHERE. 145 | is_where = (isinstance(token, sql_types.Where)) 146 | 147 | # If you are in a SELECT, start recording what follows until you hit a 148 | # FROM 149 | if token.value == "SELECT": 150 | in_select = True 151 | elif in_select: 152 | select_toks.append(token) 153 | if token.value == "FROM": 154 | in_select = False 155 | 156 | seq = [] 157 | if len(sequence.tokens) > i + 2: 158 | seq = token_list_to_seq( 159 | select_toks + [sequence.tokens[i + 2]]) 160 | 161 | if seq not in current_subtrees and len( 162 | seq) > 0 and seq[0] in interesting_selects: 163 | current_subtrees.append(seq) 164 | 165 | select_toks = [] 166 | 167 | # Recursively find subtrees in the children of the node. 168 | find_subtrees(token, 169 | current_subtrees, 170 | is_where, 171 | where_parent or keep_conj_subtrees) 172 | 173 | # get_subtrees 174 | 175 | 176 | def get_subtrees(sql, oldsnippets=[]): 177 | parsed = sqlparse.parse(" ".join(sql))[0] 178 | 179 | subtrees = [] 180 | find_subtrees(parsed, subtrees) 181 | 182 | final_subtrees = [] 183 | for subtree in subtrees: 184 | if subtree not in ignored_subtrees: 185 | final_version = [] 186 | keep = True 187 | 188 | parens_counts = 0 189 | for i, token in enumerate(subtree): 190 | if token == ".": 191 | newtoken = final_version[-1] + "." + subtree[i + 1] 192 | final_version = final_version[:-1] + [newtoken] 193 | keep = False 194 | elif keep: 195 | final_version.append(token) 196 | else: 197 | keep = True 198 | 199 | if token == "(": 200 | parens_counts -= 1 201 | elif token == ")": 202 | parens_counts += 1 203 | 204 | if parens_counts == 0: 205 | final_subtrees.append(final_version) 206 | 207 | snippets = [] 208 | sql = [str(tok) for tok in sql] 209 | for subtree in final_subtrees: 210 | startpos = -1 211 | for i in range(len(sql) - len(subtree) + 1): 212 | if sql[i:i + len(subtree)] == subtree: 213 | startpos = i 214 | if startpos >= 0 and startpos + len(subtree) < len(sql): 215 | age = 0 216 | for prevsnippet in oldsnippets: 217 | if prevsnippet.sequence == subtree: 218 | age = prevsnippet.age + 1 219 | snippet = Snippet(subtree, startpos, sql, age=age) 220 | snippets.append(snippet) 221 | 222 | return snippets 223 | 224 | 225 | def get_subtrees_simple(sql, oldsnippets=[]): 226 | sql_string = " ".join(sql) 227 | format_sql = sqlparse.format(sql_string, reindent=True) 228 | 229 | # get subtrees 230 | subtrees = [] 231 | for sub_sql in format_sql.split('\n'): 232 | sub_sql = sub_sql.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ') 233 | 234 | subtree = sub_sql.strip().split() 235 | if len(subtree) > 1: 236 | subtrees.append(subtree) 237 | 238 | final_subtrees = subtrees 239 | 240 | snippets = [] 241 | sql = [str(tok) for tok in sql] 242 | for subtree in final_subtrees: 243 | startpos = -1 244 | for i in range(len(sql) - len(subtree) + 1): 245 | if sql[i:i + len(subtree)] == subtree: 246 | startpos = i 247 | 248 | if startpos >= 0 and startpos + len(subtree) <= len(sql): 249 | age = 0 250 | for prevsnippet in oldsnippets: 251 | if prevsnippet.sequence == subtree: 252 | age = prevsnippet.age + 1 253 | new_sql = sql + [';'] 254 | snippet = Snippet(subtree, startpos, new_sql, age=age) 255 | snippets.append(snippet) 256 | 257 | return snippets 258 | 259 | 260 | conjunctions = {"AND", "OR", "WHERE"} 261 | 262 | 263 | def get_all_in_parens(sequence): 264 | if sequence[-1] == ";": 265 | sequence = sequence[:-1] 266 | 267 | if not "(" in sequence: 268 | return [] 269 | 270 | if sequence[0] == "(" and sequence[-1] == ")": 271 | in_parens = sequence[1:-1] 272 | return [in_parens] + get_all_in_parens(in_parens) 273 | else: 274 | paren_subseqs = [] 275 | current_seq = [] 276 | num_parens = 0 277 | in_parens = False 278 | for token in sequence: 279 | if in_parens: 280 | current_seq.append(token) 281 | if token == ")": 282 | num_parens -= 1 283 | if num_parens == 0: 284 | in_parens = False 285 | paren_subseqs.append(current_seq) 286 | current_seq = [] 287 | elif token == "(": 288 | in_parens = True 289 | current_seq.append(token) 290 | if token == "(": 291 | num_parens += 1 292 | 293 | all_subseqs = [] 294 | for subseq in paren_subseqs: 295 | all_subseqs.extend(get_all_in_parens(subseq)) 296 | return all_subseqs 297 | 298 | 299 | def split_by_conj(sequence): 300 | num_parens = 0 301 | current_seq = [] 302 | subsequences = [] 303 | 304 | for token in sequence: 305 | if num_parens == 0: 306 | if token in conjunctions: 307 | subsequences.append(current_seq) 308 | current_seq = [] 309 | break 310 | current_seq.append(token) 311 | if token == "(": 312 | num_parens += 1 313 | elif token == ")": 314 | num_parens -= 1 315 | 316 | assert num_parens >= 0 317 | 318 | return subsequences 319 | 320 | 321 | def get_sql_snippets(sequence): 322 | # First, get all subsequences of the sequence that are surrounded by 323 | # parentheses. 324 | all_in_parens = get_all_in_parens(sequence) 325 | all_subseq = [] 326 | 327 | # Then for each one, split the sequence on conjunctions (AND/OR). 328 | for seq in all_in_parens: 329 | subsequences = split_by_conj(seq) 330 | all_subseq.append(seq) 331 | all_subseq.extend(subsequences) 332 | 333 | # Finally, also get "interesting" selects 334 | 335 | for i, seq in enumerate(all_subseq): 336 | print(str(i) + "\t" + " ".join(seq)) 337 | exit() 338 | 339 | # add_snippets_to_query 340 | 341 | 342 | def add_snippets_to_query(snippets, ignored_entities, query, prob_align=1.): 343 | query_copy = copy.copy(query) 344 | 345 | # Replace the longest snippets first, so sort by length descending. 346 | sorted_snippets = sorted(snippets, key=lambda s: len(s.sequence))[::-1] 347 | 348 | for snippet in sorted_snippets: 349 | ignore = False 350 | snippet_seq = snippet.sequence 351 | 352 | # TODO: continue here 353 | # If it contains an ignored entity, then don't use it. 354 | for entity in ignored_entities: 355 | ignore = ignore or util.subsequence(entity, snippet_seq) 356 | 357 | # No NL entities found in snippet, then see if snippet is a substring of 358 | # the gold sequence 359 | if not ignore: 360 | snippet_length = len(snippet_seq) 361 | 362 | # Iterate through gold sequence to see if it's a subsequence. 363 | for start_idx in range(len(query_copy) - snippet_length + 1): 364 | if query_copy[start_idx:start_idx + 365 | snippet_length] == snippet_seq: 366 | align = random.random() < prob_align 367 | 368 | if align: 369 | prev_length = len(query_copy) 370 | 371 | # At the start position of the snippet, replace with an 372 | # identifier. 373 | query_copy[start_idx] = snippet.name 374 | 375 | # Then cut out the indices which were collapsed into 376 | # the snippet. 377 | query_copy = query_copy[:start_idx + 1] + \ 378 | query_copy[start_idx + snippet_length:] 379 | 380 | # Make sure the length is as expected 381 | assert len(query_copy) == prev_length - \ 382 | (snippet_length - 1) 383 | 384 | return query_copy 385 | 386 | 387 | def execution_results(query, username, password, timeout=3): 388 | connection = pymysql.connect(user=username, password=password) 389 | 390 | class TimeoutException(Exception): 391 | pass 392 | 393 | def timeout_handler(signum, frame): 394 | raise TimeoutException 395 | 396 | signal.signal(signal.SIGALRM, timeout_handler) 397 | 398 | syntactic = True 399 | semantic = True 400 | 401 | table = [] 402 | 403 | with connection.cursor() as cursor: 404 | signal.alarm(timeout) 405 | try: 406 | cursor.execute("SET sql_mode='IGNORE_SPACE';") 407 | cursor.execute("use atis3;") 408 | cursor.execute(query) 409 | table = cursor.fetchall() 410 | cursor.close() 411 | except TimeoutException: 412 | signal.alarm(0) 413 | cursor.close() 414 | except pymysql.err.ProgrammingError: 415 | syntactic = False 416 | semantic = False 417 | cursor.close() 418 | except pymysql.err.InternalError: 419 | semantic = False 420 | cursor.close() 421 | except Exception as e: 422 | signal.alarm(0) 423 | signal.alarm(0) 424 | cursor.close() 425 | signal.alarm(0) 426 | 427 | connection.close() 428 | 429 | return (syntactic, semantic, sorted(table)) 430 | 431 | 432 | def executable(query, username, password, timeout=2): 433 | return execution_results(query, username, password, timeout)[1] 434 | 435 | 436 | def fix_parentheses(sequence): 437 | num_left = sequence.count("(") 438 | num_right = sequence.count(")") 439 | 440 | if num_right < num_left: 441 | fixed_sequence = sequence[:-1] + \ 442 | [")" for _ in range(num_left - num_right)] + [sequence[-1]] 443 | return fixed_sequence 444 | 445 | return sequence 446 | -------------------------------------------------------------------------------- /data_util/tokenizers.py: -------------------------------------------------------------------------------- 1 | """Tokenizers for natural language SQL queries, and lambda calculus.""" 2 | import nltk 3 | import sqlparse 4 | 5 | def nl_tokenize(string): 6 | """Tokenizes a natural language string into tokens. 7 | 8 | Inputs: 9 | string: the string to tokenize. 10 | Outputs: 11 | a list of tokens. 12 | 13 | Assumes data is space-separated (this is true of ZC07 data in ATIS2/3). 14 | """ 15 | return nltk.word_tokenize(string) 16 | 17 | def sql_tokenize(string): 18 | """ Tokenizes a SQL statement into tokens. 19 | 20 | Inputs: 21 | string: string to tokenize. 22 | 23 | Outputs: 24 | a list of tokens. 25 | """ 26 | tokens = [] 27 | statements = sqlparse.parse(string) 28 | 29 | # SQLparse gives you a list of statements. 30 | for statement in statements: 31 | # Flatten the tokens in each statement and add to the tokens list. 32 | flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten() 33 | for token in flat_tokens: 34 | strip_token = str(token).strip() 35 | if len(strip_token) > 0: 36 | tokens.append(strip_token) 37 | 38 | newtokens = [] 39 | keep = True 40 | for i, token in enumerate(tokens): 41 | if token == ".": 42 | newtoken = newtokens[-1] + "." + tokens[i + 1] 43 | newtokens = newtokens[:-1] + [newtoken] 44 | keep = False 45 | elif keep: 46 | newtokens.append(token) 47 | else: 48 | keep = True 49 | 50 | return newtokens 51 | 52 | def lambda_tokenize(string): 53 | """ Tokenizes a lambda-calculus statement into tokens. 54 | 55 | Inputs: 56 | string: a lambda-calculus string 57 | 58 | Outputs: 59 | a list of tokens. 60 | """ 61 | 62 | space_separated = string.split(" ") 63 | 64 | new_tokens = [] 65 | 66 | # Separate the string by spaces, then separate based on existence of ( or 67 | # ). 68 | for token in space_separated: 69 | tokens = [] 70 | 71 | current_token = "" 72 | for char in token: 73 | if char == ")" or char == "(": 74 | tokens.append(current_token) 75 | tokens.append(char) 76 | current_token = "" 77 | else: 78 | current_token += char 79 | tokens.append(current_token) 80 | new_tokens.extend([tok for tok in tokens if tok]) 81 | 82 | return new_tokens 83 | -------------------------------------------------------------------------------- /data_util/util.py: -------------------------------------------------------------------------------- 1 | """Contains various utility functions.""" 2 | def subsequence(first_sequence, second_sequence): 3 | """ 4 | Returns whether the first sequence is a subsequence of the second sequence. 5 | 6 | Inputs: 7 | first_sequence (list): A sequence. 8 | second_sequence (list): Another sequence. 9 | 10 | Returns: 11 | Boolean indicating whether first_sequence is a subsequence of second_sequence. 12 | """ 13 | for startidx in range(len(second_sequence) - len(first_sequence) + 1): 14 | if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence: 15 | return True 16 | return False 17 | -------------------------------------------------------------------------------- /data_util/utterance.py: -------------------------------------------------------------------------------- 1 | """ Contains the Utterance class. """ 2 | 3 | from . import sql_util 4 | from . import tokenizers 5 | 6 | ANON_INPUT_KEY = "cleaned_nl" 7 | OUTPUT_KEY = "sql" 8 | 9 | class Utterance: 10 | """ Utterance class. """ 11 | def process_input_seq(self, 12 | anonymize, 13 | anonymizer, 14 | anon_tok_to_ent): 15 | assert not anon_tok_to_ent or anonymize 16 | assert not anonymize or anonymizer 17 | 18 | if anonymize: 19 | assert anonymizer 20 | 21 | self.input_seq_to_use = anonymizer.anonymize( 22 | self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True) 23 | else: 24 | self.input_seq_to_use = self.original_input_seq 25 | 26 | def process_gold_seq(self, 27 | output_sequences, 28 | nl_to_sql_dict, 29 | available_snippets, 30 | anonymize, 31 | anonymizer, 32 | anon_tok_to_ent): 33 | # Get entities in the input sequence: 34 | # anonymized entity types 35 | # othe recognized entities (this includes "flight") 36 | entities_in_input = [ 37 | [tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent] 38 | entities_in_input.extend( 39 | nl_to_sql_dict.get_sql_entities( 40 | self.input_seq_to_use)) 41 | 42 | # Get the shortest gold query (this is what we use to train) 43 | shortest_gold_and_results = min(output_sequences, 44 | key=lambda x: len(x[0])) 45 | 46 | # Tokenize and anonymize it if necessary. 47 | self.original_gold_query = shortest_gold_and_results[0] 48 | self.gold_sql_results = shortest_gold_and_results[1] 49 | 50 | self.contained_entities = entities_in_input 51 | 52 | # Keep track of all gold queries and the resulting tables so that we can 53 | # give credit if it predicts a different correct sequence. 54 | self.all_gold_queries = output_sequences 55 | 56 | self.anonymized_gold_query = self.original_gold_query 57 | if anonymize: 58 | self.anonymized_gold_query = anonymizer.anonymize( 59 | self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False) 60 | 61 | # Add snippets to it. 62 | self.gold_query_to_use = sql_util.add_snippets_to_query( 63 | available_snippets, entities_in_input, self.anonymized_gold_query) 64 | 65 | def __init__(self, 66 | example, 67 | available_snippets, 68 | nl_to_sql_dict, 69 | params, 70 | anon_tok_to_ent={}, 71 | anonymizer=None): 72 | # Get output and input sequences from the dictionary representation. 73 | output_sequences = example[OUTPUT_KEY] 74 | self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key]) 75 | self.available_snippets = available_snippets 76 | self.keep = False 77 | 78 | # pruned_output_sequences = [] 79 | # for sequence in output_sequences: 80 | # if len(sequence[0]) > 3: 81 | # pruned_output_sequences.append(sequence) 82 | 83 | # output_sequences = pruned_output_sequences 84 | if len(output_sequences) > 0 and len(self.original_input_seq) > 0: 85 | # Only keep this example if there is at least one output sequence. 86 | self.keep = True 87 | if len(output_sequences) == 0 or len(self.original_input_seq) == 0: 88 | return 89 | 90 | # Process the input sequence 91 | self.process_input_seq(params.anonymize, 92 | anonymizer, 93 | anon_tok_to_ent) 94 | 95 | # Process the gold sequence 96 | self.process_gold_seq(output_sequences, 97 | nl_to_sql_dict, 98 | self.available_snippets, 99 | params.anonymize, 100 | anonymizer, 101 | anon_tok_to_ent) 102 | 103 | def __str__(self): 104 | string = "Original input: " + " ".join(self.original_input_seq) + "\n" 105 | string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n" 106 | string += "Original output: " + " ".join(self.original_gold_query) + "\n" 107 | string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n" 108 | string += "Snippets:\n" 109 | for snippet in self.available_snippets: 110 | string += str(snippet) + "\n" 111 | return string 112 | 113 | def length_valid(self, input_limit, output_limit): 114 | return (len(self.input_seq_to_use) < input_limit \ 115 | and len(self.gold_query_to_use) < output_limit) 116 | -------------------------------------------------------------------------------- /data_util/vocabulary.py: -------------------------------------------------------------------------------- 1 | """Contains class and methods for storing and computing a vocabulary from text.""" 2 | import operator 3 | import os 4 | import pickle 5 | 6 | # Special sequencing tokens. 7 | UNK_TOK = "_UNK" # Replaces out-of-vocabulary words. 8 | EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end. 9 | DEL_TOK = ";" 10 | 11 | 12 | class Vocabulary: 13 | """Vocabulary class: stores information about words in a corpus. 14 | 15 | Members: 16 | functional_types (list of str): Functional vocabulary words, such as EOS. 17 | max_size (int): The maximum size of vocabulary to keep. 18 | min_occur (int): The minimum number of times a word should occur to keep it. 19 | id_to_token (list of str): Ordered list of word types. 20 | token_to_id (dict str->int): Maps from each unique word type to its index. 21 | """ 22 | def get_vocab(self, sequences, ignore_fn): 23 | """Gets vocabulary from a list of sequences. 24 | 25 | Inputs: 26 | sequences (list of list of str): Sequences from which to compute the vocabulary. 27 | ignore_fn (lambda str: bool): Function used to tell whether to ignore a 28 | token during computation of the vocabulary. 29 | 30 | Returns: 31 | list of str, representing the unique word types in the vocabulary. 32 | """ 33 | type_counts = {} 34 | 35 | for sequence in sequences: 36 | for token in sequence: 37 | if not ignore_fn(token): 38 | if token not in type_counts: 39 | type_counts[token] = 0 40 | type_counts[token] += 1 41 | 42 | # Create sorted list of tokens, by their counts. Reverse so it is in order of 43 | # most frequent to least frequent. 44 | sorted_type_counts = sorted(sorted(type_counts.items()), 45 | key=operator.itemgetter(1))[::-1] 46 | 47 | sorted_types = [typecount[0] 48 | for typecount in sorted_type_counts if typecount[1] >= self.min_occur] 49 | 50 | # Append the necessary functional tokens. 51 | sorted_types = self.functional_types + sorted_types 52 | 53 | # Cut off if vocab_size is set (nonnegative) 54 | if self.max_size >= 0: 55 | vocab = sorted_types[:max(self.max_size, len(sorted_types))] 56 | else: 57 | vocab = sorted_types 58 | 59 | return vocab 60 | 61 | def __init__(self, 62 | sequences, 63 | filename, 64 | functional_types=None, 65 | max_size=-1, 66 | min_occur=0, 67 | ignore_fn=lambda x: False): 68 | self.functional_types = functional_types 69 | self.max_size = max_size 70 | self.min_occur = min_occur 71 | 72 | vocab = self.get_vocab(sequences, ignore_fn) 73 | 74 | self.id_to_token = [] 75 | self.token_to_id = {} 76 | 77 | for i, word_type in enumerate(vocab): 78 | self.id_to_token.append(word_type) 79 | self.token_to_id[word_type] = i 80 | 81 | # Load the previous vocab, if it exists. 82 | if os.path.exists(filename): 83 | infile = open(filename, 'rb') 84 | loaded_vocab = pickle.load(infile) 85 | infile.close() 86 | 87 | print("Loaded vocabulary from " + str(filename)) 88 | if loaded_vocab.id_to_token != self.id_to_token \ 89 | or loaded_vocab.token_to_id != self.token_to_id: 90 | print("Loaded vocabulary is different than generated vocabulary.") 91 | else: 92 | print("Writing vocabulary to " + str(filename)) 93 | outfile = open(filename, 'wb') 94 | pickle.dump(self, outfile) 95 | outfile.close() 96 | 97 | def __len__(self): 98 | return len(self.id_to_token) 99 | -------------------------------------------------------------------------------- /eval_scripts/metric_averages.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | predictions = [json.loads(line) for line in open(sys.argv[1]).readlines() if line] 5 | 6 | string_count = 0. 7 | sem_count = 0. 8 | syn_count = 0. 9 | table_count = 0. 10 | strict_table_count = 0. 11 | 12 | precision_denom = 0. 13 | precision = 0. 14 | recall_denom = 0. 15 | recall = 0. 16 | f1_score = 0. 17 | f1_denom = 0. 18 | 19 | time = 0. 20 | 21 | for prediction in predictions: 22 | if prediction["correct_string"]: 23 | string_count += 1. 24 | if prediction["semantic"]: 25 | sem_count += 1. 26 | if prediction["syntactic"]: 27 | syn_count += 1. 28 | if prediction["correct_table"]: 29 | table_count += 1. 30 | if prediction["strict_correct_table"]: 31 | strict_table_count += 1. 32 | if prediction["gold_tables"] !="[[]]": 33 | precision += prediction["table_prec"] 34 | precision_denom += 1 35 | if prediction["pred_table"] != "[]": 36 | recall += prediction["table_rec"] 37 | recall_denom += 1 38 | 39 | if prediction["gold_tables"] != "[[]]": 40 | f1_score += prediction["table_f1"] 41 | f1_denom += 1 42 | 43 | num_p = len(predictions) 44 | print("string precision: " + str(string_count / num_p)) 45 | print("% semantic: " + str(sem_count / num_p)) 46 | print("% syntactic: " + str(syn_count / num_p)) 47 | print("table prec: " + str(table_count / num_p)) 48 | print("strict table prec: " + str(strict_table_count / num_p)) 49 | print("table row prec: " + str(precision / precision_denom)) 50 | print("table row recall: " + str(recall / recall_denom)) 51 | print("table row f1: " + str(f1_score / f1_denom)) 52 | print("inference time: " + str(time / num_p)) 53 | 54 | -------------------------------------------------------------------------------- /eval_scripts/process_sql.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/headacheboy/IGSQL/5cdbdbf5e530be14ff6441d2a4a514f3546d76ba/eval_scripts/process_sql.pyc -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | """Contains the logging class.""" 2 | 3 | class Logger(): 4 | """Attributes: 5 | 6 | fileptr (file): File pointer for input/output. 7 | lines (list of str): The lines read from the log. 8 | """ 9 | def __init__(self, filename, option): 10 | self.fileptr = open(filename, option) 11 | if option == "r": 12 | self.lines = self.fileptr.readlines() 13 | else: 14 | self.lines = [] 15 | 16 | def put(self, string): 17 | """Writes to the file.""" 18 | self.fileptr.write(string + "\n") 19 | self.fileptr.flush() 20 | 21 | def close(self): 22 | """Closes the logger.""" 23 | self.fileptr.close() 24 | 25 | def findlast(self, identifier, default=0.): 26 | """Finds the last line in the log with a certain value.""" 27 | for line in self.lines[::-1]: 28 | if line.lower().startswith(identifier): 29 | string = line.strip().split("\t")[1] 30 | if string.replace(".", "").isdigit(): 31 | return float(string) 32 | elif string.lower() == "true": 33 | return True 34 | elif string.lower() == "false": 35 | return False 36 | else: 37 | return string 38 | return default 39 | 40 | def contains(self, string): 41 | """Dtermines whether the string is present in the log.""" 42 | for line in self.lines[::-1]: 43 | if string.lower() in line.lower(): 44 | return True 45 | return False 46 | 47 | def findlast_log_before(self, before_str): 48 | """Finds the last entry in the log before another entry.""" 49 | loglines = [] 50 | in_line = False 51 | for line in self.lines[::-1]: 52 | if line.startswith(before_str): 53 | in_line = True 54 | elif in_line: 55 | loglines.append(line) 56 | if line.strip() == "" and in_line: 57 | return "".join(loglines[::-1]) 58 | return "".join(loglines[::-1]) 59 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | """Contains classes for computing and keeping track of attention distributions. 2 | """ 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from . import torch_utils 8 | 9 | class AttentionResult(namedtuple('AttentionResult', 10 | ('scores', 11 | 'distribution', 12 | 'vector'))): 13 | """Stores the result of an attention calculation.""" 14 | __slots__ = () 15 | 16 | 17 | class Attention(torch.nn.Module): 18 | """Attention mechanism class. Stores parameters for and computes attention. 19 | 20 | Attributes: 21 | transform_query (bool): Whether or not to transform the query being 22 | passed in with a weight transformation before computing attentino. 23 | transform_key (bool): Whether or not to transform the key being 24 | passed in with a weight transformation before computing attentino. 25 | transform_value (bool): Whether or not to transform the value being 26 | passed in with a weight transformation before computing attentino. 27 | key_size (int): The size of the key vectors. 28 | value_size (int): The size of the value vectors. 29 | the query or key. 30 | query_weights (dy.Parameters): Weights for transforming the query. 31 | key_weights (dy.Parameters): Weights for transforming the key. 32 | value_weights (dy.Parameters): Weights for transforming the value. 33 | """ 34 | def __init__(self, query_size, key_size, value_size): 35 | super().__init__() 36 | self.key_size = key_size 37 | self.value_size = value_size 38 | 39 | self.query_weights = torch_utils.add_params((query_size, self.key_size), "weights-attention-q") 40 | 41 | def transform_arguments(self, query, keys, values): 42 | """ Transforms the query/key/value inputs before attention calculations. 43 | 44 | Arguments: 45 | query (dy.Expression): Vector representing the query (e.g., hidden state.) 46 | keys (list of dy.Expression): List of vectors representing the key 47 | values. 48 | values (list of dy.Expression): List of vectors representing the values. 49 | 50 | Returns: 51 | triple of dy.Expression, where the first represents the (transformed) 52 | query, the second represents the (transformed and concatenated) 53 | keys, and the third represents the (transformed and concatenated) 54 | values. 55 | """ 56 | assert len(keys) == len(values) 57 | 58 | all_keys = torch.stack(keys, dim=1) 59 | all_values = torch.stack(values, dim=1) 60 | 61 | assert all_keys.size()[0] == self.key_size, "Expected key size of " + str(self.key_size) + " but got " + str(all_keys.size()[0]) 62 | assert all_values.size()[0] == self.value_size 63 | 64 | query = torch_utils.linear_layer(query, self.query_weights) 65 | 66 | return query, all_keys, all_values 67 | 68 | def forward(self, query, keys, values=None): 69 | if not values: 70 | values = keys 71 | 72 | query_t, keys_t, values_t = self.transform_arguments(query, keys, values) 73 | 74 | scores = torch.t(torch.mm(query_t,keys_t)) # len(key) x len(query) 75 | 76 | distribution = F.softmax(scores, dim=0) # len(key) x len(query) 77 | 78 | context_vector = torch.mm(values_t, distribution).squeeze() # value_size x len(query) 79 | 80 | return AttentionResult(scores, distribution, context_vector) 81 | -------------------------------------------------------------------------------- /model/bert/LICENSE_bert: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /model/bert/README_bert.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Google AI's BERT model with a script to load Google's pre-trained models 2 | 3 | ## Forked for wikisql application 4 | 5 | ## NSML 6 | 7 | ### SQuAD1.1 finetuning 8 | 9 | ``` 10 | nsml run -d squad_bert -g 4 -e run_squad.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu" 11 | 12 | ``` 13 | 14 | ### SQuAD2.0 finetuning 15 | 16 | ``` 17 | nsml run -d squad_bert -g 4 -e run_squad2.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu" 18 | ``` 19 | 20 | ### Evaluation 21 | 22 | 1. Download prediction file from NSML session 23 | 24 | ``` 25 | nsml download -f /app/squad_base [NSML_ID]/squad_bert/[SESSION] . 26 | ``` 27 | 28 | 2. Run official evaluation file 29 | 30 | ``` 31 | python3 evaluate-v1.1.py [dev.json] [predictions.json] 32 | 33 | python3 evaluate-v2.0.py [dev.json] [predictions.json] -n [na_probs.json] 34 | ``` 35 | 36 | ## Introduction 37 | 38 | This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 39 | 40 | This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below). 41 | 42 | The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated). 43 | 44 | ## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models)) 45 | 46 | You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. 47 | 48 | This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`). 49 | 50 | You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. 51 | 52 | To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. 53 | 54 | Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: 55 | 56 | ```shell 57 | export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 58 | 59 | python convert_tf_checkpoint_to_pytorch.py \ 60 | --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ 61 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 62 | --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin 63 | ``` 64 | 65 | You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). 66 | 67 | ## PyTorch models for BERT 68 | 69 | We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py): 70 | 71 | - `BertModel` - the basic BERT Transformer model 72 | - `BertForSequenceClassification` - the BERT model with a sequence classification head on top 73 | - `BertForQuestionAnswering` - the BERT model with a token classification head on top 74 | 75 | Here are some details on each class. 76 | 77 | ### 1. `BertModel` 78 | 79 | `BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large). 80 | 81 | The inputs and output are **identical to the TensorFlow model inputs and outputs**. 82 | 83 | We detail them here. This model takes as inputs: 84 | 85 | - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and 86 | - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 87 | - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. 88 | 89 | This model outputs a tuple composed of: 90 | 91 | - `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and 92 | - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 93 | 94 | An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input. 95 | 96 | ### 2. `BertForSequenceClassification` 97 | 98 | `BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`. 99 | 100 | The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper). 101 | 102 | An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task. 103 | 104 | ### 3. `BertForQuestionAnswering` 105 | 106 | `BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states. 107 | 108 | The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper). 109 | 110 | An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task. 111 | 112 | ## Installation, requirements, test 113 | 114 | This code was tested on Python 3.5+. The requirements are: 115 | 116 | - PyTorch (>= 0.4.1) 117 | - tqdm 118 | 119 | To install the dependencies: 120 | 121 | ````bash 122 | pip install -r ./requirements.txt 123 | ```` 124 | 125 | A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`). 126 | 127 | You can run the tests with the command: 128 | ```bash 129 | python -m pytest -sv tests/ 130 | ``` 131 | 132 | ## Training on large batches: gradient accumulation, multi-GPU and distributed training 133 | 134 | BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32). 135 | 136 | To help with fine-tuning these models, we have included four techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: optimize on CPU, gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month. 137 | 138 | Here is how to use these techniques in our scripts: 139 | 140 | - **Optimize on CPU**: The Adam optimizer comprise 2 moving average of all the weights of the model which means that if you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal when using a large model like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU to free more room on the GPU(s). As the most computational intensive operation is the backward pass, this usually doesn't increase the computation time by a lot. This is the only way to fine-tune `BERT-large` in a reasonable time on GPU(s) (see below). Activate this option with `--optimize_on_cpu` on the `run_squad.py` script. 141 | - **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps. 142 | - **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs. 143 | - **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details): 144 | 145 | ```bash 146 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script) 147 | ``` 148 | 149 | Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`. 150 | 151 | ## TPU support and pretraining scripts 152 | 153 | TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)). 154 | 155 | We will add TPU support when this next release is published. 156 | 157 | The original TensorFlow code further comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py). 158 | 159 | Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts. 160 | 161 | ## Comparing the PyTorch model and the TensorFlow model predictions 162 | 163 | We also include [two Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. 164 | 165 | - The first NoteBook ([Comparing TF and PT models.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models. 166 | 167 | - The second NoteBook ([Comparing TF and PT models SQuAD predictions.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models%20SQuAD%20predictions.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models. 168 | 169 | Please follow the instructions given in the notebooks to run and modify them. They can also be nice example on how to use the models in a simpler way than the full fine-tuning scripts we provide. 170 | 171 | ## Fine-tuning with BERT: running the examples 172 | 173 | We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD. 174 | 175 | Before running these examples you should download the 176 | [GLUE data](https://gluebenchmark.com/tasks) by running 177 | [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) 178 | and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base` 179 | checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section. 180 | 181 | This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase 182 | Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80. 183 | 184 | ```shell 185 | export GLUE_DIR=/path/to/glue 186 | 187 | python run_classifier.py \ 188 | --task_name MRPC \ 189 | --do_train \ 190 | --do_eval \ 191 | --do_lower_case \ 192 | --data_dir $GLUE_DIR/MRPC/ \ 193 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 194 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 195 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 196 | --max_seq_length 128 \ 197 | --train_batch_size 32 \ 198 | --learning_rate 2e-5 \ 199 | --num_train_epochs 3.0 \ 200 | --output_dir /tmp/mrpc_output/ 201 | ``` 202 | 203 | Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%. 204 | 205 | The second example fine-tunes `BERT-Base` on the SQuAD question answering task. 206 | 207 | The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory. 208 | 209 | * [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json) 210 | * [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json) 211 | * [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py) 212 | 213 | ```shell 214 | export SQUAD_DIR=/path/to/SQUAD 215 | 216 | python run_squad.py \ 217 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 218 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 219 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 220 | --do_train \ 221 | --do_predict \ 222 | --do_lower_case 223 | --train_file $SQUAD_DIR/train-v1.1.json \ 224 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 225 | --train_batch_size 12 \ 226 | --learning_rate 3e-5 \ 227 | --num_train_epochs 2.0 \ 228 | --max_seq_length 384 \ 229 | --doc_stride 128 \ 230 | --output_dir ../debug_squad/ 231 | ``` 232 | 233 | Training with the previous hyper-parameters gave us the following results: 234 | ```bash 235 | {"f1": 88.52381567990474, "exact_match": 81.22043519394512} 236 | ``` 237 | 238 | # Fine-tuning BERT-large on GPUs 239 | 240 | The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation. 241 | 242 | For example, fine-tuning BERT-large on SQuAD can be done on a server with 4 k-80 (these are pretty old now) in 18 hours. Our results are similar to the TensorFlow implementation results (actually slightly higher): 243 | ```bash 244 | {"exact_match": 84.56953642384106, "f1": 91.04028647786927} 245 | ``` 246 | To get these results that we used a combination of: 247 | - multi-GPU training (automatically activated on a multi-GPU server), 248 | - 2 steps of gradient accumulation and 249 | - perform the optimization step on CPU to store Adam's averages in RAM. 250 | 251 | Here are the full list of hyper-parameters we used for this run: 252 | ```bash 253 | python ./run_squad.py --vocab_file $BERT_LARGE_DIR/vocab.txt --bert_config_file $BERT_LARGE_DIR/bert_config.json --init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin --do_lower_case --do_train --do_predict --train_file $SQUAD_TRAIN --predict_file $SQUAD_EVAL --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir $OUTPUT_DIR/bert_large_bsz_24 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu 254 | ``` 255 | 256 | -------------------------------------------------------------------------------- /model/bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if name[0] in ['redictions', 'eq_relationship']: 75 | print("Skipping") 76 | continue 77 | pointer = model 78 | for m_name in name: 79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 80 | l = re.split(r'_(\d+)', m_name) 81 | else: 82 | l = [m_name] 83 | if l[0] == 'kernel': 84 | pointer = getattr(pointer, 'weight') 85 | else: 86 | pointer = getattr(pointer, l[0]) 87 | if len(l) >= 2: 88 | num = int(l[1]) 89 | pointer = pointer[num] 90 | if m_name[-11:] == '_embeddings': 91 | pointer = getattr(pointer, 'weight') 92 | elif m_name == 'kernel': 93 | array = np.transpose(array) 94 | try: 95 | assert pointer.shape == array.shape 96 | except AssertionError as e: 97 | e.args += (pointer.shape, array.shape) 98 | raise 99 | pointer.data = torch.from_numpy(array) 100 | 101 | # Save pytorch-model 102 | torch.save(model.state_dict(), args.pytorch_dump_path) 103 | 104 | if __name__ == "__main__": 105 | convert() 106 | -------------------------------------------------------------------------------- /model/bert/data/annotated_wikisql_and_PyTorch_bert_param/bert_config_uncased_L-12_H-768_A-12.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /model/bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with open(vocab_file, "r", encoding="utf-8") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | ids.append(vocab[token]) 89 | return ids 90 | 91 | 92 | def whitespace_tokenize(text): 93 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 94 | text = text.strip() 95 | if not text: 96 | return [] 97 | tokens = text.split() 98 | return tokens 99 | 100 | 101 | class FullTokenizer(object): 102 | """Runs end-to-end tokenziation.""" 103 | 104 | def __init__(self, vocab_file, do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 108 | 109 | def tokenize(self, text): 110 | split_tokens = [] 111 | for token in self.basic_tokenizer.tokenize(text): 112 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 113 | split_tokens.append(sub_token) 114 | 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | return convert_tokens_to_ids(self.vocab, tokens) 119 | 120 | 121 | class BasicTokenizer(object): 122 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 123 | 124 | def __init__(self, do_lower_case=True): 125 | """Constructs a BasicTokenizer. 126 | 127 | Args: 128 | do_lower_case: Whether to lower case the input. 129 | """ 130 | self.do_lower_case = do_lower_case 131 | 132 | def tokenize(self, text): 133 | """Tokenizes a piece of text.""" 134 | text = convert_to_unicode(text) 135 | text = self._clean_text(text) 136 | # This was added on November 1st, 2018 for the multilingual and Chinese 137 | # models. This is also applied to the English models now, but it doesn't 138 | # matter since the English models were not trained on any Chinese data 139 | # and generally don't have any Chinese data in them (there are Chinese 140 | # characters in the vocabulary because Wikipedia does have some Chinese 141 | # words in the English Wikipedia.). 142 | text = self._tokenize_chinese_chars(text) 143 | orig_tokens = whitespace_tokenize(text) 144 | split_tokens = [] 145 | for token in orig_tokens: 146 | if self.do_lower_case: 147 | token = token.lower() 148 | token = self._run_strip_accents(token) 149 | split_tokens.extend(self._run_split_on_punc(token)) 150 | 151 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 152 | return output_tokens 153 | 154 | def _run_strip_accents(self, text): 155 | """Strips accents from a piece of text.""" 156 | text = unicodedata.normalize("NFD", text) 157 | output = [] 158 | for char in text: 159 | cat = unicodedata.category(char) 160 | if cat == "Mn": 161 | continue 162 | output.append(char) 163 | return "".join(output) 164 | 165 | def _run_split_on_punc(self, text): 166 | """Splits punctuation on a piece of text.""" 167 | chars = list(text) 168 | i = 0 169 | start_new_word = True 170 | output = [] 171 | while i < len(chars): 172 | char = chars[i] 173 | if _is_punctuation(char): 174 | output.append([char]) 175 | start_new_word = True 176 | else: 177 | if start_new_word: 178 | output.append([]) 179 | start_new_word = False 180 | output[-1].append(char) 181 | i += 1 182 | 183 | return ["".join(x) for x in output] 184 | 185 | def _tokenize_chinese_chars(self, text): 186 | """Adds whitespace around any CJK character.""" 187 | output = [] 188 | for char in text: 189 | cp = ord(char) 190 | if self._is_chinese_char(cp): 191 | output.append(" ") 192 | output.append(char) 193 | output.append(" ") 194 | else: 195 | output.append(char) 196 | return "".join(output) 197 | 198 | def _is_chinese_char(self, cp): 199 | """Checks whether CP is the codepoint of a CJK character.""" 200 | # This defines a "chinese character" as anything in the CJK Unicode block: 201 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 202 | # 203 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 204 | # despite its name. The modern Korean Hangul alphabet is a different block, 205 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 206 | # space-separated words, so they are not treated specially and handled 207 | # like the all of the other languages. 208 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 209 | (cp >= 0x3400 and cp <= 0x4DBF) or # 210 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 211 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 212 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 213 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 214 | (cp >= 0xF900 and cp <= 0xFAFF) or # 215 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 216 | return True 217 | 218 | return False 219 | 220 | def _clean_text(self, text): 221 | """Performs invalid character removal and whitespace cleanup on text.""" 222 | output = [] 223 | for char in text: 224 | cp = ord(char) 225 | if cp == 0 or cp == 0xfffd or _is_control(char): 226 | continue 227 | if _is_whitespace(char): 228 | output.append(" ") 229 | else: 230 | output.append(char) 231 | return "".join(output) 232 | 233 | 234 | class WordpieceTokenizer(object): 235 | """Runs WordPiece tokenization.""" 236 | 237 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 238 | self.vocab = vocab 239 | self.unk_token = unk_token 240 | self.max_input_chars_per_word = max_input_chars_per_word 241 | 242 | def tokenize(self, text): 243 | """Tokenizes a piece of text into its word pieces. 244 | 245 | This uses a greedy longest-match-first algorithm to perform tokenization 246 | using the given vocabulary. 247 | 248 | For example: 249 | input = "unaffable" 250 | output = ["un", "##aff", "##able"] 251 | 252 | Args: 253 | text: A single token or whitespace separated tokens. This should have 254 | already been passed through `BasicTokenizer. 255 | 256 | Returns: 257 | A list of wordpiece tokens. 258 | """ 259 | 260 | text = convert_to_unicode(text) 261 | 262 | output_tokens = [] 263 | for token in whitespace_tokenize(text): 264 | chars = list(token) 265 | if len(chars) > self.max_input_chars_per_word: 266 | output_tokens.append(self.unk_token) 267 | continue 268 | 269 | is_bad = False 270 | start = 0 271 | sub_tokens = [] 272 | while start < len(chars): 273 | end = len(chars) 274 | cur_substr = None 275 | while start < end: 276 | substr = "".join(chars[start:end]) 277 | if start > 0: 278 | substr = "##" + substr 279 | if substr in self.vocab: 280 | cur_substr = substr 281 | break 282 | end -= 1 283 | if cur_substr is None: 284 | is_bad = True 285 | break 286 | sub_tokens.append(cur_substr) 287 | start = end 288 | 289 | if is_bad: 290 | output_tokens.append(self.unk_token) 291 | else: 292 | output_tokens.extend(sub_tokens) 293 | return output_tokens 294 | 295 | 296 | def _is_whitespace(char): 297 | """Checks whether `chars` is a whitespace character.""" 298 | # \t, \n, and \r are technically contorl characters but we treat them 299 | # as whitespace since they are generally considered as such. 300 | if char == " " or char == "\t" or char == "\n" or char == "\r": 301 | return True 302 | cat = unicodedata.category(char) 303 | if cat == "Zs": 304 | return True 305 | return False 306 | 307 | 308 | def _is_control(char): 309 | """Checks whether `chars` is a control character.""" 310 | # These are technically control characters but we count them as whitespace 311 | # characters. 312 | if char == "\t" or char == "\n" or char == "\r": 313 | return False 314 | cat = unicodedata.category(char) 315 | if cat.startswith("C"): 316 | return True 317 | return False 318 | 319 | 320 | def _is_punctuation(char): 321 | """Checks whether `chars` is a punctuation character.""" 322 | cp = ord(char) 323 | # We treat all non-letter/number ASCII as punctuation. 324 | # Characters such as "^", "$", and "`" are not in the Unicode 325 | # Punctuation class but we treat them as punctuation anyways, for 326 | # consistency. 327 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 328 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 329 | return True 330 | cat = unicodedata.category(char) 331 | if cat.startswith("P"): 332 | return True 333 | return False 334 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | """ Decoder for the SQL generation problem.""" 2 | 3 | from collections import namedtuple 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from . import torch_utils 9 | 10 | from .token_predictor import PredictionInput, PredictionInputWithSchema 11 | import data_util.snippets as snippet_handler 12 | from . import embedder 13 | from data_util.vocabulary import EOS_TOK, UNK_TOK 14 | 15 | def flatten_distribution(distribution_map, probabilities): 16 | """ Flattens a probability distribution given a map of "unique" values. 17 | All values in distribution_map with the same value should get the sum 18 | of the probabilities. 19 | 20 | Arguments: 21 | distribution_map (list of str): List of values to get the probability for. 22 | probabilities (np.ndarray): Probabilities corresponding to the values in 23 | distribution_map. 24 | 25 | Returns: 26 | list, np.ndarray of the same size where probabilities for duplicates 27 | in distribution_map are given the sum of the probabilities in probabilities. 28 | """ 29 | assert len(distribution_map) == len(probabilities) 30 | if len(distribution_map) != len(set(distribution_map)): 31 | idx_first_dup = 0 32 | seen_set = set() 33 | for i, tok in enumerate(distribution_map): 34 | if tok in seen_set: 35 | idx_first_dup = i 36 | break 37 | seen_set.add(tok) 38 | new_dist_map = distribution_map[:idx_first_dup] + list( 39 | set(distribution_map) - set(distribution_map[:idx_first_dup])) 40 | assert len(new_dist_map) == len(set(new_dist_map)) 41 | new_probs = np.array( 42 | probabilities[:idx_first_dup] \ 43 | + [0. for _ in range(len(set(distribution_map)) \ 44 | - idx_first_dup)]) 45 | assert len(new_probs) == len(new_dist_map) 46 | 47 | for i, token_name in enumerate( 48 | distribution_map[idx_first_dup:]): 49 | if token_name not in new_dist_map: 50 | new_dist_map.append(token_name) 51 | 52 | new_index = new_dist_map.index(token_name) 53 | new_probs[new_index] += probabilities[i + 54 | idx_first_dup] 55 | new_probs = new_probs.tolist() 56 | else: 57 | new_dist_map = distribution_map 58 | new_probs = probabilities 59 | 60 | assert len(new_dist_map) == len(new_probs) 61 | 62 | return new_dist_map, new_probs 63 | 64 | class SQLPrediction(namedtuple('SQLPrediction', 65 | ('predictions', 66 | 'sequence', 67 | 'probability'))): 68 | """Contains prediction for a sequence.""" 69 | __slots__ = () 70 | 71 | def __str__(self): 72 | return str(self.probability) + "\t" + " ".join(self.sequence) 73 | 74 | class SequencePredictorWithSchema(torch.nn.Module): 75 | """ Predicts a sequence. 76 | 77 | Attributes: 78 | lstms (list of dy.RNNBuilder): The RNN used. 79 | token_predictor (TokenPredictor): Used to actually predict tokens. 80 | """ 81 | def __init__(self, 82 | params, 83 | input_size, 84 | output_embedder, 85 | column_name_token_embedder, 86 | token_predictor): 87 | super().__init__() 88 | 89 | self.lstms = torch_utils.create_multilayer_lstm_params(params.decoder_num_layers, input_size, params.decoder_state_size, "LSTM-d") 90 | self.token_predictor = token_predictor 91 | self.output_embedder = output_embedder 92 | self.column_name_token_embedder = column_name_token_embedder 93 | self.start_token_embedding = torch_utils.add_params((params.output_embedding_size,), "y-0") 94 | 95 | self.input_size = input_size 96 | self.params = params 97 | 98 | def _initialize_decoder_lstm(self, encoder_state): 99 | decoder_lstm_states = [] 100 | for i, lstm in enumerate(self.lstms): 101 | encoder_layer_num = 0 102 | if len(encoder_state[0]) > 1: 103 | encoder_layer_num = i 104 | 105 | # check which one is h_0, which is c_0 106 | c_0 = encoder_state[0][encoder_layer_num].view(1,-1) 107 | h_0 = encoder_state[1][encoder_layer_num].view(1,-1) 108 | 109 | decoder_lstm_states.append((h_0, c_0)) 110 | return decoder_lstm_states 111 | 112 | def get_output_token_embedding(self, output_token, input_schema, snippets): 113 | if self.params.use_snippets and snippet_handler.is_snippet(output_token): 114 | output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema) 115 | else: 116 | if input_schema: 117 | assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True) 118 | if self.output_embedder.in_vocabulary(output_token): 119 | output_token_embedding = self.output_embedder(output_token) 120 | else: 121 | output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True) 122 | else: 123 | output_token_embedding = self.output_embedder(output_token) 124 | return output_token_embedding 125 | 126 | def get_decoder_input(self, output_token_embedding, prediction): 127 | if self.params.use_schema_attention and self.params.use_query_attention: 128 | decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector, prediction.query_attention_results.vector], dim=0) 129 | elif self.params.use_schema_attention: 130 | decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector], dim=0) 131 | else: 132 | decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector], dim=0) 133 | return decoder_input 134 | 135 | def forward(self, 136 | final_encoder_state, 137 | encoder_states, 138 | schema_states, 139 | max_generation_length, 140 | snippets=None, 141 | gold_sequence=None, 142 | input_sequence=None, 143 | previous_queries=None, 144 | previous_query_states=None, 145 | input_schema=None, 146 | dropout_amount=0.): 147 | """ Generates a sequence. """ 148 | index = 0 149 | 150 | context_vector_size = self.input_size - self.params.output_embedding_size 151 | 152 | # Decoder states: just the initialized decoder. 153 | # Current input to decoder: phi(start_token) ; zeros the size of the 154 | # context vector 155 | predictions = [] 156 | sequence = [] 157 | probability = 1. 158 | 159 | decoder_states = self._initialize_decoder_lstm(final_encoder_state) 160 | 161 | if self.start_token_embedding.is_cuda: 162 | decoder_input = torch.cat([self.start_token_embedding, torch.cuda.FloatTensor(context_vector_size).fill_(0)], dim=0) 163 | else: 164 | decoder_input = torch.cat([self.start_token_embedding, torch.zeros(context_vector_size)], dim=0) 165 | 166 | continue_generating = True 167 | while continue_generating: 168 | if len(sequence) == 0 or sequence[-1] != EOS_TOK: 169 | _, decoder_state, decoder_states = torch_utils.forward_one_multilayer(self.lstms, decoder_input, decoder_states, dropout_amount) 170 | prediction_input = PredictionInputWithSchema(decoder_state=decoder_state, 171 | input_hidden_states=encoder_states, 172 | schema_states=schema_states, 173 | snippets=snippets, 174 | input_sequence=input_sequence, 175 | previous_queries=previous_queries, 176 | previous_query_states=previous_query_states, 177 | input_schema=input_schema) 178 | 179 | prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount) 180 | 181 | predictions.append(prediction) 182 | 183 | if gold_sequence: 184 | output_token = gold_sequence[index] 185 | 186 | output_token_embedding = self.get_output_token_embedding(output_token, input_schema, snippets) 187 | 188 | decoder_input = self.get_decoder_input(output_token_embedding, prediction) 189 | 190 | sequence.append(gold_sequence[index]) 191 | 192 | if index >= len(gold_sequence) - 1: 193 | continue_generating = False 194 | else: 195 | assert prediction.scores.dim() == 1 196 | probabilities = F.softmax(prediction.scores, dim=0).cpu().data.numpy().tolist() 197 | 198 | distribution_map = prediction.aligned_tokens 199 | assert len(probabilities) == len(distribution_map) 200 | 201 | if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0: 202 | assert prediction.query_scores.dim() == 1 203 | query_token_probabilities = F.softmax(prediction.query_scores, dim=0).cpu().data.numpy().tolist() 204 | 205 | query_token_distribution_map = prediction.query_tokens 206 | 207 | assert len(query_token_probabilities) == len(query_token_distribution_map) 208 | 209 | copy_switch = prediction.copy_switch.cpu().data.numpy() 210 | 211 | # Merge the two 212 | probabilities = ((np.array(probabilities) * (1 - copy_switch)).tolist() + 213 | (np.array(query_token_probabilities) * copy_switch).tolist() 214 | ) 215 | distribution_map = distribution_map + query_token_distribution_map 216 | assert len(probabilities) == len(distribution_map) 217 | 218 | # Get a new probabilities and distribution_map consolidating duplicates 219 | distribution_map, probabilities = flatten_distribution(distribution_map, probabilities) 220 | 221 | # Modify the probability distribution so that the UNK token can never be produced 222 | probabilities[distribution_map.index(UNK_TOK)] = 0. 223 | argmax_index = int(np.argmax(probabilities)) 224 | 225 | argmax_token = distribution_map[argmax_index] 226 | sequence.append(argmax_token) 227 | 228 | output_token_embedding = self.get_output_token_embedding(argmax_token, input_schema, snippets) 229 | 230 | decoder_input = self.get_decoder_input(output_token_embedding, prediction) 231 | 232 | probability *= probabilities[argmax_index] 233 | 234 | continue_generating = False 235 | if index < max_generation_length and argmax_token != EOS_TOK: 236 | continue_generating = True 237 | 238 | index += 1 239 | 240 | return SQLPrediction(predictions, 241 | sequence, 242 | probability) -------------------------------------------------------------------------------- /model/embedder.py: -------------------------------------------------------------------------------- 1 | """ Embedder for tokens. """ 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import data_util.snippets as snippet_handler 7 | import data_util.vocabulary as vocabulary_handler 8 | 9 | class Embedder(torch.nn.Module): 10 | """ Embeds tokens. """ 11 | def __init__(self, embedding_size, name="", initializer=None, vocabulary=None, num_tokens=-1, anonymizer=None, freeze=False, use_unk=True): 12 | super().__init__() 13 | 14 | if vocabulary: 15 | assert num_tokens < 0, "Specified a vocabulary but also set number of tokens to " + \ 16 | str(num_tokens) 17 | self.in_vocabulary = lambda token: token in vocabulary.tokens 18 | self.vocab_token_lookup = lambda token: vocabulary.token_to_id(token) 19 | if use_unk: 20 | self.unknown_token_id = vocabulary.token_to_id(vocabulary_handler.UNK_TOK) 21 | else: 22 | self.unknown_token_id = -1 23 | self.vocabulary_size = len(vocabulary) 24 | else: 25 | def check_vocab(index): 26 | """ Makes sure the index is in the vocabulary.""" 27 | assert index < num_tokens, "Passed token ID " + \ 28 | str(index) + "; expecting something less than " + str(num_tokens) 29 | return index < num_tokens 30 | self.in_vocabulary = check_vocab 31 | self.vocab_token_lookup = lambda x: x 32 | self.unknown_token_id = num_tokens # Deliberately throws an error here, 33 | # But should crash before this 34 | self.vocabulary_size = num_tokens 35 | 36 | self.anonymizer = anonymizer 37 | 38 | emb_name = name + "-tokens" 39 | print("Creating token embedder called " + emb_name + " of size " + str(self.vocabulary_size) + " x " + str(embedding_size)) 40 | 41 | if initializer is not None: 42 | word_embeddings_tensor = torch.FloatTensor(initializer) 43 | self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(word_embeddings_tensor, freeze=freeze) 44 | else: 45 | init_tensor = torch.empty(self.vocabulary_size, embedding_size).uniform_(-0.1, 0.1) 46 | self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False) 47 | 48 | if self.anonymizer: 49 | emb_name = name + "-entities" 50 | entity_size = len(self.anonymizer.entity_types) 51 | print("Creating entity embedder called " + emb_name + " of size " + str(entity_size) + " x " + str(embedding_size)) 52 | init_tensor = torch.empty(entity_size, embedding_size).uniform_(-0.1, 0.1) 53 | self.entity_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False) 54 | 55 | 56 | def forward(self, token): 57 | assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets" 58 | 59 | if self.in_vocabulary(token): 60 | index_list = torch.LongTensor([self.vocab_token_lookup(token)]) 61 | if self.token_embedding_matrix.weight.is_cuda: 62 | index_list = index_list.cuda() 63 | return self.token_embedding_matrix(index_list).squeeze() 64 | elif self.anonymizer and self.anonymizer.is_anon_tok(token): 65 | index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)]) 66 | if self.token_embedding_matrix.weight.is_cuda: 67 | index_list = index_list.cuda() 68 | return self.entity_embedding_matrix(index_list).squeeze() 69 | else: 70 | index_list = torch.LongTensor([self.unknown_token_id]) 71 | if self.token_embedding_matrix.weight.is_cuda: 72 | index_list = index_list.cuda() 73 | return self.token_embedding_matrix(index_list).squeeze() 74 | 75 | 76 | def bow_snippets(token, snippets, output_embedder, input_schema): 77 | """ Bag of words embedding for snippets""" 78 | assert snippet_handler.is_snippet(token) and snippets 79 | 80 | snippet_sequence = [] 81 | for snippet in snippets: 82 | if snippet.name == token: 83 | snippet_sequence = snippet.sequence 84 | break 85 | assert snippet_sequence 86 | 87 | if input_schema: 88 | snippet_embeddings = [] 89 | for output_token in snippet_sequence: 90 | assert output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True) 91 | if output_embedder.in_vocabulary(output_token): 92 | snippet_embeddings.append(output_embedder(output_token)) 93 | else: 94 | snippet_embeddings.append(input_schema.column_name_embedder(output_token, surface_form=True)) 95 | else: 96 | snippet_embeddings = [output_embedder(subtoken) for subtoken in snippet_sequence] 97 | 98 | snippet_embeddings = torch.stack(snippet_embeddings, dim=0) # len(snippet_sequence) x emb_size 99 | return torch.mean(snippet_embeddings, dim=0) # emb_size 100 | 101 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | """ Contains code for encoding an input sequence. """ 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from .torch_utils import create_multilayer_lstm_params, encode_sequence 6 | 7 | class Encoder(torch.nn.Module): 8 | """ Encodes an input sequence. """ 9 | def __init__(self, num_layers, input_size, state_size): 10 | super().__init__() 11 | 12 | self.num_layers = num_layers 13 | self.forward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-ef") 14 | self.backward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-eb") 15 | 16 | def forward(self, sequence, embedder, dropout_amount=0.): 17 | """ Encodes a sequence forward and backward. 18 | Inputs: 19 | forward_seq (list of str): The string forwards. 20 | backward_seq (list of str): The string backwards. 21 | f_rnns (list of dy.RNNBuilder): The forward RNNs. 22 | b_rnns (list of dy.RNNBuilder): The backward RNNS. 23 | emb_fn (dict str->dy.Expression): Embedding function for tokens in the 24 | sequence. 25 | size (int): The size of the RNNs. 26 | dropout_amount (float, optional): The amount of dropout to apply. 27 | 28 | Returns: 29 | (list of dy.Expression, list of dy.Expression), list of dy.Expression, 30 | where the first pair is the (final cell memories, final cell states) of 31 | all layers, and the second list is a list of the final layer's cell 32 | state for all tokens in the sequence. 33 | """ 34 | forward_state, forward_outputs = encode_sequence( 35 | sequence, 36 | self.forward_lstms, 37 | embedder, 38 | dropout_amount=dropout_amount) 39 | 40 | backward_state, backward_outputs = encode_sequence( 41 | sequence[::-1], 42 | self.backward_lstms, 43 | embedder, 44 | dropout_amount=dropout_amount) 45 | 46 | cell_memories = [] 47 | hidden_states = [] 48 | for i in range(self.num_layers): 49 | cell_memories.append(torch.cat([forward_state[0][i], backward_state[0][i]], dim=0)) 50 | hidden_states.append(torch.cat([forward_state[1][i], backward_state[1][i]], dim=0)) 51 | 52 | assert len(forward_outputs) == len(backward_outputs) 53 | 54 | backward_outputs = backward_outputs[::-1] 55 | 56 | final_outputs = [] 57 | for i in range(len(sequence)): 58 | final_outputs.append(torch.cat([forward_outputs[i], backward_outputs[i]], dim=0)) 59 | 60 | return (cell_memories, hidden_states), final_outputs -------------------------------------------------------------------------------- /model/torch_utils.py: -------------------------------------------------------------------------------- 1 | """Contains various utility functions for Dynet models.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | def linear_layer(exp, weights, biases=None): 11 | # exp: input as size_1 or 1 x size_1 12 | # weight: size_1 x size_2 13 | # bias: size_2 14 | if exp.dim() == 1: 15 | exp = torch.unsqueeze(exp, 0) 16 | assert exp.size()[1] == weights.size()[0] 17 | if biases is not None: 18 | assert weights.size()[1] == biases.size()[0] 19 | result = torch.mm(exp, weights) + biases 20 | else: 21 | result = torch.mm(exp, weights) 22 | return result 23 | 24 | 25 | def compute_loss(gold_seq, 26 | scores, 27 | index_to_token_maps, 28 | gold_tok_to_id, 29 | noise=0.00000001): 30 | """ Computes the loss of a gold sequence given scores. 31 | 32 | Inputs: 33 | gold_seq (list of str): A sequence of gold tokens. 34 | scores (list of dy.Expression): Expressions representing the scores of 35 | potential output tokens for each token in gold_seq. 36 | index_to_token_maps (list of dict str->list of int): Maps from index in the 37 | sequence to a dictionary mapping from a string to a set of integers. 38 | gold_tok_to_id (lambda (str, str)->list of int): Maps from the gold token 39 | and some lookup function to the indices in the probability distribution 40 | where the gold token occurs. 41 | noise (float, optional): The amount of noise to add to the loss. 42 | 43 | Returns: 44 | dy.Expression representing the sum of losses over the sequence. 45 | """ 46 | assert len(gold_seq) == len(scores) == len(index_to_token_maps) 47 | 48 | losses = [] 49 | for i, gold_tok in enumerate(gold_seq): 50 | score = scores[i] 51 | token_map = index_to_token_maps[i] 52 | 53 | gold_indices = gold_tok_to_id(gold_tok, token_map) 54 | assert len(gold_indices) > 0 55 | noise_i = noise 56 | ''' 57 | if len(gold_indices) == 1: 58 | noise_i = 0 59 | ''' 60 | 61 | probdist = score 62 | prob_of_tok = torch.sum(probdist[gold_indices]) 63 | if prob_of_tok < noise_i: 64 | prob_of_tok = prob_of_tok + noise_i 65 | elif prob_of_tok > 1 - noise_i: 66 | prob_of_tok = prob_of_tok - noise_i 67 | losses.append(-torch.log(prob_of_tok)) 68 | 69 | return torch.sum(torch.stack(losses)) 70 | 71 | 72 | def get_seq_from_scores(scores, index_to_token_maps): 73 | """Gets the argmax sequence from a set of scores. 74 | 75 | Inputs: 76 | scores (list of dy.Expression): Sequences of output scores. 77 | index_to_token_maps (list of list of str): For each output token, maps 78 | the index in the probability distribution to a string. 79 | 80 | Returns: 81 | list of str, representing the argmax sequence. 82 | """ 83 | seq = [] 84 | for score, tok_map in zip(scores, index_to_token_maps): 85 | # score_numpy_list = score.cpu().detach().numpy() 86 | score_numpy_list = score.cpu().data.numpy() 87 | assert score.size()[0] == len(tok_map) == len(list(score_numpy_list)) 88 | seq.append(tok_map[np.argmax(score_numpy_list)]) 89 | return seq 90 | 91 | def per_token_accuracy(gold_seq, pred_seq): 92 | """ Returns the per-token accuracy comparing two strings (recall). 93 | 94 | Inputs: 95 | gold_seq (list of str): A list of gold tokens. 96 | pred_seq (list of str): A list of predicted tokens. 97 | 98 | Returns: 99 | float, representing the accuracy. 100 | """ 101 | num_correct = 0 102 | for i, gold_token in enumerate(gold_seq): 103 | if i < len(pred_seq) and pred_seq[i] == gold_token: 104 | num_correct += 1 105 | 106 | return float(num_correct) / len(gold_seq) 107 | 108 | def forward_one_multilayer(rnns, lstm_input, layer_states, dropout_amount=0.): 109 | """ Goes forward for one multilayer RNN cell step. 110 | 111 | Inputs: 112 | lstm_input (dy.Expression): Some input to the step. 113 | layer_states (list of dy.RNNState): The states of each layer in the cell. 114 | dropout_amount (float, optional): The amount of dropout to apply, in 115 | between the layers. 116 | 117 | Returns: 118 | (list of dy.Expression, list of dy.Expression), dy.Expression, (list of dy.RNNSTate), 119 | representing (each layer's cell memory, each layer's cell hidden state), 120 | the final hidden state, and (each layer's updated RNNState). 121 | """ 122 | num_layers = len(layer_states) 123 | new_states = [] 124 | cell_states = [] 125 | hidden_states = [] 126 | state = lstm_input 127 | for i in range(num_layers): 128 | # view as (1, input_size) 129 | layer_h, layer_c = rnns[i](torch.unsqueeze(state,0), layer_states[i]) 130 | new_states.append((layer_h, layer_c)) 131 | 132 | layer_h = layer_h.squeeze() 133 | layer_c = layer_c.squeeze() 134 | 135 | state = layer_h 136 | if i < num_layers - 1: 137 | # In both Dynet and Pytorch 138 | # p stands for probability of an element to be zeroed. i.e. p=1 means switch off all activations. 139 | state = F.dropout(state, p=dropout_amount) 140 | 141 | cell_states.append(layer_c) 142 | hidden_states.append(layer_h) 143 | 144 | return (cell_states, hidden_states), state, new_states 145 | 146 | 147 | def encode_sequence(sequence, rnns, embedder, dropout_amount=0.): 148 | """ Encodes a sequence given RNN cells and an embedding function. 149 | 150 | Inputs: 151 | seq (list of str): The sequence to encode. 152 | rnns (list of dy._RNNBuilder): The RNNs to use. 153 | emb_fn (dict str->dy.Expression): Function that embeds strings to 154 | word vectors. 155 | size (int): The size of the RNN. 156 | dropout_amount (float, optional): The amount of dropout to apply. 157 | 158 | Returns: 159 | (list of dy.Expression, list of dy.Expression), list of dy.Expression, 160 | where the first pair is the (final cell memories, final cell states) of 161 | all layers, and the second list is a list of the final layer's cell 162 | state for all tokens in the sequence. 163 | """ 164 | 165 | batch_size = 1 166 | layer_states = [] 167 | for rnn in rnns: 168 | hidden_size = rnn.weight_hh.size()[1] 169 | 170 | # h_0 of shape (batch, hidden_size) 171 | # c_0 of shape (batch, hidden_size) 172 | if rnn.weight_hh.is_cuda: 173 | h_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0) 174 | c_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0) 175 | else: 176 | h_0 = torch.zeros(batch_size,hidden_size) 177 | c_0 = torch.zeros(batch_size,hidden_size) 178 | 179 | layer_states.append((h_0, c_0)) 180 | 181 | outputs = [] 182 | for token in sequence: 183 | rnn_input = embedder(token) 184 | (cell_states, hidden_states), output, layer_states = forward_one_multilayer(rnns,rnn_input,layer_states,dropout_amount) 185 | 186 | outputs.append(output) 187 | 188 | return (cell_states, hidden_states), outputs 189 | 190 | def create_multilayer_lstm_params(num_layers, in_size, state_size, name=""): 191 | """ Adds a multilayer LSTM to the model parameters. 192 | 193 | Inputs: 194 | num_layers (int): Number of layers to create. 195 | in_size (int): The input size to the first layer. 196 | state_size (int): The size of the states. 197 | model (dy.ParameterCollection): The parameter collection for the model. 198 | name (str, optional): The name of the multilayer LSTM. 199 | """ 200 | lstm_layers = [] 201 | for i in range(num_layers): 202 | layer_name = name + "-" + str(i) 203 | print("LSTM " + layer_name + ": " + str(in_size) + " x " + str(state_size) + "; default Dynet initialization of hidden weights") 204 | lstm_layer = torch.nn.LSTMCell(input_size=int(in_size), hidden_size=int(state_size), bias=True) 205 | lstm_layers.append(lstm_layer) 206 | in_size = state_size 207 | return torch.nn.ModuleList(lstm_layers) 208 | 209 | def add_params(size, name=""): 210 | """ Adds parameters to the model. 211 | 212 | Inputs: 213 | model (dy.ParameterCollection): The parameter collection for the model. 214 | size (tuple of int): The size to create. 215 | name (str, optional): The name of the parameters. 216 | """ 217 | if len(size) == 1: 218 | print("vector " + name + ": " + str(size[0]) + "; uniform in [-0.1, 0.1]") 219 | else: 220 | print("matrix " + name + ": " + str(size[0]) + " x " + str(size[1]) + "; uniform in [-0.1, 0.1]") 221 | 222 | size_int = tuple([int(ss) for ss in size]) 223 | #return torch.nn.Parameter(torch.empty(size_int).uniform_(-0.1, 0.1)) 224 | if len(size) == 1: 225 | #return torch.nn.Parameter(torch.zeros(size_int)) 226 | return torch.nn.Parameter(torch.empty(size_int).uniform_(-0.1, 0.1)) 227 | else: 228 | tmp_ret = torch.empty(size_int) 229 | torch.nn.init.xavier_uniform_(tmp_ret) 230 | return torch.nn.Parameter(tmp_ret) 231 | -------------------------------------------------------------------------------- /model/utils_bert.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/naver/sqlova 2 | 3 | import os, json 4 | import random as rd 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .bert import tokenization as tokenization 12 | from .bert.modeling import BertConfig, BertModel 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | def get_bert(params): 17 | BERT_PT_PATH = './model/bert/data/annotated_wikisql_and_PyTorch_bert_param' 18 | map_bert_type_abb = {'uS': 'uncased_L-12_H-768_A-12', 19 | 'uL': 'uncased_L-24_H-1024_A-16', 20 | 'cS': 'cased_L-12_H-768_A-12', 21 | 'cL': 'cased_L-24_H-1024_A-16', 22 | 'mcS': 'multi_cased_L-12_H-768_A-12'} 23 | bert_type = map_bert_type_abb[params.bert_type_abb] 24 | if params.bert_type_abb == 'cS' or params.bert_type_abb == 'cL' or params.bert_type_abb == 'mcS': 25 | do_lower_case = False 26 | else: 27 | do_lower_case = True 28 | no_pretraining = False 29 | 30 | bert_config_file = os.path.join(BERT_PT_PATH, 'bert_config_{0}.json'.format(bert_type)) 31 | vocab_file = os.path.join(BERT_PT_PATH, 'vocab_{0}.txt'.format(bert_type)) 32 | init_checkpoint = os.path.join(BERT_PT_PATH, 'pytorch_model_{0}.bin'.format(bert_type)) 33 | 34 | print('bert_config_file', bert_config_file) 35 | print('vocab_file', vocab_file) 36 | print('init_checkpoint', init_checkpoint) 37 | 38 | bert_config = BertConfig.from_json_file(bert_config_file) 39 | tokenizer = tokenization.FullTokenizer( 40 | vocab_file=vocab_file, do_lower_case=do_lower_case) 41 | bert_config.print_status() 42 | 43 | model_bert = BertModel(bert_config) 44 | if no_pretraining: 45 | pass 46 | else: 47 | model_bert.load_state_dict(torch.load(init_checkpoint, map_location='cpu')) 48 | print("Load pre-trained parameters.") 49 | model_bert.to(device) 50 | 51 | return model_bert, tokenizer, bert_config 52 | 53 | def generate_inputs(tokenizer, nlu1_tok, hds1): 54 | tokens = [] 55 | segment_ids = [] 56 | 57 | t_to_tt_idx_hds1 = [] 58 | 59 | tokens.append("[CLS]") 60 | i_st_nlu = len(tokens) # to use it later 61 | 62 | segment_ids.append(0) 63 | for token in nlu1_tok: 64 | tokens.append(token) 65 | segment_ids.append(0) 66 | i_ed_nlu = len(tokens) 67 | tokens.append("[SEP]") 68 | segment_ids.append(0) 69 | 70 | i_hds = [] 71 | for i, hds11 in enumerate(hds1): 72 | i_st_hd = len(tokens) 73 | t_to_tt_idx_hds11 = [] 74 | sub_tok = [] 75 | for sub_tok1 in hds11.split(): 76 | t_to_tt_idx_hds11.append(len(sub_tok)) 77 | sub_tok += tokenizer.tokenize(sub_tok1) 78 | t_to_tt_idx_hds1.append(t_to_tt_idx_hds11) 79 | tokens += sub_tok 80 | 81 | i_ed_hd = len(tokens) 82 | i_hds.append((i_st_hd, i_ed_hd)) 83 | segment_ids += [1] * len(sub_tok) 84 | if i < len(hds1)-1: 85 | tokens.append("[SEP]") 86 | segment_ids.append(0) 87 | elif i == len(hds1)-1: 88 | tokens.append("[SEP]") 89 | segment_ids.append(1) 90 | else: 91 | raise EnvironmentError 92 | 93 | i_nlu = (i_st_nlu, i_ed_nlu) 94 | 95 | return tokens, segment_ids, i_nlu, i_hds, t_to_tt_idx_hds1 96 | 97 | def gen_l_hpu(i_hds): 98 | """ 99 | # Treat columns as if it is a batch of natural language utterance with batch-size = # of columns * # of batch_size 100 | i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)]) 101 | """ 102 | l_hpu = [] 103 | for i_hds1 in i_hds: 104 | for i_hds11 in i_hds1: 105 | l_hpu.append(i_hds11[1] - i_hds11[0]) 106 | 107 | return l_hpu 108 | 109 | def get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length): 110 | """ 111 | Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT. 112 | 113 | INPUT 114 | :param model_bert: 115 | :param tokenizer: WordPiece toknizer 116 | :param nlu: Question 117 | :param nlu_t: CoreNLP tokenized nlu. 118 | :param hds: Headers 119 | :param hs_t: None or 1st-level tokenized headers 120 | :param max_seq_length: max input token length 121 | 122 | OUTPUT 123 | tokens: BERT input tokens 124 | nlu_tt: WP-tokenized input natural language questions 125 | orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token 126 | tok_to_orig_index: inverse map. 127 | 128 | """ 129 | 130 | l_n = [] 131 | l_hs = [] # The length of columns for each batch 132 | 133 | input_ids = [] 134 | tokens = [] 135 | segment_ids = [] 136 | input_mask = [] 137 | 138 | i_nlu = [] # index to retreive the position of contextual vector later. 139 | i_hds = [] 140 | 141 | doc_tokens = [] 142 | nlu_tt = [] 143 | 144 | t_to_tt_idx = [] 145 | tt_to_t_idx = [] 146 | 147 | t_to_tt_idx_hds = [] 148 | 149 | for b, nlu_t1 in enumerate(nlu_t): 150 | hds1 = hds[b] 151 | l_hs.append(len(hds1)) 152 | 153 | # 1. 2nd tokenization using WordPiece 154 | tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens (here, CoreNLP). 155 | t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens. 156 | nlu_tt1 = [] # all_doc_tokens[ orig_to_tok_idx[i] ] returns first sub-token segement of i-th-1st-level-token 157 | for (i, token) in enumerate(nlu_t1): 158 | t_to_tt_idx1.append( 159 | len(nlu_tt1)) # all_doc_tokens[ indicate the start position of original 'white-space' tokens. 160 | sub_tokens = tokenizer.tokenize(token) 161 | for sub_token in sub_tokens: 162 | tt_to_t_idx1.append(i) 163 | nlu_tt1.append(sub_token) # all_doc_tokens are further tokenized using WordPiece tokenizer 164 | nlu_tt.append(nlu_tt1) 165 | tt_to_t_idx.append(tt_to_t_idx1) 166 | t_to_tt_idx.append(t_to_tt_idx1) 167 | 168 | l_n.append(len(nlu_tt1)) 169 | 170 | # [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP] 171 | # 2. Generate BERT inputs & indices. 172 | tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, hds1) 173 | 174 | assert len(t_to_tt_idx_hds1) == len(hds1) 175 | 176 | t_to_tt_idx_hds.append(t_to_tt_idx_hds1) 177 | 178 | input_ids1 = tokenizer.convert_tokens_to_ids(tokens1) 179 | 180 | # Input masks 181 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 182 | # tokens are attended to. 183 | input_mask1 = [1] * len(input_ids1) 184 | 185 | # 3. Zero-pad up to the sequence length. 186 | if len(nlu_t) == 1: 187 | max_seq_length = len(input_ids1) 188 | while len(input_ids1) < max_seq_length: 189 | input_ids1.append(0) 190 | input_mask1.append(0) 191 | segment_ids1.append(0) 192 | 193 | assert len(input_ids1) == max_seq_length 194 | assert len(input_mask1) == max_seq_length 195 | assert len(segment_ids1) == max_seq_length 196 | 197 | input_ids.append(input_ids1) 198 | tokens.append(tokens1) 199 | segment_ids.append(segment_ids1) 200 | input_mask.append(input_mask1) 201 | 202 | i_nlu.append(i_nlu1) 203 | i_hds.append(i_hds1) 204 | 205 | # Convert to tensor 206 | all_input_ids = torch.tensor(input_ids, dtype=torch.long).to(device) 207 | all_input_mask = torch.tensor(input_mask, dtype=torch.long).to(device) 208 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long).to(device) 209 | 210 | # 4. Generate BERT output. 211 | all_encoder_layer, pooled_output = model_bert(all_input_ids, all_segment_ids, all_input_mask) 212 | 213 | # 5. generate l_hpu from i_hds 214 | l_hpu = gen_l_hpu(i_hds) 215 | 216 | assert len(set(l_n)) == 1 and len(set(i_nlu)) == 1 217 | assert l_n[0] == i_nlu[0][1] - i_nlu[0][0] 218 | 219 | return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ 220 | l_n, l_hpu, l_hs, \ 221 | nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds 222 | 223 | def get_wemb_n(i_nlu, l_n, hS, num_hidden_layers, all_encoder_layer, num_out_layers_n): 224 | """ 225 | Get the representation of each tokens. 226 | """ 227 | bS = len(l_n) 228 | l_n_max = max(l_n) 229 | # print('wemb_n: [bS, l_n_max, hS * num_out_layers_n] = ', bS, l_n_max, hS * num_out_layers_n) 230 | wemb_n = torch.zeros([bS, l_n_max, hS * num_out_layers_n]).to(device) 231 | for b in range(bS): 232 | # [B, max_len, dim] 233 | # Fill zero for non-exist part. 234 | l_n1 = l_n[b] 235 | i_nlu1 = i_nlu[b] 236 | for i_noln in range(num_out_layers_n): 237 | i_layer = num_hidden_layers - 1 - i_noln 238 | st = i_noln * hS 239 | ed = (i_noln + 1) * hS 240 | wemb_n[b, 0:(i_nlu1[1] - i_nlu1[0]), st:ed] = all_encoder_layer[i_layer][b, i_nlu1[0]:i_nlu1[1], :] 241 | return wemb_n 242 | 243 | def get_wemb_h(i_hds, l_hpu, l_hs, hS, num_hidden_layers, all_encoder_layer, num_out_layers_h): 244 | """ 245 | As if 246 | [ [table-1-col-1-tok1, t1-c1-t2, ...], 247 | [t1-c2-t1, t1-c2-t2, ...]. 248 | ... 249 | [t2-c1-t1, ...,] 250 | ] 251 | """ 252 | bS = len(l_hs) 253 | l_hpu_max = max(l_hpu) 254 | num_of_all_hds = sum(l_hs) 255 | wemb_h = torch.zeros([num_of_all_hds, l_hpu_max, hS * num_out_layers_h]).to(device) 256 | # print('wemb_h: [num_of_all_hds, l_hpu_max, hS * num_out_layers_h] = ', wemb_h.size()) 257 | b_pu = -1 258 | for b, i_hds1 in enumerate(i_hds): 259 | for b1, i_hds11 in enumerate(i_hds1): 260 | b_pu += 1 261 | for i_nolh in range(num_out_layers_h): 262 | i_layer = num_hidden_layers - 1 - i_nolh 263 | st = i_nolh * hS 264 | ed = (i_nolh + 1) * hS 265 | wemb_h[b_pu, 0:(i_hds11[1] - i_hds11[0]), st:ed] \ 266 | = all_encoder_layer[i_layer][b, i_hds11[0]:i_hds11[1],:] 267 | 268 | 269 | return wemb_h 270 | 271 | def get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=1, num_out_layers_h=1): 272 | 273 | # get contextual output of all tokens from bert 274 | all_encoder_layer, pooled_output, tokens, i_nlu, i_hds,\ 275 | l_n, l_hpu, l_hs, \ 276 | nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length) 277 | # all_encoder_layer: BERT outputs from all layers. 278 | # pooled_output: output of [CLS] vec. 279 | # tokens: BERT intput tokens 280 | # i_nlu: start and end indices of question in tokens 281 | # i_hds: start and end indices of headers 282 | 283 | # get the wemb 284 | wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer, 285 | num_out_layers_n) 286 | 287 | wemb_h = get_wemb_h(i_hds, l_hpu, l_hs, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer, 288 | num_out_layers_h) 289 | 290 | return wemb_n, wemb_h, l_n, l_hpu, l_hs, \ 291 | nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds 292 | 293 | def prepare_input(tokenizer, input_sequence, input_schema, max_seq_length): 294 | nlu_t = [] 295 | hds = [] 296 | 297 | nlu_t1 = input_sequence 298 | all_hds = input_schema.column_names_embedder_input 299 | 300 | nlu_tt1 = [] 301 | for (i, token) in enumerate(nlu_t1): 302 | nlu_tt1 += tokenizer.tokenize(token) 303 | 304 | current_hds1 = [] 305 | for hds1 in all_hds: 306 | new_hds1 = current_hds1 + [hds1] 307 | tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, new_hds1) 308 | if len(segment_ids1) > max_seq_length: 309 | nlu_t.append(nlu_t1) 310 | hds.append(current_hds1) 311 | current_hds1 = [hds1] 312 | else: 313 | current_hds1 = new_hds1 314 | 315 | if len(current_hds1) > 0: 316 | nlu_t.append(nlu_t1) 317 | hds.append(current_hds1) 318 | 319 | return nlu_t, hds 320 | 321 | def prepare_input_v2(tokenizer, input_sequence, input_schema): 322 | nlu_t = [] 323 | hds = [] 324 | max_seq_length = 0 325 | 326 | nlu_t1 = input_sequence 327 | all_hds = input_schema.column_names_embedder_input 328 | 329 | nlu_tt1 = [] 330 | for (i, token) in enumerate(nlu_t1): 331 | nlu_tt1 += tokenizer.tokenize(token) 332 | 333 | current_hds1 = [] 334 | current_table = '' 335 | for hds1 in all_hds: 336 | hds1_table = hds1.split('.')[0].strip() 337 | if hds1_table == current_table: 338 | current_hds1.append(hds1) 339 | else: 340 | tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, current_hds1) 341 | max_seq_length = max(max_seq_length, len(segment_ids1)) 342 | 343 | nlu_t.append(nlu_t1) 344 | hds.append(current_hds1) 345 | current_hds1 = [hds1] 346 | current_table = hds1_table 347 | 348 | if len(current_hds1) > 0: 349 | tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, current_hds1) 350 | max_seq_length = max(max_seq_length, len(segment_ids1)) 351 | nlu_t.append(nlu_t1) 352 | hds.append(current_hds1) 353 | 354 | return nlu_t, hds, max_seq_length 355 | 356 | def get_bert_encoding(bert_config, model_bert, tokenizer, input_sequence, input_schema, bert_input_version='v1', max_seq_length=512, num_out_layers_n=1, num_out_layers_h=1): 357 | if bert_input_version == 'v1': 358 | nlu_t, hds = prepare_input(tokenizer, input_sequence, input_schema, max_seq_length) 359 | elif bert_input_version == 'v2': 360 | nlu_t, hds, max_seq_length = prepare_input_v2(tokenizer, input_sequence, input_schema) 361 | 362 | wemb_n, wemb_h, l_n, l_hpu, l_hs, nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n, num_out_layers_h) 363 | 364 | t_to_tt_idx = t_to_tt_idx[0] 365 | assert len(t_to_tt_idx) == len(input_sequence) 366 | assert sum(len(t_to_tt_idx_hds1) for t_to_tt_idx_hds1 in t_to_tt_idx_hds) == len(input_schema.column_names_embedder_input) 367 | 368 | assert list(wemb_h.size())[0] == len(input_schema.column_names_embedder_input) 369 | 370 | utterance_states = [] 371 | for i in range(len(t_to_tt_idx)): 372 | start = t_to_tt_idx[i] 373 | if i == len(t_to_tt_idx)-1: 374 | end = l_n[0] 375 | else: 376 | end = t_to_tt_idx[i+1] 377 | utterance_states.append(torch.mean(wemb_n[:,start:end,:], dim=[0,1])) 378 | assert len(utterance_states) == len(input_sequence) 379 | 380 | schema_token_states = [] 381 | cnt = -1 382 | for t_to_tt_idx_hds1 in t_to_tt_idx_hds: 383 | for t_to_tt_idx_hds11 in t_to_tt_idx_hds1: 384 | cnt += 1 385 | schema_token_states1 = [] 386 | for i in range(len(t_to_tt_idx_hds11)): 387 | start = t_to_tt_idx_hds11[i] 388 | if i == len(t_to_tt_idx_hds11)-1: 389 | end = l_hpu[cnt] 390 | else: 391 | end = t_to_tt_idx_hds11[i+1] 392 | schema_token_states1.append(torch.mean(wemb_h[cnt,start:end,:], dim=0)) 393 | assert len(schema_token_states1) == len(input_schema.column_names_embedder_input[cnt].split()) 394 | schema_token_states.append(schema_token_states1) 395 | 396 | assert len(schema_token_states) == len(input_schema.column_names_embedder_input) 397 | 398 | return utterance_states, schema_token_states 399 | -------------------------------------------------------------------------------- /parse_args.py: -------------------------------------------------------------------------------- 1 | import sys 2 | args = sys.argv 3 | 4 | import os 5 | import argparse 6 | 7 | def interpret_args(): 8 | """ Interprets the command line arguments, and returns a dictionary. """ 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--no_gpus", type=bool, default=1) 12 | 13 | ### Data parameters 14 | parser.add_argument( 15 | '--raw_train_filename', 16 | type=str, 17 | default='../atis_data/data/resplit/processed/train_with_tables.pkl') 18 | parser.add_argument( 19 | '--raw_dev_filename', 20 | type=str, 21 | default='../atis_data/data/resplit/processed/dev_with_tables.pkl') 22 | parser.add_argument( 23 | '--raw_validation_filename', 24 | type=str, 25 | default='../atis_data/data/resplit/processed/valid_with_tables.pkl') 26 | parser.add_argument( 27 | '--raw_test_filename', 28 | type=str, 29 | default='../atis_data/data/resplit/processed/test_with_tables.pkl') 30 | 31 | parser.add_argument('--data_directory', type=str, default='processed_data') 32 | 33 | parser.add_argument('--processed_train_filename', type=str, default='train.pkl') 34 | parser.add_argument('--processed_dev_filename', type=str, default='dev.pkl') 35 | parser.add_argument('--processed_validation_filename', type=str, default='validation.pkl') 36 | parser.add_argument('--processed_test_filename', type=str, default='test.pkl') 37 | 38 | parser.add_argument('--database_schema_filename', type=str, default=None) 39 | parser.add_argument('--embedding_filename', type=str, default=None) 40 | 41 | parser.add_argument('--input_vocabulary_filename', type=str, default='input_vocabulary.pkl') 42 | parser.add_argument('--output_vocabulary_filename', 43 | type=str, 44 | default='output_vocabulary.pkl') 45 | 46 | parser.add_argument('--input_key', type=str, default='nl_with_dates') 47 | 48 | parser.add_argument('--anonymize', type=bool, default=False) 49 | parser.add_argument('--anonymization_scoring', type=bool, default=False) 50 | parser.add_argument('--use_snippets', type=bool, default=False) 51 | 52 | parser.add_argument('--use_previous_query', type=bool, default=False) 53 | parser.add_argument('--maximum_queries', type=int, default=1) 54 | parser.add_argument('--use_copy_switch', type=bool, default=False) 55 | parser.add_argument('--use_query_attention', type=bool, default=False) 56 | 57 | parser.add_argument('--use_utterance_attention', type=bool, default=False) 58 | 59 | parser.add_argument('--freeze', type=bool, default=False) 60 | parser.add_argument('--scheduler', type=bool, default=False) 61 | 62 | parser.add_argument('--use_bert', type=bool, default=False) 63 | parser.add_argument("--bert_type_abb", type=str, help="Type of BERT model to load. e.g.) uS, uL, cS, cL, and mcS") 64 | parser.add_argument("--bert_input_version", type=str, default='v1') 65 | parser.add_argument('--fine_tune_bert', type=bool, default=False) 66 | parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.') 67 | 68 | ### Debugging/logging parameters 69 | parser.add_argument('--reload_embedding', type=int, default=0) 70 | parser.add_argument('--logdir', type=str, default='logs') 71 | parser.add_argument('--deterministic', type=bool, default=False) 72 | parser.add_argument('--num_train', type=int, default=-1) 73 | 74 | parser.add_argument('--logfile', type=str, default='log.txt') 75 | parser.add_argument('--results_file', type=str, default='results.txt') 76 | 77 | ### Model architecture 78 | parser.add_argument('--input_embedding_size', type=int, default=300) 79 | parser.add_argument('--output_embedding_size', type=int, default=300) 80 | 81 | parser.add_argument('--encoder_state_size', type=int, default=300) 82 | parser.add_argument('--decoder_state_size', type=int, default=300) 83 | 84 | parser.add_argument('--encoder_num_layers', type=int, default=1) 85 | parser.add_argument('--decoder_num_layers', type=int, default=1) 86 | parser.add_argument('--snippet_num_layers', type=int, default=1) 87 | 88 | parser.add_argument('--maximum_utterances', type=int, default=5) 89 | parser.add_argument('--state_positional_embeddings', type=bool, default=False) 90 | parser.add_argument('--positional_embedding_size', type=int, default=50) 91 | 92 | parser.add_argument('--snippet_age_embedding', type=bool, default=False) 93 | parser.add_argument('--snippet_age_embedding_size', type=int, default=64) 94 | parser.add_argument('--max_snippet_age_embedding', type=int, default=4) 95 | parser.add_argument('--previous_decoder_snippet_encoding', type=bool, default=False) 96 | 97 | parser.add_argument('--discourse_level_lstm', type=bool, default=False) 98 | 99 | parser.add_argument('--use_schema_attention', type=bool, default=False) 100 | parser.add_argument('--use_encoder_attention', type=bool, default=False) 101 | 102 | parser.add_argument('--use_schema_encoder', type=bool, default=False) 103 | parser.add_argument('--use_schema_self_attention', type=bool, default=False) 104 | parser.add_argument('--use_schema_encoder_2', type=bool, default=False) 105 | 106 | ### Training parameters 107 | parser.add_argument('--batch_size', type=int, default=16) 108 | parser.add_argument('--train_maximum_sql_length', type=int, default=400) #200 109 | parser.add_argument('--train_evaluation_size', type=int, default=100) 110 | 111 | parser.add_argument('--dropout_amount', type=float, default=0.5) 112 | 113 | parser.add_argument('--initial_patience', type=float, default=10.) 114 | parser.add_argument('--patience_ratio', type=float, default=1.01) 115 | 116 | parser.add_argument('--initial_learning_rate', type=float, default=1e-3) 117 | parser.add_argument('--learning_rate_ratio', type=float, default=0.8) 118 | 119 | parser.add_argument('--interaction_level', type=bool, default=False) 120 | parser.add_argument('--reweight_batch', type=bool, default=False) 121 | parser.add_argument('--gnn_layer_number', type=int, default=1) 122 | parser.add_argument('--clip', type=float, default=5.0) 123 | parser.add_argument('--warmup_step', type=int, default=1000) 124 | 125 | ### Setting 126 | parser.add_argument('--train', type=bool, default=False) 127 | parser.add_argument('--debug', type=bool, default=False) 128 | 129 | parser.add_argument('--evaluate', type=bool, default=False) 130 | parser.add_argument('--attention', type=bool, default=False) 131 | parser.add_argument('--save_file', type=str, default="") 132 | parser.add_argument('--enable_testing', type=bool, default=False) 133 | parser.add_argument('--use_predicted_queries', type=bool, default=False) 134 | parser.add_argument('--evaluate_split', type=str, default='dev') 135 | parser.add_argument('--evaluate_with_gold_forcing', type=bool, default=False) 136 | parser.add_argument('--eval_maximum_sql_length', type=int, default=400) 137 | parser.add_argument('--results_note', type=str, default='') 138 | parser.add_argument('--compute_metrics', type=bool, default=False) 139 | 140 | parser.add_argument('--reference_results', type=str, default='') 141 | 142 | parser.add_argument('--interactive', type=bool, default=False) 143 | 144 | parser.add_argument('--database_username', type=str, default="aviarmy") 145 | parser.add_argument('--database_password', type=str, default="aviarmy") 146 | parser.add_argument('--database_timeout', type=int, default=2) 147 | 148 | args = parser.parse_args() 149 | 150 | if not os.path.exists(args.logdir): 151 | os.makedirs(args.logdir) 152 | 153 | if not (args.train or args.evaluate or args.interactive or args.attention): 154 | raise ValueError('You need to be training or evaluating') 155 | if args.enable_testing and not args.evaluate: 156 | raise ValueError('You should evaluate the model if enabling testing') 157 | 158 | if args.train: 159 | args_file = args.logdir + '/args.log' 160 | if os.path.exists(args_file): 161 | raise ValueError('Warning: arguments already exist in ' + str(args_file)) 162 | with open(args_file, 'w') as infile: 163 | infile.write(str(args)) 164 | 165 | return args 166 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1 2 | sqlparse 3 | pymysql 4 | progressbar 5 | nltk -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Contains a main function for training and/or evaluating a model.""" 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import random 8 | 9 | from parse_args import interpret_args 10 | 11 | import data_util 12 | from data_util import atis_data 13 | from model.schema_interaction_model import SchemaInteractionATISModel 14 | from logger import Logger 15 | from model.model import ATISModel 16 | from model_util import Metrics, evaluate_utterance_sample, evaluate_interaction_sample, \ 17 | train_epoch_with_utterances, train_epoch_with_interactions, evaluate_using_predicted_queries 18 | 19 | import torch 20 | 21 | #np.random.seed(0) 22 | #random.seed(0) 23 | 24 | VALID_EVAL_METRICS = [Metrics.LOSS, Metrics.TOKEN_ACCURACY, Metrics.STRING_ACCURACY] 25 | TRAIN_EVAL_METRICS = [Metrics.LOSS, Metrics.TOKEN_ACCURACY, Metrics.STRING_ACCURACY] 26 | FINAL_EVAL_METRICS = [Metrics.STRING_ACCURACY, Metrics.TOKEN_ACCURACY] 27 | 28 | def train(model, data, params): 29 | """ Trains a model. 30 | 31 | Inputs: 32 | model (ATISModel): The model to train. 33 | data (ATISData): The data that is used to train. 34 | params (namespace): Training parameters. 35 | """ 36 | # Get the training batches. 37 | log = Logger(os.path.join(params.logdir, params.logfile), "w") 38 | num_train_original = atis_data.num_utterances(data.train_data) 39 | log.put("Original number of training utterances:\t" 40 | + str(num_train_original)) 41 | 42 | eval_fn = evaluate_utterance_sample 43 | trainbatch_fn = data.get_utterance_batches 44 | trainsample_fn = data.get_random_utterances 45 | validsample_fn = data.get_all_utterances 46 | batch_size = params.batch_size 47 | if params.interaction_level: 48 | batch_size = 1 49 | eval_fn = evaluate_interaction_sample 50 | trainbatch_fn = data.get_interaction_batches 51 | trainsample_fn = data.get_random_interactions 52 | validsample_fn = data.get_all_interactions 53 | 54 | maximum_output_length = params.train_maximum_sql_length 55 | train_batches = trainbatch_fn(batch_size, 56 | max_output_length=maximum_output_length, 57 | randomize=not params.deterministic) 58 | 59 | if params.num_train >= 0: 60 | train_batches = train_batches[:params.num_train] 61 | 62 | training_sample = trainsample_fn(params.train_evaluation_size, 63 | max_output_length=maximum_output_length) 64 | valid_examples = validsample_fn(data.valid_data, 65 | max_output_length=maximum_output_length) 66 | 67 | num_train_examples = sum([len(batch) for batch in train_batches]) 68 | num_steps_per_epoch = len(train_batches) 69 | 70 | log.put( 71 | "Actual number of used training examples:\t" + 72 | str(num_train_examples)) 73 | log.put("(Shortened by output limit of " + 74 | str(maximum_output_length) + 75 | ")") 76 | log.put("Number of steps per epoch:\t" + str(num_steps_per_epoch)) 77 | log.put("Batch size:\t" + str(batch_size)) 78 | 79 | print( 80 | "Kept " + 81 | str(num_train_examples) + 82 | "/" + 83 | str(num_train_original) + 84 | " examples") 85 | print( 86 | "Batch size of " + 87 | str(batch_size) + 88 | " gives " + 89 | str(num_steps_per_epoch) + 90 | " steps per epoch") 91 | 92 | # Keeping track of things during training. 93 | epochs = 0 94 | patience = params.initial_patience 95 | learning_rate_coefficient = 1. 96 | previous_epoch_loss = float('inf') 97 | previous_valid_acc = 0. 98 | maximum_validation_accuracy = 0. 99 | maximum_string_accuracy = 0. 100 | 101 | countdown = int(patience) 102 | 103 | if params.scheduler: 104 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model.trainer, mode='min', ) 105 | 106 | keep_training = True 107 | step = 0 108 | while keep_training: 109 | log.put("Epoch:\t" + str(epochs)) 110 | model.set_dropout(params.dropout_amount) 111 | model.train() 112 | 113 | if not params.scheduler: 114 | model.set_learning_rate(learning_rate_coefficient * params.initial_learning_rate) 115 | 116 | # Run a training step. 117 | if params.interaction_level: 118 | epoch_loss, step = train_epoch_with_interactions( 119 | train_batches, 120 | params, 121 | model, 122 | randomize=not params.deterministic, 123 | db2id=data.db2id, 124 | id2db=data.id2db, 125 | step=step) 126 | else: 127 | epoch_loss = train_epoch_with_utterances( 128 | train_batches, 129 | model, 130 | randomize=not params.deterministic) 131 | 132 | log.put("train epoch loss:\t" + str(epoch_loss)) 133 | 134 | model.set_dropout(0.) 135 | model.eval() 136 | 137 | # Run an evaluation step on a sample of the training data. 138 | train_eval_results = eval_fn(training_sample, 139 | model, 140 | params.train_maximum_sql_length, 141 | name=os.path.join(params.logdir, "train-eval"), 142 | write_results=True, 143 | gold_forcing=True, 144 | metrics=TRAIN_EVAL_METRICS)[0] 145 | 146 | for name, value in train_eval_results.items(): 147 | log.put( 148 | "train final gold-passing " + 149 | name.name + 150 | ":\t" + 151 | "%.2f" % 152 | value) 153 | 154 | # Run an evaluation step on the validation set. 155 | valid_eval_results = eval_fn(valid_examples, 156 | model, 157 | params.eval_maximum_sql_length, 158 | name=os.path.join(params.logdir, "valid-eval"), 159 | write_results=True, 160 | gold_forcing=True, 161 | metrics=VALID_EVAL_METRICS)[0] 162 | for name, value in valid_eval_results.items(): 163 | log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value) 164 | 165 | valid_loss = valid_eval_results[Metrics.LOSS] 166 | valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY] 167 | string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY] 168 | 169 | if params.scheduler: 170 | scheduler.step(valid_loss) 171 | 172 | if valid_loss > previous_epoch_loss and valid_token_accuracy < previous_valid_acc and step >= params.warmup_step: 173 | learning_rate_coefficient *= params.learning_rate_ratio 174 | log.put( 175 | "learning rate coefficient:\t" + 176 | str(learning_rate_coefficient)) 177 | 178 | previous_epoch_loss = valid_loss 179 | previous_valid_acc = valid_token_accuracy 180 | saved = False 181 | 182 | if not saved and string_accuracy > maximum_string_accuracy: 183 | maximum_string_accuracy = string_accuracy 184 | patience = patience * params.patience_ratio 185 | countdown = int(patience) 186 | last_save_file = os.path.join(params.logdir, "save_" + str(epochs)) 187 | model.save(last_save_file) 188 | 189 | log.put( 190 | "maximum string accuracy:\t" + 191 | str(maximum_string_accuracy)) 192 | log.put("patience:\t" + str(patience)) 193 | log.put("save file:\t" + str(last_save_file)) 194 | else: 195 | log.put("still saved") 196 | last_save_file = os.path.join(params.logdir, "save_" + str(epochs)) 197 | model.save(last_save_file) 198 | 199 | if countdown <= 0: 200 | keep_training = False 201 | 202 | countdown -= 1 203 | log.put("countdown:\t" + str(countdown)) 204 | log.put("") 205 | 206 | epochs += 1 207 | 208 | log.put("Finished training!") 209 | log.close() 210 | 211 | return last_save_file 212 | 213 | 214 | def evaluate(model, data, params, last_save_file, split): 215 | """Evaluates a pretrained model on a dataset. 216 | 217 | Inputs: 218 | model (ATISModel): Model class. 219 | data (ATISData): All of the data. 220 | params (namespace): Parameters for the model. 221 | last_save_file (str): Location where the model save file is. 222 | """ 223 | if last_save_file: 224 | model.load(last_save_file) 225 | else: 226 | if not params.save_file: 227 | raise ValueError( 228 | "Must provide a save file name if not training first.") 229 | print('0------0') 230 | print(params.save_file) 231 | print('0------0') 232 | model.load(params.save_file) 233 | 234 | filename = split 235 | 236 | if filename == 'dev': 237 | split = data.dev_data 238 | elif filename == 'train': 239 | split = data.train_data 240 | elif filename == 'test': 241 | split = data.test_data 242 | elif filename == 'valid': 243 | split = data.valid_data 244 | else: 245 | raise ValueError("Split not recognized: " + str(params.evaluate_split)) 246 | 247 | if params.use_predicted_queries: 248 | filename += "_use_predicted_queries" 249 | else: 250 | filename += "_use_gold_queries" 251 | 252 | full_name = os.path.join(params.logdir, filename) + params.results_note 253 | 254 | if params.interaction_level or params.use_predicted_queries: 255 | examples = data.get_all_interactions(split) 256 | if params.interaction_level: 257 | evaluate_interaction_sample( 258 | examples, 259 | model, 260 | name=full_name, 261 | metrics=FINAL_EVAL_METRICS, 262 | total_num=atis_data.num_utterances(split), 263 | database_username=params.database_username, 264 | database_password=params.database_password, 265 | database_timeout=params.database_timeout, 266 | use_predicted_queries=params.use_predicted_queries, 267 | max_generation_length=params.eval_maximum_sql_length, 268 | write_results=True, 269 | use_gpu=True, 270 | compute_metrics=params.compute_metrics) 271 | else: 272 | evaluate_using_predicted_queries( 273 | examples, 274 | model, 275 | name=full_name, 276 | metrics=FINAL_EVAL_METRICS, 277 | total_num=atis_data.num_utterances(split), 278 | database_username=params.database_username, 279 | database_password=params.database_password, 280 | database_timeout=params.database_timeout) 281 | else: 282 | examples = data.get_all_utterances(split) 283 | evaluate_utterance_sample( 284 | examples, 285 | model, 286 | name=full_name, 287 | gold_forcing=False, 288 | metrics=FINAL_EVAL_METRICS, 289 | total_num=atis_data.num_utterances(split), 290 | max_generation_length=params.eval_maximum_sql_length, 291 | database_username=params.database_username, 292 | database_password=params.database_password, 293 | database_timeout=params.database_timeout, 294 | write_results=True) 295 | 296 | 297 | def main(): 298 | """Main function that trains and/or evaluates a model.""" 299 | params = interpret_args() 300 | 301 | # Prepare the dataset into the proper form. 302 | data = atis_data.ATISDataset(params) 303 | params.num_db = len(data.db2id) 304 | 305 | # Construct the model object. 306 | if params.interaction_level: 307 | model_type = SchemaInteractionATISModel 308 | else: 309 | print('not implemented') 310 | exit() 311 | 312 | model = model_type( 313 | params, 314 | data.input_vocabulary, 315 | data.output_vocabulary, 316 | data.output_vocabulary_schema, 317 | data.anonymizer if params.anonymize and params.anonymization_scoring else None) 318 | 319 | model = model.cuda() 320 | print('=====================Model Parameters=====================') 321 | for name, param in model.named_parameters(): 322 | print(name, param.requires_grad, param.is_cuda, param.size()) 323 | assert param.is_cuda 324 | 325 | model.build_optim() 326 | 327 | print('=====================Parameters in Optimizer==============') 328 | for param_group in model.trainer.param_groups: 329 | print(param_group.keys()) 330 | for param in param_group['params']: 331 | print(param.size()) 332 | 333 | if params.fine_tune_bert: 334 | print('=====================Parameters in BERT Optimizer==============') 335 | for param_group in model.bert_trainer.param_groups: 336 | print(param_group.keys()) 337 | for param in param_group['params']: 338 | print(param.size()) 339 | 340 | sys.stdout.flush() 341 | 342 | last_save_file = "" 343 | 344 | if params.train: 345 | last_save_file = train(model, data, params) 346 | if params.evaluate and 'valid' in params.evaluate_split: 347 | evaluate(model, data, params, last_save_file, split='valid') 348 | if params.evaluate and 'dev' in params.evaluate_split: 349 | evaluate(model, data, params, last_save_file, split='dev') 350 | if params.evaluate and 'test' in params.evaluate_split: 351 | evaluate(model, data, params, last_save_file, split='test') 352 | 353 | if __name__ == "__main__": 354 | main() 355 | -------------------------------------------------------------------------------- /run_cosql.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/ 4 | 5 | python3 preprocess.py --dataset=cosql --remove_from 6 | 7 | # 2. train and evaluate. 8 | # the result (models, logs, prediction outputs) are saved in $LOGDIR 9 | 10 | GLOVE_PATH="/home/caiyitao/glove.840B.300d.txt" # you need to change this 11 | LOGDIR="logs_cosql_editsql" 12 | 13 | CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \ 14 | --raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \ 15 | --database_schema_filename="data/cosql_data_removefrom/tables.json" \ 16 | --embedding_filename=$GLOVE_PATH \ 17 | --data_directory="processed_data_cosql_removefrom" \ 18 | --input_key="utterance" \ 19 | --state_positional_embeddings=1 \ 20 | --discourse_level_lstm=1 \ 21 | --use_schema_encoder=1 \ 22 | --use_schema_attention=1 \ 23 | --use_bert=1 \ 24 | --fine_tune_bert=1 \ 25 | --bert_type_abb=uS \ 26 | --interaction_level=1 \ 27 | --reweight_batch=1 \ 28 | --train=1 \ 29 | --logdir=$LOGDIR \ 30 | --evaluate=1 \ 31 | --evaluate_split="valid" \ 32 | --use_query_attention=1 \ 33 | --use_previous_query=1 \ 34 | --use_encoder_attention=1 \ 35 | --use_utterance_attention=1 \ 36 | --use_predicted_queries=1 37 | 38 | # 3. get evaluation result 39 | 40 | python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from 41 | -------------------------------------------------------------------------------- /run_sparc.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1. preprocess dataset by the following. It will produce data/sparc_data_removefrom/ 4 | 5 | python3 preprocess.py --dataset=sparc --remove_from 6 | 7 | # 2. train and evaluate. 8 | # the result (models, logs, prediction outputs) are saved in $LOGDIR 9 | 10 | GLOVE_PATH="/home/caiyitao/glove.840B.300d.txt" # you need to change this 11 | LOGDIR="logs_sparc_editsql" 12 | 13 | CUDA_VISIBLE_DEVICES=3 python3 run.py --raw_train_filename="data/sparc_data_removefrom/train.pkl" \ 14 | --raw_validation_filename="data/sparc_data_removefrom/dev.pkl" \ 15 | --database_schema_filename="data/sparc_data_removefrom/tables.json" \ 16 | --embedding_filename=$GLOVE_PATH \ 17 | --data_directory="processed_data_sparc_removefrom" \ 18 | --input_key="utterance" \ 19 | --state_positional_embeddings=1 \ 20 | --discourse_level_lstm=1 \ 21 | --use_schema_encoder=1 \ 22 | --use_schema_attention=1 \ 23 | --use_encoder_attention=1 \ 24 | --use_bert=1 \ 25 | --fine_tune_bert=1 \ 26 | --bert_type_abb=uS \ 27 | --interaction_level=1 \ 28 | --reweight_batch=1 \ 29 | --train=1 \ 30 | --use_previous_query=1 \ 31 | --use_query_attention=1 \ 32 | --logdir=$LOGDIR \ 33 | --evaluate=1 \ 34 | --evaluate_split="valid" \ 35 | --use_utterance_attention=1 \ 36 | --use_predicted_queries=1 37 | 38 | # 3. get evaluation result 39 | 40 | python3 postprocess_eval.py --dataset=sparc --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from 41 | -------------------------------------------------------------------------------- /test_cosql.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/ 4 | 5 | # python3 preprocess.py --dataset=cosql --remove_from 6 | 7 | # 2. train and evaluate. 8 | # the result (models, logs, prediction outputs) are saved in $LOGDIR 9 | 10 | GLOVE_PATH="/home/caiyitao/glove.840B.300d.txt" # you need to change this 11 | LOGDIR="logs_cosql_editsql" 12 | 13 | CUDA_VISIBLE_DEVICES=3 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \ 14 | --raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \ 15 | --database_schema_filename="data/cosql_data_removefrom/tables.json" \ 16 | --embedding_filename=$GLOVE_PATH \ 17 | --data_directory="processed_data_cosql_removefrom" \ 18 | --input_key="utterance" \ 19 | --state_positional_embeddings=1 \ 20 | --reload_embedding=1 \ 21 | --discourse_level_lstm=1 \ 22 | --use_schema_encoder=1 \ 23 | --use_schema_attention=1 \ 24 | --use_bert=1 \ 25 | --bert_type_abb=uS \ 26 | --fine_tune_bert=1 \ 27 | --interaction_level=1 \ 28 | --reweight_batch=1 \ 29 | --freeze=1 \ 30 | --use_previous_query=1 \ 31 | --use_query_attention=1 \ 32 | --logdir=$LOGDIR \ 33 | --evaluate=1 \ 34 | --evaluate_split="valid" \ 35 | --use_predicted_queries=1 \ 36 | --use_encoder_attention=1 \ 37 | --use_utterance_attention=1 \ 38 | --save_file="" 39 | 40 | # 3. get evaluation result 41 | 42 | python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from 43 | -------------------------------------------------------------------------------- /test_sparc.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1. preprocess dataset by the following. It will produce data/sparc_data_removefrom/ 4 | 5 | #python3 preprocess.py --dataset=sparc --remove_from 6 | 7 | # 2. train and evaluate. 8 | # the result (models, logs, prediction outputs) are saved in $LOGDIR 9 | 10 | GLOVE_PATH="/home/caiyitao/glove.840B.300d.txt" # you need to change this 11 | LOGDIR="logs_sparc_editsql" 12 | 13 | CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/sparc_data_removefrom/train.pkl" \ 14 | --raw_validation_filename="data/sparc_data_removefrom/dev.pkl" \ 15 | --database_schema_filename="data/sparc_data_removefrom/tables.json" \ 16 | --embedding_filename=$GLOVE_PATH \ 17 | --data_directory="processed_data_sparc_removefrom" \ 18 | --input_key="utterance" \ 19 | --state_positional_embeddings=1 \ 20 | --discourse_level_lstm=1 \ 21 | --use_schema_encoder=1 \ 22 | --use_schema_attention=1 \ 23 | --use_encoder_attention=1 \ 24 | --use_bert=1 \ 25 | --bert_type_abb=uS \ 26 | --fine_tune_bert=1 \ 27 | --interaction_level=1 \ 28 | --reweight_batch=1 \ 29 | --freeze=1 \ 30 | --logdir=$LOGDIR \ 31 | --evaluate=1 \ 32 | --evaluate_split="valid" \ 33 | --use_predicted_queries=1 \ 34 | --use_previous_query=1 \ 35 | --use_query_attention=1 \ 36 | --use_utterance_attention=1 \ 37 | --save_file="" 38 | 39 | # 3. get evaluation result 40 | 41 | python3 postprocess_eval.py --dataset=sparc --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from 42 | --------------------------------------------------------------------------------