├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── common_model.py
├── download.sh
├── feature_model.py
├── kernel_model.py
├── mips-qa.pdf
├── requirements.txt
├── squad_data.py
├── squad_prepro.py
├── squad_prepro_main.py
├── tf_utils.py
└── train_and_eval.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2017 Google Inc.
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # README
2 |
3 | **This is not an official Google product**
4 |
5 | This project contains code for training and running an extractive question answering model on the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). All methods and models contained in this project are described in the [technical report](https://github.com/google/mipsqa/blob/master/mips-qa.pdf). Any extensions of this work should cite the report as:
6 |
7 | ```
8 | @misc{SeoKwiatParikh:2017,
9 | title = {Question Answering with Maximum Inner Product Search},
10 | author = {Minjoon Seo and Tom Kwiatkowski and Ankur Parikh},
11 | url = {...},
12 | }
13 | ```
14 |
15 | # 0. Requirements and data
16 | - Basic requirements: Python 2 or 3, wget (if using MacOS. You can also download yourself looking at the `download.sh` script)
17 | - Python packages: tensorflow 1.3.0 or higher, nltk, tqdm
18 | - Data: SQuAD, GloVe, nltk tokenizer
19 |
20 | To install required packages, run:
21 | ```bash
22 | pip install -r requirements.txt
23 | ```
24 |
25 | To download data, run:
26 | ```bash
27 | chmod +x download.sh; ./download.sh
28 | ```
29 |
30 | Change the directories where the data is stored if needed, and use them for runs below.
31 |
32 | # 1. Train and test (draft mode)
33 | If you are using default directories for the data:
34 | ```bash
35 | export SQUAD_DIR=$HOME/data/squad
36 | export GLOVE_DIR=$HOME/data/glove
37 | ```
38 |
39 | First, preprocess train data:
40 | ```bash
41 | python squad_prepro_main.py --from_dir $SQUAD_DIR --to_dir prepro/draft/sort_filter --glove_dir $GLOVE_DIR --sort --filter --draft
42 | ```
43 | Note the `--draft` flag, which only processes a portion of the data for fast sanity check. Make sure to remove this flag when doing real training and test.
44 | `--filter` filters out very long examples, which can slow down training and cause memory issues.
45 |
46 | Second, preprocess for test data, which does not filter any example:
47 | ```bash
48 | python squad_prepro_main.py --from_dir $SQUAD_DIR --to_dir prepro/draft/sort_filter/sort --glove_dir $GLOVE_DIR --sort --draft --indexer_dir prepro/draft/sort_filter
49 | ```
50 |
51 | Third, train a model:
52 | ```bash
53 | python train_and_eval.py --output_dir /tmp/squad_ckpts --root_data_dir prepro/draft/sort_filter/ --glove_dir $GLOVE_DIR --oom_test
54 | ```
55 | In general, `--oom_test` is a flag for testing if your GPU has enough memory for the model, but it can also serve as a quick test to make sure everything runs.
56 |
57 | Fourth, test the model:
58 | ```bash
59 | python train_and_eval.py --root_data_dir prepro/draft/sort_filter/sort --glove_dir $GLOVE_DIR --oom_test --infer --restore_dir /tmp/squad_ckpts
60 | ```
61 | Note that `--output_dir` has to be changed to `--restore_dir`, and also `--infer` flag has been added.
62 | Instead of using data from `prepro/draft/sort_filter/`, it is using `prepro/draft/sort_filter/sort`, which does not filter any example.
63 | This outputs the json file in `--restore_dir` that is compatible with SQuAD official evaluator.
64 | If you want to run this fully (no draft mode), remove `--draft` and `--oom_test` flags when applicable.
65 |
--------------------------------------------------------------------------------
/common_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common components for feature and kernel models."""
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import numpy as np
21 | import tensorflow as tf
22 | import tensorflow.contrib.learn as learn
23 | import squad_data
24 | import tf_utils
25 |
26 | # This value is needed to compute best answer span with GPUs.
27 | # Set this to a high value for large context, but note that it will take
28 | # GPU memory.
29 | MAX_CONTEXT_SIZE = 2000
30 |
31 |
32 | def embedding_layer(features, mode, params, reuse=False):
33 | """Common embedding layer for feature and kernel functions.
34 |
35 | Args:
36 | features: A dictionary containing features, directly copied from `model_fn`.
37 | mode: Mode.
38 | params: Contains parameters, directly copied from `model_fn`.
39 | reuse: Reuse variables.
40 | Returns:
41 | `(x, q)` where `x` is embedded representation of context, and `q` is the
42 | embedded representation of the question.
43 | """
44 | with tf.variable_scope('embedding_layer', reuse=reuse):
45 | training = mode == learn.ModeKeys.TRAIN
46 | with tf.variable_scope('embedding'):
47 | char_emb_mat = tf.get_variable('char_emb_mat',
48 | [params.char_vocab_size, params.emb_size])
49 | xc = tf.nn.embedding_lookup(char_emb_mat,
50 | features['indexed_context_chars'])
51 | qc = tf.nn.embedding_lookup(char_emb_mat,
52 | features['indexed_question_chars'])
53 |
54 | xc = tf.reduce_max(xc, 2)
55 | qc = tf.reduce_max(qc, 2)
56 |
57 | _, xv, qv = glove_layer(features)
58 |
59 | # Concat
60 | x = tf.concat([xc, xv], 2)
61 | q = tf.concat([qc, qv], 2)
62 |
63 | x = tf_utils.highway_net(
64 | x, 2, training=training, dropout_rate=params.dropout_rate)
65 | q = tf_utils.highway_net(
66 | q, 2, training=training, dropout_rate=params.dropout_rate, reuse=True)
67 |
68 | return x, q
69 |
70 |
71 | def glove_layer(features, scope=None):
72 | """GloVe embedding layer.
73 |
74 | The first two words of `features['emb_mat']` are and .
75 | The other words are actual words. So we learn the representations of the
76 | first two words but the representation of other words are fixed (GloVe).
77 |
78 | Args:
79 | features: `dict` of feature tensors.
80 | scope: `str` for scope name.
81 | Returns:
82 | A tuple of tensors, `(glove_emb_mat, context_emb, question_emb)`.
83 | """
84 | with tf.variable_scope(scope or 'glove_layer'):
85 | glove_emb_mat_const = tf.slice(features['emb_mat'], [2, 0], [-1, -1])
86 | glove_emb_mat_var = tf.get_variable('glove_emb_mat_var',
87 | [2,
88 | glove_emb_mat_const.get_shape()[1]])
89 | glove_emb_mat = tf.concat([glove_emb_mat_var, glove_emb_mat_const], 0)
90 | xv = tf.nn.embedding_lookup(glove_emb_mat,
91 | features['glove_indexed_context_words'])
92 | qv = tf.nn.embedding_lookup(glove_emb_mat,
93 | features['glove_indexed_question_words'])
94 | return glove_emb_mat, xv, qv
95 |
96 |
97 | def char_layer(features, params, scope=None):
98 | """Character embedding layer.
99 |
100 | Args:
101 | features: `dict` of feature tensors.
102 | params: `HParams` object.
103 | scope: `str` for scope name.
104 | Returns:
105 | a tuple of tensors, `(char_emb_mat, context_emb, question_emb)`.
106 | """
107 | with tf.variable_scope(scope or 'char_layer'):
108 | char_emb_mat = tf.get_variable('char_emb_mat',
109 | [params.char_vocab_size, params.emb_size])
110 | xc = tf.nn.embedding_lookup(char_emb_mat, features['indexed_context_chars'])
111 | qc = tf.nn.embedding_lookup(char_emb_mat,
112 | features['indexed_question_chars'])
113 |
114 | xc = tf.reduce_max(xc, 2)
115 | qc = tf.reduce_max(qc, 2)
116 | return char_emb_mat, xc, qc
117 |
118 |
119 | def get_pred_ops(features, params, logits_start, logits_end, no_answer_bias):
120 | """Get prediction op dictionary given start & end logits.
121 |
122 | This dictionary will contain predictions as well as everything needed
123 | to produce the nominal answer and identifier (ids).
124 |
125 | Args:
126 | features: Features.
127 | params: `HParams` object.
128 | logits_start: [batch_size, context_size]-shaped tensor of logits for start.
129 | logits_end: Similar to `logits_start`, but for end. This tensor can be also
130 | [batch_size, context_size, context_size], in which case the true answer
131 | start is used to index on dim 1 (context_size).
132 | no_answer_bias: [batch_size, 1]-shaped tensor, bias for no answer decision.
133 | Returns:
134 | A dictionary of prediction tensors.
135 | """
136 | max_x_len = tf.shape(logits_start)[1]
137 |
138 | if len(logits_end.get_shape()) == 3:
139 | prob_end_given_start = tf.nn.softmax(logits_end)
140 | prob_start = tf.nn.softmax(logits_start)
141 | prob_start_end = prob_end_given_start * tf.expand_dims(prob_start, -1)
142 |
143 | upper_tri_mat = tf.slice(
144 | np.triu(
145 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32') -
146 | np.triu(
147 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32'),
148 | k=params.max_answer_size)), [0, 0], [max_x_len, max_x_len])
149 | prob_start_end *= tf.expand_dims(upper_tri_mat, 0)
150 |
151 | prob_end = tf.reduce_sum(prob_start_end, 1)
152 | answer_pred_start = tf.argmax(tf.reduce_max(prob_start_end, 2), 1)
153 | answer_pred_end = tf.argmax(tf.reduce_max(prob_start_end, 1), 1)
154 | answer = squad_data.get_answer_op(features['context'],
155 | features['context_words'],
156 | answer_pred_start, answer_pred_end)
157 | answer_prob = tf.reduce_max(prob_start_end, [1, 2])
158 |
159 | predictions = {
160 | 'yp1': answer_pred_start,
161 | 'yp2': answer_pred_end,
162 | 'p1': prob_start,
163 | 'p2': prob_end,
164 | 'a': answer,
165 | 'id': features['id'],
166 | 'context': features['context'],
167 | 'context_words': features['context_words'],
168 | 'answer_prob': answer_prob,
169 | 'has_answer': answer_prob > 0.0,
170 | }
171 |
172 | else:
173 | # Predictions and metrics.
174 | concat_logits_start = tf.concat([no_answer_bias, logits_start], 1)
175 | concat_logits_end = tf.concat([no_answer_bias, logits_end], 1)
176 |
177 | concat_prob_start = tf.nn.softmax(concat_logits_start)
178 | concat_prob_end = tf.nn.softmax(concat_logits_end)
179 |
180 | no_answer_prob_start = tf.squeeze(
181 | tf.slice(concat_prob_start, [0, 0], [-1, 1]), 1)
182 | no_answer_prob_end = tf.squeeze(
183 | tf.slice(concat_prob_end, [0, 0], [-1, 1]), 1)
184 | no_answer_prob = no_answer_prob_start * no_answer_prob_end
185 | has_answer = no_answer_prob < 0.5
186 | prob_start = tf.slice(concat_prob_start, [0, 1], [-1, -1])
187 | prob_end = tf.slice(concat_prob_end, [0, 1], [-1, -1])
188 |
189 | # This is only for computing span accuracy and not used for training.
190 | # Masking with `upper_triangular_matrix` only allows valid spans,
191 | # i.e. `answer_pred_start` <= `answer_pred_end`.
192 | # TODO(seominjoon): Replace with dynamic upper triangular matrix.
193 | upper_tri_mat = tf.slice(
194 | np.triu(
195 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32') -
196 | np.triu(
197 | np.ones([MAX_CONTEXT_SIZE, MAX_CONTEXT_SIZE], dtype='float32'),
198 | k=params.max_answer_size)), [0, 0], [max_x_len, max_x_len])
199 | prob_mat = tf.expand_dims(prob_start, -1) * tf.expand_dims(
200 | prob_end, 1) * tf.expand_dims(upper_tri_mat, 0)
201 | # TODO(seominjoon): Handle this.
202 | logits_mat = tf_utils.exp_mask(
203 | tf.expand_dims(logits_start, -1) + tf.expand_dims(logits_end, 1),
204 | tf.expand_dims(upper_tri_mat, 0),
205 | mask_is_length=False)
206 | del logits_mat
207 |
208 | answer_pred_start = tf.argmax(tf.reduce_max(prob_mat, 2), 1)
209 | answer_pred_end = tf.argmax(tf.reduce_max(prob_mat, 1), 1) # [batch_size]
210 | answer = squad_data.get_answer_op(features['context'],
211 | features['context_words'],
212 | answer_pred_start, answer_pred_end)
213 | answer_prob = tf.reduce_max(prob_mat, [1, 2])
214 |
215 | predictions = {
216 | 'yp1': answer_pred_start,
217 | 'yp2': answer_pred_end,
218 | 'p1': prob_start,
219 | 'p2': prob_end,
220 | 'a': answer,
221 | 'id': features['id'],
222 | 'context': features['context'],
223 | 'context_words': features['context_words'],
224 | 'no_answer_prob': no_answer_prob,
225 | 'no_answer_prob_start': no_answer_prob_start,
226 | 'no_answer_prob_end': no_answer_prob_end,
227 | 'answer_prob': answer_prob,
228 | 'has_answer': has_answer,
229 | }
230 | return predictions
231 |
232 |
233 | def get_loss(answer_start,
234 | answer_end,
235 | logits_start,
236 | logits_end,
237 | no_answer_bias,
238 | sparse=True):
239 | """Get loss given answer and logits.
240 |
241 | Args:
242 | answer_start: [batch_size, num_answers] shaped tensor if `sparse=True`, or
243 | [batch_size, context_size] shaped if `sparse=False`.
244 | answer_end: Similar to `answer_start` but for end.
245 | logits_start: [batch_size, context_size]-shaped tensor for answer start
246 | logits.
247 | logits_end: Similar to `logits_start`, but for end. This tensor can be also
248 | [batch_size, context_size, context_size], in which case the true answer
249 | start is used to index on dim 1 (context_size).
250 | no_answer_bias: [batch_size, 1] shaped tensor, bias for no answer decision.
251 | sparse: Indicates whether `answer_start` and `answer_end` are sparse or
252 | dense.
253 | Returns:
254 | Float loss tensor.
255 | """
256 | if sparse:
257 | # During training, only one answer. During eval, multiple answers.
258 | # TODO(seominjoon): Make eval loss minimum over multiple answers.
259 | # Loss for start.
260 | answer_start = tf.squeeze(tf.slice(answer_start, [0, 0], [-1, 1]), 1)
261 | answer_start += 1
262 | logits_start = tf.concat([no_answer_bias, logits_start], 1)
263 | losses_start = tf.nn.sparse_softmax_cross_entropy_with_logits(
264 | labels=answer_start, logits=logits_start)
265 | loss_start = tf.reduce_mean(losses_start)
266 | tf.add_to_collection('losses', loss_start)
267 |
268 | # Loss for end.
269 | answer_end = tf.squeeze(tf.slice(answer_end, [0, 0], [-1, 1]), 1)
270 | # Below are start-conditional loss, where every start position has its
271 | # own logits for end position.
272 | if len(logits_end.get_shape()) == 3:
273 | mask = tf.one_hot(answer_start, tf.shape(logits_end)[1])
274 | mask = tf.cast(tf.expand_dims(mask, -1), 'float')
275 | logits_end = tf.reduce_sum(mask * logits_end, 1)
276 | answer_end += 1
277 | logits_end = tf.concat([no_answer_bias, logits_end], 1)
278 | losses_end = tf.nn.sparse_softmax_cross_entropy_with_logits(
279 | labels=answer_end, logits=logits_end)
280 | loss_end = tf.reduce_mean(losses_end)
281 | tf.add_to_collection('losses', loss_end)
282 | else:
283 | # TODO(seominjoon): Implement no answer capability for non sparse labels.
284 | losses_start = tf.nn.softmax_cross_entropy_with_logits(
285 | labels=answer_start, logits=logits_start)
286 | loss_start = tf.reduce_mean(losses_start)
287 | tf.add_to_collection('losses', loss_start)
288 |
289 | losses_end = tf.nn.softmax_cross_entropy_with_logits(
290 | labels=answer_end, logits=logits_end)
291 | loss_end = tf.reduce_mean(losses_end)
292 | tf.add_to_collection('losses', loss_end)
293 |
294 | return tf.add_n(tf.get_collection('losses'))
295 |
296 |
297 | def get_train_op(loss,
298 | var_list=None,
299 | post_ops=None,
300 | inc_step=True,
301 | learning_rate=0.001,
302 | clip_norm=0.0):
303 | """Get train op for the given loss.
304 |
305 | Args:
306 | loss: Loss tensor.
307 | var_list: A list of variables that the train op will minimize.
308 | post_ops: A list of ops that will be run after the train op. If not defined,
309 | no op is run after train op.
310 | inc_step: If `True`, will increase the `global_step` variable by 1 after
311 | step.
312 | learning_rate: Initial learning rate for the optimizer.
313 | clip_norm: If specified, clips the gradient of each variable by this value.
314 | Returns:
315 | Train op to be used for training.
316 | """
317 |
318 | global_step = tf.train.get_global_step() if inc_step else None
319 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
320 | grads = optimizer.compute_gradients(loss, var_list=var_list)
321 | grads = [(grad, var) for grad, var in grads if grad is not None]
322 | for grad, var in grads:
323 | tf.summary.histogram(var.op.name, var)
324 | tf.summary.histogram('gradients/' + var.op.name, grad)
325 | if clip_norm:
326 | grads = [(tf.clip_by_norm(grad, clip_norm), var) for grad, var in grads]
327 | train_op = optimizer.apply_gradients(grads, global_step=global_step)
328 |
329 | if post_ops is not None:
330 | with tf.control_dependencies([train_op]):
331 | train_op = tf.group(*post_ops)
332 |
333 | return train_op
334 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 |
17 | DATA_DIR=$HOME/data
18 | mkdir $DATA_DIR
19 |
20 | # Download SQuAD
21 | SQUAD_DIR=$DATA_DIR/squad
22 | mkdir $SQUAD_DIR
23 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $SQUAD_DIR/train-v1.1.json
24 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $SQUAD_DIR/dev-v1.1.json
25 |
26 | # Download GloVe
27 | GLOVE_DIR=$DATA_DIR/glove
28 | mkdir $GLOVE_DIR
29 | wget http://nlp.stanford.edu/data/glove.6B.zip -O $GLOVE_DIR/glove.6B.zip
30 | unzip $GLOVE_DIR/glove.6B.zip -d $GLOVE_DIR
31 |
32 | # Download NLTK tokenizer data
33 | # Make sure nltk is already installed!
34 | python -m nltk.downloader -d $HOME/nltk_data punkt
35 |
--------------------------------------------------------------------------------
/feature_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Feature map model.
15 |
16 | Obtains separate embedding for question and each word of context, and then
17 | use distance metric (dot, l1, l2) to obtain the closest word in the context.
18 | """
19 | # TODO(seominjoon): Refactor file and function names (e.g. drop squad).
20 | import sys
21 |
22 | import tensorflow as tf
23 | import tensorflow.contrib.learn as learn
24 |
25 | import tf_utils
26 | from common_model import char_layer
27 | from common_model import embedding_layer
28 | from common_model import glove_layer
29 |
30 |
31 | def feature_model(features, mode, params, scope=None):
32 | """Handler for SQuAD feature models: only allow feature mapping.
33 |
34 | Every feature model is called via this function. Which model to call can
35 | be controlled via `params.model_id` (e.g. `params.model_id = m00`).
36 |
37 | Function requirement is in:
38 | https://www.tensorflow.org/extend/estimators
39 |
40 | This function does not have any dependency on FLAGS. All parameters must be
41 | passed through `params` argument for all models in this script.
42 |
43 | Args:
44 | features: A dict of feature tensors.
45 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
46 | params: `params` passed during initialization of `Estimator` object.
47 | scope: Variable scope.
48 | Returns:
49 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of
50 | tensors that can be useful outside of this function, e.g. visualization.
51 | """
52 | this_module = sys.modules[__name__]
53 | model_fn = getattr(this_module, '_model_%s' % params.model_id)
54 | return model_fn(features, mode, params, scope=scope)
55 |
56 |
57 | def _get_logits_from_multihead_x_and_q(x_start,
58 | x_end,
59 | q_start,
60 | q_end,
61 | x_len,
62 | dist,
63 | x_start_reduce_fn=None,
64 | x_end_reduce_fn=None,
65 | q_start_reduce_fn=None,
66 | q_end_reduce_fn=None):
67 | """Helper function for getting logits from context and question vectors.
68 |
69 | Args:
70 | x_start: [batch_size, context_num_words, hidden_size]-shaped float tensor,
71 | representing context vectors for answer start.
72 | Can have one more dim at the end, which will be reduced after computing
73 | distance via `reduce_fn`.
74 | x_end: [batch_size, context_num_words, hidden_size]-shaped float tensor,
75 | representing context vectors for answer end.
76 | Can have one more dim at the end, which will be reduced after computing
77 | distance via `reduce_fn`.
78 | q_start: [batch_size, hidden_size]-shaped float tensor,
79 | representing question vector for answer start.
80 | Can have one more dim at the end, which will be reduced after computing
81 | distance via `reduce_fn`.
82 | q_end: [batch_size, hidden_size]-shaped float tensor,
83 | representing question vector for answer end.
84 | Can have one more dim at the end, which will be reduced after computing
85 | distance via `reduce_fn`.
86 | x_len: [batch_size]-shaped int64 tensor, containing length of each context.
87 | dist: distance function, `dot`, `l1`, or `l2`.
88 | x_start_reduce_fn: reduction function that takes in the tensor as first
89 | argument and the axis as the second argument. Default is `tf.reduce_max`.
90 | Reduction for `x_start` and `x_end` if the extra dim is provided.
91 | Note that `l1` and `l2` distances are first negated and then the reduction
92 | is applied, so that `reduce_max` effectively gets minimal distance.
93 | Note that, for correct performance during inference when using nearest
94 | neighbor, reduction must be the default one (None, i.e. `tf.reduce_max`).
95 | x_end_reduce_fn: ditto, for end.
96 | q_start_reduce_fn: reduction function that takes in the tensor as first
97 | argument and the axis as the second argument. Default is `tf.reduce_max`.
98 | Reduction for `q_start` and `q_end` if the extra dim is provided.
99 | Note that `l1` and `l2` distances are first negated and then the reduction
100 | is applied, so that `reduce_max` effectively gets minimal distance.
101 | This can be any reduction function, unlike `x_reduce_fn`.
102 | q_end_reduce_fn: ditto, for end.
103 |
104 | Returns:
105 | a tuple `(logits_start, logits_end)` where each tensor's shape is
106 | [batch_size, context_num_words].
107 |
108 | """
109 | if len(q_start.get_shape()) == 1:
110 | # q can be universal, e.g. trainable weights.
111 | q_start = tf.expand_dims(q_start, 0)
112 | q_end = tf.expand_dims(q_end, 0)
113 |
114 | # Expand q first to broadcast for `context_num_words` dim.
115 | q_start = tf.expand_dims(q_start, 1)
116 | q_end = tf.expand_dims(q_end, 1)
117 |
118 | # Add one dim at the end if no additional dim at the end, to make them rank-4.
119 | if len(x_start.get_shape()) == 3:
120 | x_start = tf.expand_dims(x_start, -1)
121 | x_end = tf.expand_dims(x_end, -1)
122 | if len(q_start.get_shape()) == 3:
123 | q_start = tf.expand_dims(q_start, -1)
124 | q_end = tf.expand_dims(q_end, -1)
125 |
126 | # Add dim to outer-product x and q. This makes them rank-5 tensors.
127 | # shape : [batch_size, context_words, hidden_size, num_heads, 1]
128 | x_start = tf.expand_dims(x_start, -1)
129 | x_end = tf.expand_dims(x_end, -1)
130 |
131 | # shape : [batch_size, context_words, hidden_size, 1, num_heads]
132 | q_start = tf.expand_dims(q_start, 3)
133 | q_end = tf.expand_dims(q_end, 3)
134 |
135 | if x_start_reduce_fn is None:
136 | x_start_reduce_fn = tf.reduce_max
137 | if x_end_reduce_fn is None:
138 | x_end_reduce_fn = tf.reduce_max
139 | if q_start_reduce_fn is None:
140 | q_start_reduce_fn = tf.reduce_max
141 | if q_end_reduce_fn is None:
142 | q_end_reduce_fn = tf.reduce_max
143 |
144 | if dist == 'dot':
145 | logits_start = q_start_reduce_fn(
146 | x_start_reduce_fn(tf.reduce_sum(x_start * q_start, 2), 2), 2)
147 | logits_end = q_end_reduce_fn(
148 | x_end_reduce_fn(tf.reduce_sum(x_end * q_end, 2), 2), 2)
149 | elif dist == 'l1':
150 | logits_start = q_start_reduce_fn(
151 | x_start_reduce_fn(
152 | -tf.norm(x_start - q_start, ord=1, axis=2, keep_dims=True), 2), 2)
153 | logits_start = tf.squeeze(
154 | tf.layers.dense(logits_start, 1, name='logits_start'), 2)
155 | logits_end = q_end_reduce_fn(
156 | x_end_reduce_fn(-tf.norm(x_end - q_end, ord=1, axis=2, keep_dims=True),
157 | 2), 2)
158 | logits_end = tf.squeeze(
159 | tf.layers.dense(logits_end, 1, name='logits_end'), 2)
160 | elif dist == 'l2':
161 | logits_start = q_start_reduce_fn(
162 | x_start_reduce_fn(
163 | -tf.norm(x_start - q_start, ord=2, axis=2, keep_dims=True), 2), 2)
164 | logits_start = tf.squeeze(
165 | tf.layers.dense(logits_start, 1, name='logits_start'), 2)
166 | logits_end = q_end_reduce_fn(
167 | x_end_reduce_fn(-tf.norm(x_end - q_end, ord=2, axis=2, keep_dims=True),
168 | 2), 2)
169 | logits_end = tf.squeeze(
170 | tf.layers.dense(logits_end, 1, name='logits_end'), 2)
171 |
172 | logits_start = tf_utils.exp_mask(logits_start, x_len)
173 | logits_end = tf_utils.exp_mask(logits_end, x_len)
174 |
175 | return logits_start, logits_end
176 |
177 |
178 | def _model_m00(features, mode, params, scope=None):
179 | """LSTM-based model.
180 |
181 | This model uses two stacked LSTMs to output vectors for context, and
182 | self-attention to output vectors for question. This model reaches 57~58% F1.
183 |
184 | Args:
185 | features: A dict of feature tensors.
186 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
187 | params: `params` passed during initialization of `Estimator` object.
188 | scope: Variable scope, default is `feature_model`.
189 | Returns:
190 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of
191 | tensors that can be useful outside of this function, e.g. visualization.
192 | """
193 |
194 | with tf.variable_scope(scope or 'feature_model'):
195 | training = mode == learn.ModeKeys.TRAIN
196 |
197 | x, q = embedding_layer(features, mode, params)
198 |
199 | x1 = tf_utils.bi_rnn(
200 | params.hidden_size,
201 | x,
202 | sequence_length_list=features['context_num_words'],
203 | scope='x_bi_rnn_1',
204 | training=training,
205 | dropout_rate=params.dropout_rate)
206 |
207 | x2 = tf_utils.bi_rnn(
208 | params.hidden_size,
209 | x1,
210 | sequence_length_list=features['context_num_words'],
211 | scope='x_bi_rnn_2',
212 | training=training,
213 | dropout_rate=params.dropout_rate)
214 |
215 | q1 = tf_utils.bi_rnn(
216 | params.hidden_size,
217 | q,
218 | sequence_length_list=features['question_num_words'],
219 | scope='q_bi_rnn_1',
220 | training=training,
221 | dropout_rate=params.dropout_rate)
222 |
223 | q2 = tf_utils.bi_rnn(
224 | params.hidden_size,
225 | q1,
226 | sequence_length_list=features['question_num_words'],
227 | scope='q_bi_rnn_2',
228 | training=training,
229 | dropout_rate=params.dropout_rate)
230 |
231 | # Self-attention to obtain single vector representation.
232 | q_start = tf_utils.self_att(
233 | q1, mask=features['question_num_words'], scope='q_start')
234 | q_end = tf_utils.self_att(
235 | q2, mask=features['question_num_words'], scope='q_end')
236 |
237 | logits_start, logits_end = _get_logits_from_multihead_x_and_q(
238 | x1, x2, q_start, q_end, features['context_num_words'], params.dist)
239 | return logits_start, logits_end, dict()
240 |
241 |
242 | def _model_m01(features, mode, params, scope=None):
243 | """Self-attention with MLP, reaching 55~56% F1.
244 |
245 | Args:
246 | features: A dict of feature tensors.
247 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
248 | params: `params` passed during initialization of `Estimator` object.
249 | scope: Variable scope, default is `feature_model`.
250 | Returns:
251 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of
252 | tensors that can be useful outside of this function, e.g. visualization.
253 | """
254 |
255 | with tf.variable_scope(scope or 'feature_model'):
256 | training = mode == learn.ModeKeys.TRAIN
257 | tensors = {}
258 |
259 | x, q = embedding_layer(features, mode, params)
260 |
261 | x0 = tf_utils.bi_rnn(
262 | params.hidden_size,
263 | x,
264 | sequence_length_list=features['context_num_words'],
265 | scope='bi_rnn_x0',
266 | training=training,
267 | dropout_rate=params.dropout_rate)
268 |
269 | x1 = tf_utils.bi_rnn(
270 | params.hidden_size,
271 | x0,
272 | sequence_length_list=features['context_num_words'],
273 | scope='bi_rnn_x1',
274 | training=training,
275 | dropout_rate=params.dropout_rate)
276 |
277 | x1 += x0
278 |
279 | q1 = tf_utils.bi_rnn(
280 | params.hidden_size,
281 | q,
282 | sequence_length_list=features['question_num_words'],
283 | scope='bi_rnn_q1',
284 | training=training,
285 | dropout_rate=params.dropout_rate)
286 |
287 | def get_x(x_, scope=None):
288 | with tf.variable_scope(scope or 'get_x_clue'):
289 | hidden_sizes = [params.hidden_size, params.hidden_size]
290 | attender = tf_utils.mlp(
291 | x_,
292 | hidden_sizes,
293 | activate_last=False,
294 | training=training,
295 | dropout_rate=params.dropout_rate,
296 | scope='attender')
297 | attendee = tf_utils.mlp(
298 | x_,
299 | hidden_sizes,
300 | activate_last=False,
301 | training=training,
302 | dropout_rate=params.dropout_rate,
303 | scope='attendee')
304 | clue = tf_utils.att2d(
305 | attendee,
306 | attender,
307 | a_val=x_,
308 | mask=features['context_num_words'],
309 | logit_fn='dot',
310 | tensors=tensors)
311 | return tf.concat([x_, clue], 2)
312 |
313 | x_start = get_x(x1, scope='get_x_start')
314 | x_end = get_x(x1, scope='get_x_end')
315 |
316 | q_type = tf_utils.self_att(
317 | q1,
318 | mask=features['question_num_words'],
319 | scope='self_att_q_type',
320 | tensors=tensors)
321 | q_clue = tf_utils.self_att(
322 | q1,
323 | mask=features['question_num_words'],
324 | scope='self_att_q_clue',
325 | tensors=tensors)
326 | q_start = q_end = tf.concat([q_type, q_clue], 1)
327 |
328 | logits_start, logits_end = _get_logits_from_multihead_x_and_q(
329 | x_start, x_end, q_start, q_end, features['context_num_words'],
330 | params.dist)
331 | return logits_start, logits_end, tensors
332 |
333 |
334 | def _model_m02(features, mode, params, scope=None):
335 | """Self-attention with LSTM, reaching 59~60% F1.
336 |
337 | Args:
338 | features: A dict of feature tensors.
339 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
340 | params: `params` passed during initialization of `Estimator` object.
341 | scope: Variable scope, default is `feature_model`.
342 | Returns:
343 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of
344 | tensors that can be useful outside of this function, e.g. visualization.
345 | """
346 |
347 | with tf.variable_scope(scope or 'feature_model'):
348 | training = mode == learn.ModeKeys.TRAIN
349 | tensors = {}
350 |
351 | x, q = embedding_layer(features, mode, params)
352 |
353 | x1 = tf_utils.bi_rnn(
354 | params.hidden_size,
355 | x,
356 | sequence_length_list=features['context_num_words'],
357 | scope='bi_rnn_x1',
358 | training=training,
359 | dropout_rate=params.dropout_rate)
360 |
361 | def get_clue(x_, scope=None):
362 | with tf.variable_scope(scope or 'get_clue'):
363 | attendee = tf_utils.bi_rnn(
364 | params.hidden_size,
365 | x_,
366 | sequence_length_list=features['context_num_words'],
367 | scope='bi_rnn_attendee',
368 | training=training,
369 | dropout_rate=params.dropout_rate)
370 | attender = tf_utils.bi_rnn(
371 | params.hidden_size,
372 | x_,
373 | sequence_length_list=features['context_num_words'],
374 | scope='bi_rnn_attender',
375 | training=training,
376 | dropout_rate=params.dropout_rate)
377 | clue = tf_utils.att2d(
378 | attendee,
379 | attender,
380 | a_val=x_,
381 | mask=features['context_num_words'],
382 | logit_fn='dot',
383 | tensors=tensors)
384 | return clue
385 |
386 | x1_clue = get_clue(x1)
387 | x_start = tf.concat([x1, x1_clue], 2)
388 |
389 | x2 = tf_utils.bi_rnn(
390 | params.hidden_size,
391 | x1,
392 | sequence_length_list=features['context_num_words'],
393 | scope='bi_rnn_x2',
394 | training=training,
395 | dropout_rate=params.dropout_rate)
396 | x2_clue = tf_utils.bi_rnn(
397 | params.hidden_size,
398 | x1_clue,
399 | sequence_length_list=features['context_num_words'],
400 | scope='bi_rnn_x2_clue',
401 | training=training,
402 | dropout_rate=params.dropout_rate)
403 | x_end = tf.concat([x2, x2_clue], 2)
404 |
405 | q1 = tf_utils.bi_rnn(
406 | params.hidden_size,
407 | q,
408 | sequence_length_list=features['question_num_words'],
409 | scope='bi_rnn_q1',
410 | training=training,
411 | dropout_rate=params.dropout_rate)
412 |
413 | q2 = tf_utils.bi_rnn(
414 | params.hidden_size,
415 | q1,
416 | sequence_length_list=features['question_num_words'],
417 | scope='bi_rnn_q2',
418 | training=training,
419 | dropout_rate=params.dropout_rate)
420 |
421 | # Self-attention to obtain single vector representation.
422 | q1_type = tf_utils.self_att(
423 | q1,
424 | mask=features['question_num_words'],
425 | tensors=tensors,
426 | scope='self_att_q1_type')
427 | q1_clue = tf_utils.self_att(
428 | q1,
429 | mask=features['question_num_words'],
430 | tensors=tensors,
431 | scope='self_att_q1_clue')
432 | q_start = tf.concat([q1_type, q1_clue], 1)
433 | q2_type = tf_utils.self_att(
434 | q2, mask=features['question_num_words'], scope='self_att_q2_type')
435 | q2_clue = tf_utils.self_att(
436 | q2, mask=features['question_num_words'], scope='self_att_q2_clue')
437 | q_end = tf.concat([q2_type, q2_clue], 1)
438 |
439 | logits_start, logits_end = _get_logits_from_multihead_x_and_q(
440 | x_start, x_end, q_start, q_end, features['context_num_words'],
441 | params.dist)
442 | return logits_start, logits_end, tensors
443 |
444 |
445 | def _model_m03(features, mode, params, scope=None):
446 | """Independent self-attention with LSTM, reaching 60~61%.
447 |
448 | Args:
449 | features: A dict of feature tensors.
450 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
451 | params: `params` passed during initialization of `Estimator` object.
452 | scope: Variable scope, default is `feature_model`.
453 | Returns:
454 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of
455 | tensors that can be useful outside of this function, e.g. visualization.
456 | """
457 |
458 | with tf.variable_scope(scope or 'feature_model'):
459 | training = mode == learn.ModeKeys.TRAIN
460 | tensors = {}
461 |
462 | x, q = embedding_layer(features, mode, params)
463 |
464 | def get_x_and_q(scope=None):
465 | with tf.variable_scope(scope or 'get_x_and_q'):
466 | x1 = tf_utils.bi_rnn(
467 | params.hidden_size,
468 | x,
469 | sequence_length_list=features['context_num_words'],
470 | scope='bi_rnn_x1',
471 | training=training,
472 | dropout_rate=params.dropout_rate)
473 |
474 | attendee = tf_utils.bi_rnn(
475 | params.hidden_size,
476 | x1,
477 | sequence_length_list=features['context_num_words'],
478 | scope='bi_rnn_attendee',
479 | training=training,
480 | dropout_rate=params.dropout_rate)
481 | attender = tf_utils.bi_rnn(
482 | params.hidden_size,
483 | x1,
484 | sequence_length_list=features['context_num_words'],
485 | scope='bi_rnn_attender',
486 | training=training,
487 | dropout_rate=params.dropout_rate)
488 | clue = tf_utils.att2d(
489 | attendee,
490 | attender,
491 | a_val=x1,
492 | mask=features['context_num_words'],
493 | logit_fn='dot',
494 | tensors=tensors)
495 |
496 | x_out = tf.concat([x1, clue], 2)
497 |
498 | q1 = tf_utils.bi_rnn(
499 | params.hidden_size,
500 | q,
501 | sequence_length_list=features['question_num_words'],
502 | scope='bi_rnn_q1',
503 | training=training,
504 | dropout_rate=params.dropout_rate)
505 |
506 | q1_type = tf_utils.self_att(
507 | q1,
508 | mask=features['question_num_words'],
509 | tensors=tensors,
510 | scope='self_att_q1_type')
511 | q1_clue = tf_utils.self_att(
512 | q1,
513 | mask=features['question_num_words'],
514 | tensors=tensors,
515 | scope='self_att_q1_clue')
516 | q_out = tf.concat([q1_type, q1_clue], 1)
517 |
518 | return x_out, q_out
519 |
520 | x_start, q_start = get_x_and_q('start')
521 | x_end, q_end = get_x_and_q('end')
522 |
523 | logits_start, logits_end = _get_logits_from_multihead_x_and_q(
524 | x_start, x_end, q_start, q_end, features['context_num_words'],
525 | params.dist)
526 | return logits_start, logits_end, tensors
527 |
528 |
529 | def _model_m04(features, mode, params, scope=None):
530 | """Regularization with query generation loss on top of m03, reaching 63~64%.
531 |
532 | Note that most part of this model is identical to m03, except for the function
533 | `reg_gen`, which adds additional generation loss.
534 |
535 | Args:
536 | features: A dict of feature tensors.
537 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
538 | params: `params` passed during initialization of `Estimator` object.
539 | scope: Variable scope, default is `feature_model`.
540 | Returns:
541 | `(logits_start, logits_end, tensors)` pair. `tensors` is a dictionary of
542 | tensors that can be useful outside of this function, e.g. visualization.
543 | """
544 |
545 | with tf.variable_scope(scope or 'feature_model'):
546 | training = mode == learn.ModeKeys.TRAIN
547 | inference = mode == learn.ModeKeys.INFER
548 | tensors = {}
549 |
550 | with tf.variable_scope('embedding'):
551 | glove_emb_mat, xv, qv = glove_layer(features)
552 | _, xc, qc = char_layer(features, params)
553 | x = tf.concat([xc, xv], 2)
554 | q = tf.concat([qc, qv], 2)
555 | x = tf_utils.highway_net(
556 | x, 2, training=training, dropout_rate=params.dropout_rate)
557 | q = tf_utils.highway_net(
558 | q, 2, training=training, dropout_rate=params.dropout_rate, reuse=True)
559 |
560 | def get_x_and_q(scope=None):
561 | with tf.variable_scope(scope or 'get_x_and_q'):
562 | x1 = tf_utils.bi_rnn(
563 | params.hidden_size,
564 | x,
565 | sequence_length_list=features['context_num_words'],
566 | scope='bi_rnn_x1',
567 | training=training,
568 | dropout_rate=params.dropout_rate)
569 |
570 | attendee = tf_utils.bi_rnn(
571 | params.hidden_size,
572 | x1,
573 | sequence_length_list=features['context_num_words'],
574 | scope='bi_rnn_attendee',
575 | training=training,
576 | dropout_rate=params.dropout_rate)
577 | attender = tf_utils.bi_rnn(
578 | params.hidden_size,
579 | x1,
580 | sequence_length_list=features['context_num_words'],
581 | scope='bi_rnn_attender',
582 | training=training,
583 | dropout_rate=params.dropout_rate)
584 | clue = tf_utils.att2d(
585 | attendee,
586 | attender,
587 | a_val=x1,
588 | mask=features['context_num_words'],
589 | logit_fn='dot',
590 | tensors=tensors)
591 |
592 | x_out = tf.concat([x1, clue], 2)
593 |
594 | q1 = tf_utils.bi_rnn(
595 | params.hidden_size,
596 | q,
597 | sequence_length_list=features['question_num_words'],
598 | scope='bi_rnn_q1',
599 | training=training,
600 | dropout_rate=params.dropout_rate)
601 |
602 | q1_type = tf_utils.self_att(
603 | q1,
604 | mask=features['question_num_words'],
605 | tensors=tensors,
606 | scope='self_att_q1_type')
607 | q1_clue = tf_utils.self_att(
608 | q1,
609 | mask=features['question_num_words'],
610 | tensors=tensors,
611 | scope='self_att_q1_clue')
612 | q_out = tf.concat([q1_type, q1_clue], 1)
613 |
614 | return x_out, q_out
615 |
616 | x_start, q_start = get_x_and_q('start')
617 | x_end, q_end = get_x_and_q('end')
618 |
619 | # TODO(seominjoon): Separate regularization and model parts.
620 | def reg_gen(glove_emb_mat, memory, scope):
621 | """Add query generation loss to `losses` collection as regularization."""
622 | with tf.variable_scope(scope):
623 | start_vec = tf.get_variable(
624 | 'start_vec', shape=glove_emb_mat.get_shape()[1])
625 | end_vec = tf.get_variable('end_vec', shape=glove_emb_mat.get_shape()[1])
626 | glove_emb_mat = tf.concat([
627 | glove_emb_mat,
628 | tf.expand_dims(start_vec, 0),
629 | tf.expand_dims(end_vec, 0)
630 | ], 0)
631 | vocab_size = glove_emb_mat.get_shape().as_list()[0]
632 | start_idx = vocab_size - 2
633 | end_idx = vocab_size - 1
634 | batch_size = tf.shape(x)[0]
635 |
636 | # Index memory
637 | memory_mask = tf.one_hot(
638 | tf.slice(features['word_answer_%ss' % scope], [0, 0], [-1, 1]),
639 | tf.shape(x)[1])
640 | # Transposing below is just a convenient way to do reduction at dim 2
641 | # and expansion at dim 1 with one operation.
642 | memory_mask = tf.transpose(memory_mask, [0, 2, 1])
643 | initial_state = tf.reduce_sum(memory * tf.cast(memory_mask, 'float'), 1)
644 | cell = tf.contrib.rnn.GRUCell(memory.get_shape().as_list()[-1])
645 |
646 | glove_emb_mat_dense = tf.layers.dense(glove_emb_mat, cell.output_size)
647 |
648 | def deembed(inputs):
649 | shape = tf.shape(inputs)
650 | inputs = tf.reshape(inputs, [-1, shape[-1]])
651 | outputs = tf.matmul(inputs, tf.transpose(glove_emb_mat_dense))
652 | outputs = tf.reshape(outputs, tf.concat([shape[:-1], [vocab_size]],
653 | 0))
654 | return outputs
655 |
656 | if inference:
657 | # During inference, feed previous output to the next input.
658 | start_tokens = tf.tile(tf.reshape(start_idx, [1]), [batch_size])
659 | helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
660 | glove_emb_mat, start_tokens, end_idx)
661 | helper = tf_utils.DeembedWrapper(helper, deembed)
662 | maximum_iterations = params.max_gen_length
663 | else:
664 | # During training and eval, feed ground truth input all the time.
665 | q_in = tf_utils.concat_seq_and_tok(qv, start_vec, 'start')
666 | indexed_q_out = tf_utils.concat_seq_and_tok(
667 | tf.cast(features['glove_indexed_question_words'], 'int32'),
668 | end_idx,
669 | 'end',
670 | sequence_length=features['question_num_words'])
671 | q_len = tf.cast(features['question_num_words'], 'int32') + 1
672 | helper = tf.contrib.seq2seq.TrainingHelper(q_in, q_len)
673 | maximum_iterations = None
674 |
675 | decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, initial_state)
676 | (outputs, _), _, _ = tf.contrib.seq2seq.dynamic_decode(
677 | decoder, maximum_iterations=maximum_iterations)
678 | logits = deembed(outputs)
679 | indexed_q_pred = tf.argmax(logits, axis=2, name='indexed_q_pred')
680 | tensors[indexed_q_pred.op.name] = indexed_q_pred
681 |
682 | if not inference:
683 | # Add sequence loss to the `losses` collection.
684 | weights = tf.sequence_mask(q_len, maxlen=tf.shape(indexed_q_out)[1])
685 | loss = tf.contrib.seq2seq.sequence_loss(logits, indexed_q_out,
686 | tf.cast(weights, 'float'))
687 | cf = params.reg_cf * tf.exp(-tf.log(2.0) * tf.cast(
688 | tf.train.get_global_step(), 'float') / params.reg_half_life)
689 | tf.add_to_collection('losses', cf * loss)
690 |
691 | if params.reg_gen:
692 | reg_gen(glove_emb_mat, x_start, 'start')
693 | reg_gen(glove_emb_mat, x_end, 'end')
694 |
695 | logits_start, logits_end = _get_logits_from_multihead_x_and_q(
696 | x_start, x_end, q_start, q_end, features['context_num_words'],
697 | params.dist)
698 | return logits_start, logits_end, tensors
699 |
--------------------------------------------------------------------------------
/kernel_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Kernel model.
15 |
16 | This model allows interactions between the two inputs, such as attention.
17 | """
18 | import sys
19 |
20 | import tensorflow as tf
21 | import tensorflow.contrib.learn as learn
22 |
23 | import tf_utils
24 | from common_model import embedding_layer
25 |
26 |
27 | def kernel_model(features, mode, params, scope=None):
28 | """Kernel models that allow interaction between question and context.
29 |
30 | This is handler for all kernel models in this script. Models are called via
31 | `params.model_id` (e.g. `params.model_id = m00`).
32 |
33 | Function requirement for each model is in:
34 | https://www.tensorflow.org/extend/estimators
35 |
36 | This function does not have any dependency on FLAGS. All parameters must be
37 | passed through `params` argument.
38 |
39 | Args:
40 | features: A dict of feature tensors.
41 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
42 | params: `params` passed during initialization of `Estimator` object.
43 | scope: Variable name scope.
44 | Returns:
45 | `(logits_start, logits_end, tensors)` pair. Tensors is a dictionary of
46 | tensors that can be useful outside of this function, e.g. visualization.
47 | """
48 | this_module = sys.modules[__name__]
49 | model_fn = getattr(this_module, '_model_%s' % params.model_id)
50 | return model_fn(
51 | features, mode, params, scope=scope)
52 |
53 |
54 | def _model_m00(features, mode, params, scope=None):
55 | """Simplified BiDAF, reaching 74~75% F1.
56 |
57 | Args:
58 | features: A dict of feature tensors.
59 | mode: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ModeKeys
60 | params: `params` passed during initialization of `Estimator` object.
61 | scope: Variable name scope.
62 | Returns:
63 | `(logits_start, logits_end, tensors)` pair. Tensors is a dictionary of
64 | tensors that can be useful outside of this function, e.g. visualization.
65 | """
66 | with tf.variable_scope(scope or 'kernel_model'):
67 | training = mode == learn.ModeKeys.TRAIN
68 | tensors = {}
69 |
70 | x, q = embedding_layer(features, mode, params)
71 |
72 | x0 = tf_utils.bi_rnn(
73 | params.hidden_size,
74 | x,
75 | sequence_length_list=features['context_num_words'],
76 | scope='x_bi_rnn_0',
77 | training=training,
78 | dropout_rate=params.dropout_rate)
79 |
80 | q0 = tf_utils.bi_rnn(
81 | params.hidden_size,
82 | q,
83 | sequence_length_list=features['question_num_words'],
84 | scope='q_bi_rnn_0',
85 | training=training,
86 | dropout_rate=params.dropout_rate)
87 |
88 | xq = tf_utils.att2d(
89 | q0,
90 | x0,
91 | mask=features['question_num_words'],
92 | tensors=tensors,
93 | scope='xq')
94 | xq = tf.concat([x0, xq, x0 * xq], 2)
95 | x1 = tf_utils.bi_rnn(
96 | params.hidden_size,
97 | xq,
98 | sequence_length_list=features['context_num_words'],
99 | training=training,
100 | scope='x1_bi_rnn',
101 | dropout_rate=params.dropout_rate)
102 | x2 = tf_utils.bi_rnn(
103 | params.hidden_size,
104 | x1,
105 | sequence_length_list=features['context_num_words'],
106 | training=training,
107 | scope='x2_bi_rnn',
108 | dropout_rate=params.dropout_rate)
109 | x3 = tf_utils.bi_rnn(
110 | params.hidden_size,
111 | x2,
112 | sequence_length_list=features['context_num_words'],
113 | training=training,
114 | scope='x3_bi_rnn',
115 | dropout_rate=params.dropout_rate)
116 |
117 | logits_start = tf_utils.exp_mask(
118 | tf.squeeze(
119 | tf.layers.dense(tf.concat([x1, x2], 2), 1, name='logits1'), 2),
120 | features['context_num_words'])
121 | logits_end = tf_utils.exp_mask(
122 | tf.squeeze(
123 | tf.layers.dense(tf.concat([x1, x3], 2), 1, name='logits2'), 2),
124 | features['context_num_words'])
125 |
126 | return logits_start, logits_end, tensors
127 |
--------------------------------------------------------------------------------
/mips-qa.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/mipsqa/5ce97002e16069cac85a267f759014d5cb30cf9e/mips-qa.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow>=1.3.0
2 | nltk
3 | tqdm
4 |
5 |
--------------------------------------------------------------------------------
/squad_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """SQuAD data parsing module for tf.learn model.
15 |
16 | This module loads TFRecord and hyperparameters from a specified directory
17 | (files dumped by `squad_prepro.py`) and provides tensors for data feeding.
18 | This module also provides data-specific functions for evaluation.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | from collections import Counter
26 | import json
27 | import os
28 | import re
29 | import string
30 |
31 | import numpy as np
32 | import tensorflow as tf
33 |
34 | import squad_prepro
35 |
36 |
37 | def get_input_fn(root_data_dir,
38 | glove_dir,
39 | data_type,
40 | batch_size,
41 | glove_size,
42 | shuffle_files=True,
43 | shuffle_examples=False,
44 | queue_capacity=5000,
45 | min_after_dequeue=1000,
46 | num_epochs=None,
47 | oom_test=False):
48 | """Get input function for the given data type from the given data directory.
49 |
50 | Args:
51 | root_data_dir: The directory to load data from. Corresponds to `to_dir`
52 | of `squad_prepro_main.py` file.
53 | glove_dir: path to the directory that contains GloVe files.
54 | data_type: `str` object, either `train` or `dev`.
55 | batch_size: Batch size of the inputs.
56 | glove_size: size of GloVe vector to load.
57 | shuffle_files: If `True`, shuffle the queue for the input files.
58 | shuffle_examples: If `True`, shuffle the queue for the examples.
59 | queue_capacity: `int`, maximum number of examples in input queue.
60 | min_after_dequeue: `int`, for`RandomShuffleQueue`, minimum number of
61 | examples before dequeueing to ensure randomness.
62 | num_epochs: Number of epochs on the data. `None` means infinite.
63 | This queue comes after the file queue.
64 | oom_test: Stress test to see if the current dataset and model causes
65 | out-of-memory error on GPU.
66 | Returns:
67 | Function definition `input_fn` compatible with `Experiment` object.
68 | """
69 | filenames = tf.gfile.Glob(
70 | os.path.join(root_data_dir, data_type, 'data', 'squad_data_*'))
71 | tf.logging.info('reading examples from following files:')
72 | for filename in filenames:
73 | tf.logging.info(filename)
74 | sequence_feature = tf.FixedLenSequenceFeature(
75 | [], tf.int64, allow_missing=True)
76 | str_sequence_feature = tf.FixedLenSequenceFeature(
77 | [], tf.string, allow_missing=True)
78 | int_feature = tf.FixedLenFeature([], tf.int64)
79 | str_feature = tf.FixedLenFeature([], tf.string)
80 | # Let N = batch_size, JX = max num context words, JQ = max num ques words,
81 | # C = num chars per word (fixed, default = 16)
82 | features = {
83 | 'indexed_context_words': sequence_feature, # Shape = [JX]
84 | 'glove_indexed_context_words': sequence_feature,
85 | 'indexed_context_chars': sequence_feature, # Shape = [JX * C]
86 | 'indexed_question_words': sequence_feature, # Shape = [JQ]
87 | 'glove_indexed_question_words': sequence_feature,
88 | 'indexed_question_chars': sequence_feature, # Shape = [JQ * C]
89 | 'word_answer_starts': sequence_feature, # Answer start index.
90 | 'word_answer_ends': sequence_feature, # Answer end index.
91 | 'context_num_words':
92 | int_feature, # Number of context words in each example. [A]
93 | 'question_num_words':
94 | int_feature, # Number of question words in each example. [A]
95 | 'answers': str_sequence_feature, # List of answers in each example. [A]
96 | 'context_words': str_sequence_feature, # [JX]
97 | 'question_words': str_sequence_feature, # [JQ]
98 | 'context': str_feature,
99 | 'id': str_feature,
100 | 'num_answers': int_feature,
101 | 'question': str_feature,
102 | }
103 |
104 | exp_metadata_path = os.path.join(root_data_dir, 'metadata.json')
105 | with tf.gfile.GFile(exp_metadata_path, 'r') as fp:
106 | exp_metadata = json.load(fp)
107 |
108 | metadata_path = os.path.join(root_data_dir, data_type, 'metadata.json')
109 | with tf.gfile.GFile(metadata_path, 'r') as fp:
110 | metadata = json.load(fp)
111 | emb_mat = squad_prepro.get_idx2vec_mat(glove_dir, glove_size,
112 | metadata['glove_word2idx'])
113 |
114 | def _input_fn():
115 | """Input function compatible with `Experiment` object.
116 |
117 | Returns:
118 | A tuple of feature tensors and target tensors.
119 | """
120 | # TODO(seominjoon): There is bottleneck in data feeding, slow for N >= 128.
121 | filename_queue = tf.train.string_input_producer(
122 | filenames, shuffle=shuffle_files, num_epochs=num_epochs)
123 | reader = tf.TFRecordReader()
124 | _, se = reader.read(filename_queue)
125 | # TODO(seominjoon): Consider moving data filtering to here.
126 | features_op = tf.parse_single_example(se, features=features)
127 |
128 | names = list(features_op.keys())
129 | dtypes = [features_op[name].dtype for name in names]
130 | shapes = [features_op[name].shape for name in names]
131 |
132 | if shuffle_examples:
133 | # Data shuffling.
134 | rq = tf.RandomShuffleQueue(
135 | queue_capacity, min_after_dequeue, dtypes, names=names)
136 | else:
137 | rq = tf.FIFOQueue(queue_capacity, dtypes, names=names)
138 | enqueue_op = rq.enqueue(features_op)
139 | dequeue_op = rq.dequeue()
140 | dequeue_op = [dequeue_op[name] for name in names]
141 | qr = tf.train.QueueRunner(rq, [enqueue_op])
142 | tf.train.add_queue_runner(qr)
143 |
144 | batch = tf.train.batch(
145 | dequeue_op,
146 | batch_size,
147 | capacity=queue_capacity,
148 | dynamic_pad=True,
149 | shapes=shapes,
150 | allow_smaller_final_batch=True,
151 | num_threads=5)
152 | batch = {name: each for name, each in zip(names, batch)}
153 | target_keys = [
154 | 'word_answer_starts', 'word_answer_ends', 'answers', 'num_answers'
155 | ]
156 | # TODO(seominjoon) For cheating-safe, comment out #.
157 | features_batch = {
158 | key: val
159 | for key, val in batch.items() # if key not in target_keys
160 | }
161 |
162 | # `metadata['emb_mat`]` contains GloVe embedding, and `xv` in
163 | # `features_batch` index into the vectors.
164 | features_batch['emb_mat'] = tf.constant(emb_mat)
165 | targets_batch = {key: batch[key] for key in target_keys}
166 |
167 | # Postprocessing for character data.
168 | # Due to the limitation of the python wrapper for prototxt,
169 | # the characters (by index) need to be flattened when saving on prototxt.
170 | # The following 'unflattens' the character tensor.
171 | actual_batch_size = tf.shape(batch['indexed_context_chars'])[0]
172 | features_batch['indexed_context_chars'] = tf.reshape(
173 | features_batch['indexed_context_chars'],
174 | [actual_batch_size, -1, metadata['num_chars_per_word']])
175 | features_batch['indexed_question_chars'] = tf.reshape(
176 | features_batch['indexed_question_chars'],
177 | [actual_batch_size, -1, metadata['num_chars_per_word']])
178 |
179 | # Make sure answer start and end positions are less than sequence lengths.
180 | # TODO(seominjoon) This will need to move to a separate test.
181 | with tf.control_dependencies([
182 | tf.assert_less(
183 | tf.reduce_max(targets_batch['word_answer_starts'], 1),
184 | features_batch['context_num_words'])
185 | ]):
186 | targets_batch['word_answer_starts'] = tf.identity(
187 | targets_batch['word_answer_starts'])
188 | with tf.control_dependencies([
189 | tf.assert_less(
190 | tf.reduce_max(targets_batch['word_answer_ends'], 1),
191 | features_batch['context_num_words'])
192 | ]):
193 | targets_batch['word_answer_ends'] = tf.identity(
194 | targets_batch['word_answer_ends'])
195 |
196 | # Stress test to ensure no OOM for GPU occurs.
197 | if oom_test:
198 | features_batch['indexed_context_words'] = tf.constant(
199 | np.ones(
200 | [batch_size, exp_metadata['max_context_size']], dtype='int64'))
201 | features_batch['glove_indexed_context_words'] = tf.constant(
202 | np.ones(
203 | [batch_size, exp_metadata['max_context_size']], dtype='int64'))
204 | features_batch['indexed_context_chars'] = tf.constant(
205 | np.ones(
206 | [
207 | batch_size, exp_metadata['max_context_size'], exp_metadata[
208 | 'num_chars_per_word']
209 | ],
210 | dtype='int64'))
211 | features_batch['indexed_question_words'] = tf.constant(
212 | np.ones([batch_size, exp_metadata['max_ques_size']], dtype='int64'))
213 | features_batch['glove_indexed_question_words'] = tf.constant(
214 | np.ones([batch_size, exp_metadata['max_ques_size']], dtype='int64'))
215 | features_batch['indexed_question_chars'] = tf.constant(
216 | np.ones(
217 | [
218 | batch_size, exp_metadata['max_ques_size'], exp_metadata[
219 | 'num_chars_per_word']
220 | ],
221 | dtype='int64'))
222 | features_batch['question_num_words'] = tf.constant(
223 | np.ones([batch_size], dtype='int64') * exp_metadata['max_ques_size'])
224 | features_batch['context_num_words'] = tf.constant(
225 | np.ones([batch_size], dtype='int64') *
226 | exp_metadata['max_context_size'])
227 |
228 | return features_batch, targets_batch
229 |
230 | return _input_fn
231 |
232 |
233 | def get_params(root_data_dir):
234 | """Load data-specific parameters from `root_data_dir`.
235 |
236 | Args:
237 | root_data_dir: The data directory to load parameter files from.
238 | This is equivalent to the `output_dir` of `data/squad_prepro.py`.
239 | Returns:
240 | A dict of hyperparameters.
241 | """
242 | indexer_path = os.path.join(root_data_dir, 'indexer.json')
243 | with tf.gfile.GFile(indexer_path, 'r') as fp:
244 | indexer = json.load(fp)
245 |
246 | return {
247 | 'vocab_size': len(indexer['word2idx']),
248 | 'char_vocab_size': len(indexer['char2idx']),
249 | }
250 |
251 |
252 | def get_eval_metric_ops(targets, predictions):
253 | """Get a dictionary of eval metrics for `Experiment` object.
254 |
255 | Args:
256 | targets: `targets` that go into `model_fn` of `Experiment`.
257 | predictions: Dictionary of predictions, output of `get_preds`.
258 | Returns:
259 | A dictionary of eval metrics.
260 | """
261 | # TODO(seominjoon): yp should also consider no answer case.
262 | yp1 = tf.expand_dims(predictions['yp1'], -1)
263 | yp2 = tf.expand_dims(predictions['yp2'], -1)
264 | answer_mask = tf.sequence_mask(targets['num_answers'])
265 | start_correct = tf.reduce_any(
266 | tf.equal(targets['word_answer_starts'], yp1) & answer_mask, 1)
267 | end_correct = tf.reduce_any(
268 | tf.equal(targets['word_answer_ends'], yp2) & answer_mask, 1)
269 | correct = start_correct & end_correct
270 | em = tf.py_func(
271 | _enum_fn(_exact_match_score, dtype='float32'), [
272 | predictions['a'], targets['answers'], predictions['has_answer'],
273 | answer_mask
274 | ], 'float32')
275 | f1 = tf.py_func(
276 | _enum_fn(_f1_score, dtype='float32'), [
277 | predictions['a'], targets['answers'], predictions['has_answer'],
278 | answer_mask
279 | ], 'float32')
280 |
281 | eval_metric_ops = {
282 | 'acc1': tf.metrics.mean(tf.cast(start_correct, 'float')),
283 | 'acc2': tf.metrics.mean(tf.cast(end_correct, 'float')),
284 | 'acc': tf.metrics.mean(tf.cast(correct, 'float')),
285 | 'em': tf.metrics.mean(em),
286 | 'f1': tf.metrics.mean(f1),
287 | }
288 | return eval_metric_ops
289 |
290 |
291 | def get_answer_op(context, context_words, answer_start, answer_end):
292 | return tf.py_func(
293 | _enum_fn(_get_answer), [context, context_words, answer_start, answer_end],
294 | 'string')
295 |
296 |
297 | def _get_answer(context, context_words, answer_start, answer_end):
298 | """Get answer given context, context_words, and span.
299 |
300 | Args:
301 | context: A list of bytes, to be decoded with utf-8.
302 | context_words: A list of a list of bytes, to be decoded with utf-8.
303 | answer_start: An int for answer start.
304 | answer_end: An int for answer end.
305 | Returns:
306 | A list of bytes, encoded with utf-8, for the answer.
307 | """
308 | context = context.decode('utf-8')
309 | context_words = [word.decode('utf-8') for word in context_words]
310 | pos = 0
311 | answer_start_char = None
312 | answer_end_char = None
313 | for i, word in enumerate(context_words):
314 | pos = context.index(word, pos)
315 | if answer_start == i:
316 | answer_start_char = pos
317 | pos += len(word)
318 | if answer_end == i:
319 | answer_end_char = pos
320 | break
321 | assert answer_start_char is not None, (
322 | '`answer_start` is not found in context. '
323 | 'context=`%s`, context_words=`%r`, '
324 | 'answer_start=%d, answer_end=%d') % (context, context_words, answer_start,
325 | answer_end)
326 | assert answer_end_char is not None, (
327 | '`answer_end` is not found in context. '
328 | 'context=`%s`, context_words=`%r`, '
329 | 'answer_start=%d, answer_end=%d') % (context, context_words, answer_start,
330 | answer_end)
331 | answer = context[answer_start_char:answer_end_char].encode('utf-8')
332 | return answer
333 |
334 |
335 | def _f1_score(prediction, ground_truths, has_answer, answer_mask):
336 | prediction = prediction.decode('utf-8')
337 | ground_truths = [
338 | ground_truth.decode('utf-8') for ground_truth in ground_truths
339 | ]
340 | if not has_answer:
341 | return float(ground_truths[0] == squad_prepro.NO_ANSWER)
342 | elif ground_truths[0] == squad_prepro.NO_ANSWER:
343 | return 0.0
344 | else:
345 | scores = np.array([
346 | _f1_score_(prediction, ground_truth) for ground_truth in ground_truths
347 | ])
348 | return max(scores * answer_mask.astype(float))
349 |
350 |
351 | def _exact_match_score(prediction, ground_truths, has_answer, answer_mask):
352 | prediction = prediction.decode('utf-8')
353 | ground_truths = [
354 | ground_truth.decode('utf-8') for ground_truth in ground_truths
355 | ]
356 | if not has_answer:
357 | return float(ground_truths[0] == squad_prepro.NO_ANSWER)
358 | elif ground_truths[0] == squad_prepro.NO_ANSWER:
359 | return 0.0
360 | else:
361 | scores = np.array([
362 | float(_exact_match_score_(prediction, ground_truth))
363 | for ground_truth in ground_truths
364 | ])
365 | return max(scores * answer_mask.astype(float))
366 |
367 |
368 | def _enum_fn(fn, dtype='object'):
369 |
370 | def new_fn(*args):
371 | return np.array([fn(*each_args) for each_args in zip(*args)], dtype=dtype)
372 |
373 | return new_fn
374 |
375 |
376 | # Functions below are copied from official SQuAD eval script and SHOULD NOT
377 | # BE MODIFIED.
378 |
379 |
380 | def _normalize_answer(s):
381 | """Lower text and remove punctuation, articles and extra whitespace.
382 |
383 | Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED.
384 |
385 | Args:
386 | s: Input text.
387 | Returns:
388 | Normalized text.
389 | """
390 |
391 | def remove_articles(text):
392 | return re.sub(r'\b(a|an|the)\b', ' ', text)
393 |
394 | def white_space_fix(text):
395 | return ' '.join(text.split())
396 |
397 | def remove_punc(text):
398 | exclude = set(string.punctuation)
399 | return ''.join(ch for ch in text if ch not in exclude)
400 |
401 | def lower(text):
402 | return text.lower()
403 |
404 | return white_space_fix(remove_articles(remove_punc(lower(s))))
405 |
406 |
407 | def _f1_score_(prediction, ground_truth):
408 | """Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED."""
409 | prediction_tokens = _normalize_answer(prediction).split()
410 | ground_truth_tokens = _normalize_answer(ground_truth).split()
411 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
412 | num_same = sum(common.values())
413 | if num_same == 0:
414 | return 0
415 | precision = 1.0 * num_same / len(prediction_tokens)
416 | recall = 1.0 * num_same / len(ground_truth_tokens)
417 | f1 = (2 * precision * recall) / (precision + recall)
418 | return f1
419 |
420 |
421 | def _exact_match_score_(prediction, ground_truth):
422 | """Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED."""
423 | return _normalize_answer(prediction) == _normalize_answer(ground_truth)
424 |
--------------------------------------------------------------------------------
/squad_prepro.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Library for preprocessing SQuAD.
15 | """
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from collections import Counter
22 | from collections import OrderedDict
23 | import itertools
24 | import json
25 | import os
26 |
27 | import nltk
28 | import numpy as np
29 | import tensorflow as tf
30 | from tqdm import tqdm
31 |
32 | PAD = u''
33 | UNK = u''
34 | NO_ANSWER = u''
35 |
36 |
37 | class SquadIndexer(object):
38 | """Indexer for SQuAD.
39 |
40 | Instantiating this class loads GloVe. The object can fit examples (creating
41 | vocab out of the examples) and index the examples.
42 | """
43 |
44 | def __init__(self, glove_dir, tokenizer=None, draft=False):
45 | self._glove_words = list(_get_glove(glove_dir, size=50, draft=draft))
46 | self._glove_vocab = _get_glove_vocab(self._glove_words)
47 | self._glove_dir = glove_dir
48 | self._draft = draft
49 | self._tokenizer = tokenizer or _word_tokenize
50 | self._word2idx_dict = None
51 | self._char2idx_dict = None
52 |
53 | def word2idx(self, word):
54 | word = word.lower()
55 | if word not in self._word2idx_dict:
56 | return 1
57 | return self._word2idx_dict[word]
58 |
59 | def char2idx(self, char):
60 | if char not in self._char2idx_dict:
61 | return 1
62 | return self._char2idx_dict[char]
63 |
64 | def fit(self,
65 | examples,
66 | min_word_count=10,
67 | min_char_count=100,
68 | num_chars_per_word=16):
69 | """Fits examples and return indexed examples with metadata.
70 |
71 | Fitting examples means the vocab is created out of the examples.
72 | The vocab can be saved via `save` and loaded via `load` methods.
73 |
74 | Args:
75 | examples: list of dictionary, where each dictionary is an example.
76 | min_word_count: `int` value, minimum word count to be included in vocab.
77 | min_char_count: `int` value, minimum char count to be included in vocab.
78 | num_chars_per_word: `int` value, number of chars to store per word.
79 | This is fixed, so if word is shorter, then the rest is padded with 0.
80 | The characters are flattened, so need to be reshaped when using them.
81 | Returns:
82 | a tuple `(indexed_examples, metadata)`, where `indexed_examples` is a
83 | list of dict (each dict being indexed example) and `metadata` is a dict
84 | of `glove_word2idx_dict` and statistics of the examples.
85 | """
86 | tokenized_examples = [
87 | _tokenize(example, self._tokenizer, num_chars_per_word)
88 | for example in tqdm(examples, 'tokenizing')
89 | ]
90 | word_counter = _get_word_counter(tokenized_examples)
91 | char_counter = _get_char_counter(tokenized_examples)
92 | self._word2idx_dict = _counter2vocab(word_counter, min_word_count)
93 | tf.logging.info('Word vocab size: %d' % len(self._word2idx_dict))
94 | self._char2idx_dict = _counter2vocab(char_counter, min_char_count)
95 | tf.logging.info('Char vocab size: %d' % len(self._char2idx_dict))
96 | glove_word2idx_dict = _get_glove_vocab(
97 | self._glove_words, counter=word_counter)
98 | tf.logging.info('Glove word vocab size: %d' % len(glove_word2idx_dict))
99 |
100 | def glove_word2idx(word):
101 | word = word.lower()
102 | return glove_word2idx_dict[word] if word in glove_word2idx_dict else 1
103 |
104 | indexed_examples = [
105 | _index(example, self.word2idx, glove_word2idx, self.char2idx)
106 | for example in tqdm(tokenized_examples, desc='indexing')
107 | ]
108 |
109 | metadata = self._get_metadata(indexed_examples)
110 | metadata['glove_word2idx'] = glove_word2idx_dict
111 | metadata['num_chars_per_word'] = num_chars_per_word
112 |
113 | return indexed_examples, metadata
114 |
115 | def prepro_eval(self, examples, num_chars_per_word=16):
116 | """Tokenizes and indexes examples (usually non-train examples).
117 |
118 | In order to use this, `fit` must have been already executed on train data.
119 | Other than that, this function has same functionality as `fit`, returning
120 | indexed examples.
121 |
122 | Args:
123 | examples: a list of dict, where each dict is an example.
124 | num_chars_per_word: `int` value, number of chars to store per word.
125 | This is fixed, so if word is shorter, then the rest is padded with 0.
126 | The charaters are flattened, so need to be reshaped when using them.
127 | Returns:
128 | a tuple `(indexed_examples, metadata)`, where `indexed_examples` is a
129 | list of dict (each dict being indexed example) and `metadata` is a dict
130 | of `glove_word2idx_dict` and statistics of the examples.
131 | """
132 | tokenized_examples = [
133 | _tokenize(example, self._tokenizer, num_chars_per_word)
134 | for example in tqdm(examples, desc='tokenizing')
135 | ]
136 | word_counter = _get_word_counter(tokenized_examples)
137 | glove_word2idx_dict = _get_glove_vocab(
138 | self._glove_words, counter=word_counter)
139 | tf.logging.info('Glove word vocab size: %d' % len(glove_word2idx_dict))
140 |
141 | def glove_word2idx(word):
142 | word = word.lower()
143 | return glove_word2idx_dict[word] if word in glove_word2idx_dict else 1
144 |
145 | indexed_examples = [
146 | _index(example, self.word2idx, glove_word2idx, self.char2idx)
147 | for example in tqdm(tokenized_examples, desc='indexing')
148 | ]
149 |
150 | metadata = self._get_metadata(indexed_examples)
151 | metadata['glove_word2idx'] = glove_word2idx_dict
152 | metadata['num_chars_per_word'] = num_chars_per_word
153 | return indexed_examples, metadata
154 |
155 | @property
156 | def savable(self):
157 | return {'word2idx': self._word2idx_dict, 'char2idx': self._char2idx_dict}
158 |
159 | def save(self, save_path):
160 | with tf.gfile.GFile(save_path, 'w') as fp:
161 | json.dump(self.savable, fp)
162 |
163 | def load(self, load_path):
164 | with tf.gfile.GFile(load_path, 'r') as fp:
165 | savable = json.load(fp)
166 | self._word2idx_dict = savable['word2idx']
167 | self._char2idx_dict = savable['char2idx']
168 |
169 | def _get_metadata(self, examples):
170 | metadata = {
171 | 'max_context_size': max(len(e['context_words']) for e in examples),
172 | 'max_ques_size': max(len(e['question_words']) for e in examples),
173 | 'word_vocab_size': len(self._word2idx_dict),
174 | 'char_vocab_size': len(self._char2idx_dict),
175 | }
176 | return metadata
177 |
178 |
179 | def split(example, para2sents_fn=None, positive_augment_factor=0):
180 | """Splits context in example into sentences and create multiple examples.
181 |
182 | Args:
183 | example: `dict` object, each element of `get_examples()`.
184 | para2sents_fn: function that maps `str` to a list of `str`, splitting
185 | paragraph into sentences.
186 | positive_augment_factor: Multiply positive examples by this factor.
187 | For handling class imbalance problem.
188 | Returns:
189 | a list of examples, with modified fields: `id`, `context`, `answers` and
190 | `answer_starts`. Will add `has_answer` bool field.
191 | """
192 | if para2sents_fn is None:
193 | para2sents_fn = nltk.sent_tokenize
194 | sents = para2sents_fn(example['context'])
195 | sent_start_idxs = _tokens2idxs(example['context'], sents)
196 |
197 | context = example['context']
198 | examples = []
199 | for i, (sent, sent_start_idx) in enumerate(zip(sents, sent_start_idxs)):
200 | sent_end_idx = sent_start_idx + len(sent)
201 | e = dict(example.items()) # Copying dict content.
202 | e['context'] = sent
203 | e['id'] = '%s %d' % (e['id'], i)
204 | e['answers'] = []
205 | e['answer_starts'] = []
206 | for answer, answer_start in zip(example['answers'],
207 | example['answer_starts']):
208 | answer_end = answer_start + len(answer)
209 | if (sent_start_idx <= answer_start < sent_end_idx or
210 | sent_start_idx < answer_end <= sent_end_idx):
211 | new_answer = context[max(sent_start_idx, answer_start):min(
212 | sent_end_idx, answer_end)]
213 | new_answer_start = max(answer_start, sent_start_idx) - sent_start_idx
214 | e['answers'].append(new_answer)
215 | e['answer_starts'].append(new_answer_start)
216 | if not e['answers']:
217 | e['answers'].append(NO_ANSWER)
218 | e['answer_starts'].append(-1)
219 | e['num_answers'] = len(e['answers'])
220 | # If the list is empty, then the example has no answer.
221 | examples.append(e)
222 | if positive_augment_factor and e['answers'][0] != NO_ANSWER:
223 | for _ in range(positive_augment_factor):
224 | examples.append(e)
225 | return examples
226 |
227 |
228 | def get_idx2vec_mat(glove_dir, size, glove_word2idx_dict):
229 | """Gets embedding matrix for given GloVe vocab."""
230 | glove = _get_glove(glove_dir, size=size)
231 | glove[PAD] = glove[UNK] = [0.0] * size
232 | idx2vec_dict = {idx: glove[word] for word, idx in glove_word2idx_dict.items()}
233 | idx2vec_mat = np.array(
234 | [idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
235 | return idx2vec_mat
236 |
237 |
238 | def get_examples(squad_path):
239 | """Obtain a list of examples from official SQuAD file.
240 |
241 | Args:
242 | squad_path: path to the official SQuAD file (e.g. `train-v1.1.json`).
243 | Returns:
244 | a list of dict, where each dict is example.
245 | """
246 | with tf.gfile.GFile(squad_path, 'r') as fp:
247 | squad = json.load(fp)
248 |
249 | examples = []
250 | version = squad['version']
251 | for article in squad['data']:
252 | title = article['title']
253 | for paragraph in article['paragraphs']:
254 | context = paragraph['context']
255 | for qa in paragraph['qas']:
256 | question = qa['question']
257 | id_ = qa['id']
258 |
259 | answer_starts = [answer['answer_start'] for answer in qa['answers']]
260 | answers = [answer['text'] for answer in qa['answers']]
261 |
262 | example = {
263 | 'version': version,
264 | 'title': title,
265 | 'context': context,
266 | 'question': question,
267 | 'id': id_,
268 | 'answer_starts': answer_starts,
269 | 'answers': answers,
270 | 'num_answers': len(answers),
271 | 'is_supervised': True,
272 | }
273 | example = normalize_example(example)
274 | examples.append(example)
275 | return examples
276 |
277 |
278 | def normalize_example(example):
279 | n_example = dict(example.items())
280 | n_example['context'] = _replace_quotations(n_example['context'])
281 | n_example['answers'] = [_replace_quotations(a) for a in n_example['answers']]
282 | return n_example
283 |
284 |
285 | def _replace_quotations(text):
286 | return text.replace('``', '" ').replace("''", '" ')
287 |
288 |
289 | def _word_tokenize(text):
290 | # TODO(seominjoon): Consider using Stanford Tokenizer or othe tokenizers.
291 | return [
292 | word.replace('``', '"').replace("''", '"')
293 | for word in nltk.word_tokenize(text)
294 | ]
295 |
296 |
297 | def _tokens2idxs(text, tokens):
298 | idxs = []
299 | idx = 0
300 | for token in tokens:
301 | idx = text.find(token, idx)
302 | assert idx >= 0, (text, tokens)
303 | idxs.append(idx)
304 | idx += len(token)
305 | return idxs
306 |
307 |
308 | def _tokenize(example, text2words_fn, num_chars_per_word):
309 | """Tokenize each example using provided tokenizer (`text2words_fn`).
310 |
311 | Args:
312 | example: `dict` value, an example.
313 | text2words_fn: tokenizer.
314 | num_chars_per_word: `int` value, number of chars to store per word.
315 | This is fixed, so if word is shorter, then the rest is padded with 0.
316 | The charaters are flattened, so need to be reshaped when using them.
317 | Returns:
318 | `dict`, representing tokenized example.
319 | """
320 | new_example = dict(example.items())
321 | new_example['question_words'] = text2words_fn(example['question'])
322 | new_example['question_num_words'] = len(new_example['question_words'])
323 | new_example['context_words'] = text2words_fn(example['context'])
324 | new_example['context_num_words'] = len(new_example['context_words'])
325 |
326 | def word2chars(word):
327 | chars = list(word)
328 | if len(chars) > num_chars_per_word:
329 | return chars[:num_chars_per_word]
330 | else:
331 | return chars + [PAD] * (num_chars_per_word - len(chars))
332 |
333 | new_example['question_chars'] = list(
334 | itertools.chain(
335 | * [word2chars(word) for word in new_example['question_words']]))
336 | new_example['context_chars'] = list(
337 | itertools.chain(
338 | * [word2chars(word) for word in new_example['context_words']]))
339 | return new_example
340 |
341 |
342 | def _index(example, word2idx_fn, glove_word2idx_fn, char2idx_fn):
343 | """Indexes each tokenized example, using provided vocabs.
344 |
345 | Args:
346 | example: `dict` representing tokenized example.
347 | word2idx_fn: indexer for word vocab.
348 | glove_word2idx_fn: indexer for glove word vocab.
349 | char2idx_fn: indexer for character vocab.
350 | Returns:
351 | `dict` representing indexed example.
352 | """
353 | new_example = dict(example.items())
354 | new_example['indexed_question_words'] = [
355 | word2idx_fn(word) for word in example['question_words']
356 | ]
357 | new_example['indexed_context_words'] = [
358 | word2idx_fn(word) for word in example['context_words']
359 | ]
360 | new_example['indexed_question_chars'] = [
361 | char2idx_fn(word) for word in example['question_chars']
362 | ]
363 | new_example['indexed_context_chars'] = [
364 | char2idx_fn(word) for word in example['context_chars']
365 | ]
366 | new_example['glove_indexed_question_words'] = [
367 | glove_word2idx_fn(word) for word in example['question_words']
368 | ]
369 | new_example['glove_indexed_context_words'] = [
370 | glove_word2idx_fn(word) for word in example['context_words']
371 | ]
372 |
373 | word_answer_starts = []
374 | word_answer_ends = []
375 | for answer_start, answer in zip(new_example['answer_starts'],
376 | new_example['answers']):
377 | if answer_start < 0:
378 | word_answer_starts.append(-1)
379 | word_answer_ends.append(-1)
380 | break
381 | word_answer_start, word_answer_end = _get_word_answer(
382 | new_example['context'], new_example['context_words'], answer_start,
383 | answer)
384 | word_answer_starts.append(word_answer_start)
385 | word_answer_ends.append(word_answer_end)
386 | new_example['word_answer_starts'] = word_answer_starts
387 | new_example['word_answer_ends'] = word_answer_ends
388 |
389 | return new_example
390 |
391 |
392 | def _get_glove(glove_path, size=None, draft=False):
393 | """Get an `OrderedDict` that maps word to vector.
394 |
395 | Args:
396 | glove_path: `str` value,
397 | path to the glove file (e.g. `glove.6B.50d.txt`) or directory.
398 | size: `int` value, size of the vector, if `glove_path` is a directory.
399 | draft: `bool` value, whether to only load first 99 for draft mode.
400 | Returns:
401 | `OrderedDict` object, mapping word to vector.
402 | """
403 | if size is not None:
404 | glove_path = os.path.join(glove_path, 'glove.6B.%dd.txt' % size)
405 | glove = OrderedDict()
406 | with tf.gfile.GFile(glove_path, 'rb') as fp:
407 | for idx, line in enumerate(fp):
408 | line = line.decode('utf-8')
409 | tokens = line.strip().split(u' ')
410 | word = tokens[0]
411 | vec = list(map(float, tokens[1:]))
412 | glove[word] = vec
413 | if draft and idx > 99:
414 | break
415 | return glove
416 |
417 |
418 | def _get_word_counter(examples):
419 | # TODO(seominjoon): Consider not ignoring uppercase.
420 | counter = Counter()
421 | for example in tqdm(examples, desc='word counter'):
422 | for word in example['question_words']:
423 | counter[word.lower()] += 1
424 | for word in example['context_words']:
425 | counter[word.lower()] += 1
426 | return counter
427 |
428 |
429 | def _get_char_counter(examples):
430 | counter = Counter()
431 | for example in tqdm(examples, desc='char counter'):
432 | for chars in example['question_chars']:
433 | for char in chars:
434 | counter[char] += 1
435 | for chars in example['context_chars']:
436 | for char in chars:
437 | counter[char] += 1
438 | return counter
439 |
440 |
441 | def _counter2vocab(counter, min_count):
442 | tokens = [token for token, count in counter.items() if count >= min_count]
443 | tokens = [PAD, UNK] + tokens
444 | vocab = {token: idx for idx, token in enumerate(tokens)}
445 | return vocab
446 |
447 |
448 | def _get_glove_vocab(words, counter=None):
449 | if counter is not None:
450 | words = [word for word in counter if word in set(words)]
451 | words = [PAD, UNK] + words
452 | vocab = {word: idx for idx, word in enumerate(words)}
453 | return vocab
454 |
455 |
456 | def _get_word_answer(context, context_words, answer_start, answer):
457 | """Get word-level answer index.
458 |
459 | Args:
460 | context: `unicode`, representing the context of the question.
461 | context_words: a list of `unicode`, tokenized context.
462 | answer_start: `int`, the char-level start index of the answer.
463 | answer: `unicode`, the answer that is substring of context.
464 | Returns:
465 | a tuple of `(word_answer_start, word_answer_end)`, representing the start
466 | and end indices of the answer in respect to `context_words`.
467 | """
468 | assert answer, 'Encountered length-0 answer.'
469 | answer_end = answer_start + len(answer)
470 | char_idxs = _tokens2idxs(context, context_words)
471 | word_answer_start = None
472 | word_answer_end = None
473 | for word_idx, char_idx in enumerate(char_idxs):
474 | if char_idx <= answer_start:
475 | word_answer_start = word_idx
476 | if char_idx < answer_end:
477 | word_answer_end = word_idx
478 | return word_answer_start, word_answer_end
479 |
--------------------------------------------------------------------------------
/squad_prepro_main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Preprocesser for SQuAD.
15 |
16 | `from_dir` will need to contain original SQuAD train/dev data:
17 |
18 | - `train-v1.1.json`
19 | - `dev-v1.1.json`
20 |
21 | In each TFRecord file, following features will be provided. Note that all
22 | strings are encoded with utf-8.
23 |
24 | Directly from the data.
25 | - `id` : string id from original SQuAD data.
26 | - `version` : the version of SQuAD data.
27 | - `title` : the title of the article.
28 | - `question` : original question string.
29 | - `context` : original context string.
30 |
31 | - `answers` : original list of answer strings. Variable length.
32 | - `answer_starts` : original list of integers for answer starts.
33 |
34 | Processed.
35 | - `question_words` : tokenized question. A variable len list of strings.
36 | - `context_words` : tokenized context. Variable length.
37 | - `question_chars` : question chars.
38 | - `context_chars` : context chars.
39 | - `indexed_question_words` : question words indexed by vocab.
40 | - `indexed_context_words` : context words indexed by vocab.
41 | - `glove_indexed_question_words` : question words indexed by GloVe.
42 | - `glove_indexed_context_words` : context words indexed by GloVe.
43 | - `indexed_question_chars` : question chars, flattened and indexed.
44 | - `indexed_context_chars` : ditto.
45 | - `question_num_words` : number of words in question.
46 | - `context_num_words` : number of words in context.
47 | - `is_supervised` : whether the dataset is supervised or not.
48 |
49 | - `num_answers': integer indicating number of answers. 0 means no answer.
50 | - `word_answer_starts` : word-level answer start positions.
51 | - `word_answer_ends` : word-level answer end positions.
52 | """
53 |
54 | from __future__ import absolute_import
55 | from __future__ import division
56 | from __future__ import print_function
57 |
58 | import itertools
59 | import json
60 | import os
61 | import random
62 | from six import string_types
63 |
64 | import tensorflow as tf
65 | from tqdm import tqdm
66 | import squad_prepro
67 |
68 | tf.flags.DEFINE_string('from_dir', '', 'Directory for original SQuAD data.')
69 | tf.flags.DEFINE_string('to_dir', '', 'Directory for preprocessed SQuAD data.')
70 | tf.flags.DEFINE_string('glove_dir', '', 'Directory for GloVe files.')
71 | tf.flags.DEFINE_string('indexer_dir', '', 'Directory for indexer. '
72 | 'If specified, does not load train and uses this.')
73 | tf.flags.DEFINE_integer('word_count_th', 100, 'Word count threshold for vocab.')
74 | tf.flags.DEFINE_integer('char_count_th', 100, 'Char count threshold for vocab.')
75 | tf.flags.DEFINE_integer('max_context_size', 256,
76 | 'Maximum context size. Set this to `0` for test, which '
77 | 'sets no limit to the maximum context size.')
78 | tf.flags.DEFINE_integer('max_ques_size', 32,
79 | 'Maximum question size. Set this to `0` for test.')
80 | tf.flags.DEFINE_integer('num_chars_per_word', 16,
81 | 'Fixed number of characters per word.')
82 | tf.flags.DEFINE_boolean(
83 | 'split', False,
84 | 'if `True`, each context will be sentence instead of paragraph, '
85 | 'and answer label (word index) will be `None` in case of no answer.')
86 | tf.flags.DEFINE_boolean('filter', False,
87 | 'If `True`, filters data by context and question sizes')
88 | tf.flags.DEFINE_boolean('sort', False,
89 | 'If `True`, sorts data by context length.')
90 | tf.flags.DEFINE_boolean('shuffle', True, 'If `True`, shuffle examples.')
91 | tf.flags.DEFINE_boolean('draft', False, 'If `True`, fast draft mode, '
92 | 'which only loads first few examples.')
93 | tf.flags.DEFINE_integer('max_shard_size', 1000, 'Max size of each shard.')
94 | tf.flags.DEFINE_integer('positive_augment_factor', 0, 'Augment positive '
95 | 'examples by this factor.')
96 | tf.flags.DEFINE_boolean('answerable', False, 'This flag '
97 | 'allows one to use only examples that have answers.')
98 |
99 | FLAGS = tf.flags.FLAGS
100 |
101 |
102 | def get_tf_example(example):
103 | """Get `tf.train.Example` object from example dict.
104 |
105 | Args:
106 | example: tokenized, indexed example.
107 | Returns:
108 | `tf.train.Example` object corresponding to the example.
109 | Raises:
110 | ValueError: if a key in `example` is invalid.
111 | """
112 | feature = {}
113 | for key, val in example.items():
114 | if not isinstance(val, list):
115 | val = [val]
116 | if val:
117 | if isinstance(val[0], string_types):
118 | dtype = 'bytes'
119 | elif isinstance(val[0], int):
120 | dtype = 'int64'
121 | else:
122 | raise TypeError('`%s` has an invalid type: %r' % (key, type(val[0])))
123 | else:
124 | if key == 'answers':
125 | dtype = 'bytes'
126 | elif key in ['answer_starts', 'word_answer_starts', 'word_answer_ends']:
127 | dtype = 'int64'
128 | else:
129 | raise ValueError(key)
130 |
131 | if dtype == 'bytes':
132 | # Transform unicode into bytes if necessary.
133 | val = [each.encode('utf-8') for each in val]
134 | feature[key] = tf.train.Feature(bytes_list=tf.train.BytesList(value=val))
135 | elif dtype == 'int64':
136 | feature[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=val))
137 | else:
138 | raise TypeError('`%s` has an invalid type: %r' % (key, type(val[0])))
139 | return tf.train.Example(features=tf.train.Features(feature=feature))
140 |
141 |
142 | def dump(examples, metadata, data_type):
143 | """Dumps examples as TFRecord files.
144 |
145 | Args:
146 | examples: a `list` of `dict`, where each dict is indexed example.
147 | metadata: `dict`, metadata corresponding to the examples.
148 | data_type: `str`, representing the data type of the examples (e.g. `train`).
149 | """
150 | out_dir = os.path.join(FLAGS.to_dir, data_type)
151 | metadata_path = os.path.join(out_dir, 'metadata.json')
152 | data_dir = os.path.join(out_dir, 'data')
153 | tf.gfile.MakeDirs(out_dir)
154 | tf.gfile.MakeDirs(data_dir)
155 |
156 | with tf.gfile.GFile(metadata_path, 'w') as fp:
157 | json.dump(metadata, fp)
158 |
159 | # Dump stuff
160 | writer = None
161 | counter = 0
162 | num_shards = 0
163 | for example in tqdm(examples):
164 | if writer is None:
165 | path = os.path.join(data_dir,
166 | 'squad_data_{}'.format(str(num_shards).zfill(4)))
167 | writer = tf.python_io.TFRecordWriter(path)
168 | tf_example = get_tf_example(example)
169 | writer.write(tf_example.SerializeToString())
170 | counter += 1
171 | if counter == FLAGS.max_shard_size:
172 | counter = 0
173 | writer.close()
174 | writer = None
175 | num_shards += 1
176 | if writer is not None:
177 | writer.close()
178 |
179 |
180 | def prepro(data_type, indexer=None):
181 | """Preprocesses the given data type."""
182 | squad_path = os.path.join(FLAGS.from_dir, '%s-v1.1.json' % data_type)
183 | tf.logging.info('Loading %s' % squad_path)
184 | examples = squad_prepro.get_examples(squad_path)
185 |
186 | if FLAGS.draft:
187 | examples = random.sample(examples, 100)
188 |
189 | if FLAGS.split:
190 | tf.logging.info('Splitting each example')
191 | tf.logging.info('Before splitting: %d %s examples' % (len(examples),
192 | data_type))
193 | examples = list(
194 | itertools.chain(* [
195 | squad_prepro.split(
196 | e, positive_augment_factor=FLAGS.positive_augment_factor)
197 | for e in tqdm(examples)
198 | ]))
199 | tf.logging.info('After splitting: %d %s examples' % (len(examples),
200 | data_type))
201 |
202 | if FLAGS.answerable:
203 | tf.logging.info('Using only answerable examples.')
204 | examples = [
205 | example for example in examples
206 | if example['answers'][0] != squad_prepro.NO_ANSWER
207 | ]
208 |
209 | if FLAGS.shuffle:
210 | tf.logging.info('Shuffling examples')
211 | random.shuffle(examples)
212 |
213 | if indexer is None:
214 | tf.logging.info('Creating indexer')
215 | indexer = squad_prepro.SquadIndexer(FLAGS.glove_dir, draft=FLAGS.draft)
216 |
217 | tf.logging.info('Indexing %s data' % data_type)
218 | indexed_examples, metadata = indexer.fit(
219 | examples,
220 | min_word_count=FLAGS.word_count_th,
221 | min_char_count=FLAGS.char_count_th,
222 | num_chars_per_word=FLAGS.num_chars_per_word)
223 | else:
224 | indexed_examples, metadata = indexer.prepro_eval(
225 | examples, num_chars_per_word=FLAGS.num_chars_per_word)
226 | tf.gfile.MakeDirs(FLAGS.to_dir)
227 | indexer_save_path = os.path.join(FLAGS.to_dir, 'indexer.json')
228 | tf.logging.info('Saving indexer')
229 | indexer.save(indexer_save_path)
230 |
231 | if FLAGS.filter:
232 | tf.logging.info('Filtering examples')
233 | tf.logging.info('Before filtering: %d %s examples' % (len(indexed_examples),
234 | data_type))
235 | indexed_examples = [
236 | e for e in indexed_examples
237 | if len(e['context_words']) <= FLAGS.max_context_size and
238 | len(e['question_words']) <= FLAGS.max_ques_size
239 | ]
240 | tf.logging.info('After filtering: %d %s examples' % (len(indexed_examples),
241 | data_type))
242 | tf.logging.info('Has answers: %d %s examples' %
243 | (sum(1 for e in indexed_examples
244 | if e['answer_starts'][0] >= 0), data_type))
245 | metadata['max_context_size'] = max(
246 | len(e['context_words']) for e in indexed_examples)
247 | metadata['max_ques_size'] = max(
248 | len(e['question_words']) for e in indexed_examples)
249 |
250 | if FLAGS.sort:
251 | tf.logging.info('Sorting examples')
252 | indexed_examples = sorted(
253 | indexed_examples, key=lambda e: len(e['context_words']))
254 |
255 | tf.logging.info('Dumping %s examples' % data_type)
256 | dump(indexed_examples, metadata, data_type)
257 |
258 | return indexer, metadata
259 |
260 |
261 | def main(argv):
262 | del argv
263 | assert not tf.gfile.Exists(FLAGS.to_dir), '%s already exists.' % FLAGS.to_dir
264 |
265 | if FLAGS.indexer_dir:
266 | indexer_path = os.path.join(FLAGS.indexer_dir, 'indexer.json')
267 | tf.logging.info('Loading indexer from %s' % indexer_path)
268 | indexer = squad_prepro.SquadIndexer(FLAGS.glove_dir, draft=FLAGS.draft)
269 | indexer.load(indexer_path)
270 | else:
271 | indexer, train_metadata = prepro('train')
272 | _, dev_metadata = prepro('dev', indexer=indexer)
273 |
274 | if FLAGS.indexer_dir:
275 | exp_metadata = dev_metadata
276 | else:
277 | exp_metadata = {
278 | 'max_context_size':
279 | max(train_metadata['max_context_size'],
280 | dev_metadata['max_context_size']),
281 | 'max_ques_size':
282 | max(train_metadata['max_ques_size'], dev_metadata['max_ques_size']),
283 | 'num_chars_per_word':
284 | max(train_metadata['num_chars_per_word'],
285 | dev_metadata['num_chars_per_word'])
286 | }
287 |
288 | tf.logging.info('Dumping experiment metadata')
289 | exp_metadata_path = os.path.join(FLAGS.to_dir, 'metadata.json')
290 | with tf.gfile.GFile(exp_metadata_path, 'w') as fp:
291 | json.dump(exp_metadata, fp)
292 |
293 |
294 | if __name__ == '__main__':
295 | tf.app.run(main)
296 |
--------------------------------------------------------------------------------
/tf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """TensorFlow utilities for extractive question answering models.
15 | """
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | VERY_LARGE_NEGATIVE_VALUE = -1e12
23 | VERY_SMALL_POSITIVE_VALUE = 1e-12
24 |
25 |
26 | def bi_rnn(hidden_size,
27 | inputs_list,
28 | sequence_length_list=None,
29 | scope=None,
30 | dropout_rate=0.0,
31 | training=False,
32 | stack=False,
33 | cells=None,
34 | postprocess='concat',
35 | return_outputs=True,
36 | out_dim=None,
37 | reuse=False):
38 | """Bidirectional RNN with `BasicLSTMCell`.
39 |
40 | Args:
41 | hidden_size: `int` value, the hidden state size of the LSTM.
42 | inputs_list: A list of `inputs` tensors, where each `inputs` is
43 | single sequence tensor with shape [batch_size, seq_len, hidden_size].
44 | Can be single element instead of list.
45 | sequence_length_list: A list of `sequence_length` tensors.
46 | The size of the list should equal to that of `inputs_list`.
47 | Can be a single element instead of a list.
48 | scope: `str` value, variable scope for this function.
49 | dropout_rate: `float` value, dropout rate of LSTM, applied at the inputs.
50 | training: `bool` value, whether current run is training.
51 | stack: `bool` value, whether to stack instead of simultaneous bi-LSTM.
52 | cells: two `RNNCell` instances. If provided, `hidden_size` is ignored.
53 | postprocess: `str` value: `raw` or `concat` or `add`.
54 | Postprocessing on forward and backward outputs of LSTM.
55 | return_outputs: `bool` value, whether to return sequence outputs.
56 | Otherwise, return the last state.
57 | out_dim: `bool` value. If `postprocess` is `linear, then this indicates
58 | the output dim of the linearity.
59 | reuse: `bool` value, whether to reuse variables.
60 | Returns:
61 | A list `return_list` where each element corresponds to each element of
62 | `input_list`. If the `input_list` is a tensor, also returns a tensor.
63 | Raises:
64 | ValueError: If argument `postprocess` is an invalid value.
65 | """
66 | if not isinstance(inputs_list, list):
67 | inputs_list = [inputs_list]
68 | if sequence_length_list is None:
69 | sequence_length_list = [None] * len(inputs_list)
70 | elif not isinstance(sequence_length_list, list):
71 | sequence_length_list = [sequence_length_list]
72 | assert len(inputs_list) == len(
73 | sequence_length_list
74 | ), '`inputs_list` and `sequence_length_list` must have same lengths.'
75 | with tf.variable_scope(scope or 'bi_rnn', reuse=reuse) as vs:
76 | if cells is not None:
77 | cell_fw = cells[0]
78 | cell_bw = cells[1]
79 | else:
80 | cell_fw = tf.contrib.rnn.BasicLSTMCell(hidden_size, reuse=reuse)
81 | cell_bw = tf.contrib.rnn.BasicLSTMCell(hidden_size, reuse=reuse)
82 | return_list = []
83 | for inputs, sequence_length in zip(inputs_list, sequence_length_list):
84 | if return_list:
85 | vs.reuse_variables()
86 | if dropout_rate > 0.0:
87 | inputs = tf.layers.dropout(inputs, rate=dropout_rate, training=training)
88 | if stack:
89 | o_bw, state_bw = tf.nn.dynamic_rnn(
90 | cell_bw,
91 | tf.reverse_sequence(inputs, sequence_length, seq_dim=1),
92 | sequence_length=sequence_length,
93 | dtype='float',
94 | scope='rnn_bw')
95 | o_bw = tf.reverse_sequence(o_bw, sequence_length, seq_dim=1)
96 | if dropout_rate > 0.0:
97 | o_bw = tf.layers.dropout(o_bw, rate=dropout_rate, training=training)
98 | o_fw, state_fw = tf.nn.dynamic_rnn(
99 | cell_fw,
100 | o_bw,
101 | sequence_length=sequence_length,
102 | dtype='float',
103 | scope='rnn_fw')
104 | else:
105 | (o_fw, o_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
106 | cell_fw,
107 | cell_bw,
108 | inputs,
109 | sequence_length=sequence_length,
110 | dtype='float')
111 | return_fw = o_fw if return_outputs else state_fw[-1]
112 | return_bw = o_bw if return_outputs else state_bw[-1]
113 | if postprocess == 'raw':
114 | return_ = return_fw, return_bw
115 | elif postprocess == 'concat':
116 | return_ = tf.concat([return_fw, return_bw], 2 if return_outputs else 1)
117 | elif postprocess == 'add':
118 | return_ = return_fw + return_bw
119 | elif postprocess == 'max':
120 | return_ = tf.maximum(return_fw, return_bw)
121 | elif postprocess == 'linear':
122 | if out_dim is None:
123 | out_dim = 2 * hidden_size
124 | return_ = tf.concat([return_fw, return_bw], 2 if return_outputs else 1)
125 | return_ = tf.layers.dense(return_, out_dim)
126 | else:
127 | return_ = postprocess(return_fw, return_bw)
128 | return_list.append(return_)
129 | if len(return_list) == 1:
130 | return return_list[0]
131 | return return_list
132 |
133 |
134 | def exp_mask(logits, mask, mask_is_length=True):
135 | """Exponential mask for logits.
136 |
137 | Logits cannot be masked with 0 (i.e. multiplying boolean mask)
138 | because expnentiating 0 becomes 1. `exp_mask` adds very large negative value
139 | to `False` portion of `mask` so that the portion is effectively ignored
140 | when exponentiated, e.g. softmaxed.
141 |
142 | Args:
143 | logits: Arbitrary-rank logits tensor to be masked.
144 | mask: `boolean` type mask tensor.
145 | Could be same shape as logits (`mask_is_length=False`)
146 | or could be length tensor of the logits (`mask_is_length=True`).
147 | mask_is_length: `bool` value. whether `mask` is boolean mask.
148 | Returns:
149 | Masked logits with the same shape of `logits`.
150 | """
151 | if mask_is_length:
152 | mask = tf.sequence_mask(mask, maxlen=tf.shape(logits)[-1])
153 | return logits + (1.0 - tf.cast(mask, 'float')) * VERY_LARGE_NEGATIVE_VALUE
154 |
155 |
156 | def self_att(tensor,
157 | tensor_val=None,
158 | mask=None,
159 | mask_is_length=True,
160 | logit_fn=None,
161 | scale_dot=False,
162 | normalizer=tf.nn.softmax,
163 | tensors=None,
164 | scope=None,
165 | reuse=False):
166 | """Performs self attention.
167 |
168 | Performs self attention to obtain single vector representation for a sequence
169 | of vectors.
170 |
171 | Args:
172 | tensor: [batch_size, sequence_length, hidden_size]-shaped tensor
173 | tensor_val: If specified, attention is applied on `tensor_val`, i.e.
174 | `tensor` is key.
175 | mask: Length mask (shape of [batch_size]) or
176 | boolean mask ([batch_size, sequence_length])
177 | mask_is_length: `True` if `mask` is length mask, `False` if it is boolean
178 | mask
179 | logit_fn: `logit_fn(tensor)` to obtain logits.
180 | scale_dot: `bool`, whether to scale the dot product by dividing by
181 | sqrt(hidden_size).
182 | normalizer: function to normalize logits.
183 | tensors: `dict`. If specified, add useful tensors (e.g. attention weights)
184 | to the `dict` with their (scope) names.
185 | scope: `string` for defining variable scope
186 | reuse: Reuse if `True`.
187 | Returns:
188 | [batch_size, hidden_size]-shaped tensor.
189 | """
190 | assert len(tensor.get_shape()
191 | ) == 3, 'The rank of `tensor` must be 3 but got {}.'.format(
192 | len(tensor.get_shape()))
193 | with tf.variable_scope(scope or 'self_att', reuse=reuse):
194 | hidden_size = tensor.get_shape().as_list()[-1]
195 | if logit_fn is None:
196 | logits = tf.layers.dense(tensor, hidden_size, activation=tf.tanh)
197 | logits = tf.squeeze(tf.layers.dense(logits, 1), 2)
198 | else:
199 | logits = logit_fn(tensor)
200 | if scale_dot:
201 | logits /= tf.sqrt(hidden_size)
202 | if mask is not None:
203 | logits = exp_mask(logits, mask, mask_is_length=mask_is_length)
204 | weights = normalizer(logits)
205 | if tensors is not None:
206 | weights = tf.identity(weights, name='attention')
207 | tensors[weights.op.name] = weights
208 | out = tf.reduce_sum(
209 | tf.expand_dims(weights, -1) * (tensor
210 | if tensor_val is None else tensor_val),
211 | 1)
212 | return out
213 |
214 |
215 | def highway(inputs,
216 | outputs=None,
217 | dropout_rate=0.0,
218 | batch_norm=False,
219 | training=False,
220 | scope=None,
221 | reuse=False):
222 | """Single-layer highway networks (https://arxiv.org/abs/1505.00387).
223 |
224 | Args:
225 | inputs: Arbitrary-rank `float` tensor, where the first dim is batch size
226 | and the last dim is where the highway network is applied.
227 | outputs: If provided, will replace the perceptron layer (i.e. gating only.)
228 | dropout_rate: `float` value, input dropout rate.
229 | batch_norm: `bool` value, whether to use batch normalization.
230 | training: `bool` value, whether the current run is training.
231 | scope: `str` value variable scope, default to `highway_net`.
232 | reuse: `bool` value, whether to reuse variables.
233 | Returns:
234 | The output of the highway network, same shape as `inputs`.
235 | """
236 | with tf.variable_scope(scope or 'highway', reuse=reuse):
237 | if dropout_rate > 0.0:
238 | inputs = tf.layers.dropout(inputs, rate=dropout_rate, training=training)
239 | dim = inputs.get_shape()[-1]
240 | if outputs is None:
241 | outputs = tf.layers.dense(inputs, dim, name='outputs')
242 | if batch_norm:
243 | outputs = tf.layers.batch_normalization(outputs, training=training)
244 | outputs = tf.nn.relu(outputs)
245 | gate = tf.layers.dense(inputs, dim, activation=tf.nn.sigmoid, name='gate')
246 | return gate * inputs + (1 - gate) * outputs
247 |
248 |
249 | def highway_net(inputs,
250 | num_layers,
251 | dropout_rate=0.0,
252 | batch_norm=False,
253 | training=False,
254 | scope=None,
255 | reuse=False):
256 | """Multi-layer highway networks (https://arxiv.org/abs/1505.00387).
257 |
258 | Args:
259 | inputs: `float` input tensor to the highway networks.
260 | num_layers: `int` value, indicating the number of highway layers to build.
261 | dropout_rate: `float` value for the input dropout rate.
262 | batch_norm: `bool` value, indicating whether to use batch normalization
263 | or not.
264 | training: `bool` value, indicating whether the current run is training
265 | or not (e.g. eval or inference).
266 | scope: `str` value, variable scope. Default is `highway_net`.
267 | reuse: `bool` value, indicating whether the variables in this function
268 | are reused.
269 | Returns:
270 | The output of the highway networks, which is the same shape as `inputs`.
271 | """
272 | with tf.variable_scope(scope or 'highway_net', reuse=reuse):
273 | outputs = inputs
274 | for i in range(num_layers):
275 | outputs = highway(
276 | outputs,
277 | dropout_rate=dropout_rate,
278 | batch_norm=batch_norm,
279 | training=training,
280 | scope='layer_{}'.format(i))
281 | return outputs
282 |
283 |
284 | def char_cnn(inputs,
285 | out_dim,
286 | kernel_size,
287 | dropout_rate=0.0,
288 | name=None,
289 | reuse=False,
290 | activation=None,
291 | batch_norm=False,
292 | training=False):
293 | """Character-level CNN.
294 |
295 | Args:
296 | inputs: Input tensor of shape [batch_size, num_words, num_chars, in_dim].
297 | out_dim: `int` value, output dimension of CNN.
298 | kernel_size: `int` value, the width of the kernel for CNN.
299 | dropout_rate: `float` value, input dropout rate.
300 | name: `str` value, variable scope for variables in this function.
301 | reuse: `bool` value, indicating whether to reuse CNN variables.
302 | activation: function for activation. Default is `tf.nn.relu`.
303 | batch_norm: `bool` value, whether to perform batch normalization.
304 | training: `bool` value, whether the current run is training or not.
305 | Returns:
306 | Output tensor of shape [batch_size, num_words, out_dim].
307 |
308 | """
309 | with tf.variable_scope(name or 'char_cnn', reuse=reuse):
310 | batch_size = tf.shape(inputs)[0]
311 | num_words = tf.shape(inputs)[1]
312 | num_chars = tf.shape(inputs)[2]
313 | in_dim = inputs.get_shape().as_list()[3]
314 | if dropout_rate > 0.0:
315 | inputs = tf.layers.dropout(inputs, rate=dropout_rate, training=training)
316 | outputs = tf.reshape(
317 | tf.layers.conv1d(
318 | tf.reshape(inputs, [batch_size * num_words, num_chars, in_dim]),
319 | out_dim,
320 | kernel_size,
321 | name=name), [batch_size, num_words, -1, out_dim])
322 | if batch_norm:
323 | outputs = tf.layers.batch_normalization(outputs, training=training)
324 | if activation is None:
325 | activation = tf.nn.relu
326 | outputs = activation(outputs)
327 | outputs = tf.reduce_max(outputs, 2)
328 | return outputs
329 |
330 |
331 | def att2d(a,
332 | b,
333 | a_val=None,
334 | mask=None,
335 | b_mask=None,
336 | a_null=None,
337 | logit_fn='bahdanau',
338 | return_weights=False,
339 | normalizer=tf.nn.softmax,
340 | transpose=False,
341 | scale_logits=False,
342 | reduce_fn=None,
343 | tensors=None,
344 | scope=None):
345 | """2D-attention on a pair of sequences.
346 |
347 | Obtain the most similar (most attended) vector among the vectors in `a` to
348 | each vector in `b`. That is, `b` can be considered as key and `a` is value.
349 | Or in other words, `b` is attender and `a` is attendee.
350 |
351 | Args:
352 | a: [batch_size, a_len, hidden_size] shaped tensor.
353 | b: [batch_size, b_len, hidden_size] shaped tensor.
354 | a_val: If specified, attention is performed on `a_val` instead of `a`.
355 | mask: length mask tensor, boolean mask tensor for `a`, or 2d mask of size
356 | [b_len, a_len].
357 | b_mask: something.
358 | a_null: If specified, this becomes a possible vector to be attended.
359 | logit_fn: `logit_fn(a, b)` computes logits for attention. By default,
360 | uses attention function by Bahdanau et al (2014). Can be string `dot`,
361 | in which case the dot product is memory-efficiently computed via
362 | `tf.batch_matmul`.
363 | return_weights: `bool` value, whether to return weights instead of
364 | the attended vector.
365 | normalizer: function that normalizes the weights.
366 | transpose: `bool`, whether to transpose the normalizer axis.
367 | scale_logits: `bool`, whether to scale the logits by
368 | sqrt(hidden_size), as shown in https://arxiv.org/abs/1706.03762.
369 | reduce_fn: python fn, which reduces the logit matrix in the b's axis
370 | if specified.
371 | tensors: `dict`. If specified, add useful tensors (e.g. attention weights)
372 | to the `dict` with their (scope) names.
373 | scope: `str` value, indicating the variable scope of this function.
374 | Returns:
375 | [batch_size, b_len, hidden_size] shaped tensor, where each vector
376 | represents the most similar vector among `a` for each vector in `b`.
377 | If `return_weights` is `True`, return tensor is
378 | [batch_size, b_len, a_len] shape. If `reduce_fn` is specified, the return
379 | tensor shape is [batch_size, 1, a_len]
380 | """
381 | with tf.variable_scope(scope or 'att_2d'):
382 | batch_size = tf.shape(a)[0]
383 | hidden_size = a.get_shape().as_list()[-1]
384 | if a_null is not None:
385 | a_null = tf.tile(
386 | tf.expand_dims(tf.expand_dims(a_null, 0), 0), [batch_size, 1, 1])
387 | a = tf.concat([a_null, a], 1)
388 | if mask is not None:
389 | # TODO(seominjoon) : To support other shapes of masks.
390 | assert len(mask.get_shape()) == 1
391 | mask += 1
392 |
393 | # Memory-efficient operation for dot product logit function.
394 | if logit_fn == 'wdot':
395 | weights = tf.get_variable('weights', shape=[hidden_size], dtype='float')
396 | bias = tf.get_variable('bias', shape=[], dtype='float')
397 | logits = tf.matmul(b * weights, a, transpose_b=True) + bias
398 | elif logit_fn == 'dot':
399 | logits = tf.matmul(b, a, transpose_b=True)
400 | elif logit_fn == 'bilinear':
401 | logits = tf.matmul(
402 | tf.layers.dense(b, hidden_size, use_bias=False), a, transpose_b=True)
403 | elif logit_fn == 'l2':
404 | ba = tf.matmul(b, a, transpose_b=True)
405 | aa = tf.expand_dims(tf.reduce_sum(a * a, 2), 1)
406 | bb = tf.expand_dims(tf.reduce_sum(b * b, 2), 2)
407 | logits = 2 * ba - aa - bb
408 | else:
409 | # WARNING : This is memory-intensive!
410 | aa = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[1], 1, 1])
411 | bb = tf.tile(tf.expand_dims(b, 2), [1, 1, tf.shape(a)[1], 1])
412 | if logit_fn == 'bahdanau':
413 | logits = tf.layers.dense(
414 | tf.concat([aa, bb], 3), hidden_size, activation=tf.tanh)
415 | logits = tf.squeeze(tf.layers.dense(logits, 1), 3)
416 | else:
417 | logits = logit_fn(aa, bb) # [batch_size, a_len, b_len]-shaped tensor.
418 | if scale_logits:
419 | logits /= tf.sqrt(tf.cast(hidden_size, 'float'))
420 | if mask is not None:
421 | if len(mask.get_shape()) == 1:
422 | mask = tf.sequence_mask(mask, tf.shape(a)[1])
423 | if len(mask.get_shape()) == 2:
424 | mask = tf.expand_dims(mask, 1)
425 | if b_mask is not None:
426 | if len(b_mask.get_shape()) == 1:
427 | b_mask = tf.sequence_mask(b_mask, tf.shape(b)[1])
428 | mask &= tf.expand_dims(b_mask, -1)
429 | logits = exp_mask(logits, mask, mask_is_length=False)
430 | if reduce_fn:
431 | logits = tf.expand_dims(reduce_fn(logits, 1), 1)
432 | if transpose:
433 | logits = tf.transpose(logits, [0, 2, 1])
434 | p = logits if normalizer is None else normalizer(logits)
435 | if transpose:
436 | p = tf.transpose(p, [0, 2, 1])
437 | logits = tf.transpose(logits, [0, 2, 1])
438 | if tensors is not None:
439 | p = tf.identity(p, name='attention')
440 | tensors[p.op.name] = p
441 |
442 | if return_weights:
443 | return p
444 |
445 | # Memory-efficient application of attention weights.
446 | # [batch_size, b_len, hidden_size]
447 | a_b = tf.matmul(p, a if a_val is None else a_val)
448 | return a_b
449 |
450 |
451 | def mlp(a,
452 | hidden_sizes,
453 | activation=tf.nn.relu,
454 | activate_last=True,
455 | dropout_rate=0.0,
456 | training=False,
457 | scope=None):
458 | """Multi-layer perceptron.
459 |
460 | Args:
461 | a: input tensor.
462 | hidden_sizes: `list` of `int`, hidden state sizes for perceptron layers.
463 | activation: function handler for activation.
464 | activate_last: `bool`, whether to activate the last layer or not.
465 | dropout_rate: `float`, dropout rate at the input of each layer.
466 | training: `bool`, whether the current run is training or not.
467 | scope: `str`, variable scope of all tensors and weights in this function.
468 | Returns:
469 | Tensor with same shape as `a` except for the last dim, whose size is equal
470 | to `hidden_sizes[-1]`.
471 | """
472 | with tf.variable_scope(scope or 'mlp'):
473 | for idx, hidden_size in enumerate(hidden_sizes):
474 | with tf.variable_scope('layer_%d' % idx):
475 | if dropout_rate > 0.0:
476 | a = tf.layers.dropout(a, rate=dropout_rate, training=training)
477 | activate = idx < len(hidden_sizes) - 1 or activate_last
478 | a = tf.layers.dense(
479 | a, hidden_size, activation=activation if activate else None)
480 |
481 | return a
482 |
483 |
484 | def split_concat(a, b, num, axis=None):
485 | if axis is None:
486 | axis = len(a.get_shape()) - 1
487 | a_list = tf.split(a, num, axis=axis)
488 | b_list = tf.split(b, num, axis=axis)
489 | t_list = tuple(
490 | tf.concat([aa, bb], axis=axis) for aa, bb in zip(a_list, b_list))
491 | return t_list
492 |
493 |
494 | def concat_seq_and_tok(sequence, token, position, sequence_length=None):
495 | """Concatenates a token to the given sequence, either at the start or end.
496 |
497 | The token's dimension should match the last dimension of the sequence.
498 |
499 | Args:
500 | sequence: [batch_size, sequence_length] shaped tensor or
501 | [batch_size, sequence_length, hidden_size] shaped tensor.
502 | token: scalar tensor or [hidden_size] shaped tensor.
503 | position: `str`, either 'start' or 'end'.
504 | sequence_length: [batch_size] shaped `int64` tensor. Must be specified
505 | if `position` is 'end'.
506 | Returns:
507 | [batch_size, sequence_length+1] or
508 | [batch_size, sequence_length+1, hidden_size] shaped tensor.
509 | Raises:
510 | ValueError: If `position` is not 'start' or 'end'.
511 | """
512 | batch_size = tf.shape(sequence)[0]
513 | if len(sequence.get_shape()) == 3:
514 | token = tf.tile(
515 | tf.expand_dims(tf.expand_dims(token, 0), 0), [batch_size, 1, 1])
516 | elif len(sequence.get_shape()) == 2:
517 | token = tf.tile(tf.reshape(token, [1, 1]), [batch_size, 1])
518 |
519 | if position == 'start':
520 | sequence = tf.concat([token, sequence], 1)
521 | elif position == 'end':
522 | assert sequence_length is not None
523 | sequence = tf.reverse_sequence(sequence, sequence_length, seq_axis=1)
524 | sequence = tf.concat([token, sequence], 1)
525 | sequence = tf.reverse_sequence(sequence, sequence_length + 1, seq_axis=1)
526 | else:
527 | raise ValueError('%r is an invalid argument for `position`.' % position)
528 | return sequence
529 |
530 |
531 | class ExternalInputWrapper(tf.contrib.rnn.RNNCell):
532 | """Wrapper for `RNNCell`, concatenates an external tensor to the input."""
533 |
534 | def __init__(self, cell, external_input, reuse=False):
535 | super(ExternalInputWrapper, self).__init__(_reuse=reuse)
536 | self._cell = cell
537 | self._external = external_input
538 |
539 | @property
540 | def state_size(self):
541 | return self._cell.state_size
542 |
543 | @property
544 | def output_size(self):
545 | return self._cell.output_size
546 |
547 | def zero_state(self, batch_size, dtype):
548 | with tf.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]):
549 | return self._cell.zero_state(batch_size, dtype)
550 |
551 | def call(self, inputs, state):
552 | inputs = tf.concat([self._external_input, inputs], 1)
553 | return self._cell(inputs, state)
554 |
555 |
556 | class DeembedWrapper(tf.contrib.seq2seq.Helper):
557 | """Wrapper for `Helper`, applies given deembed function to the output.
558 |
559 | The deembed function has single input and single output.
560 | It is applied on the output of the previous RNN before feeding it into the
561 | next time step.
562 | """
563 |
564 | def __init__(self, helper, deembed_fn):
565 | self._helper = helper
566 | self._deembed_fn = deembed_fn
567 |
568 | @property
569 | def batch_size(self):
570 | return self._helper.batch_size
571 |
572 | def initialize(self, name=None):
573 | return self._helper.initialize(name=name)
574 |
575 | def next_inputs(self, time, outputs, state, sample_ids, name=None):
576 | outputs = self._deembed_fn(outputs)
577 | return self._helper.next_inputs(time, outputs, state, sample_ids, name=name)
578 |
579 | def sample(self, time, outputs, state, name=None):
580 | outputs = self._deembed_fn(outputs)
581 | return self._helper.sample(time, outputs, state, name=name)
582 |
--------------------------------------------------------------------------------
/train_and_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Experiments for kernel vs feature map in SQuAD.
15 |
16 | `feature` model does not allow any interaction between question and context
17 | except at the end, where the dot product (or L1/L2 distance) is used to get the
18 | answer.
19 | `kernel` model allows any interaction between question and context
20 | (e.g. cross attention).
21 | This script is for establishing baseline for both feature and kernel models.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from collections import defaultdict
29 | import json
30 | import os
31 |
32 | import tensorflow as tf
33 | from tqdm import tqdm
34 | import tensorflow.contrib.learn as learn
35 |
36 | # This is required for importing google specific flags:
37 | # `output_dir`, `schedule`
38 | # (`learn` above is not sufficient). Will need to add these flags when
39 | # removing this import for open-sourcing.
40 | from tensorflow.contrib.learn import learn_runner
41 |
42 | import squad_data
43 | from common_model import get_loss
44 | from common_model import get_pred_ops
45 | from common_model import get_train_op
46 | from feature_model import feature_model
47 | from kernel_model import kernel_model
48 |
49 | tf.flags.DEFINE_integer('emb_size', 200, 'embedding size')
50 | tf.flags.DEFINE_integer('glove_size', 200, 'GloVe size')
51 | tf.flags.DEFINE_integer('hidden_size', 200, 'hidden state size')
52 | tf.flags.DEFINE_integer('num_train_steps', 20000, 'num train steps')
53 | tf.flags.DEFINE_integer('num_eval_steps', 500, 'num eval steps')
54 | tf.flags.DEFINE_boolean('draft', False, 'draft?')
55 | tf.flags.DEFINE_integer('batch_size', 64, 'batch size')
56 | tf.flags.DEFINE_float('dropout_rate', 0.2,
57 | 'dropout rate, applied to the input of LSTMs.')
58 | tf.flags.DEFINE_string(
59 | 'root_data_dir',
60 | '/cns/ok-d/home/neon-core/ker2vec/squad/prepro/sort_filter',
61 | 'root data dir')
62 | tf.flags.DEFINE_integer('save_checkpoints_steps', 500, '')
63 | tf.flags.DEFINE_integer('num_eval_delay_secs', 1, 'eval delay secs')
64 | tf.flags.DEFINE_boolean('shuffle_examples', False, 'Use shuffle example queue?')
65 | tf.flags.DEFINE_boolean('shuffle_files', True, 'Use shuffle file queue?')
66 | tf.flags.DEFINE_string('model', 'feature', '`feature` or `kernel`.')
67 | tf.flags.DEFINE_boolean('oom_test', False, 'Performs out-of-memory test')
68 | tf.flags.DEFINE_string(
69 | 'dist', 'dot', 'Distance function for feature model. `dot`, `l1` or `l2`.')
70 | tf.flags.DEFINE_float('learning_rate', 0.001,
71 | '(Initial) learning rate for optimizer')
72 | tf.flags.DEFINE_boolean(
73 | 'infer', False,
74 | 'If `True`, obtains and saves predictions for the test dataset '
75 | 'at `answers_path`.')
76 | tf.flags.DEFINE_string('answers_path', '',
77 | 'The path for saving predictions on test dataset. '
78 | 'If not specified, saves in `restore_dir` directory.')
79 | tf.flags.DEFINE_float('clip_norm', 0, 'Clip norm threshold, 0 for no clip.')
80 | tf.flags.DEFINE_integer(
81 | 'restore_step', 0,
82 | 'The global step for which the model is restored in the beginning. '
83 | '`0` for the most recent save file.')
84 | tf.flags.DEFINE_float(
85 | 'restore_decay', 1.0,
86 | 'The decay rate for exponential moving average of variables that '
87 | 'will be restored upon eval or infer. '
88 | '`1.0` for restoring variables without decay.')
89 | tf.flags.DEFINE_string(
90 | 'ema_decays', '',
91 | 'List of exponential moving average (EMA) decay rates (float) '
92 | 'to track for variables during training. Values are separated by commas.')
93 | tf.flags.DEFINE_string(
94 | 'restore_dir', '',
95 | 'Directory from which variables are restored. If not specfied, `output_dir`'
96 | 'will be used instead. For inference mode, this needs to be specified.')
97 | tf.flags.DEFINE_string('model_id', 'm00', 'Model id.')
98 | tf.flags.DEFINE_string('glove_dir', '/cns/ok-d/home/neon-core/ker2vec/glove',
99 | 'GloVe dir.')
100 | tf.flags.DEFINE_boolean('merge', False, 'If `True`, merges answers from same '
101 | 'paragraph that were split in preprocessing step.')
102 | tf.flags.DEFINE_integer('queue_capacity', 5000, 'Input queue capacity.')
103 | tf.flags.DEFINE_integer('min_after_dequeue', 1000, 'Minimum number of examples '
104 | 'after queue dequeue.')
105 | tf.flags.DEFINE_integer('max_answer_size', 7, 'Max number of answer words.')
106 | tf.flags.DEFINE_string('restore_scopes', '', 'Restore scopes, separated by ,.')
107 | tf.flags.DEFINE_boolean('reg_gen', True, 'Whether to regularize training '
108 | 'with question generation (reconstruction) loss.')
109 | tf.flags.DEFINE_float('reg_cf', 3.0, 'Regularization initial coefficient.')
110 | tf.flags.DEFINE_float('reg_half_life', 6000, 'Regularization decay half life. '
111 | 'Set it to very high value to effectively disable decay.')
112 | tf.flags.DEFINE_integer('max_gen_length', 32, 'During inference, maximum '
113 | 'length of generated question.')
114 |
115 | # Below are added for third party.
116 | tf.flags.DEFINE_string('schedule', 'train_and_evaluate',
117 | 'schedule for learn_runner.')
118 | tf.flags.DEFINE_string('output_dir', '/tmp/squad_ckpts',
119 | 'Output directory for saving model.')
120 |
121 | FLAGS = tf.flags.FLAGS
122 | tf.logging.set_verbosity(tf.logging.INFO)
123 |
124 |
125 | def model_fn(features, targets, mode, params):
126 | """Model function to be used for `Experiment` object.
127 |
128 | Should not access `flags.FLAGS`.
129 |
130 | Args:
131 | features: a dictionary of feature tensors.
132 | targets: a dictionary of target tensors.
133 | mode: `learn.ModeKeys.TRAIN` or `learn.ModeKeys.EVAL`.
134 | params: `HParams` object.
135 | Returns:
136 | `ModelFnOps` object.
137 | Raises:
138 | ValueError: rasied if `params.model` is not an appropriate value.
139 | """
140 | with tf.variable_scope('model'):
141 | if params.model == 'feature':
142 | logits_start, logits_end, tensors = feature_model(
143 | features, mode, params)
144 | elif params.model == 'kernel':
145 | logits_start, logits_end, tensors = kernel_model(
146 | features, mode, params)
147 | else:
148 | raise ValueError(
149 | '`%s` is an invalid argument for `model` parameter.' % params.model)
150 | no_answer_bias = tf.get_variable('no_answer_bias', shape=[], dtype='float')
151 | no_answer_bias = tf.tile(
152 | tf.reshape(no_answer_bias, [1, 1]),
153 | [tf.shape(features['context_words'])[0], 1])
154 |
155 | predictions = get_pred_ops(features, params, logits_start, logits_end,
156 | no_answer_bias)
157 | predictions.update(tensors)
158 | predictions.update(features)
159 |
160 | if mode == learn.ModeKeys.INFER:
161 | eval_metric_ops, loss = None, None
162 | else:
163 | eval_metric_ops = squad_data.get_eval_metric_ops(targets, predictions)
164 | loss = get_loss(targets['word_answer_starts'], targets['word_answer_ends'],
165 | logits_start, logits_end, no_answer_bias)
166 |
167 | emas = {
168 | decay: tf.train.ExponentialMovingAverage(
169 | decay=decay, name='EMA_%f' % decay)
170 | for decay in params.ema_decays
171 | }
172 |
173 | ema_ops = [ema.apply() for ema in emas.values()]
174 | if mode == learn.ModeKeys.TRAIN:
175 | train_op = get_train_op(
176 | loss,
177 | learning_rate=params.learning_rate,
178 | clip_norm=params.clip_norm,
179 | post_ops=ema_ops)
180 | # TODO(seominjoon): Checking `Exists` is not the best way to do this.
181 | if params.restore_dir and not tf.gfile.Exists(params.output_dir):
182 | assert params.restore_scopes
183 | checkpoint_dir = params.restore_dir
184 | if params.restore_step:
185 | checkpoint_dir = os.path.join(params.restore_dir,
186 | 'model.ckpt-%d' % params.restore_step)
187 | restore_vars = []
188 | for restore_scope in params.restore_scopes:
189 | restore_vars.extend(
190 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, restore_scope))
191 | assignment_map = {var.op.name: var for var in restore_vars}
192 | tf.contrib.framework.init_from_checkpoint(checkpoint_dir, assignment_map)
193 | else:
194 | if params.restore_decay < 1.0:
195 | ema = emas[params.restore_decay]
196 | assign_ops = []
197 | for var in tf.trainable_variables():
198 | assign_op = tf.assign(var, ema.average(var))
199 | assign_ops.append(assign_op)
200 | with tf.control_dependencies(assign_ops):
201 | for key, val in predictions.items():
202 | predictions[key] = tf.identity(val)
203 | train_op = None
204 |
205 | return learn.ModelFnOps(
206 | mode=mode,
207 | predictions=predictions,
208 | loss=loss,
209 | train_op=train_op,
210 | eval_metric_ops=eval_metric_ops)
211 |
212 |
213 | def _experiment_fn(run_config, hparams):
214 | """Outputs `Experiment` object given `output_dir`.
215 |
216 | Args:
217 | run_config: `EstimatorConfig` object fo run configuration.
218 | hparams: `HParams` object that contains hyperparameters.
219 |
220 | Returns:
221 | `Experiment` object
222 | """
223 | estimator = learn.Estimator(
224 | model_fn=model_fn, config=run_config, params=hparams)
225 |
226 | num_train_steps = 1 if FLAGS.oom_test else FLAGS.num_train_steps
227 | num_eval_steps = 1 if FLAGS.oom_test else FLAGS.num_eval_steps
228 |
229 | return learn.Experiment(
230 | estimator=estimator,
231 | train_input_fn=_get_train_input_fn(),
232 | eval_input_fn=_get_eval_input_fn(),
233 | train_steps=num_train_steps,
234 | eval_steps=num_eval_steps,
235 | eval_delay_secs=FLAGS.num_eval_delay_secs)
236 |
237 |
238 | def _get_train_input_fn():
239 | """Get train input function."""
240 | train_input_fn = squad_data.get_input_fn(
241 | FLAGS.root_data_dir,
242 | FLAGS.glove_dir,
243 | 'train',
244 | FLAGS.batch_size,
245 | FLAGS.glove_size,
246 | shuffle_files=FLAGS.shuffle_files,
247 | shuffle_examples=FLAGS.shuffle_examples,
248 | queue_capacity=FLAGS.queue_capacity,
249 | min_after_dequeue=FLAGS.min_after_dequeue,
250 | oom_test=FLAGS.oom_test)
251 | return train_input_fn
252 |
253 |
254 | def _get_eval_input_fn():
255 | """Get eval input function."""
256 | eval_input_fn = squad_data.get_input_fn(
257 | FLAGS.root_data_dir,
258 | FLAGS.glove_dir,
259 | 'dev',
260 | FLAGS.batch_size,
261 | FLAGS.glove_size,
262 | shuffle_files=True,
263 | shuffle_examples=True,
264 | queue_capacity=FLAGS.queue_capacity,
265 | min_after_dequeue=FLAGS.min_after_dequeue,
266 | num_epochs=1,
267 | oom_test=FLAGS.oom_test)
268 | return eval_input_fn
269 |
270 |
271 | def _get_test_input_fn():
272 | """Get test input function."""
273 | # TODO(seominjoon) For now, test input is same as eval input (dev).
274 | test_input_fn = squad_data.get_input_fn(
275 | FLAGS.root_data_dir,
276 | FLAGS.glove_dir,
277 | 'dev',
278 | FLAGS.batch_size,
279 | FLAGS.glove_size,
280 | shuffle_files=FLAGS.shuffle_files,
281 | shuffle_examples=FLAGS.shuffle_examples,
282 | queue_capacity=FLAGS.queue_capacity,
283 | min_after_dequeue=FLAGS.min_after_dequeue,
284 | num_epochs=1,
285 | oom_test=FLAGS.oom_test)
286 | return test_input_fn
287 |
288 |
289 | def _get_config():
290 | """Get configuration object for `Estimator` object.
291 |
292 | For open-soucing, `EstimatorConfig` has been replaced with `RunConfig`.
293 | Depends on `flags.FLAGS`, and should not be used outside of this main script.
294 |
295 | Returns:
296 | `EstimatorConfig` object.
297 | """
298 | config = learn.RunConfig(
299 | model_dir=FLAGS.restore_dir if FLAGS.infer else FLAGS.output_dir,
300 | keep_checkpoint_max=0, # Keep all checkpoints.
301 | save_checkpoints_steps=FLAGS.save_checkpoints_steps)
302 | return config
303 |
304 |
305 | def _get_hparams():
306 | """Model-specific hyperparameters go here.
307 |
308 | All model parameters go here, since `model_fn()` should not access
309 | `flags.FLAGS`.
310 | Depends on `flags.FLAGS`, and should not be used outside of this main script.
311 |
312 | Returns:
313 | `HParams` object.
314 | """
315 | hparams = tf.contrib.training.HParams()
316 | data_hparams = squad_data.get_params(FLAGS.root_data_dir)
317 | hparams.vocab_size = data_hparams['vocab_size']
318 | hparams.char_vocab_size = data_hparams['char_vocab_size']
319 | hparams.batch_size = FLAGS.batch_size
320 | hparams.hidden_size = FLAGS.hidden_size
321 | hparams.emb_size = FLAGS.emb_size
322 | hparams.dropout_rate = FLAGS.dropout_rate
323 | hparams.dist = FLAGS.dist
324 | hparams.learning_rate = FLAGS.learning_rate
325 | hparams.model = FLAGS.model
326 | hparams.restore_dir = FLAGS.restore_dir
327 | hparams.output_dir = FLAGS.output_dir
328 | hparams.clip_norm = FLAGS.clip_norm
329 | hparams.restore_decay = FLAGS.restore_decay
330 | if FLAGS.ema_decays:
331 | hparams.ema_decays = list(map(float, FLAGS.ema_decays.split(',')))
332 | else:
333 | hparams.ema_decays = []
334 | hparams.restore_step = FLAGS.restore_step
335 | hparams.model_id = FLAGS.model_id
336 | hparams.max_answer_size = FLAGS.max_answer_size
337 | hparams.restore_scopes = FLAGS.restore_scopes.split(',')
338 | hparams.glove_size = FLAGS.glove_size
339 |
340 | # Regularization by Query Generation (reconstruction) parameters.
341 | hparams.reg_gen = FLAGS.reg_gen
342 | hparams.reg_cf = FLAGS.reg_cf
343 | hparams.reg_half_life = FLAGS.reg_half_life
344 |
345 | return hparams
346 |
347 |
348 | def train_and_eval():
349 | """Train and eval routine."""
350 | learn_runner.run(
351 | experiment_fn=_experiment_fn,
352 | schedule=FLAGS.schedule,
353 | run_config=_get_config(),
354 | hparams=_get_hparams())
355 |
356 |
357 | def _set_ckpt():
358 | # TODO(seominjoon): This is adhoc. Need better ckpt loading during inf.
359 | if FLAGS.restore_step:
360 | path = os.path.join(FLAGS.restore_dir, 'checkpoint')
361 | with tf.gfile.GFile(path, 'w') as fp:
362 | fp.write('model_checkpoint_path: "model.ckpt-%d"\n' % FLAGS.restore_step)
363 |
364 |
365 | def infer():
366 | """Inference routine, outputting answers to `FLAGS.answers_path`."""
367 | _set_ckpt()
368 | estimator = learn.Estimator(
369 | model_fn=model_fn, config=_get_config(), params=_get_hparams())
370 | predictions = estimator.predict(
371 | input_fn=_get_test_input_fn(), as_iterable=True)
372 | global_step = estimator.get_variable_value('global_step')
373 | path = FLAGS.answers_path or os.path.join(FLAGS.restore_dir,
374 | 'answers-%d.json' % global_step)
375 | answer_dict = {'no_answer_prob': {}, 'answer_prob': {}}
376 | for prediction in tqdm(predictions):
377 | id_ = prediction['id'].decode('utf-8')
378 | answer_dict[id_] = prediction['a'].decode('utf-8')
379 | answer_dict['answer_prob'][id_] = prediction['answer_prob'].tolist()
380 | answer_dict['no_answer_prob'][id_] = prediction['no_answer_prob'].tolist()
381 | if FLAGS.oom_test:
382 | break
383 |
384 | # TODO(seominjoon): use sum of logits instead of normalized prob.
385 | if FLAGS.merge:
386 | new_answer_dict = defaultdict(list)
387 | for id_, answer_prob in answer_dict['answer_prob'].items():
388 | answer = answer_dict[id_]
389 | id_ = id_.split(' ')[0] # retrieve true id
390 | new_answer_dict[id_].append([answer_prob, answer])
391 | answer_dict = {
392 | id_: max(each, key=lambda pair: pair[0])[1]
393 | for id_, each in new_answer_dict.items()
394 | }
395 |
396 | with tf.gfile.GFile(path, 'w') as fp:
397 | json.dump(answer_dict, fp)
398 | tf.logging.info('Dumped predictions at: %s' % path)
399 |
400 |
401 | def main(_):
402 | if FLAGS.infer:
403 | infer()
404 | else:
405 | train_and_eval()
406 |
407 |
408 | if __name__ == '__main__':
409 | tf.app.run()
410 |
--------------------------------------------------------------------------------