├── .gitignore
├── LICENSE.txt
├── README.md
├── __init__.py
├── attention_decoder.py
├── batcher.py
├── beam_search.py
├── data.py
├── decode.py
├── inspect_checkpoint.py
├── model.py
├── run_summarization.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | log/
4 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright 2017 The TensorFlow Authors. All rights reserved.
2 | Modifications Copyright 2017 Abigail See
3 |
4 |
5 | Apache License
6 | Version 2.0, January 2004
7 | http://www.apache.org/licenses/
8 |
9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10 |
11 | 1. Definitions.
12 |
13 | "License" shall mean the terms and conditions for use, reproduction,
14 | and distribution as defined by Sections 1 through 9 of this document.
15 |
16 | "Licensor" shall mean the copyright owner or entity authorized by
17 | the copyright owner that is granting the License.
18 |
19 | "Legal Entity" shall mean the union of the acting entity and all
20 | other entities that control, are controlled by, or are under common
21 | control with that entity. For the purposes of this definition,
22 | "control" means (i) the power, direct or indirect, to cause the
23 | direction or management of such entity, whether by contract or
24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
25 | outstanding shares, or (iii) beneficial ownership of such entity.
26 |
27 | "You" (or "Your") shall mean an individual or Legal Entity
28 | exercising permissions granted by this License.
29 |
30 | "Source" form shall mean the preferred form for making modifications,
31 | including but not limited to software source code, documentation
32 | source, and configuration files.
33 |
34 | "Object" form shall mean any form resulting from mechanical
35 | transformation or translation of a Source form, including but
36 | not limited to compiled object code, generated documentation,
37 | and conversions to other media types.
38 |
39 | "Work" shall mean the work of authorship, whether in Source or
40 | Object form, made available under the License, as indicated by a
41 | copyright notice that is included in or attached to the work
42 | (an example is provided in the Appendix below).
43 |
44 | "Derivative Works" shall mean any work, whether in Source or Object
45 | form, that is based on (or derived from) the Work and for which the
46 | editorial revisions, annotations, elaborations, or other modifications
47 | represent, as a whole, an original work of authorship. For the purposes
48 | of this License, Derivative Works shall not include works that remain
49 | separable from, or merely link (or bind by name) to the interfaces of,
50 | the Work and Derivative Works thereof.
51 |
52 | "Contribution" shall mean any work of authorship, including
53 | the original version of the Work and any modifications or additions
54 | to that Work or Derivative Works thereof, that is intentionally
55 | submitted to Licensor for inclusion in the Work by the copyright owner
56 | or by an individual or Legal Entity authorized to submit on behalf of
57 | the copyright owner. For the purposes of this definition, "submitted"
58 | means any form of electronic, verbal, or written communication sent
59 | to the Licensor or its representatives, including but not limited to
60 | communication on electronic mailing lists, source code control systems,
61 | and issue tracking systems that are managed by, or on behalf of, the
62 | Licensor for the purpose of discussing and improving the Work, but
63 | excluding communication that is conspicuously marked or otherwise
64 | designated in writing by the copyright owner as "Not a Contribution."
65 |
66 | "Contributor" shall mean Licensor and any individual or Legal Entity
67 | on behalf of whom a Contribution has been received by Licensor and
68 | subsequently incorporated within the Work.
69 |
70 | 2. Grant of Copyright License. Subject to the terms and conditions of
71 | this License, each Contributor hereby grants to You a perpetual,
72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73 | copyright license to reproduce, prepare Derivative Works of,
74 | publicly display, publicly perform, sublicense, and distribute the
75 | Work and such Derivative Works in Source or Object form.
76 |
77 | 3. Grant of Patent License. Subject to the terms and conditions of
78 | this License, each Contributor hereby grants to You a perpetual,
79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80 | (except as stated in this section) patent license to make, have made,
81 | use, offer to sell, sell, import, and otherwise transfer the Work,
82 | where such license applies only to those patent claims licensable
83 | by such Contributor that are necessarily infringed by their
84 | Contribution(s) alone or by combination of their Contribution(s)
85 | with the Work to which such Contribution(s) was submitted. If You
86 | institute patent litigation against any entity (including a
87 | cross-claim or counterclaim in a lawsuit) alleging that the Work
88 | or a Contribution incorporated within the Work constitutes direct
89 | or contributory patent infringement, then any patent licenses
90 | granted to You under this License for that Work shall terminate
91 | as of the date such litigation is filed.
92 |
93 | 4. Redistribution. You may reproduce and distribute copies of the
94 | Work or Derivative Works thereof in any medium, with or without
95 | modifications, and in Source or Object form, provided that You
96 | meet the following conditions:
97 |
98 | (a) You must give any other recipients of the Work or
99 | Derivative Works a copy of this License; and
100 |
101 | (b) You must cause any modified files to carry prominent notices
102 | stating that You changed the files; and
103 |
104 | (c) You must retain, in the Source form of any Derivative Works
105 | that You distribute, all copyright, patent, trademark, and
106 | attribution notices from the Source form of the Work,
107 | excluding those notices that do not pertain to any part of
108 | the Derivative Works; and
109 |
110 | (d) If the Work includes a "NOTICE" text file as part of its
111 | distribution, then any Derivative Works that You distribute must
112 | include a readable copy of the attribution notices contained
113 | within such NOTICE file, excluding those notices that do not
114 | pertain to any part of the Derivative Works, in at least one
115 | of the following places: within a NOTICE text file distributed
116 | as part of the Derivative Works; within the Source form or
117 | documentation, if provided along with the Derivative Works; or,
118 | within a display generated by the Derivative Works, if and
119 | wherever such third-party notices normally appear. The contents
120 | of the NOTICE file are for informational purposes only and
121 | do not modify the License. You may add Your own attribution
122 | notices within Derivative Works that You distribute, alongside
123 | or as an addendum to the NOTICE text from the Work, provided
124 | that such additional attribution notices cannot be construed
125 | as modifying the License.
126 |
127 | You may add Your own copyright statement to Your modifications and
128 | may provide additional or different license terms and conditions
129 | for use, reproduction, or distribution of Your modifications, or
130 | for any such Derivative Works as a whole, provided Your use,
131 | reproduction, and distribution of the Work otherwise complies with
132 | the conditions stated in this License.
133 |
134 | 5. Submission of Contributions. Unless You explicitly state otherwise,
135 | any Contribution intentionally submitted for inclusion in the Work
136 | by You to the Licensor shall be under the terms and conditions of
137 | this License, without any additional terms or conditions.
138 | Notwithstanding the above, nothing herein shall supersede or modify
139 | the terms of any separate license agreement you may have executed
140 | with Licensor regarding such Contributions.
141 |
142 | 6. Trademarks. This License does not grant permission to use the trade
143 | names, trademarks, service marks, or product names of the Licensor,
144 | except as required for reasonable and customary use in describing the
145 | origin of the Work and reproducing the content of the NOTICE file.
146 |
147 | 7. Disclaimer of Warranty. Unless required by applicable law or
148 | agreed to in writing, Licensor provides the Work (and each
149 | Contributor provides its Contributions) on an "AS IS" BASIS,
150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151 | implied, including, without limitation, any warranties or conditions
152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153 | PARTICULAR PURPOSE. You are solely responsible for determining the
154 | appropriateness of using or redistributing the Work and assume any
155 | risks associated with Your exercise of permissions under this License.
156 |
157 | 8. Limitation of Liability. In no event and under no legal theory,
158 | whether in tort (including negligence), contract, or otherwise,
159 | unless required by applicable law (such as deliberate and grossly
160 | negligent acts) or agreed to in writing, shall any Contributor be
161 | liable to You for damages, including any direct, indirect, special,
162 | incidental, or consequential damages of any character arising as a
163 | result of this License or out of the use or inability to use the
164 | Work (including but not limited to damages for loss of goodwill,
165 | work stoppage, computer failure or malfunction, or any and all
166 | other commercial damages or losses), even if such Contributor
167 | has been advised of the possibility of such damages.
168 |
169 | 9. Accepting Warranty or Additional Liability. While redistributing
170 | the Work or Derivative Works thereof, You may choose to offer,
171 | and charge a fee for, acceptance of support, warranty, indemnity,
172 | or other liability obligations and/or rights consistent with this
173 | License. However, in accepting such obligations, You may act only
174 | on Your own behalf and on Your sole responsibility, not on behalf
175 | of any other Contributor, and only if You agree to indemnify,
176 | defend, and hold each Contributor harmless for any liability
177 | incurred by, or claims asserted against, such Contributor by reason
178 | of your accepting any such warranty or additional liability.
179 |
180 | END OF TERMS AND CONDITIONS
181 |
182 | APPENDIX: How to apply the Apache License to your work.
183 |
184 | To apply the Apache License to your work, attach the following
185 | boilerplate notice, with the fields enclosed by brackets "[]"
186 | replaced with your own identifying information. (Don't include
187 | the brackets!) The text should be enclosed in the appropriate
188 | comment syntax for the file format. We also recommend that a
189 | file or class name and description of purpose be included on the
190 | same "printed page" as the copyright notice for easier
191 | identification within third-party archives.
192 |
193 | Copyright 2017, The TensorFlow Authors.
194 | Modifications Copyright 2017 Abigail See
195 |
196 | Licensed under the Apache License, Version 2.0 (the "License");
197 | you may not use this file except in compliance with the License.
198 | You may obtain a copy of the License at
199 |
200 | http://www.apache.org/licenses/LICENSE-2.0
201 |
202 | Unless required by applicable law or agreed to in writing, software
203 | distributed under the License is distributed on an "AS IS" BASIS,
204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
205 | See the License for the specific language governing permissions and
206 | limitations under the License.
207 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository contains code for the ACL 2017 paper *[Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)*.
2 |
3 | ## Looking for test set output?
4 | The test set output of the models described in the paper can be found [here](https://drive.google.com/file/d/0B7pQmm-OfDv7MEtMVU5sOHc5LTg/view?usp=sharing).
5 |
6 | ## Looking for pretrained model?
7 | A pretrained model is available here:
8 | * [Version for Tensorflow 1.0](https://drive.google.com/file/d/0B7pQmm-OfDv7SHFadHR4RllfR1E/view?usp=sharing)
9 | * [Version for Tensorflow 1.2.1](https://drive.google.com/file/d/0B7pQmm-OfDv7ZUhHZm9ZWEZidDg/view?usp=sharing)
10 |
11 | (The only difference between these two is the naming of some of the variables in the checkpoint. Tensorflow 1.0 uses `lstm_cell/biases` and `lstm_cell/weights` whereas Tensorflow 1.2.1 uses `lstm_cell/bias` and `lstm_cell/kernel`).
12 |
13 | ## Looking for CNN / Daily Mail data?
14 | Instructions are [here](https://github.com/abisee/cnn-dailymail).
15 |
16 | ## About this code
17 | This code is based on the [TextSum code](https://github.com/tensorflow/models/tree/master/textsum) from Google Brain.
18 |
19 | This code was developed for Tensorflow 0.12, but has been updated to run with Tensorflow 1.0.
20 | In particular, the code in attention_decoder.py is based on [tf.contrib.legacy_seq2seq_attention_decoder](https://www.tensorflow.org/api_docs/python/tf/contrib/legacy_seq2seq/attention_decoder), which is now outdated.
21 | Tensorflow 1.0's [new seq2seq library](https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention) probably provides a way to do this (as well as beam search) more elegantly and efficiently in the future.
22 |
23 | ## How to run
24 |
25 | ### Get the dataset
26 | To obtain the CNN / Daily Mail dataset, follow the instructions [here](https://github.com/abisee/cnn-dailymail). Once finished, you should have [chunked](https://github.com/abisee/cnn-dailymail/issues/3) datafiles `train_000.bin`, ..., `train_287.bin`, `val_000.bin`, ..., `val_013.bin`, `test_000.bin`, ..., `test_011.bin` (each contains 1000 examples) and a vocabulary file `vocab`.
27 |
28 | **Note**: If you did this before 7th May 2017, follow the instructions [here](https://github.com/abisee/cnn-dailymail/issues/2) to correct a bug in the process.
29 |
30 | ### Run training
31 | To train your model, run:
32 |
33 | ```
34 | python run_summarization.py --mode=train --data_path=/path/to/chunked/train_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment
35 | ```
36 |
37 | This will create a subdirectory of your specified `log_root` called `myexperiment` where all checkpoints and other data will be saved. Then the model will start training using the `train_*.bin` files as training data.
38 |
39 | **Warning**: Using default settings as in the above command, both initializing the model and running training iterations will probably be quite slow. To make things faster, try setting the following flags (especially `max_enc_steps` and `max_dec_steps`) to something smaller than the defaults specified in `run_summarization.py`: `hidden_dim`, `emb_dim`, `batch_size`, `max_enc_steps`, `max_dec_steps`, `vocab_size`.
40 |
41 | **Increasing sequence length during training**: Note that to obtain the results described in the paper, we increase the values of `max_enc_steps` and `max_dec_steps` in stages throughout training (mostly so we can perform quicker iterations during early stages of training). If you wish to do the same, start with small values of `max_enc_steps` and `max_dec_steps`, then interrupt and restart the job with larger values when you want to increase them.
42 |
43 | ### Run (concurrent) eval
44 | You may want to run a concurrent evaluation job, that runs your model on the validation set and logs the loss. To do this, run:
45 |
46 | ```
47 | python run_summarization.py --mode=eval --data_path=/path/to/chunked/val_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment
48 | ```
49 |
50 | Note: you want to run the above command using the same settings you entered for your training job.
51 |
52 | **Restoring snapshots**: The eval job saves a snapshot of the model that scored the lowest loss on the validation data so far. You may want to restore one of these "best models", e.g. if your training job has overfit, or if the training checkpoint has become corrupted by NaN values. To do this, run your train command plus the `--restore_best_model=1` flag. This will copy the best model in the eval directory to the train directory. Then run the usual train command again.
53 |
54 | ### Run beam search decoding
55 | To run beam search decoding:
56 |
57 | ```
58 | python run_summarization.py --mode=decode --data_path=/path/to/chunked/val_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment
59 | ```
60 |
61 | Note: you want to run the above command using the same settings you entered for your training job (plus any decode mode specific flags like `beam_size`).
62 |
63 | This will repeatedly load random examples from your specified datafile and generate a summary using beam search. The results will be printed to screen.
64 |
65 | **Visualize your output**: Additionally, the decode job produces a file called `attn_vis_data.json`. This file provides the data necessary for an in-browser visualization tool that allows you to view the attention distributions projected onto the text. To use the visualizer, follow the instructions [here](https://github.com/abisee/attn_vis).
66 |
67 | If you want to run evaluation on the entire validation or test set and get ROUGE scores, set the flag `single_pass=1`. This will go through the entire dataset in order, writing the generated summaries to file, and then run evaluation using [pyrouge](https://pypi.python.org/pypi/pyrouge). (Note this will *not* produce the `attn_vis_data.json` files for the attention visualizer).
68 |
69 | ### Evaluate with ROUGE
70 | `decode.py` uses the Python package [`pyrouge`](https://pypi.python.org/pypi/pyrouge) to run ROUGE evaluation. `pyrouge` provides an easier-to-use interface for the official Perl ROUGE package, which you must install for `pyrouge` to work. Here are some useful instructions on how to do this:
71 | * [How to setup Perl ROUGE](http://kavita-ganesan.com/rouge-howto)
72 | * [More details about plugins for Perl ROUGE](http://www.summarizerman.com/post/42675198985/figuring-out-rouge)
73 |
74 | **Note:** As of 18th May 2017 the [website](http://berouge.com/) for the official Perl package appears to be down. Unfortunately you need to download a directory called `ROUGE-1.5.5` from there. As an alternative, it seems that you can get that directory from [here](https://github.com/andersjo/pyrouge) (however, the version of `pyrouge` in that repo appears to be outdated, so best to install `pyrouge` from the [official source](https://pypi.python.org/pypi/pyrouge)).
75 |
76 | ### Tensorboard
77 | Run Tensorboard from the experiment directory (in the example above, `myexperiment`). You should be able to see data from the train and eval runs. If you select "embeddings", you should also see your word embeddings visualized.
78 |
79 | ### Help, I've got NaNs!
80 | For reasons that are [difficult to diagnose](https://github.com/abisee/pointer-generator/issues/4), NaNs sometimes occur during training, making the loss=NaN and sometimes also corrupting the model checkpoint with NaN values, making it unusable. Here are some suggestions:
81 |
82 | * If training stopped with the `Loss is not finite. Stopping.` exception, you can just try restarting. It may be that the checkpoint is not corrupted.
83 | * You can check if your checkpoint is corrupted by using the `inspect_checkpoint.py` script. If it says that all values are finite, then your checkpoint is OK and you can try resuming training with it.
84 | * The training job is set to keep 3 checkpoints at any one time (see the `max_to_keep` variable in `run_summarization.py`). If your newer checkpoint is corrupted, it may be that one of the older ones is not. You can switch to that checkpoint by editing the `checkpoint` file inside the `train` directory.
85 | * Alternatively, you can restore a "best model" from the `eval` directory. See the note **Restoring snapshots** above.
86 | * If you want to try to diagnose the cause of the NaNs, you can run with the `--debug=1` flag turned on. This will run [Tensorflow Debugger](https://www.tensorflow.org/versions/master/programmers_guide/debugger), which checks for NaNs and diagnoses their causes during training.
87 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/becxer/pointer-generator/3b480d56437867b2b021467a664084915f74144b/__init__.py
--------------------------------------------------------------------------------
/attention_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file defines the decoder"""
18 |
19 | import tensorflow as tf
20 | from tensorflow.python.ops import variable_scope
21 | from tensorflow.python.ops import array_ops
22 | from tensorflow.python.ops import nn_ops
23 | from tensorflow.python.ops import math_ops
24 |
25 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated.
26 | # In the future, it would make more sense to write variants on the attention mechanism using the new seq2seq library for tensorflow 1.0: https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention
27 | def attention_decoder(decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None):
28 | """
29 | Args:
30 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
31 | initial_state: 2D Tensor [batch_size x cell.state_size].
32 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size].
33 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1).
34 | cell: rnn_cell.RNNCell defining the cell function and size.
35 | initial_state_attention:
36 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step).
37 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step.
38 | use_coverage: boolean. If True, use coverage mechanism.
39 | prev_coverage:
40 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage.
41 |
42 | Returns:
43 | outputs: A list of the same length as decoder_inputs of 2D Tensors of
44 | shape [batch_size x cell.output_size]. The output vectors.
45 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size].
46 | attn_dists: A list containing tensors of shape (batch_size,attn_length).
47 | The attention distributions for each decoder step.
48 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False.
49 | coverage: Coverage vector on the last step computed. None if use_coverage=False.
50 | """
51 | with variable_scope.variable_scope("attention_decoder") as scope:
52 | batch_size = encoder_states.get_shape()[0].value # if this line fails, it's because the batch size isn't defined
53 | attn_size = encoder_states.get_shape()[2].value # if this line fails, it's because the attention length isn't defined
54 |
55 | # Reshape encoder_states (need to insert a dim)
56 | encoder_states = tf.expand_dims(encoder_states, axis=2) # now is shape (batch_size, attn_len, 1, attn_size)
57 |
58 | # To calculate attention, we calculate
59 | # v^T tanh(W_h h_i + W_s s_t + b_attn)
60 | # where h_i is an encoder state, and s_t a decoder state.
61 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t).
62 | # We set it to be equal to the size of the encoder states.
63 | attention_vec_size = attn_size
64 |
65 | # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features
66 | W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size])
67 | encoder_features = nn_ops.conv2d(encoder_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,attn_length,1,attention_vec_size)
68 |
69 | # Get the weight vectors v and w_c (w_c is for coverage)
70 | v = variable_scope.get_variable("v", [attention_vec_size])
71 | if use_coverage:
72 | with variable_scope.variable_scope("coverage"):
73 | w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size])
74 |
75 | if prev_coverage is not None: # for beam search mode with coverage
76 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 1, 1)
77 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3)
78 |
79 | def attention(decoder_state, coverage=None):
80 | """Calculate the context vector and attention distribution from the decoder state.
81 |
82 | Args:
83 | decoder_state: state of the decoder
84 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1).
85 |
86 | Returns:
87 | context_vector: weighted sum of encoder_states
88 | attn_dist: attention distribution
89 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1)
90 | """
91 | with variable_scope.variable_scope("Attention"):
92 | # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
93 | decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size)
94 | decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size)
95 |
96 | def masked_attention(e):
97 | """Take softmax of e then apply enc_padding_mask and re-normalize"""
98 | attn_dist = nn_ops.softmax(e) # take softmax. shape (batch_size, attn_length)
99 | attn_dist *= enc_padding_mask # apply mask
100 | masked_sums = tf.reduce_sum(attn_dist, axis=1) # shape (batch_size)
101 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) # re-normalize
102 |
103 | if use_coverage and coverage is not None: # non-first step of coverage
104 | # Multiply coverage vector by w_c to get coverage_features.
105 | coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, attn_length, 1, attention_vec_size)
106 |
107 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
108 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) # shape (batch_size,attn_length)
109 |
110 | # Calculate attention distribution
111 | attn_dist = masked_attention(e)
112 |
113 | # Update coverage vector
114 | coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1])
115 | else:
116 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
117 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e
118 |
119 | # Calculate attention distribution
120 | attn_dist = masked_attention(e)
121 |
122 | if use_coverage: # first step of training
123 | coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage
124 |
125 | # Calculate the context vector from attn_dist and encoder_states
126 | context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) # shape (batch_size, attn_size).
127 | context_vector = array_ops.reshape(context_vector, [-1, attn_size])
128 |
129 | return context_vector, attn_dist, coverage
130 |
131 | outputs = []
132 | attn_dists = []
133 | p_gens = []
134 | state = initial_state
135 | coverage = prev_coverage # initialize coverage to None or whatever was passed in
136 | context_vector = array_ops.zeros([batch_size, attn_size])
137 | context_vector.set_shape([None, attn_size]) # Ensure the second shape of attention vectors is set.
138 | if initial_state_attention: # true in decode mode
139 | # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input
140 | context_vector, _, coverage = attention(initial_state, coverage) # in decode mode, this is what updates the coverage vector
141 | for i, inp in enumerate(decoder_inputs):
142 | tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(decoder_inputs))
143 | if i > 0:
144 | variable_scope.get_variable_scope().reuse_variables()
145 |
146 | # Merge input and previous attentions into one vector x of the same size as inp
147 | input_size = inp.get_shape().with_rank(2)[1]
148 | if input_size.value is None:
149 | raise ValueError("Could not infer input size from input: %s" % inp.name)
150 | x = linear([inp] + [context_vector], input_size, True)
151 |
152 | # Run the decoder RNN cell. cell_output = decoder state
153 | cell_output, state = cell(x, state)
154 |
155 | # Run the attention mechanism.
156 | if i == 0 and initial_state_attention: # always true in decode mode
157 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call
158 | context_vector, attn_dist, _ = attention(state, coverage) # don't allow coverage to update
159 | else:
160 | context_vector, attn_dist, coverage = attention(state, coverage)
161 | attn_dists.append(attn_dist)
162 |
163 | # Calculate p_gen
164 | if pointer_gen:
165 | with tf.variable_scope('calculate_pgen'):
166 | p_gen = linear([context_vector, state.c, state.h, x], 1, True) # a scalar
167 | p_gen = tf.sigmoid(p_gen)
168 | p_gens.append(p_gen)
169 |
170 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer
171 | # This is V[s_t, h*_t] + b in the paper
172 | with variable_scope.variable_scope("AttnOutputProjection"):
173 | output = linear([cell_output] + [context_vector], cell.output_size, True)
174 | outputs.append(output)
175 |
176 | # If using coverage, reshape it
177 | if coverage is not None:
178 | coverage = array_ops.reshape(coverage, [batch_size, -1])
179 |
180 | return outputs, state, attn_dists, p_gens, coverage
181 |
182 |
183 |
184 | def linear(args, output_size, bias, bias_start=0.0, scope=None):
185 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
186 |
187 | Args:
188 | args: a 2D Tensor or a list of 2D, batch x n, Tensors.
189 | output_size: int, second dimension of W[i].
190 | bias: boolean, whether to add a bias term or not.
191 | bias_start: starting value to initialize the bias; 0 by default.
192 | scope: VariableScope for the created subgraph; defaults to "Linear".
193 |
194 | Returns:
195 | A 2D Tensor with shape [batch x output_size] equal to
196 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
197 |
198 | Raises:
199 | ValueError: if some of the arguments has unspecified or wrong shape.
200 | """
201 | if args is None or (isinstance(args, (list, tuple)) and not args):
202 | raise ValueError("`args` must be specified")
203 | if not isinstance(args, (list, tuple)):
204 | args = [args]
205 |
206 | # Calculate the total size of arguments on dimension 1.
207 | total_arg_size = 0
208 | shapes = [a.get_shape().as_list() for a in args]
209 | for shape in shapes:
210 | if len(shape) != 2:
211 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
212 | if not shape[1]:
213 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
214 | else:
215 | total_arg_size += shape[1]
216 |
217 | # Now the computation.
218 | with tf.variable_scope(scope or "Linear"):
219 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
220 | if len(args) == 1:
221 | res = tf.matmul(args[0], matrix)
222 | else:
223 | res = tf.matmul(tf.concat(axis=1, values=args), matrix)
224 | if not bias:
225 | return res
226 | bias_term = tf.get_variable(
227 | "Bias", [output_size], initializer=tf.constant_initializer(bias_start))
228 | return res + bias_term
229 |
--------------------------------------------------------------------------------
/batcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to process data into batches"""
18 |
19 | import queue as Queue
20 | from random import shuffle
21 | from threading import Thread
22 | import time
23 | import numpy as np
24 | import tensorflow as tf
25 | import data
26 |
27 |
28 | class Example(object):
29 | """Class representing a train/val/test example for text summarization."""
30 |
31 | def __init__(self, article, abstract_sentences, vocab, hps):
32 | """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.
33 |
34 | Args:
35 | article: source text; a string. each token is separated by a single space.
36 | abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space.
37 | vocab: Vocabulary object
38 | hps: hyperparameters
39 | """
40 | self.hps = hps
41 |
42 | # Get ids of special tokens
43 | start_decoding = vocab.word2id(data.START_DECODING)
44 | stop_decoding = vocab.word2id(data.STOP_DECODING)
45 |
46 | # Process the article
47 | article_words = article.split()
48 | if len(article_words) > hps.max_enc_steps:
49 | article_words = article_words[:hps.max_enc_steps]
50 | self.enc_len = len(article_words) # store the length after truncation but before padding
51 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token
52 |
53 | # Process the abstract
54 | abstract = ' '.join(abstract_sentences) # string
55 | abstract_words = abstract.split() # list of strings
56 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token
57 |
58 | # Get the decoder input sequence and target sequence
59 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding, stop_decoding)
60 | self.dec_len = len(self.dec_input)
61 |
62 | # If using pointer-generator mode, we need to store some extra info
63 | if hps.pointer_gen:
64 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
65 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
66 |
67 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
68 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)
69 |
70 | # Overwrite decoder target sequence so it uses the temp article OOV ids
71 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding, stop_decoding)
72 |
73 | # Store the original strings
74 | self.original_article = article
75 | self.original_abstract = abstract
76 | self.original_abstract_sents = abstract_sentences
77 |
78 |
79 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
80 | """Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated).
81 |
82 | Args:
83 | sequence: List of ids (integers)
84 | max_len: integer
85 | start_id: integer
86 | stop_id: integer
87 |
88 | Returns:
89 | inp: sequence length <=max_len starting with start_id
90 | target: sequence same length as input, ending with stop_id only if there was no truncation
91 | """
92 | inp = [start_id] + sequence[:]
93 | target = sequence[:]
94 | if len(inp) > max_len: # truncate
95 | inp = inp[:max_len]
96 | target = target[:max_len] # no end_token
97 | else: # no truncation
98 | target.append(stop_id) # end token
99 | assert len(inp) == len(target)
100 | return inp, target
101 |
102 |
103 | def pad_decoder_inp_targ(self, max_len, pad_id):
104 | """Pad decoder input and target sequences with pad_id up to max_len."""
105 | while len(self.dec_input) < max_len:
106 | self.dec_input.append(pad_id)
107 | while len(self.target) < max_len:
108 | self.target.append(pad_id)
109 |
110 |
111 | def pad_encoder_input(self, max_len, pad_id):
112 | """Pad the encoder input sequence with pad_id up to max_len."""
113 | while len(self.enc_input) < max_len:
114 | self.enc_input.append(pad_id)
115 | if self.hps.pointer_gen:
116 | while len(self.enc_input_extend_vocab) < max_len:
117 | self.enc_input_extend_vocab.append(pad_id)
118 |
119 |
120 | class Batch(object):
121 | """Class representing a minibatch of train/val/test examples for text summarization."""
122 |
123 | def __init__(self, example_list, hps, vocab):
124 | """Turns the example_list into a Batch object.
125 |
126 | Args:
127 | example_list: List of Example objects
128 | hps: hyperparameters
129 | vocab: Vocabulary object
130 | """
131 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences
132 | self.init_encoder_seq(example_list, hps) # initialize the input to the encoder
133 | self.init_decoder_seq(example_list, hps) # initialize the input and targets for the decoder
134 | self.store_orig_strings(example_list) # store the original strings
135 |
136 | def init_encoder_seq(self, example_list, hps):
137 | """Initializes the following:
138 | self.enc_batch:
139 | numpy array of shape (batch_size, <=max_enc_steps) containing integer ids (all OOVs represented by UNK id), padded to length of longest sequence in the batch
140 | self.enc_lens:
141 | numpy array of shape (batch_size) containing integers. The (truncated) length of each encoder input sequence (pre-padding).
142 | self.enc_padding_mask:
143 | numpy array of shape (batch_size, <=max_enc_steps), containing 1s and 0s. 1s correspond to real tokens in enc_batch and target_batch; 0s correspond to padding.
144 |
145 | If hps.pointer_gen, additionally initializes the following:
146 | self.max_art_oovs:
147 | maximum number of in-article OOVs in the batch
148 | self.art_oovs:
149 | list of list of in-article OOVs (strings), for each example in the batch
150 | self.enc_batch_extend_vocab:
151 | Same as self.enc_batch, but in-article OOVs are represented by their temporary article OOV number.
152 | """
153 | # Determine the maximum length of the encoder input sequence in this batch
154 | max_enc_seq_len = max([ex.enc_len for ex in example_list])
155 |
156 | # Pad the encoder input sequences up to the length of the longest sequence
157 | for ex in example_list:
158 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id)
159 |
160 | # Initialize the numpy arrays
161 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder.
162 | self.enc_batch = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32)
163 | self.enc_lens = np.zeros((hps.batch_size), dtype=np.int32)
164 | self.enc_padding_mask = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.float32)
165 |
166 | # Fill in the numpy arrays
167 | for i, ex in enumerate(example_list):
168 | self.enc_batch[i, :] = ex.enc_input[:]
169 | self.enc_lens[i] = ex.enc_len
170 | for j in range(ex.enc_len):
171 | self.enc_padding_mask[i][j] = 1
172 |
173 | # For pointer-generator mode, need to store some extra info
174 | if hps.pointer_gen:
175 | # Determine the max number of in-article OOVs in this batch
176 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
177 | # Store the in-article OOVs themselves
178 | self.art_oovs = [ex.article_oovs for ex in example_list]
179 | # Store the version of the enc_batch that uses the article OOV ids
180 | self.enc_batch_extend_vocab = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32)
181 | for i, ex in enumerate(example_list):
182 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:]
183 |
184 | def init_decoder_seq(self, example_list, hps):
185 | """Initializes the following:
186 | self.dec_batch:
187 | numpy array of shape (batch_size, max_dec_steps), containing integer ids as input for the decoder, padded to max_dec_steps length.
188 | self.target_batch:
189 | numpy array of shape (batch_size, max_dec_steps), containing integer ids for the target sequence, padded to max_dec_steps length.
190 | self.dec_padding_mask:
191 | numpy array of shape (batch_size, max_dec_steps), containing 1s and 0s. 1s correspond to real tokens in dec_batch and target_batch; 0s correspond to padding.
192 | """
193 | # Pad the inputs and targets
194 | for ex in example_list:
195 | ex.pad_decoder_inp_targ(hps.max_dec_steps, self.pad_id)
196 |
197 | # Initialize the numpy arrays.
198 | # Note: our decoder inputs and targets must be the same length for each batch (second dimension = max_dec_steps) because we do not use a dynamic_rnn for decoding. However I believe this is possible, or will soon be possible, with Tensorflow 1.0, in which case it may be best to upgrade to that.
199 | self.dec_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32)
200 | self.target_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32)
201 | self.dec_padding_mask = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.float32)
202 |
203 | # Fill in the numpy arrays
204 | for i, ex in enumerate(example_list):
205 | self.dec_batch[i, :] = ex.dec_input[:]
206 | self.target_batch[i, :] = ex.target[:]
207 | for j in range(ex.dec_len):
208 | self.dec_padding_mask[i][j] = 1
209 |
210 | def store_orig_strings(self, example_list):
211 | """Store the original article and abstract strings in the Batch object"""
212 | self.original_articles = [ex.original_article for ex in example_list] # list of lists
213 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
214 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists
215 |
216 |
217 | class Batcher(object):
218 | """A class to generate minibatches of data. Buckets examples together based on length of the encoder sequence."""
219 |
220 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold
221 |
222 | def __init__(self, data_path, vocab, hps, single_pass):
223 | """Initialize the batcher. Start threads that process the data into batches.
224 |
225 | Args:
226 | data_path: tf.Example filepattern.
227 | vocab: Vocabulary object
228 | hps: hyperparameters
229 | single_pass: If True, run through the dataset exactly once (useful for when you want to run evaluation on the dev or test set). Otherwise generate random batches indefinitely (useful for training).
230 | """
231 | self._data_path = data_path
232 | self._vocab = vocab
233 | self._hps = hps
234 | self._single_pass = single_pass
235 |
236 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
237 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX)
238 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size)
239 |
240 | # Different settings depending on whether we're in single_pass mode or not
241 | if single_pass:
242 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
243 | self._num_batch_q_threads = 1 # just one thread to batch examples
244 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing
245 | self._finished_reading = False # this will tell us when we're finished reading the dataset
246 | else:
247 | self._num_example_q_threads = 16 # num threads to fill example queue
248 | self._num_batch_q_threads = 4 # num threads to fill batch queue
249 | self._bucketing_cache_size = 100 # how many batches-worth of examples to load into cache before bucketing
250 |
251 | # Start the threads that load the queues
252 | self._example_q_threads = []
253 | for _ in range(self._num_example_q_threads):
254 | self._example_q_threads.append(Thread(target=self.fill_example_queue))
255 | self._example_q_threads[-1].daemon = True
256 | self._example_q_threads[-1].start()
257 | self._batch_q_threads = []
258 | for _ in range(self._num_batch_q_threads):
259 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
260 | self._batch_q_threads[-1].daemon = True
261 | self._batch_q_threads[-1].start()
262 |
263 | # Start a thread that watches the other threads and restarts them if they're dead
264 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
265 | self._watch_thread = Thread(target=self.watch_threads)
266 | self._watch_thread.daemon = True
267 | self._watch_thread.start()
268 |
269 |
270 | def next_batch(self):
271 | """Return a Batch from the batch queue.
272 |
273 | If mode='decode' then each batch contains a single example repeated beam_size-many times; this is necessary for beam search.
274 |
275 | Returns:
276 | batch: a Batch object, or None if we're in single_pass mode and we've exhausted the dataset.
277 | """
278 | # If the batch queue is empty, print a warning
279 | if self._batch_queue.qsize() == 0:
280 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize())
281 | if self._single_pass and self._finished_reading:
282 | tf.logging.info("Finished reading dataset in single_pass mode.")
283 | return None
284 |
285 | batch = self._batch_queue.get() # get the next Batch
286 | return batch
287 |
288 | def fill_example_queue(self):
289 | """Reads data from file and processes into Examples which are then placed into the example queue."""
290 |
291 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass))
292 |
293 | while True:
294 | try:
295 | (article, abstract) = next(input_gen) # read the next example from file. article and abstract are both strings.
296 | except StopIteration: # if there are no more examples:
297 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
298 | if self._single_pass:
299 | tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
300 | self._finished_reading = True
301 | break
302 | else:
303 | raise Exception("single_pass mode is off but the example generator is out of data; error.")
304 |
305 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences.
306 | example = Example(article, abstract_sentences, self._vocab, self._hps) # Process into an Example.
307 | self._example_queue.put(example) # place the Example in the example queue.
308 |
309 |
310 | def fill_batch_queue(self):
311 | """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue.
312 |
313 | In decode mode, makes batches that each contain a single example repeated.
314 | """
315 | while True:
316 | if self._hps.mode != 'decode':
317 | # Get bucketing_cache_size-many batches of Examples into a list, then sort
318 | inputs = []
319 | for _ in range(self._hps.batch_size * self._bucketing_cache_size):
320 | inputs.append(self._example_queue.get())
321 | inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence
322 |
323 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
324 | batches = []
325 | for i in range(0, len(inputs), self._hps.batch_size):
326 | batches.append(inputs[i:i + self._hps.batch_size])
327 | if not self._single_pass:
328 | shuffle(batches)
329 | for b in batches: # each b is a list of Example objects
330 | self._batch_queue.put(Batch(b, self._hps, self._vocab))
331 |
332 | else: # beam search decode mode
333 | ex = self._example_queue.get()
334 | b = [ex for _ in range(self._hps.batch_size)]
335 | self._batch_queue.put(Batch(b, self._hps, self._vocab))
336 |
337 |
338 | def watch_threads(self):
339 | """Watch example queue and batch queue threads and restart if dead."""
340 | while True:
341 | time.sleep(60)
342 | for idx,t in enumerate(self._example_q_threads):
343 | if not t.is_alive(): # if the thread is dead
344 | tf.logging.error('Found example queue thread dead. Restarting.')
345 | new_t = Thread(target=self.fill_example_queue)
346 | self._example_q_threads[idx] = new_t
347 | new_t.daemon = True
348 | new_t.start()
349 | for idx,t in enumerate(self._batch_q_threads):
350 | if not t.is_alive(): # if the thread is dead
351 | tf.logging.error('Found batch queue thread dead. Restarting.')
352 | new_t = Thread(target=self.fill_batch_queue)
353 | self._batch_q_threads[idx] = new_t
354 | new_t.daemon = True
355 | new_t.start()
356 |
357 |
358 | def text_generator(self, example_generator):
359 | """Generates article and abstract text from tf.Example.
360 |
361 | Args:
362 | example_generator: a generator of tf.Examples from file. See data.example_generator"""
363 | while True:
364 | e = next(example_generator) # e is a tf.Example
365 | try:
366 | article_text = e.features.feature['article'].bytes_list.value[0].decode() # the article text was saved under the key 'article' in the data files
367 | abstract_text = e.features.feature['abstract'].bytes_list.value[0].decode() # the abstract text was saved under the key 'abstract' in the data files
368 | except ValueError:
369 | tf.logging.error('Failed to get article or abstract from example')
370 | continue
371 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1
372 | tf.logging.warning('Found an example with empty article text. Skipping it.')
373 | else:
374 | yield (article_text, abstract_text)
375 |
--------------------------------------------------------------------------------
/beam_search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to run beam search decoding"""
18 |
19 | import tensorflow as tf
20 | import numpy as np
21 | import data
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 | class Hypothesis(object):
26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis."""
27 |
28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage):
29 | """Hypothesis constructor.
30 |
31 | Args:
32 | tokens: List of integers. The ids of the tokens that form the summary so far.
33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far.
34 | state: Current state of the decoder, a LSTMStateTuple.
35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far.
36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far.
37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector.
38 | """
39 | self.tokens = tokens
40 | self.log_probs = log_probs
41 | self.state = state
42 | self.attn_dists = attn_dists
43 | self.p_gens = p_gens
44 | self.coverage = coverage
45 |
46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage):
47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search.
48 |
49 | Args:
50 | token: Integer. Latest token produced by beam search.
51 | log_prob: Float. Log prob of the latest token.
52 | state: Current decoder state, a LSTMStateTuple.
53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length).
54 | p_gen: Generation probability on latest step. Float.
55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage.
56 | Returns:
57 | New Hypothesis for next step.
58 | """
59 | return Hypothesis(tokens = self.tokens + [token],
60 | log_probs = self.log_probs + [log_prob],
61 | state = state,
62 | attn_dists = self.attn_dists + [attn_dist],
63 | p_gens = self.p_gens + [p_gen],
64 | coverage = coverage)
65 |
66 | @property
67 | def latest_token(self):
68 | return self.tokens[-1]
69 |
70 | @property
71 | def log_prob(self):
72 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
73 | return sum(self.log_probs)
74 |
75 | @property
76 | def avg_log_prob(self):
77 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
78 | return self.log_prob / len(self.tokens)
79 |
80 |
81 | def run_beam_search(sess, model, vocab, batch):
82 | """Performs beam search decoding on the given example.
83 |
84 | Args:
85 | sess: a tf.Session
86 | model: a seq2seq model
87 | vocab: Vocabulary object
88 | batch: Batch object that is the same example repeated across the batch
89 |
90 | Returns:
91 | best_hyp: Hypothesis object; the best hypothesis found by beam search.
92 | """
93 | # Run the encoder to get the encoder hidden states and decoder initial state
94 | enc_states, dec_in_state = model.run_encoder(sess, batch)
95 | # dec_in_state is a LSTMStateTuple
96 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim].
97 |
98 | # Initialize beam_size-many hyptheses
99 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)],
100 | log_probs=[0.0],
101 | state=dec_in_state,
102 | attn_dists=[],
103 | p_gens=[],
104 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length
105 | ) for _ in range(FLAGS.beam_size)]
106 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)
107 |
108 | steps = 0
109 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size:
110 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
111 | latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
112 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses
113 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)
114 |
115 | # Run one step of the decoder to get the new info
116 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess,
117 | batch=batch,
118 | latest_tokens=latest_tokens,
119 | enc_states=enc_states,
120 | dec_init_states=states,
121 | prev_coverage=prev_coverage)
122 |
123 | # Extend each hypothesis and collect them all in all_hyps
124 | all_hyps = []
125 | num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
126 | for i in range(num_orig_hyps):
127 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info
128 | for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
129 | # Extend the ith hypothesis with the jth option
130 | new_hyp = h.extend(token=topk_ids[i, j],
131 | log_prob=topk_log_probs[i, j],
132 | state=new_state,
133 | attn_dist=attn_dist,
134 | p_gen=p_gen,
135 | coverage=new_coverage_i)
136 | all_hyps.append(new_hyp)
137 |
138 | # Filter and collect any hypotheses that have produced the end token.
139 | hyps = [] # will contain hypotheses for the next step
140 | for h in sort_hyps(all_hyps): # in order of most likely h
141 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
142 | # If this hypothesis is sufficiently long, put in results. Otherwise discard.
143 | if steps >= FLAGS.min_dec_steps:
144 | results.append(h)
145 | else: # hasn't reached stop token, so continue to extend this hypothesis
146 | hyps.append(h)
147 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
148 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
149 | break
150 |
151 | steps += 1
152 |
153 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps
154 |
155 | if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
156 | results = hyps
157 |
158 | # Sort hypotheses by average log probability
159 | hyps_sorted = sort_hyps(results)
160 |
161 | # Return the hypothesis with highest average log prob
162 | return hyps_sorted[0]
163 |
164 | def sort_hyps(hyps):
165 | """Return a list of Hypothesis objects, sorted by descending average log probability"""
166 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)
167 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
18 |
19 | import glob
20 | import random
21 | import struct
22 | import csv
23 | from tensorflow.core.example import example_pb2
24 |
25 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
26 | SENTENCE_START = ''
27 | SENTENCE_END = ''
28 |
29 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
30 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
31 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
32 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
33 |
34 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
35 |
36 |
37 | class Vocab(object):
38 | """Vocabulary class for mapping between words and ids (integers)"""
39 |
40 | def __init__(self, vocab_file, max_size):
41 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.
42 |
43 | Args:
44 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
45 | max_size: integer. The maximum size of the resulting Vocabulary."""
46 | self._word_to_id = {}
47 | self._id_to_word = {}
48 | self._count = 0 # keeps track of total number of words in the Vocab
49 |
50 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
51 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
52 | self._word_to_id[w] = self._count
53 | self._id_to_word[self._count] = w
54 | self._count += 1
55 |
56 | # Read the vocab file and add words up to max_size
57 | with open(vocab_file, 'r') as vocab_f:
58 | for line in vocab_f:
59 | pieces = line.split()
60 | if len(pieces) != 2:
61 | print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
62 | continue
63 | w = pieces[0]
64 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
65 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
66 | if w in self._word_to_id:
67 | raise Exception('Duplicated word in vocabulary file: %s' % w)
68 | self._word_to_id[w] = self._count
69 | self._id_to_word[self._count] = w
70 | self._count += 1
71 | if max_size != 0 and self._count >= max_size:
72 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
73 | break
74 |
75 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))
76 |
77 | def word2id(self, word):
78 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
79 | if word not in self._word_to_id:
80 | return self._word_to_id[UNKNOWN_TOKEN]
81 | return self._word_to_id[word]
82 |
83 | def id2word(self, word_id):
84 | """Returns the word (string) corresponding to an id (integer)."""
85 | if word_id not in self._id_to_word:
86 | raise ValueError('Id not found in vocab: %d' % word_id)
87 | return self._id_to_word[word_id]
88 |
89 | def size(self):
90 | """Returns the total size of the vocabulary"""
91 | return self._count
92 |
93 | def write_metadata(self, fpath):
94 | """Writes metadata file for Tensorboard word embedding visualizer as described here:
95 | https://www.tensorflow.org/get_started/embedding_viz
96 |
97 | Args:
98 | fpath: place to write the metadata file
99 | """
100 | print("Writing word embedding metadata file to %s..." % (fpath))
101 | with open(fpath, "w") as f:
102 | fieldnames = ['word']
103 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
104 | for i in range(self.size()):
105 | writer.writerow({"word": self._id_to_word[i]})
106 |
107 |
108 | def example_generator(data_path, single_pass):
109 | """Generates tf.Examples from data files.
110 |
111 | Binary data format: . represents the byte size
112 | of . is serialized tf.Example proto. The tf.Example contains
113 | the tokenized article text and summary.
114 |
115 | Args:
116 | data_path:
117 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all.
118 | single_pass:
119 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely.
120 |
121 | Yields:
122 | Deserialized tf.Example.
123 | """
124 | while True:
125 | filelist = glob.glob(data_path) # get the list of datafiles
126 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
127 | if single_pass:
128 | filelist = sorted(filelist)
129 | else:
130 | random.shuffle(filelist)
131 | for f in filelist:
132 | reader = open(f, 'rb')
133 | while True:
134 | len_bytes = reader.read(8)
135 | if not len_bytes: break # finished reading this file
136 | str_len = struct.unpack('q', len_bytes)[0]
137 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
138 | yield example_pb2.Example.FromString(example_str)
139 | if single_pass:
140 | print("example_generator completed reading all datafiles. No more data.")
141 | break
142 |
143 |
144 | def article2ids(article_words, vocab):
145 | """Map the article words to their ids. Also return a list of OOVs in the article.
146 |
147 | Args:
148 | article_words: list of words (strings)
149 | vocab: Vocabulary object
150 |
151 | Returns:
152 | ids:
153 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002.
154 | oovs:
155 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers."""
156 | ids = []
157 | oovs = []
158 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
159 | for w in article_words:
160 | i = vocab.word2id(w)
161 | if i == unk_id: # If w is OOV
162 | if w not in oovs: # Add to list of OOVs
163 | oovs.append(w)
164 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
165 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
166 | else:
167 | ids.append(i)
168 | return ids, oovs
169 |
170 |
171 | def abstract2ids(abstract_words, vocab, article_oovs):
172 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers.
173 |
174 | Args:
175 | abstract_words: list of words (strings)
176 | vocab: Vocabulary object
177 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers
178 |
179 | Returns:
180 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id."""
181 | ids = []
182 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
183 | for w in abstract_words:
184 | i = vocab.word2id(w)
185 | if i == unk_id: # If w is an OOV word
186 | if w in article_oovs: # If w is an in-article OOV
187 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
188 | ids.append(vocab_idx)
189 | else: # If w is an out-of-article OOV
190 | ids.append(unk_id) # Map to the UNK token id
191 | else:
192 | ids.append(i)
193 | return ids
194 |
195 |
196 | def outputids2words(id_list, vocab, article_oovs):
197 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode).
198 |
199 | Args:
200 | id_list: list of ids (integers)
201 | vocab: Vocabulary object
202 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode)
203 |
204 | Returns:
205 | words: list of words (strings)
206 | """
207 | words = []
208 | for i in id_list:
209 | try:
210 | w = vocab.id2word(i) # might be [UNK]
211 | except ValueError as e: # w is OOV
212 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
213 | article_oov_idx = i - vocab.size()
214 | try:
215 | w = article_oovs[article_oov_idx]
216 | except ValueError as e: # i doesn't correspond to an article oov
217 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
218 | words.append(w)
219 | return words
220 |
221 |
222 | def abstract2sents(abstract):
223 | """Splits abstract text from datafile into list of sentences.
224 |
225 | Args:
226 | abstract: string containing and tags for starts and ends of sentences
227 |
228 | Returns:
229 | sents: List of sentence strings (no tags)"""
230 | cur = 0
231 | sents = []
232 | while True:
233 | try:
234 | start_p = abstract.index(SENTENCE_START, cur)
235 | end_p = abstract.index(SENTENCE_END, start_p + 1)
236 | cur = end_p + len(SENTENCE_END)
237 | sents.append(abstract[start_p+len(SENTENCE_START):end_p])
238 | except ValueError as e: # no more sentences
239 | return sents
240 |
241 |
242 | def show_art_oovs(article, vocab):
243 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them"""
244 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
245 | words = article.split(' ')
246 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words]
247 | out_str = ' '.join(words)
248 | return out_str
249 |
250 |
251 | def show_abs_oovs(abstract, vocab, article_oovs):
252 | """Returns the abstract string, highlighting the article OOVs with __underscores__.
253 |
254 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!.
255 |
256 | Args:
257 | abstract: string
258 | vocab: Vocabulary object
259 | article_oovs: list of words (strings), or None (in baseline mode)
260 | """
261 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
262 | words = abstract.split(' ')
263 | new_words = []
264 | for w in words:
265 | if vocab.word2id(w) == unk_token: # w is oov
266 | if article_oovs is None: # baseline mode
267 | new_words.append("__%s__" % w)
268 | else: # pointer-generator mode
269 | if w in article_oovs:
270 | new_words.append("__%s__" % w)
271 | else:
272 | new_words.append("!!__%s__!!" % w)
273 | else: # w is in-vocab word
274 | new_words.append(w)
275 | out_str = ' '.join(new_words)
276 | return out_str
277 |
--------------------------------------------------------------------------------
/decode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to run beam search decoding, including running ROUGE evaluation and producing JSON datafiles for the in-browser attention visualizer, which can be found here https://github.com/abisee/attn_vis"""
18 |
19 | import os
20 | import time
21 | import tensorflow as tf
22 | import beam_search
23 | import data
24 | import json
25 | import pyrouge
26 | import util
27 | import logging
28 | import numpy as np
29 |
30 | FLAGS = tf.app.flags.FLAGS
31 |
32 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint
33 |
34 |
35 | class BeamSearchDecoder(object):
36 | """Beam search decoder."""
37 |
38 | def __init__(self, model, batcher, vocab):
39 | """Initialize decoder.
40 |
41 | Args:
42 | model: a Seq2SeqAttentionModel object.
43 | batcher: a Batcher object.
44 | vocab: Vocabulary object
45 | """
46 | self._model = model
47 | self._model.build_graph()
48 | self._batcher = batcher
49 | self._vocab = vocab
50 | self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
51 | self._sess = tf.Session(config=util.get_config())
52 |
53 | # Load an initial checkpoint to use for decoding
54 | ckpt_path = util.load_ckpt(self._saver, self._sess)
55 |
56 | if FLAGS.single_pass:
57 | # Make a descriptive decode directory name
58 | ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456"
59 | self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
60 | if os.path.exists(self._decode_dir):
61 | raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir)
62 |
63 | else: # Generic decode dir name
64 | self._decode_dir = os.path.join(FLAGS.log_root, "decode")
65 |
66 | # Make the decode dir if necessary
67 | if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)
68 |
69 | if FLAGS.single_pass:
70 | # Make the dirs to contain output written in the correct format for pyrouge
71 | self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
72 | if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
73 | self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
74 | if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
75 |
76 |
77 | def decode(self):
78 | """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
79 | t0 = time.time()
80 | counter = 0
81 | while True:
82 | batch = self._batcher.next_batch() # 1 example repeated across batch
83 | if batch is None: # finished decoding dataset in single_pass mode
84 | assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
85 | tf.logging.info("Decoder has finished reading dataset for single_pass.")
86 | tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
87 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
88 | rouge_log(results_dict, self._decode_dir)
89 | return
90 |
91 | original_article = batch.original_articles[0] # string
92 | original_abstract = batch.original_abstracts[0] # string
93 | original_abstract_sents = batch.original_abstracts_sents[0] # list of strings
94 |
95 | article_withunks = data.show_art_oovs(original_article, self._vocab) # string
96 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
97 |
98 | # Run beam search to get best Hypothesis
99 | best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
100 |
101 | # Extract the output ids from the hypothesis and convert back to words
102 | output_ids = [int(t) for t in best_hyp.tokens[1:]]
103 | decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
104 |
105 | # Remove the [STOP] token from decoded_words, if necessary
106 | try:
107 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
108 | decoded_words = decoded_words[:fst_stop_idx]
109 | except ValueError:
110 | decoded_words = decoded_words
111 | decoded_output = ' '.join(decoded_words) # single string
112 |
113 | if FLAGS.single_pass:
114 | self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
115 | counter += 1 # this is how many examples we've decoded
116 | else:
117 | print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
118 | self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool
119 |
120 | # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
121 | t1 = time.time()
122 | if t1-t0 > SECS_UNTIL_NEW_CKPT:
123 | tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
124 | _ = util.load_ckpt(self._saver, self._sess)
125 | t0 = time.time()
126 |
127 | def write_for_rouge(self, reference_sents, decoded_words, ex_index):
128 | """Write output to file in correct format for eval with pyrouge. This is called in single_pass mode.
129 |
130 | Args:
131 | reference_sents: list of strings
132 | decoded_words: list of strings
133 | ex_index: int, the index with which to label the files
134 | """
135 | # First, divide decoded output into sentences
136 | decoded_sents = []
137 | while len(decoded_words) > 0:
138 | try:
139 | fst_period_idx = decoded_words.index(".")
140 | except ValueError: # there is text remaining that doesn't end in "."
141 | fst_period_idx = len(decoded_words)
142 | sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period
143 | decoded_words = decoded_words[fst_period_idx+1:] # everything else
144 | decoded_sents.append(' '.join(sent))
145 |
146 | # pyrouge calls a perl script that puts the data into HTML files.
147 | # Therefore we need to make our output HTML safe.
148 | decoded_sents = [make_html_safe(w) for w in decoded_sents]
149 | reference_sents = [make_html_safe(w) for w in reference_sents]
150 |
151 | # Write to file
152 | ref_file = os.path.join(self._rouge_ref_dir, "%06d_reference.txt" % ex_index)
153 | decoded_file = os.path.join(self._rouge_dec_dir, "%06d_decoded.txt" % ex_index)
154 |
155 | with open(ref_file, "w") as f:
156 | for idx,sent in enumerate(reference_sents):
157 | f.write(sent) if idx==len(reference_sents)-1 else f.write(sent+"\n")
158 | with open(decoded_file, "w") as f:
159 | for idx,sent in enumerate(decoded_sents):
160 | f.write(sent) if idx==len(decoded_sents)-1 else f.write(sent+"\n")
161 |
162 | tf.logging.info("Wrote example %i to file" % ex_index)
163 |
164 |
165 | def write_for_attnvis(self, article, abstract, decoded_words, attn_dists, p_gens):
166 | """Write some data to json file, which can be read into the in-browser attention visualizer tool:
167 | https://github.com/abisee/attn_vis
168 |
169 | Args:
170 | article: The original article string.
171 | abstract: The human (correct) abstract string.
172 | attn_dists: List of arrays; the attention distributions.
173 | decoded_words: List of strings; the words of the generated summary.
174 | p_gens: List of scalars; the p_gen values. If not running in pointer-generator mode, list of None.
175 | """
176 | article_lst = article.split() # list of words
177 | decoded_lst = decoded_words # list of decoded words
178 | to_write = {
179 | 'article_lst': [make_html_safe(t) for t in article_lst],
180 | 'decoded_lst': [make_html_safe(t) for t in decoded_lst],
181 | 'abstract_str': make_html_safe(abstract),
182 | 'attn_dists': attn_dists
183 | }
184 | if FLAGS.pointer_gen:
185 | to_write['p_gens'] = p_gens
186 | output_fname = os.path.join(self._decode_dir, 'attn_vis_data.json')
187 | with open(output_fname, 'w') as output_file:
188 | json.dump(to_write, output_file)
189 | tf.logging.info('Wrote visualization data to %s', output_fname)
190 |
191 |
192 | def print_results(article, abstract, decoded_output):
193 | """Prints the article, the reference summmary and the decoded summary to screen"""
194 | print("---------------------------------------------------------------------------")
195 | tf.logging.info('ARTICLE: %s', article)
196 | tf.logging.info('REFERENCE SUMMARY: %s', abstract)
197 | tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
198 | print("---------------------------------------------------------------------------")
199 |
200 |
201 | def make_html_safe(s):
202 | """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer."""
203 | s.replace("<", "<")
204 | s.replace(">", ">")
205 | return s
206 |
207 |
208 | def rouge_eval(ref_dir, dec_dir):
209 | """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict"""
210 | r = pyrouge.Rouge155()
211 | r.model_filename_pattern = '#ID#_reference.txt'
212 | r.system_filename_pattern = '(\d+)_decoded.txt'
213 | r.model_dir = ref_dir
214 | r.system_dir = dec_dir
215 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
216 | rouge_results = r.convert_and_evaluate()
217 | return r.output_to_dict(rouge_results)
218 |
219 |
220 | def rouge_log(results_dict, dir_to_write):
221 | """Log ROUGE results to screen and write to file.
222 |
223 | Args:
224 | results_dict: the dictionary returned by pyrouge
225 | dir_to_write: the directory where we will write the results to"""
226 | log_str = ""
227 | for x in ["1","2","l"]:
228 | log_str += "\nROUGE-%s:\n" % x
229 | for y in ["f_score", "recall", "precision"]:
230 | key = "rouge_%s_%s" % (x,y)
231 | key_cb = key + "_cb"
232 | key_ce = key + "_ce"
233 | val = results_dict[key]
234 | val_cb = results_dict[key_cb]
235 | val_ce = results_dict[key_ce]
236 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
237 | tf.logging.info(log_str) # log to screen
238 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
239 | tf.logging.info("Writing final ROUGE results to %s...", results_file)
240 | with open(results_file, "w") as f:
241 | f.write(log_str)
242 |
243 | def get_decode_dir_name(ckpt_name):
244 | """Make a descriptive name for the decode dir, including the name of the checkpoint we use to decode. This is called in single_pass mode."""
245 |
246 | if "train" in FLAGS.data_path: dataset = "train"
247 | elif "val" in FLAGS.data_path: dataset = "val"
248 | elif "test" in FLAGS.data_path: dataset = "test"
249 | else: raise ValueError("FLAGS.data_path %s should contain one of train, val or test" % (FLAGS.data_path))
250 | dirname = "decode_%s_%imaxenc_%ibeam_%imindec_%imaxdec" % (dataset, FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps)
251 | if ckpt_name is not None:
252 | dirname += "_%s" % ckpt_name
253 | return dirname
254 |
--------------------------------------------------------------------------------
/inspect_checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this:
3 | python inspect_checkpoint.py model.12345
4 | """
5 |
6 | import tensorflow as tf
7 | import sys
8 | import numpy as np
9 |
10 |
11 | if __name__ == '__main__':
12 | if len(sys.argv) != 2:
13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.")
14 | file_name = sys.argv[1]
15 | reader = tf.train.NewCheckpointReader(file_name)
16 | var_to_shape_map = reader.get_variable_to_shape_map()
17 |
18 | finite = []
19 | all_infnan = []
20 | some_infnan = []
21 |
22 | for key in sorted(var_to_shape_map.keys()):
23 | tensor = reader.get_tensor(key)
24 | if np.all(np.isfinite(tensor)):
25 | finite.append(key)
26 | else:
27 | if not np.any(np.isfinite(tensor)):
28 | all_infnan.append(key)
29 | else:
30 | some_infnan.append(key)
31 |
32 | print("\nFINITE VARIABLES:")
33 | for key in finite: print(key)
34 |
35 | print("\nVARIABLES THAT ARE ALL INF/NAN:")
36 | for key in all_infnan: print(key)
37 |
38 | print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:")
39 | for key in some_infnan: print(key)
40 |
41 | if not all_infnan and not some_infnan:
42 | print("CHECK PASSED: checkpoint contains no inf/NaN values")
43 | else:
44 | print("CHECK FAILED: checkpoint contains some inf/NaN values")
45 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to build and run the tensorflow graph for the sequence-to-sequence model"""
18 |
19 | import os
20 | import time
21 | import numpy as np
22 | import tensorflow as tf
23 | from attention_decoder import attention_decoder
24 | from tensorflow.contrib.tensorboard.plugins import projector
25 |
26 | FLAGS = tf.app.flags.FLAGS
27 |
28 | class SummarizationModel(object):
29 | """A class to represent a sequence-to-sequence model for text summarization. Supports both baseline mode, pointer-generator mode, and coverage"""
30 |
31 | def __init__(self, hps, vocab):
32 | self._hps = hps
33 | self._vocab = vocab
34 |
35 | def _add_placeholders(self):
36 | """Add placeholders to the graph. These are entry points for any input data."""
37 | hps = self._hps
38 |
39 | # encoder part
40 | self._enc_batch = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch')
41 | self._enc_lens = tf.placeholder(tf.int32, [hps.batch_size], name='enc_lens')
42 | self._enc_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, None], name='enc_padding_mask')
43 | if FLAGS.pointer_gen:
44 | self._enc_batch_extend_vocab = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch_extend_vocab')
45 | self._max_art_oovs = tf.placeholder(tf.int32, [], name='max_art_oovs')
46 |
47 | # decoder part
48 | self._dec_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='dec_batch')
49 | self._target_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='target_batch')
50 | self._dec_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, hps.max_dec_steps], name='dec_padding_mask')
51 |
52 | if hps.mode=="decode" and hps.coverage:
53 | self.prev_coverage = tf.placeholder(tf.float32, [hps.batch_size, None], name='prev_coverage')
54 |
55 |
56 | def _make_feed_dict(self, batch, just_enc=False):
57 | """Make a feed dictionary mapping parts of the batch to the appropriate placeholders.
58 |
59 | Args:
60 | batch: Batch object
61 | just_enc: Boolean. If True, only feed the parts needed for the encoder.
62 | """
63 | feed_dict = {}
64 | feed_dict[self._enc_batch] = batch.enc_batch
65 | feed_dict[self._enc_lens] = batch.enc_lens
66 | feed_dict[self._enc_padding_mask] = batch.enc_padding_mask
67 | if FLAGS.pointer_gen:
68 | feed_dict[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab
69 | feed_dict[self._max_art_oovs] = batch.max_art_oovs
70 | if not just_enc:
71 | feed_dict[self._dec_batch] = batch.dec_batch
72 | feed_dict[self._target_batch] = batch.target_batch
73 | feed_dict[self._dec_padding_mask] = batch.dec_padding_mask
74 | return feed_dict
75 |
76 | def _add_encoder(self, encoder_inputs, seq_len):
77 | """Add a single-layer bidirectional LSTM encoder to the graph.
78 |
79 | Args:
80 | encoder_inputs: A tensor of shape [batch_size, <=max_enc_steps, emb_size].
81 | seq_len: Lengths of encoder_inputs (before padding). A tensor of shape [batch_size].
82 |
83 | Returns:
84 | encoder_outputs:
85 | A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim]. It's 2*hidden_dim because it's the concatenation of the forwards and backwards states.
86 | fw_state, bw_state:
87 | Each are LSTMStateTuples of shape ([batch_size,hidden_dim],[batch_size,hidden_dim])
88 | """
89 | with tf.variable_scope('encoder'):
90 | cell_fw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init, state_is_tuple=True)
91 | cell_bw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init, state_is_tuple=True)
92 | (encoder_outputs, (fw_st, bw_st)) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, encoder_inputs, dtype=tf.float32, sequence_length=seq_len, swap_memory=True)
93 | encoder_outputs = tf.concat(axis=2, values=encoder_outputs) # concatenate the forwards and backwards states
94 | return encoder_outputs, fw_st, bw_st
95 |
96 |
97 | def _reduce_states(self, fw_st, bw_st):
98 | """Add to the graph a linear layer to reduce the encoder's final FW and BW state into a single initial state for the decoder. This is needed because the encoder is bidirectional but the decoder is not.
99 |
100 | Args:
101 | fw_st: LSTMStateTuple with hidden_dim units.
102 | bw_st: LSTMStateTuple with hidden_dim units.
103 |
104 | Returns:
105 | state: LSTMStateTuple with hidden_dim units.
106 | """
107 | hidden_dim = self._hps.hidden_dim
108 | with tf.variable_scope('reduce_final_st'):
109 |
110 | # Define weights and biases to reduce the cell and reduce the state
111 | w_reduce_c = tf.get_variable('w_reduce_c', [hidden_dim * 2, hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
112 | w_reduce_h = tf.get_variable('w_reduce_h', [hidden_dim * 2, hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
113 | bias_reduce_c = tf.get_variable('bias_reduce_c', [hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
114 | bias_reduce_h = tf.get_variable('bias_reduce_h', [hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
115 |
116 | # Apply linear layer
117 | old_c = tf.concat(axis=1, values=[fw_st.c, bw_st.c]) # Concatenation of fw and bw cell
118 | old_h = tf.concat(axis=1, values=[fw_st.h, bw_st.h]) # Concatenation of fw and bw state
119 | new_c = tf.nn.relu(tf.matmul(old_c, w_reduce_c) + bias_reduce_c) # Get new cell from old cell
120 | new_h = tf.nn.relu(tf.matmul(old_h, w_reduce_h) + bias_reduce_h) # Get new state from old state
121 | return tf.contrib.rnn.LSTMStateTuple(new_c, new_h) # Return new cell and state
122 |
123 |
124 | def _add_decoder(self, inputs):
125 | """Add attention decoder to the graph. In train or eval mode, you call this once to get output on ALL steps. In decode (beam search) mode, you call this once for EACH decoder step.
126 |
127 | Args:
128 | inputs: inputs to the decoder (word embeddings). A list of tensors shape (batch_size, emb_dim)
129 |
130 | Returns:
131 | outputs: List of tensors; the outputs of the decoder
132 | out_state: The final state of the decoder
133 | attn_dists: A list of tensors; the attention distributions
134 | p_gens: A list of scalar tensors; the generation probabilities
135 | coverage: A tensor, the current coverage vector
136 | """
137 | hps = self._hps
138 | cell = tf.contrib.rnn.LSTMCell(hps.hidden_dim, state_is_tuple=True, initializer=self.rand_unif_init)
139 |
140 | prev_coverage = self.prev_coverage if hps.mode=="decode" and hps.coverage else None # In decode mode, we run attention_decoder one step at a time and so need to pass in the previous step's coverage vector each time
141 |
142 | outputs, out_state, attn_dists, p_gens, coverage = attention_decoder(inputs, self._dec_in_state, self._enc_states, self._enc_padding_mask, cell, initial_state_attention=(hps.mode=="decode"), pointer_gen=hps.pointer_gen, use_coverage=hps.coverage, prev_coverage=prev_coverage)
143 |
144 | return outputs, out_state, attn_dists, p_gens, coverage
145 |
146 | def _calc_final_dist(self, vocab_dists, attn_dists):
147 | """Calculate the final distribution, for the pointer-generator model
148 |
149 | Args:
150 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.
151 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays
152 |
153 | Returns:
154 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays.
155 | """
156 | with tf.variable_scope('final_distribution'):
157 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen)
158 | vocab_dists = [p_gen * dist for (p_gen,dist) in zip(self.p_gens, vocab_dists)]
159 | attn_dists = [(1-p_gen) * dist for (p_gen,dist) in zip(self.p_gens, attn_dists)]
160 |
161 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words
162 | extended_vsize = self._vocab.size() + self._max_art_oovs # the maximum (over the batch) size of the extended vocabulary
163 | extra_zeros = tf.zeros((self._hps.batch_size, self._max_art_oovs))
164 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize)
165 |
166 | # Project the values in the attention distributions onto the appropriate entries in the final distributions
167 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution
168 | # This is done for each decoder timestep.
169 | # This is fiddly; we use tf.scatter_nd to do the projection
170 | batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
171 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1)
172 | attn_len = tf.shape(self._enc_batch_extend_vocab)[1] # number of states we attend over
173 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len)
174 | indices = tf.stack( (batch_nums, self._enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2)
175 | shape = [self._hps.batch_size, extended_vsize]
176 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists] # list length max_dec_steps (batch_size, extended_vsize)
177 |
178 | # Add the vocab distributions and the copy distributions together to get the final distributions
179 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep
180 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore.
181 | final_dists = [vocab_dist + copy_dist for (vocab_dist,copy_dist) in zip(vocab_dists_extended, attn_dists_projected)]
182 |
183 | return final_dists
184 |
185 | def _add_emb_vis(self, embedding_var):
186 | """Do setup so that we can view word embedding visualization in Tensorboard, as described here:
187 | https://www.tensorflow.org/get_started/embedding_viz
188 | Make the vocab metadata file, then make the projector config file pointing to it."""
189 | train_dir = os.path.join(FLAGS.log_root, "train")
190 | vocab_metadata_path = os.path.join(train_dir, "vocab_metadata.tsv")
191 | self._vocab.write_metadata(vocab_metadata_path) # write metadata file
192 | summary_writer = tf.summary.FileWriter(train_dir)
193 | config = projector.ProjectorConfig()
194 | embedding = config.embeddings.add()
195 | embedding.tensor_name = embedding_var.name
196 | embedding.metadata_path = vocab_metadata_path
197 | projector.visualize_embeddings(summary_writer, config)
198 |
199 | def _add_seq2seq(self):
200 | """Add the whole sequence-to-sequence model to the graph."""
201 | hps = self._hps
202 | vsize = self._vocab.size() # size of the vocabulary
203 |
204 | with tf.variable_scope('seq2seq'):
205 | # Some initializers
206 | self.rand_unif_init = tf.random_uniform_initializer(-hps.rand_unif_init_mag, hps.rand_unif_init_mag, seed=123)
207 | self.trunc_norm_init = tf.truncated_normal_initializer(stddev=hps.trunc_norm_init_std)
208 |
209 | # Add embedding matrix (shared by the encoder and decoder inputs)
210 | with tf.variable_scope('embedding'):
211 | embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
212 | if hps.mode=="train": self._add_emb_vis(embedding) # add to tensorboard
213 | emb_enc_inputs = tf.nn.embedding_lookup(embedding, self._enc_batch) # tensor with shape (batch_size, max_enc_steps, emb_size)
214 | emb_dec_inputs = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch, axis=1)] # list length max_dec_steps containing shape (batch_size, emb_size)
215 |
216 | # Add the encoder.
217 | enc_outputs, fw_st, bw_st = self._add_encoder(emb_enc_inputs, self._enc_lens)
218 | self._enc_states = enc_outputs
219 |
220 | # Our encoder is bidirectional and our decoder is unidirectional so we need to reduce the final encoder hidden state to the right size to be the initial decoder hidden state
221 | self._dec_in_state = self._reduce_states(fw_st, bw_st)
222 |
223 | # Add the decoder.
224 | with tf.variable_scope('decoder'):
225 | decoder_outputs, self._dec_out_state, self.attn_dists, self.p_gens, self.coverage = self._add_decoder(emb_dec_inputs)
226 |
227 | # Add the output projection to obtain the vocabulary distribution
228 | with tf.variable_scope('output_projection'):
229 | w = tf.get_variable('w', [hps.hidden_dim, vsize], dtype=tf.float32, initializer=self.trunc_norm_init)
230 | w_t = tf.transpose(w)
231 | v = tf.get_variable('v', [vsize], dtype=tf.float32, initializer=self.trunc_norm_init)
232 | vocab_scores = [] # vocab_scores is the vocabulary distribution before applying softmax. Each entry on the list corresponds to one decoder step
233 | for i,output in enumerate(decoder_outputs):
234 | if i > 0:
235 | tf.get_variable_scope().reuse_variables()
236 | vocab_scores.append(tf.nn.xw_plus_b(output, w, v)) # apply the linear layer
237 |
238 | vocab_dists = [tf.nn.softmax(s) for s in vocab_scores] # The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.
239 |
240 |
241 | # For pointer-generator model, calc final distribution from copy distribution and vocabulary distribution
242 | if FLAGS.pointer_gen:
243 | final_dists = self._calc_final_dist(vocab_dists, self.attn_dists)
244 | else: # final distribution is just vocabulary distribution
245 | final_dists = vocab_dists
246 |
247 |
248 |
249 | if hps.mode in ['train', 'eval']:
250 | # Calculate the loss
251 | with tf.variable_scope('loss'):
252 | if FLAGS.pointer_gen:
253 | # Calculate the loss per step
254 | # This is fiddly; we use tf.gather_nd to pick out the probabilities of the gold target words
255 | loss_per_step = [] # will be list length max_dec_steps containing shape (batch_size)
256 | batch_nums = tf.range(0, limit=hps.batch_size) # shape (batch_size)
257 | for dec_step, dist in enumerate(final_dists):
258 | targets = self._target_batch[:,dec_step] # The indices of the target words. shape (batch_size)
259 | indices = tf.stack( (batch_nums, targets), axis=1) # shape (batch_size, 2)
260 | gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
261 | losses = -tf.log(gold_probs)
262 | loss_per_step.append(losses)
263 |
264 | # Apply dec_padding_mask and get loss
265 | self._loss = _mask_and_avg(loss_per_step, self._dec_padding_mask)
266 |
267 | else: # baseline model
268 | self._loss = tf.contrib.seq2seq.sequence_loss(tf.stack(vocab_scores, axis=1), self._target_batch, self._dec_padding_mask) # this applies softmax internally
269 |
270 | tf.summary.scalar('loss', self._loss)
271 |
272 | # Calculate coverage loss from the attention distributions
273 | if hps.coverage:
274 | with tf.variable_scope('coverage_loss'):
275 | self._coverage_loss = _coverage_loss(self.attn_dists, self._dec_padding_mask)
276 | tf.summary.scalar('coverage_loss', self._coverage_loss)
277 | self._total_loss = self._loss + hps.cov_loss_wt * self._coverage_loss
278 | tf.summary.scalar('total_loss', self._total_loss)
279 |
280 | if hps.mode == "decode":
281 | # We run decode beam search mode one decoder step at a time
282 | assert len(final_dists)==1 # final_dists is a singleton list containing shape (batch_size, extended_vsize)
283 | final_dists = final_dists[0]
284 | topk_probs, self._topk_ids = tf.nn.top_k(final_dists, hps.batch_size*2) # take the k largest probs. note batch_size=beam_size in decode mode
285 | self._topk_log_probs = tf.log(topk_probs)
286 |
287 |
288 | def _add_train_op(self):
289 | """Sets self._train_op, the op to run for training."""
290 | # Take gradients of the trainable variables w.r.t. the loss function to minimize
291 | loss_to_minimize = self._total_loss if self._hps.coverage else self._loss
292 | tvars = tf.trainable_variables()
293 | gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)
294 |
295 | # Clip the gradients
296 | with tf.device("/gpu:0"):
297 | grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)
298 |
299 | # Add a summary
300 | tf.summary.scalar('global_norm', global_norm)
301 |
302 | # Apply adagrad optimizer
303 | optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc)
304 | with tf.device("/gpu:0"):
305 | self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step')
306 |
307 |
308 | def build_graph(self):
309 | """Add the placeholders, model, global step, train_op and summaries to the graph"""
310 | tf.logging.info('Building graph...')
311 | t0 = time.time()
312 | self._add_placeholders()
313 | with tf.device("/gpu:0"):
314 | self._add_seq2seq()
315 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
316 | if self._hps.mode == 'train':
317 | self._add_train_op()
318 | self._summaries = tf.summary.merge_all()
319 | t1 = time.time()
320 | tf.logging.info('Time to build graph: %i seconds', t1 - t0)
321 |
322 | def run_train_step(self, sess, batch):
323 | """Runs one training iteration. Returns a dictionary containing train op, summaries, loss, global_step and (optionally) coverage loss."""
324 | feed_dict = self._make_feed_dict(batch)
325 | to_return = {
326 | 'train_op': self._train_op,
327 | 'summaries': self._summaries,
328 | 'loss': self._loss,
329 | 'global_step': self.global_step,
330 | }
331 | if self._hps.coverage:
332 | to_return['coverage_loss'] = self._coverage_loss
333 | return sess.run(to_return, feed_dict)
334 |
335 | def run_eval_step(self, sess, batch):
336 | """Runs one evaluation iteration. Returns a dictionary containing summaries, loss, global_step and (optionally) coverage loss."""
337 | feed_dict = self._make_feed_dict(batch)
338 | to_return = {
339 | 'summaries': self._summaries,
340 | 'loss': self._loss,
341 | 'global_step': self.global_step,
342 | }
343 | if self._hps.coverage:
344 | to_return['coverage_loss'] = self._coverage_loss
345 | return sess.run(to_return, feed_dict)
346 |
347 | def run_encoder(self, sess, batch):
348 | """For beam search decoding. Run the encoder on the batch and return the encoder states and decoder initial state.
349 |
350 | Args:
351 | sess: Tensorflow session.
352 | batch: Batch object that is the same example repeated across the batch (for beam search)
353 |
354 | Returns:
355 | enc_states: The encoder states. A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim].
356 | dec_in_state: A LSTMStateTuple of shape ([1,hidden_dim],[1,hidden_dim])
357 | """
358 | feed_dict = self._make_feed_dict(batch, just_enc=True) # feed the batch into the placeholders
359 | (enc_states, dec_in_state, global_step) = sess.run([self._enc_states, self._dec_in_state, self.global_step], feed_dict) # run the encoder
360 |
361 | # dec_in_state is LSTMStateTuple shape ([batch_size,hidden_dim],[batch_size,hidden_dim])
362 | # Given that the batch is a single example repeated, dec_in_state is identical across the batch so we just take the top row.
363 | dec_in_state = tf.contrib.rnn.LSTMStateTuple(dec_in_state.c[0], dec_in_state.h[0])
364 | return enc_states, dec_in_state
365 |
366 |
367 | def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states, prev_coverage):
368 | """For beam search decoding. Run the decoder for one step.
369 |
370 | Args:
371 | sess: Tensorflow session.
372 | batch: Batch object containing single example repeated across the batch
373 | latest_tokens: Tokens to be fed as input into the decoder for this timestep
374 | enc_states: The encoder states.
375 | dec_init_states: List of beam_size LSTMStateTuples; the decoder states from the previous timestep
376 | prev_coverage: List of np arrays. The coverage vectors from the previous timestep. List of None if not using coverage.
377 |
378 | Returns:
379 | ids: top 2k ids. shape [beam_size, 2*beam_size]
380 | probs: top 2k log probabilities. shape [beam_size, 2*beam_size]
381 | new_states: new states of the decoder. a list length beam_size containing
382 | LSTMStateTuples each of shape ([hidden_dim,],[hidden_dim,])
383 | attn_dists: List length beam_size containing lists length attn_length.
384 | p_gens: Generation probabilities for this step. A list length beam_size. List of None if in baseline mode.
385 | new_coverage: Coverage vectors for this step. A list of arrays. List of None if coverage is not turned on.
386 | """
387 |
388 | beam_size = len(dec_init_states)
389 |
390 | # Turn dec_init_states (a list of LSTMStateTuples) into a single LSTMStateTuple for the batch
391 | cells = [np.expand_dims(state.c, axis=0) for state in dec_init_states]
392 | hiddens = [np.expand_dims(state.h, axis=0) for state in dec_init_states]
393 | new_c = np.concatenate(cells, axis=0) # shape [batch_size,hidden_dim]
394 | new_h = np.concatenate(hiddens, axis=0) # shape [batch_size,hidden_dim]
395 | new_dec_in_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
396 |
397 | feed = {
398 | self._enc_states: enc_states,
399 | self._enc_padding_mask: batch.enc_padding_mask,
400 | self._dec_in_state: new_dec_in_state,
401 | self._dec_batch: np.transpose(np.array([latest_tokens])),
402 | }
403 |
404 | to_return = {
405 | "ids": self._topk_ids,
406 | "probs": self._topk_log_probs,
407 | "states": self._dec_out_state,
408 | "attn_dists": self.attn_dists
409 | }
410 |
411 | if FLAGS.pointer_gen:
412 | feed[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab
413 | feed[self._max_art_oovs] = batch.max_art_oovs
414 | to_return['p_gens'] = self.p_gens
415 |
416 | if self._hps.coverage:
417 | feed[self.prev_coverage] = np.stack(prev_coverage, axis=0)
418 | to_return['coverage'] = self.coverage
419 |
420 | results = sess.run(to_return, feed_dict=feed) # run the decoder step
421 |
422 | # Convert results['states'] (a single LSTMStateTuple) into a list of LSTMStateTuple -- one for each hypothesis
423 | new_states = [tf.contrib.rnn.LSTMStateTuple(results['states'].c[i, :], results['states'].h[i, :]) for i in range(beam_size)]
424 |
425 | # Convert singleton list containing a tensor to a list of k arrays
426 | assert len(results['attn_dists'])==1
427 | attn_dists = results['attn_dists'][0].tolist()
428 |
429 | if FLAGS.pointer_gen:
430 | # Convert singleton list containing a tensor to a list of k arrays
431 | assert len(results['p_gens'])==1
432 | p_gens = results['p_gens'][0].tolist()
433 | else:
434 | p_gens = [None for _ in range(beam_size)]
435 |
436 | # Convert the coverage tensor to a list length k containing the coverage vector for each hypothesis
437 | if FLAGS.coverage:
438 | new_coverage = results['coverage'].tolist()
439 | assert len(new_coverage) == beam_size
440 | else:
441 | new_coverage = [None for _ in range(beam_size)]
442 |
443 | return results['ids'], results['probs'], new_states, attn_dists, p_gens, new_coverage
444 |
445 |
446 | def _mask_and_avg(values, padding_mask):
447 | """Applies mask to values then returns overall average (a scalar)
448 |
449 | Args:
450 | values: a list length max_dec_steps containing arrays shape (batch_size).
451 | padding_mask: tensor shape (batch_size, max_dec_steps) containing 1s and 0s.
452 |
453 | Returns:
454 | a scalar
455 | """
456 |
457 | dec_lens = tf.reduce_sum(padding_mask, axis=1) # shape batch_size. float32
458 | values_per_step = [v * padding_mask[:,dec_step] for dec_step,v in enumerate(values)]
459 | values_per_ex = sum(values_per_step)/dec_lens # shape (batch_size); normalized value for each batch member
460 | return tf.reduce_mean(values_per_ex) # overall average
461 |
462 |
463 | def _coverage_loss(attn_dists, padding_mask):
464 | """Calculates the coverage loss from the attention distributions.
465 |
466 | Args:
467 | attn_dists: The attention distributions for each decoder timestep. A list length max_dec_steps containing shape (batch_size, attn_length)
468 | padding_mask: shape (batch_size, max_dec_steps).
469 |
470 | Returns:
471 | coverage_loss: scalar
472 | """
473 | coverage = tf.zeros_like(attn_dists[0]) # shape (batch_size, attn_length). Initial coverage is zero.
474 | covlosses = [] # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size).
475 | for a in attn_dists:
476 | covloss = tf.reduce_sum(tf.minimum(a, coverage), [1]) # calculate the coverage loss for this step
477 | covlosses.append(covloss)
478 | coverage += a # update the coverage vector
479 | coverage_loss = _mask_and_avg(covlosses, padding_mask)
480 | return coverage_loss
481 |
--------------------------------------------------------------------------------
/run_summarization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This is the top-level file to train, evaluate or test your summarization model"""
18 |
19 | import sys
20 | import time
21 | import os
22 | import tensorflow as tf
23 | import numpy as np
24 | from collections import namedtuple
25 | from data import Vocab
26 | from batcher import Batcher
27 | from model import SummarizationModel
28 | from decode import BeamSearchDecoder
29 | import util
30 | from tensorflow.python import debug as tf_debug
31 |
32 | FLAGS = tf.app.flags.FLAGS
33 |
34 | # Where to find data
35 | tf.app.flags.DEFINE_string('data_path', '', 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.')
36 | tf.app.flags.DEFINE_string('vocab_path', '', 'Path expression to text vocabulary file.')
37 |
38 | # Important settings
39 | tf.app.flags.DEFINE_string('mode', 'train', 'must be one of train/eval/decode')
40 | tf.app.flags.DEFINE_boolean('single_pass', False, 'For decode mode only. If True, run eval on the full dataset using a fixed checkpoint, i.e. take the current checkpoint, and use it to produce one summary for each example in the dataset, write the summaries to file and then get ROUGE scores for the whole dataset. If False (default), run concurrent decoding, i.e. repeatedly load latest checkpoint, use it to produce summaries for randomly-chosen examples and log the results to screen, indefinitely.')
41 |
42 | # Where to save output
43 | tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.')
44 | tf.app.flags.DEFINE_string('exp_name', '', 'Name for experiment. Logs will be saved in a directory with this name, under log_root.')
45 |
46 | # Hyperparameters
47 | tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states')
48 | tf.app.flags.DEFINE_integer('emb_dim', 128, 'dimension of word embeddings')
49 | tf.app.flags.DEFINE_integer('batch_size', 16, 'minibatch size')
50 | tf.app.flags.DEFINE_integer('max_enc_steps', 400, 'max timesteps of encoder (max source text tokens)')
51 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'max timesteps of decoder (max summary tokens)')
52 | tf.app.flags.DEFINE_integer('beam_size', 4, 'beam size for beam search decoding.')
53 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 'Minimum sequence length of generated summary. Applies only for beam search decoding mode')
54 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.')
55 | tf.app.flags.DEFINE_float('lr', 0.15, 'learning rate')
56 | tf.app.flags.DEFINE_float('adagrad_init_acc', 0.1, 'initial accumulator value for Adagrad')
57 | tf.app.flags.DEFINE_float('rand_unif_init_mag', 0.02, 'magnitude for lstm cells random uniform inititalization')
58 | tf.app.flags.DEFINE_float('trunc_norm_init_std', 1e-4, 'std of trunc norm init, used for initializing everything else')
59 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'for gradient clipping')
60 |
61 | # Pointer-generator or baseline model
62 | tf.app.flags.DEFINE_boolean('pointer_gen', True, 'If True, use pointer-generator model. If False, use baseline model.')
63 |
64 | # Coverage hyperparameters
65 | tf.app.flags.DEFINE_boolean('coverage', False, 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.')
66 | tf.app.flags.DEFINE_float('cov_loss_wt', 1.0, 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.')
67 |
68 | # Utility flags, for restoring and changing checkpoints
69 | tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False, 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current training model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.')
70 | tf.app.flags.DEFINE_boolean('restore_best_model', False, 'Restore the best model in the eval/ dir and save it in the train/ dir, ready to be used for further training. Useful for early stopping, or if your training checkpoint has become corrupted with e.g. NaN values.')
71 |
72 | # Debugging. See https://www.tensorflow.org/programmers_guide/debugger
73 | tf.app.flags.DEFINE_boolean('debug', False, "Run in tensorflow's debug mode (watches for NaN/inf values)")
74 |
75 |
76 |
77 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
78 | """Calculate the running average loss via exponential decay.
79 | This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve.
80 |
81 | Args:
82 | loss: loss on the most recent eval step
83 | running_avg_loss: running_avg_loss so far
84 | summary_writer: FileWriter object to write for tensorboard
85 | step: training iteration step
86 | decay: rate of exponential decay, a float between 0 and 1. Larger is smoother.
87 |
88 | Returns:
89 | running_avg_loss: new running average loss
90 | """
91 | if running_avg_loss == 0: # on the first iteration just take the loss
92 | running_avg_loss = loss
93 | else:
94 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
95 | running_avg_loss = min(running_avg_loss, 12) # clip
96 | loss_sum = tf.Summary()
97 | tag_name = 'running_avg_loss/decay=%f' % (decay)
98 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
99 | summary_writer.add_summary(loss_sum, step)
100 | tf.logging.info('running_avg_loss: %f', running_avg_loss)
101 | return running_avg_loss
102 |
103 |
104 | def restore_best_model():
105 | """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
106 | tf.logging.info("Restoring bestmodel for training...")
107 |
108 | # Initialize all vars in the model
109 | sess = tf.Session(config=util.get_config())
110 | print("Initializing all variables...")
111 | sess.run(tf.initialize_all_variables())
112 |
113 | # Restore the best model from eval dir
114 | saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
115 | print("Restoring all non-adagrad variables from best model in eval dir...")
116 | curr_ckpt = util.load_ckpt(saver, sess, "eval")
117 | print ("Restored %s." % curr_ckpt)
118 |
119 | # Save this model to train dir and quit
120 | new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
121 | new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
122 | print ("Saving model to %s..." % (new_fname))
123 | new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
124 | new_saver.save(sess, new_fname)
125 | print ("Saved.")
126 | exit()
127 |
128 |
129 | def convert_to_coverage_model():
130 | """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
131 | tf.logging.info("converting non-coverage model to coverage model..")
132 |
133 | # initialize an entire coverage model from scratch
134 | sess = tf.Session(config=util.get_config())
135 | print("initializing everything...")
136 | sess.run(tf.global_variables_initializer())
137 |
138 | # load all non-coverage weights from checkpoint
139 | saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
140 | print("restoring non-coverage variables...")
141 | curr_ckpt = util.load_ckpt(saver, sess)
142 | print("restored.")
143 |
144 | # save this model and quit
145 | new_fname = curr_ckpt + '_cov_init'
146 | print("saving model to %s..." % (new_fname))
147 | new_saver = tf.train.Saver() # this one will save all variables that now exist
148 | new_saver.save(sess, new_fname)
149 | print("saved.")
150 | exit()
151 |
152 |
153 | def setup_training(model, batcher):
154 | """Does setup before starting training (run_training)"""
155 | train_dir = os.path.join(FLAGS.log_root, "train")
156 | if not os.path.exists(train_dir): os.makedirs(train_dir)
157 |
158 | model.build_graph() # build the graph
159 | if FLAGS.convert_to_coverage_model:
160 | assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
161 | convert_to_coverage_model()
162 | if FLAGS.restore_best_model:
163 | restore_best_model()
164 | saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
165 |
166 | sv = tf.train.Supervisor(logdir=train_dir,
167 | is_chief=True,
168 | saver=saver,
169 | summary_op=None,
170 | save_summaries_secs=60, # save summaries for tensorboard every 60 secs
171 | save_model_secs=60, # checkpoint every 60 secs
172 | global_step=model.global_step)
173 | summary_writer = sv.summary_writer
174 | tf.logging.info("Preparing or waiting for session...")
175 | sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config())
176 | tf.logging.info("Created session.")
177 | try:
178 | run_training(model, batcher, sess_context_manager, sv, summary_writer) # this is an infinite loop until interrupted
179 | except KeyboardInterrupt:
180 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
181 | sv.stop()
182 |
183 |
184 | def run_training(model, batcher, sess_context_manager, sv, summary_writer):
185 | """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
186 | tf.logging.info("starting run_training")
187 | with sess_context_manager as sess:
188 | if FLAGS.debug: # start the tensorflow debugger
189 | sess = tf_debug.LocalCLIDebugWrapperSession(sess)
190 | sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
191 | while True: # repeats until interrupted
192 | batch = batcher.next_batch()
193 |
194 | tf.logging.info('running training step...')
195 | t0=time.time()
196 | results = model.run_train_step(sess, batch)
197 | t1=time.time()
198 | tf.logging.info('seconds for training step: %.3f', t1-t0)
199 |
200 | loss = results['loss']
201 | tf.logging.info('loss: %f', loss) # print the loss to screen
202 |
203 | if not np.isfinite(loss):
204 | raise Exception("Loss is not finite. Stopping.")
205 |
206 | if FLAGS.coverage:
207 | coverage_loss = results['coverage_loss']
208 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen
209 |
210 | # get the summaries and iteration number so we can write summaries to tensorboard
211 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
212 | train_step = results['global_step'] # we need this to update our running average loss
213 |
214 | summary_writer.add_summary(summaries, train_step) # write the summaries
215 | if train_step % 100 == 0: # flush the summary writer every so often
216 | summary_writer.flush()
217 |
218 |
219 | def run_eval(model, batcher, vocab):
220 | """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
221 | model.build_graph() # build the graph
222 | saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
223 | sess = tf.Session(config=util.get_config())
224 | eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
225 | bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
226 | summary_writer = tf.summary.FileWriter(eval_dir)
227 | running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
228 | best_loss = None # will hold the best loss achieved so far
229 |
230 | while True:
231 | _ = util.load_ckpt(saver, sess) # load a new checkpoint
232 | batch = batcher.next_batch() # get the next batch
233 |
234 | # run eval on the batch
235 | t0=time.time()
236 | results = model.run_eval_step(sess, batch)
237 | t1=time.time()
238 | tf.logging.info('seconds for batch: %.2f', t1-t0)
239 |
240 | # print the loss and coverage loss to screen
241 | loss = results['loss']
242 | tf.logging.info('loss: %f', loss)
243 | if FLAGS.coverage:
244 | coverage_loss = results['coverage_loss']
245 | tf.logging.info("coverage_loss: %f", coverage_loss)
246 |
247 | # add summaries
248 | summaries = results['summaries']
249 | train_step = results['global_step']
250 | summary_writer.add_summary(summaries, train_step)
251 |
252 | # calculate running avg loss
253 | running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step)
254 |
255 | # If running_avg_loss is best so far, save this checkpoint (early stopping).
256 | # These checkpoints will appear as bestmodel- in the eval dir
257 | if best_loss is None or running_avg_loss < best_loss:
258 | tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
259 | saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
260 | best_loss = running_avg_loss
261 |
262 | # flush the summary writer every so often
263 | if train_step % 100 == 0:
264 | summary_writer.flush()
265 |
266 |
267 | def main(unused_argv):
268 | if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
269 | raise Exception("Problem with flags: %s" % unused_argv)
270 |
271 | tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
272 | tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))
273 |
274 | # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
275 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
276 | if not os.path.exists(FLAGS.log_root):
277 | if FLAGS.mode=="train":
278 | os.makedirs(FLAGS.log_root)
279 | else:
280 | raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))
281 |
282 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary
283 |
284 | # If in decode mode, set batch_size = beam_size
285 | # Reason: in decode mode, we decode one example at a time.
286 | # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
287 | if FLAGS.mode == 'decode':
288 | FLAGS.batch_size = FLAGS.beam_size
289 |
290 | # If single_pass=True, check we're in decode mode
291 | if FLAGS.single_pass and FLAGS.mode!='decode':
292 | raise Exception("The single_pass flag should only be True in decode mode")
293 |
294 | # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
295 | hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen']
296 | hps_dict = {}
297 | for key,val in FLAGS.__flags.items(): # for each flag
298 | if key in hparam_list: # if it's in the list
299 | hps_dict[key] = val # add it to the dict
300 | hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)
301 |
302 | # Create a batcher object that will create minibatches of data
303 | batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass)
304 |
305 | tf.set_random_seed(111) # a seed value for randomness
306 |
307 | if hps.mode == 'train':
308 | print("creating model...")
309 | model = SummarizationModel(hps, vocab)
310 | setup_training(model, batcher)
311 | elif hps.mode == 'eval':
312 | model = SummarizationModel(hps, vocab)
313 | run_eval(model, batcher, vocab)
314 | elif hps.mode == 'decode':
315 | decode_model_hps = hps # This will be the hyperparameters for the decoder model
316 | decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
317 | model = SummarizationModel(decode_model_hps, vocab)
318 | decoder = BeamSearchDecoder(model, batcher, vocab)
319 | decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
320 | else:
321 | raise ValueError("The 'mode' flag must be one of train/eval/decode")
322 |
323 | if __name__ == '__main__':
324 | tf.app.run()
325 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains some utility functions"""
18 |
19 | import tensorflow as tf
20 | import time
21 | import os
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 | def get_config():
25 | """Returns config for tf.session"""
26 | config = tf.ConfigProto(allow_soft_placement=True)
27 | config.gpu_options.allow_growth=True
28 | return config
29 |
30 | def load_ckpt(saver, sess, ckpt_dir="train"):
31 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name."""
32 | while True:
33 | try:
34 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None
35 | ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir)
36 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename)
37 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
38 | saver.restore(sess, ckpt_state.model_checkpoint_path)
39 | return ckpt_state.model_checkpoint_path
40 | except:
41 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10)
42 | time.sleep(10)
43 |
--------------------------------------------------------------------------------