├── LICENSE ├── README.md ├── __init__.py ├── create_pretraining_data.py ├── data └── data.rar ├── extract_features.py ├── make_data ├── make_data_for_augment.py ├── make_data_for_squad.py └── utils.py ├── matrix_code ├── data_util.py ├── main_modeling.py ├── model_wrapper.py └── train.py ├── modeling.py ├── modeling_test.py ├── optimization.py ├── optimization_test.py ├── run_classifier.py ├── run_pretraining.py ├── run_squad.py ├── squad_code ├── data_util.py ├── model_wrapper.py └── train.py ├── squad_data ├── my_train.json └── my_valid.json ├── tokenization.py ├── tokenization_test.py └── uncased_L-12_H-768_A-12 ├── bert_config.json └── vocab.txt /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # table2answer 2 | 3 | [![LICENSE](https://img.shields.io/badge/license-Anti%20996-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE) 4 | 5 | Table2answer: Read the database and answer without SQL 6 | 7 | https://arxiv.org/abs/1902.04260 8 | 9 | # REASONABLE 10 | 11 | The reason why we think removing the logic form step is possible is that human can do the text2sql task without explicit logic form. 12 | 13 | # requirement 14 | 15 | python3 16 | 17 | tensorflow >= 1.12.0 18 | 19 | # Train 20 | 21 | ### Step 1. 22 | 23 | Download the pre-trained model at https://github.com/google-research/bert and unzip them to `uncased_L-12_H-768_A-12` 24 | 25 | ### Step 2. 26 | 27 | Download the v1.1 squad data at https://github.com/rajpurkar/SQuAD-explorer/tree/master/dataset 28 | 29 | ### Step 3. 30 | 31 | use `make_data` to create the data. 32 | 33 | ### Step 4. 34 | 35 | `matrix_code/train.py` 36 | 37 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import tensorflow as tf 24 | import tokenization 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("input_file", None, 31 | "Input raw text file (or comma-separated list of files).") 32 | 33 | flags.DEFINE_string( 34 | "output_file", None, 35 | "Output TF example file (or comma-separated list of files).") 36 | 37 | flags.DEFINE_string("vocab_file", None, 38 | "The vocabulary file that the BERT model was trained on.") 39 | 40 | flags.DEFINE_bool( 41 | "do_lower_case", True, 42 | "Whether to lower case the input text. Should be True for uncased " 43 | "models and False for cased models.") 44 | 45 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 46 | 47 | flags.DEFINE_integer("max_predictions_per_seq", 20, 48 | "Maximum number of masked LM predictions per sequence.") 49 | 50 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 51 | 52 | flags.DEFINE_integer( 53 | "dupe_factor", 10, 54 | "Number of times to duplicate the input data (with different masks).") 55 | 56 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 57 | 58 | flags.DEFINE_float( 59 | "short_seq_prob", 0.1, 60 | "Probability of creating sequences which are shorter than the " 61 | "maximum length.") 62 | 63 | 64 | class TrainingInstance(object): 65 | """A single training instance (sentence pair).""" 66 | 67 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 68 | is_random_next): 69 | self.tokens = tokens 70 | self.segment_ids = segment_ids 71 | self.is_random_next = is_random_next 72 | self.masked_lm_positions = masked_lm_positions 73 | self.masked_lm_labels = masked_lm_labels 74 | 75 | def __str__(self): 76 | s = "" 77 | s += "tokens: %s\n" % (" ".join( 78 | [tokenization.printable_text(x) for x in self.tokens])) 79 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 80 | s += "is_random_next: %s\n" % self.is_random_next 81 | s += "masked_lm_positions: %s\n" % (" ".join( 82 | [str(x) for x in self.masked_lm_positions])) 83 | s += "masked_lm_labels: %s\n" % (" ".join( 84 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 85 | s += "\n" 86 | return s 87 | 88 | def __repr__(self): 89 | return self.__str__() 90 | 91 | 92 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 93 | max_predictions_per_seq, output_files): 94 | """Create TF example files from `TrainingInstance`s.""" 95 | writers = [] 96 | for output_file in output_files: 97 | writers.append(tf.python_io.TFRecordWriter(output_file)) 98 | 99 | writer_index = 0 100 | 101 | total_written = 0 102 | for (inst_index, instance) in enumerate(instances): 103 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 104 | input_mask = [1] * len(input_ids) 105 | segment_ids = list(instance.segment_ids) 106 | assert len(input_ids) <= max_seq_length 107 | 108 | while len(input_ids) < max_seq_length: 109 | input_ids.append(0) 110 | input_mask.append(0) 111 | segment_ids.append(0) 112 | 113 | assert len(input_ids) == max_seq_length 114 | assert len(input_mask) == max_seq_length 115 | assert len(segment_ids) == max_seq_length 116 | 117 | masked_lm_positions = list(instance.masked_lm_positions) 118 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 119 | masked_lm_weights = [1.0] * len(masked_lm_ids) 120 | 121 | while len(masked_lm_positions) < max_predictions_per_seq: 122 | masked_lm_positions.append(0) 123 | masked_lm_ids.append(0) 124 | masked_lm_weights.append(0.0) 125 | 126 | next_sentence_label = 1 if instance.is_random_next else 0 127 | 128 | features = collections.OrderedDict() 129 | features["input_ids"] = create_int_feature(input_ids) 130 | features["input_mask"] = create_int_feature(input_mask) 131 | features["segment_ids"] = create_int_feature(segment_ids) 132 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 133 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 134 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 135 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 136 | 137 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 138 | 139 | writers[writer_index].write(tf_example.SerializeToString()) 140 | writer_index = (writer_index + 1) % len(writers) 141 | 142 | total_written += 1 143 | 144 | if inst_index < 20: 145 | tf.logging.info("*** Example ***") 146 | tf.logging.info("tokens: %s" % " ".join( 147 | [tokenization.printable_text(x) for x in instance.tokens])) 148 | 149 | for feature_name in features.keys(): 150 | feature = features[feature_name] 151 | values = [] 152 | if feature.int64_list.value: 153 | values = feature.int64_list.value 154 | elif feature.float_list.value: 155 | values = feature.float_list.value 156 | tf.logging.info( 157 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 158 | 159 | for writer in writers: 160 | writer.close() 161 | 162 | tf.logging.info("Wrote %d total instances", total_written) 163 | 164 | 165 | def create_int_feature(values): 166 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 167 | return feature 168 | 169 | 170 | def create_float_feature(values): 171 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 172 | return feature 173 | 174 | 175 | def create_training_instances(input_files, tokenizer, max_seq_length, 176 | dupe_factor, short_seq_prob, masked_lm_prob, 177 | max_predictions_per_seq, rng): 178 | """Create `TrainingInstance`s from raw text.""" 179 | all_documents = [[]] 180 | 181 | # Input file format: 182 | # (1) One sentence per line. These should ideally be actual sentences, not 183 | # entire paragraphs or arbitrary spans of text. (Because we use the 184 | # sentence boundaries for the "next sentence prediction" task). 185 | # (2) Blank lines between documents. Document boundaries are needed so 186 | # that the "next sentence prediction" task doesn't span between documents. 187 | for input_file in input_files: 188 | with tf.gfile.GFile(input_file, "r") as reader: 189 | while True: 190 | line = tokenization.convert_to_unicode(reader.readline()) 191 | if not line: 192 | break 193 | line = line.strip() 194 | 195 | # Empty lines are used as document delimiters 196 | if not line: 197 | all_documents.append([]) 198 | tokens = tokenizer.tokenize(line) 199 | if tokens: 200 | all_documents[-1].append(tokens) 201 | 202 | # Remove empty documents 203 | all_documents = [x for x in all_documents if x] 204 | rng.shuffle(all_documents) 205 | 206 | vocab_words = list(tokenizer.vocab.keys()) 207 | instances = [] 208 | for _ in range(dupe_factor): 209 | for document_index in range(len(all_documents)): 210 | instances.extend( 211 | create_instances_from_document( 212 | all_documents, document_index, max_seq_length, short_seq_prob, 213 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 214 | 215 | rng.shuffle(instances) 216 | return instances 217 | 218 | 219 | def create_instances_from_document( 220 | all_documents, document_index, max_seq_length, short_seq_prob, 221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 222 | """Creates `TrainingInstance`s for a single document.""" 223 | document = all_documents[document_index] 224 | 225 | # Account for [CLS], [SEP], [SEP] 226 | max_num_tokens = max_seq_length - 3 227 | 228 | # We *usually* want to fill up the entire sequence since we are padding 229 | # to `max_seq_length` anyways, so short sequences are generally wasted 230 | # computation. However, we *sometimes* 231 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 232 | # sequences to minimize the mismatch between pre-training and fine-tuning. 233 | # The `target_seq_length` is just a rough target however, whereas 234 | # `max_seq_length` is a hard limit. 235 | target_seq_length = max_num_tokens 236 | if rng.random() < short_seq_prob: 237 | target_seq_length = rng.randint(2, max_num_tokens) 238 | 239 | # We DON'T just concatenate all of the tokens from a document into a long 240 | # sequence and choose an arbitrary split point because this would make the 241 | # next sentence prediction task too easy. Instead, we split the input into 242 | # segments "A" and "B" based on the actual "sentences" provided by the user 243 | # input. 244 | instances = [] 245 | current_chunk = [] 246 | current_length = 0 247 | i = 0 248 | while i < len(document): 249 | segment = document[i] 250 | current_chunk.append(segment) 251 | current_length += len(segment) 252 | if i == len(document) - 1 or current_length >= target_seq_length: 253 | if current_chunk: 254 | # `a_end` is how many segments from `current_chunk` go into the `A` 255 | # (first) sentence. 256 | a_end = 1 257 | if len(current_chunk) >= 2: 258 | a_end = rng.randint(1, len(current_chunk) - 1) 259 | 260 | tokens_a = [] 261 | for j in range(a_end): 262 | tokens_a.extend(current_chunk[j]) 263 | 264 | tokens_b = [] 265 | # Random next 266 | is_random_next = False 267 | if len(current_chunk) == 1 or rng.random() < 0.5: 268 | is_random_next = True 269 | target_b_length = target_seq_length - len(tokens_a) 270 | 271 | # This should rarely go for more than one iteration for large 272 | # corpora. However, just to be careful, we try to make sure that 273 | # the random document is not the same as the document 274 | # we're processing. 275 | for _ in range(10): 276 | random_document_index = rng.randint(0, len(all_documents) - 1) 277 | if random_document_index != document_index: 278 | break 279 | 280 | random_document = all_documents[random_document_index] 281 | random_start = rng.randint(0, len(random_document) - 1) 282 | for j in range(random_start, len(random_document)): 283 | tokens_b.extend(random_document[j]) 284 | if len(tokens_b) >= target_b_length: 285 | break 286 | # We didn't actually use these segments so we "put them back" so 287 | # they don't go to waste. 288 | num_unused_segments = len(current_chunk) - a_end 289 | i -= num_unused_segments 290 | # Actual next 291 | else: 292 | is_random_next = False 293 | for j in range(a_end, len(current_chunk)): 294 | tokens_b.extend(current_chunk[j]) 295 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 296 | 297 | assert len(tokens_a) >= 1 298 | assert len(tokens_b) >= 1 299 | 300 | tokens = [] 301 | segment_ids = [] 302 | tokens.append("[CLS]") 303 | segment_ids.append(0) 304 | for token in tokens_a: 305 | tokens.append(token) 306 | segment_ids.append(0) 307 | 308 | tokens.append("[SEP]") 309 | segment_ids.append(0) 310 | 311 | for token in tokens_b: 312 | tokens.append(token) 313 | segment_ids.append(1) 314 | tokens.append("[SEP]") 315 | segment_ids.append(1) 316 | 317 | (tokens, masked_lm_positions, 318 | masked_lm_labels) = create_masked_lm_predictions( 319 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 320 | instance = TrainingInstance( 321 | tokens=tokens, 322 | segment_ids=segment_ids, 323 | is_random_next=is_random_next, 324 | masked_lm_positions=masked_lm_positions, 325 | masked_lm_labels=masked_lm_labels) 326 | instances.append(instance) 327 | current_chunk = [] 328 | current_length = 0 329 | i += 1 330 | 331 | return instances 332 | 333 | 334 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 335 | ["index", "label"]) 336 | 337 | 338 | def create_masked_lm_predictions(tokens, masked_lm_prob, 339 | max_predictions_per_seq, vocab_words, rng): 340 | """Creates the predictions for the masked LM objective.""" 341 | 342 | cand_indexes = [] 343 | for (i, token) in enumerate(tokens): 344 | if token == "[CLS]" or token == "[SEP]": 345 | continue 346 | cand_indexes.append(i) 347 | 348 | rng.shuffle(cand_indexes) 349 | 350 | output_tokens = list(tokens) 351 | 352 | num_to_predict = min(max_predictions_per_seq, 353 | max(1, int(round(len(tokens) * masked_lm_prob)))) 354 | 355 | masked_lms = [] 356 | covered_indexes = set() 357 | for index in cand_indexes: 358 | if len(masked_lms) >= num_to_predict: 359 | break 360 | if index in covered_indexes: 361 | continue 362 | covered_indexes.add(index) 363 | 364 | masked_token = None 365 | # 80% of the time, replace with [MASK] 366 | if rng.random() < 0.8: 367 | masked_token = "[MASK]" 368 | else: 369 | # 10% of the time, keep original 370 | if rng.random() < 0.5: 371 | masked_token = tokens[index] 372 | # 10% of the time, replace with random word 373 | else: 374 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 375 | 376 | output_tokens[index] = masked_token 377 | 378 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 379 | 380 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 381 | 382 | masked_lm_positions = [] 383 | masked_lm_labels = [] 384 | for p in masked_lms: 385 | masked_lm_positions.append(p.index) 386 | masked_lm_labels.append(p.label) 387 | 388 | return (output_tokens, masked_lm_positions, masked_lm_labels) 389 | 390 | 391 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 392 | """Truncates a pair of sequences to a maximum sequence length.""" 393 | while True: 394 | total_length = len(tokens_a) + len(tokens_b) 395 | if total_length <= max_num_tokens: 396 | break 397 | 398 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 399 | assert len(trunc_tokens) >= 1 400 | 401 | # We want to sometimes truncate from the front and sometimes from the 402 | # back to add more randomness and avoid biases. 403 | if rng.random() < 0.5: 404 | del trunc_tokens[0] 405 | else: 406 | trunc_tokens.pop() 407 | 408 | 409 | def main(_): 410 | tf.logging.set_verbosity(tf.logging.INFO) 411 | 412 | tokenizer = tokenization.FullTokenizer( 413 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 414 | 415 | input_files = [] 416 | for input_pattern in FLAGS.input_file.split(","): 417 | input_files.extend(tf.gfile.Glob(input_pattern)) 418 | 419 | tf.logging.info("*** Reading from input files ***") 420 | for input_file in input_files: 421 | tf.logging.info(" %s", input_file) 422 | 423 | rng = random.Random(FLAGS.random_seed) 424 | instances = create_training_instances( 425 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 426 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 427 | rng) 428 | 429 | output_files = FLAGS.output_file.split(",") 430 | tf.logging.info("*** Writing to output files ***") 431 | for output_file in output_files: 432 | tf.logging.info(" %s", output_file) 433 | 434 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 435 | FLAGS.max_predictions_per_seq, output_files) 436 | 437 | 438 | if __name__ == "__main__": 439 | flags.mark_flag_as_required("input_file") 440 | flags.mark_flag_as_required("output_file") 441 | flags.mark_flag_as_required("vocab_file") 442 | tf.app.run() 443 | -------------------------------------------------------------------------------- /data/data.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guotong1988/table2answer/af8b86b3a19ff18bd16e6a223368110d743f3d71/data/data.rar -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import modeling 27 | import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, 174 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 175 | tvars, init_checkpoint) 176 | if use_tpu: 177 | 178 | def tpu_scaffold(): 179 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 180 | return tf.train.Scaffold() 181 | 182 | scaffold_fn = tpu_scaffold 183 | else: 184 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 185 | 186 | tf.logging.info("**** Trainable Variables ****") 187 | for var in tvars: 188 | init_string = "" 189 | if var.name in initialized_variable_names: 190 | init_string = ", *INIT_FROM_CKPT*" 191 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 192 | init_string) 193 | 194 | all_layers = model.get_all_encoder_layers() 195 | 196 | predictions = { 197 | "unique_id": unique_ids, 198 | } 199 | 200 | for (i, layer_index) in enumerate(layer_indexes): 201 | predictions["layer_output_%d" % i] = all_layers[layer_index] 202 | 203 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 204 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 205 | return output_spec 206 | 207 | return model_fn 208 | 209 | 210 | def convert_examples_to_features(examples, seq_length, tokenizer): 211 | """Loads a data file into a list of `InputBatch`s.""" 212 | 213 | features = [] 214 | for (ex_index, example) in enumerate(examples): 215 | tokens_a = tokenizer.tokenize(example.text_a) 216 | 217 | tokens_b = None 218 | if example.text_b: 219 | tokens_b = tokenizer.tokenize(example.text_b) 220 | 221 | if tokens_b: 222 | # Modifies `tokens_a` and `tokens_b` in place so that the total 223 | # length is less than the specified length. 224 | # Account for [CLS], [SEP], [SEP] with "- 3" 225 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 226 | else: 227 | # Account for [CLS] and [SEP] with "- 2" 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | if tokens_b: 260 | for token in tokens_b: 261 | tokens.append(token) 262 | input_type_ids.append(1) 263 | tokens.append("[SEP]") 264 | input_type_ids.append(1) 265 | 266 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 267 | 268 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 269 | # tokens are attended to. 270 | input_mask = [1] * len(input_ids) 271 | 272 | # Zero-pad up to the sequence length. 273 | while len(input_ids) < seq_length: 274 | input_ids.append(0) 275 | input_mask.append(0) 276 | input_type_ids.append(0) 277 | 278 | assert len(input_ids) == seq_length 279 | assert len(input_mask) == seq_length 280 | assert len(input_type_ids) == seq_length 281 | 282 | if ex_index < 5: 283 | tf.logging.info("*** Example ***") 284 | tf.logging.info("unique_id: %s" % (example.unique_id)) 285 | tf.logging.info("tokens: %s" % " ".join( 286 | [tokenization.printable_text(x) for x in tokens])) 287 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 288 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 289 | tf.logging.info( 290 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 291 | 292 | features.append( 293 | InputFeatures( 294 | unique_id=example.unique_id, 295 | tokens=tokens, 296 | input_ids=input_ids, 297 | input_mask=input_mask, 298 | input_type_ids=input_type_ids)) 299 | return features 300 | 301 | 302 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 303 | """Truncates a sequence pair in place to the maximum length.""" 304 | 305 | # This is a simple heuristic which will always truncate the longer sequence 306 | # one token at a time. This makes more sense than truncating an equal percent 307 | # of tokens from each, since if one sequence is very short then each token 308 | # that's truncated likely contains more information than a longer sequence. 309 | while True: 310 | total_length = len(tokens_a) + len(tokens_b) 311 | if total_length <= max_length: 312 | break 313 | if len(tokens_a) > len(tokens_b): 314 | tokens_a.pop() 315 | else: 316 | tokens_b.pop() 317 | 318 | 319 | def read_examples(input_file): 320 | """Read a list of `InputExample`s from an input file.""" 321 | examples = [] 322 | unique_id = 0 323 | with tf.gfile.GFile(input_file, "r") as reader: 324 | while True: 325 | line = tokenization.convert_to_unicode(reader.readline()) 326 | if not line: 327 | break 328 | line = line.strip() 329 | text_a = None 330 | text_b = None 331 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 332 | if m is None: 333 | text_a = line 334 | else: 335 | text_a = m.group(1) 336 | text_b = m.group(2) 337 | examples.append( 338 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 339 | unique_id += 1 340 | return examples 341 | 342 | 343 | def main(_): 344 | tf.logging.set_verbosity(tf.logging.INFO) 345 | 346 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 347 | 348 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 349 | 350 | tokenizer = tokenization.FullTokenizer( 351 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 354 | run_config = tf.contrib.tpu.RunConfig( 355 | master=FLAGS.master, 356 | tpu_config=tf.contrib.tpu.TPUConfig( 357 | num_shards=FLAGS.num_tpu_cores, 358 | per_host_input_for_training=is_per_host)) 359 | 360 | examples = read_examples(FLAGS.input_file) 361 | 362 | features = convert_examples_to_features( 363 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 364 | 365 | unique_id_to_feature = {} 366 | for feature in features: 367 | unique_id_to_feature[feature.unique_id] = feature 368 | 369 | model_fn = model_fn_builder( 370 | bert_config=bert_config, 371 | init_checkpoint=FLAGS.init_checkpoint, 372 | layer_indexes=layer_indexes, 373 | use_tpu=FLAGS.use_tpu, 374 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 375 | 376 | # If TPU is not available, this will fall back to normal Estimator on CPU 377 | # or GPU. 378 | estimator = tf.contrib.tpu.TPUEstimator( 379 | use_tpu=FLAGS.use_tpu, 380 | model_fn=model_fn, 381 | config=run_config, 382 | predict_batch_size=FLAGS.batch_size) 383 | 384 | input_fn = input_fn_builder( 385 | features=features, seq_length=FLAGS.max_seq_length) 386 | 387 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 388 | "w")) as writer: 389 | for result in estimator.predict(input_fn, yield_single_examples=True): 390 | unique_id = int(result["unique_id"]) 391 | feature = unique_id_to_feature[unique_id] 392 | output_json = collections.OrderedDict() 393 | output_json["linex_index"] = unique_id 394 | all_features = [] 395 | for (i, token) in enumerate(feature.tokens): 396 | all_layers = [] 397 | for (j, layer_index) in enumerate(layer_indexes): 398 | layer_output = result["layer_output_%d" % j] 399 | layers = collections.OrderedDict() 400 | layers["index"] = layer_index 401 | layers["values"] = [ 402 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 403 | ] 404 | all_layers.append(layers) 405 | features = collections.OrderedDict() 406 | features["token"] = token 407 | features["layers"] = all_layers 408 | all_features.append(features) 409 | output_json["features"] = all_features 410 | writer.write(json.dumps(output_json) + "\n") 411 | 412 | 413 | if __name__ == "__main__": 414 | flags.mark_flag_as_required("input_file") 415 | flags.mark_flag_as_required("vocab_file") 416 | flags.mark_flag_as_required("bert_config_file") 417 | flags.mark_flag_as_required("init_checkpoint") 418 | flags.mark_flag_as_required("output_file") 419 | tf.app.run() 420 | -------------------------------------------------------------------------------- /make_data/make_data_for_augment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import datetime 3 | import argparse 4 | import numpy as np 5 | from utils import * 6 | 7 | 8 | 9 | f = open("squad_data/train-v1.1.json", mode="r", encoding="utf-8") 10 | squad_data = json.load(f) 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--toy', action='store_true', default=False, 14 | help='If set, use small data; used for fast debugging.') 15 | parser.add_argument('--sd', type=str, default='saved_model_kg', 16 | help='set model save directory.') 17 | parser.add_argument('--db_content', type=int, default=0, 18 | help='0: use knowledge graph type, 1: use db content to get type info') 19 | parser.add_argument('--train_emb', action='store_true', 20 | help='Use trained word embedding for SQLNet.') 21 | args = parser.parse_args() 22 | 23 | N_word=600 24 | B_word=42 25 | if args.toy: 26 | USE_SMALL=True 27 | GPU=True 28 | BATCH_SIZE=15 29 | else: 30 | USE_SMALL=False 31 | GPU=True 32 | BATCH_SIZE=64 33 | TEST_ENTRY=(True, True, True) # (AGG, SEL, COND) 34 | 35 | sql_data, table_data, val_sql_data, val_table_data, \ 36 | test_sql_data, test_table_data, \ 37 | TRAIN_DB, DEV_DB, TEST_DB = load_dataset(use_small=USE_SMALL) 38 | 39 | 40 | 41 | def process(indata1,indata2,indata3,outfile): 42 | DB = indata3 43 | engine = DBEngine(DB) 44 | sql_data = indata1 45 | table_data = indata2 46 | # point to the position of matrix 47 | def input_table_output_matrix(input): 48 | result = [] 49 | header = input["header_tok"] 50 | result_header = [] 51 | for item in header: 52 | concat_str = "" 53 | for str_ in item: 54 | concat_str += str_ 55 | concat_str += " " 56 | concat_str = concat_str.strip() 57 | result_header.append(concat_str.lower()) 58 | result.append(result_header) 59 | for row in input["rows"]: 60 | result.append([str(item).lower() for item in row]) 61 | return result 62 | 63 | tableid2qa={} 64 | paragraphs = [] 65 | 66 | 67 | for item in squad_data["data"]: 68 | for context_qas in item["paragraphs"]: 69 | context = context_qas["context"] 70 | qas = context_qas["qas"] 71 | import sys 72 | import traceback 73 | count = 0 74 | for i in range(len(sql_data)): 75 | try: 76 | # if i>100: 77 | # break 78 | count += 1 79 | # if count==16070: 80 | # print() 81 | one_sql = sql_data[i] 82 | if one_sql["sql"]["agg"]!=0: 83 | continue 84 | if len(one_sql["sql"]["conds"])>1: 85 | continue 86 | # only consider one condition and no agg situation 87 | answer = engine.execute(one_sql["table_id"], one_sql["sql"]["sel"],one_sql["sql"]["agg"],one_sql["sql"]["conds"]) 88 | question = one_sql["question"] 89 | 90 | # table_info = table_data[one_sql["table_id"]] 91 | 92 | if one_sql["table_id"] in tableid2qa: 93 | qa = {} 94 | qa["question"] = question.lower() 95 | qa["answer"] = [str(item).lower() for item in answer] 96 | tableid2qa[one_sql["table_id"]].append(qa) 97 | else: 98 | tableid2qa[one_sql["table_id"]] = [] 99 | qa = {} 100 | qa["question"] = question.lower() 101 | qa["answer"] = [str(item).lower() for item in answer] 102 | tableid2qa[one_sql["table_id"]].append(qa) 103 | except: 104 | # exc_type, exc_value, exc_traceback = sys.exc_info() 105 | # traceback.print_tb(exc_traceback, limit=1, file=sys.stdout) 106 | print(count) 107 | continue 108 | 109 | print(count) 110 | def find_answer_position_in_matrix(matrix,text,row1,row2,col1,col2): 111 | result = 0 112 | # if row1==1 and row2==1 and col1==0 and col2==0: 113 | if True: 114 | for row_index,row in enumerate(matrix): 115 | for col_index,one_col in enumerate(row): 116 | if one_col==text: 117 | return result 118 | else: 119 | try: 120 | if one_col == str(text).split(".")[0]: 121 | return result 122 | except: 123 | continue 124 | result += 1 125 | print("error") 126 | print(matrix) 127 | print(text) 128 | print() 129 | else: 130 | for row_index,row in enumerate(matrix): 131 | if row_index==row1 or row_index==row2: 132 | for col_index,one_col in enumerate(row): 133 | if col_index==col1 or col_index==col2: 134 | if one_col==text: 135 | return result 136 | else: 137 | try: 138 | if one_col == str(text).split(".")[0]: 139 | return result 140 | except: 141 | continue 142 | result += 1 143 | 144 | return None 145 | 146 | 147 | paragraphs = [] 148 | 149 | def add_to_paragraphs_with_shuffle(row1=1,row2=1,col1=0,col2=0): 150 | for tableid in tableid2qa: 151 | context = input_table_output_matrix(table_data[tableid]) 152 | 153 | if col1100: 76 | # break 77 | count += 1 78 | # if count==16070: 79 | # print() 80 | one_sql = sql_data_[i] 81 | if one_sql["sql"]["agg"]!=0: 82 | continue 83 | if len(one_sql["sql"]["conds"])>1: 84 | continue 85 | # only consider one condition and no agg situation 86 | answer = engine.execute(one_sql["table_id"], one_sql["sql"]["sel"],one_sql["sql"]["agg"],one_sql["sql"]["conds"]) 87 | question = one_sql["question"] 88 | 89 | # table_info = table_data[one_sql["table_id"]] 90 | 91 | if one_sql["table_id"] in tableid2qa: 92 | qa = {} 93 | qa["question"] = question.lower() 94 | qa["answer"] = [str(item).lower() for item in answer] 95 | tableid2qa[one_sql["table_id"]].append(qa) 96 | else: 97 | tableid2qa[one_sql["table_id"]] = [] 98 | qa = {} 99 | qa["question"] = question.lower() 100 | qa["answer"] = [str(item).lower() for item in answer] 101 | tableid2qa[one_sql["table_id"]].append(qa) 102 | except: 103 | # exc_type, exc_value, exc_traceback = sys.exc_info() 104 | # traceback.print_tb(exc_traceback, limit=1, file=sys.stdout) 105 | print(count) 106 | continue 107 | 108 | print(count) 109 | def find_answer_position_in_matrix(matrix,text): 110 | result = 0 111 | for row in matrix: 112 | for one_col in row: 113 | if one_col==text: 114 | return result 115 | else: 116 | try: 117 | if one_col == str(text).split(".")[0]: 118 | return result 119 | except: 120 | continue 121 | result += 1 122 | print("error") 123 | print(matrix) 124 | print(text) 125 | print() 126 | 127 | paragraphs = [] 128 | for tableid in tableid2qa: 129 | 130 | headers.append(table_data_[tableid]["header"]) 131 | 132 | context = input_table_output_matrix(table_data_[tableid]) 133 | context_qas ={} 134 | context_qas["context"] = context 135 | context_qas["qas"] = [] 136 | 137 | for item in tableid2qa[tableid]: 138 | one_qas = {} 139 | one_qas["answers"] = [] 140 | one_answer = {} 141 | if len(item["answer"])==1:# only consider one answer 142 | one_answer["text"] = item["answer"][0] 143 | else: 144 | continue 145 | one_answer["answer_start"] = find_answer_position_in_matrix(context,one_answer["text"]) 146 | if one_answer["answer_start"]==None: 147 | continue 148 | one_qas["question"] = item["question"] 149 | one_qas["answers"].append(one_answer) 150 | context_qas["qas"].append(one_qas) 151 | if len(context_qas["qas"])==0: # only context, no answer 152 | continue 153 | else: 154 | paragraphs.append(context_qas) 155 | 156 | 157 | squad_data["data"]=[] 158 | one_data = {} 159 | one_data["paragraphs"] = paragraphs 160 | squad_data["data"].append(one_data) 161 | 162 | # f2 = open(outfile, mode="w", encoding="utf-8") 163 | # json.dump(squad_data,f2) 164 | return headers 165 | 166 | 167 | def process_test(indata1, indata2, indata3, outfile, headers_train): 168 | DB = indata3 169 | engine = DBEngine(DB) 170 | sql_data_ = indata1 171 | table_data_ = indata2 172 | # point to the position of matrix 173 | def input_table_output_matrix(input): 174 | result = [] 175 | header = input["header_tok"] 176 | result_header = [] 177 | for item in header: 178 | concat_str = "" 179 | for str_ in item: 180 | concat_str += str_ 181 | concat_str += " " 182 | concat_str = concat_str.strip() 183 | result_header.append(concat_str.lower()) 184 | result.append(result_header) 185 | for row in input["rows"]: 186 | result.append([str(item).lower() for item in row]) 187 | return result 188 | 189 | tableid2qa_test = {} 190 | paragraphs = [] 191 | 192 | 193 | for item in squad_data["data"]: 194 | for context_qas in item["paragraphs"]: 195 | context = context_qas["context"] 196 | qas = context_qas["qas"] 197 | import sys 198 | import traceback 199 | count = 0 200 | for i in range(len(sql_data_)): 201 | try: 202 | # if i>100: 203 | # break 204 | count += 1 205 | # if count==16070: 206 | # print() 207 | one_sql = sql_data_[i] 208 | if one_sql["sql"]["agg"]!=0: 209 | continue 210 | if len(one_sql["sql"]["conds"])>1: 211 | continue 212 | # only consider one condition and no agg situation 213 | answer = engine.execute(one_sql["table_id"], one_sql["sql"]["sel"],one_sql["sql"]["agg"],one_sql["sql"]["conds"]) 214 | question = one_sql["question"] 215 | 216 | # table_info = table_data[one_sql["table_id"]] 217 | if table_data_[one_sql["table_id"]]["header"] in headers_train: 218 | print("in!!!!!") 219 | if one_sql["table_id"] in tableid2qa_test: 220 | qa = {} 221 | qa["question"] = question.lower() 222 | qa["answer"] = [str(item).lower() for item in answer] 223 | tableid2qa_test[one_sql["table_id"]].append(qa) 224 | else: 225 | tableid2qa_test[one_sql["table_id"]] = [] 226 | qa = {} 227 | qa["question"] = question.lower() 228 | qa["answer"] = [str(item).lower() for item in answer] 229 | tableid2qa_test[one_sql["table_id"]].append(qa) 230 | else: 231 | print("out!!!!!") 232 | except: 233 | # exc_type, exc_value, exc_traceback = sys.exc_info() 234 | # traceback.print_tb(exc_traceback, limit=1, file=sys.stdout) 235 | print(count) 236 | continue 237 | 238 | print(count) 239 | def find_answer_position_in_matrix(matrix,text): 240 | result = 0 241 | for row in matrix: 242 | for one_col in row: 243 | if one_col==text: 244 | return result 245 | else: 246 | try: 247 | if one_col == str(text).split(".")[0]: 248 | return result 249 | except: 250 | continue 251 | result += 1 252 | print("error") 253 | print(matrix) 254 | print(text) 255 | print() 256 | 257 | paragraphs = [] 258 | for tableid in tableid2qa_test: 259 | context = input_table_output_matrix(table_data_[tableid]) 260 | context_qas ={} 261 | context_qas["context"] = context 262 | context_qas["qas"] = [] 263 | 264 | for item in tableid2qa_test[tableid]: 265 | one_qas = {} 266 | one_qas["answers"] = [] 267 | one_answer = {} 268 | if len(item["answer"])==1:# only consider one answer 269 | one_answer["text"] = item["answer"][0] 270 | else: 271 | continue 272 | one_answer["answer_start"] = find_answer_position_in_matrix(context,one_answer["text"]) 273 | if one_answer["answer_start"]==None: 274 | continue 275 | one_qas["question"] = item["question"] 276 | one_qas["answers"].append(one_answer) 277 | context_qas["qas"].append(one_qas) 278 | if len(context_qas["qas"])==0: # only context, no answer 279 | continue 280 | else: 281 | paragraphs.append(context_qas) 282 | 283 | 284 | squad_data["data"]=[] 285 | one_data = {} 286 | one_data["paragraphs"] = paragraphs 287 | squad_data["data"].append(one_data) 288 | 289 | f2 = open(outfile, mode="w", encoding="utf-8") 290 | json.dump(squad_data,f2) 291 | 292 | headers = process_train(sql_data,table_data,TRAIN_DB,"squad_data/my_train.json") 293 | process_test(val_sql_data,val_table_data,DEV_DB,"squad_data/my_valid.json",headers) -------------------------------------------------------------------------------- /make_data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import io 3 | import json 4 | import numpy as np 5 | from typesql.lib.dbengine import DBEngine 6 | 7 | def load_data(sql_paths, table_paths, use_small=False): 8 | if not isinstance(sql_paths, list): 9 | sql_paths = (sql_paths, ) 10 | if not isinstance(table_paths, list): 11 | table_paths = (table_paths, ) 12 | sql_data = [] 13 | table_data = {} 14 | 15 | max_col_num = 0 16 | for SQL_PATH in sql_paths: 17 | print("Loading data from %s"%SQL_PATH) 18 | with open(SQL_PATH) as inf: 19 | for idx, line in enumerate(inf): 20 | if use_small and idx >= 1000: 21 | break 22 | sql = json.loads(line.strip()) 23 | sql_data.append(sql) 24 | 25 | for TABLE_PATH in table_paths: 26 | print("Loading data from %s"%TABLE_PATH) 27 | with open(TABLE_PATH) as inf: 28 | for line in inf: 29 | tab = json.loads(line.strip()) 30 | table_data[tab[u'id']] = tab 31 | 32 | for sql in sql_data: 33 | assert sql[u'table_id'] in table_data 34 | 35 | return sql_data, table_data 36 | 37 | 38 | def load_dataset(use_small=False): 39 | print("Loading from original dataset") 40 | sql_data, table_data = load_data('data/train_tok.jsonl', 41 | 'data/train_tok.tables.jsonl', use_small=use_small) 42 | val_sql_data, val_table_data = load_data('data/dev_tok.jsonl', 43 | 'data/dev_tok.tables.jsonl', use_small=use_small) 44 | 45 | test_sql_data, test_table_data = load_data('data/test_tok.jsonl', 46 | 'data/test_tok.tables.jsonl', use_small=use_small) 47 | TRAIN_DB = 'data/train.db' 48 | DEV_DB = 'data/dev.db' 49 | TEST_DB = 'data/test.db' 50 | 51 | return sql_data, table_data, val_sql_data, val_table_data,\ 52 | test_sql_data, test_table_data, TRAIN_DB, DEV_DB, TEST_DB 53 | 54 | def best_model_name(args, for_load=False): 55 | new_data = 'old' 56 | mode = 'sqlnet' 57 | if for_load: 58 | use_emb = '' 59 | else: 60 | use_emb = '_train_emb' if args.train_emb else '' 61 | 62 | agg_model_name = args.sd + '/%s_%s%s.agg_model'%(new_data, 63 | mode, use_emb) 64 | sel_model_name = args.sd + '/%s_%s%s.sel_model'%(new_data, 65 | mode, use_emb) 66 | cond_model_name = args.sd + '/%s_%s%s.cond_model'%(new_data, 67 | mode, use_emb) 68 | 69 | agg_embed_name = args.sd + '/%s_%s%s.agg_embed'%(new_data, mode, use_emb) 70 | sel_embed_name = args.sd + '/%s_%s%s.sel_embed'%(new_data, mode, use_emb) 71 | cond_embed_name = args.sd + '/%s_%s%s.cond_embed'%(new_data, mode, use_emb) 72 | 73 | return agg_model_name, sel_model_name, cond_model_name,\ 74 | agg_embed_name, sel_embed_name, cond_embed_name 75 | 76 | 77 | def to_batch_seq(sql_data, table_data, idxes, st, ed, db_content=0, ret_vis_data=False): 78 | q_seq = [] 79 | col_seq = [] 80 | col_num = [] 81 | ans_seq = [] 82 | query_seq = [] 83 | gt_cond_seq = [] 84 | vis_seq = [] 85 | 86 | q_type = [] 87 | col_type = [] 88 | for i in range(st, ed): 89 | sql = sql_data[idxes[i]] 90 | if db_content == 0: 91 | q_seq.append([[x] for x in sql['question_tok']]) 92 | q_type.append([[x] for x in sql["question_type_org_kgcol"]]) 93 | else: 94 | q_seq.append(sql['question_tok_concol']) 95 | q_type.append(sql["question_type_concol_list"]) 96 | col_type.append(table_data[sql['table_id']]['header_type_kg']) 97 | col_seq.append(table_data[sql['table_id']]['header_tok']) 98 | col_num.append(len(table_data[sql['table_id']]['header'])) 99 | ans_seq.append((sql['sql']['agg'], 100 | sql['sql']['sel'], 101 | len(sql['sql']['conds']), #number of conditions + selection 102 | tuple(x[0] for x in sql['sql']['conds']), #col num rep in condition 103 | tuple(x[1] for x in sql['sql']['conds']))) #op num rep in condition, then where is str in cond? 104 | query_seq.append(sql['query_tok']) # real query string toks 105 | gt_cond_seq.append(sql['sql']['conds']) # list of conds (a list of col, op, str) 106 | vis_seq.append((sql['question'], 107 | table_data[sql['table_id']]['header'], sql['query'], [[x] for x in sql['question_tok']])) 108 | if ret_vis_data: 109 | return q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type, vis_seq 110 | else: 111 | return q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type 112 | 113 | 114 | def to_batch_query(sql_data, idxes, st, ed): 115 | query_gt = [] 116 | table_ids = [] 117 | for i in range(st, ed): 118 | query_gt.append(sql_data[idxes[i]]['sql']) 119 | table_ids.append(sql_data[idxes[i]]['table_id']) 120 | return query_gt, table_ids 121 | 122 | 123 | def epoch_train(model, optimizer, batch_size, sql_data, table_data, pred_entry, db_content): 124 | model.train() 125 | perm=np.random.permutation(len(sql_data)) 126 | cum_loss = 0.0 127 | st = 0 128 | while st < len(sql_data): 129 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 130 | 131 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type = \ 132 | to_batch_seq(sql_data, table_data, perm, st, ed, db_content) 133 | gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) 134 | gt_sel_seq = [x[1] for x in ans_seq] 135 | gt_agg_seq = [x[0] for x in ans_seq] 136 | score = model.forward(q_seq, col_seq, col_num, q_type, col_type, pred_entry, 137 | gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq) 138 | loss = model.loss(score, ans_seq, pred_entry, gt_where_seq) 139 | # cum_loss += loss.data.cpu().numpy()[0]*(ed - st) 140 | cum_loss += loss.data.cpu().numpy() * (ed - st) 141 | optimizer.zero_grad() 142 | loss.backward() 143 | optimizer.step() 144 | 145 | st = ed 146 | 147 | return cum_loss / len(sql_data) 148 | 149 | 150 | def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path, db_content): 151 | engine = DBEngine(db_path) 152 | model.eval() 153 | perm = list(range(len(sql_data))) 154 | tot_acc_num = 0.0 155 | acc_of_log = 0.0 156 | st = 0 157 | while st < len(sql_data): 158 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 159 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type, raw_data = \ 160 | to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True) 161 | raw_q_seq = [x[0] for x in raw_data] 162 | raw_col_seq = [x[1] for x in raw_data] 163 | gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) 164 | query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) 165 | gt_sel_seq = [x[1] for x in ans_seq] 166 | gt_agg_seq = [x[0] for x in ans_seq] 167 | score = model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True)) 168 | pred_queries = model.gen_query(score, q_seq, col_seq, 169 | raw_q_seq, raw_col_seq, (True, True, True)) 170 | 171 | for idx, (sql_gt, sql_pred, tid) in enumerate( 172 | zip(query_gt, pred_queries, table_ids)): 173 | ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) 174 | try: 175 | ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) 176 | except: 177 | ret_pred = None 178 | tot_acc_num += (ret_gt == ret_pred) 179 | 180 | st = ed 181 | 182 | return tot_acc_num / len(sql_data) 183 | 184 | 185 | def epoch_acc(model, batch_size, sql_data, table_data, pred_entry, db_content, error_print=False): 186 | model.eval() 187 | perm = list(range(len(sql_data))) 188 | st = 0 189 | one_acc_num = 0.0 190 | tot_acc_num = 0.0 191 | while st < len(sql_data): 192 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 193 | 194 | q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ 195 | raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True) 196 | raw_q_seq = [x[0] for x in raw_data] 197 | raw_col_seq = [x[1] for x in raw_data] 198 | query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) 199 | gt_sel_seq = [x[1] for x in ans_seq] 200 | score = model.forward(q_seq, col_seq, col_num, q_type, col_type, pred_entry) 201 | pred_queries = model.gen_query(score, q_seq, col_seq, 202 | raw_q_seq, raw_col_seq, pred_entry) 203 | one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt, pred_entry, error_print) 204 | 205 | one_acc_num += (ed-st-one_err) 206 | tot_acc_num += (ed-st-tot_err) 207 | 208 | st = ed 209 | return tot_acc_num / len(sql_data), one_acc_num / len(sql_data) 210 | 211 | 212 | def load_para_wemb(file_name): 213 | f = open(file_name, 'r', encoding='utf-8') 214 | lines = f.readlines() 215 | ret = {} 216 | if len(lines[0].split()) == 2: 217 | lines.pop(0) 218 | for (n,line) in enumerate(lines): 219 | info = line.strip().split(' ') 220 | if info[0].lower() not in ret: 221 | ret[info[0].lower()] = np.array(info[1:]).astype(float) 222 | 223 | return ret 224 | 225 | 226 | def load_comb_wemb(fn1, fn2): 227 | wemb1 = load_word_emb(fn1) 228 | wemb2 = load_para_wemb(fn2) 229 | comb_emb = {k: wemb1.get(k, 0) + wemb2.get(k, 0) for k in set(wemb1) | set(wemb2)} 230 | 231 | return comb_emb 232 | 233 | 234 | def load_concat_wemb(fn1, fn2,use_small=False): 235 | wemb1 = load_word_emb(fn1,use_small=use_small) 236 | wemb2 = load_para_wemb(fn2) 237 | backup = np.zeros(300, dtype=np.float32) 238 | comb_emb = {k: np.concatenate((wemb1.get(k, backup), wemb2.get(k, backup)), axis=0) for k in set(wemb1) | set(wemb2)} 239 | 240 | return None, None, comb_emb 241 | 242 | 243 | def load_word_emb(file_name, load_used=False, use_small=False): 244 | if not load_used: 245 | print ('Loading word embedding from %s'%file_name) 246 | ret = {} 247 | with open(file_name,encoding="utf-8",mode="r") as inf: 248 | for idx, line in enumerate(inf): 249 | if (use_small and idx >= 1000): 250 | break 251 | info = line.strip().split(' ') 252 | if info[0].lower() not in ret: 253 | ret[info[0]] = np.array(info[1:]).astype(float) 254 | return ret 255 | else: 256 | print ('Load used word embedding') 257 | with open('glove/word2idx.json',mode="r",encoding="utf-8") as inf: 258 | w2i = json.load(inf) 259 | with open('glove/usedwordemb.npy',mode="r",encoding="utf-8") as inf: 260 | word_emb_val = np.load(inf) 261 | return w2i, word_emb_val 262 | 263 | 264 | def load_word_and_type_emb(fn1, fn2, sql_data, table_data, db_content, is_list=False, use_htype=False,use_small = False): 265 | word_to_idx = {'':0, '':1, '':2} 266 | word_num = 3 267 | N_word = 300 268 | embs = [np.zeros(N_word, dtype=np.float32) for _ in range(word_num)] 269 | _, _, word_emb = load_concat_wemb(fn1, fn2, use_small) 270 | 271 | if is_list: 272 | for sql in sql_data: 273 | if db_content == 0: 274 | qtype = [[x] for x in sql["question_type_org_kgcol"]] 275 | else: 276 | qtype = sql['question_type_concol_list'] 277 | for tok_typl in qtype: 278 | tys = " ".join(sorted(tok_typl)) 279 | if tys not in word_to_idx: 280 | emb_list = [] 281 | ws_len = len(tok_typl) 282 | for w in tok_typl: 283 | if w in word_emb: 284 | emb_list.append(word_emb[w][:N_word]) 285 | else: 286 | emb_list.append(np.zeros(N_word, dtype=np.float32)) 287 | word_to_idx[tys] = word_num 288 | word_num += 1 289 | embs.append(sum(emb_list) / float(ws_len)) 290 | 291 | if use_htype: 292 | for tab in table_data.values(): 293 | for col in tab['header_type_kg']: 294 | cts = " ".join(sorted(col)) 295 | if cts not in word_to_idx: 296 | emb_list = [] 297 | ws_len = len(col) 298 | for w in col: 299 | if w in word_emb: 300 | emb_list.append(word_emb[w][:N_word]) 301 | else: 302 | emb_list.append(np.zeros(N_word, dtype=np.float32)) 303 | word_to_idx[cts] = word_num 304 | word_num += 1 305 | embs.append(sum(emb_list) / float(ws_len)) 306 | 307 | else: 308 | for sql in sql_data: 309 | if db_content == 0: 310 | qtype = sql['question_tok_type'] 311 | else: 312 | qtype = sql['question_type_concol_list'] 313 | for tok in qtype: 314 | if tok not in word_to_idx: 315 | word_to_idx[tok] = word_num 316 | word_num += 1 317 | embs.append(word_emb[tok][:N_word]) 318 | 319 | if use_htype: 320 | for tab in table_data.values(): 321 | for tok in tab['header_type_kg']: 322 | if tok not in word_to_idx: 323 | word_to_idx[tok] = word_num 324 | word_num += 1 325 | embs.append(word_emb[tok][:N_word]) 326 | 327 | 328 | agg_ops = ['null', 'maximum', 'minimum', 'count', 'total', 'average'] 329 | for tok in agg_ops: 330 | if tok not in word_to_idx: 331 | word_to_idx[tok] = word_num 332 | word_num += 1 333 | embs.append(word_emb[tok][:N_word]) 334 | 335 | emb_array = np.stack(embs, axis=0) 336 | 337 | return (word_to_idx, emb_array, word_emb) 338 | -------------------------------------------------------------------------------- /matrix_code/data_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | def make_batch(size, batch_size): 6 | nb_batch = int(np.ceil(size/float(batch_size))) 7 | return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)] # zgwang: starting point of each batch 8 | 9 | def pad_2d(in_vals, dim1_size, dim2_size, dtype=np.int32): 10 | out_val = np.zeros((dim1_size, dim2_size), dtype=dtype) 11 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 12 | for i in range(dim1_size): 13 | cur_in_vals = in_vals[i] 14 | cur_dim2_size = dim2_size 15 | if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals) 16 | out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size] 17 | return out_val 18 | 19 | def pad_3d(in_vals, dim1_size, dim2_size, dim3_size, dtype=np.int32): 20 | out_val = np.zeros((dim1_size, dim2_size, dim3_size), dtype=dtype) 21 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 22 | for i in range(dim1_size): 23 | in_vals_i = in_vals[i] 24 | cur_dim2_size = dim2_size 25 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i) 26 | for j in range(cur_dim2_size): 27 | in_vals_ij = in_vals_i[j] 28 | cur_dim3_size = dim3_size 29 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij) 30 | out_val[i, j, :cur_dim3_size] = in_vals_ij[:cur_dim3_size] 31 | return out_val 32 | 33 | 34 | class OneBatch(object): 35 | def __init__(self, current_batch, config): 36 | self.target_batch = [] 37 | self.all_cell_batch = [] 38 | self.question_batch = [] 39 | self.type_batch = [] 40 | self.header_len_batch = [] 41 | self.all_cell_len_batch = [] 42 | self.question_len_batch = [] 43 | self.input_mask_batch = [] 44 | # self.answer_batch = [] 45 | for (target_id,all_cell_id,question_id,answer_id, 46 | type_id,header_len,all_cell_len,question_len,input_mask) in current_batch: 47 | self.all_cell_batch.append(all_cell_id) 48 | self.question_batch.append(question_id) 49 | self.target_batch.append(target_id) 50 | self.type_batch.append(type_id) 51 | self.header_len_batch.append(header_len) 52 | self.all_cell_len_batch.append(all_cell_len) 53 | self.question_len_batch.append(question_len) 54 | self.input_mask_batch.append(input_mask) 55 | """ 56 | to numpy 57 | """ 58 | # self.all_cell_batch=np.array(self.all_cell_batch, dtype=np.int32) 59 | # self.question_batch = np.array(self.question_batch, dtype=np.int32) 60 | # self.target_batch = np.array(self.target_batch, dtype=np.int32) 61 | """ 62 | padding 63 | """ 64 | self.all_cell_batch=pad_3d(self.all_cell_batch,dim1_size=config["batch_size"], 65 | dim2_size=config["max_cell_num"],dim3_size=config["max_word_num"]) 66 | self.question_batch=pad_2d(self.question_batch,dim1_size=config["batch_size"], 67 | dim2_size=config["max_question_len"]) 68 | # self.target_batch = pad_2d(self.target_batch, dim1_size=config["batch_size"], 69 | # dim2_size=config["max_cell_num"]+config["max_question_len"]) 70 | # self.type_batch = np.array(self.type_batch) 71 | 72 | class DataUtil(object): 73 | def __init__(self, json_path="../squad_data/my_train.json",config=None,tokenizer=None): 74 | 75 | f = open(json_path, mode="r", encoding="utf-8") 76 | jdata = json.load(f) 77 | all_case_list = [] 78 | count = 0 79 | max_content_len = 0 80 | max_question_len = 0 81 | for context_qas in jdata["data"][0]["paragraphs"]: 82 | count += 1 83 | if config["debug"]==True and count>100: 84 | break 85 | if count % 1000==0: 86 | print("data process ", count) 87 | context = [] 88 | for row in context_qas["context"]: 89 | context.extend(row) 90 | context_id = [] 91 | for cell in context: 92 | context_id.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(cell))) 93 | 94 | if len(context_id)>config["max_cell_num"]: 95 | continue 96 | if len(context_id)>max_content_len: 97 | max_content_len=len(context_id) 98 | 99 | for qas in context_qas["qas"]: 100 | one_case = {} 101 | one_case["all_cell_id"] = context_id 102 | question_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(qas["question"])) 103 | answer_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(qas["answers"][0]["text"])) 104 | # target_id = [0] * len(context_id) 105 | target_id = [0] * (config["max_cell_num"]+config["max_question_len"]) 106 | if qas["answers"][0]["answer_start"]>len(context_id): 107 | print("error: answer start > context") 108 | target_id[qas["answers"][0]["answer_start"]] = 1 # sequence tag problem 109 | # flag = False 110 | # for item in target_id: 111 | # if item==1: 112 | # flag = True 113 | # assert flag==True 114 | header_len = len(context_qas["context"][0]) 115 | context_len = len(context_id) 116 | question_len = len(question_id) 117 | 118 | if question_len > config["max_question_len"]: 119 | continue 120 | if question_len > max_question_len: 121 | max_question_len = question_len 122 | 123 | input_mask = context_len*[1] + (config["max_cell_num"]-context_len)*[0] + \ 124 | question_len*[1] + (config["max_question_len"]-question_len)*[0] 125 | 126 | type_id = [1]*header_len+ (context_len-header_len)*[2] + \ 127 | (config["max_cell_num"] - context_len) * [0]+ \ 128 | question_len * [3] + (config["max_question_len"] - question_len) * [0] 129 | 130 | all_case_list.append((target_id,context_id,question_id,answer_id,type_id, 131 | header_len,context_len,question_len,input_mask)) 132 | print("max content len", max_content_len) 133 | print("max question len", max_question_len) 134 | print("data num", len(all_case_list)) 135 | batch_spans = make_batch(len(all_case_list), config["batch_size"] ) 136 | self.all_batch = [] 137 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 138 | current_batch = [] 139 | for i in range(batch_start, batch_end): 140 | current_batch.append(all_case_list[i]) 141 | if len(current_batch)= len(self.all_batch): 150 | self.cur_pointer = 0 151 | np.random.shuffle(self.index_array) 152 | cur_batch = self.all_batch[self.index_array[self.cur_pointer]] 153 | self.cur_pointer += 1 154 | return cur_batch 155 | 156 | def get_batch(self, i): 157 | if i >= len(self.all_batch): return None 158 | return self.all_batch[self.index_array[i]] 159 | 160 | 161 | class DataUtil_bert(object): 162 | def __init__(self, json_path="../squad_data/my_train.json",config=None,tokenizer=None): 163 | 164 | f = open(json_path, mode="r", encoding="utf-8") 165 | jdata = json.load(f) 166 | all_case_list = [] 167 | count = 0 168 | max_content_len = 0 169 | max_question_len = 0 170 | for context_qas in jdata["data"][0]["paragraphs"]: 171 | count += 1 172 | if config["debug"]==True and count>100: 173 | break 174 | if count % 1000==0: 175 | print("data process ", count) 176 | context = [] 177 | for row in context_qas["context"]: 178 | context.extend(row) 179 | context_id = [] 180 | context_id.append(tokenizer.convert_tokens_to_ids(["[CLS]"])) 181 | for cell in context: 182 | context_id.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(cell))) 183 | context_id.append(tokenizer.convert_tokens_to_ids(["[SEP]"])) 184 | if len(context_id)>config["max_cell_num"]: 185 | continue 186 | if len(context_id)>max_content_len: 187 | max_content_len=len(context_id) 188 | 189 | for qas in context_qas["qas"]: 190 | question = tokenizer.tokenize(qas["question"]) 191 | question = question + ["[SEP]"] 192 | question_id = tokenizer.convert_tokens_to_ids(question) 193 | answer_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(qas["answers"][0]["text"])) 194 | # target_id = [0] * len(context_id) 195 | target_id = [0] * (config["max_cell_num"]+config["max_question_len"]) 196 | if qas["answers"][0]["answer_start"]>len(context_id)-1: 197 | print("error: answer start > context") 198 | continue 199 | target_id[qas["answers"][0]["answer_start"]+1] = 1 200 | # flag = False 201 | # for item in target_id: 202 | # if item==1: 203 | # flag = True 204 | # assert flag==True 205 | header_len = len(context_qas["context"][0]) + 1 206 | context_len = len(context_id) 207 | question_len = len(question_id) 208 | 209 | if question_len > config["max_question_len"]: 210 | continue 211 | if question_len > max_question_len: 212 | max_question_len = question_len 213 | 214 | input_mask = context_len*[1] + (config["max_cell_num"]-context_len)*[0] + \ 215 | question_len*[1] + (config["max_question_len"]-question_len)*[0] 216 | 217 | type_id = config["max_cell_num"]*[0] + config["max_question_len"] * [1] 218 | 219 | all_case_list.append((target_id,context_id,question_id,answer_id,type_id, 220 | header_len,context_len,question_len,input_mask)) 221 | print("max content len", max_content_len) 222 | print("max question len", max_question_len) 223 | print("data num", len(all_case_list)) 224 | batch_spans = make_batch(len(all_case_list), config["batch_size"] ) 225 | self.all_batch = [] 226 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 227 | current_batch = [] 228 | for i in range(batch_start, batch_end): 229 | current_batch.append(all_case_list[i]) 230 | if len(current_batch)= len(self.all_batch): 239 | self.cur_pointer = 0 240 | np.random.shuffle(self.index_array) 241 | cur_batch = self.all_batch[self.index_array[self.cur_pointer]] 242 | self.cur_pointer += 1 243 | return cur_batch 244 | 245 | def get_batch(self, i): 246 | if i >= len(self.all_batch): return None 247 | return self.all_batch[self.index_array[i]] -------------------------------------------------------------------------------- /matrix_code/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matrix.main_modeling as main_modeling 3 | 4 | class ModelWrapper(): 5 | def __init__(self, config, is_train): 6 | self.config = config 7 | self.header_len_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"]]) 8 | self.all_cell_len_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"]]) 9 | self.question_len_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"]]) 10 | self.matrix_placeholder = tf.placeholder(dtype=tf.int32,shape=[config["batch_size"], config["max_cell_num"], config["max_word_num"]]) 11 | self.input_mask_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"] , config["max_cell_num"] + config["max_question_len"]]) 12 | self.target_placeholder = tf.placeholder(dtype=tf.int32,shape=[config["batch_size"], config["max_cell_num"]+config["max_question_len"]]) 13 | self.question_placeholder = tf.placeholder(dtype=tf.int32,shape=[config["batch_size"], config["max_question_len"]]) 14 | self.type_placeholder = tf.placeholder(dtype=tf.int32,shape=[config["batch_size"], config["max_cell_num"] + config["max_question_len"]]) 15 | main_model_config = main_modeling.BertConfig.from_json_file("../uncased_L-12_H-768_A-12/bert_config.json") 16 | 17 | main_model = main_modeling.BertModel(main_model_config, 18 | config, 19 | input_matrix=self.matrix_placeholder, 20 | input_question=self.question_placeholder, 21 | token_type_ids=self.type_placeholder, 22 | input_mask=self.input_mask_placeholder, 23 | is_training=is_train, 24 | use_one_hot_embeddings=False) 25 | 26 | final_hidden = main_model.get_sequence_output()#[:,:config["max_cell_num"],:] 27 | 28 | """ 29 | logits = [] 30 | init_scale = 0.01 31 | initializer = tf.random_uniform_initializer(-init_scale, init_scale) 32 | W_projection_slot = tf.get_variable("W_projection_slot", shape=[200, 2], 33 | initializer=initializer) # [embed_size,label_size] 34 | b_projection_slot = tf.get_variable("b_projection_slot", shape=[2]) 35 | dense = tf.layers.Dense(200, activation=tf.nn.tanh) 36 | for i in range(config["max_cell_num"]): 37 | feature = final_hidden[:, i, :] # [none,self.hidden_size*2] 38 | hidden_states = dense(feature) # [none,hidden_size] 39 | output = tf.matmul(hidden_states, W_projection_slot) + b_projection_slot # [none,slots_num_classes] 40 | logits.append(output) 41 | # logits is a list. each element is:[none,slots_num_classes] 42 | logits = tf.stack(logits, axis=1) # [none,sequence_length,slots_num_classes] 43 | self.predictions_slots = tf.argmax(logits, axis=2, name="predictions_slots") 44 | correct_prediction_slot = tf.equal(tf.cast(self.predictions_slots, tf.int32), 45 | self.target_placeholder) # [batch_size, self.sequence_length] 46 | accuracy_slot = tf.reduce_mean(tf.cast(correct_prediction_slot, tf.float32), name="accuracy_slot") # shape=() 47 | loss_slot = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.target_placeholder, logits=logits) 48 | self.loss = tf.reduce_mean(loss_slot) 49 | """ 50 | 51 | batch_size = config["batch_size"] 52 | seq_length = config["max_cell_num"]+config["max_question_len"] 53 | hidden_size = main_model_config.hidden_size 54 | 55 | output_weights = tf.get_variable( 56 | "cls/squad/output_weights", [1, hidden_size], 57 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 58 | 59 | output_bias = tf.get_variable( 60 | "cls/squad/output_bias", [1], initializer=tf.zeros_initializer()) 61 | 62 | final_hidden_matrix = tf.reshape(final_hidden, 63 | [batch_size * seq_length, hidden_size]) 64 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) 65 | logits = tf.nn.bias_add(logits, output_bias) 66 | 67 | self.start_logits = tf.reshape(logits, [batch_size, seq_length]) 68 | 69 | self.all_cell_mask = tf.sequence_mask(self.all_cell_len_placeholder, maxlen=config["max_cell_num"] + config["max_question_len"], dtype=tf.float32) 70 | self.header_mask = tf.sequence_mask(self.header_len_placeholder,maxlen=config["max_cell_num"]+config["max_question_len"],dtype=tf.float32) 71 | self.target_mask = self.all_cell_mask - self.header_mask 72 | 73 | self.output_logits = self.start_logits * self.target_mask 74 | 75 | log_probs = tf.nn.log_softmax(self.start_logits*self.target_mask, axis=-1) 76 | self.loss = -tf.reduce_mean( 77 | tf.reduce_sum(tf.cast(self.target_placeholder,tf.float32) * log_probs*self.target_mask, axis=-1)) 78 | 79 | if not is_train: return 80 | optimizer = tf.train.AdamOptimizer(learning_rate=config["learning_rate"]) 81 | 82 | # tvars = tf.trainable_variables() 83 | # l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1]) 84 | # self.loss = self.loss + 0.01 * l2_loss 85 | # self.train_op = optimizer.minimize(self.loss) 86 | 87 | def var_filter(var_list, last_layers): 88 | filter_keywords = ['layer_11', 'layer_10', 'layer_9', 'layer_8'] 89 | for var in var_list: 90 | if "bert" not in var.name: 91 | yield var 92 | else: 93 | for layer in last_layers: 94 | kw = filter_keywords[layer] 95 | if kw in var.name: 96 | yield var 97 | 98 | def compute_gradients(tensor, var_list): 99 | grads = tf.gradients(tensor, var_list) 100 | return [grad if grad is not None else tf.zeros_like(var) for var, grad in zip(var_list, grads)] 101 | 102 | tvars = list(var_filter(tf.trainable_variables(), last_layers=range(3))) 103 | grads = compute_gradients(self.loss, tvars) 104 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 105 | 106 | -------------------------------------------------------------------------------- /matrix_code/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | import matrix.data_util as data_util 4 | import matrix.model_wrapper as model_wrapper 5 | import numpy as np 6 | import tokenization 7 | config = {} 8 | config["batch_size"] = 16 9 | config["max_cell_num"] = 100 10 | config["max_word_num"] = 10 # max word num in a cell 11 | config["max_question_len"] = 30 12 | config["debug"] = False 13 | config["learning_rate"] = 5e-5 14 | 15 | tokenizer = tokenization.FullTokenizer(vocab_file="../uncased_L-12_H-768_A-12/vocab.txt") 16 | 17 | valid_data_util = data_util.DataUtil_bert(json_path="../squad_data/my_valid.json",config=config,tokenizer=tokenizer) 18 | train_data_util = data_util.DataUtil_bert(json_path="../squad_data/my_train.json",config=config,tokenizer=tokenizer) 19 | 20 | with tf.variable_scope("Model", reuse=False): 21 | train_model = model_wrapper.ModelWrapper(config, is_train = True) 22 | 23 | with tf.variable_scope("Model", reuse=True): 24 | valid_model = model_wrapper.ModelWrapper(config, is_train = False) 25 | 26 | 27 | sess = tf.Session() 28 | saver = tf.train.Saver() 29 | sess.run(tf.global_variables_initializer()) 30 | # import os 31 | # if os.path.exists("save_model"): 32 | # saver.restore(sess,"save_model/model") 33 | 34 | def optimistic_restore(session, save_file): 35 | """ 36 | restore only those variable that exists in the model 37 | :param session: 38 | :param save_file: 39 | :return: 40 | """ 41 | reader = tf.train.NewCheckpointReader(save_file) 42 | # reader.get_tensor() 43 | saved_shapes = reader.get_variable_to_shape_map() 44 | print(saved_shapes) 45 | print() 46 | print([var.name for var in tf.global_variables()]) 47 | 48 | restore_vars = { (v.name.split(':')[0].replace("Model/","").replace("Model/","")): v for v in tf.trainable_variables() if 'bert' in v.name} 49 | saver = tf.train.Saver(restore_vars) 50 | saver.restore(session, save_file) 51 | 52 | optimistic_restore(sess, "../uncased_L-12_H-768_A-12/bert_model.ckpt") 53 | 54 | correct_count = 0 55 | total_count = 0 56 | max_acc = 0 57 | best_accuracy = -1 58 | 59 | for epoch in range(50): 60 | print('Train in epoch %d' % epoch) 61 | num_batch = len(train_data_util.all_batch) 62 | start_time = time.time() 63 | total_loss = 0 64 | for batch_index in range(num_batch): # for each batch 65 | cur_batch = train_data_util.next_batch() 66 | train_feed_dict = { 67 | train_model.matrix_placeholder:cur_batch.all_cell_batch, 68 | train_model.question_placeholder:cur_batch.question_batch, 69 | train_model.target_placeholder:cur_batch.target_batch, 70 | train_model.type_placeholder:cur_batch.type_batch, 71 | train_model.input_mask_placeholder:cur_batch.input_mask_batch, 72 | train_model.all_cell_len_placeholder:cur_batch.all_cell_len_batch, 73 | train_model.question_len_placeholder:cur_batch.question_len_batch, 74 | train_model.header_len_placeholder:cur_batch.header_len_batch, 75 | } 76 | _, loss_value,predictions_slots_train = sess.run([train_model.train_op, 77 | train_model.loss, 78 | train_model.output_logits], 79 | feed_dict=train_feed_dict) 80 | total_loss += loss_value 81 | for i in range(len(predictions_slots_train)): 82 | predict_cell_index = np.argmax(predictions_slots_train[i]) 83 | true_cell_index = np.argmax(cur_batch.target_batch[i]) 84 | if predict_cell_index == true_cell_index or \ 85 | predict_cell_index < config["max_cell_num"] and \ 86 | (cur_batch.all_cell_batch[i][predict_cell_index]== 87 | cur_batch.all_cell_batch[i][true_cell_index]).all(): 88 | correct_count += 1 89 | total_count += 1 90 | if total_count!=0: 91 | print("train acc" , (correct_count / total_count)) 92 | duration = time.time() - start_time 93 | start_time = time.time() 94 | print('train loss = %.4f (%.3f sec)' % (total_loss / num_batch, duration)) 95 | 96 | correct_count = 0 97 | total_count = 0 98 | total_loss = 0 99 | 100 | num_valid_batch = len(valid_data_util.all_batch) 101 | for valid_batch_index in range(num_valid_batch): 102 | cur_valid_batch = valid_data_util.next_batch() 103 | valid_feed_dict = { 104 | valid_model.matrix_placeholder: cur_valid_batch.all_cell_batch, 105 | valid_model.question_placeholder: cur_valid_batch.question_batch, 106 | valid_model.target_placeholder: cur_valid_batch.target_batch, 107 | valid_model.type_placeholder: cur_valid_batch.type_batch, 108 | valid_model.input_mask_placeholder: cur_valid_batch.input_mask_batch, 109 | valid_model.all_cell_len_placeholder: cur_valid_batch.all_cell_len_batch, 110 | valid_model.question_len_placeholder: cur_valid_batch.question_len_batch, 111 | valid_model.header_len_placeholder: cur_valid_batch.header_len_batch, 112 | } 113 | valid_loss_value, predictions_slots_ = sess.run( 114 | [valid_model.loss, valid_model.output_logits], 115 | feed_dict=valid_feed_dict) 116 | total_loss += valid_loss_value 117 | for i in range(len(predictions_slots_)): 118 | predict_cell_index = np.argmax(predictions_slots_[i]) 119 | true_cell_index = np.argmax(cur_valid_batch.target_batch[i]) 120 | if predict_cell_index == true_cell_index or \ 121 | predict_cell_index < config["max_cell_num"] and \ 122 | (cur_valid_batch.all_cell_batch[i][predict_cell_index]== 123 | cur_valid_batch.all_cell_batch[i][true_cell_index]).all(): 124 | correct_count += 1 125 | total_count += 1 126 | duration = time.time() - start_time 127 | start_time = time.time() 128 | print('valid loss = %.4f (%.3f sec)' % (total_loss / num_valid_batch, duration)) 129 | if total_count!=0: 130 | print("valid acc" , (correct_count / total_count)) 131 | if total_count != 0: 132 | if correct_count / total_count > max_acc and epoch>5: 133 | max_acc = correct_count / total_count 134 | saver.save(sess,"save_model/model") 135 | correct_count = 0 136 | total_count = 0 137 | print("max acc",max_acc) 138 | -------------------------------------------------------------------------------- /modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization 26 | import tokenization 27 | import tensorflow as tf 28 | 29 | flags = tf.flags 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | ## Required parameters 34 | flags.DEFINE_string( 35 | "data_dir", None, 36 | "The input data dir. Should contain the .tsv files (or other data files) " 37 | "for the task.") 38 | 39 | flags.DEFINE_string( 40 | "bert_config_file", None, 41 | "The config json file corresponding to the pre-trained BERT model. " 42 | "This specifies the model architecture.") 43 | 44 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 45 | 46 | flags.DEFINE_string("vocab_file", None, 47 | "The vocabulary file that the BERT model was trained on.") 48 | 49 | flags.DEFINE_string( 50 | "output_dir", None, 51 | "The output directory where the model checkpoints will be written.") 52 | 53 | ## Other parameters 54 | 55 | flags.DEFINE_string( 56 | "init_checkpoint", None, 57 | "Initial checkpoint (usually from a pre-trained BERT model).") 58 | 59 | flags.DEFINE_bool( 60 | "do_lower_case", True, 61 | "Whether to lower case the input text. Should be True for uncased " 62 | "models and False for cased models.") 63 | 64 | flags.DEFINE_integer( 65 | "max_seq_length", 128, 66 | "The maximum total input sequence length after WordPiece tokenization. " 67 | "Sequences longer than this will be truncated, and sequences shorter " 68 | "than this will be padded.") 69 | 70 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 71 | 72 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 73 | 74 | flags.DEFINE_bool( 75 | "do_predict", False, 76 | "Whether to run the model in inference mode on the test set.") 77 | 78 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 79 | 80 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 81 | 82 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 83 | 84 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 85 | 86 | flags.DEFINE_float("num_train_epochs", 3.0, 87 | "Total number of training epochs to perform.") 88 | 89 | flags.DEFINE_float( 90 | "warmup_proportion", 0.1, 91 | "Proportion of training to perform linear learning rate warmup for. " 92 | "E.g., 0.1 = 10% of training.") 93 | 94 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 95 | "How often to save the model checkpoint.") 96 | 97 | flags.DEFINE_integer("iterations_per_loop", 1000, 98 | "How many steps to make in each estimator call.") 99 | 100 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 101 | 102 | tf.flags.DEFINE_string( 103 | "tpu_name", None, 104 | "The Cloud TPU to use for training. This should be either the name " 105 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 106 | "url.") 107 | 108 | tf.flags.DEFINE_string( 109 | "tpu_zone", None, 110 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 111 | "specified, we will attempt to automatically detect the GCE project from " 112 | "metadata.") 113 | 114 | tf.flags.DEFINE_string( 115 | "gcp_project", None, 116 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 117 | "specified, we will attempt to automatically detect the GCE project from " 118 | "metadata.") 119 | 120 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 121 | 122 | flags.DEFINE_integer( 123 | "num_tpu_cores", 8, 124 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 125 | 126 | 127 | class InputExample(object): 128 | """A single training/test example for simple sequence classification.""" 129 | 130 | def __init__(self, guid, text_a, text_b=None, label=None): 131 | """Constructs a InputExample. 132 | 133 | Args: 134 | guid: Unique id for the example. 135 | text_a: string. The untokenized text of the first sequence. For single 136 | sequence tasks, only this sequence must be specified. 137 | text_b: (Optional) string. The untokenized text of the second sequence. 138 | Only must be specified for sequence pair tasks. 139 | label: (Optional) string. The label of the example. This should be 140 | specified for train and dev examples, but not for test examples. 141 | """ 142 | self.guid = guid 143 | self.text_a = text_a 144 | self.text_b = text_b 145 | self.label = label 146 | 147 | 148 | class PaddingInputExample(object): 149 | """Fake example so the num input examples is a multiple of the batch size. 150 | 151 | When running eval/predict on the TPU, we need to pad the number of examples 152 | to be a multiple of the batch size, because the TPU requires a fixed batch 153 | size. The alternative is to drop the last batch, which is bad because it means 154 | the entire output data won't be generated. 155 | 156 | We use this class instead of `None` because treating `None` as padding 157 | battches could cause silent errors. 158 | """ 159 | 160 | 161 | class InputFeatures(object): 162 | """A single set of features of data.""" 163 | 164 | def __init__(self, 165 | input_ids, 166 | input_mask, 167 | segment_ids, 168 | label_id, 169 | is_real_example=True): 170 | self.input_ids = input_ids 171 | self.input_mask = input_mask 172 | self.segment_ids = segment_ids 173 | self.label_id = label_id 174 | self.is_real_example = is_real_example 175 | 176 | 177 | class DataProcessor(object): 178 | """Base class for data converters for sequence classification data sets.""" 179 | 180 | def get_train_examples(self, data_dir): 181 | """Gets a collection of `InputExample`s for the train set.""" 182 | raise NotImplementedError() 183 | 184 | def get_dev_examples(self, data_dir): 185 | """Gets a collection of `InputExample`s for the dev set.""" 186 | raise NotImplementedError() 187 | 188 | def get_test_examples(self, data_dir): 189 | """Gets a collection of `InputExample`s for prediction.""" 190 | raise NotImplementedError() 191 | 192 | def get_labels(self): 193 | """Gets the list of labels for this data set.""" 194 | raise NotImplementedError() 195 | 196 | @classmethod 197 | def _read_tsv(cls, input_file, quotechar=None): 198 | """Reads a tab separated value file.""" 199 | with tf.gfile.Open(input_file, "r") as f: 200 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 201 | lines = [] 202 | for line in reader: 203 | lines.append(line) 204 | return lines 205 | 206 | 207 | class XnliProcessor(DataProcessor): 208 | """Processor for the XNLI data set.""" 209 | 210 | def __init__(self): 211 | self.language = "zh" 212 | 213 | def get_train_examples(self, data_dir): 214 | """See base class.""" 215 | lines = self._read_tsv( 216 | os.path.join(data_dir, "multinli", 217 | "multinli.train.%s.tsv" % self.language)) 218 | examples = [] 219 | for (i, line) in enumerate(lines): 220 | if i == 0: 221 | continue 222 | guid = "train-%d" % (i) 223 | text_a = tokenization.convert_to_unicode(line[0]) 224 | text_b = tokenization.convert_to_unicode(line[1]) 225 | label = tokenization.convert_to_unicode(line[2]) 226 | if label == tokenization.convert_to_unicode("contradictory"): 227 | label = tokenization.convert_to_unicode("contradiction") 228 | examples.append( 229 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 230 | return examples 231 | 232 | def get_dev_examples(self, data_dir): 233 | """See base class.""" 234 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) 235 | examples = [] 236 | for (i, line) in enumerate(lines): 237 | if i == 0: 238 | continue 239 | guid = "dev-%d" % (i) 240 | language = tokenization.convert_to_unicode(line[0]) 241 | if language != tokenization.convert_to_unicode(self.language): 242 | continue 243 | text_a = tokenization.convert_to_unicode(line[6]) 244 | text_b = tokenization.convert_to_unicode(line[7]) 245 | label = tokenization.convert_to_unicode(line[1]) 246 | examples.append( 247 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 248 | return examples 249 | 250 | def get_labels(self): 251 | """See base class.""" 252 | return ["contradiction", "entailment", "neutral"] 253 | 254 | 255 | class MnliProcessor(DataProcessor): 256 | """Processor for the MultiNLI data set (GLUE version).""" 257 | 258 | def get_train_examples(self, data_dir): 259 | """See base class.""" 260 | return self._create_examples( 261 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 262 | 263 | def get_dev_examples(self, data_dir): 264 | """See base class.""" 265 | return self._create_examples( 266 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 267 | "dev_matched") 268 | 269 | def get_test_examples(self, data_dir): 270 | """See base class.""" 271 | return self._create_examples( 272 | self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") 273 | 274 | def get_labels(self): 275 | """See base class.""" 276 | return ["contradiction", "entailment", "neutral"] 277 | 278 | def _create_examples(self, lines, set_type): 279 | """Creates examples for the training and dev sets.""" 280 | examples = [] 281 | for (i, line) in enumerate(lines): 282 | if i == 0: 283 | continue 284 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 285 | text_a = tokenization.convert_to_unicode(line[8]) 286 | text_b = tokenization.convert_to_unicode(line[9]) 287 | if set_type == "test": 288 | label = "contradiction" 289 | else: 290 | label = tokenization.convert_to_unicode(line[-1]) 291 | examples.append( 292 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 293 | return examples 294 | 295 | 296 | class MrpcProcessor(DataProcessor): 297 | """Processor for the MRPC data set (GLUE version).""" 298 | 299 | def get_train_examples(self, data_dir): 300 | """See base class.""" 301 | return self._create_examples( 302 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 303 | 304 | def get_dev_examples(self, data_dir): 305 | """See base class.""" 306 | return self._create_examples( 307 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 308 | 309 | def get_test_examples(self, data_dir): 310 | """See base class.""" 311 | return self._create_examples( 312 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 313 | 314 | def get_labels(self): 315 | """See base class.""" 316 | return ["0", "1"] 317 | 318 | def _create_examples(self, lines, set_type): 319 | """Creates examples for the training and dev sets.""" 320 | examples = [] 321 | for (i, line) in enumerate(lines): 322 | if i == 0: 323 | continue 324 | guid = "%s-%s" % (set_type, i) 325 | text_a = tokenization.convert_to_unicode(line[3]) 326 | text_b = tokenization.convert_to_unicode(line[4]) 327 | if set_type == "test": 328 | label = "0" 329 | else: 330 | label = tokenization.convert_to_unicode(line[0]) 331 | examples.append( 332 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 333 | return examples 334 | 335 | 336 | class ColaProcessor(DataProcessor): 337 | """Processor for the CoLA data set (GLUE version).""" 338 | 339 | def get_train_examples(self, data_dir): 340 | """See base class.""" 341 | return self._create_examples( 342 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 343 | 344 | def get_dev_examples(self, data_dir): 345 | """See base class.""" 346 | return self._create_examples( 347 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 348 | 349 | def get_test_examples(self, data_dir): 350 | """See base class.""" 351 | return self._create_examples( 352 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 353 | 354 | def get_labels(self): 355 | """See base class.""" 356 | return ["0", "1"] 357 | 358 | def _create_examples(self, lines, set_type): 359 | """Creates examples for the training and dev sets.""" 360 | examples = [] 361 | for (i, line) in enumerate(lines): 362 | # Only the test set has a header 363 | if set_type == "test" and i == 0: 364 | continue 365 | guid = "%s-%s" % (set_type, i) 366 | if set_type == "test": 367 | text_a = tokenization.convert_to_unicode(line[1]) 368 | label = "0" 369 | else: 370 | text_a = tokenization.convert_to_unicode(line[3]) 371 | label = tokenization.convert_to_unicode(line[1]) 372 | examples.append( 373 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 374 | return examples 375 | 376 | 377 | def convert_single_example(ex_index, example, label_list, max_seq_length, 378 | tokenizer): 379 | """Converts a single `InputExample` into a single `InputFeatures`.""" 380 | 381 | if isinstance(example, PaddingInputExample): 382 | return InputFeatures( 383 | input_ids=[0] * max_seq_length, 384 | input_mask=[0] * max_seq_length, 385 | segment_ids=[0] * max_seq_length, 386 | label_id=0, 387 | is_real_example=False) 388 | 389 | label_map = {} 390 | for (i, label) in enumerate(label_list): 391 | label_map[label] = i 392 | 393 | tokens_a = tokenizer.tokenize(example.text_a) 394 | tokens_b = None 395 | if example.text_b: 396 | tokens_b = tokenizer.tokenize(example.text_b) 397 | 398 | if tokens_b: 399 | # Modifies `tokens_a` and `tokens_b` in place so that the total 400 | # length is less than the specified length. 401 | # Account for [CLS], [SEP], [SEP] with "- 3" 402 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 403 | else: 404 | # Account for [CLS] and [SEP] with "- 2" 405 | if len(tokens_a) > max_seq_length - 2: 406 | tokens_a = tokens_a[0:(max_seq_length - 2)] 407 | 408 | # The convention in BERT is: 409 | # (a) For sequence pairs: 410 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 411 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 412 | # (b) For single sequences: 413 | # tokens: [CLS] the dog is hairy . [SEP] 414 | # type_ids: 0 0 0 0 0 0 0 415 | # 416 | # Where "type_ids" are used to indicate whether this is the first 417 | # sequence or the second sequence. The embedding vectors for `type=0` and 418 | # `type=1` were learned during pre-training and are added to the wordpiece 419 | # embedding vector (and position vector). This is not *strictly* necessary 420 | # since the [SEP] token unambiguously separates the sequences, but it makes 421 | # it easier for the model to learn the concept of sequences. 422 | # 423 | # For classification tasks, the first vector (corresponding to [CLS]) is 424 | # used as the "sentence vector". Note that this only makes sense because 425 | # the entire model is fine-tuned. 426 | tokens = [] 427 | segment_ids = [] 428 | tokens.append("[CLS]") 429 | segment_ids.append(0) 430 | for token in tokens_a: 431 | tokens.append(token) 432 | segment_ids.append(0) 433 | tokens.append("[SEP]") 434 | segment_ids.append(0) 435 | 436 | if tokens_b: 437 | for token in tokens_b: 438 | tokens.append(token) 439 | segment_ids.append(1) 440 | tokens.append("[SEP]") 441 | segment_ids.append(1) 442 | 443 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 444 | 445 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 446 | # tokens are attended to. 447 | input_mask = [1] * len(input_ids) 448 | 449 | # Zero-pad up to the sequence length. 450 | while len(input_ids) < max_seq_length: 451 | input_ids.append(0) 452 | input_mask.append(0) 453 | segment_ids.append(0) 454 | 455 | assert len(input_ids) == max_seq_length 456 | assert len(input_mask) == max_seq_length 457 | assert len(segment_ids) == max_seq_length 458 | 459 | label_id = label_map[example.label] 460 | if ex_index < 5: 461 | tf.logging.info("*** Example ***") 462 | tf.logging.info("guid: %s" % (example.guid)) 463 | tf.logging.info("tokens: %s" % " ".join( 464 | [tokenization.printable_text(x) for x in tokens])) 465 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 466 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 467 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 468 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 469 | 470 | feature = InputFeatures( 471 | input_ids=input_ids, 472 | input_mask=input_mask, 473 | segment_ids=segment_ids, 474 | label_id=label_id, 475 | is_real_example=True) 476 | return feature 477 | 478 | 479 | def file_based_convert_examples_to_features( 480 | examples, label_list, max_seq_length, tokenizer, output_file): 481 | """Convert a set of `InputExample`s to a TFRecord file.""" 482 | 483 | writer = tf.python_io.TFRecordWriter(output_file) 484 | 485 | for (ex_index, example) in enumerate(examples): 486 | if ex_index % 10000 == 0: 487 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 488 | 489 | feature = convert_single_example(ex_index, example, label_list, 490 | max_seq_length, tokenizer) 491 | 492 | def create_int_feature(values): 493 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 494 | return f 495 | 496 | features = collections.OrderedDict() 497 | features["input_ids"] = create_int_feature(feature.input_ids) 498 | features["input_mask"] = create_int_feature(feature.input_mask) 499 | features["segment_ids"] = create_int_feature(feature.segment_ids) 500 | features["label_ids"] = create_int_feature([feature.label_id]) 501 | features["is_real_example"] = create_int_feature( 502 | [int(feature.is_real_example)]) 503 | 504 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 505 | writer.write(tf_example.SerializeToString()) 506 | writer.close() 507 | 508 | 509 | def file_based_input_fn_builder(input_file, seq_length, is_training, 510 | drop_remainder): 511 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 512 | 513 | name_to_features = { 514 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 515 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 516 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 517 | "label_ids": tf.FixedLenFeature([], tf.int64), 518 | "is_real_example": tf.FixedLenFeature([], tf.int64), 519 | } 520 | 521 | def _decode_record(record, name_to_features): 522 | """Decodes a record to a TensorFlow example.""" 523 | example = tf.parse_single_example(record, name_to_features) 524 | 525 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 526 | # So cast all int64 to int32. 527 | for name in list(example.keys()): 528 | t = example[name] 529 | if t.dtype == tf.int64: 530 | t = tf.to_int32(t) 531 | example[name] = t 532 | 533 | return example 534 | 535 | def input_fn(params): 536 | """The actual input function.""" 537 | batch_size = params["batch_size"] 538 | 539 | # For training, we want a lot of parallel reading and shuffling. 540 | # For eval, we want no shuffling and parallel reading doesn't matter. 541 | d = tf.data.TFRecordDataset(input_file) 542 | if is_training: 543 | d = d.repeat() 544 | d = d.shuffle(buffer_size=100) 545 | 546 | d = d.apply( 547 | tf.contrib.data.map_and_batch( 548 | lambda record: _decode_record(record, name_to_features), 549 | batch_size=batch_size, 550 | drop_remainder=drop_remainder)) 551 | 552 | return d 553 | 554 | return input_fn 555 | 556 | 557 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 558 | """Truncates a sequence pair in place to the maximum length.""" 559 | 560 | # This is a simple heuristic which will always truncate the longer sequence 561 | # one token at a time. This makes more sense than truncating an equal percent 562 | # of tokens from each, since if one sequence is very short then each token 563 | # that's truncated likely contains more information than a longer sequence. 564 | while True: 565 | total_length = len(tokens_a) + len(tokens_b) 566 | if total_length <= max_length: 567 | break 568 | if len(tokens_a) > len(tokens_b): 569 | tokens_a.pop() 570 | else: 571 | tokens_b.pop() 572 | 573 | 574 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 575 | labels, num_labels, use_one_hot_embeddings): 576 | """Creates a classification model.""" 577 | model = modeling.BertModel( 578 | config=bert_config, 579 | is_training=is_training, 580 | input_ids=input_ids, 581 | input_mask=input_mask, 582 | token_type_ids=segment_ids, 583 | use_one_hot_embeddings=use_one_hot_embeddings) 584 | 585 | # In the demo, we are doing a simple classification task on the entire 586 | # segment. 587 | # 588 | # If you want to use the token-level output, use model.get_sequence_output() 589 | # instead. 590 | output_layer = model.get_pooled_output() 591 | 592 | hidden_size = output_layer.shape[-1].value 593 | 594 | output_weights = tf.get_variable( 595 | "output_weights", [num_labels, hidden_size], 596 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 597 | 598 | output_bias = tf.get_variable( 599 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 600 | 601 | with tf.variable_scope("loss"): 602 | if is_training: 603 | # I.e., 0.1 dropout 604 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 605 | 606 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 607 | logits = tf.nn.bias_add(logits, output_bias) 608 | probabilities = tf.nn.softmax(logits, axis=-1) 609 | log_probs = tf.nn.log_softmax(logits, axis=-1) 610 | 611 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 612 | 613 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 614 | loss = tf.reduce_mean(per_example_loss) 615 | 616 | return (loss, per_example_loss, logits, probabilities) 617 | 618 | 619 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 620 | num_train_steps, num_warmup_steps, use_tpu, 621 | use_one_hot_embeddings): 622 | """Returns `model_fn` closure for TPUEstimator.""" 623 | 624 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 625 | """The `model_fn` for TPUEstimator.""" 626 | 627 | tf.logging.info("*** Features ***") 628 | for name in sorted(features.keys()): 629 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 630 | 631 | input_ids = features["input_ids"] 632 | input_mask = features["input_mask"] 633 | segment_ids = features["segment_ids"] 634 | label_ids = features["label_ids"] 635 | is_real_example = None 636 | if "is_real_example" in features: 637 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 638 | else: 639 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 640 | 641 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 642 | 643 | (total_loss, per_example_loss, logits, probabilities) = create_model( 644 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 645 | num_labels, use_one_hot_embeddings) 646 | 647 | tvars = tf.trainable_variables() 648 | initialized_variable_names = {} 649 | scaffold_fn = None 650 | if init_checkpoint: 651 | (assignment_map, initialized_variable_names 652 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 653 | if use_tpu: 654 | 655 | def tpu_scaffold(): 656 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 657 | return tf.train.Scaffold() 658 | 659 | scaffold_fn = tpu_scaffold 660 | else: 661 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 662 | 663 | tf.logging.info("**** Trainable Variables ****") 664 | for var in tvars: 665 | init_string = "" 666 | if var.name in initialized_variable_names: 667 | init_string = ", *INIT_FROM_CKPT*" 668 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 669 | init_string) 670 | 671 | output_spec = None 672 | if mode == tf.estimator.ModeKeys.TRAIN: 673 | 674 | train_op = optimization.create_optimizer( 675 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 676 | 677 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 678 | mode=mode, 679 | loss=total_loss, 680 | train_op=train_op, 681 | scaffold_fn=scaffold_fn) 682 | elif mode == tf.estimator.ModeKeys.EVAL: 683 | 684 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 685 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 686 | accuracy = tf.metrics.accuracy( 687 | labels=label_ids, predictions=predictions, weights=is_real_example) 688 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 689 | return { 690 | "eval_accuracy": accuracy, 691 | "eval_loss": loss, 692 | } 693 | 694 | eval_metrics = (metric_fn, 695 | [per_example_loss, label_ids, logits, is_real_example]) 696 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 697 | mode=mode, 698 | loss=total_loss, 699 | eval_metrics=eval_metrics, 700 | scaffold_fn=scaffold_fn) 701 | else: 702 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 703 | mode=mode, 704 | predictions={"probabilities": probabilities}, 705 | scaffold_fn=scaffold_fn) 706 | return output_spec 707 | 708 | return model_fn 709 | 710 | 711 | # This function is not used by this file but is still used by the Colab and 712 | # people who depend on it. 713 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 714 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 715 | 716 | all_input_ids = [] 717 | all_input_mask = [] 718 | all_segment_ids = [] 719 | all_label_ids = [] 720 | 721 | for feature in features: 722 | all_input_ids.append(feature.input_ids) 723 | all_input_mask.append(feature.input_mask) 724 | all_segment_ids.append(feature.segment_ids) 725 | all_label_ids.append(feature.label_id) 726 | 727 | def input_fn(params): 728 | """The actual input function.""" 729 | batch_size = params["batch_size"] 730 | 731 | num_examples = len(features) 732 | 733 | # This is for demo purposes and does NOT scale to large data sets. We do 734 | # not use Dataset.from_generator() because that uses tf.py_func which is 735 | # not TPU compatible. The right way to load data is with TFRecordReader. 736 | d = tf.data.Dataset.from_tensor_slices({ 737 | "input_ids": 738 | tf.constant( 739 | all_input_ids, shape=[num_examples, seq_length], 740 | dtype=tf.int32), 741 | "input_mask": 742 | tf.constant( 743 | all_input_mask, 744 | shape=[num_examples, seq_length], 745 | dtype=tf.int32), 746 | "segment_ids": 747 | tf.constant( 748 | all_segment_ids, 749 | shape=[num_examples, seq_length], 750 | dtype=tf.int32), 751 | "label_ids": 752 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 753 | }) 754 | 755 | if is_training: 756 | d = d.repeat() 757 | d = d.shuffle(buffer_size=100) 758 | 759 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 760 | return d 761 | 762 | return input_fn 763 | 764 | 765 | # This function is not used by this file but is still used by the Colab and 766 | # people who depend on it. 767 | def convert_examples_to_features(examples, label_list, max_seq_length, 768 | tokenizer): 769 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 770 | 771 | features = [] 772 | for (ex_index, example) in enumerate(examples): 773 | if ex_index % 10000 == 0: 774 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 775 | 776 | feature = convert_single_example(ex_index, example, label_list, 777 | max_seq_length, tokenizer) 778 | 779 | features.append(feature) 780 | return features 781 | 782 | 783 | def main(_): 784 | tf.logging.set_verbosity(tf.logging.INFO) 785 | 786 | processors = { 787 | "cola": ColaProcessor, 788 | "mnli": MnliProcessor, 789 | "mrpc": MrpcProcessor, 790 | "xnli": XnliProcessor, 791 | } 792 | 793 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 794 | FLAGS.init_checkpoint) 795 | 796 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 797 | raise ValueError( 798 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 799 | 800 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 801 | 802 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 803 | raise ValueError( 804 | "Cannot use sequence length %d because the BERT model " 805 | "was only trained up to sequence length %d" % 806 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 807 | 808 | tf.gfile.MakeDirs(FLAGS.output_dir) 809 | 810 | task_name = FLAGS.task_name.lower() 811 | 812 | if task_name not in processors: 813 | raise ValueError("Task not found: %s" % (task_name)) 814 | 815 | processor = processors[task_name]() 816 | 817 | label_list = processor.get_labels() 818 | 819 | tokenizer = tokenization.FullTokenizer( 820 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 821 | 822 | tpu_cluster_resolver = None 823 | if FLAGS.use_tpu and FLAGS.tpu_name: 824 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 825 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 826 | 827 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 828 | run_config = tf.contrib.tpu.RunConfig( 829 | cluster=tpu_cluster_resolver, 830 | master=FLAGS.master, 831 | model_dir=FLAGS.output_dir, 832 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 833 | tpu_config=tf.contrib.tpu.TPUConfig( 834 | iterations_per_loop=FLAGS.iterations_per_loop, 835 | num_shards=FLAGS.num_tpu_cores, 836 | per_host_input_for_training=is_per_host)) 837 | 838 | train_examples = None 839 | num_train_steps = None 840 | num_warmup_steps = None 841 | if FLAGS.do_train: 842 | train_examples = processor.get_train_examples(FLAGS.data_dir) 843 | num_train_steps = int( 844 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 845 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 846 | 847 | model_fn = model_fn_builder( 848 | bert_config=bert_config, 849 | num_labels=len(label_list), 850 | init_checkpoint=FLAGS.init_checkpoint, 851 | learning_rate=FLAGS.learning_rate, 852 | num_train_steps=num_train_steps, 853 | num_warmup_steps=num_warmup_steps, 854 | use_tpu=FLAGS.use_tpu, 855 | use_one_hot_embeddings=FLAGS.use_tpu) 856 | 857 | # If TPU is not available, this will fall back to normal Estimator on CPU 858 | # or GPU. 859 | estimator = tf.contrib.tpu.TPUEstimator( 860 | use_tpu=FLAGS.use_tpu, 861 | model_fn=model_fn, 862 | config=run_config, 863 | train_batch_size=FLAGS.train_batch_size, 864 | eval_batch_size=FLAGS.eval_batch_size, 865 | predict_batch_size=FLAGS.predict_batch_size) 866 | 867 | if FLAGS.do_train: 868 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 869 | file_based_convert_examples_to_features( 870 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 871 | tf.logging.info("***** Running training *****") 872 | tf.logging.info(" Num examples = %d", len(train_examples)) 873 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 874 | tf.logging.info(" Num steps = %d", num_train_steps) 875 | train_input_fn = file_based_input_fn_builder( 876 | input_file=train_file, 877 | seq_length=FLAGS.max_seq_length, 878 | is_training=True, 879 | drop_remainder=True) 880 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 881 | 882 | if FLAGS.do_eval: 883 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 884 | num_actual_eval_examples = len(eval_examples) 885 | if FLAGS.use_tpu: 886 | # TPU requires a fixed batch size for all batches, therefore the number 887 | # of examples must be a multiple of the batch size, or else examples 888 | # will get dropped. So we pad with fake examples which are ignored 889 | # later on. These do NOT count towards the metric (all tf.metrics 890 | # support a per-instance weight, and these get a weight of 0.0). 891 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 892 | eval_examples.append(PaddingInputExample()) 893 | 894 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 895 | file_based_convert_examples_to_features( 896 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 897 | 898 | tf.logging.info("***** Running evaluation *****") 899 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 900 | len(eval_examples), num_actual_eval_examples, 901 | len(eval_examples) - num_actual_eval_examples) 902 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 903 | 904 | # This tells the estimator to run through the entire set. 905 | eval_steps = None 906 | # However, if running eval on the TPU, you will need to specify the 907 | # number of steps. 908 | if FLAGS.use_tpu: 909 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 910 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 911 | 912 | eval_drop_remainder = True if FLAGS.use_tpu else False 913 | eval_input_fn = file_based_input_fn_builder( 914 | input_file=eval_file, 915 | seq_length=FLAGS.max_seq_length, 916 | is_training=False, 917 | drop_remainder=eval_drop_remainder) 918 | 919 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 920 | 921 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 922 | with tf.gfile.GFile(output_eval_file, "w") as writer: 923 | tf.logging.info("***** Eval results *****") 924 | for key in sorted(result.keys()): 925 | tf.logging.info(" %s = %s", key, str(result[key])) 926 | writer.write("%s = %s\n" % (key, str(result[key]))) 927 | 928 | if FLAGS.do_predict: 929 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 930 | num_actual_predict_examples = len(predict_examples) 931 | if FLAGS.use_tpu: 932 | # TPU requires a fixed batch size for all batches, therefore the number 933 | # of examples must be a multiple of the batch size, or else examples 934 | # will get dropped. So we pad with fake examples which are ignored 935 | # later on. 936 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 937 | predict_examples.append(PaddingInputExample()) 938 | 939 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 940 | file_based_convert_examples_to_features(predict_examples, label_list, 941 | FLAGS.max_seq_length, tokenizer, 942 | predict_file) 943 | 944 | tf.logging.info("***** Running prediction*****") 945 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 946 | len(predict_examples), num_actual_predict_examples, 947 | len(predict_examples) - num_actual_predict_examples) 948 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 949 | 950 | predict_drop_remainder = True if FLAGS.use_tpu else False 951 | predict_input_fn = file_based_input_fn_builder( 952 | input_file=predict_file, 953 | seq_length=FLAGS.max_seq_length, 954 | is_training=False, 955 | drop_remainder=predict_drop_remainder) 956 | 957 | result = estimator.predict(input_fn=predict_input_fn) 958 | 959 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 960 | with tf.gfile.GFile(output_predict_file, "w") as writer: 961 | num_written_lines = 0 962 | tf.logging.info("***** Predict results *****") 963 | for (i, prediction) in enumerate(result): 964 | probabilities = prediction["probabilities"] 965 | if i >= num_actual_predict_examples: 966 | break 967 | output_line = "\t".join( 968 | str(class_probability) 969 | for class_probability in probabilities) + "\n" 970 | writer.write(output_line) 971 | num_written_lines += 1 972 | assert num_written_lines == num_actual_predict_examples 973 | 974 | 975 | if __name__ == "__main__": 976 | flags.mark_flag_as_required("data_dir") 977 | flags.mark_flag_as_required("task_name") 978 | flags.mark_flag_as_required("vocab_file") 979 | flags.mark_flag_as_required("bert_config_file") 980 | flags.mark_flag_as_required("output_dir") 981 | tf.app.run() 982 | -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | scaffold_fn = None 154 | if init_checkpoint: 155 | (assignment_map, initialized_variable_names 156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 157 | if use_tpu: 158 | 159 | def tpu_scaffold(): 160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 161 | return tf.train.Scaffold() 162 | 163 | scaffold_fn = tpu_scaffold 164 | else: 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | 167 | tf.logging.info("**** Trainable Variables ****") 168 | for var in tvars: 169 | init_string = "" 170 | if var.name in initialized_variable_names: 171 | init_string = ", *INIT_FROM_CKPT*" 172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 173 | init_string) 174 | 175 | output_spec = None 176 | if mode == tf.estimator.ModeKeys.TRAIN: 177 | train_op = optimization.create_optimizer( 178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 179 | 180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | train_op=train_op, 184 | scaffold_fn=scaffold_fn) 185 | elif mode == tf.estimator.ModeKeys.EVAL: 186 | 187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 188 | masked_lm_weights, next_sentence_example_loss, 189 | next_sentence_log_probs, next_sentence_labels): 190 | """Computes the loss and accuracy of the model.""" 191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 192 | [-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax( 194 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 198 | masked_lm_accuracy = tf.metrics.accuracy( 199 | labels=masked_lm_ids, 200 | predictions=masked_lm_predictions, 201 | weights=masked_lm_weights) 202 | masked_lm_mean_loss = tf.metrics.mean( 203 | values=masked_lm_example_loss, weights=masked_lm_weights) 204 | 205 | next_sentence_log_probs = tf.reshape( 206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 207 | next_sentence_predictions = tf.argmax( 208 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 210 | next_sentence_accuracy = tf.metrics.accuracy( 211 | labels=next_sentence_labels, predictions=next_sentence_predictions) 212 | next_sentence_mean_loss = tf.metrics.mean( 213 | values=next_sentence_example_loss) 214 | 215 | return { 216 | "masked_lm_accuracy": masked_lm_accuracy, 217 | "masked_lm_loss": masked_lm_mean_loss, 218 | "next_sentence_accuracy": next_sentence_accuracy, 219 | "next_sentence_loss": next_sentence_mean_loss, 220 | } 221 | 222 | eval_metrics = (metric_fn, [ 223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 224 | masked_lm_weights, next_sentence_example_loss, 225 | next_sentence_log_probs, next_sentence_labels 226 | ]) 227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 228 | mode=mode, 229 | loss=total_loss, 230 | eval_metrics=eval_metrics, 231 | scaffold_fn=scaffold_fn) 232 | else: 233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 234 | 235 | return output_spec 236 | 237 | return model_fn 238 | 239 | 240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 241 | label_ids, label_weights): 242 | """Get loss and log probs for the masked LM.""" 243 | input_tensor = gather_indexes(input_tensor, positions) 244 | 245 | with tf.variable_scope("cls/predictions"): 246 | # We apply one more non-linear transformation before the output layer. 247 | # This matrix is not used after pre-training. 248 | with tf.variable_scope("transform"): 249 | input_tensor = tf.layers.dense( 250 | input_tensor, 251 | units=bert_config.hidden_size, 252 | activation=modeling.get_activation(bert_config.hidden_act), 253 | kernel_initializer=modeling.create_initializer( 254 | bert_config.initializer_range)) 255 | input_tensor = modeling.layer_norm(input_tensor) 256 | 257 | # The output weights are the same as the input embeddings, but there is 258 | # an output-only bias for each token. 259 | output_bias = tf.get_variable( 260 | "output_bias", 261 | shape=[bert_config.vocab_size], 262 | initializer=tf.zeros_initializer()) 263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 264 | logits = tf.nn.bias_add(logits, output_bias) 265 | log_probs = tf.nn.log_softmax(logits, axis=-1) 266 | 267 | label_ids = tf.reshape(label_ids, [-1]) 268 | label_weights = tf.reshape(label_weights, [-1]) 269 | 270 | one_hot_labels = tf.one_hot( 271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 272 | 273 | # The `positions` tensor might be zero-padded (if the sequence is too 274 | # short to have the maximum number of predictions). The `label_weights` 275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 276 | # padding predictions. 277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 278 | numerator = tf.reduce_sum(label_weights * per_example_loss) 279 | denominator = tf.reduce_sum(label_weights) + 1e-5 280 | loss = numerator / denominator 281 | 282 | return (loss, per_example_loss, log_probs) 283 | 284 | 285 | def get_next_sentence_output(bert_config, input_tensor, labels): 286 | """Get loss and log probs for the next sentence prediction.""" 287 | 288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 289 | # "random sentence". This weight matrix is not used after pre-training. 290 | with tf.variable_scope("cls/seq_relationship"): 291 | output_weights = tf.get_variable( 292 | "output_weights", 293 | shape=[2, bert_config.hidden_size], 294 | initializer=modeling.create_initializer(bert_config.initializer_range)) 295 | output_bias = tf.get_variable( 296 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 297 | 298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 299 | logits = tf.nn.bias_add(logits, output_bias) 300 | log_probs = tf.nn.log_softmax(logits, axis=-1) 301 | labels = tf.reshape(labels, [-1]) 302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 304 | loss = tf.reduce_mean(per_example_loss) 305 | return (loss, per_example_loss, log_probs) 306 | 307 | 308 | def gather_indexes(sequence_tensor, positions): 309 | """Gathers the vectors at the specific positions over a minibatch.""" 310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 311 | batch_size = sequence_shape[0] 312 | seq_length = sequence_shape[1] 313 | width = sequence_shape[2] 314 | 315 | flat_offsets = tf.reshape( 316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 317 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 318 | flat_sequence_tensor = tf.reshape(sequence_tensor, 319 | [batch_size * seq_length, width]) 320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 321 | return output_tensor 322 | 323 | 324 | def input_fn_builder(input_files, 325 | max_seq_length, 326 | max_predictions_per_seq, 327 | is_training, 328 | num_cpu_threads=4): 329 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | batch_size = params["batch_size"] 334 | 335 | name_to_features = { 336 | "input_ids": 337 | tf.FixedLenFeature([max_seq_length], tf.int64), 338 | "input_mask": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "segment_ids": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "masked_lm_positions": 343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 344 | "masked_lm_ids": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_weights": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 348 | "next_sentence_labels": 349 | tf.FixedLenFeature([1], tf.int64), 350 | } 351 | 352 | # For training, we want a lot of parallel reading and shuffling. 353 | # For eval, we want no shuffling and parallel reading doesn't matter. 354 | if is_training: 355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 356 | d = d.repeat() 357 | d = d.shuffle(buffer_size=len(input_files)) 358 | 359 | # `cycle_length` is the number of parallel files that get read. 360 | cycle_length = min(num_cpu_threads, len(input_files)) 361 | 362 | # `sloppy` mode means that the interleaving is not exact. This adds 363 | # even more randomness to the training pipeline. 364 | d = d.apply( 365 | tf.contrib.data.parallel_interleave( 366 | tf.data.TFRecordDataset, 367 | sloppy=is_training, 368 | cycle_length=cycle_length)) 369 | d = d.shuffle(buffer_size=100) 370 | else: 371 | d = tf.data.TFRecordDataset(input_files) 372 | # Since we evaluate for a fixed number of steps we don't want to encounter 373 | # out-of-range exceptions. 374 | d = d.repeat() 375 | 376 | # We must `drop_remainder` on training because the TPU requires fixed 377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 378 | # and we *don't* want to drop the remainder, otherwise we wont cover 379 | # every sample. 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size, 384 | num_parallel_batches=num_cpu_threads, 385 | drop_remainder=True)) 386 | return d 387 | 388 | return input_fn 389 | 390 | 391 | def _decode_record(record, name_to_features): 392 | """Decodes a record to a TensorFlow example.""" 393 | example = tf.parse_single_example(record, name_to_features) 394 | 395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 396 | # So cast all int64 to int32. 397 | for name in list(example.keys()): 398 | t = example[name] 399 | if t.dtype == tf.int64: 400 | t = tf.to_int32(t) 401 | example[name] = t 402 | 403 | return example 404 | 405 | 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | if not FLAGS.do_train and not FLAGS.do_eval: 410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 411 | 412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 413 | 414 | tf.gfile.MakeDirs(FLAGS.output_dir) 415 | 416 | input_files = [] 417 | for input_pattern in FLAGS.input_file.split(","): 418 | input_files.extend(tf.gfile.Glob(input_pattern)) 419 | 420 | tf.logging.info("*** Input Files ***") 421 | for input_file in input_files: 422 | tf.logging.info(" %s" % input_file) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | model_fn = model_fn_builder( 441 | bert_config=bert_config, 442 | init_checkpoint=FLAGS.init_checkpoint, 443 | learning_rate=FLAGS.learning_rate, 444 | num_train_steps=FLAGS.num_train_steps, 445 | num_warmup_steps=FLAGS.num_warmup_steps, 446 | use_tpu=FLAGS.use_tpu, 447 | use_one_hot_embeddings=FLAGS.use_tpu) 448 | 449 | # If TPU is not available, this will fall back to normal Estimator on CPU 450 | # or GPU. 451 | estimator = tf.contrib.tpu.TPUEstimator( 452 | use_tpu=FLAGS.use_tpu, 453 | model_fn=model_fn, 454 | config=run_config, 455 | train_batch_size=FLAGS.train_batch_size, 456 | eval_batch_size=FLAGS.eval_batch_size) 457 | 458 | if FLAGS.do_train: 459 | tf.logging.info("***** Running training *****") 460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 461 | train_input_fn = input_fn_builder( 462 | input_files=input_files, 463 | max_seq_length=FLAGS.max_seq_length, 464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 465 | is_training=True) 466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 467 | 468 | if FLAGS.do_eval: 469 | tf.logging.info("***** Running evaluation *****") 470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 471 | 472 | eval_input_fn = input_fn_builder( 473 | input_files=input_files, 474 | max_seq_length=FLAGS.max_seq_length, 475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 476 | is_training=False) 477 | 478 | result = estimator.evaluate( 479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 480 | 481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 482 | with tf.gfile.GFile(output_eval_file, "w") as writer: 483 | tf.logging.info("***** Eval results *****") 484 | for key in sorted(result.keys()): 485 | tf.logging.info(" %s = %s", key, str(result[key])) 486 | writer.write("%s = %s\n" % (key, str(result[key]))) 487 | 488 | 489 | if __name__ == "__main__": 490 | flags.mark_flag_as_required("input_file") 491 | flags.mark_flag_as_required("bert_config_file") 492 | flags.mark_flag_as_required("output_dir") 493 | tf.app.run() 494 | -------------------------------------------------------------------------------- /squad_code/data_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | def make_batch(size, batch_size): 6 | nb_batch = int(np.ceil(size/float(batch_size))) 7 | return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)] # zgwang: starting point of each batch 8 | 9 | def pad_2d(in_vals, dim1_size, dim2_size, dtype=np.int32): 10 | out_val = np.zeros((dim1_size, dim2_size), dtype=dtype) 11 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 12 | for i in range(dim1_size): 13 | cur_in_vals = in_vals[i] 14 | cur_dim2_size = dim2_size 15 | if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals) 16 | out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size] 17 | return out_val 18 | 19 | def pad_3d(in_vals, dim1_size, dim2_size, dim3_size, dtype=np.int32): 20 | out_val = np.zeros((dim1_size, dim2_size, dim3_size), dtype=dtype) 21 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 22 | for i in range(dim1_size): 23 | in_vals_i = in_vals[i] 24 | cur_dim2_size = dim2_size 25 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i) 26 | for j in range(cur_dim2_size): 27 | in_vals_ij = in_vals_i[j] 28 | cur_dim3_size = dim3_size 29 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij) 30 | out_val[i, j, :cur_dim3_size] = in_vals_ij[:cur_dim3_size] 31 | return out_val 32 | 33 | 34 | class OneBatch(object): 35 | def __init__(self, current_batch, config): 36 | self.target_batch1 = [] 37 | self.target_batch2 = [] 38 | self.input_batch = [] 39 | self.type_batch = [] 40 | self.header_len_batch = [] 41 | self.context_len_batch = [] 42 | self.question_len_batch = [] 43 | self.input_mask_batch = [] 44 | # self.answer_batch = [] 45 | for (answer_start,answer_end,input_id,type_id, 46 | header_len,context_len,question_len,input_mask) in current_batch: 47 | self.input_batch.append(input_id) 48 | self.target_batch1.append(answer_start) 49 | self.target_batch2.append(answer_end) 50 | self.type_batch.append(type_id) 51 | self.header_len_batch.append(header_len) 52 | self.context_len_batch.append(context_len) 53 | self.question_len_batch.append(question_len) 54 | self.input_mask_batch.append(input_mask) 55 | """ 56 | to numpy 57 | """ 58 | # self.all_cell_batch=np.array(self.all_cell_batch, dtype=np.int32) 59 | # self.question_batch = np.array(self.question_batch, dtype=np.int32) 60 | # self.target_batch = np.array(self.target_batch, dtype=np.int32) 61 | """ 62 | padding 63 | """ 64 | # self.all_cell_batch=pad_3d(self.all_cell_batch,dim1_size=config["batch_size"], 65 | # dim2_size=config["max_cell_num"],dim3_size=config["max_word_num"]) 66 | # self.question_batch=pad_2d(self.question_batch,dim1_size=config["batch_size"], 67 | # dim2_size=config["max_question_len"]) 68 | # self.target_batch = pad_2d(self.target_batch, dim1_size=config["batch_size"], 69 | # dim2_size=config["max_cell_num"]+config["max_question_len"]) 70 | # self.type_batch = np.array(self.type_batch) 71 | 72 | 73 | class DataUtil_bert(object): 74 | def __init__(self, json_path=None,config=None,tokenizer=None): 75 | 76 | f = open(json_path, mode="r", encoding="utf-8") 77 | jdata = json.load(f) 78 | all_case_list = [] 79 | count = 0 80 | for context_qas in jdata["data"][0]["paragraphs"]: 81 | count += 1 82 | if config["debug"]==True and count>100: 83 | break 84 | if count % 1000==0: 85 | print("data process ", count) 86 | context = [] 87 | for row in context_qas["context"]: 88 | context.extend(row) 89 | context_id = [] 90 | context_id.append(tokenizer.convert_tokens_to_ids(["[CLS]"])) 91 | for cell in context: 92 | context_id.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(cell))) 93 | context_id.append(tokenizer.convert_tokens_to_ids(["[SEP]"])) 94 | 95 | for qas in context_qas["qas"]: 96 | question = tokenizer.tokenize(qas["question"]) 97 | question = question + ["[SEP]"] 98 | question_id = tokenizer.convert_tokens_to_ids(question) 99 | # answer_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(qas["answers"][0]["text"])) 100 | 101 | flat_context_id = [] 102 | word_index = 0 103 | answer_start = [0] * config["max_input_len"] 104 | answer_end = [0] * config["max_input_len"] 105 | 106 | for cell_index,cell in enumerate(context_id): 107 | if word_index > 200: 108 | break 109 | if cell_index == qas["answers"][0]["answer_start"] + 1: 110 | answer_start[word_index]=1 111 | if cell_index == qas["answers"][0]["answer_start"] + 2: 112 | answer_end[word_index-1]=1 # include or not 113 | for word_id in cell: 114 | flat_context_id.append(word_id) 115 | word_index += 1 116 | 117 | header_len = len(context_qas["context"][0]) + 1 118 | context_len = len(flat_context_id) 119 | question_len = len(question_id) 120 | 121 | input_mask = (context_len + question_len)*[1] + (config["max_input_len"]-context_len-question_len)*[0] 122 | 123 | type_id = context_len*[0] + question_len*[1] + (config["max_input_len"]-context_len-question_len)*[0] 124 | 125 | input_id = flat_context_id + question_id + (config["max_input_len"]-context_len-question_len)*[0] 126 | 127 | if len(input_id) > config["max_input_len"]: 128 | continue 129 | 130 | all_case_list.append((answer_start,answer_end,input_id,type_id, 131 | header_len,context_len,question_len,input_mask)) 132 | print("data num", len(all_case_list)) 133 | batch_spans = make_batch(len(all_case_list), config["batch_size"] ) 134 | self.all_batch = [] 135 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 136 | current_batch = [] 137 | for i in range(batch_start, batch_end): 138 | current_batch.append(all_case_list[i]) 139 | if len(current_batch)= len(self.all_batch): 148 | self.cur_pointer = 0 149 | np.random.shuffle(self.index_array) 150 | cur_batch = self.all_batch[self.index_array[self.cur_pointer]] 151 | self.cur_pointer += 1 152 | return cur_batch 153 | 154 | def get_batch(self, i): 155 | if i >= len(self.all_batch): return None 156 | return self.all_batch[self.index_array[i]] -------------------------------------------------------------------------------- /squad_code/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import modeling 3 | 4 | class ModelWrapper(): 5 | def __init__(self, config, is_train): 6 | self.config = config 7 | self.header_len_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"]]) 8 | self.context_len_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"]]) 9 | self.question_len_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"]]) 10 | self.input_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"], 11 | config["max_input_len"]]) 12 | self.input_mask_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"], 13 | config["max_input_len"]]) 14 | self.target1_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"], 15 | config["max_input_len"]]) 16 | self.target2_placeholder = tf.placeholder(dtype=tf.int32, shape=[config["batch_size"], 17 | config["max_input_len"]]) 18 | 19 | self.type_placeholder = tf.placeholder(dtype=tf.int32,shape=[config["batch_size"], 20 | config["max_input_len"]]) 21 | main_model_config = modeling.BertConfig.from_json_file("../uncased_L-12_H-768_A-12/bert_config.json") 22 | 23 | main_model = modeling.BertModel(main_model_config, 24 | input_ids=self.input_placeholder, 25 | token_type_ids=self.type_placeholder, 26 | input_mask=self.input_mask_placeholder, 27 | is_training=is_train, 28 | use_one_hot_embeddings=False) 29 | 30 | final_hidden = main_model.get_sequence_output()#[:,:config["max_cell_num"],:] 31 | 32 | batch_size = config["batch_size"] 33 | seq_length = config["max_input_len"] 34 | hidden_size = main_model_config.hidden_size 35 | 36 | output_weights = tf.get_variable( 37 | "cls/squad/output_weights", [2, hidden_size], 38 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 39 | 40 | output_bias = tf.get_variable( 41 | "cls/squad/output_bias", [2], initializer=tf.zeros_initializer()) 42 | 43 | final_hidden_matrix = tf.reshape(final_hidden, 44 | [batch_size * seq_length, hidden_size]) 45 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) 46 | logits = tf.nn.bias_add(logits, output_bias) 47 | logits = tf.reshape(logits, [batch_size, seq_length, 2]) 48 | logits = tf.transpose(logits, [2, 0, 1]) 49 | unstacked_logits = tf.unstack(logits, axis=0) 50 | (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) 51 | 52 | self.start_logits = tf.reshape(start_logits, [batch_size, seq_length]) 53 | self.end_logits = tf.reshape(end_logits, [batch_size, seq_length]) 54 | 55 | self.context_mask = tf.sequence_mask(self.context_len_placeholder, maxlen=config["max_input_len"], dtype=tf.float32) 56 | self.header_mask = tf.sequence_mask(self.header_len_placeholder,maxlen=config["max_input_len"],dtype=tf.float32) 57 | self.target_mask = self.context_mask - self.header_mask 58 | 59 | self.output_logits1 = self.start_logits * self.target_mask 60 | self.output_logits2 = self.end_logits * self.target_mask 61 | 62 | log_probs1 = tf.nn.log_softmax(self.start_logits*self.target_mask, axis=-1) 63 | self.loss = -tf.reduce_mean( 64 | tf.reduce_sum(tf.cast(self.target1_placeholder, tf.float32) * log_probs1*self.target_mask, axis=-1)) 65 | 66 | log_probs2 = tf.nn.log_softmax(self.end_logits*self.target_mask, axis=-1) 67 | self.loss += -tf.reduce_mean( 68 | tf.reduce_sum(tf.cast(self.target2_placeholder, tf.float32) * log_probs2*self.target_mask, axis=-1)) 69 | 70 | if not is_train: return 71 | optimizer = tf.train.AdamOptimizer(learning_rate=config["learning_rate"]) 72 | 73 | # tvars = tf.trainable_variables() 74 | # l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1]) 75 | # self.loss = self.loss + 0.01 * l2_loss 76 | # self.train_op = optimizer.minimize(self.loss) 77 | 78 | def var_filter(var_list, last_layers): 79 | filter_keywords = ['layer_11', 'layer_10', 'layer_9', 'layer_8'] 80 | for var in var_list: 81 | if "bert" not in var.name: 82 | yield var 83 | else: 84 | for layer in last_layers: 85 | kw = filter_keywords[layer] 86 | if kw in var.name: 87 | yield var 88 | 89 | def compute_gradients(tensor, var_list): 90 | grads = tf.gradients(tensor, var_list) 91 | return [grad if grad is not None else tf.zeros_like(var) for var, grad in zip(var_list, grads)] 92 | 93 | tvars = list(var_filter(tf.trainable_variables(), last_layers=range(3))) 94 | grads = compute_gradients(self.loss, tvars) 95 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 96 | 97 | -------------------------------------------------------------------------------- /squad_code/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | import text.data_util as data_util 4 | import text.model_wrapper as model_wrapper 5 | import numpy as np 6 | import tokenization 7 | config = {} 8 | config["batch_size"] = 16 9 | config["max_input_len"] = 256 10 | # config["max_word_num"] = 10 # max word num in a cell 11 | # config["max_question_len"] = 30 12 | config["debug"] = False 13 | config["learning_rate"] = 5e-5 14 | 15 | tokenizer = tokenization.FullTokenizer(vocab_file="../uncased_L-12_H-768_A-12/vocab.txt") 16 | 17 | valid_data_util = data_util.DataUtil_bert(json_path="../squad_data/my_valid.json",config=config,tokenizer=tokenizer) 18 | train_data_util = data_util.DataUtil_bert(json_path="../squad_data/my_train.json",config=config,tokenizer=tokenizer) 19 | 20 | with tf.variable_scope("Model", reuse=False): 21 | train_model = model_wrapper.ModelWrapper(config, is_train = True) 22 | 23 | with tf.variable_scope("Model", reuse=True): 24 | valid_model = model_wrapper.ModelWrapper(config, is_train = False) 25 | 26 | 27 | sess = tf.Session() 28 | saver = tf.train.Saver() 29 | sess.run(tf.global_variables_initializer()) 30 | # import os 31 | # if os.path.exists("save_model"): 32 | # saver.restore(sess,"save_model/model") 33 | 34 | def optimistic_restore(session, save_file): 35 | """ 36 | restore only those variable that exists in the model 37 | :param session: 38 | :param save_file: 39 | :return: 40 | """ 41 | reader = tf.train.NewCheckpointReader(save_file) 42 | # reader.get_tensor() 43 | saved_shapes = reader.get_variable_to_shape_map() 44 | print(saved_shapes) 45 | print() 46 | print([var.name for var in tf.global_variables()]) 47 | 48 | restore_vars = { (v.name.split(':')[0].replace("Model/","").replace("Model/","")): v for v in tf.trainable_variables() if 'bert' in v.name} 49 | saver = tf.train.Saver(restore_vars) 50 | saver.restore(session, save_file) 51 | 52 | optimistic_restore(sess, "../uncased_L-12_H-768_A-12/bert_model.ckpt") 53 | 54 | correct_count = 0 55 | total_count = 0 56 | max_acc = 0 57 | best_accuracy = -1 58 | for epoch in range(50): 59 | print('Train in epoch %d' % epoch) 60 | num_batch = len(train_data_util.all_batch) 61 | start_time = time.time() 62 | total_loss = 0 63 | for batch_index in range(num_batch): # for each batch 64 | cur_batch = train_data_util.next_batch() 65 | train_feed_dict = { 66 | train_model.input_placeholder:cur_batch.input_batch, 67 | train_model.target2_placeholder:cur_batch.target_batch2, 68 | train_model.target1_placeholder:cur_batch.target_batch1, 69 | train_model.type_placeholder:cur_batch.type_batch, 70 | train_model.input_mask_placeholder:cur_batch.input_mask_batch, 71 | train_model.context_len_placeholder:cur_batch.context_len_batch, 72 | train_model.question_len_placeholder:cur_batch.question_len_batch, 73 | train_model.header_len_placeholder:cur_batch.header_len_batch, 74 | } 75 | _, loss_value,start_predict_logits,end_predict_logits = sess.run([train_model.train_op, 76 | train_model.loss, 77 | train_model.output_logits1, 78 | train_model.output_logits2], 79 | feed_dict=train_feed_dict) 80 | total_loss += loss_value 81 | for i in range(len(start_predict_logits)): 82 | predict_start_index = np.argmax(start_predict_logits[i]) 83 | true_start_index = np.argmax(cur_batch.target_batch1[i]) 84 | 85 | predict_end_index = np.argmax(end_predict_logits[i]) 86 | true_end_index = np.argmax(cur_batch.target_batch2[i]) 87 | 88 | if predict_start_index == true_start_index\ 89 | and predict_end_index == true_end_index \ 90 | or predict_start_index < config["max_input_len"] and \ 91 | (cur_batch.input_batch[i][predict_start_index:predict_end_index+1]== 92 | cur_batch.input_batch[i][true_start_index:true_end_index+1]): 93 | correct_count += 1 94 | total_count += 1 95 | if total_count!=0: 96 | print("train acc" , (correct_count / total_count)) 97 | duration = time.time() - start_time 98 | start_time = time.time() 99 | print('train loss = %.4f (%.3f sec)' % (total_loss / num_batch, duration)) 100 | 101 | correct_count = 0 102 | total_count = 0 103 | total_loss = 0 104 | 105 | num_valid_batch = len(valid_data_util.all_batch) 106 | for valid_batch_index in range(num_valid_batch): 107 | cur_valid_batch = valid_data_util.next_batch() 108 | valid_feed_dict = { 109 | valid_model.input_placeholder: cur_valid_batch.input_batch, 110 | valid_model.target2_placeholder: cur_valid_batch.target_batch2, 111 | valid_model.target1_placeholder: cur_valid_batch.target_batch1, 112 | valid_model.type_placeholder: cur_valid_batch.type_batch, 113 | valid_model.input_mask_placeholder: cur_valid_batch.input_mask_batch, 114 | valid_model.context_len_placeholder: cur_valid_batch.context_len_batch, 115 | valid_model.question_len_placeholder: cur_valid_batch.question_len_batch, 116 | valid_model.header_len_placeholder: cur_valid_batch.header_len_batch, 117 | } 118 | valid_loss_value, start_predict_logits, end_predict_logits = sess.run( 119 | [valid_model.loss, valid_model.output_logits1, valid_model.output_logits2], 120 | feed_dict=valid_feed_dict) 121 | total_loss += valid_loss_value 122 | for i in range(len(start_predict_logits)): 123 | predict_start_index = np.argmax(start_predict_logits[i]) 124 | true_start_index = np.argmax(cur_valid_batch.target_batch1[i]) 125 | 126 | predict_end_index = np.argmax(end_predict_logits[i]) 127 | true_end_index = np.argmax(cur_valid_batch.target_batch2[i]) 128 | 129 | if predict_start_index == true_start_index\ 130 | and predict_end_index == true_end_index \ 131 | or predict_start_index < config["max_input_len"] and \ 132 | (cur_valid_batch.input_batch[i][predict_start_index:predict_end_index+1]== 133 | cur_valid_batch.input_batch[i][true_start_index:true_end_index+1]): 134 | correct_count += 1 135 | total_count += 1 136 | duration = time.time() - start_time 137 | start_time = time.time() 138 | print('valid loss = %.4f (%.3f sec)' % (total_loss / num_valid_batch, duration)) 139 | if total_count!=0: 140 | print("valid acc" , (correct_count / total_count)) 141 | if total_count != 0: 142 | if correct_count / total_count > max_acc and epoch>5: 143 | max_acc = correct_count / total_count 144 | saver.save(sess,"save_model/model") 145 | correct_count = 0 146 | total_count = 0 147 | print("max acc",max_acc) -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import six 22 | import tensorflow as tf 23 | import tokenization 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /uncased_L-12_H-768_A-12/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | --------------------------------------------------------------------------------