├── .gitignore ├── LICENSE.txt ├── README.md ├── __init__.py ├── attention_decoder.py ├── batcher.py ├── beam_search.py ├── data.py ├── decode.py ├── inspect_checkpoint.py ├── model.py ├── run_summarization.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | log/ 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 The TensorFlow Authors. All rights reserved. 2 | Modifications Copyright 2017 Abigail See 3 | 4 | 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright 2017, The TensorFlow Authors. 194 | Modifications Copyright 2017 Abigail See 195 | 196 | Licensed under the Apache License, Version 2.0 (the "License"); 197 | you may not use this file except in compliance with the License. 198 | You may obtain a copy of the License at 199 | 200 | http://www.apache.org/licenses/LICENSE-2.0 201 | 202 | Unless required by applicable law or agreed to in writing, software 203 | distributed under the License is distributed on an "AS IS" BASIS, 204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 205 | See the License for the specific language governing permissions and 206 | limitations under the License. 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for the ACL 2017 paper *[Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)*. 2 | 3 | ## Looking for test set output? 4 | The test set output of the models described in the paper can be found [here](https://drive.google.com/file/d/0B7pQmm-OfDv7MEtMVU5sOHc5LTg/view?usp=sharing). 5 | 6 | ## Looking for pretrained model? 7 | A pretrained model is available here: 8 | * [Version for Tensorflow 1.0](https://drive.google.com/file/d/0B7pQmm-OfDv7SHFadHR4RllfR1E/view?usp=sharing) 9 | * [Version for Tensorflow 1.2.1](https://drive.google.com/file/d/0B7pQmm-OfDv7ZUhHZm9ZWEZidDg/view?usp=sharing) 10 | 11 | (The only difference between these two is the naming of some of the variables in the checkpoint. Tensorflow 1.0 uses `lstm_cell/biases` and `lstm_cell/weights` whereas Tensorflow 1.2.1 uses `lstm_cell/bias` and `lstm_cell/kernel`). 12 | 13 | ## Looking for CNN / Daily Mail data? 14 | Instructions are [here](https://github.com/abisee/cnn-dailymail). 15 | 16 | ## About this code 17 | This code is based on the [TextSum code](https://github.com/tensorflow/models/tree/master/textsum) from Google Brain. 18 | 19 | This code was developed for Tensorflow 0.12, but has been updated to run with Tensorflow 1.0. 20 | In particular, the code in attention_decoder.py is based on [tf.contrib.legacy_seq2seq_attention_decoder](https://www.tensorflow.org/api_docs/python/tf/contrib/legacy_seq2seq/attention_decoder), which is now outdated. 21 | Tensorflow 1.0's [new seq2seq library](https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention) probably provides a way to do this (as well as beam search) more elegantly and efficiently in the future. 22 | 23 | ## How to run 24 | 25 | ### Get the dataset 26 | To obtain the CNN / Daily Mail dataset, follow the instructions [here](https://github.com/abisee/cnn-dailymail). Once finished, you should have [chunked](https://github.com/abisee/cnn-dailymail/issues/3) datafiles `train_000.bin`, ..., `train_287.bin`, `val_000.bin`, ..., `val_013.bin`, `test_000.bin`, ..., `test_011.bin` (each contains 1000 examples) and a vocabulary file `vocab`. 27 | 28 | **Note**: If you did this before 7th May 2017, follow the instructions [here](https://github.com/abisee/cnn-dailymail/issues/2) to correct a bug in the process. 29 | 30 | ### Run training 31 | To train your model, run: 32 | 33 | ``` 34 | python run_summarization.py --mode=train --data_path=/path/to/chunked/train_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment 35 | ``` 36 | 37 | This will create a subdirectory of your specified `log_root` called `myexperiment` where all checkpoints and other data will be saved. Then the model will start training using the `train_*.bin` files as training data. 38 | 39 | **Warning**: Using default settings as in the above command, both initializing the model and running training iterations will probably be quite slow. To make things faster, try setting the following flags (especially `max_enc_steps` and `max_dec_steps`) to something smaller than the defaults specified in `run_summarization.py`: `hidden_dim`, `emb_dim`, `batch_size`, `max_enc_steps`, `max_dec_steps`, `vocab_size`. 40 | 41 | **Increasing sequence length during training**: Note that to obtain the results described in the paper, we increase the values of `max_enc_steps` and `max_dec_steps` in stages throughout training (mostly so we can perform quicker iterations during early stages of training). If you wish to do the same, start with small values of `max_enc_steps` and `max_dec_steps`, then interrupt and restart the job with larger values when you want to increase them. 42 | 43 | ### Run (concurrent) eval 44 | You may want to run a concurrent evaluation job, that runs your model on the validation set and logs the loss. To do this, run: 45 | 46 | ``` 47 | python run_summarization.py --mode=eval --data_path=/path/to/chunked/val_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment 48 | ``` 49 | 50 | Note: you want to run the above command using the same settings you entered for your training job. 51 | 52 | **Restoring snapshots**: The eval job saves a snapshot of the model that scored the lowest loss on the validation data so far. You may want to restore one of these "best models", e.g. if your training job has overfit, or if the training checkpoint has become corrupted by NaN values. To do this, run your train command plus the `--restore_best_model=1` flag. This will copy the best model in the eval directory to the train directory. Then run the usual train command again. 53 | 54 | ### Run beam search decoding 55 | To run beam search decoding: 56 | 57 | ``` 58 | python run_summarization.py --mode=decode --data_path=/path/to/chunked/val_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment 59 | ``` 60 | 61 | Note: you want to run the above command using the same settings you entered for your training job (plus any decode mode specific flags like `beam_size`). 62 | 63 | This will repeatedly load random examples from your specified datafile and generate a summary using beam search. The results will be printed to screen. 64 | 65 | **Visualize your output**: Additionally, the decode job produces a file called `attn_vis_data.json`. This file provides the data necessary for an in-browser visualization tool that allows you to view the attention distributions projected onto the text. To use the visualizer, follow the instructions [here](https://github.com/abisee/attn_vis). 66 | 67 | If you want to run evaluation on the entire validation or test set and get ROUGE scores, set the flag `single_pass=1`. This will go through the entire dataset in order, writing the generated summaries to file, and then run evaluation using [pyrouge](https://pypi.python.org/pypi/pyrouge). (Note this will *not* produce the `attn_vis_data.json` files for the attention visualizer). 68 | 69 | ### Evaluate with ROUGE 70 | `decode.py` uses the Python package [`pyrouge`](https://pypi.python.org/pypi/pyrouge) to run ROUGE evaluation. `pyrouge` provides an easier-to-use interface for the official Perl ROUGE package, which you must install for `pyrouge` to work. Here are some useful instructions on how to do this: 71 | * [How to setup Perl ROUGE](http://kavita-ganesan.com/rouge-howto) 72 | * [More details about plugins for Perl ROUGE](http://www.summarizerman.com/post/42675198985/figuring-out-rouge) 73 | 74 | **Note:** As of 18th May 2017 the [website](http://berouge.com/) for the official Perl package appears to be down. Unfortunately you need to download a directory called `ROUGE-1.5.5` from there. As an alternative, it seems that you can get that directory from [here](https://github.com/andersjo/pyrouge) (however, the version of `pyrouge` in that repo appears to be outdated, so best to install `pyrouge` from the [official source](https://pypi.python.org/pypi/pyrouge)). 75 | 76 | ### Tensorboard 77 | Run Tensorboard from the experiment directory (in the example above, `myexperiment`). You should be able to see data from the train and eval runs. If you select "embeddings", you should also see your word embeddings visualized. 78 | 79 | ### Help, I've got NaNs! 80 | For reasons that are [difficult to diagnose](https://github.com/abisee/pointer-generator/issues/4), NaNs sometimes occur during training, making the loss=NaN and sometimes also corrupting the model checkpoint with NaN values, making it unusable. Here are some suggestions: 81 | 82 | * If training stopped with the `Loss is not finite. Stopping.` exception, you can just try restarting. It may be that the checkpoint is not corrupted. 83 | * You can check if your checkpoint is corrupted by using the `inspect_checkpoint.py` script. If it says that all values are finite, then your checkpoint is OK and you can try resuming training with it. 84 | * The training job is set to keep 3 checkpoints at any one time (see the `max_to_keep` variable in `run_summarization.py`). If your newer checkpoint is corrupted, it may be that one of the older ones is not. You can switch to that checkpoint by editing the `checkpoint` file inside the `train` directory. 85 | * Alternatively, you can restore a "best model" from the `eval` directory. See the note **Restoring snapshots** above. 86 | * If you want to try to diagnose the cause of the NaNs, you can run with the `--debug=1` flag turned on. This will run [Tensorflow Debugger](https://www.tensorflow.org/versions/master/programmers_guide/debugger), which checks for NaNs and diagnoses their causes during training. 87 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/becxer/pointer-generator/3b480d56437867b2b021467a664084915f74144b/__init__.py -------------------------------------------------------------------------------- /attention_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file defines the decoder""" 18 | 19 | import tensorflow as tf 20 | from tensorflow.python.ops import variable_scope 21 | from tensorflow.python.ops import array_ops 22 | from tensorflow.python.ops import nn_ops 23 | from tensorflow.python.ops import math_ops 24 | 25 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated. 26 | # In the future, it would make more sense to write variants on the attention mechanism using the new seq2seq library for tensorflow 1.0: https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention 27 | def attention_decoder(decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None): 28 | """ 29 | Args: 30 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 31 | initial_state: 2D Tensor [batch_size x cell.state_size]. 32 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size]. 33 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1). 34 | cell: rnn_cell.RNNCell defining the cell function and size. 35 | initial_state_attention: 36 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step). 37 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step. 38 | use_coverage: boolean. If True, use coverage mechanism. 39 | prev_coverage: 40 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage. 41 | 42 | Returns: 43 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 44 | shape [batch_size x cell.output_size]. The output vectors. 45 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size]. 46 | attn_dists: A list containing tensors of shape (batch_size,attn_length). 47 | The attention distributions for each decoder step. 48 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False. 49 | coverage: Coverage vector on the last step computed. None if use_coverage=False. 50 | """ 51 | with variable_scope.variable_scope("attention_decoder") as scope: 52 | batch_size = encoder_states.get_shape()[0].value # if this line fails, it's because the batch size isn't defined 53 | attn_size = encoder_states.get_shape()[2].value # if this line fails, it's because the attention length isn't defined 54 | 55 | # Reshape encoder_states (need to insert a dim) 56 | encoder_states = tf.expand_dims(encoder_states, axis=2) # now is shape (batch_size, attn_len, 1, attn_size) 57 | 58 | # To calculate attention, we calculate 59 | # v^T tanh(W_h h_i + W_s s_t + b_attn) 60 | # where h_i is an encoder state, and s_t a decoder state. 61 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t). 62 | # We set it to be equal to the size of the encoder states. 63 | attention_vec_size = attn_size 64 | 65 | # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features 66 | W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size]) 67 | encoder_features = nn_ops.conv2d(encoder_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,attn_length,1,attention_vec_size) 68 | 69 | # Get the weight vectors v and w_c (w_c is for coverage) 70 | v = variable_scope.get_variable("v", [attention_vec_size]) 71 | if use_coverage: 72 | with variable_scope.variable_scope("coverage"): 73 | w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size]) 74 | 75 | if prev_coverage is not None: # for beam search mode with coverage 76 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 1, 1) 77 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3) 78 | 79 | def attention(decoder_state, coverage=None): 80 | """Calculate the context vector and attention distribution from the decoder state. 81 | 82 | Args: 83 | decoder_state: state of the decoder 84 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1). 85 | 86 | Returns: 87 | context_vector: weighted sum of encoder_states 88 | attn_dist: attention distribution 89 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1) 90 | """ 91 | with variable_scope.variable_scope("Attention"): 92 | # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper) 93 | decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size) 94 | decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size) 95 | 96 | def masked_attention(e): 97 | """Take softmax of e then apply enc_padding_mask and re-normalize""" 98 | attn_dist = nn_ops.softmax(e) # take softmax. shape (batch_size, attn_length) 99 | attn_dist *= enc_padding_mask # apply mask 100 | masked_sums = tf.reduce_sum(attn_dist, axis=1) # shape (batch_size) 101 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) # re-normalize 102 | 103 | if use_coverage and coverage is not None: # non-first step of coverage 104 | # Multiply coverage vector by w_c to get coverage_features. 105 | coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, attn_length, 1, attention_vec_size) 106 | 107 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn) 108 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) # shape (batch_size,attn_length) 109 | 110 | # Calculate attention distribution 111 | attn_dist = masked_attention(e) 112 | 113 | # Update coverage vector 114 | coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) 115 | else: 116 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn) 117 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e 118 | 119 | # Calculate attention distribution 120 | attn_dist = masked_attention(e) 121 | 122 | if use_coverage: # first step of training 123 | coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage 124 | 125 | # Calculate the context vector from attn_dist and encoder_states 126 | context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) # shape (batch_size, attn_size). 127 | context_vector = array_ops.reshape(context_vector, [-1, attn_size]) 128 | 129 | return context_vector, attn_dist, coverage 130 | 131 | outputs = [] 132 | attn_dists = [] 133 | p_gens = [] 134 | state = initial_state 135 | coverage = prev_coverage # initialize coverage to None or whatever was passed in 136 | context_vector = array_ops.zeros([batch_size, attn_size]) 137 | context_vector.set_shape([None, attn_size]) # Ensure the second shape of attention vectors is set. 138 | if initial_state_attention: # true in decode mode 139 | # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input 140 | context_vector, _, coverage = attention(initial_state, coverage) # in decode mode, this is what updates the coverage vector 141 | for i, inp in enumerate(decoder_inputs): 142 | tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(decoder_inputs)) 143 | if i > 0: 144 | variable_scope.get_variable_scope().reuse_variables() 145 | 146 | # Merge input and previous attentions into one vector x of the same size as inp 147 | input_size = inp.get_shape().with_rank(2)[1] 148 | if input_size.value is None: 149 | raise ValueError("Could not infer input size from input: %s" % inp.name) 150 | x = linear([inp] + [context_vector], input_size, True) 151 | 152 | # Run the decoder RNN cell. cell_output = decoder state 153 | cell_output, state = cell(x, state) 154 | 155 | # Run the attention mechanism. 156 | if i == 0 and initial_state_attention: # always true in decode mode 157 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call 158 | context_vector, attn_dist, _ = attention(state, coverage) # don't allow coverage to update 159 | else: 160 | context_vector, attn_dist, coverage = attention(state, coverage) 161 | attn_dists.append(attn_dist) 162 | 163 | # Calculate p_gen 164 | if pointer_gen: 165 | with tf.variable_scope('calculate_pgen'): 166 | p_gen = linear([context_vector, state.c, state.h, x], 1, True) # a scalar 167 | p_gen = tf.sigmoid(p_gen) 168 | p_gens.append(p_gen) 169 | 170 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer 171 | # This is V[s_t, h*_t] + b in the paper 172 | with variable_scope.variable_scope("AttnOutputProjection"): 173 | output = linear([cell_output] + [context_vector], cell.output_size, True) 174 | outputs.append(output) 175 | 176 | # If using coverage, reshape it 177 | if coverage is not None: 178 | coverage = array_ops.reshape(coverage, [batch_size, -1]) 179 | 180 | return outputs, state, attn_dists, p_gens, coverage 181 | 182 | 183 | 184 | def linear(args, output_size, bias, bias_start=0.0, scope=None): 185 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 186 | 187 | Args: 188 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 189 | output_size: int, second dimension of W[i]. 190 | bias: boolean, whether to add a bias term or not. 191 | bias_start: starting value to initialize the bias; 0 by default. 192 | scope: VariableScope for the created subgraph; defaults to "Linear". 193 | 194 | Returns: 195 | A 2D Tensor with shape [batch x output_size] equal to 196 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 197 | 198 | Raises: 199 | ValueError: if some of the arguments has unspecified or wrong shape. 200 | """ 201 | if args is None or (isinstance(args, (list, tuple)) and not args): 202 | raise ValueError("`args` must be specified") 203 | if not isinstance(args, (list, tuple)): 204 | args = [args] 205 | 206 | # Calculate the total size of arguments on dimension 1. 207 | total_arg_size = 0 208 | shapes = [a.get_shape().as_list() for a in args] 209 | for shape in shapes: 210 | if len(shape) != 2: 211 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) 212 | if not shape[1]: 213 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) 214 | else: 215 | total_arg_size += shape[1] 216 | 217 | # Now the computation. 218 | with tf.variable_scope(scope or "Linear"): 219 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size]) 220 | if len(args) == 1: 221 | res = tf.matmul(args[0], matrix) 222 | else: 223 | res = tf.matmul(tf.concat(axis=1, values=args), matrix) 224 | if not bias: 225 | return res 226 | bias_term = tf.get_variable( 227 | "Bias", [output_size], initializer=tf.constant_initializer(bias_start)) 228 | return res + bias_term 229 | -------------------------------------------------------------------------------- /batcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to process data into batches""" 18 | 19 | import queue as Queue 20 | from random import shuffle 21 | from threading import Thread 22 | import time 23 | import numpy as np 24 | import tensorflow as tf 25 | import data 26 | 27 | 28 | class Example(object): 29 | """Class representing a train/val/test example for text summarization.""" 30 | 31 | def __init__(self, article, abstract_sentences, vocab, hps): 32 | """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self. 33 | 34 | Args: 35 | article: source text; a string. each token is separated by a single space. 36 | abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space. 37 | vocab: Vocabulary object 38 | hps: hyperparameters 39 | """ 40 | self.hps = hps 41 | 42 | # Get ids of special tokens 43 | start_decoding = vocab.word2id(data.START_DECODING) 44 | stop_decoding = vocab.word2id(data.STOP_DECODING) 45 | 46 | # Process the article 47 | article_words = article.split() 48 | if len(article_words) > hps.max_enc_steps: 49 | article_words = article_words[:hps.max_enc_steps] 50 | self.enc_len = len(article_words) # store the length after truncation but before padding 51 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token 52 | 53 | # Process the abstract 54 | abstract = ' '.join(abstract_sentences) # string 55 | abstract_words = abstract.split() # list of strings 56 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token 57 | 58 | # Get the decoder input sequence and target sequence 59 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding, stop_decoding) 60 | self.dec_len = len(self.dec_input) 61 | 62 | # If using pointer-generator mode, we need to store some extra info 63 | if hps.pointer_gen: 64 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves 65 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab) 66 | 67 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id 68 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs) 69 | 70 | # Overwrite decoder target sequence so it uses the temp article OOV ids 71 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding, stop_decoding) 72 | 73 | # Store the original strings 74 | self.original_article = article 75 | self.original_abstract = abstract 76 | self.original_abstract_sents = abstract_sentences 77 | 78 | 79 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id): 80 | """Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated). 81 | 82 | Args: 83 | sequence: List of ids (integers) 84 | max_len: integer 85 | start_id: integer 86 | stop_id: integer 87 | 88 | Returns: 89 | inp: sequence length <=max_len starting with start_id 90 | target: sequence same length as input, ending with stop_id only if there was no truncation 91 | """ 92 | inp = [start_id] + sequence[:] 93 | target = sequence[:] 94 | if len(inp) > max_len: # truncate 95 | inp = inp[:max_len] 96 | target = target[:max_len] # no end_token 97 | else: # no truncation 98 | target.append(stop_id) # end token 99 | assert len(inp) == len(target) 100 | return inp, target 101 | 102 | 103 | def pad_decoder_inp_targ(self, max_len, pad_id): 104 | """Pad decoder input and target sequences with pad_id up to max_len.""" 105 | while len(self.dec_input) < max_len: 106 | self.dec_input.append(pad_id) 107 | while len(self.target) < max_len: 108 | self.target.append(pad_id) 109 | 110 | 111 | def pad_encoder_input(self, max_len, pad_id): 112 | """Pad the encoder input sequence with pad_id up to max_len.""" 113 | while len(self.enc_input) < max_len: 114 | self.enc_input.append(pad_id) 115 | if self.hps.pointer_gen: 116 | while len(self.enc_input_extend_vocab) < max_len: 117 | self.enc_input_extend_vocab.append(pad_id) 118 | 119 | 120 | class Batch(object): 121 | """Class representing a minibatch of train/val/test examples for text summarization.""" 122 | 123 | def __init__(self, example_list, hps, vocab): 124 | """Turns the example_list into a Batch object. 125 | 126 | Args: 127 | example_list: List of Example objects 128 | hps: hyperparameters 129 | vocab: Vocabulary object 130 | """ 131 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences 132 | self.init_encoder_seq(example_list, hps) # initialize the input to the encoder 133 | self.init_decoder_seq(example_list, hps) # initialize the input and targets for the decoder 134 | self.store_orig_strings(example_list) # store the original strings 135 | 136 | def init_encoder_seq(self, example_list, hps): 137 | """Initializes the following: 138 | self.enc_batch: 139 | numpy array of shape (batch_size, <=max_enc_steps) containing integer ids (all OOVs represented by UNK id), padded to length of longest sequence in the batch 140 | self.enc_lens: 141 | numpy array of shape (batch_size) containing integers. The (truncated) length of each encoder input sequence (pre-padding). 142 | self.enc_padding_mask: 143 | numpy array of shape (batch_size, <=max_enc_steps), containing 1s and 0s. 1s correspond to real tokens in enc_batch and target_batch; 0s correspond to padding. 144 | 145 | If hps.pointer_gen, additionally initializes the following: 146 | self.max_art_oovs: 147 | maximum number of in-article OOVs in the batch 148 | self.art_oovs: 149 | list of list of in-article OOVs (strings), for each example in the batch 150 | self.enc_batch_extend_vocab: 151 | Same as self.enc_batch, but in-article OOVs are represented by their temporary article OOV number. 152 | """ 153 | # Determine the maximum length of the encoder input sequence in this batch 154 | max_enc_seq_len = max([ex.enc_len for ex in example_list]) 155 | 156 | # Pad the encoder input sequences up to the length of the longest sequence 157 | for ex in example_list: 158 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id) 159 | 160 | # Initialize the numpy arrays 161 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder. 162 | self.enc_batch = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32) 163 | self.enc_lens = np.zeros((hps.batch_size), dtype=np.int32) 164 | self.enc_padding_mask = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.float32) 165 | 166 | # Fill in the numpy arrays 167 | for i, ex in enumerate(example_list): 168 | self.enc_batch[i, :] = ex.enc_input[:] 169 | self.enc_lens[i] = ex.enc_len 170 | for j in range(ex.enc_len): 171 | self.enc_padding_mask[i][j] = 1 172 | 173 | # For pointer-generator mode, need to store some extra info 174 | if hps.pointer_gen: 175 | # Determine the max number of in-article OOVs in this batch 176 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list]) 177 | # Store the in-article OOVs themselves 178 | self.art_oovs = [ex.article_oovs for ex in example_list] 179 | # Store the version of the enc_batch that uses the article OOV ids 180 | self.enc_batch_extend_vocab = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32) 181 | for i, ex in enumerate(example_list): 182 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:] 183 | 184 | def init_decoder_seq(self, example_list, hps): 185 | """Initializes the following: 186 | self.dec_batch: 187 | numpy array of shape (batch_size, max_dec_steps), containing integer ids as input for the decoder, padded to max_dec_steps length. 188 | self.target_batch: 189 | numpy array of shape (batch_size, max_dec_steps), containing integer ids for the target sequence, padded to max_dec_steps length. 190 | self.dec_padding_mask: 191 | numpy array of shape (batch_size, max_dec_steps), containing 1s and 0s. 1s correspond to real tokens in dec_batch and target_batch; 0s correspond to padding. 192 | """ 193 | # Pad the inputs and targets 194 | for ex in example_list: 195 | ex.pad_decoder_inp_targ(hps.max_dec_steps, self.pad_id) 196 | 197 | # Initialize the numpy arrays. 198 | # Note: our decoder inputs and targets must be the same length for each batch (second dimension = max_dec_steps) because we do not use a dynamic_rnn for decoding. However I believe this is possible, or will soon be possible, with Tensorflow 1.0, in which case it may be best to upgrade to that. 199 | self.dec_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32) 200 | self.target_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32) 201 | self.dec_padding_mask = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.float32) 202 | 203 | # Fill in the numpy arrays 204 | for i, ex in enumerate(example_list): 205 | self.dec_batch[i, :] = ex.dec_input[:] 206 | self.target_batch[i, :] = ex.target[:] 207 | for j in range(ex.dec_len): 208 | self.dec_padding_mask[i][j] = 1 209 | 210 | def store_orig_strings(self, example_list): 211 | """Store the original article and abstract strings in the Batch object""" 212 | self.original_articles = [ex.original_article for ex in example_list] # list of lists 213 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists 214 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists 215 | 216 | 217 | class Batcher(object): 218 | """A class to generate minibatches of data. Buckets examples together based on length of the encoder sequence.""" 219 | 220 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold 221 | 222 | def __init__(self, data_path, vocab, hps, single_pass): 223 | """Initialize the batcher. Start threads that process the data into batches. 224 | 225 | Args: 226 | data_path: tf.Example filepattern. 227 | vocab: Vocabulary object 228 | hps: hyperparameters 229 | single_pass: If True, run through the dataset exactly once (useful for when you want to run evaluation on the dev or test set). Otherwise generate random batches indefinitely (useful for training). 230 | """ 231 | self._data_path = data_path 232 | self._vocab = vocab 233 | self._hps = hps 234 | self._single_pass = single_pass 235 | 236 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched 237 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX) 238 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size) 239 | 240 | # Different settings depending on whether we're in single_pass mode or not 241 | if single_pass: 242 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once 243 | self._num_batch_q_threads = 1 # just one thread to batch examples 244 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing 245 | self._finished_reading = False # this will tell us when we're finished reading the dataset 246 | else: 247 | self._num_example_q_threads = 16 # num threads to fill example queue 248 | self._num_batch_q_threads = 4 # num threads to fill batch queue 249 | self._bucketing_cache_size = 100 # how many batches-worth of examples to load into cache before bucketing 250 | 251 | # Start the threads that load the queues 252 | self._example_q_threads = [] 253 | for _ in range(self._num_example_q_threads): 254 | self._example_q_threads.append(Thread(target=self.fill_example_queue)) 255 | self._example_q_threads[-1].daemon = True 256 | self._example_q_threads[-1].start() 257 | self._batch_q_threads = [] 258 | for _ in range(self._num_batch_q_threads): 259 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) 260 | self._batch_q_threads[-1].daemon = True 261 | self._batch_q_threads[-1].start() 262 | 263 | # Start a thread that watches the other threads and restarts them if they're dead 264 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever 265 | self._watch_thread = Thread(target=self.watch_threads) 266 | self._watch_thread.daemon = True 267 | self._watch_thread.start() 268 | 269 | 270 | def next_batch(self): 271 | """Return a Batch from the batch queue. 272 | 273 | If mode='decode' then each batch contains a single example repeated beam_size-many times; this is necessary for beam search. 274 | 275 | Returns: 276 | batch: a Batch object, or None if we're in single_pass mode and we've exhausted the dataset. 277 | """ 278 | # If the batch queue is empty, print a warning 279 | if self._batch_queue.qsize() == 0: 280 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize()) 281 | if self._single_pass and self._finished_reading: 282 | tf.logging.info("Finished reading dataset in single_pass mode.") 283 | return None 284 | 285 | batch = self._batch_queue.get() # get the next Batch 286 | return batch 287 | 288 | def fill_example_queue(self): 289 | """Reads data from file and processes into Examples which are then placed into the example queue.""" 290 | 291 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass)) 292 | 293 | while True: 294 | try: 295 | (article, abstract) = next(input_gen) # read the next example from file. article and abstract are both strings. 296 | except StopIteration: # if there are no more examples: 297 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.") 298 | if self._single_pass: 299 | tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.") 300 | self._finished_reading = True 301 | break 302 | else: 303 | raise Exception("single_pass mode is off but the example generator is out of data; error.") 304 | 305 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences. 306 | example = Example(article, abstract_sentences, self._vocab, self._hps) # Process into an Example. 307 | self._example_queue.put(example) # place the Example in the example queue. 308 | 309 | 310 | def fill_batch_queue(self): 311 | """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue. 312 | 313 | In decode mode, makes batches that each contain a single example repeated. 314 | """ 315 | while True: 316 | if self._hps.mode != 'decode': 317 | # Get bucketing_cache_size-many batches of Examples into a list, then sort 318 | inputs = [] 319 | for _ in range(self._hps.batch_size * self._bucketing_cache_size): 320 | inputs.append(self._example_queue.get()) 321 | inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence 322 | 323 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. 324 | batches = [] 325 | for i in range(0, len(inputs), self._hps.batch_size): 326 | batches.append(inputs[i:i + self._hps.batch_size]) 327 | if not self._single_pass: 328 | shuffle(batches) 329 | for b in batches: # each b is a list of Example objects 330 | self._batch_queue.put(Batch(b, self._hps, self._vocab)) 331 | 332 | else: # beam search decode mode 333 | ex = self._example_queue.get() 334 | b = [ex for _ in range(self._hps.batch_size)] 335 | self._batch_queue.put(Batch(b, self._hps, self._vocab)) 336 | 337 | 338 | def watch_threads(self): 339 | """Watch example queue and batch queue threads and restart if dead.""" 340 | while True: 341 | time.sleep(60) 342 | for idx,t in enumerate(self._example_q_threads): 343 | if not t.is_alive(): # if the thread is dead 344 | tf.logging.error('Found example queue thread dead. Restarting.') 345 | new_t = Thread(target=self.fill_example_queue) 346 | self._example_q_threads[idx] = new_t 347 | new_t.daemon = True 348 | new_t.start() 349 | for idx,t in enumerate(self._batch_q_threads): 350 | if not t.is_alive(): # if the thread is dead 351 | tf.logging.error('Found batch queue thread dead. Restarting.') 352 | new_t = Thread(target=self.fill_batch_queue) 353 | self._batch_q_threads[idx] = new_t 354 | new_t.daemon = True 355 | new_t.start() 356 | 357 | 358 | def text_generator(self, example_generator): 359 | """Generates article and abstract text from tf.Example. 360 | 361 | Args: 362 | example_generator: a generator of tf.Examples from file. See data.example_generator""" 363 | while True: 364 | e = next(example_generator) # e is a tf.Example 365 | try: 366 | article_text = e.features.feature['article'].bytes_list.value[0].decode() # the article text was saved under the key 'article' in the data files 367 | abstract_text = e.features.feature['abstract'].bytes_list.value[0].decode() # the abstract text was saved under the key 'abstract' in the data files 368 | except ValueError: 369 | tf.logging.error('Failed to get article or abstract from example') 370 | continue 371 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1 372 | tf.logging.warning('Found an example with empty article text. Skipping it.') 373 | else: 374 | yield (article_text, abstract_text) 375 | -------------------------------------------------------------------------------- /beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding""" 18 | 19 | import tensorflow as tf 20 | import numpy as np 21 | import data 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | class Hypothesis(object): 26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis.""" 27 | 28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage): 29 | """Hypothesis constructor. 30 | 31 | Args: 32 | tokens: List of integers. The ids of the tokens that form the summary so far. 33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far. 34 | state: Current state of the decoder, a LSTMStateTuple. 35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far. 36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far. 37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector. 38 | """ 39 | self.tokens = tokens 40 | self.log_probs = log_probs 41 | self.state = state 42 | self.attn_dists = attn_dists 43 | self.p_gens = p_gens 44 | self.coverage = coverage 45 | 46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage): 47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search. 48 | 49 | Args: 50 | token: Integer. Latest token produced by beam search. 51 | log_prob: Float. Log prob of the latest token. 52 | state: Current decoder state, a LSTMStateTuple. 53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length). 54 | p_gen: Generation probability on latest step. Float. 55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage. 56 | Returns: 57 | New Hypothesis for next step. 58 | """ 59 | return Hypothesis(tokens = self.tokens + [token], 60 | log_probs = self.log_probs + [log_prob], 61 | state = state, 62 | attn_dists = self.attn_dists + [attn_dist], 63 | p_gens = self.p_gens + [p_gen], 64 | coverage = coverage) 65 | 66 | @property 67 | def latest_token(self): 68 | return self.tokens[-1] 69 | 70 | @property 71 | def log_prob(self): 72 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far 73 | return sum(self.log_probs) 74 | 75 | @property 76 | def avg_log_prob(self): 77 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability) 78 | return self.log_prob / len(self.tokens) 79 | 80 | 81 | def run_beam_search(sess, model, vocab, batch): 82 | """Performs beam search decoding on the given example. 83 | 84 | Args: 85 | sess: a tf.Session 86 | model: a seq2seq model 87 | vocab: Vocabulary object 88 | batch: Batch object that is the same example repeated across the batch 89 | 90 | Returns: 91 | best_hyp: Hypothesis object; the best hypothesis found by beam search. 92 | """ 93 | # Run the encoder to get the encoder hidden states and decoder initial state 94 | enc_states, dec_in_state = model.run_encoder(sess, batch) 95 | # dec_in_state is a LSTMStateTuple 96 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 97 | 98 | # Initialize beam_size-many hyptheses 99 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)], 100 | log_probs=[0.0], 101 | state=dec_in_state, 102 | attn_dists=[], 103 | p_gens=[], 104 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length 105 | ) for _ in range(FLAGS.beam_size)] 106 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token) 107 | 108 | steps = 0 109 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size: 110 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis 111 | latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings 112 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses 113 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None) 114 | 115 | # Run one step of the decoder to get the new info 116 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess, 117 | batch=batch, 118 | latest_tokens=latest_tokens, 119 | enc_states=enc_states, 120 | dec_init_states=states, 121 | prev_coverage=prev_coverage) 122 | 123 | # Extend each hypothesis and collect them all in all_hyps 124 | all_hyps = [] 125 | num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct. 126 | for i in range(num_orig_hyps): 127 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info 128 | for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps: 129 | # Extend the ith hypothesis with the jth option 130 | new_hyp = h.extend(token=topk_ids[i, j], 131 | log_prob=topk_log_probs[i, j], 132 | state=new_state, 133 | attn_dist=attn_dist, 134 | p_gen=p_gen, 135 | coverage=new_coverage_i) 136 | all_hyps.append(new_hyp) 137 | 138 | # Filter and collect any hypotheses that have produced the end token. 139 | hyps = [] # will contain hypotheses for the next step 140 | for h in sort_hyps(all_hyps): # in order of most likely h 141 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached... 142 | # If this hypothesis is sufficiently long, put in results. Otherwise discard. 143 | if steps >= FLAGS.min_dec_steps: 144 | results.append(h) 145 | else: # hasn't reached stop token, so continue to extend this hypothesis 146 | hyps.append(h) 147 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 148 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop. 149 | break 150 | 151 | steps += 1 152 | 153 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps 154 | 155 | if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 156 | results = hyps 157 | 158 | # Sort hypotheses by average log probability 159 | hyps_sorted = sort_hyps(results) 160 | 161 | # Return the hypothesis with highest average log prob 162 | return hyps_sorted[0] 163 | 164 | def sort_hyps(hyps): 165 | """Return a list of Hypothesis objects, sorted by descending average log probability""" 166 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True) 167 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it""" 18 | 19 | import glob 20 | import random 21 | import struct 22 | import csv 23 | from tensorflow.core.example import example_pb2 24 | 25 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 26 | SENTENCE_START = '' 27 | SENTENCE_END = '' 28 | 29 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 30 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 31 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 32 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 33 | 34 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 35 | 36 | 37 | class Vocab(object): 38 | """Vocabulary class for mapping between words and ids (integers)""" 39 | 40 | def __init__(self, vocab_file, max_size): 41 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. 42 | 43 | Args: 44 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. 45 | max_size: integer. The maximum size of the resulting Vocabulary.""" 46 | self._word_to_id = {} 47 | self._id_to_word = {} 48 | self._count = 0 # keeps track of total number of words in the Vocab 49 | 50 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. 51 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 52 | self._word_to_id[w] = self._count 53 | self._id_to_word[self._count] = w 54 | self._count += 1 55 | 56 | # Read the vocab file and add words up to max_size 57 | with open(vocab_file, 'r') as vocab_f: 58 | for line in vocab_f: 59 | pieces = line.split() 60 | if len(pieces) != 2: 61 | print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line) 62 | continue 63 | w = pieces[0] 64 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 65 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 66 | if w in self._word_to_id: 67 | raise Exception('Duplicated word in vocabulary file: %s' % w) 68 | self._word_to_id[w] = self._count 69 | self._id_to_word[self._count] = w 70 | self._count += 1 71 | if max_size != 0 and self._count >= max_size: 72 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) 73 | break 74 | 75 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])) 76 | 77 | def word2id(self, word): 78 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" 79 | if word not in self._word_to_id: 80 | return self._word_to_id[UNKNOWN_TOKEN] 81 | return self._word_to_id[word] 82 | 83 | def id2word(self, word_id): 84 | """Returns the word (string) corresponding to an id (integer).""" 85 | if word_id not in self._id_to_word: 86 | raise ValueError('Id not found in vocab: %d' % word_id) 87 | return self._id_to_word[word_id] 88 | 89 | def size(self): 90 | """Returns the total size of the vocabulary""" 91 | return self._count 92 | 93 | def write_metadata(self, fpath): 94 | """Writes metadata file for Tensorboard word embedding visualizer as described here: 95 | https://www.tensorflow.org/get_started/embedding_viz 96 | 97 | Args: 98 | fpath: place to write the metadata file 99 | """ 100 | print("Writing word embedding metadata file to %s..." % (fpath)) 101 | with open(fpath, "w") as f: 102 | fieldnames = ['word'] 103 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 104 | for i in range(self.size()): 105 | writer.writerow({"word": self._id_to_word[i]}) 106 | 107 | 108 | def example_generator(data_path, single_pass): 109 | """Generates tf.Examples from data files. 110 | 111 | Binary data format: . represents the byte size 112 | of . is serialized tf.Example proto. The tf.Example contains 113 | the tokenized article text and summary. 114 | 115 | Args: 116 | data_path: 117 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all. 118 | single_pass: 119 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely. 120 | 121 | Yields: 122 | Deserialized tf.Example. 123 | """ 124 | while True: 125 | filelist = glob.glob(data_path) # get the list of datafiles 126 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 127 | if single_pass: 128 | filelist = sorted(filelist) 129 | else: 130 | random.shuffle(filelist) 131 | for f in filelist: 132 | reader = open(f, 'rb') 133 | while True: 134 | len_bytes = reader.read(8) 135 | if not len_bytes: break # finished reading this file 136 | str_len = struct.unpack('q', len_bytes)[0] 137 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 138 | yield example_pb2.Example.FromString(example_str) 139 | if single_pass: 140 | print("example_generator completed reading all datafiles. No more data.") 141 | break 142 | 143 | 144 | def article2ids(article_words, vocab): 145 | """Map the article words to their ids. Also return a list of OOVs in the article. 146 | 147 | Args: 148 | article_words: list of words (strings) 149 | vocab: Vocabulary object 150 | 151 | Returns: 152 | ids: 153 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002. 154 | oovs: 155 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers.""" 156 | ids = [] 157 | oovs = [] 158 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 159 | for w in article_words: 160 | i = vocab.word2id(w) 161 | if i == unk_id: # If w is OOV 162 | if w not in oovs: # Add to list of OOVs 163 | oovs.append(w) 164 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 165 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 166 | else: 167 | ids.append(i) 168 | return ids, oovs 169 | 170 | 171 | def abstract2ids(abstract_words, vocab, article_oovs): 172 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers. 173 | 174 | Args: 175 | abstract_words: list of words (strings) 176 | vocab: Vocabulary object 177 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers 178 | 179 | Returns: 180 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id.""" 181 | ids = [] 182 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 183 | for w in abstract_words: 184 | i = vocab.word2id(w) 185 | if i == unk_id: # If w is an OOV word 186 | if w in article_oovs: # If w is an in-article OOV 187 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 188 | ids.append(vocab_idx) 189 | else: # If w is an out-of-article OOV 190 | ids.append(unk_id) # Map to the UNK token id 191 | else: 192 | ids.append(i) 193 | return ids 194 | 195 | 196 | def outputids2words(id_list, vocab, article_oovs): 197 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode). 198 | 199 | Args: 200 | id_list: list of ids (integers) 201 | vocab: Vocabulary object 202 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode) 203 | 204 | Returns: 205 | words: list of words (strings) 206 | """ 207 | words = [] 208 | for i in id_list: 209 | try: 210 | w = vocab.id2word(i) # might be [UNK] 211 | except ValueError as e: # w is OOV 212 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 213 | article_oov_idx = i - vocab.size() 214 | try: 215 | w = article_oovs[article_oov_idx] 216 | except ValueError as e: # i doesn't correspond to an article oov 217 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 218 | words.append(w) 219 | return words 220 | 221 | 222 | def abstract2sents(abstract): 223 | """Splits abstract text from datafile into list of sentences. 224 | 225 | Args: 226 | abstract: string containing and tags for starts and ends of sentences 227 | 228 | Returns: 229 | sents: List of sentence strings (no tags)""" 230 | cur = 0 231 | sents = [] 232 | while True: 233 | try: 234 | start_p = abstract.index(SENTENCE_START, cur) 235 | end_p = abstract.index(SENTENCE_END, start_p + 1) 236 | cur = end_p + len(SENTENCE_END) 237 | sents.append(abstract[start_p+len(SENTENCE_START):end_p]) 238 | except ValueError as e: # no more sentences 239 | return sents 240 | 241 | 242 | def show_art_oovs(article, vocab): 243 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them""" 244 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 245 | words = article.split(' ') 246 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words] 247 | out_str = ' '.join(words) 248 | return out_str 249 | 250 | 251 | def show_abs_oovs(abstract, vocab, article_oovs): 252 | """Returns the abstract string, highlighting the article OOVs with __underscores__. 253 | 254 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!. 255 | 256 | Args: 257 | abstract: string 258 | vocab: Vocabulary object 259 | article_oovs: list of words (strings), or None (in baseline mode) 260 | """ 261 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 262 | words = abstract.split(' ') 263 | new_words = [] 264 | for w in words: 265 | if vocab.word2id(w) == unk_token: # w is oov 266 | if article_oovs is None: # baseline mode 267 | new_words.append("__%s__" % w) 268 | else: # pointer-generator mode 269 | if w in article_oovs: 270 | new_words.append("__%s__" % w) 271 | else: 272 | new_words.append("!!__%s__!!" % w) 273 | else: # w is in-vocab word 274 | new_words.append(w) 275 | out_str = ' '.join(new_words) 276 | return out_str 277 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding, including running ROUGE evaluation and producing JSON datafiles for the in-browser attention visualizer, which can be found here https://github.com/abisee/attn_vis""" 18 | 19 | import os 20 | import time 21 | import tensorflow as tf 22 | import beam_search 23 | import data 24 | import json 25 | import pyrouge 26 | import util 27 | import logging 28 | import numpy as np 29 | 30 | FLAGS = tf.app.flags.FLAGS 31 | 32 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint 33 | 34 | 35 | class BeamSearchDecoder(object): 36 | """Beam search decoder.""" 37 | 38 | def __init__(self, model, batcher, vocab): 39 | """Initialize decoder. 40 | 41 | Args: 42 | model: a Seq2SeqAttentionModel object. 43 | batcher: a Batcher object. 44 | vocab: Vocabulary object 45 | """ 46 | self._model = model 47 | self._model.build_graph() 48 | self._batcher = batcher 49 | self._vocab = vocab 50 | self._saver = tf.train.Saver() # we use this to load checkpoints for decoding 51 | self._sess = tf.Session(config=util.get_config()) 52 | 53 | # Load an initial checkpoint to use for decoding 54 | ckpt_path = util.load_ckpt(self._saver, self._sess) 55 | 56 | if FLAGS.single_pass: 57 | # Make a descriptive decode directory name 58 | ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456" 59 | self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) 60 | if os.path.exists(self._decode_dir): 61 | raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir) 62 | 63 | else: # Generic decode dir name 64 | self._decode_dir = os.path.join(FLAGS.log_root, "decode") 65 | 66 | # Make the decode dir if necessary 67 | if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) 68 | 69 | if FLAGS.single_pass: 70 | # Make the dirs to contain output written in the correct format for pyrouge 71 | self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") 72 | if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) 73 | self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") 74 | if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir) 75 | 76 | 77 | def decode(self): 78 | """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals""" 79 | t0 = time.time() 80 | counter = 0 81 | while True: 82 | batch = self._batcher.next_batch() # 1 example repeated across batch 83 | if batch is None: # finished decoding dataset in single_pass mode 84 | assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode" 85 | tf.logging.info("Decoder has finished reading dataset for single_pass.") 86 | tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir) 87 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) 88 | rouge_log(results_dict, self._decode_dir) 89 | return 90 | 91 | original_article = batch.original_articles[0] # string 92 | original_abstract = batch.original_abstracts[0] # string 93 | original_abstract_sents = batch.original_abstracts_sents[0] # list of strings 94 | 95 | article_withunks = data.show_art_oovs(original_article, self._vocab) # string 96 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string 97 | 98 | # Run beam search to get best Hypothesis 99 | best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch) 100 | 101 | # Extract the output ids from the hypothesis and convert back to words 102 | output_ids = [int(t) for t in best_hyp.tokens[1:]] 103 | decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) 104 | 105 | # Remove the [STOP] token from decoded_words, if necessary 106 | try: 107 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol 108 | decoded_words = decoded_words[:fst_stop_idx] 109 | except ValueError: 110 | decoded_words = decoded_words 111 | decoded_output = ' '.join(decoded_words) # single string 112 | 113 | if FLAGS.single_pass: 114 | self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later 115 | counter += 1 # this is how many examples we've decoded 116 | else: 117 | print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen 118 | self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool 119 | 120 | # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint 121 | t1 = time.time() 122 | if t1-t0 > SECS_UNTIL_NEW_CKPT: 123 | tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0) 124 | _ = util.load_ckpt(self._saver, self._sess) 125 | t0 = time.time() 126 | 127 | def write_for_rouge(self, reference_sents, decoded_words, ex_index): 128 | """Write output to file in correct format for eval with pyrouge. This is called in single_pass mode. 129 | 130 | Args: 131 | reference_sents: list of strings 132 | decoded_words: list of strings 133 | ex_index: int, the index with which to label the files 134 | """ 135 | # First, divide decoded output into sentences 136 | decoded_sents = [] 137 | while len(decoded_words) > 0: 138 | try: 139 | fst_period_idx = decoded_words.index(".") 140 | except ValueError: # there is text remaining that doesn't end in "." 141 | fst_period_idx = len(decoded_words) 142 | sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period 143 | decoded_words = decoded_words[fst_period_idx+1:] # everything else 144 | decoded_sents.append(' '.join(sent)) 145 | 146 | # pyrouge calls a perl script that puts the data into HTML files. 147 | # Therefore we need to make our output HTML safe. 148 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 149 | reference_sents = [make_html_safe(w) for w in reference_sents] 150 | 151 | # Write to file 152 | ref_file = os.path.join(self._rouge_ref_dir, "%06d_reference.txt" % ex_index) 153 | decoded_file = os.path.join(self._rouge_dec_dir, "%06d_decoded.txt" % ex_index) 154 | 155 | with open(ref_file, "w") as f: 156 | for idx,sent in enumerate(reference_sents): 157 | f.write(sent) if idx==len(reference_sents)-1 else f.write(sent+"\n") 158 | with open(decoded_file, "w") as f: 159 | for idx,sent in enumerate(decoded_sents): 160 | f.write(sent) if idx==len(decoded_sents)-1 else f.write(sent+"\n") 161 | 162 | tf.logging.info("Wrote example %i to file" % ex_index) 163 | 164 | 165 | def write_for_attnvis(self, article, abstract, decoded_words, attn_dists, p_gens): 166 | """Write some data to json file, which can be read into the in-browser attention visualizer tool: 167 | https://github.com/abisee/attn_vis 168 | 169 | Args: 170 | article: The original article string. 171 | abstract: The human (correct) abstract string. 172 | attn_dists: List of arrays; the attention distributions. 173 | decoded_words: List of strings; the words of the generated summary. 174 | p_gens: List of scalars; the p_gen values. If not running in pointer-generator mode, list of None. 175 | """ 176 | article_lst = article.split() # list of words 177 | decoded_lst = decoded_words # list of decoded words 178 | to_write = { 179 | 'article_lst': [make_html_safe(t) for t in article_lst], 180 | 'decoded_lst': [make_html_safe(t) for t in decoded_lst], 181 | 'abstract_str': make_html_safe(abstract), 182 | 'attn_dists': attn_dists 183 | } 184 | if FLAGS.pointer_gen: 185 | to_write['p_gens'] = p_gens 186 | output_fname = os.path.join(self._decode_dir, 'attn_vis_data.json') 187 | with open(output_fname, 'w') as output_file: 188 | json.dump(to_write, output_file) 189 | tf.logging.info('Wrote visualization data to %s', output_fname) 190 | 191 | 192 | def print_results(article, abstract, decoded_output): 193 | """Prints the article, the reference summmary and the decoded summary to screen""" 194 | print("---------------------------------------------------------------------------") 195 | tf.logging.info('ARTICLE: %s', article) 196 | tf.logging.info('REFERENCE SUMMARY: %s', abstract) 197 | tf.logging.info('GENERATED SUMMARY: %s', decoded_output) 198 | print("---------------------------------------------------------------------------") 199 | 200 | 201 | def make_html_safe(s): 202 | """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer.""" 203 | s.replace("<", "<") 204 | s.replace(">", ">") 205 | return s 206 | 207 | 208 | def rouge_eval(ref_dir, dec_dir): 209 | """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict""" 210 | r = pyrouge.Rouge155() 211 | r.model_filename_pattern = '#ID#_reference.txt' 212 | r.system_filename_pattern = '(\d+)_decoded.txt' 213 | r.model_dir = ref_dir 214 | r.system_dir = dec_dir 215 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging 216 | rouge_results = r.convert_and_evaluate() 217 | return r.output_to_dict(rouge_results) 218 | 219 | 220 | def rouge_log(results_dict, dir_to_write): 221 | """Log ROUGE results to screen and write to file. 222 | 223 | Args: 224 | results_dict: the dictionary returned by pyrouge 225 | dir_to_write: the directory where we will write the results to""" 226 | log_str = "" 227 | for x in ["1","2","l"]: 228 | log_str += "\nROUGE-%s:\n" % x 229 | for y in ["f_score", "recall", "precision"]: 230 | key = "rouge_%s_%s" % (x,y) 231 | key_cb = key + "_cb" 232 | key_ce = key + "_ce" 233 | val = results_dict[key] 234 | val_cb = results_dict[key_cb] 235 | val_ce = results_dict[key_ce] 236 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 237 | tf.logging.info(log_str) # log to screen 238 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt") 239 | tf.logging.info("Writing final ROUGE results to %s...", results_file) 240 | with open(results_file, "w") as f: 241 | f.write(log_str) 242 | 243 | def get_decode_dir_name(ckpt_name): 244 | """Make a descriptive name for the decode dir, including the name of the checkpoint we use to decode. This is called in single_pass mode.""" 245 | 246 | if "train" in FLAGS.data_path: dataset = "train" 247 | elif "val" in FLAGS.data_path: dataset = "val" 248 | elif "test" in FLAGS.data_path: dataset = "test" 249 | else: raise ValueError("FLAGS.data_path %s should contain one of train, val or test" % (FLAGS.data_path)) 250 | dirname = "decode_%s_%imaxenc_%ibeam_%imindec_%imaxdec" % (dataset, FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps) 251 | if ckpt_name is not None: 252 | dirname += "_%s" % ckpt_name 253 | return dirname 254 | -------------------------------------------------------------------------------- /inspect_checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this: 3 | python inspect_checkpoint.py model.12345 4 | """ 5 | 6 | import tensorflow as tf 7 | import sys 8 | import numpy as np 9 | 10 | 11 | if __name__ == '__main__': 12 | if len(sys.argv) != 2: 13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.") 14 | file_name = sys.argv[1] 15 | reader = tf.train.NewCheckpointReader(file_name) 16 | var_to_shape_map = reader.get_variable_to_shape_map() 17 | 18 | finite = [] 19 | all_infnan = [] 20 | some_infnan = [] 21 | 22 | for key in sorted(var_to_shape_map.keys()): 23 | tensor = reader.get_tensor(key) 24 | if np.all(np.isfinite(tensor)): 25 | finite.append(key) 26 | else: 27 | if not np.any(np.isfinite(tensor)): 28 | all_infnan.append(key) 29 | else: 30 | some_infnan.append(key) 31 | 32 | print("\nFINITE VARIABLES:") 33 | for key in finite: print(key) 34 | 35 | print("\nVARIABLES THAT ARE ALL INF/NAN:") 36 | for key in all_infnan: print(key) 37 | 38 | print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:") 39 | for key in some_infnan: print(key) 40 | 41 | if not all_infnan and not some_infnan: 42 | print("CHECK PASSED: checkpoint contains no inf/NaN values") 43 | else: 44 | print("CHECK FAILED: checkpoint contains some inf/NaN values") 45 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to build and run the tensorflow graph for the sequence-to-sequence model""" 18 | 19 | import os 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | from attention_decoder import attention_decoder 24 | from tensorflow.contrib.tensorboard.plugins import projector 25 | 26 | FLAGS = tf.app.flags.FLAGS 27 | 28 | class SummarizationModel(object): 29 | """A class to represent a sequence-to-sequence model for text summarization. Supports both baseline mode, pointer-generator mode, and coverage""" 30 | 31 | def __init__(self, hps, vocab): 32 | self._hps = hps 33 | self._vocab = vocab 34 | 35 | def _add_placeholders(self): 36 | """Add placeholders to the graph. These are entry points for any input data.""" 37 | hps = self._hps 38 | 39 | # encoder part 40 | self._enc_batch = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch') 41 | self._enc_lens = tf.placeholder(tf.int32, [hps.batch_size], name='enc_lens') 42 | self._enc_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, None], name='enc_padding_mask') 43 | if FLAGS.pointer_gen: 44 | self._enc_batch_extend_vocab = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch_extend_vocab') 45 | self._max_art_oovs = tf.placeholder(tf.int32, [], name='max_art_oovs') 46 | 47 | # decoder part 48 | self._dec_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='dec_batch') 49 | self._target_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='target_batch') 50 | self._dec_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, hps.max_dec_steps], name='dec_padding_mask') 51 | 52 | if hps.mode=="decode" and hps.coverage: 53 | self.prev_coverage = tf.placeholder(tf.float32, [hps.batch_size, None], name='prev_coverage') 54 | 55 | 56 | def _make_feed_dict(self, batch, just_enc=False): 57 | """Make a feed dictionary mapping parts of the batch to the appropriate placeholders. 58 | 59 | Args: 60 | batch: Batch object 61 | just_enc: Boolean. If True, only feed the parts needed for the encoder. 62 | """ 63 | feed_dict = {} 64 | feed_dict[self._enc_batch] = batch.enc_batch 65 | feed_dict[self._enc_lens] = batch.enc_lens 66 | feed_dict[self._enc_padding_mask] = batch.enc_padding_mask 67 | if FLAGS.pointer_gen: 68 | feed_dict[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab 69 | feed_dict[self._max_art_oovs] = batch.max_art_oovs 70 | if not just_enc: 71 | feed_dict[self._dec_batch] = batch.dec_batch 72 | feed_dict[self._target_batch] = batch.target_batch 73 | feed_dict[self._dec_padding_mask] = batch.dec_padding_mask 74 | return feed_dict 75 | 76 | def _add_encoder(self, encoder_inputs, seq_len): 77 | """Add a single-layer bidirectional LSTM encoder to the graph. 78 | 79 | Args: 80 | encoder_inputs: A tensor of shape [batch_size, <=max_enc_steps, emb_size]. 81 | seq_len: Lengths of encoder_inputs (before padding). A tensor of shape [batch_size]. 82 | 83 | Returns: 84 | encoder_outputs: 85 | A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim]. It's 2*hidden_dim because it's the concatenation of the forwards and backwards states. 86 | fw_state, bw_state: 87 | Each are LSTMStateTuples of shape ([batch_size,hidden_dim],[batch_size,hidden_dim]) 88 | """ 89 | with tf.variable_scope('encoder'): 90 | cell_fw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init, state_is_tuple=True) 91 | cell_bw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init, state_is_tuple=True) 92 | (encoder_outputs, (fw_st, bw_st)) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, encoder_inputs, dtype=tf.float32, sequence_length=seq_len, swap_memory=True) 93 | encoder_outputs = tf.concat(axis=2, values=encoder_outputs) # concatenate the forwards and backwards states 94 | return encoder_outputs, fw_st, bw_st 95 | 96 | 97 | def _reduce_states(self, fw_st, bw_st): 98 | """Add to the graph a linear layer to reduce the encoder's final FW and BW state into a single initial state for the decoder. This is needed because the encoder is bidirectional but the decoder is not. 99 | 100 | Args: 101 | fw_st: LSTMStateTuple with hidden_dim units. 102 | bw_st: LSTMStateTuple with hidden_dim units. 103 | 104 | Returns: 105 | state: LSTMStateTuple with hidden_dim units. 106 | """ 107 | hidden_dim = self._hps.hidden_dim 108 | with tf.variable_scope('reduce_final_st'): 109 | 110 | # Define weights and biases to reduce the cell and reduce the state 111 | w_reduce_c = tf.get_variable('w_reduce_c', [hidden_dim * 2, hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init) 112 | w_reduce_h = tf.get_variable('w_reduce_h', [hidden_dim * 2, hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init) 113 | bias_reduce_c = tf.get_variable('bias_reduce_c', [hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init) 114 | bias_reduce_h = tf.get_variable('bias_reduce_h', [hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init) 115 | 116 | # Apply linear layer 117 | old_c = tf.concat(axis=1, values=[fw_st.c, bw_st.c]) # Concatenation of fw and bw cell 118 | old_h = tf.concat(axis=1, values=[fw_st.h, bw_st.h]) # Concatenation of fw and bw state 119 | new_c = tf.nn.relu(tf.matmul(old_c, w_reduce_c) + bias_reduce_c) # Get new cell from old cell 120 | new_h = tf.nn.relu(tf.matmul(old_h, w_reduce_h) + bias_reduce_h) # Get new state from old state 121 | return tf.contrib.rnn.LSTMStateTuple(new_c, new_h) # Return new cell and state 122 | 123 | 124 | def _add_decoder(self, inputs): 125 | """Add attention decoder to the graph. In train or eval mode, you call this once to get output on ALL steps. In decode (beam search) mode, you call this once for EACH decoder step. 126 | 127 | Args: 128 | inputs: inputs to the decoder (word embeddings). A list of tensors shape (batch_size, emb_dim) 129 | 130 | Returns: 131 | outputs: List of tensors; the outputs of the decoder 132 | out_state: The final state of the decoder 133 | attn_dists: A list of tensors; the attention distributions 134 | p_gens: A list of scalar tensors; the generation probabilities 135 | coverage: A tensor, the current coverage vector 136 | """ 137 | hps = self._hps 138 | cell = tf.contrib.rnn.LSTMCell(hps.hidden_dim, state_is_tuple=True, initializer=self.rand_unif_init) 139 | 140 | prev_coverage = self.prev_coverage if hps.mode=="decode" and hps.coverage else None # In decode mode, we run attention_decoder one step at a time and so need to pass in the previous step's coverage vector each time 141 | 142 | outputs, out_state, attn_dists, p_gens, coverage = attention_decoder(inputs, self._dec_in_state, self._enc_states, self._enc_padding_mask, cell, initial_state_attention=(hps.mode=="decode"), pointer_gen=hps.pointer_gen, use_coverage=hps.coverage, prev_coverage=prev_coverage) 143 | 144 | return outputs, out_state, attn_dists, p_gens, coverage 145 | 146 | def _calc_final_dist(self, vocab_dists, attn_dists): 147 | """Calculate the final distribution, for the pointer-generator model 148 | 149 | Args: 150 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file. 151 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays 152 | 153 | Returns: 154 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays. 155 | """ 156 | with tf.variable_scope('final_distribution'): 157 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen) 158 | vocab_dists = [p_gen * dist for (p_gen,dist) in zip(self.p_gens, vocab_dists)] 159 | attn_dists = [(1-p_gen) * dist for (p_gen,dist) in zip(self.p_gens, attn_dists)] 160 | 161 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words 162 | extended_vsize = self._vocab.size() + self._max_art_oovs # the maximum (over the batch) size of the extended vocabulary 163 | extra_zeros = tf.zeros((self._hps.batch_size, self._max_art_oovs)) 164 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize) 165 | 166 | # Project the values in the attention distributions onto the appropriate entries in the final distributions 167 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution 168 | # This is done for each decoder timestep. 169 | # This is fiddly; we use tf.scatter_nd to do the projection 170 | batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size) 171 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1) 172 | attn_len = tf.shape(self._enc_batch_extend_vocab)[1] # number of states we attend over 173 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len) 174 | indices = tf.stack( (batch_nums, self._enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2) 175 | shape = [self._hps.batch_size, extended_vsize] 176 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists] # list length max_dec_steps (batch_size, extended_vsize) 177 | 178 | # Add the vocab distributions and the copy distributions together to get the final distributions 179 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep 180 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore. 181 | final_dists = [vocab_dist + copy_dist for (vocab_dist,copy_dist) in zip(vocab_dists_extended, attn_dists_projected)] 182 | 183 | return final_dists 184 | 185 | def _add_emb_vis(self, embedding_var): 186 | """Do setup so that we can view word embedding visualization in Tensorboard, as described here: 187 | https://www.tensorflow.org/get_started/embedding_viz 188 | Make the vocab metadata file, then make the projector config file pointing to it.""" 189 | train_dir = os.path.join(FLAGS.log_root, "train") 190 | vocab_metadata_path = os.path.join(train_dir, "vocab_metadata.tsv") 191 | self._vocab.write_metadata(vocab_metadata_path) # write metadata file 192 | summary_writer = tf.summary.FileWriter(train_dir) 193 | config = projector.ProjectorConfig() 194 | embedding = config.embeddings.add() 195 | embedding.tensor_name = embedding_var.name 196 | embedding.metadata_path = vocab_metadata_path 197 | projector.visualize_embeddings(summary_writer, config) 198 | 199 | def _add_seq2seq(self): 200 | """Add the whole sequence-to-sequence model to the graph.""" 201 | hps = self._hps 202 | vsize = self._vocab.size() # size of the vocabulary 203 | 204 | with tf.variable_scope('seq2seq'): 205 | # Some initializers 206 | self.rand_unif_init = tf.random_uniform_initializer(-hps.rand_unif_init_mag, hps.rand_unif_init_mag, seed=123) 207 | self.trunc_norm_init = tf.truncated_normal_initializer(stddev=hps.trunc_norm_init_std) 208 | 209 | # Add embedding matrix (shared by the encoder and decoder inputs) 210 | with tf.variable_scope('embedding'): 211 | embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype=tf.float32, initializer=self.trunc_norm_init) 212 | if hps.mode=="train": self._add_emb_vis(embedding) # add to tensorboard 213 | emb_enc_inputs = tf.nn.embedding_lookup(embedding, self._enc_batch) # tensor with shape (batch_size, max_enc_steps, emb_size) 214 | emb_dec_inputs = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch, axis=1)] # list length max_dec_steps containing shape (batch_size, emb_size) 215 | 216 | # Add the encoder. 217 | enc_outputs, fw_st, bw_st = self._add_encoder(emb_enc_inputs, self._enc_lens) 218 | self._enc_states = enc_outputs 219 | 220 | # Our encoder is bidirectional and our decoder is unidirectional so we need to reduce the final encoder hidden state to the right size to be the initial decoder hidden state 221 | self._dec_in_state = self._reduce_states(fw_st, bw_st) 222 | 223 | # Add the decoder. 224 | with tf.variable_scope('decoder'): 225 | decoder_outputs, self._dec_out_state, self.attn_dists, self.p_gens, self.coverage = self._add_decoder(emb_dec_inputs) 226 | 227 | # Add the output projection to obtain the vocabulary distribution 228 | with tf.variable_scope('output_projection'): 229 | w = tf.get_variable('w', [hps.hidden_dim, vsize], dtype=tf.float32, initializer=self.trunc_norm_init) 230 | w_t = tf.transpose(w) 231 | v = tf.get_variable('v', [vsize], dtype=tf.float32, initializer=self.trunc_norm_init) 232 | vocab_scores = [] # vocab_scores is the vocabulary distribution before applying softmax. Each entry on the list corresponds to one decoder step 233 | for i,output in enumerate(decoder_outputs): 234 | if i > 0: 235 | tf.get_variable_scope().reuse_variables() 236 | vocab_scores.append(tf.nn.xw_plus_b(output, w, v)) # apply the linear layer 237 | 238 | vocab_dists = [tf.nn.softmax(s) for s in vocab_scores] # The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file. 239 | 240 | 241 | # For pointer-generator model, calc final distribution from copy distribution and vocabulary distribution 242 | if FLAGS.pointer_gen: 243 | final_dists = self._calc_final_dist(vocab_dists, self.attn_dists) 244 | else: # final distribution is just vocabulary distribution 245 | final_dists = vocab_dists 246 | 247 | 248 | 249 | if hps.mode in ['train', 'eval']: 250 | # Calculate the loss 251 | with tf.variable_scope('loss'): 252 | if FLAGS.pointer_gen: 253 | # Calculate the loss per step 254 | # This is fiddly; we use tf.gather_nd to pick out the probabilities of the gold target words 255 | loss_per_step = [] # will be list length max_dec_steps containing shape (batch_size) 256 | batch_nums = tf.range(0, limit=hps.batch_size) # shape (batch_size) 257 | for dec_step, dist in enumerate(final_dists): 258 | targets = self._target_batch[:,dec_step] # The indices of the target words. shape (batch_size) 259 | indices = tf.stack( (batch_nums, targets), axis=1) # shape (batch_size, 2) 260 | gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step 261 | losses = -tf.log(gold_probs) 262 | loss_per_step.append(losses) 263 | 264 | # Apply dec_padding_mask and get loss 265 | self._loss = _mask_and_avg(loss_per_step, self._dec_padding_mask) 266 | 267 | else: # baseline model 268 | self._loss = tf.contrib.seq2seq.sequence_loss(tf.stack(vocab_scores, axis=1), self._target_batch, self._dec_padding_mask) # this applies softmax internally 269 | 270 | tf.summary.scalar('loss', self._loss) 271 | 272 | # Calculate coverage loss from the attention distributions 273 | if hps.coverage: 274 | with tf.variable_scope('coverage_loss'): 275 | self._coverage_loss = _coverage_loss(self.attn_dists, self._dec_padding_mask) 276 | tf.summary.scalar('coverage_loss', self._coverage_loss) 277 | self._total_loss = self._loss + hps.cov_loss_wt * self._coverage_loss 278 | tf.summary.scalar('total_loss', self._total_loss) 279 | 280 | if hps.mode == "decode": 281 | # We run decode beam search mode one decoder step at a time 282 | assert len(final_dists)==1 # final_dists is a singleton list containing shape (batch_size, extended_vsize) 283 | final_dists = final_dists[0] 284 | topk_probs, self._topk_ids = tf.nn.top_k(final_dists, hps.batch_size*2) # take the k largest probs. note batch_size=beam_size in decode mode 285 | self._topk_log_probs = tf.log(topk_probs) 286 | 287 | 288 | def _add_train_op(self): 289 | """Sets self._train_op, the op to run for training.""" 290 | # Take gradients of the trainable variables w.r.t. the loss function to minimize 291 | loss_to_minimize = self._total_loss if self._hps.coverage else self._loss 292 | tvars = tf.trainable_variables() 293 | gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE) 294 | 295 | # Clip the gradients 296 | with tf.device("/gpu:0"): 297 | grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm) 298 | 299 | # Add a summary 300 | tf.summary.scalar('global_norm', global_norm) 301 | 302 | # Apply adagrad optimizer 303 | optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc) 304 | with tf.device("/gpu:0"): 305 | self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step') 306 | 307 | 308 | def build_graph(self): 309 | """Add the placeholders, model, global step, train_op and summaries to the graph""" 310 | tf.logging.info('Building graph...') 311 | t0 = time.time() 312 | self._add_placeholders() 313 | with tf.device("/gpu:0"): 314 | self._add_seq2seq() 315 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 316 | if self._hps.mode == 'train': 317 | self._add_train_op() 318 | self._summaries = tf.summary.merge_all() 319 | t1 = time.time() 320 | tf.logging.info('Time to build graph: %i seconds', t1 - t0) 321 | 322 | def run_train_step(self, sess, batch): 323 | """Runs one training iteration. Returns a dictionary containing train op, summaries, loss, global_step and (optionally) coverage loss.""" 324 | feed_dict = self._make_feed_dict(batch) 325 | to_return = { 326 | 'train_op': self._train_op, 327 | 'summaries': self._summaries, 328 | 'loss': self._loss, 329 | 'global_step': self.global_step, 330 | } 331 | if self._hps.coverage: 332 | to_return['coverage_loss'] = self._coverage_loss 333 | return sess.run(to_return, feed_dict) 334 | 335 | def run_eval_step(self, sess, batch): 336 | """Runs one evaluation iteration. Returns a dictionary containing summaries, loss, global_step and (optionally) coverage loss.""" 337 | feed_dict = self._make_feed_dict(batch) 338 | to_return = { 339 | 'summaries': self._summaries, 340 | 'loss': self._loss, 341 | 'global_step': self.global_step, 342 | } 343 | if self._hps.coverage: 344 | to_return['coverage_loss'] = self._coverage_loss 345 | return sess.run(to_return, feed_dict) 346 | 347 | def run_encoder(self, sess, batch): 348 | """For beam search decoding. Run the encoder on the batch and return the encoder states and decoder initial state. 349 | 350 | Args: 351 | sess: Tensorflow session. 352 | batch: Batch object that is the same example repeated across the batch (for beam search) 353 | 354 | Returns: 355 | enc_states: The encoder states. A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 356 | dec_in_state: A LSTMStateTuple of shape ([1,hidden_dim],[1,hidden_dim]) 357 | """ 358 | feed_dict = self._make_feed_dict(batch, just_enc=True) # feed the batch into the placeholders 359 | (enc_states, dec_in_state, global_step) = sess.run([self._enc_states, self._dec_in_state, self.global_step], feed_dict) # run the encoder 360 | 361 | # dec_in_state is LSTMStateTuple shape ([batch_size,hidden_dim],[batch_size,hidden_dim]) 362 | # Given that the batch is a single example repeated, dec_in_state is identical across the batch so we just take the top row. 363 | dec_in_state = tf.contrib.rnn.LSTMStateTuple(dec_in_state.c[0], dec_in_state.h[0]) 364 | return enc_states, dec_in_state 365 | 366 | 367 | def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states, prev_coverage): 368 | """For beam search decoding. Run the decoder for one step. 369 | 370 | Args: 371 | sess: Tensorflow session. 372 | batch: Batch object containing single example repeated across the batch 373 | latest_tokens: Tokens to be fed as input into the decoder for this timestep 374 | enc_states: The encoder states. 375 | dec_init_states: List of beam_size LSTMStateTuples; the decoder states from the previous timestep 376 | prev_coverage: List of np arrays. The coverage vectors from the previous timestep. List of None if not using coverage. 377 | 378 | Returns: 379 | ids: top 2k ids. shape [beam_size, 2*beam_size] 380 | probs: top 2k log probabilities. shape [beam_size, 2*beam_size] 381 | new_states: new states of the decoder. a list length beam_size containing 382 | LSTMStateTuples each of shape ([hidden_dim,],[hidden_dim,]) 383 | attn_dists: List length beam_size containing lists length attn_length. 384 | p_gens: Generation probabilities for this step. A list length beam_size. List of None if in baseline mode. 385 | new_coverage: Coverage vectors for this step. A list of arrays. List of None if coverage is not turned on. 386 | """ 387 | 388 | beam_size = len(dec_init_states) 389 | 390 | # Turn dec_init_states (a list of LSTMStateTuples) into a single LSTMStateTuple for the batch 391 | cells = [np.expand_dims(state.c, axis=0) for state in dec_init_states] 392 | hiddens = [np.expand_dims(state.h, axis=0) for state in dec_init_states] 393 | new_c = np.concatenate(cells, axis=0) # shape [batch_size,hidden_dim] 394 | new_h = np.concatenate(hiddens, axis=0) # shape [batch_size,hidden_dim] 395 | new_dec_in_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 396 | 397 | feed = { 398 | self._enc_states: enc_states, 399 | self._enc_padding_mask: batch.enc_padding_mask, 400 | self._dec_in_state: new_dec_in_state, 401 | self._dec_batch: np.transpose(np.array([latest_tokens])), 402 | } 403 | 404 | to_return = { 405 | "ids": self._topk_ids, 406 | "probs": self._topk_log_probs, 407 | "states": self._dec_out_state, 408 | "attn_dists": self.attn_dists 409 | } 410 | 411 | if FLAGS.pointer_gen: 412 | feed[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab 413 | feed[self._max_art_oovs] = batch.max_art_oovs 414 | to_return['p_gens'] = self.p_gens 415 | 416 | if self._hps.coverage: 417 | feed[self.prev_coverage] = np.stack(prev_coverage, axis=0) 418 | to_return['coverage'] = self.coverage 419 | 420 | results = sess.run(to_return, feed_dict=feed) # run the decoder step 421 | 422 | # Convert results['states'] (a single LSTMStateTuple) into a list of LSTMStateTuple -- one for each hypothesis 423 | new_states = [tf.contrib.rnn.LSTMStateTuple(results['states'].c[i, :], results['states'].h[i, :]) for i in range(beam_size)] 424 | 425 | # Convert singleton list containing a tensor to a list of k arrays 426 | assert len(results['attn_dists'])==1 427 | attn_dists = results['attn_dists'][0].tolist() 428 | 429 | if FLAGS.pointer_gen: 430 | # Convert singleton list containing a tensor to a list of k arrays 431 | assert len(results['p_gens'])==1 432 | p_gens = results['p_gens'][0].tolist() 433 | else: 434 | p_gens = [None for _ in range(beam_size)] 435 | 436 | # Convert the coverage tensor to a list length k containing the coverage vector for each hypothesis 437 | if FLAGS.coverage: 438 | new_coverage = results['coverage'].tolist() 439 | assert len(new_coverage) == beam_size 440 | else: 441 | new_coverage = [None for _ in range(beam_size)] 442 | 443 | return results['ids'], results['probs'], new_states, attn_dists, p_gens, new_coverage 444 | 445 | 446 | def _mask_and_avg(values, padding_mask): 447 | """Applies mask to values then returns overall average (a scalar) 448 | 449 | Args: 450 | values: a list length max_dec_steps containing arrays shape (batch_size). 451 | padding_mask: tensor shape (batch_size, max_dec_steps) containing 1s and 0s. 452 | 453 | Returns: 454 | a scalar 455 | """ 456 | 457 | dec_lens = tf.reduce_sum(padding_mask, axis=1) # shape batch_size. float32 458 | values_per_step = [v * padding_mask[:,dec_step] for dec_step,v in enumerate(values)] 459 | values_per_ex = sum(values_per_step)/dec_lens # shape (batch_size); normalized value for each batch member 460 | return tf.reduce_mean(values_per_ex) # overall average 461 | 462 | 463 | def _coverage_loss(attn_dists, padding_mask): 464 | """Calculates the coverage loss from the attention distributions. 465 | 466 | Args: 467 | attn_dists: The attention distributions for each decoder timestep. A list length max_dec_steps containing shape (batch_size, attn_length) 468 | padding_mask: shape (batch_size, max_dec_steps). 469 | 470 | Returns: 471 | coverage_loss: scalar 472 | """ 473 | coverage = tf.zeros_like(attn_dists[0]) # shape (batch_size, attn_length). Initial coverage is zero. 474 | covlosses = [] # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size). 475 | for a in attn_dists: 476 | covloss = tf.reduce_sum(tf.minimum(a, coverage), [1]) # calculate the coverage loss for this step 477 | covlosses.append(covloss) 478 | coverage += a # update the coverage vector 479 | coverage_loss = _mask_and_avg(covlosses, padding_mask) 480 | return coverage_loss 481 | -------------------------------------------------------------------------------- /run_summarization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This is the top-level file to train, evaluate or test your summarization model""" 18 | 19 | import sys 20 | import time 21 | import os 22 | import tensorflow as tf 23 | import numpy as np 24 | from collections import namedtuple 25 | from data import Vocab 26 | from batcher import Batcher 27 | from model import SummarizationModel 28 | from decode import BeamSearchDecoder 29 | import util 30 | from tensorflow.python import debug as tf_debug 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | # Where to find data 35 | tf.app.flags.DEFINE_string('data_path', '', 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.') 36 | tf.app.flags.DEFINE_string('vocab_path', '', 'Path expression to text vocabulary file.') 37 | 38 | # Important settings 39 | tf.app.flags.DEFINE_string('mode', 'train', 'must be one of train/eval/decode') 40 | tf.app.flags.DEFINE_boolean('single_pass', False, 'For decode mode only. If True, run eval on the full dataset using a fixed checkpoint, i.e. take the current checkpoint, and use it to produce one summary for each example in the dataset, write the summaries to file and then get ROUGE scores for the whole dataset. If False (default), run concurrent decoding, i.e. repeatedly load latest checkpoint, use it to produce summaries for randomly-chosen examples and log the results to screen, indefinitely.') 41 | 42 | # Where to save output 43 | tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.') 44 | tf.app.flags.DEFINE_string('exp_name', '', 'Name for experiment. Logs will be saved in a directory with this name, under log_root.') 45 | 46 | # Hyperparameters 47 | tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states') 48 | tf.app.flags.DEFINE_integer('emb_dim', 128, 'dimension of word embeddings') 49 | tf.app.flags.DEFINE_integer('batch_size', 16, 'minibatch size') 50 | tf.app.flags.DEFINE_integer('max_enc_steps', 400, 'max timesteps of encoder (max source text tokens)') 51 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'max timesteps of decoder (max summary tokens)') 52 | tf.app.flags.DEFINE_integer('beam_size', 4, 'beam size for beam search decoding.') 53 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 'Minimum sequence length of generated summary. Applies only for beam search decoding mode') 54 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') 55 | tf.app.flags.DEFINE_float('lr', 0.15, 'learning rate') 56 | tf.app.flags.DEFINE_float('adagrad_init_acc', 0.1, 'initial accumulator value for Adagrad') 57 | tf.app.flags.DEFINE_float('rand_unif_init_mag', 0.02, 'magnitude for lstm cells random uniform inititalization') 58 | tf.app.flags.DEFINE_float('trunc_norm_init_std', 1e-4, 'std of trunc norm init, used for initializing everything else') 59 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'for gradient clipping') 60 | 61 | # Pointer-generator or baseline model 62 | tf.app.flags.DEFINE_boolean('pointer_gen', True, 'If True, use pointer-generator model. If False, use baseline model.') 63 | 64 | # Coverage hyperparameters 65 | tf.app.flags.DEFINE_boolean('coverage', False, 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.') 66 | tf.app.flags.DEFINE_float('cov_loss_wt', 1.0, 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.') 67 | 68 | # Utility flags, for restoring and changing checkpoints 69 | tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False, 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current training model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.') 70 | tf.app.flags.DEFINE_boolean('restore_best_model', False, 'Restore the best model in the eval/ dir and save it in the train/ dir, ready to be used for further training. Useful for early stopping, or if your training checkpoint has become corrupted with e.g. NaN values.') 71 | 72 | # Debugging. See https://www.tensorflow.org/programmers_guide/debugger 73 | tf.app.flags.DEFINE_boolean('debug', False, "Run in tensorflow's debug mode (watches for NaN/inf values)") 74 | 75 | 76 | 77 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): 78 | """Calculate the running average loss via exponential decay. 79 | This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve. 80 | 81 | Args: 82 | loss: loss on the most recent eval step 83 | running_avg_loss: running_avg_loss so far 84 | summary_writer: FileWriter object to write for tensorboard 85 | step: training iteration step 86 | decay: rate of exponential decay, a float between 0 and 1. Larger is smoother. 87 | 88 | Returns: 89 | running_avg_loss: new running average loss 90 | """ 91 | if running_avg_loss == 0: # on the first iteration just take the loss 92 | running_avg_loss = loss 93 | else: 94 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 95 | running_avg_loss = min(running_avg_loss, 12) # clip 96 | loss_sum = tf.Summary() 97 | tag_name = 'running_avg_loss/decay=%f' % (decay) 98 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) 99 | summary_writer.add_summary(loss_sum, step) 100 | tf.logging.info('running_avg_loss: %f', running_avg_loss) 101 | return running_avg_loss 102 | 103 | 104 | def restore_best_model(): 105 | """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory""" 106 | tf.logging.info("Restoring bestmodel for training...") 107 | 108 | # Initialize all vars in the model 109 | sess = tf.Session(config=util.get_config()) 110 | print("Initializing all variables...") 111 | sess.run(tf.initialize_all_variables()) 112 | 113 | # Restore the best model from eval dir 114 | saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name]) 115 | print("Restoring all non-adagrad variables from best model in eval dir...") 116 | curr_ckpt = util.load_ckpt(saver, sess, "eval") 117 | print ("Restored %s." % curr_ckpt) 118 | 119 | # Save this model to train dir and quit 120 | new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model") 121 | new_fname = os.path.join(FLAGS.log_root, "train", new_model_name) 122 | print ("Saving model to %s..." % (new_fname)) 123 | new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables 124 | new_saver.save(sess, new_fname) 125 | print ("Saved.") 126 | exit() 127 | 128 | 129 | def convert_to_coverage_model(): 130 | """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" 131 | tf.logging.info("converting non-coverage model to coverage model..") 132 | 133 | # initialize an entire coverage model from scratch 134 | sess = tf.Session(config=util.get_config()) 135 | print("initializing everything...") 136 | sess.run(tf.global_variables_initializer()) 137 | 138 | # load all non-coverage weights from checkpoint 139 | saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) 140 | print("restoring non-coverage variables...") 141 | curr_ckpt = util.load_ckpt(saver, sess) 142 | print("restored.") 143 | 144 | # save this model and quit 145 | new_fname = curr_ckpt + '_cov_init' 146 | print("saving model to %s..." % (new_fname)) 147 | new_saver = tf.train.Saver() # this one will save all variables that now exist 148 | new_saver.save(sess, new_fname) 149 | print("saved.") 150 | exit() 151 | 152 | 153 | def setup_training(model, batcher): 154 | """Does setup before starting training (run_training)""" 155 | train_dir = os.path.join(FLAGS.log_root, "train") 156 | if not os.path.exists(train_dir): os.makedirs(train_dir) 157 | 158 | model.build_graph() # build the graph 159 | if FLAGS.convert_to_coverage_model: 160 | assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True" 161 | convert_to_coverage_model() 162 | if FLAGS.restore_best_model: 163 | restore_best_model() 164 | saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time 165 | 166 | sv = tf.train.Supervisor(logdir=train_dir, 167 | is_chief=True, 168 | saver=saver, 169 | summary_op=None, 170 | save_summaries_secs=60, # save summaries for tensorboard every 60 secs 171 | save_model_secs=60, # checkpoint every 60 secs 172 | global_step=model.global_step) 173 | summary_writer = sv.summary_writer 174 | tf.logging.info("Preparing or waiting for session...") 175 | sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config()) 176 | tf.logging.info("Created session.") 177 | try: 178 | run_training(model, batcher, sess_context_manager, sv, summary_writer) # this is an infinite loop until interrupted 179 | except KeyboardInterrupt: 180 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") 181 | sv.stop() 182 | 183 | 184 | def run_training(model, batcher, sess_context_manager, sv, summary_writer): 185 | """Repeatedly runs training iterations, logging loss to screen and writing summaries""" 186 | tf.logging.info("starting run_training") 187 | with sess_context_manager as sess: 188 | if FLAGS.debug: # start the tensorflow debugger 189 | sess = tf_debug.LocalCLIDebugWrapperSession(sess) 190 | sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 191 | while True: # repeats until interrupted 192 | batch = batcher.next_batch() 193 | 194 | tf.logging.info('running training step...') 195 | t0=time.time() 196 | results = model.run_train_step(sess, batch) 197 | t1=time.time() 198 | tf.logging.info('seconds for training step: %.3f', t1-t0) 199 | 200 | loss = results['loss'] 201 | tf.logging.info('loss: %f', loss) # print the loss to screen 202 | 203 | if not np.isfinite(loss): 204 | raise Exception("Loss is not finite. Stopping.") 205 | 206 | if FLAGS.coverage: 207 | coverage_loss = results['coverage_loss'] 208 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen 209 | 210 | # get the summaries and iteration number so we can write summaries to tensorboard 211 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer 212 | train_step = results['global_step'] # we need this to update our running average loss 213 | 214 | summary_writer.add_summary(summaries, train_step) # write the summaries 215 | if train_step % 100 == 0: # flush the summary writer every so often 216 | summary_writer.flush() 217 | 218 | 219 | def run_eval(model, batcher, vocab): 220 | """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" 221 | model.build_graph() # build the graph 222 | saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time 223 | sess = tf.Session(config=util.get_config()) 224 | eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data 225 | bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved 226 | summary_writer = tf.summary.FileWriter(eval_dir) 227 | running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping 228 | best_loss = None # will hold the best loss achieved so far 229 | 230 | while True: 231 | _ = util.load_ckpt(saver, sess) # load a new checkpoint 232 | batch = batcher.next_batch() # get the next batch 233 | 234 | # run eval on the batch 235 | t0=time.time() 236 | results = model.run_eval_step(sess, batch) 237 | t1=time.time() 238 | tf.logging.info('seconds for batch: %.2f', t1-t0) 239 | 240 | # print the loss and coverage loss to screen 241 | loss = results['loss'] 242 | tf.logging.info('loss: %f', loss) 243 | if FLAGS.coverage: 244 | coverage_loss = results['coverage_loss'] 245 | tf.logging.info("coverage_loss: %f", coverage_loss) 246 | 247 | # add summaries 248 | summaries = results['summaries'] 249 | train_step = results['global_step'] 250 | summary_writer.add_summary(summaries, train_step) 251 | 252 | # calculate running avg loss 253 | running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step) 254 | 255 | # If running_avg_loss is best so far, save this checkpoint (early stopping). 256 | # These checkpoints will appear as bestmodel- in the eval dir 257 | if best_loss is None or running_avg_loss < best_loss: 258 | tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) 259 | saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') 260 | best_loss = running_avg_loss 261 | 262 | # flush the summary writer every so often 263 | if train_step % 100 == 0: 264 | summary_writer.flush() 265 | 266 | 267 | def main(unused_argv): 268 | if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly 269 | raise Exception("Problem with flags: %s" % unused_argv) 270 | 271 | tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want 272 | tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) 273 | 274 | # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary 275 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) 276 | if not os.path.exists(FLAGS.log_root): 277 | if FLAGS.mode=="train": 278 | os.makedirs(FLAGS.log_root) 279 | else: 280 | raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) 281 | 282 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary 283 | 284 | # If in decode mode, set batch_size = beam_size 285 | # Reason: in decode mode, we decode one example at a time. 286 | # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. 287 | if FLAGS.mode == 'decode': 288 | FLAGS.batch_size = FLAGS.beam_size 289 | 290 | # If single_pass=True, check we're in decode mode 291 | if FLAGS.single_pass and FLAGS.mode!='decode': 292 | raise Exception("The single_pass flag should only be True in decode mode") 293 | 294 | # Make a namedtuple hps, containing the values of the hyperparameters that the model needs 295 | hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen'] 296 | hps_dict = {} 297 | for key,val in FLAGS.__flags.items(): # for each flag 298 | if key in hparam_list: # if it's in the list 299 | hps_dict[key] = val # add it to the dict 300 | hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) 301 | 302 | # Create a batcher object that will create minibatches of data 303 | batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass) 304 | 305 | tf.set_random_seed(111) # a seed value for randomness 306 | 307 | if hps.mode == 'train': 308 | print("creating model...") 309 | model = SummarizationModel(hps, vocab) 310 | setup_training(model, batcher) 311 | elif hps.mode == 'eval': 312 | model = SummarizationModel(hps, vocab) 313 | run_eval(model, batcher, vocab) 314 | elif hps.mode == 'decode': 315 | decode_model_hps = hps # This will be the hyperparameters for the decoder model 316 | decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries 317 | model = SummarizationModel(decode_model_hps, vocab) 318 | decoder = BeamSearchDecoder(model, batcher, vocab) 319 | decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) 320 | else: 321 | raise ValueError("The 'mode' flag must be one of train/eval/decode") 322 | 323 | if __name__ == '__main__': 324 | tf.app.run() 325 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains some utility functions""" 18 | 19 | import tensorflow as tf 20 | import time 21 | import os 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | def get_config(): 25 | """Returns config for tf.session""" 26 | config = tf.ConfigProto(allow_soft_placement=True) 27 | config.gpu_options.allow_growth=True 28 | return config 29 | 30 | def load_ckpt(saver, sess, ckpt_dir="train"): 31 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name.""" 32 | while True: 33 | try: 34 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None 35 | ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir) 36 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename) 37 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path) 38 | saver.restore(sess, ckpt_state.model_checkpoint_path) 39 | return ckpt_state.model_checkpoint_path 40 | except: 41 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10) 42 | time.sleep(10) 43 | --------------------------------------------------------------------------------