├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── common_model.py ├── download.sh ├── feature_model.py ├── kernel_model.py ├── mips-qa.pdf ├── requirements.txt ├── squad_data.py ├── squad_prepro.py ├── squad_prepro_main.py ├── tf_utils.py └── train_and_eval.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2017 Google Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | **This is not an official Google product** 4 | 5 | This project contains code for training and running an extractive question answering model on the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). All methods and models contained in this project are described in the [technical report](https://github.com/google/mipsqa/blob/master/mips-qa.pdf). Any extensions of this work should cite the report as: 6 | 7 | ``` 8 | @misc{SeoKwiatParikh:2017, 9 | title = {Question Answering with Maximum Inner Product Search}, 10 | author = {Minjoon Seo and Tom Kwiatkowski and Ankur Parikh}, 11 | url = {...}, 12 | } 13 | ``` 14 | 15 | # 0. Requirements and data 16 | - Basic requirements: Python 2 or 3, wget (if using MacOS. You can also download yourself looking at the `download.sh` script) 17 | - Python packages: tensorflow 1.3.0 or higher, nltk, tqdm 18 | - Data: SQuAD, GloVe, nltk tokenizer 19 | 20 | To install required packages, run: 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | To download data, run: 26 | ```bash 27 | chmod +x download.sh; ./download.sh 28 | ``` 29 | 30 | Change the directories where the data is stored if needed, and use them for runs below. 31 | 32 | # 1. Train and test (draft mode) 33 | If you are using default directories for the data: 34 | ```bash 35 | export SQUAD_DIR=$HOME/data/squad 36 | export GLOVE_DIR=$HOME/data/glove 37 | ``` 38 | 39 | First, preprocess train data: 40 | ```bash 41 | python squad_prepro_main.py --from_dir $SQUAD_DIR --to_dir prepro/draft/sort_filter --glove_dir $GLOVE_DIR --sort --filter --draft 42 | ``` 43 | Note the `--draft` flag, which only processes a portion of the data for fast sanity check. Make sure to remove this flag when doing real training and test. 44 | `--filter` filters out very long examples, which can slow down training and cause memory issues. 45 | 46 | Second, preprocess for test data, which does not filter any example: 47 | ```bash 48 | python squad_prepro_main.py --from_dir $SQUAD_DIR --to_dir prepro/draft/sort_filter/sort --glove_dir $GLOVE_DIR --sort --draft --indexer_dir prepro/draft/sort_filter 49 | ``` 50 | 51 | Third, train a model: 52 | ```bash 53 | python train_and_eval.py --output_dir /tmp/squad_ckpts --root_data_dir prepro/draft/sort_filter/ --glove_dir $GLOVE_DIR --oom_test 54 | ``` 55 | In general, `--oom_test` is a flag for testing if your GPU has enough memory for the model, but it can also serve as a quick test to make sure everything runs. 56 | 57 | Fourth, test the model: 58 | ```bash 59 | python train_and_eval.py --root_data_dir prepro/draft/sort_filter/sort --glove_dir $GLOVE_DIR --oom_test --infer --restore_dir /tmp/squad_ckpts 60 | ``` 61 | Note that `--output_dir` has to be changed to `--restore_dir`, and also `--infer` flag has been added. 62 | Instead of using data from `prepro/draft/sort_filter/`, it is using `prepro/draft/sort_filter/sort`, which does not filter any example. 63 | This outputs the json file in `--restore_dir` that is compatible with SQuAD official evaluator. 64 | If you want to run this fully (no draft mode), remove `--draft` and `--oom_test` flags when applicable. 65 | -------------------------------------------------------------------------------- /common_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common components for feature and kernel models.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | import tensorflow.contrib.learn as learn 23 | import squad_data 24 | import tf_utils 25 | 26 | # This value is needed to compute best answer span with GPUs. 27 | # Set this to a high value for large context, but note that it will take 28 | # GPU memory. 29 | MAX_CONTEXT_SIZE = 2000 30 | 31 | 32 | def embedding_layer(features, mode, params, reuse=False): 33 | """Common embedding layer for feature and kernel functions. 34 | 35 | Args: 36 | features: A dictionary containing features, directly copied from `model_fn`. 37 | mode: Mode. 38 | params: Contains parameters, directly copied from `model_fn`. 39 | reuse: Reuse variables. 40 | Returns: 41 | `(x, q)` where `x` is embedded representation of context, and `q` is the 42 | embedded representation of the question. 43 | """ 44 | with tf.variable_scope('embedding_layer', reuse=reuse): 45 | training = mode == learn.ModeKeys.TRAIN 46 | with tf.variable_scope('embedding'): 47 | char_emb_mat = tf.get_variable('char_emb_mat', 48 | [params.char_vocab_size, params.emb_size]) 49 | xc = tf.nn.embedding_lookup(char_emb_mat, 50 | features['indexed_context_chars']) 51 | qc = tf.nn.embedding_lookup(char_emb_mat, 52 | features['indexed_question_chars']) 53 | 54 | xc = tf.reduce_max(xc, 2) 55 | qc = tf.reduce_max(qc, 2) 56 | 57 | _, xv, qv = glove_layer(features) 58 | 59 | # Concat 60 | x = tf.concat([xc, xv], 2) 61 | q = tf.concat([qc, qv], 2) 62 | 63 | x = tf_utils.highway_net( 64 | x, 2, training=training, dropout_rate=params.dropout_rate) 65 | q = tf_utils.highway_net( 66 | q, 2, training=training, dropout_rate=params.dropout_rate, reuse=True) 67 | 68 | return x, q 69 | 70 | 71 | def glove_layer(features, scope=None): 72 | """GloVe embedding layer. 73 | 74 | The first two words of `features['emb_mat']` are and . 75 | The other words are actual words. So we learn the representations of the 76 | first two words but the representation of other words are fixed (GloVe). 77 | 78 | Args: 79 | features: `dict` of feature tensors. 80 | scope: `str` for scope name. 81 | Returns: 82 | A tuple of tensors, `(glove_emb_mat, context_emb, question_emb)`. 83 | """ 84 | with tf.variable_scope(scope or 'glove_layer'): 85 | glove_emb_mat_const = tf.slice(features['emb_mat'], [2, 0], [-1, -1]) 86 | glove_emb_mat_var = tf.get_variable('glove_emb_mat_var', 87 | [2, 88 | glove_emb_mat_const.get_shape()[1]]) 89 | glove_emb_mat = tf.concat([glove_emb_mat_var, glove_emb_mat_const], 0) 90 | xv = tf.nn.embedding_lookup(glove_emb_mat, 91 | features['glove_indexed_context_words']) 92 | qv = tf.nn.embedding_lookup(glove_emb_mat, 93 | features['glove_indexed_question_words']) 94 | return glove_emb_mat, xv, qv 95 | 96 | 97 | def char_layer(features, params, scope=None): 98 | """Character embedding layer. 99 | 100 | Args: 101 | features: `dict` of feature tensors. 102 | params: `HParams` object. 103 | scope: `str` for scope name. 104 | Returns: 105 | a tuple of tensors, `(char_emb_mat, context_emb, question_emb)`. 106 | """ 107 | with tf.variable_scope(scope or 'char_layer'): 108 | char_emb_mat = tf.get_variable('char_emb_mat', 109 | [params.char_vocab_size, params.emb_size]) 110 | xc = tf.nn.embedding_lookup(char_emb_mat, features['indexed_context_chars']) 111 | qc = tf.nn.embedding_lookup(char_emb_mat, 112 | features['indexed_question_chars']) 113 | 114 | xc = tf.reduce_max(xc, 2) 115 | qc = tf.reduce_max(qc, 2) 116 | return char_emb_mat, xc, qc 117 | 118 | 119 | def get_pred_ops(features, params, logits_start, logits_end, no_answer_bias): 120 | """Get prediction op dictionary given start & end logits. 121 | 122 | This dictionary will contain predictions as well as everything needed 123 | to produce the nominal answer and identifier (ids). 124 | 125 | Args: 126 | features: Features. 127 | params: `HParams` object. 128 | logits_start: [batch_size, context_size]-shaped tensor of logits for start. 129 | logits_end: Similar to `logits_start`, but for end. This tensor can be also 130 | [batch_size, context_size, context_size], in which case the true answer 131 | start is used to index on dim 1 (context_size). 132 | no_answer_bias: [batch_size, 1]-shaped tensor, bias for no answer decision. 133 | Returns: 134 | A dictionary of prediction tensors. 135 | """ 136 | max_x_len = tf.shape(logits_start)[1] 137 | 138 | if len(logits_end.get_shape()) == 3: 139 | prob_end_given_start = tf.nn.softmax(logits_end) 140 | prob_start = tf.nn.softmax(logits_start) 141 | prob_start_end = prob_end_given_start * tf.expand_dims(prob_start, -1) 142 | 143 | upper_tri_mat = tf.slice( 144 | np.triu( 145 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32') - 146 | np.triu( 147 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32'), 148 | k=params.max_answer_size)), [0, 0], [max_x_len, max_x_len]) 149 | prob_start_end *= tf.expand_dims(upper_tri_mat, 0) 150 | 151 | prob_end = tf.reduce_sum(prob_start_end, 1) 152 | answer_pred_start = tf.argmax(tf.reduce_max(prob_start_end, 2), 1) 153 | answer_pred_end = tf.argmax(tf.reduce_max(prob_start_end, 1), 1) 154 | answer = squad_data.get_answer_op(features['context'], 155 | features['context_words'], 156 | answer_pred_start, answer_pred_end) 157 | answer_prob = tf.reduce_max(prob_start_end, [1, 2]) 158 | 159 | predictions = { 160 | 'yp1': answer_pred_start, 161 | 'yp2': answer_pred_end, 162 | 'p1': prob_start, 163 | 'p2': prob_end, 164 | 'a': answer, 165 | 'id': features['id'], 166 | 'context': features['context'], 167 | 'context_words': features['context_words'], 168 | 'answer_prob': answer_prob, 169 | 'has_answer': answer_prob > 0.0, 170 | } 171 | 172 | else: 173 | # Predictions and metrics. 174 | concat_logits_start = tf.concat([no_answer_bias, logits_start], 1) 175 | concat_logits_end = tf.concat([no_answer_bias, logits_end], 1) 176 | 177 | concat_prob_start = tf.nn.softmax(concat_logits_start) 178 | concat_prob_end = tf.nn.softmax(concat_logits_end) 179 | 180 | no_answer_prob_start = tf.squeeze( 181 | tf.slice(concat_prob_start, [0, 0], [-1, 1]), 1) 182 | no_answer_prob_end = tf.squeeze( 183 | tf.slice(concat_prob_end, [0, 0], [-1, 1]), 1) 184 | no_answer_prob = no_answer_prob_start * no_answer_prob_end 185 | has_answer = no_answer_prob < 0.5 186 | prob_start = tf.slice(concat_prob_start, [0, 1], [-1, -1]) 187 | prob_end = tf.slice(concat_prob_end, [0, 1], [-1, -1]) 188 | 189 | # This is only for computing span accuracy and not used for training. 190 | # Masking with `upper_triangular_matrix` only allows valid spans, 191 | # i.e. `answer_pred_start` <= `answer_pred_end`. 192 | # TODO(seominjoon): Replace with dynamic upper triangular matrix. 193 | upper_tri_mat = tf.slice( 194 | np.triu( 195 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32') - 196 | np.triu( 197 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32'), 198 | k=params.max_answer_size)), [0, 0], [max_x_len, max_x_len]) 199 | prob_mat = tf.expand_dims(prob_start, -1) * tf.expand_dims( 200 | prob_end, 1) * tf.expand_dims(upper_tri_mat, 0) 201 | # TODO(seominjoon): Handle this. 202 | logits_mat = tf_utils.exp_mask( 203 | tf.expand_dims(logits_start, -1) + tf.expand_dims(logits_end, 1), 204 | tf.expand_dims(upper_tri_mat, 0), 205 | mask_is_length=False) 206 | del logits_mat 207 | 208 | answer_pred_start = tf.argmax(tf.reduce_max(prob_mat, 2), 1) 209 | answer_pred_end = tf.argmax(tf.reduce_max(prob_mat, 1), 1) # [batch_size] 210 | answer = squad_data.get_answer_op(features['context'], 211 | features['context_words'], 212 | answer_pred_start, answer_pred_end) 213 | answer_prob = tf.reduce_max(prob_mat, [1, 2]) 214 | 215 | predictions = { 216 | 'yp1': answer_pred_start, 217 | 'yp2': answer_pred_end, 218 | 'p1': prob_start, 219 | 'p2': prob_end, 220 | 'a': answer, 221 | 'id': features['id'], 222 | 'context': features['context'], 223 | 'context_words': features['context_words'], 224 | 'no_answer_prob': no_answer_prob, 225 | 'no_answer_prob_start': no_answer_prob_start, 226 | 'no_answer_prob_end': no_answer_prob_end, 227 | 'answer_prob': answer_prob, 228 | 'has_answer': has_answer, 229 | } 230 | return predictions 231 | 232 | 233 | def get_loss(answer_start, 234 | answer_end, 235 | logits_start, 236 | logits_end, 237 | no_answer_bias, 238 | sparse=True): 239 | """Get loss given answer and logits. 240 | 241 | Args: 242 | answer_start: [batch_size, num_answers] shaped tensor if `sparse=True`, or 243 | [batch_size, context_size] shaped if `sparse=False`. 244 | answer_end: Similar to `answer_start` but for end. 245 | logits_start: [batch_size, context_size]-shaped tensor for answer start 246 | logits. 247 | logits_end: Similar to `logits_start`, but for end. This tensor can be also 248 | [batch_size, context_size, context_size], in which case the true answer 249 | start is used to index on dim 1 (context_size). 250 | no_answer_bias: [batch_size, 1] shaped tensor, bias for no answer decision. 251 | sparse: Indicates whether `answer_start` and `answer_end` are sparse or 252 | dense. 253 | Returns: 254 | Float loss tensor. 255 | """ 256 | if sparse: 257 | # During training, only one answer. During eval, multiple answers. 258 | # TODO(seominjoon): Make eval loss minimum over multiple answers. 259 | # Loss for start. 260 | answer_start = tf.squeeze(tf.slice(answer_start, [0, 0], [-1, 1]), 1) 261 | answer_start += 1 262 | logits_start = tf.concat([no_answer_bias, logits_start], 1) 263 | losses_start = tf.nn.sparse_softmax_cross_entropy_with_logits( 264 | labels=answer_start, logits=logits_start) 265 | loss_start = tf.reduce_mean(losses_start) 266 | tf.add_to_collection('losses', loss_start) 267 | 268 | # Loss for end. 269 | answer_end = tf.squeeze(tf.slice(answer_end, [0, 0], [-1, 1]), 1) 270 | # Below are start-conditional loss, where every start position has its 271 | # own logits for end position. 272 | if len(logits_end.get_shape()) == 3: 273 | mask = tf.one_hot(answer_start, tf.shape(logits_end)[1]) 274 | mask = tf.cast(tf.expand_dims(mask, -1), 'float') 275 | logits_end = tf.reduce_sum(mask * logits_end, 1) 276 | answer_end += 1 277 | logits_end = tf.concat([no_answer_bias, logits_end], 1) 278 | losses_end = tf.nn.sparse_softmax_cross_entropy_with_logits( 279 | labels=answer_end, logits=logits_end) 280 | loss_end = tf.reduce_mean(losses_end) 281 | tf.add_to_collection('losses', loss_end) 282 | else: 283 | # TODO(seominjoon): Implement no answer capability for non sparse labels. 284 | losses_start = tf.nn.softmax_cross_entropy_with_logits( 285 | labels=answer_start, logits=logits_start) 286 | loss_start = tf.reduce_mean(losses_start) 287 | tf.add_to_collection('losses', loss_start) 288 | 289 | losses_end = tf.nn.softmax_cross_entropy_with_logits( 290 | labels=answer_end, logits=logits_end) 291 | loss_end = tf.reduce_mean(losses_end) 292 | tf.add_to_collection('losses', loss_end) 293 | 294 | return tf.add_n(tf.get_collection('losses')) 295 | 296 | 297 | def get_train_op(loss, 298 | var_list=None, 299 | post_ops=None, 300 | inc_step=True, 301 | learning_rate=0.001, 302 | clip_norm=0.0): 303 | """Get train op for the given loss. 304 | 305 | Args: 306 | loss: Loss tensor. 307 | var_list: A list of variables that the train op will minimize. 308 | post_ops: A list of ops that will be run after the train op. If not defined, 309 | no op is run after train op. 310 | inc_step: If `True`, will increase the `global_step` variable by 1 after 311 | step. 312 | learning_rate: Initial learning rate for the optimizer. 313 | clip_norm: If specified, clips the gradient of each variable by this value. 314 | Returns: 315 | Train op to be used for training. 316 | """ 317 | 318 | global_step = tf.train.get_global_step() if inc_step else None 319 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 320 | grads = optimizer.compute_gradients(loss, var_list=var_list) 321 | grads = [(grad, var) for grad, var in grads if grad is not None] 322 | for grad, var in grads: 323 | tf.summary.histogram(var.op.name, var) 324 | tf.summary.histogram('gradients/' + var.op.name, grad) 325 | if clip_norm: 326 | grads = [(tf.clip_by_norm(grad, clip_norm), var) for grad, var in grads] 327 | train_op = optimizer.apply_gradients(grads, global_step=global_step) 328 | 329 | if post_ops is not None: 330 | with tf.control_dependencies([train_op]): 331 | train_op = tf.group(*post_ops) 332 | 333 | return train_op 334 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | DATA_DIR=$HOME/data 18 | mkdir $DATA_DIR 19 | 20 | # Download SQuAD 21 | SQUAD_DIR=$DATA_DIR/squad 22 | mkdir $SQUAD_DIR 23 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $SQUAD_DIR/train-v1.1.json 24 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $SQUAD_DIR/dev-v1.1.json 25 | 26 | # Download GloVe 27 | GLOVE_DIR=$DATA_DIR/glove 28 | mkdir $GLOVE_DIR 29 | wget http://nlp.stanford.edu/data/glove.6B.zip -O $GLOVE_DIR/glove.6B.zip 30 | unzip $GLOVE_DIR/glove.6B.zip -d $GLOVE_DIR 31 | 32 | # Download NLTK tokenizer data 33 | # Make sure nltk is already installed! 34 | python -m nltk.downloader -d $HOME/nltk_data punkt 35 | -------------------------------------------------------------------------------- /feature_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Feature map model. 15 | 16 | Obtains separate embedding for question and each word of context, and then 17 | use distance metric (dot, l1, l2) to obtain the closest word in the context. 18 | """ 19 | # TODO(seominjoon): Refactor file and function names (e.g. drop squad). 20 | import sys 21 | 22 | import tensorflow as tf 23 | import tensorflow.contrib.learn as learn 24 | 25 | import tf_utils 26 | from common_model import char_layer 27 | from common_model import embedding_layer 28 | from common_model import glove_layer 29 | 30 | 31 | def feature_model(features, mode, params, scope=None): 32 | """Handler for SQuAD feature models: only allow feature mapping. 33 | 34 | Every feature model is called via this function. Which model to call can 35 | be controlled via `params.model_id` (e.g. `params.model_id = m00`). 36 | 37 | Function requirement is in: 38 | https://www.tensorflow.org/extend/estimators 39 | 40 | This function does not have any dependency on FLAGS. All parameters must be 41 | passed through `params` argument for all models in this script. 42 | 43 | Args: 44 | features: A dict of feature tensors. 45 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 46 | params: `params` passed during initialization of `Estimator` object. 47 | scope: Variable scope. 48 | Returns: 49 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of 50 | tensors that can be useful outside of this function, e.g. visualization. 51 | """ 52 | this_module = sys.modules[__name__] 53 | model_fn = getattr(this_module, '_model_%s' % params.model_id) 54 | return model_fn(features, mode, params, scope=scope) 55 | 56 | 57 | def _get_logits_from_multihead_x_and_q(x_start, 58 | x_end, 59 | q_start, 60 | q_end, 61 | x_len, 62 | dist, 63 | x_start_reduce_fn=None, 64 | x_end_reduce_fn=None, 65 | q_start_reduce_fn=None, 66 | q_end_reduce_fn=None): 67 | """Helper function for getting logits from context and question vectors. 68 | 69 | Args: 70 | x_start: [batch_size, context_num_words, hidden_size]-shaped float tensor, 71 | representing context vectors for answer start. 72 | Can have one more dim at the end, which will be reduced after computing 73 | distance via `reduce_fn`. 74 | x_end: [batch_size, context_num_words, hidden_size]-shaped float tensor, 75 | representing context vectors for answer end. 76 | Can have one more dim at the end, which will be reduced after computing 77 | distance via `reduce_fn`. 78 | q_start: [batch_size, hidden_size]-shaped float tensor, 79 | representing question vector for answer start. 80 | Can have one more dim at the end, which will be reduced after computing 81 | distance via `reduce_fn`. 82 | q_end: [batch_size, hidden_size]-shaped float tensor, 83 | representing question vector for answer end. 84 | Can have one more dim at the end, which will be reduced after computing 85 | distance via `reduce_fn`. 86 | x_len: [batch_size]-shaped int64 tensor, containing length of each context. 87 | dist: distance function, `dot`, `l1`, or `l2`. 88 | x_start_reduce_fn: reduction function that takes in the tensor as first 89 | argument and the axis as the second argument. Default is `tf.reduce_max`. 90 | Reduction for `x_start` and `x_end` if the extra dim is provided. 91 | Note that `l1` and `l2` distances are first negated and then the reduction 92 | is applied, so that `reduce_max` effectively gets minimal distance. 93 | Note that, for correct performance during inference when using nearest 94 | neighbor, reduction must be the default one (None, i.e. `tf.reduce_max`). 95 | x_end_reduce_fn: ditto, for end. 96 | q_start_reduce_fn: reduction function that takes in the tensor as first 97 | argument and the axis as the second argument. Default is `tf.reduce_max`. 98 | Reduction for `q_start` and `q_end` if the extra dim is provided. 99 | Note that `l1` and `l2` distances are first negated and then the reduction 100 | is applied, so that `reduce_max` effectively gets minimal distance. 101 | This can be any reduction function, unlike `x_reduce_fn`. 102 | q_end_reduce_fn: ditto, for end. 103 | 104 | Returns: 105 | a tuple `(logits_start, logits_end)` where each tensor's shape is 106 | [batch_size, context_num_words]. 107 | 108 | """ 109 | if len(q_start.get_shape()) == 1: 110 | # q can be universal, e.g. trainable weights. 111 | q_start = tf.expand_dims(q_start, 0) 112 | q_end = tf.expand_dims(q_end, 0) 113 | 114 | # Expand q first to broadcast for `context_num_words` dim. 115 | q_start = tf.expand_dims(q_start, 1) 116 | q_end = tf.expand_dims(q_end, 1) 117 | 118 | # Add one dim at the end if no additional dim at the end, to make them rank-4. 119 | if len(x_start.get_shape()) == 3: 120 | x_start = tf.expand_dims(x_start, -1) 121 | x_end = tf.expand_dims(x_end, -1) 122 | if len(q_start.get_shape()) == 3: 123 | q_start = tf.expand_dims(q_start, -1) 124 | q_end = tf.expand_dims(q_end, -1) 125 | 126 | # Add dim to outer-product x and q. This makes them rank-5 tensors. 127 | # shape : [batch_size, context_words, hidden_size, num_heads, 1] 128 | x_start = tf.expand_dims(x_start, -1) 129 | x_end = tf.expand_dims(x_end, -1) 130 | 131 | # shape : [batch_size, context_words, hidden_size, 1, num_heads] 132 | q_start = tf.expand_dims(q_start, 3) 133 | q_end = tf.expand_dims(q_end, 3) 134 | 135 | if x_start_reduce_fn is None: 136 | x_start_reduce_fn = tf.reduce_max 137 | if x_end_reduce_fn is None: 138 | x_end_reduce_fn = tf.reduce_max 139 | if q_start_reduce_fn is None: 140 | q_start_reduce_fn = tf.reduce_max 141 | if q_end_reduce_fn is None: 142 | q_end_reduce_fn = tf.reduce_max 143 | 144 | if dist == 'dot': 145 | logits_start = q_start_reduce_fn( 146 | x_start_reduce_fn(tf.reduce_sum(x_start * q_start, 2), 2), 2) 147 | logits_end = q_end_reduce_fn( 148 | x_end_reduce_fn(tf.reduce_sum(x_end * q_end, 2), 2), 2) 149 | elif dist == 'l1': 150 | logits_start = q_start_reduce_fn( 151 | x_start_reduce_fn( 152 | -tf.norm(x_start - q_start, ord=1, axis=2, keep_dims=True), 2), 2) 153 | logits_start = tf.squeeze( 154 | tf.layers.dense(logits_start, 1, name='logits_start'), 2) 155 | logits_end = q_end_reduce_fn( 156 | x_end_reduce_fn(-tf.norm(x_end - q_end, ord=1, axis=2, keep_dims=True), 157 | 2), 2) 158 | logits_end = tf.squeeze( 159 | tf.layers.dense(logits_end, 1, name='logits_end'), 2) 160 | elif dist == 'l2': 161 | logits_start = q_start_reduce_fn( 162 | x_start_reduce_fn( 163 | -tf.norm(x_start - q_start, ord=2, axis=2, keep_dims=True), 2), 2) 164 | logits_start = tf.squeeze( 165 | tf.layers.dense(logits_start, 1, name='logits_start'), 2) 166 | logits_end = q_end_reduce_fn( 167 | x_end_reduce_fn(-tf.norm(x_end - q_end, ord=2, axis=2, keep_dims=True), 168 | 2), 2) 169 | logits_end = tf.squeeze( 170 | tf.layers.dense(logits_end, 1, name='logits_end'), 2) 171 | 172 | logits_start = tf_utils.exp_mask(logits_start, x_len) 173 | logits_end = tf_utils.exp_mask(logits_end, x_len) 174 | 175 | return logits_start, logits_end 176 | 177 | 178 | def _model_m00(features, mode, params, scope=None): 179 | """LSTM-based model. 180 | 181 | This model uses two stacked LSTMs to output vectors for context, and 182 | self-attention to output vectors for question. This model reaches 57~58% F1. 183 | 184 | Args: 185 | features: A dict of feature tensors. 186 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 187 | params: `params` passed during initialization of `Estimator` object. 188 | scope: Variable scope, default is `feature_model`. 189 | Returns: 190 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of 191 | tensors that can be useful outside of this function, e.g. visualization. 192 | """ 193 | 194 | with tf.variable_scope(scope or 'feature_model'): 195 | training = mode == learn.ModeKeys.TRAIN 196 | 197 | x, q = embedding_layer(features, mode, params) 198 | 199 | x1 = tf_utils.bi_rnn( 200 | params.hidden_size, 201 | x, 202 | sequence_length_list=features['context_num_words'], 203 | scope='x_bi_rnn_1', 204 | training=training, 205 | dropout_rate=params.dropout_rate) 206 | 207 | x2 = tf_utils.bi_rnn( 208 | params.hidden_size, 209 | x1, 210 | sequence_length_list=features['context_num_words'], 211 | scope='x_bi_rnn_2', 212 | training=training, 213 | dropout_rate=params.dropout_rate) 214 | 215 | q1 = tf_utils.bi_rnn( 216 | params.hidden_size, 217 | q, 218 | sequence_length_list=features['question_num_words'], 219 | scope='q_bi_rnn_1', 220 | training=training, 221 | dropout_rate=params.dropout_rate) 222 | 223 | q2 = tf_utils.bi_rnn( 224 | params.hidden_size, 225 | q1, 226 | sequence_length_list=features['question_num_words'], 227 | scope='q_bi_rnn_2', 228 | training=training, 229 | dropout_rate=params.dropout_rate) 230 | 231 | # Self-attention to obtain single vector representation. 232 | q_start = tf_utils.self_att( 233 | q1, mask=features['question_num_words'], scope='q_start') 234 | q_end = tf_utils.self_att( 235 | q2, mask=features['question_num_words'], scope='q_end') 236 | 237 | logits_start, logits_end = _get_logits_from_multihead_x_and_q( 238 | x1, x2, q_start, q_end, features['context_num_words'], params.dist) 239 | return logits_start, logits_end, dict() 240 | 241 | 242 | def _model_m01(features, mode, params, scope=None): 243 | """Self-attention with MLP, reaching 55~56% F1. 244 | 245 | Args: 246 | features: A dict of feature tensors. 247 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 248 | params: `params` passed during initialization of `Estimator` object. 249 | scope: Variable scope, default is `feature_model`. 250 | Returns: 251 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of 252 | tensors that can be useful outside of this function, e.g. visualization. 253 | """ 254 | 255 | with tf.variable_scope(scope or 'feature_model'): 256 | training = mode == learn.ModeKeys.TRAIN 257 | tensors = {} 258 | 259 | x, q = embedding_layer(features, mode, params) 260 | 261 | x0 = tf_utils.bi_rnn( 262 | params.hidden_size, 263 | x, 264 | sequence_length_list=features['context_num_words'], 265 | scope='bi_rnn_x0', 266 | training=training, 267 | dropout_rate=params.dropout_rate) 268 | 269 | x1 = tf_utils.bi_rnn( 270 | params.hidden_size, 271 | x0, 272 | sequence_length_list=features['context_num_words'], 273 | scope='bi_rnn_x1', 274 | training=training, 275 | dropout_rate=params.dropout_rate) 276 | 277 | x1 += x0 278 | 279 | q1 = tf_utils.bi_rnn( 280 | params.hidden_size, 281 | q, 282 | sequence_length_list=features['question_num_words'], 283 | scope='bi_rnn_q1', 284 | training=training, 285 | dropout_rate=params.dropout_rate) 286 | 287 | def get_x(x_, scope=None): 288 | with tf.variable_scope(scope or 'get_x_clue'): 289 | hidden_sizes = [params.hidden_size, params.hidden_size] 290 | attender = tf_utils.mlp( 291 | x_, 292 | hidden_sizes, 293 | activate_last=False, 294 | training=training, 295 | dropout_rate=params.dropout_rate, 296 | scope='attender') 297 | attendee = tf_utils.mlp( 298 | x_, 299 | hidden_sizes, 300 | activate_last=False, 301 | training=training, 302 | dropout_rate=params.dropout_rate, 303 | scope='attendee') 304 | clue = tf_utils.att2d( 305 | attendee, 306 | attender, 307 | a_val=x_, 308 | mask=features['context_num_words'], 309 | logit_fn='dot', 310 | tensors=tensors) 311 | return tf.concat([x_, clue], 2) 312 | 313 | x_start = get_x(x1, scope='get_x_start') 314 | x_end = get_x(x1, scope='get_x_end') 315 | 316 | q_type = tf_utils.self_att( 317 | q1, 318 | mask=features['question_num_words'], 319 | scope='self_att_q_type', 320 | tensors=tensors) 321 | q_clue = tf_utils.self_att( 322 | q1, 323 | mask=features['question_num_words'], 324 | scope='self_att_q_clue', 325 | tensors=tensors) 326 | q_start = q_end = tf.concat([q_type, q_clue], 1) 327 | 328 | logits_start, logits_end = _get_logits_from_multihead_x_and_q( 329 | x_start, x_end, q_start, q_end, features['context_num_words'], 330 | params.dist) 331 | return logits_start, logits_end, tensors 332 | 333 | 334 | def _model_m02(features, mode, params, scope=None): 335 | """Self-attention with LSTM, reaching 59~60% F1. 336 | 337 | Args: 338 | features: A dict of feature tensors. 339 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 340 | params: `params` passed during initialization of `Estimator` object. 341 | scope: Variable scope, default is `feature_model`. 342 | Returns: 343 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of 344 | tensors that can be useful outside of this function, e.g. visualization. 345 | """ 346 | 347 | with tf.variable_scope(scope or 'feature_model'): 348 | training = mode == learn.ModeKeys.TRAIN 349 | tensors = {} 350 | 351 | x, q = embedding_layer(features, mode, params) 352 | 353 | x1 = tf_utils.bi_rnn( 354 | params.hidden_size, 355 | x, 356 | sequence_length_list=features['context_num_words'], 357 | scope='bi_rnn_x1', 358 | training=training, 359 | dropout_rate=params.dropout_rate) 360 | 361 | def get_clue(x_, scope=None): 362 | with tf.variable_scope(scope or 'get_clue'): 363 | attendee = tf_utils.bi_rnn( 364 | params.hidden_size, 365 | x_, 366 | sequence_length_list=features['context_num_words'], 367 | scope='bi_rnn_attendee', 368 | training=training, 369 | dropout_rate=params.dropout_rate) 370 | attender = tf_utils.bi_rnn( 371 | params.hidden_size, 372 | x_, 373 | sequence_length_list=features['context_num_words'], 374 | scope='bi_rnn_attender', 375 | training=training, 376 | dropout_rate=params.dropout_rate) 377 | clue = tf_utils.att2d( 378 | attendee, 379 | attender, 380 | a_val=x_, 381 | mask=features['context_num_words'], 382 | logit_fn='dot', 383 | tensors=tensors) 384 | return clue 385 | 386 | x1_clue = get_clue(x1) 387 | x_start = tf.concat([x1, x1_clue], 2) 388 | 389 | x2 = tf_utils.bi_rnn( 390 | params.hidden_size, 391 | x1, 392 | sequence_length_list=features['context_num_words'], 393 | scope='bi_rnn_x2', 394 | training=training, 395 | dropout_rate=params.dropout_rate) 396 | x2_clue = tf_utils.bi_rnn( 397 | params.hidden_size, 398 | x1_clue, 399 | sequence_length_list=features['context_num_words'], 400 | scope='bi_rnn_x2_clue', 401 | training=training, 402 | dropout_rate=params.dropout_rate) 403 | x_end = tf.concat([x2, x2_clue], 2) 404 | 405 | q1 = tf_utils.bi_rnn( 406 | params.hidden_size, 407 | q, 408 | sequence_length_list=features['question_num_words'], 409 | scope='bi_rnn_q1', 410 | training=training, 411 | dropout_rate=params.dropout_rate) 412 | 413 | q2 = tf_utils.bi_rnn( 414 | params.hidden_size, 415 | q1, 416 | sequence_length_list=features['question_num_words'], 417 | scope='bi_rnn_q2', 418 | training=training, 419 | dropout_rate=params.dropout_rate) 420 | 421 | # Self-attention to obtain single vector representation. 422 | q1_type = tf_utils.self_att( 423 | q1, 424 | mask=features['question_num_words'], 425 | tensors=tensors, 426 | scope='self_att_q1_type') 427 | q1_clue = tf_utils.self_att( 428 | q1, 429 | mask=features['question_num_words'], 430 | tensors=tensors, 431 | scope='self_att_q1_clue') 432 | q_start = tf.concat([q1_type, q1_clue], 1) 433 | q2_type = tf_utils.self_att( 434 | q2, mask=features['question_num_words'], scope='self_att_q2_type') 435 | q2_clue = tf_utils.self_att( 436 | q2, mask=features['question_num_words'], scope='self_att_q2_clue') 437 | q_end = tf.concat([q2_type, q2_clue], 1) 438 | 439 | logits_start, logits_end = _get_logits_from_multihead_x_and_q( 440 | x_start, x_end, q_start, q_end, features['context_num_words'], 441 | params.dist) 442 | return logits_start, logits_end, tensors 443 | 444 | 445 | def _model_m03(features, mode, params, scope=None): 446 | """Independent self-attention with LSTM, reaching 60~61%. 447 | 448 | Args: 449 | features: A dict of feature tensors. 450 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 451 | params: `params` passed during initialization of `Estimator` object. 452 | scope: Variable scope, default is `feature_model`. 453 | Returns: 454 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of 455 | tensors that can be useful outside of this function, e.g. visualization. 456 | """ 457 | 458 | with tf.variable_scope(scope or 'feature_model'): 459 | training = mode == learn.ModeKeys.TRAIN 460 | tensors = {} 461 | 462 | x, q = embedding_layer(features, mode, params) 463 | 464 | def get_x_and_q(scope=None): 465 | with tf.variable_scope(scope or 'get_x_and_q'): 466 | x1 = tf_utils.bi_rnn( 467 | params.hidden_size, 468 | x, 469 | sequence_length_list=features['context_num_words'], 470 | scope='bi_rnn_x1', 471 | training=training, 472 | dropout_rate=params.dropout_rate) 473 | 474 | attendee = tf_utils.bi_rnn( 475 | params.hidden_size, 476 | x1, 477 | sequence_length_list=features['context_num_words'], 478 | scope='bi_rnn_attendee', 479 | training=training, 480 | dropout_rate=params.dropout_rate) 481 | attender = tf_utils.bi_rnn( 482 | params.hidden_size, 483 | x1, 484 | sequence_length_list=features['context_num_words'], 485 | scope='bi_rnn_attender', 486 | training=training, 487 | dropout_rate=params.dropout_rate) 488 | clue = tf_utils.att2d( 489 | attendee, 490 | attender, 491 | a_val=x1, 492 | mask=features['context_num_words'], 493 | logit_fn='dot', 494 | tensors=tensors) 495 | 496 | x_out = tf.concat([x1, clue], 2) 497 | 498 | q1 = tf_utils.bi_rnn( 499 | params.hidden_size, 500 | q, 501 | sequence_length_list=features['question_num_words'], 502 | scope='bi_rnn_q1', 503 | training=training, 504 | dropout_rate=params.dropout_rate) 505 | 506 | q1_type = tf_utils.self_att( 507 | q1, 508 | mask=features['question_num_words'], 509 | tensors=tensors, 510 | scope='self_att_q1_type') 511 | q1_clue = tf_utils.self_att( 512 | q1, 513 | mask=features['question_num_words'], 514 | tensors=tensors, 515 | scope='self_att_q1_clue') 516 | q_out = tf.concat([q1_type, q1_clue], 1) 517 | 518 | return x_out, q_out 519 | 520 | x_start, q_start = get_x_and_q('start') 521 | x_end, q_end = get_x_and_q('end') 522 | 523 | logits_start, logits_end = _get_logits_from_multihead_x_and_q( 524 | x_start, x_end, q_start, q_end, features['context_num_words'], 525 | params.dist) 526 | return logits_start, logits_end, tensors 527 | 528 | 529 | def _model_m04(features, mode, params, scope=None): 530 | """Regularization with query generation loss on top of m03, reaching 63~64%. 531 | 532 | Note that most part of this model is identical to m03, except for the function 533 | `reg_gen`, which adds additional generation loss. 534 | 535 | Args: 536 | features: A dict of feature tensors. 537 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 538 | params: `params` passed during initialization of `Estimator` object. 539 | scope: Variable scope, default is `feature_model`. 540 | Returns: 541 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of 542 | tensors that can be useful outside of this function, e.g. visualization. 543 | """ 544 | 545 | with tf.variable_scope(scope or 'feature_model'): 546 | training = mode == learn.ModeKeys.TRAIN 547 | inference = mode == learn.ModeKeys.INFER 548 | tensors = {} 549 | 550 | with tf.variable_scope('embedding'): 551 | glove_emb_mat, xv, qv = glove_layer(features) 552 | _, xc, qc = char_layer(features, params) 553 | x = tf.concat([xc, xv], 2) 554 | q = tf.concat([qc, qv], 2) 555 | x = tf_utils.highway_net( 556 | x, 2, training=training, dropout_rate=params.dropout_rate) 557 | q = tf_utils.highway_net( 558 | q, 2, training=training, dropout_rate=params.dropout_rate, reuse=True) 559 | 560 | def get_x_and_q(scope=None): 561 | with tf.variable_scope(scope or 'get_x_and_q'): 562 | x1 = tf_utils.bi_rnn( 563 | params.hidden_size, 564 | x, 565 | sequence_length_list=features['context_num_words'], 566 | scope='bi_rnn_x1', 567 | training=training, 568 | dropout_rate=params.dropout_rate) 569 | 570 | attendee = tf_utils.bi_rnn( 571 | params.hidden_size, 572 | x1, 573 | sequence_length_list=features['context_num_words'], 574 | scope='bi_rnn_attendee', 575 | training=training, 576 | dropout_rate=params.dropout_rate) 577 | attender = tf_utils.bi_rnn( 578 | params.hidden_size, 579 | x1, 580 | sequence_length_list=features['context_num_words'], 581 | scope='bi_rnn_attender', 582 | training=training, 583 | dropout_rate=params.dropout_rate) 584 | clue = tf_utils.att2d( 585 | attendee, 586 | attender, 587 | a_val=x1, 588 | mask=features['context_num_words'], 589 | logit_fn='dot', 590 | tensors=tensors) 591 | 592 | x_out = tf.concat([x1, clue], 2) 593 | 594 | q1 = tf_utils.bi_rnn( 595 | params.hidden_size, 596 | q, 597 | sequence_length_list=features['question_num_words'], 598 | scope='bi_rnn_q1', 599 | training=training, 600 | dropout_rate=params.dropout_rate) 601 | 602 | q1_type = tf_utils.self_att( 603 | q1, 604 | mask=features['question_num_words'], 605 | tensors=tensors, 606 | scope='self_att_q1_type') 607 | q1_clue = tf_utils.self_att( 608 | q1, 609 | mask=features['question_num_words'], 610 | tensors=tensors, 611 | scope='self_att_q1_clue') 612 | q_out = tf.concat([q1_type, q1_clue], 1) 613 | 614 | return x_out, q_out 615 | 616 | x_start, q_start = get_x_and_q('start') 617 | x_end, q_end = get_x_and_q('end') 618 | 619 | # TODO(seominjoon): Separate regularization and model parts. 620 | def reg_gen(glove_emb_mat, memory, scope): 621 | """Add query generation loss to `losses` collection as regularization.""" 622 | with tf.variable_scope(scope): 623 | start_vec = tf.get_variable( 624 | 'start_vec', shape=glove_emb_mat.get_shape()[1]) 625 | end_vec = tf.get_variable('end_vec', shape=glove_emb_mat.get_shape()[1]) 626 | glove_emb_mat = tf.concat([ 627 | glove_emb_mat, 628 | tf.expand_dims(start_vec, 0), 629 | tf.expand_dims(end_vec, 0) 630 | ], 0) 631 | vocab_size = glove_emb_mat.get_shape().as_list()[0] 632 | start_idx = vocab_size - 2 633 | end_idx = vocab_size - 1 634 | batch_size = tf.shape(x)[0] 635 | 636 | # Index memory 637 | memory_mask = tf.one_hot( 638 | tf.slice(features['word_answer_%ss' % scope], [0, 0], [-1, 1]), 639 | tf.shape(x)[1]) 640 | # Transposing below is just a convenient way to do reduction at dim 2 641 | # and expansion at dim 1 with one operation. 642 | memory_mask = tf.transpose(memory_mask, [0, 2, 1]) 643 | initial_state = tf.reduce_sum(memory * tf.cast(memory_mask, 'float'), 1) 644 | cell = tf.contrib.rnn.GRUCell(memory.get_shape().as_list()[-1]) 645 | 646 | glove_emb_mat_dense = tf.layers.dense(glove_emb_mat, cell.output_size) 647 | 648 | def deembed(inputs): 649 | shape = tf.shape(inputs) 650 | inputs = tf.reshape(inputs, [-1, shape[-1]]) 651 | outputs = tf.matmul(inputs, tf.transpose(glove_emb_mat_dense)) 652 | outputs = tf.reshape(outputs, tf.concat([shape[:-1], [vocab_size]], 653 | 0)) 654 | return outputs 655 | 656 | if inference: 657 | # During inference, feed previous output to the next input. 658 | start_tokens = tf.tile(tf.reshape(start_idx, [1]), [batch_size]) 659 | helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( 660 | glove_emb_mat, start_tokens, end_idx) 661 | helper = tf_utils.DeembedWrapper(helper, deembed) 662 | maximum_iterations = params.max_gen_length 663 | else: 664 | # During training and eval, feed ground truth input all the time. 665 | q_in = tf_utils.concat_seq_and_tok(qv, start_vec, 'start') 666 | indexed_q_out = tf_utils.concat_seq_and_tok( 667 | tf.cast(features['glove_indexed_question_words'], 'int32'), 668 | end_idx, 669 | 'end', 670 | sequence_length=features['question_num_words']) 671 | q_len = tf.cast(features['question_num_words'], 'int32') + 1 672 | helper = tf.contrib.seq2seq.TrainingHelper(q_in, q_len) 673 | maximum_iterations = None 674 | 675 | decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, initial_state) 676 | (outputs, _), _, _ = tf.contrib.seq2seq.dynamic_decode( 677 | decoder, maximum_iterations=maximum_iterations) 678 | logits = deembed(outputs) 679 | indexed_q_pred = tf.argmax(logits, axis=2, name='indexed_q_pred') 680 | tensors[indexed_q_pred.op.name] = indexed_q_pred 681 | 682 | if not inference: 683 | # Add sequence loss to the `losses` collection. 684 | weights = tf.sequence_mask(q_len, maxlen=tf.shape(indexed_q_out)[1]) 685 | loss = tf.contrib.seq2seq.sequence_loss(logits, indexed_q_out, 686 | tf.cast(weights, 'float')) 687 | cf = params.reg_cf * tf.exp(-tf.log(2.0) * tf.cast( 688 | tf.train.get_global_step(), 'float') / params.reg_half_life) 689 | tf.add_to_collection('losses', cf * loss) 690 | 691 | if params.reg_gen: 692 | reg_gen(glove_emb_mat, x_start, 'start') 693 | reg_gen(glove_emb_mat, x_end, 'end') 694 | 695 | logits_start, logits_end = _get_logits_from_multihead_x_and_q( 696 | x_start, x_end, q_start, q_end, features['context_num_words'], 697 | params.dist) 698 | return logits_start, logits_end, tensors 699 | -------------------------------------------------------------------------------- /kernel_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Kernel model. 15 | 16 | This model allows interactions between the two inputs, such as attention. 17 | """ 18 | import sys 19 | 20 | import tensorflow as tf 21 | import tensorflow.contrib.learn as learn 22 | 23 | import tf_utils 24 | from common_model import embedding_layer 25 | 26 | 27 | def kernel_model(features, mode, params, scope=None): 28 | """Kernel models that allow interaction between question and context. 29 | 30 | This is handler for all kernel models in this script. Models are called via 31 | `params.model_id` (e.g. `params.model_id = m00`). 32 | 33 | Function requirement for each model is in: 34 | https://www.tensorflow.org/extend/estimators 35 | 36 | This function does not have any dependency on FLAGS. All parameters must be 37 | passed through `params` argument. 38 | 39 | Args: 40 | features: A dict of feature tensors. 41 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 42 | params: `params` passed during initialization of `Estimator` object. 43 | scope: Variable name scope. 44 | Returns: 45 | `(logits_start, logits_end, tensors)` pair. Tensors is a dictionary of 46 | tensors that can be useful outside of this function, e.g. visualization. 47 | """ 48 | this_module = sys.modules[__name__] 49 | model_fn = getattr(this_module, '_model_%s' % params.model_id) 50 | return model_fn( 51 | features, mode, params, scope=scope) 52 | 53 | 54 | def _model_m00(features, mode, params, scope=None): 55 | """Simplified BiDAF, reaching 74~75% F1. 56 | 57 | Args: 58 | features: A dict of feature tensors. 59 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys 60 | params: `params` passed during initialization of `Estimator` object. 61 | scope: Variable name scope. 62 | Returns: 63 | `(logits_start, logits_end, tensors)` pair. Tensors is a dictionary of 64 | tensors that can be useful outside of this function, e.g. visualization. 65 | """ 66 | with tf.variable_scope(scope or 'kernel_model'): 67 | training = mode == learn.ModeKeys.TRAIN 68 | tensors = {} 69 | 70 | x, q = embedding_layer(features, mode, params) 71 | 72 | x0 = tf_utils.bi_rnn( 73 | params.hidden_size, 74 | x, 75 | sequence_length_list=features['context_num_words'], 76 | scope='x_bi_rnn_0', 77 | training=training, 78 | dropout_rate=params.dropout_rate) 79 | 80 | q0 = tf_utils.bi_rnn( 81 | params.hidden_size, 82 | q, 83 | sequence_length_list=features['question_num_words'], 84 | scope='q_bi_rnn_0', 85 | training=training, 86 | dropout_rate=params.dropout_rate) 87 | 88 | xq = tf_utils.att2d( 89 | q0, 90 | x0, 91 | mask=features['question_num_words'], 92 | tensors=tensors, 93 | scope='xq') 94 | xq = tf.concat([x0, xq, x0 * xq], 2) 95 | x1 = tf_utils.bi_rnn( 96 | params.hidden_size, 97 | xq, 98 | sequence_length_list=features['context_num_words'], 99 | training=training, 100 | scope='x1_bi_rnn', 101 | dropout_rate=params.dropout_rate) 102 | x2 = tf_utils.bi_rnn( 103 | params.hidden_size, 104 | x1, 105 | sequence_length_list=features['context_num_words'], 106 | training=training, 107 | scope='x2_bi_rnn', 108 | dropout_rate=params.dropout_rate) 109 | x3 = tf_utils.bi_rnn( 110 | params.hidden_size, 111 | x2, 112 | sequence_length_list=features['context_num_words'], 113 | training=training, 114 | scope='x3_bi_rnn', 115 | dropout_rate=params.dropout_rate) 116 | 117 | logits_start = tf_utils.exp_mask( 118 | tf.squeeze( 119 | tf.layers.dense(tf.concat([x1, x2], 2), 1, name='logits1'), 2), 120 | features['context_num_words']) 121 | logits_end = tf_utils.exp_mask( 122 | tf.squeeze( 123 | tf.layers.dense(tf.concat([x1, x3], 2), 1, name='logits2'), 2), 124 | features['context_num_words']) 125 | 126 | return logits_start, logits_end, tensors 127 | -------------------------------------------------------------------------------- /mips-qa.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/mipsqa/5ce97002e16069cac85a267f759014d5cb30cf9e/mips-qa.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=1.3.0 2 | nltk 3 | tqdm 4 | 5 | -------------------------------------------------------------------------------- /squad_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """SQuAD data parsing module for tf.learn model. 15 | 16 | This module loads TFRecord and hyperparameters from a specified directory 17 | (files dumped by `squad_prepro.py`) and provides tensors for data feeding. 18 | This module also provides data-specific functions for evaluation. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from collections import Counter 26 | import json 27 | import os 28 | import re 29 | import string 30 | 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | import squad_prepro 35 | 36 | 37 | def get_input_fn(root_data_dir, 38 | glove_dir, 39 | data_type, 40 | batch_size, 41 | glove_size, 42 | shuffle_files=True, 43 | shuffle_examples=False, 44 | queue_capacity=5000, 45 | min_after_dequeue=1000, 46 | num_epochs=None, 47 | oom_test=False): 48 | """Get input function for the given data type from the given data directory. 49 | 50 | Args: 51 | root_data_dir: The directory to load data from. Corresponds to `to_dir` 52 | of `squad_prepro_main.py` file. 53 | glove_dir: path to the directory that contains GloVe files. 54 | data_type: `str` object, either `train` or `dev`. 55 | batch_size: Batch size of the inputs. 56 | glove_size: size of GloVe vector to load. 57 | shuffle_files: If `True`, shuffle the queue for the input files. 58 | shuffle_examples: If `True`, shuffle the queue for the examples. 59 | queue_capacity: `int`, maximum number of examples in input queue. 60 | min_after_dequeue: `int`, for`RandomShuffleQueue`, minimum number of 61 | examples before dequeueing to ensure randomness. 62 | num_epochs: Number of epochs on the data. `None` means infinite. 63 | This queue comes after the file queue. 64 | oom_test: Stress test to see if the current dataset and model causes 65 | out-of-memory error on GPU. 66 | Returns: 67 | Function definition `input_fn` compatible with `Experiment` object. 68 | """ 69 | filenames = tf.gfile.Glob( 70 | os.path.join(root_data_dir, data_type, 'data', 'squad_data_*')) 71 | tf.logging.info('reading examples from following files:') 72 | for filename in filenames: 73 | tf.logging.info(filename) 74 | sequence_feature = tf.FixedLenSequenceFeature( 75 | [], tf.int64, allow_missing=True) 76 | str_sequence_feature = tf.FixedLenSequenceFeature( 77 | [], tf.string, allow_missing=True) 78 | int_feature = tf.FixedLenFeature([], tf.int64) 79 | str_feature = tf.FixedLenFeature([], tf.string) 80 | # Let N = batch_size, JX = max num context words, JQ = max num ques words, 81 | # C = num chars per word (fixed, default = 16) 82 | features = { 83 | 'indexed_context_words': sequence_feature, # Shape = [JX] 84 | 'glove_indexed_context_words': sequence_feature, 85 | 'indexed_context_chars': sequence_feature, # Shape = [JX * C] 86 | 'indexed_question_words': sequence_feature, # Shape = [JQ] 87 | 'glove_indexed_question_words': sequence_feature, 88 | 'indexed_question_chars': sequence_feature, # Shape = [JQ * C] 89 | 'word_answer_starts': sequence_feature, # Answer start index. 90 | 'word_answer_ends': sequence_feature, # Answer end index. 91 | 'context_num_words': 92 | int_feature, # Number of context words in each example. [A] 93 | 'question_num_words': 94 | int_feature, # Number of question words in each example. [A] 95 | 'answers': str_sequence_feature, # List of answers in each example. [A] 96 | 'context_words': str_sequence_feature, # [JX] 97 | 'question_words': str_sequence_feature, # [JQ] 98 | 'context': str_feature, 99 | 'id': str_feature, 100 | 'num_answers': int_feature, 101 | 'question': str_feature, 102 | } 103 | 104 | exp_metadata_path = os.path.join(root_data_dir, 'metadata.json') 105 | with tf.gfile.GFile(exp_metadata_path, 'r') as fp: 106 | exp_metadata = json.load(fp) 107 | 108 | metadata_path = os.path.join(root_data_dir, data_type, 'metadata.json') 109 | with tf.gfile.GFile(metadata_path, 'r') as fp: 110 | metadata = json.load(fp) 111 | emb_mat = squad_prepro.get_idx2vec_mat(glove_dir, glove_size, 112 | metadata['glove_word2idx']) 113 | 114 | def _input_fn(): 115 | """Input function compatible with `Experiment` object. 116 | 117 | Returns: 118 | A tuple of feature tensors and target tensors. 119 | """ 120 | # TODO(seominjoon): There is bottleneck in data feeding, slow for N >= 128. 121 | filename_queue = tf.train.string_input_producer( 122 | filenames, shuffle=shuffle_files, num_epochs=num_epochs) 123 | reader = tf.TFRecordReader() 124 | _, se = reader.read(filename_queue) 125 | # TODO(seominjoon): Consider moving data filtering to here. 126 | features_op = tf.parse_single_example(se, features=features) 127 | 128 | names = list(features_op.keys()) 129 | dtypes = [features_op[name].dtype for name in names] 130 | shapes = [features_op[name].shape for name in names] 131 | 132 | if shuffle_examples: 133 | # Data shuffling. 134 | rq = tf.RandomShuffleQueue( 135 | queue_capacity, min_after_dequeue, dtypes, names=names) 136 | else: 137 | rq = tf.FIFOQueue(queue_capacity, dtypes, names=names) 138 | enqueue_op = rq.enqueue(features_op) 139 | dequeue_op = rq.dequeue() 140 | dequeue_op = [dequeue_op[name] for name in names] 141 | qr = tf.train.QueueRunner(rq, [enqueue_op]) 142 | tf.train.add_queue_runner(qr) 143 | 144 | batch = tf.train.batch( 145 | dequeue_op, 146 | batch_size, 147 | capacity=queue_capacity, 148 | dynamic_pad=True, 149 | shapes=shapes, 150 | allow_smaller_final_batch=True, 151 | num_threads=5) 152 | batch = {name: each for name, each in zip(names, batch)} 153 | target_keys = [ 154 | 'word_answer_starts', 'word_answer_ends', 'answers', 'num_answers' 155 | ] 156 | # TODO(seominjoon) For cheating-safe, comment out #. 157 | features_batch = { 158 | key: val 159 | for key, val in batch.items() # if key not in target_keys 160 | } 161 | 162 | # `metadata['emb_mat`]` contains GloVe embedding, and `xv` in 163 | # `features_batch` index into the vectors. 164 | features_batch['emb_mat'] = tf.constant(emb_mat) 165 | targets_batch = {key: batch[key] for key in target_keys} 166 | 167 | # Postprocessing for character data. 168 | # Due to the limitation of the python wrapper for prototxt, 169 | # the characters (by index) need to be flattened when saving on prototxt. 170 | # The following 'unflattens' the character tensor. 171 | actual_batch_size = tf.shape(batch['indexed_context_chars'])[0] 172 | features_batch['indexed_context_chars'] = tf.reshape( 173 | features_batch['indexed_context_chars'], 174 | [actual_batch_size, -1, metadata['num_chars_per_word']]) 175 | features_batch['indexed_question_chars'] = tf.reshape( 176 | features_batch['indexed_question_chars'], 177 | [actual_batch_size, -1, metadata['num_chars_per_word']]) 178 | 179 | # Make sure answer start and end positions are less than sequence lengths. 180 | # TODO(seominjoon) This will need to move to a separate test. 181 | with tf.control_dependencies([ 182 | tf.assert_less( 183 | tf.reduce_max(targets_batch['word_answer_starts'], 1), 184 | features_batch['context_num_words']) 185 | ]): 186 | targets_batch['word_answer_starts'] = tf.identity( 187 | targets_batch['word_answer_starts']) 188 | with tf.control_dependencies([ 189 | tf.assert_less( 190 | tf.reduce_max(targets_batch['word_answer_ends'], 1), 191 | features_batch['context_num_words']) 192 | ]): 193 | targets_batch['word_answer_ends'] = tf.identity( 194 | targets_batch['word_answer_ends']) 195 | 196 | # Stress test to ensure no OOM for GPU occurs. 197 | if oom_test: 198 | features_batch['indexed_context_words'] = tf.constant( 199 | np.ones( 200 | [batch_size, exp_metadata['max_context_size']], dtype='int64')) 201 | features_batch['glove_indexed_context_words'] = tf.constant( 202 | np.ones( 203 | [batch_size, exp_metadata['max_context_size']], dtype='int64')) 204 | features_batch['indexed_context_chars'] = tf.constant( 205 | np.ones( 206 | [ 207 | batch_size, exp_metadata['max_context_size'], exp_metadata[ 208 | 'num_chars_per_word'] 209 | ], 210 | dtype='int64')) 211 | features_batch['indexed_question_words'] = tf.constant( 212 | np.ones([batch_size, exp_metadata['max_ques_size']], dtype='int64')) 213 | features_batch['glove_indexed_question_words'] = tf.constant( 214 | np.ones([batch_size, exp_metadata['max_ques_size']], dtype='int64')) 215 | features_batch['indexed_question_chars'] = tf.constant( 216 | np.ones( 217 | [ 218 | batch_size, exp_metadata['max_ques_size'], exp_metadata[ 219 | 'num_chars_per_word'] 220 | ], 221 | dtype='int64')) 222 | features_batch['question_num_words'] = tf.constant( 223 | np.ones([batch_size], dtype='int64') * exp_metadata['max_ques_size']) 224 | features_batch['context_num_words'] = tf.constant( 225 | np.ones([batch_size], dtype='int64') * 226 | exp_metadata['max_context_size']) 227 | 228 | return features_batch, targets_batch 229 | 230 | return _input_fn 231 | 232 | 233 | def get_params(root_data_dir): 234 | """Load data-specific parameters from `root_data_dir`. 235 | 236 | Args: 237 | root_data_dir: The data directory to load parameter files from. 238 | This is equivalent to the `output_dir` of `data/squad_prepro.py`. 239 | Returns: 240 | A dict of hyperparameters. 241 | """ 242 | indexer_path = os.path.join(root_data_dir, 'indexer.json') 243 | with tf.gfile.GFile(indexer_path, 'r') as fp: 244 | indexer = json.load(fp) 245 | 246 | return { 247 | 'vocab_size': len(indexer['word2idx']), 248 | 'char_vocab_size': len(indexer['char2idx']), 249 | } 250 | 251 | 252 | def get_eval_metric_ops(targets, predictions): 253 | """Get a dictionary of eval metrics for `Experiment` object. 254 | 255 | Args: 256 | targets: `targets` that go into `model_fn` of `Experiment`. 257 | predictions: Dictionary of predictions, output of `get_preds`. 258 | Returns: 259 | A dictionary of eval metrics. 260 | """ 261 | # TODO(seominjoon): yp should also consider no answer case. 262 | yp1 = tf.expand_dims(predictions['yp1'], -1) 263 | yp2 = tf.expand_dims(predictions['yp2'], -1) 264 | answer_mask = tf.sequence_mask(targets['num_answers']) 265 | start_correct = tf.reduce_any( 266 | tf.equal(targets['word_answer_starts'], yp1) & answer_mask, 1) 267 | end_correct = tf.reduce_any( 268 | tf.equal(targets['word_answer_ends'], yp2) & answer_mask, 1) 269 | correct = start_correct & end_correct 270 | em = tf.py_func( 271 | _enum_fn(_exact_match_score, dtype='float32'), [ 272 | predictions['a'], targets['answers'], predictions['has_answer'], 273 | answer_mask 274 | ], 'float32') 275 | f1 = tf.py_func( 276 | _enum_fn(_f1_score, dtype='float32'), [ 277 | predictions['a'], targets['answers'], predictions['has_answer'], 278 | answer_mask 279 | ], 'float32') 280 | 281 | eval_metric_ops = { 282 | 'acc1': tf.metrics.mean(tf.cast(start_correct, 'float')), 283 | 'acc2': tf.metrics.mean(tf.cast(end_correct, 'float')), 284 | 'acc': tf.metrics.mean(tf.cast(correct, 'float')), 285 | 'em': tf.metrics.mean(em), 286 | 'f1': tf.metrics.mean(f1), 287 | } 288 | return eval_metric_ops 289 | 290 | 291 | def get_answer_op(context, context_words, answer_start, answer_end): 292 | return tf.py_func( 293 | _enum_fn(_get_answer), [context, context_words, answer_start, answer_end], 294 | 'string') 295 | 296 | 297 | def _get_answer(context, context_words, answer_start, answer_end): 298 | """Get answer given context, context_words, and span. 299 | 300 | Args: 301 | context: A list of bytes, to be decoded with utf-8. 302 | context_words: A list of a list of bytes, to be decoded with utf-8. 303 | answer_start: An int for answer start. 304 | answer_end: An int for answer end. 305 | Returns: 306 | A list of bytes, encoded with utf-8, for the answer. 307 | """ 308 | context = context.decode('utf-8') 309 | context_words = [word.decode('utf-8') for word in context_words] 310 | pos = 0 311 | answer_start_char = None 312 | answer_end_char = None 313 | for i, word in enumerate(context_words): 314 | pos = context.index(word, pos) 315 | if answer_start == i: 316 | answer_start_char = pos 317 | pos += len(word) 318 | if answer_end == i: 319 | answer_end_char = pos 320 | break 321 | assert answer_start_char is not None, ( 322 | '`answer_start` is not found in context. ' 323 | 'context=`%s`, context_words=`%r`, ' 324 | 'answer_start=%d, answer_end=%d') % (context, context_words, answer_start, 325 | answer_end) 326 | assert answer_end_char is not None, ( 327 | '`answer_end` is not found in context. ' 328 | 'context=`%s`, context_words=`%r`, ' 329 | 'answer_start=%d, answer_end=%d') % (context, context_words, answer_start, 330 | answer_end) 331 | answer = context[answer_start_char:answer_end_char].encode('utf-8') 332 | return answer 333 | 334 | 335 | def _f1_score(prediction, ground_truths, has_answer, answer_mask): 336 | prediction = prediction.decode('utf-8') 337 | ground_truths = [ 338 | ground_truth.decode('utf-8') for ground_truth in ground_truths 339 | ] 340 | if not has_answer: 341 | return float(ground_truths[0] == squad_prepro.NO_ANSWER) 342 | elif ground_truths[0] == squad_prepro.NO_ANSWER: 343 | return 0.0 344 | else: 345 | scores = np.array([ 346 | _f1_score_(prediction, ground_truth) for ground_truth in ground_truths 347 | ]) 348 | return max(scores * answer_mask.astype(float)) 349 | 350 | 351 | def _exact_match_score(prediction, ground_truths, has_answer, answer_mask): 352 | prediction = prediction.decode('utf-8') 353 | ground_truths = [ 354 | ground_truth.decode('utf-8') for ground_truth in ground_truths 355 | ] 356 | if not has_answer: 357 | return float(ground_truths[0] == squad_prepro.NO_ANSWER) 358 | elif ground_truths[0] == squad_prepro.NO_ANSWER: 359 | return 0.0 360 | else: 361 | scores = np.array([ 362 | float(_exact_match_score_(prediction, ground_truth)) 363 | for ground_truth in ground_truths 364 | ]) 365 | return max(scores * answer_mask.astype(float)) 366 | 367 | 368 | def _enum_fn(fn, dtype='object'): 369 | 370 | def new_fn(*args): 371 | return np.array([fn(*each_args) for each_args in zip(*args)], dtype=dtype) 372 | 373 | return new_fn 374 | 375 | 376 | # Functions below are copied from official SQuAD eval script and SHOULD NOT 377 | # BE MODIFIED. 378 | 379 | 380 | def _normalize_answer(s): 381 | """Lower text and remove punctuation, articles and extra whitespace. 382 | 383 | Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED. 384 | 385 | Args: 386 | s: Input text. 387 | Returns: 388 | Normalized text. 389 | """ 390 | 391 | def remove_articles(text): 392 | return re.sub(r'\b(a|an|the)\b', ' ', text) 393 | 394 | def white_space_fix(text): 395 | return ' '.join(text.split()) 396 | 397 | def remove_punc(text): 398 | exclude = set(string.punctuation) 399 | return ''.join(ch for ch in text if ch not in exclude) 400 | 401 | def lower(text): 402 | return text.lower() 403 | 404 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 405 | 406 | 407 | def _f1_score_(prediction, ground_truth): 408 | """Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED.""" 409 | prediction_tokens = _normalize_answer(prediction).split() 410 | ground_truth_tokens = _normalize_answer(ground_truth).split() 411 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 412 | num_same = sum(common.values()) 413 | if num_same == 0: 414 | return 0 415 | precision = 1.0 * num_same / len(prediction_tokens) 416 | recall = 1.0 * num_same / len(ground_truth_tokens) 417 | f1 = (2 * precision * recall) / (precision + recall) 418 | return f1 419 | 420 | 421 | def _exact_match_score_(prediction, ground_truth): 422 | """Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED.""" 423 | return _normalize_answer(prediction) == _normalize_answer(ground_truth) 424 | -------------------------------------------------------------------------------- /squad_prepro.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Library for preprocessing SQuAD. 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from collections import Counter 22 | from collections import OrderedDict 23 | import itertools 24 | import json 25 | import os 26 | 27 | import nltk 28 | import numpy as np 29 | import tensorflow as tf 30 | from tqdm import tqdm 31 | 32 | PAD = u'' 33 | UNK = u'' 34 | NO_ANSWER = u'' 35 | 36 | 37 | class SquadIndexer(object): 38 | """Indexer for SQuAD. 39 | 40 | Instantiating this class loads GloVe. The object can fit examples (creating 41 | vocab out of the examples) and index the examples. 42 | """ 43 | 44 | def __init__(self, glove_dir, tokenizer=None, draft=False): 45 | self._glove_words = list(_get_glove(glove_dir, size=50, draft=draft)) 46 | self._glove_vocab = _get_glove_vocab(self._glove_words) 47 | self._glove_dir = glove_dir 48 | self._draft = draft 49 | self._tokenizer = tokenizer or _word_tokenize 50 | self._word2idx_dict = None 51 | self._char2idx_dict = None 52 | 53 | def word2idx(self, word): 54 | word = word.lower() 55 | if word not in self._word2idx_dict: 56 | return 1 57 | return self._word2idx_dict[word] 58 | 59 | def char2idx(self, char): 60 | if char not in self._char2idx_dict: 61 | return 1 62 | return self._char2idx_dict[char] 63 | 64 | def fit(self, 65 | examples, 66 | min_word_count=10, 67 | min_char_count=100, 68 | num_chars_per_word=16): 69 | """Fits examples and return indexed examples with metadata. 70 | 71 | Fitting examples means the vocab is created out of the examples. 72 | The vocab can be saved via `save` and loaded via `load` methods. 73 | 74 | Args: 75 | examples: list of dictionary, where each dictionary is an example. 76 | min_word_count: `int` value, minimum word count to be included in vocab. 77 | min_char_count: `int` value, minimum char count to be included in vocab. 78 | num_chars_per_word: `int` value, number of chars to store per word. 79 | This is fixed, so if word is shorter, then the rest is padded with 0. 80 | The characters are flattened, so need to be reshaped when using them. 81 | Returns: 82 | a tuple `(indexed_examples, metadata)`, where `indexed_examples` is a 83 | list of dict (each dict being indexed example) and `metadata` is a dict 84 | of `glove_word2idx_dict` and statistics of the examples. 85 | """ 86 | tokenized_examples = [ 87 | _tokenize(example, self._tokenizer, num_chars_per_word) 88 | for example in tqdm(examples, 'tokenizing') 89 | ] 90 | word_counter = _get_word_counter(tokenized_examples) 91 | char_counter = _get_char_counter(tokenized_examples) 92 | self._word2idx_dict = _counter2vocab(word_counter, min_word_count) 93 | tf.logging.info('Word vocab size: %d' % len(self._word2idx_dict)) 94 | self._char2idx_dict = _counter2vocab(char_counter, min_char_count) 95 | tf.logging.info('Char vocab size: %d' % len(self._char2idx_dict)) 96 | glove_word2idx_dict = _get_glove_vocab( 97 | self._glove_words, counter=word_counter) 98 | tf.logging.info('Glove word vocab size: %d' % len(glove_word2idx_dict)) 99 | 100 | def glove_word2idx(word): 101 | word = word.lower() 102 | return glove_word2idx_dict[word] if word in glove_word2idx_dict else 1 103 | 104 | indexed_examples = [ 105 | _index(example, self.word2idx, glove_word2idx, self.char2idx) 106 | for example in tqdm(tokenized_examples, desc='indexing') 107 | ] 108 | 109 | metadata = self._get_metadata(indexed_examples) 110 | metadata['glove_word2idx'] = glove_word2idx_dict 111 | metadata['num_chars_per_word'] = num_chars_per_word 112 | 113 | return indexed_examples, metadata 114 | 115 | def prepro_eval(self, examples, num_chars_per_word=16): 116 | """Tokenizes and indexes examples (usually non-train examples). 117 | 118 | In order to use this, `fit` must have been already executed on train data. 119 | Other than that, this function has same functionality as `fit`, returning 120 | indexed examples. 121 | 122 | Args: 123 | examples: a list of dict, where each dict is an example. 124 | num_chars_per_word: `int` value, number of chars to store per word. 125 | This is fixed, so if word is shorter, then the rest is padded with 0. 126 | The charaters are flattened, so need to be reshaped when using them. 127 | Returns: 128 | a tuple `(indexed_examples, metadata)`, where `indexed_examples` is a 129 | list of dict (each dict being indexed example) and `metadata` is a dict 130 | of `glove_word2idx_dict` and statistics of the examples. 131 | """ 132 | tokenized_examples = [ 133 | _tokenize(example, self._tokenizer, num_chars_per_word) 134 | for example in tqdm(examples, desc='tokenizing') 135 | ] 136 | word_counter = _get_word_counter(tokenized_examples) 137 | glove_word2idx_dict = _get_glove_vocab( 138 | self._glove_words, counter=word_counter) 139 | tf.logging.info('Glove word vocab size: %d' % len(glove_word2idx_dict)) 140 | 141 | def glove_word2idx(word): 142 | word = word.lower() 143 | return glove_word2idx_dict[word] if word in glove_word2idx_dict else 1 144 | 145 | indexed_examples = [ 146 | _index(example, self.word2idx, glove_word2idx, self.char2idx) 147 | for example in tqdm(tokenized_examples, desc='indexing') 148 | ] 149 | 150 | metadata = self._get_metadata(indexed_examples) 151 | metadata['glove_word2idx'] = glove_word2idx_dict 152 | metadata['num_chars_per_word'] = num_chars_per_word 153 | return indexed_examples, metadata 154 | 155 | @property 156 | def savable(self): 157 | return {'word2idx': self._word2idx_dict, 'char2idx': self._char2idx_dict} 158 | 159 | def save(self, save_path): 160 | with tf.gfile.GFile(save_path, 'w') as fp: 161 | json.dump(self.savable, fp) 162 | 163 | def load(self, load_path): 164 | with tf.gfile.GFile(load_path, 'r') as fp: 165 | savable = json.load(fp) 166 | self._word2idx_dict = savable['word2idx'] 167 | self._char2idx_dict = savable['char2idx'] 168 | 169 | def _get_metadata(self, examples): 170 | metadata = { 171 | 'max_context_size': max(len(e['context_words']) for e in examples), 172 | 'max_ques_size': max(len(e['question_words']) for e in examples), 173 | 'word_vocab_size': len(self._word2idx_dict), 174 | 'char_vocab_size': len(self._char2idx_dict), 175 | } 176 | return metadata 177 | 178 | 179 | def split(example, para2sents_fn=None, positive_augment_factor=0): 180 | """Splits context in example into sentences and create multiple examples. 181 | 182 | Args: 183 | example: `dict` object, each element of `get_examples()`. 184 | para2sents_fn: function that maps `str` to a list of `str`, splitting 185 | paragraph into sentences. 186 | positive_augment_factor: Multiply positive examples by this factor. 187 | For handling class imbalance problem. 188 | Returns: 189 | a list of examples, with modified fields: `id`, `context`, `answers` and 190 | `answer_starts`. Will add `has_answer` bool field. 191 | """ 192 | if para2sents_fn is None: 193 | para2sents_fn = nltk.sent_tokenize 194 | sents = para2sents_fn(example['context']) 195 | sent_start_idxs = _tokens2idxs(example['context'], sents) 196 | 197 | context = example['context'] 198 | examples = [] 199 | for i, (sent, sent_start_idx) in enumerate(zip(sents, sent_start_idxs)): 200 | sent_end_idx = sent_start_idx + len(sent) 201 | e = dict(example.items()) # Copying dict content. 202 | e['context'] = sent 203 | e['id'] = '%s %d' % (e['id'], i) 204 | e['answers'] = [] 205 | e['answer_starts'] = [] 206 | for answer, answer_start in zip(example['answers'], 207 | example['answer_starts']): 208 | answer_end = answer_start + len(answer) 209 | if (sent_start_idx <= answer_start < sent_end_idx or 210 | sent_start_idx < answer_end <= sent_end_idx): 211 | new_answer = context[max(sent_start_idx, answer_start):min( 212 | sent_end_idx, answer_end)] 213 | new_answer_start = max(answer_start, sent_start_idx) - sent_start_idx 214 | e['answers'].append(new_answer) 215 | e['answer_starts'].append(new_answer_start) 216 | if not e['answers']: 217 | e['answers'].append(NO_ANSWER) 218 | e['answer_starts'].append(-1) 219 | e['num_answers'] = len(e['answers']) 220 | # If the list is empty, then the example has no answer. 221 | examples.append(e) 222 | if positive_augment_factor and e['answers'][0] != NO_ANSWER: 223 | for _ in range(positive_augment_factor): 224 | examples.append(e) 225 | return examples 226 | 227 | 228 | def get_idx2vec_mat(glove_dir, size, glove_word2idx_dict): 229 | """Gets embedding matrix for given GloVe vocab.""" 230 | glove = _get_glove(glove_dir, size=size) 231 | glove[PAD] = glove[UNK] = [0.0] * size 232 | idx2vec_dict = {idx: glove[word] for word, idx in glove_word2idx_dict.items()} 233 | idx2vec_mat = np.array( 234 | [idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32') 235 | return idx2vec_mat 236 | 237 | 238 | def get_examples(squad_path): 239 | """Obtain a list of examples from official SQuAD file. 240 | 241 | Args: 242 | squad_path: path to the official SQuAD file (e.g. `train-v1.1.json`). 243 | Returns: 244 | a list of dict, where each dict is example. 245 | """ 246 | with tf.gfile.GFile(squad_path, 'r') as fp: 247 | squad = json.load(fp) 248 | 249 | examples = [] 250 | version = squad['version'] 251 | for article in squad['data']: 252 | title = article['title'] 253 | for paragraph in article['paragraphs']: 254 | context = paragraph['context'] 255 | for qa in paragraph['qas']: 256 | question = qa['question'] 257 | id_ = qa['id'] 258 | 259 | answer_starts = [answer['answer_start'] for answer in qa['answers']] 260 | answers = [answer['text'] for answer in qa['answers']] 261 | 262 | example = { 263 | 'version': version, 264 | 'title': title, 265 | 'context': context, 266 | 'question': question, 267 | 'id': id_, 268 | 'answer_starts': answer_starts, 269 | 'answers': answers, 270 | 'num_answers': len(answers), 271 | 'is_supervised': True, 272 | } 273 | example = normalize_example(example) 274 | examples.append(example) 275 | return examples 276 | 277 | 278 | def normalize_example(example): 279 | n_example = dict(example.items()) 280 | n_example['context'] = _replace_quotations(n_example['context']) 281 | n_example['answers'] = [_replace_quotations(a) for a in n_example['answers']] 282 | return n_example 283 | 284 | 285 | def _replace_quotations(text): 286 | return text.replace('``', '" ').replace("''", '" ') 287 | 288 | 289 | def _word_tokenize(text): 290 | # TODO(seominjoon): Consider using Stanford Tokenizer or othe tokenizers. 291 | return [ 292 | word.replace('``', '"').replace("''", '"') 293 | for word in nltk.word_tokenize(text) 294 | ] 295 | 296 | 297 | def _tokens2idxs(text, tokens): 298 | idxs = [] 299 | idx = 0 300 | for token in tokens: 301 | idx = text.find(token, idx) 302 | assert idx >= 0, (text, tokens) 303 | idxs.append(idx) 304 | idx += len(token) 305 | return idxs 306 | 307 | 308 | def _tokenize(example, text2words_fn, num_chars_per_word): 309 | """Tokenize each example using provided tokenizer (`text2words_fn`). 310 | 311 | Args: 312 | example: `dict` value, an example. 313 | text2words_fn: tokenizer. 314 | num_chars_per_word: `int` value, number of chars to store per word. 315 | This is fixed, so if word is shorter, then the rest is padded with 0. 316 | The charaters are flattened, so need to be reshaped when using them. 317 | Returns: 318 | `dict`, representing tokenized example. 319 | """ 320 | new_example = dict(example.items()) 321 | new_example['question_words'] = text2words_fn(example['question']) 322 | new_example['question_num_words'] = len(new_example['question_words']) 323 | new_example['context_words'] = text2words_fn(example['context']) 324 | new_example['context_num_words'] = len(new_example['context_words']) 325 | 326 | def word2chars(word): 327 | chars = list(word) 328 | if len(chars) > num_chars_per_word: 329 | return chars[:num_chars_per_word] 330 | else: 331 | return chars + [PAD] * (num_chars_per_word - len(chars)) 332 | 333 | new_example['question_chars'] = list( 334 | itertools.chain( 335 | * [word2chars(word) for word in new_example['question_words']])) 336 | new_example['context_chars'] = list( 337 | itertools.chain( 338 | * [word2chars(word) for word in new_example['context_words']])) 339 | return new_example 340 | 341 | 342 | def _index(example, word2idx_fn, glove_word2idx_fn, char2idx_fn): 343 | """Indexes each tokenized example, using provided vocabs. 344 | 345 | Args: 346 | example: `dict` representing tokenized example. 347 | word2idx_fn: indexer for word vocab. 348 | glove_word2idx_fn: indexer for glove word vocab. 349 | char2idx_fn: indexer for character vocab. 350 | Returns: 351 | `dict` representing indexed example. 352 | """ 353 | new_example = dict(example.items()) 354 | new_example['indexed_question_words'] = [ 355 | word2idx_fn(word) for word in example['question_words'] 356 | ] 357 | new_example['indexed_context_words'] = [ 358 | word2idx_fn(word) for word in example['context_words'] 359 | ] 360 | new_example['indexed_question_chars'] = [ 361 | char2idx_fn(word) for word in example['question_chars'] 362 | ] 363 | new_example['indexed_context_chars'] = [ 364 | char2idx_fn(word) for word in example['context_chars'] 365 | ] 366 | new_example['glove_indexed_question_words'] = [ 367 | glove_word2idx_fn(word) for word in example['question_words'] 368 | ] 369 | new_example['glove_indexed_context_words'] = [ 370 | glove_word2idx_fn(word) for word in example['context_words'] 371 | ] 372 | 373 | word_answer_starts = [] 374 | word_answer_ends = [] 375 | for answer_start, answer in zip(new_example['answer_starts'], 376 | new_example['answers']): 377 | if answer_start < 0: 378 | word_answer_starts.append(-1) 379 | word_answer_ends.append(-1) 380 | break 381 | word_answer_start, word_answer_end = _get_word_answer( 382 | new_example['context'], new_example['context_words'], answer_start, 383 | answer) 384 | word_answer_starts.append(word_answer_start) 385 | word_answer_ends.append(word_answer_end) 386 | new_example['word_answer_starts'] = word_answer_starts 387 | new_example['word_answer_ends'] = word_answer_ends 388 | 389 | return new_example 390 | 391 | 392 | def _get_glove(glove_path, size=None, draft=False): 393 | """Get an `OrderedDict` that maps word to vector. 394 | 395 | Args: 396 | glove_path: `str` value, 397 | path to the glove file (e.g. `glove.6B.50d.txt`) or directory. 398 | size: `int` value, size of the vector, if `glove_path` is a directory. 399 | draft: `bool` value, whether to only load first 99 for draft mode. 400 | Returns: 401 | `OrderedDict` object, mapping word to vector. 402 | """ 403 | if size is not None: 404 | glove_path = os.path.join(glove_path, 'glove.6B.%dd.txt' % size) 405 | glove = OrderedDict() 406 | with tf.gfile.GFile(glove_path, 'rb') as fp: 407 | for idx, line in enumerate(fp): 408 | line = line.decode('utf-8') 409 | tokens = line.strip().split(u' ') 410 | word = tokens[0] 411 | vec = list(map(float, tokens[1:])) 412 | glove[word] = vec 413 | if draft and idx > 99: 414 | break 415 | return glove 416 | 417 | 418 | def _get_word_counter(examples): 419 | # TODO(seominjoon): Consider not ignoring uppercase. 420 | counter = Counter() 421 | for example in tqdm(examples, desc='word counter'): 422 | for word in example['question_words']: 423 | counter[word.lower()] += 1 424 | for word in example['context_words']: 425 | counter[word.lower()] += 1 426 | return counter 427 | 428 | 429 | def _get_char_counter(examples): 430 | counter = Counter() 431 | for example in tqdm(examples, desc='char counter'): 432 | for chars in example['question_chars']: 433 | for char in chars: 434 | counter[char] += 1 435 | for chars in example['context_chars']: 436 | for char in chars: 437 | counter[char] += 1 438 | return counter 439 | 440 | 441 | def _counter2vocab(counter, min_count): 442 | tokens = [token for token, count in counter.items() if count >= min_count] 443 | tokens = [PAD, UNK] + tokens 444 | vocab = {token: idx for idx, token in enumerate(tokens)} 445 | return vocab 446 | 447 | 448 | def _get_glove_vocab(words, counter=None): 449 | if counter is not None: 450 | words = [word for word in counter if word in set(words)] 451 | words = [PAD, UNK] + words 452 | vocab = {word: idx for idx, word in enumerate(words)} 453 | return vocab 454 | 455 | 456 | def _get_word_answer(context, context_words, answer_start, answer): 457 | """Get word-level answer index. 458 | 459 | Args: 460 | context: `unicode`, representing the context of the question. 461 | context_words: a list of `unicode`, tokenized context. 462 | answer_start: `int`, the char-level start index of the answer. 463 | answer: `unicode`, the answer that is substring of context. 464 | Returns: 465 | a tuple of `(word_answer_start, word_answer_end)`, representing the start 466 | and end indices of the answer in respect to `context_words`. 467 | """ 468 | assert answer, 'Encountered length-0 answer.' 469 | answer_end = answer_start + len(answer) 470 | char_idxs = _tokens2idxs(context, context_words) 471 | word_answer_start = None 472 | word_answer_end = None 473 | for word_idx, char_idx in enumerate(char_idxs): 474 | if char_idx <= answer_start: 475 | word_answer_start = word_idx 476 | if char_idx < answer_end: 477 | word_answer_end = word_idx 478 | return word_answer_start, word_answer_end 479 | -------------------------------------------------------------------------------- /squad_prepro_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Preprocesser for SQuAD. 15 | 16 | `from_dir` will need to contain original SQuAD train/dev data: 17 | 18 | - `train-v1.1.json` 19 | - `dev-v1.1.json` 20 | 21 | In each TFRecord file, following features will be provided. Note that all 22 | strings are encoded with utf-8. 23 | 24 | Directly from the data. 25 | - `id` : string id from original SQuAD data. 26 | - `version` : the version of SQuAD data. 27 | - `title` : the title of the article. 28 | - `question` : original question string. 29 | - `context` : original context string. 30 | 31 | - `answers` : original list of answer strings. Variable length. 32 | - `answer_starts` : original list of integers for answer starts. 33 | 34 | Processed. 35 | - `question_words` : tokenized question. A variable len list of strings. 36 | - `context_words` : tokenized context. Variable length. 37 | - `question_chars` : question chars. 38 | - `context_chars` : context chars. 39 | - `indexed_question_words` : question words indexed by vocab. 40 | - `indexed_context_words` : context words indexed by vocab. 41 | - `glove_indexed_question_words` : question words indexed by GloVe. 42 | - `glove_indexed_context_words` : context words indexed by GloVe. 43 | - `indexed_question_chars` : question chars, flattened and indexed. 44 | - `indexed_context_chars` : ditto. 45 | - `question_num_words` : number of words in question. 46 | - `context_num_words` : number of words in context. 47 | - `is_supervised` : whether the dataset is supervised or not. 48 | 49 | - `num_answers': integer indicating number of answers. 0 means no answer. 50 | - `word_answer_starts` : word-level answer start positions. 51 | - `word_answer_ends` : word-level answer end positions. 52 | """ 53 | 54 | from __future__ import absolute_import 55 | from __future__ import division 56 | from __future__ import print_function 57 | 58 | import itertools 59 | import json 60 | import os 61 | import random 62 | from six import string_types 63 | 64 | import tensorflow as tf 65 | from tqdm import tqdm 66 | import squad_prepro 67 | 68 | tf.flags.DEFINE_string('from_dir', '', 'Directory for original SQuAD data.') 69 | tf.flags.DEFINE_string('to_dir', '', 'Directory for preprocessed SQuAD data.') 70 | tf.flags.DEFINE_string('glove_dir', '', 'Directory for GloVe files.') 71 | tf.flags.DEFINE_string('indexer_dir', '', 'Directory for indexer. ' 72 | 'If specified, does not load train and uses this.') 73 | tf.flags.DEFINE_integer('word_count_th', 100, 'Word count threshold for vocab.') 74 | tf.flags.DEFINE_integer('char_count_th', 100, 'Char count threshold for vocab.') 75 | tf.flags.DEFINE_integer('max_context_size', 256, 76 | 'Maximum context size. Set this to `0` for test, which ' 77 | 'sets no limit to the maximum context size.') 78 | tf.flags.DEFINE_integer('max_ques_size', 32, 79 | 'Maximum question size. Set this to `0` for test.') 80 | tf.flags.DEFINE_integer('num_chars_per_word', 16, 81 | 'Fixed number of characters per word.') 82 | tf.flags.DEFINE_boolean( 83 | 'split', False, 84 | 'if `True`, each context will be sentence instead of paragraph, ' 85 | 'and answer label (word index) will be `None` in case of no answer.') 86 | tf.flags.DEFINE_boolean('filter', False, 87 | 'If `True`, filters data by context and question sizes') 88 | tf.flags.DEFINE_boolean('sort', False, 89 | 'If `True`, sorts data by context length.') 90 | tf.flags.DEFINE_boolean('shuffle', True, 'If `True`, shuffle examples.') 91 | tf.flags.DEFINE_boolean('draft', False, 'If `True`, fast draft mode, ' 92 | 'which only loads first few examples.') 93 | tf.flags.DEFINE_integer('max_shard_size', 1000, 'Max size of each shard.') 94 | tf.flags.DEFINE_integer('positive_augment_factor', 0, 'Augment positive ' 95 | 'examples by this factor.') 96 | tf.flags.DEFINE_boolean('answerable', False, 'This flag ' 97 | 'allows one to use only examples that have answers.') 98 | 99 | FLAGS = tf.flags.FLAGS 100 | 101 | 102 | def get_tf_example(example): 103 | """Get `tf.train.Example` object from example dict. 104 | 105 | Args: 106 | example: tokenized, indexed example. 107 | Returns: 108 | `tf.train.Example` object corresponding to the example. 109 | Raises: 110 | ValueError: if a key in `example` is invalid. 111 | """ 112 | feature = {} 113 | for key, val in example.items(): 114 | if not isinstance(val, list): 115 | val = [val] 116 | if val: 117 | if isinstance(val[0], string_types): 118 | dtype = 'bytes' 119 | elif isinstance(val[0], int): 120 | dtype = 'int64' 121 | else: 122 | raise TypeError('`%s` has an invalid type: %r' % (key, type(val[0]))) 123 | else: 124 | if key == 'answers': 125 | dtype = 'bytes' 126 | elif key in ['answer_starts', 'word_answer_starts', 'word_answer_ends']: 127 | dtype = 'int64' 128 | else: 129 | raise ValueError(key) 130 | 131 | if dtype == 'bytes': 132 | # Transform unicode into bytes if necessary. 133 | val = [each.encode('utf-8') for each in val] 134 | feature[key] = tf.train.Feature(bytes_list=tf.train.BytesList(value=val)) 135 | elif dtype == 'int64': 136 | feature[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=val)) 137 | else: 138 | raise TypeError('`%s` has an invalid type: %r' % (key, type(val[0]))) 139 | return tf.train.Example(features=tf.train.Features(feature=feature)) 140 | 141 | 142 | def dump(examples, metadata, data_type): 143 | """Dumps examples as TFRecord files. 144 | 145 | Args: 146 | examples: a `list` of `dict`, where each dict is indexed example. 147 | metadata: `dict`, metadata corresponding to the examples. 148 | data_type: `str`, representing the data type of the examples (e.g. `train`). 149 | """ 150 | out_dir = os.path.join(FLAGS.to_dir, data_type) 151 | metadata_path = os.path.join(out_dir, 'metadata.json') 152 | data_dir = os.path.join(out_dir, 'data') 153 | tf.gfile.MakeDirs(out_dir) 154 | tf.gfile.MakeDirs(data_dir) 155 | 156 | with tf.gfile.GFile(metadata_path, 'w') as fp: 157 | json.dump(metadata, fp) 158 | 159 | # Dump stuff 160 | writer = None 161 | counter = 0 162 | num_shards = 0 163 | for example in tqdm(examples): 164 | if writer is None: 165 | path = os.path.join(data_dir, 166 | 'squad_data_{}'.format(str(num_shards).zfill(4))) 167 | writer = tf.python_io.TFRecordWriter(path) 168 | tf_example = get_tf_example(example) 169 | writer.write(tf_example.SerializeToString()) 170 | counter += 1 171 | if counter == FLAGS.max_shard_size: 172 | counter = 0 173 | writer.close() 174 | writer = None 175 | num_shards += 1 176 | if writer is not None: 177 | writer.close() 178 | 179 | 180 | def prepro(data_type, indexer=None): 181 | """Preprocesses the given data type.""" 182 | squad_path = os.path.join(FLAGS.from_dir, '%s-v1.1.json' % data_type) 183 | tf.logging.info('Loading %s' % squad_path) 184 | examples = squad_prepro.get_examples(squad_path) 185 | 186 | if FLAGS.draft: 187 | examples = random.sample(examples, 100) 188 | 189 | if FLAGS.split: 190 | tf.logging.info('Splitting each example') 191 | tf.logging.info('Before splitting: %d %s examples' % (len(examples), 192 | data_type)) 193 | examples = list( 194 | itertools.chain(* [ 195 | squad_prepro.split( 196 | e, positive_augment_factor=FLAGS.positive_augment_factor) 197 | for e in tqdm(examples) 198 | ])) 199 | tf.logging.info('After splitting: %d %s examples' % (len(examples), 200 | data_type)) 201 | 202 | if FLAGS.answerable: 203 | tf.logging.info('Using only answerable examples.') 204 | examples = [ 205 | example for example in examples 206 | if example['answers'][0] != squad_prepro.NO_ANSWER 207 | ] 208 | 209 | if FLAGS.shuffle: 210 | tf.logging.info('Shuffling examples') 211 | random.shuffle(examples) 212 | 213 | if indexer is None: 214 | tf.logging.info('Creating indexer') 215 | indexer = squad_prepro.SquadIndexer(FLAGS.glove_dir, draft=FLAGS.draft) 216 | 217 | tf.logging.info('Indexing %s data' % data_type) 218 | indexed_examples, metadata = indexer.fit( 219 | examples, 220 | min_word_count=FLAGS.word_count_th, 221 | min_char_count=FLAGS.char_count_th, 222 | num_chars_per_word=FLAGS.num_chars_per_word) 223 | else: 224 | indexed_examples, metadata = indexer.prepro_eval( 225 | examples, num_chars_per_word=FLAGS.num_chars_per_word) 226 | tf.gfile.MakeDirs(FLAGS.to_dir) 227 | indexer_save_path = os.path.join(FLAGS.to_dir, 'indexer.json') 228 | tf.logging.info('Saving indexer') 229 | indexer.save(indexer_save_path) 230 | 231 | if FLAGS.filter: 232 | tf.logging.info('Filtering examples') 233 | tf.logging.info('Before filtering: %d %s examples' % (len(indexed_examples), 234 | data_type)) 235 | indexed_examples = [ 236 | e for e in indexed_examples 237 | if len(e['context_words']) <= FLAGS.max_context_size and 238 | len(e['question_words']) <= FLAGS.max_ques_size 239 | ] 240 | tf.logging.info('After filtering: %d %s examples' % (len(indexed_examples), 241 | data_type)) 242 | tf.logging.info('Has answers: %d %s examples' % 243 | (sum(1 for e in indexed_examples 244 | if e['answer_starts'][0] >= 0), data_type)) 245 | metadata['max_context_size'] = max( 246 | len(e['context_words']) for e in indexed_examples) 247 | metadata['max_ques_size'] = max( 248 | len(e['question_words']) for e in indexed_examples) 249 | 250 | if FLAGS.sort: 251 | tf.logging.info('Sorting examples') 252 | indexed_examples = sorted( 253 | indexed_examples, key=lambda e: len(e['context_words'])) 254 | 255 | tf.logging.info('Dumping %s examples' % data_type) 256 | dump(indexed_examples, metadata, data_type) 257 | 258 | return indexer, metadata 259 | 260 | 261 | def main(argv): 262 | del argv 263 | assert not tf.gfile.Exists(FLAGS.to_dir), '%s already exists.' % FLAGS.to_dir 264 | 265 | if FLAGS.indexer_dir: 266 | indexer_path = os.path.join(FLAGS.indexer_dir, 'indexer.json') 267 | tf.logging.info('Loading indexer from %s' % indexer_path) 268 | indexer = squad_prepro.SquadIndexer(FLAGS.glove_dir, draft=FLAGS.draft) 269 | indexer.load(indexer_path) 270 | else: 271 | indexer, train_metadata = prepro('train') 272 | _, dev_metadata = prepro('dev', indexer=indexer) 273 | 274 | if FLAGS.indexer_dir: 275 | exp_metadata = dev_metadata 276 | else: 277 | exp_metadata = { 278 | 'max_context_size': 279 | max(train_metadata['max_context_size'], 280 | dev_metadata['max_context_size']), 281 | 'max_ques_size': 282 | max(train_metadata['max_ques_size'], dev_metadata['max_ques_size']), 283 | 'num_chars_per_word': 284 | max(train_metadata['num_chars_per_word'], 285 | dev_metadata['num_chars_per_word']) 286 | } 287 | 288 | tf.logging.info('Dumping experiment metadata') 289 | exp_metadata_path = os.path.join(FLAGS.to_dir, 'metadata.json') 290 | with tf.gfile.GFile(exp_metadata_path, 'w') as fp: 291 | json.dump(exp_metadata, fp) 292 | 293 | 294 | if __name__ == '__main__': 295 | tf.app.run(main) 296 | -------------------------------------------------------------------------------- /tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """TensorFlow utilities for extractive question answering models. 15 | """ 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | VERY_LARGE_NEGATIVE_VALUE = -1e12 23 | VERY_SMALL_POSITIVE_VALUE = 1e-12 24 | 25 | 26 | def bi_rnn(hidden_size, 27 | inputs_list, 28 | sequence_length_list=None, 29 | scope=None, 30 | dropout_rate=0.0, 31 | training=False, 32 | stack=False, 33 | cells=None, 34 | postprocess='concat', 35 | return_outputs=True, 36 | out_dim=None, 37 | reuse=False): 38 | """Bidirectional RNN with `BasicLSTMCell`. 39 | 40 | Args: 41 | hidden_size: `int` value, the hidden state size of the LSTM. 42 | inputs_list: A list of `inputs` tensors, where each `inputs` is 43 | single sequence tensor with shape [batch_size, seq_len, hidden_size]. 44 | Can be single element instead of list. 45 | sequence_length_list: A list of `sequence_length` tensors. 46 | The size of the list should equal to that of `inputs_list`. 47 | Can be a single element instead of a list. 48 | scope: `str` value, variable scope for this function. 49 | dropout_rate: `float` value, dropout rate of LSTM, applied at the inputs. 50 | training: `bool` value, whether current run is training. 51 | stack: `bool` value, whether to stack instead of simultaneous bi-LSTM. 52 | cells: two `RNNCell` instances. If provided, `hidden_size` is ignored. 53 | postprocess: `str` value: `raw` or `concat` or `add`. 54 | Postprocessing on forward and backward outputs of LSTM. 55 | return_outputs: `bool` value, whether to return sequence outputs. 56 | Otherwise, return the last state. 57 | out_dim: `bool` value. If `postprocess` is `linear, then this indicates 58 | the output dim of the linearity. 59 | reuse: `bool` value, whether to reuse variables. 60 | Returns: 61 | A list `return_list` where each element corresponds to each element of 62 | `input_list`. If the `input_list` is a tensor, also returns a tensor. 63 | Raises: 64 | ValueError: If argument `postprocess` is an invalid value. 65 | """ 66 | if not isinstance(inputs_list, list): 67 | inputs_list = [inputs_list] 68 | if sequence_length_list is None: 69 | sequence_length_list = [None] * len(inputs_list) 70 | elif not isinstance(sequence_length_list, list): 71 | sequence_length_list = [sequence_length_list] 72 | assert len(inputs_list) == len( 73 | sequence_length_list 74 | ), '`inputs_list` and `sequence_length_list` must have same lengths.' 75 | with tf.variable_scope(scope or 'bi_rnn', reuse=reuse) as vs: 76 | if cells is not None: 77 | cell_fw = cells[0] 78 | cell_bw = cells[1] 79 | else: 80 | cell_fw = tf.contrib.rnn.BasicLSTMCell(hidden_size, reuse=reuse) 81 | cell_bw = tf.contrib.rnn.BasicLSTMCell(hidden_size, reuse=reuse) 82 | return_list = [] 83 | for inputs, sequence_length in zip(inputs_list, sequence_length_list): 84 | if return_list: 85 | vs.reuse_variables() 86 | if dropout_rate > 0.0: 87 | inputs = tf.layers.dropout(inputs, rate=dropout_rate, training=training) 88 | if stack: 89 | o_bw, state_bw = tf.nn.dynamic_rnn( 90 | cell_bw, 91 | tf.reverse_sequence(inputs, sequence_length, seq_dim=1), 92 | sequence_length=sequence_length, 93 | dtype='float', 94 | scope='rnn_bw') 95 | o_bw = tf.reverse_sequence(o_bw, sequence_length, seq_dim=1) 96 | if dropout_rate > 0.0: 97 | o_bw = tf.layers.dropout(o_bw, rate=dropout_rate, training=training) 98 | o_fw, state_fw = tf.nn.dynamic_rnn( 99 | cell_fw, 100 | o_bw, 101 | sequence_length=sequence_length, 102 | dtype='float', 103 | scope='rnn_fw') 104 | else: 105 | (o_fw, o_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( 106 | cell_fw, 107 | cell_bw, 108 | inputs, 109 | sequence_length=sequence_length, 110 | dtype='float') 111 | return_fw = o_fw if return_outputs else state_fw[-1] 112 | return_bw = o_bw if return_outputs else state_bw[-1] 113 | if postprocess == 'raw': 114 | return_ = return_fw, return_bw 115 | elif postprocess == 'concat': 116 | return_ = tf.concat([return_fw, return_bw], 2 if return_outputs else 1) 117 | elif postprocess == 'add': 118 | return_ = return_fw + return_bw 119 | elif postprocess == 'max': 120 | return_ = tf.maximum(return_fw, return_bw) 121 | elif postprocess == 'linear': 122 | if out_dim is None: 123 | out_dim = 2 * hidden_size 124 | return_ = tf.concat([return_fw, return_bw], 2 if return_outputs else 1) 125 | return_ = tf.layers.dense(return_, out_dim) 126 | else: 127 | return_ = postprocess(return_fw, return_bw) 128 | return_list.append(return_) 129 | if len(return_list) == 1: 130 | return return_list[0] 131 | return return_list 132 | 133 | 134 | def exp_mask(logits, mask, mask_is_length=True): 135 | """Exponential mask for logits. 136 | 137 | Logits cannot be masked with 0 (i.e. multiplying boolean mask) 138 | because expnentiating 0 becomes 1. `exp_mask` adds very large negative value 139 | to `False` portion of `mask` so that the portion is effectively ignored 140 | when exponentiated, e.g. softmaxed. 141 | 142 | Args: 143 | logits: Arbitrary-rank logits tensor to be masked. 144 | mask: `boolean` type mask tensor. 145 | Could be same shape as logits (`mask_is_length=False`) 146 | or could be length tensor of the logits (`mask_is_length=True`). 147 | mask_is_length: `bool` value. whether `mask` is boolean mask. 148 | Returns: 149 | Masked logits with the same shape of `logits`. 150 | """ 151 | if mask_is_length: 152 | mask = tf.sequence_mask(mask, maxlen=tf.shape(logits)[-1]) 153 | return logits + (1.0 - tf.cast(mask, 'float')) * VERY_LARGE_NEGATIVE_VALUE 154 | 155 | 156 | def self_att(tensor, 157 | tensor_val=None, 158 | mask=None, 159 | mask_is_length=True, 160 | logit_fn=None, 161 | scale_dot=False, 162 | normalizer=tf.nn.softmax, 163 | tensors=None, 164 | scope=None, 165 | reuse=False): 166 | """Performs self attention. 167 | 168 | Performs self attention to obtain single vector representation for a sequence 169 | of vectors. 170 | 171 | Args: 172 | tensor: [batch_size, sequence_length, hidden_size]-shaped tensor 173 | tensor_val: If specified, attention is applied on `tensor_val`, i.e. 174 | `tensor` is key. 175 | mask: Length mask (shape of [batch_size]) or 176 | boolean mask ([batch_size, sequence_length]) 177 | mask_is_length: `True` if `mask` is length mask, `False` if it is boolean 178 | mask 179 | logit_fn: `logit_fn(tensor)` to obtain logits. 180 | scale_dot: `bool`, whether to scale the dot product by dividing by 181 | sqrt(hidden_size). 182 | normalizer: function to normalize logits. 183 | tensors: `dict`. If specified, add useful tensors (e.g. attention weights) 184 | to the `dict` with their (scope) names. 185 | scope: `string` for defining variable scope 186 | reuse: Reuse if `True`. 187 | Returns: 188 | [batch_size, hidden_size]-shaped tensor. 189 | """ 190 | assert len(tensor.get_shape() 191 | ) == 3, 'The rank of `tensor` must be 3 but got {}.'.format( 192 | len(tensor.get_shape())) 193 | with tf.variable_scope(scope or 'self_att', reuse=reuse): 194 | hidden_size = tensor.get_shape().as_list()[-1] 195 | if logit_fn is None: 196 | logits = tf.layers.dense(tensor, hidden_size, activation=tf.tanh) 197 | logits = tf.squeeze(tf.layers.dense(logits, 1), 2) 198 | else: 199 | logits = logit_fn(tensor) 200 | if scale_dot: 201 | logits /= tf.sqrt(hidden_size) 202 | if mask is not None: 203 | logits = exp_mask(logits, mask, mask_is_length=mask_is_length) 204 | weights = normalizer(logits) 205 | if tensors is not None: 206 | weights = tf.identity(weights, name='attention') 207 | tensors[weights.op.name] = weights 208 | out = tf.reduce_sum( 209 | tf.expand_dims(weights, -1) * (tensor 210 | if tensor_val is None else tensor_val), 211 | 1) 212 | return out 213 | 214 | 215 | def highway(inputs, 216 | outputs=None, 217 | dropout_rate=0.0, 218 | batch_norm=False, 219 | training=False, 220 | scope=None, 221 | reuse=False): 222 | """Single-layer highway networks (https://arxiv.org/abs/1505.00387). 223 | 224 | Args: 225 | inputs: Arbitrary-rank `float` tensor, where the first dim is batch size 226 | and the last dim is where the highway network is applied. 227 | outputs: If provided, will replace the perceptron layer (i.e. gating only.) 228 | dropout_rate: `float` value, input dropout rate. 229 | batch_norm: `bool` value, whether to use batch normalization. 230 | training: `bool` value, whether the current run is training. 231 | scope: `str` value variable scope, default to `highway_net`. 232 | reuse: `bool` value, whether to reuse variables. 233 | Returns: 234 | The output of the highway network, same shape as `inputs`. 235 | """ 236 | with tf.variable_scope(scope or 'highway', reuse=reuse): 237 | if dropout_rate > 0.0: 238 | inputs = tf.layers.dropout(inputs, rate=dropout_rate, training=training) 239 | dim = inputs.get_shape()[-1] 240 | if outputs is None: 241 | outputs = tf.layers.dense(inputs, dim, name='outputs') 242 | if batch_norm: 243 | outputs = tf.layers.batch_normalization(outputs, training=training) 244 | outputs = tf.nn.relu(outputs) 245 | gate = tf.layers.dense(inputs, dim, activation=tf.nn.sigmoid, name='gate') 246 | return gate * inputs + (1 - gate) * outputs 247 | 248 | 249 | def highway_net(inputs, 250 | num_layers, 251 | dropout_rate=0.0, 252 | batch_norm=False, 253 | training=False, 254 | scope=None, 255 | reuse=False): 256 | """Multi-layer highway networks (https://arxiv.org/abs/1505.00387). 257 | 258 | Args: 259 | inputs: `float` input tensor to the highway networks. 260 | num_layers: `int` value, indicating the number of highway layers to build. 261 | dropout_rate: `float` value for the input dropout rate. 262 | batch_norm: `bool` value, indicating whether to use batch normalization 263 | or not. 264 | training: `bool` value, indicating whether the current run is training 265 | or not (e.g. eval or inference). 266 | scope: `str` value, variable scope. Default is `highway_net`. 267 | reuse: `bool` value, indicating whether the variables in this function 268 | are reused. 269 | Returns: 270 | The output of the highway networks, which is the same shape as `inputs`. 271 | """ 272 | with tf.variable_scope(scope or 'highway_net', reuse=reuse): 273 | outputs = inputs 274 | for i in range(num_layers): 275 | outputs = highway( 276 | outputs, 277 | dropout_rate=dropout_rate, 278 | batch_norm=batch_norm, 279 | training=training, 280 | scope='layer_{}'.format(i)) 281 | return outputs 282 | 283 | 284 | def char_cnn(inputs, 285 | out_dim, 286 | kernel_size, 287 | dropout_rate=0.0, 288 | name=None, 289 | reuse=False, 290 | activation=None, 291 | batch_norm=False, 292 | training=False): 293 | """Character-level CNN. 294 | 295 | Args: 296 | inputs: Input tensor of shape [batch_size, num_words, num_chars, in_dim]. 297 | out_dim: `int` value, output dimension of CNN. 298 | kernel_size: `int` value, the width of the kernel for CNN. 299 | dropout_rate: `float` value, input dropout rate. 300 | name: `str` value, variable scope for variables in this function. 301 | reuse: `bool` value, indicating whether to reuse CNN variables. 302 | activation: function for activation. Default is `tf.nn.relu`. 303 | batch_norm: `bool` value, whether to perform batch normalization. 304 | training: `bool` value, whether the current run is training or not. 305 | Returns: 306 | Output tensor of shape [batch_size, num_words, out_dim]. 307 | 308 | """ 309 | with tf.variable_scope(name or 'char_cnn', reuse=reuse): 310 | batch_size = tf.shape(inputs)[0] 311 | num_words = tf.shape(inputs)[1] 312 | num_chars = tf.shape(inputs)[2] 313 | in_dim = inputs.get_shape().as_list()[3] 314 | if dropout_rate > 0.0: 315 | inputs = tf.layers.dropout(inputs, rate=dropout_rate, training=training) 316 | outputs = tf.reshape( 317 | tf.layers.conv1d( 318 | tf.reshape(inputs, [batch_size * num_words, num_chars, in_dim]), 319 | out_dim, 320 | kernel_size, 321 | name=name), [batch_size, num_words, -1, out_dim]) 322 | if batch_norm: 323 | outputs = tf.layers.batch_normalization(outputs, training=training) 324 | if activation is None: 325 | activation = tf.nn.relu 326 | outputs = activation(outputs) 327 | outputs = tf.reduce_max(outputs, 2) 328 | return outputs 329 | 330 | 331 | def att2d(a, 332 | b, 333 | a_val=None, 334 | mask=None, 335 | b_mask=None, 336 | a_null=None, 337 | logit_fn='bahdanau', 338 | return_weights=False, 339 | normalizer=tf.nn.softmax, 340 | transpose=False, 341 | scale_logits=False, 342 | reduce_fn=None, 343 | tensors=None, 344 | scope=None): 345 | """2D-attention on a pair of sequences. 346 | 347 | Obtain the most similar (most attended) vector among the vectors in `a` to 348 | each vector in `b`. That is, `b` can be considered as key and `a` is value. 349 | Or in other words, `b` is attender and `a` is attendee. 350 | 351 | Args: 352 | a: [batch_size, a_len, hidden_size] shaped tensor. 353 | b: [batch_size, b_len, hidden_size] shaped tensor. 354 | a_val: If specified, attention is performed on `a_val` instead of `a`. 355 | mask: length mask tensor, boolean mask tensor for `a`, or 2d mask of size 356 | [b_len, a_len]. 357 | b_mask: something. 358 | a_null: If specified, this becomes a possible vector to be attended. 359 | logit_fn: `logit_fn(a, b)` computes logits for attention. By default, 360 | uses attention function by Bahdanau et al (2014). Can be string `dot`, 361 | in which case the dot product is memory-efficiently computed via 362 | `tf.batch_matmul`. 363 | return_weights: `bool` value, whether to return weights instead of 364 | the attended vector. 365 | normalizer: function that normalizes the weights. 366 | transpose: `bool`, whether to transpose the normalizer axis. 367 | scale_logits: `bool`, whether to scale the logits by 368 | sqrt(hidden_size), as shown in https://arxiv.org/abs/1706.03762. 369 | reduce_fn: python fn, which reduces the logit matrix in the b's axis 370 | if specified. 371 | tensors: `dict`. If specified, add useful tensors (e.g. attention weights) 372 | to the `dict` with their (scope) names. 373 | scope: `str` value, indicating the variable scope of this function. 374 | Returns: 375 | [batch_size, b_len, hidden_size] shaped tensor, where each vector 376 | represents the most similar vector among `a` for each vector in `b`. 377 | If `return_weights` is `True`, return tensor is 378 | [batch_size, b_len, a_len] shape. If `reduce_fn` is specified, the return 379 | tensor shape is [batch_size, 1, a_len] 380 | """ 381 | with tf.variable_scope(scope or 'att_2d'): 382 | batch_size = tf.shape(a)[0] 383 | hidden_size = a.get_shape().as_list()[-1] 384 | if a_null is not None: 385 | a_null = tf.tile( 386 | tf.expand_dims(tf.expand_dims(a_null, 0), 0), [batch_size, 1, 1]) 387 | a = tf.concat([a_null, a], 1) 388 | if mask is not None: 389 | # TODO(seominjoon) : To support other shapes of masks. 390 | assert len(mask.get_shape()) == 1 391 | mask += 1 392 | 393 | # Memory-efficient operation for dot product logit function. 394 | if logit_fn == 'wdot': 395 | weights = tf.get_variable('weights', shape=[hidden_size], dtype='float') 396 | bias = tf.get_variable('bias', shape=[], dtype='float') 397 | logits = tf.matmul(b * weights, a, transpose_b=True) + bias 398 | elif logit_fn == 'dot': 399 | logits = tf.matmul(b, a, transpose_b=True) 400 | elif logit_fn == 'bilinear': 401 | logits = tf.matmul( 402 | tf.layers.dense(b, hidden_size, use_bias=False), a, transpose_b=True) 403 | elif logit_fn == 'l2': 404 | ba = tf.matmul(b, a, transpose_b=True) 405 | aa = tf.expand_dims(tf.reduce_sum(a * a, 2), 1) 406 | bb = tf.expand_dims(tf.reduce_sum(b * b, 2), 2) 407 | logits = 2 * ba - aa - bb 408 | else: 409 | # WARNING : This is memory-intensive! 410 | aa = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[1], 1, 1]) 411 | bb = tf.tile(tf.expand_dims(b, 2), [1, 1, tf.shape(a)[1], 1]) 412 | if logit_fn == 'bahdanau': 413 | logits = tf.layers.dense( 414 | tf.concat([aa, bb], 3), hidden_size, activation=tf.tanh) 415 | logits = tf.squeeze(tf.layers.dense(logits, 1), 3) 416 | else: 417 | logits = logit_fn(aa, bb) # [batch_size, a_len, b_len]-shaped tensor. 418 | if scale_logits: 419 | logits /= tf.sqrt(tf.cast(hidden_size, 'float')) 420 | if mask is not None: 421 | if len(mask.get_shape()) == 1: 422 | mask = tf.sequence_mask(mask, tf.shape(a)[1]) 423 | if len(mask.get_shape()) == 2: 424 | mask = tf.expand_dims(mask, 1) 425 | if b_mask is not None: 426 | if len(b_mask.get_shape()) == 1: 427 | b_mask = tf.sequence_mask(b_mask, tf.shape(b)[1]) 428 | mask &= tf.expand_dims(b_mask, -1) 429 | logits = exp_mask(logits, mask, mask_is_length=False) 430 | if reduce_fn: 431 | logits = tf.expand_dims(reduce_fn(logits, 1), 1) 432 | if transpose: 433 | logits = tf.transpose(logits, [0, 2, 1]) 434 | p = logits if normalizer is None else normalizer(logits) 435 | if transpose: 436 | p = tf.transpose(p, [0, 2, 1]) 437 | logits = tf.transpose(logits, [0, 2, 1]) 438 | if tensors is not None: 439 | p = tf.identity(p, name='attention') 440 | tensors[p.op.name] = p 441 | 442 | if return_weights: 443 | return p 444 | 445 | # Memory-efficient application of attention weights. 446 | # [batch_size, b_len, hidden_size] 447 | a_b = tf.matmul(p, a if a_val is None else a_val) 448 | return a_b 449 | 450 | 451 | def mlp(a, 452 | hidden_sizes, 453 | activation=tf.nn.relu, 454 | activate_last=True, 455 | dropout_rate=0.0, 456 | training=False, 457 | scope=None): 458 | """Multi-layer perceptron. 459 | 460 | Args: 461 | a: input tensor. 462 | hidden_sizes: `list` of `int`, hidden state sizes for perceptron layers. 463 | activation: function handler for activation. 464 | activate_last: `bool`, whether to activate the last layer or not. 465 | dropout_rate: `float`, dropout rate at the input of each layer. 466 | training: `bool`, whether the current run is training or not. 467 | scope: `str`, variable scope of all tensors and weights in this function. 468 | Returns: 469 | Tensor with same shape as `a` except for the last dim, whose size is equal 470 | to `hidden_sizes[-1]`. 471 | """ 472 | with tf.variable_scope(scope or 'mlp'): 473 | for idx, hidden_size in enumerate(hidden_sizes): 474 | with tf.variable_scope('layer_%d' % idx): 475 | if dropout_rate > 0.0: 476 | a = tf.layers.dropout(a, rate=dropout_rate, training=training) 477 | activate = idx < len(hidden_sizes) - 1 or activate_last 478 | a = tf.layers.dense( 479 | a, hidden_size, activation=activation if activate else None) 480 | 481 | return a 482 | 483 | 484 | def split_concat(a, b, num, axis=None): 485 | if axis is None: 486 | axis = len(a.get_shape()) - 1 487 | a_list = tf.split(a, num, axis=axis) 488 | b_list = tf.split(b, num, axis=axis) 489 | t_list = tuple( 490 | tf.concat([aa, bb], axis=axis) for aa, bb in zip(a_list, b_list)) 491 | return t_list 492 | 493 | 494 | def concat_seq_and_tok(sequence, token, position, sequence_length=None): 495 | """Concatenates a token to the given sequence, either at the start or end. 496 | 497 | The token's dimension should match the last dimension of the sequence. 498 | 499 | Args: 500 | sequence: [batch_size, sequence_length] shaped tensor or 501 | [batch_size, sequence_length, hidden_size] shaped tensor. 502 | token: scalar tensor or [hidden_size] shaped tensor. 503 | position: `str`, either 'start' or 'end'. 504 | sequence_length: [batch_size] shaped `int64` tensor. Must be specified 505 | if `position` is 'end'. 506 | Returns: 507 | [batch_size, sequence_length+1] or 508 | [batch_size, sequence_length+1, hidden_size] shaped tensor. 509 | Raises: 510 | ValueError: If `position` is not 'start' or 'end'. 511 | """ 512 | batch_size = tf.shape(sequence)[0] 513 | if len(sequence.get_shape()) == 3: 514 | token = tf.tile( 515 | tf.expand_dims(tf.expand_dims(token, 0), 0), [batch_size, 1, 1]) 516 | elif len(sequence.get_shape()) == 2: 517 | token = tf.tile(tf.reshape(token, [1, 1]), [batch_size, 1]) 518 | 519 | if position == 'start': 520 | sequence = tf.concat([token, sequence], 1) 521 | elif position == 'end': 522 | assert sequence_length is not None 523 | sequence = tf.reverse_sequence(sequence, sequence_length, seq_axis=1) 524 | sequence = tf.concat([token, sequence], 1) 525 | sequence = tf.reverse_sequence(sequence, sequence_length + 1, seq_axis=1) 526 | else: 527 | raise ValueError('%r is an invalid argument for `position`.' % position) 528 | return sequence 529 | 530 | 531 | class ExternalInputWrapper(tf.contrib.rnn.RNNCell): 532 | """Wrapper for `RNNCell`, concatenates an external tensor to the input.""" 533 | 534 | def __init__(self, cell, external_input, reuse=False): 535 | super(ExternalInputWrapper, self).__init__(_reuse=reuse) 536 | self._cell = cell 537 | self._external = external_input 538 | 539 | @property 540 | def state_size(self): 541 | return self._cell.state_size 542 | 543 | @property 544 | def output_size(self): 545 | return self._cell.output_size 546 | 547 | def zero_state(self, batch_size, dtype): 548 | with tf.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]): 549 | return self._cell.zero_state(batch_size, dtype) 550 | 551 | def call(self, inputs, state): 552 | inputs = tf.concat([self._external_input, inputs], 1) 553 | return self._cell(inputs, state) 554 | 555 | 556 | class DeembedWrapper(tf.contrib.seq2seq.Helper): 557 | """Wrapper for `Helper`, applies given deembed function to the output. 558 | 559 | The deembed function has single input and single output. 560 | It is applied on the output of the previous RNN before feeding it into the 561 | next time step. 562 | """ 563 | 564 | def __init__(self, helper, deembed_fn): 565 | self._helper = helper 566 | self._deembed_fn = deembed_fn 567 | 568 | @property 569 | def batch_size(self): 570 | return self._helper.batch_size 571 | 572 | def initialize(self, name=None): 573 | return self._helper.initialize(name=name) 574 | 575 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 576 | outputs = self._deembed_fn(outputs) 577 | return self._helper.next_inputs(time, outputs, state, sample_ids, name=name) 578 | 579 | def sample(self, time, outputs, state, name=None): 580 | outputs = self._deembed_fn(outputs) 581 | return self._helper.sample(time, outputs, state, name=name) 582 | -------------------------------------------------------------------------------- /train_and_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Experiments for kernel vs feature map in SQuAD. 15 | 16 | `feature` model does not allow any interaction between question and context 17 | except at the end, where the dot product (or L1/L2 distance) is used to get the 18 | answer. 19 | `kernel` model allows any interaction between question and context 20 | (e.g. cross attention). 21 | This script is for establishing baseline for both feature and kernel models. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from collections import defaultdict 29 | import json 30 | import os 31 | 32 | import tensorflow as tf 33 | from tqdm import tqdm 34 | import tensorflow.contrib.learn as learn 35 | 36 | # This is required for importing google specific flags: 37 | # `output_dir`, `schedule` 38 | # (`learn` above is not sufficient). Will need to add these flags when 39 | # removing this import for open-sourcing. 40 | from tensorflow.contrib.learn import learn_runner 41 | 42 | import squad_data 43 | from common_model import get_loss 44 | from common_model import get_pred_ops 45 | from common_model import get_train_op 46 | from feature_model import feature_model 47 | from kernel_model import kernel_model 48 | 49 | tf.flags.DEFINE_integer('emb_size', 200, 'embedding size') 50 | tf.flags.DEFINE_integer('glove_size', 200, 'GloVe size') 51 | tf.flags.DEFINE_integer('hidden_size', 200, 'hidden state size') 52 | tf.flags.DEFINE_integer('num_train_steps', 20000, 'num train steps') 53 | tf.flags.DEFINE_integer('num_eval_steps', 500, 'num eval steps') 54 | tf.flags.DEFINE_boolean('draft', False, 'draft?') 55 | tf.flags.DEFINE_integer('batch_size', 64, 'batch size') 56 | tf.flags.DEFINE_float('dropout_rate', 0.2, 57 | 'dropout rate, applied to the input of LSTMs.') 58 | tf.flags.DEFINE_string( 59 | 'root_data_dir', 60 | '/cns/ok-d/home/neon-core/ker2vec/squad/prepro/sort_filter', 61 | 'root data dir') 62 | tf.flags.DEFINE_integer('save_checkpoints_steps', 500, '') 63 | tf.flags.DEFINE_integer('num_eval_delay_secs', 1, 'eval delay secs') 64 | tf.flags.DEFINE_boolean('shuffle_examples', False, 'Use shuffle example queue?') 65 | tf.flags.DEFINE_boolean('shuffle_files', True, 'Use shuffle file queue?') 66 | tf.flags.DEFINE_string('model', 'feature', '`feature` or `kernel`.') 67 | tf.flags.DEFINE_boolean('oom_test', False, 'Performs out-of-memory test') 68 | tf.flags.DEFINE_string( 69 | 'dist', 'dot', 'Distance function for feature model. `dot`, `l1` or `l2`.') 70 | tf.flags.DEFINE_float('learning_rate', 0.001, 71 | '(Initial) learning rate for optimizer') 72 | tf.flags.DEFINE_boolean( 73 | 'infer', False, 74 | 'If `True`, obtains and saves predictions for the test dataset ' 75 | 'at `answers_path`.') 76 | tf.flags.DEFINE_string('answers_path', '', 77 | 'The path for saving predictions on test dataset. ' 78 | 'If not specified, saves in `restore_dir` directory.') 79 | tf.flags.DEFINE_float('clip_norm', 0, 'Clip norm threshold, 0 for no clip.') 80 | tf.flags.DEFINE_integer( 81 | 'restore_step', 0, 82 | 'The global step for which the model is restored in the beginning. ' 83 | '`0` for the most recent save file.') 84 | tf.flags.DEFINE_float( 85 | 'restore_decay', 1.0, 86 | 'The decay rate for exponential moving average of variables that ' 87 | 'will be restored upon eval or infer. ' 88 | '`1.0` for restoring variables without decay.') 89 | tf.flags.DEFINE_string( 90 | 'ema_decays', '', 91 | 'List of exponential moving average (EMA) decay rates (float) ' 92 | 'to track for variables during training. Values are separated by commas.') 93 | tf.flags.DEFINE_string( 94 | 'restore_dir', '', 95 | 'Directory from which variables are restored. If not specfied, `output_dir`' 96 | 'will be used instead. For inference mode, this needs to be specified.') 97 | tf.flags.DEFINE_string('model_id', 'm00', 'Model id.') 98 | tf.flags.DEFINE_string('glove_dir', '/cns/ok-d/home/neon-core/ker2vec/glove', 99 | 'GloVe dir.') 100 | tf.flags.DEFINE_boolean('merge', False, 'If `True`, merges answers from same ' 101 | 'paragraph that were split in preprocessing step.') 102 | tf.flags.DEFINE_integer('queue_capacity', 5000, 'Input queue capacity.') 103 | tf.flags.DEFINE_integer('min_after_dequeue', 1000, 'Minimum number of examples ' 104 | 'after queue dequeue.') 105 | tf.flags.DEFINE_integer('max_answer_size', 7, 'Max number of answer words.') 106 | tf.flags.DEFINE_string('restore_scopes', '', 'Restore scopes, separated by ,.') 107 | tf.flags.DEFINE_boolean('reg_gen', True, 'Whether to regularize training ' 108 | 'with question generation (reconstruction) loss.') 109 | tf.flags.DEFINE_float('reg_cf', 3.0, 'Regularization initial coefficient.') 110 | tf.flags.DEFINE_float('reg_half_life', 6000, 'Regularization decay half life. ' 111 | 'Set it to very high value to effectively disable decay.') 112 | tf.flags.DEFINE_integer('max_gen_length', 32, 'During inference, maximum ' 113 | 'length of generated question.') 114 | 115 | # Below are added for third party. 116 | tf.flags.DEFINE_string('schedule', 'train_and_evaluate', 117 | 'schedule for learn_runner.') 118 | tf.flags.DEFINE_string('output_dir', '/tmp/squad_ckpts', 119 | 'Output directory for saving model.') 120 | 121 | FLAGS = tf.flags.FLAGS 122 | tf.logging.set_verbosity(tf.logging.INFO) 123 | 124 | 125 | def model_fn(features, targets, mode, params): 126 | """Model function to be used for `Experiment` object. 127 | 128 | Should not access `flags.FLAGS`. 129 | 130 | Args: 131 | features: a dictionary of feature tensors. 132 | targets: a dictionary of target tensors. 133 | mode: `learn.ModeKeys.TRAIN` or `learn.ModeKeys.EVAL`. 134 | params: `HParams` object. 135 | Returns: 136 | `ModelFnOps` object. 137 | Raises: 138 | ValueError: rasied if `params.model` is not an appropriate value. 139 | """ 140 | with tf.variable_scope('model'): 141 | if params.model == 'feature': 142 | logits_start, logits_end, tensors = feature_model( 143 | features, mode, params) 144 | elif params.model == 'kernel': 145 | logits_start, logits_end, tensors = kernel_model( 146 | features, mode, params) 147 | else: 148 | raise ValueError( 149 | '`%s` is an invalid argument for `model` parameter.' % params.model) 150 | no_answer_bias = tf.get_variable('no_answer_bias', shape=[], dtype='float') 151 | no_answer_bias = tf.tile( 152 | tf.reshape(no_answer_bias, [1, 1]), 153 | [tf.shape(features['context_words'])[0], 1]) 154 | 155 | predictions = get_pred_ops(features, params, logits_start, logits_end, 156 | no_answer_bias) 157 | predictions.update(tensors) 158 | predictions.update(features) 159 | 160 | if mode == learn.ModeKeys.INFER: 161 | eval_metric_ops, loss = None, None 162 | else: 163 | eval_metric_ops = squad_data.get_eval_metric_ops(targets, predictions) 164 | loss = get_loss(targets['word_answer_starts'], targets['word_answer_ends'], 165 | logits_start, logits_end, no_answer_bias) 166 | 167 | emas = { 168 | decay: tf.train.ExponentialMovingAverage( 169 | decay=decay, name='EMA_%f' % decay) 170 | for decay in params.ema_decays 171 | } 172 | 173 | ema_ops = [ema.apply() for ema in emas.values()] 174 | if mode == learn.ModeKeys.TRAIN: 175 | train_op = get_train_op( 176 | loss, 177 | learning_rate=params.learning_rate, 178 | clip_norm=params.clip_norm, 179 | post_ops=ema_ops) 180 | # TODO(seominjoon): Checking `Exists` is not the best way to do this. 181 | if params.restore_dir and not tf.gfile.Exists(params.output_dir): 182 | assert params.restore_scopes 183 | checkpoint_dir = params.restore_dir 184 | if params.restore_step: 185 | checkpoint_dir = os.path.join(params.restore_dir, 186 | 'model.ckpt-%d' % params.restore_step) 187 | restore_vars = [] 188 | for restore_scope in params.restore_scopes: 189 | restore_vars.extend( 190 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, restore_scope)) 191 | assignment_map = {var.op.name: var for var in restore_vars} 192 | tf.contrib.framework.init_from_checkpoint(checkpoint_dir, assignment_map) 193 | else: 194 | if params.restore_decay < 1.0: 195 | ema = emas[params.restore_decay] 196 | assign_ops = [] 197 | for var in tf.trainable_variables(): 198 | assign_op = tf.assign(var, ema.average(var)) 199 | assign_ops.append(assign_op) 200 | with tf.control_dependencies(assign_ops): 201 | for key, val in predictions.items(): 202 | predictions[key] = tf.identity(val) 203 | train_op = None 204 | 205 | return learn.ModelFnOps( 206 | mode=mode, 207 | predictions=predictions, 208 | loss=loss, 209 | train_op=train_op, 210 | eval_metric_ops=eval_metric_ops) 211 | 212 | 213 | def _experiment_fn(run_config, hparams): 214 | """Outputs `Experiment` object given `output_dir`. 215 | 216 | Args: 217 | run_config: `EstimatorConfig` object fo run configuration. 218 | hparams: `HParams` object that contains hyperparameters. 219 | 220 | Returns: 221 | `Experiment` object 222 | """ 223 | estimator = learn.Estimator( 224 | model_fn=model_fn, config=run_config, params=hparams) 225 | 226 | num_train_steps = 1 if FLAGS.oom_test else FLAGS.num_train_steps 227 | num_eval_steps = 1 if FLAGS.oom_test else FLAGS.num_eval_steps 228 | 229 | return learn.Experiment( 230 | estimator=estimator, 231 | train_input_fn=_get_train_input_fn(), 232 | eval_input_fn=_get_eval_input_fn(), 233 | train_steps=num_train_steps, 234 | eval_steps=num_eval_steps, 235 | eval_delay_secs=FLAGS.num_eval_delay_secs) 236 | 237 | 238 | def _get_train_input_fn(): 239 | """Get train input function.""" 240 | train_input_fn = squad_data.get_input_fn( 241 | FLAGS.root_data_dir, 242 | FLAGS.glove_dir, 243 | 'train', 244 | FLAGS.batch_size, 245 | FLAGS.glove_size, 246 | shuffle_files=FLAGS.shuffle_files, 247 | shuffle_examples=FLAGS.shuffle_examples, 248 | queue_capacity=FLAGS.queue_capacity, 249 | min_after_dequeue=FLAGS.min_after_dequeue, 250 | oom_test=FLAGS.oom_test) 251 | return train_input_fn 252 | 253 | 254 | def _get_eval_input_fn(): 255 | """Get eval input function.""" 256 | eval_input_fn = squad_data.get_input_fn( 257 | FLAGS.root_data_dir, 258 | FLAGS.glove_dir, 259 | 'dev', 260 | FLAGS.batch_size, 261 | FLAGS.glove_size, 262 | shuffle_files=True, 263 | shuffle_examples=True, 264 | queue_capacity=FLAGS.queue_capacity, 265 | min_after_dequeue=FLAGS.min_after_dequeue, 266 | num_epochs=1, 267 | oom_test=FLAGS.oom_test) 268 | return eval_input_fn 269 | 270 | 271 | def _get_test_input_fn(): 272 | """Get test input function.""" 273 | # TODO(seominjoon) For now, test input is same as eval input (dev). 274 | test_input_fn = squad_data.get_input_fn( 275 | FLAGS.root_data_dir, 276 | FLAGS.glove_dir, 277 | 'dev', 278 | FLAGS.batch_size, 279 | FLAGS.glove_size, 280 | shuffle_files=FLAGS.shuffle_files, 281 | shuffle_examples=FLAGS.shuffle_examples, 282 | queue_capacity=FLAGS.queue_capacity, 283 | min_after_dequeue=FLAGS.min_after_dequeue, 284 | num_epochs=1, 285 | oom_test=FLAGS.oom_test) 286 | return test_input_fn 287 | 288 | 289 | def _get_config(): 290 | """Get configuration object for `Estimator` object. 291 | 292 | For open-soucing, `EstimatorConfig` has been replaced with `RunConfig`. 293 | Depends on `flags.FLAGS`, and should not be used outside of this main script. 294 | 295 | Returns: 296 | `EstimatorConfig` object. 297 | """ 298 | config = learn.RunConfig( 299 | model_dir=FLAGS.restore_dir if FLAGS.infer else FLAGS.output_dir, 300 | keep_checkpoint_max=0, # Keep all checkpoints. 301 | save_checkpoints_steps=FLAGS.save_checkpoints_steps) 302 | return config 303 | 304 | 305 | def _get_hparams(): 306 | """Model-specific hyperparameters go here. 307 | 308 | All model parameters go here, since `model_fn()` should not access 309 | `flags.FLAGS`. 310 | Depends on `flags.FLAGS`, and should not be used outside of this main script. 311 | 312 | Returns: 313 | `HParams` object. 314 | """ 315 | hparams = tf.contrib.training.HParams() 316 | data_hparams = squad_data.get_params(FLAGS.root_data_dir) 317 | hparams.vocab_size = data_hparams['vocab_size'] 318 | hparams.char_vocab_size = data_hparams['char_vocab_size'] 319 | hparams.batch_size = FLAGS.batch_size 320 | hparams.hidden_size = FLAGS.hidden_size 321 | hparams.emb_size = FLAGS.emb_size 322 | hparams.dropout_rate = FLAGS.dropout_rate 323 | hparams.dist = FLAGS.dist 324 | hparams.learning_rate = FLAGS.learning_rate 325 | hparams.model = FLAGS.model 326 | hparams.restore_dir = FLAGS.restore_dir 327 | hparams.output_dir = FLAGS.output_dir 328 | hparams.clip_norm = FLAGS.clip_norm 329 | hparams.restore_decay = FLAGS.restore_decay 330 | if FLAGS.ema_decays: 331 | hparams.ema_decays = list(map(float, FLAGS.ema_decays.split(','))) 332 | else: 333 | hparams.ema_decays = [] 334 | hparams.restore_step = FLAGS.restore_step 335 | hparams.model_id = FLAGS.model_id 336 | hparams.max_answer_size = FLAGS.max_answer_size 337 | hparams.restore_scopes = FLAGS.restore_scopes.split(',') 338 | hparams.glove_size = FLAGS.glove_size 339 | 340 | # Regularization by Query Generation (reconstruction) parameters. 341 | hparams.reg_gen = FLAGS.reg_gen 342 | hparams.reg_cf = FLAGS.reg_cf 343 | hparams.reg_half_life = FLAGS.reg_half_life 344 | 345 | return hparams 346 | 347 | 348 | def train_and_eval(): 349 | """Train and eval routine.""" 350 | learn_runner.run( 351 | experiment_fn=_experiment_fn, 352 | schedule=FLAGS.schedule, 353 | run_config=_get_config(), 354 | hparams=_get_hparams()) 355 | 356 | 357 | def _set_ckpt(): 358 | # TODO(seominjoon): This is adhoc. Need better ckpt loading during inf. 359 | if FLAGS.restore_step: 360 | path = os.path.join(FLAGS.restore_dir, 'checkpoint') 361 | with tf.gfile.GFile(path, 'w') as fp: 362 | fp.write('model_checkpoint_path: "model.ckpt-%d"\n' % FLAGS.restore_step) 363 | 364 | 365 | def infer(): 366 | """Inference routine, outputting answers to `FLAGS.answers_path`.""" 367 | _set_ckpt() 368 | estimator = learn.Estimator( 369 | model_fn=model_fn, config=_get_config(), params=_get_hparams()) 370 | predictions = estimator.predict( 371 | input_fn=_get_test_input_fn(), as_iterable=True) 372 | global_step = estimator.get_variable_value('global_step') 373 | path = FLAGS.answers_path or os.path.join(FLAGS.restore_dir, 374 | 'answers-%d.json' % global_step) 375 | answer_dict = {'no_answer_prob': {}, 'answer_prob': {}} 376 | for prediction in tqdm(predictions): 377 | id_ = prediction['id'].decode('utf-8') 378 | answer_dict[id_] = prediction['a'].decode('utf-8') 379 | answer_dict['answer_prob'][id_] = prediction['answer_prob'].tolist() 380 | answer_dict['no_answer_prob'][id_] = prediction['no_answer_prob'].tolist() 381 | if FLAGS.oom_test: 382 | break 383 | 384 | # TODO(seominjoon): use sum of logits instead of normalized prob. 385 | if FLAGS.merge: 386 | new_answer_dict = defaultdict(list) 387 | for id_, answer_prob in answer_dict['answer_prob'].items(): 388 | answer = answer_dict[id_] 389 | id_ = id_.split(' ')[0] # retrieve true id 390 | new_answer_dict[id_].append([answer_prob, answer]) 391 | answer_dict = { 392 | id_: max(each, key=lambda pair: pair[0])[1] 393 | for id_, each in new_answer_dict.items() 394 | } 395 | 396 | with tf.gfile.GFile(path, 'w') as fp: 397 | json.dump(answer_dict, fp) 398 | tf.logging.info('Dumped predictions at: %s' % path) 399 | 400 | 401 | def main(_): 402 | if FLAGS.infer: 403 | infer() 404 | else: 405 | train_and_eval() 406 | 407 | 408 | if __name__ == '__main__': 409 | tf.app.run() 410 | --------------------------------------------------------------------------------