├── .gitignore ├── LICENSE ├── README.md ├── data_util ├── __init__.py ├── batcher.py ├── config.py ├── data.py └── utils.py ├── learning_curve.png ├── learning_curve_coverage.png ├── start_decode.sh ├── start_eval.sh ├── start_train.sh └── training_ptr_gen ├── __init__.py ├── decode.py ├── eval.py ├── model.py ├── train.py ├── train_util.py └── transformer_encoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.bin 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | pytorch implementation of *[Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)* 2 | 3 | 1. [Train with pointer generation and coverage loss enabled](#train-with-pointer-generation-and-coverage-loss-enabled) 4 | 2. [Training with pointer generation enabled](#training-with-pointer-generation-enabled) 5 | 3. [How to run training](#how-to-run-training) 6 | 4. [Papers using this code](#papers-using-this-code) 7 | 8 | 9 | ## Train with pointer generation and coverage loss enabled 10 | After training for 100k iterations with coverage loss enabled (batch size 8) 11 | 12 | ``` 13 | ROUGE-1: 14 | rouge_1_f_score: 0.3907 with confidence interval (0.3885, 0.3928) 15 | rouge_1_recall: 0.4434 with confidence interval (0.4410, 0.4460) 16 | rouge_1_precision: 0.3698 with confidence interval (0.3672, 0.3721) 17 | 18 | ROUGE-2: 19 | rouge_2_f_score: 0.1697 with confidence interval (0.1674, 0.1720) 20 | rouge_2_recall: 0.1920 with confidence interval (0.1894, 0.1945) 21 | rouge_2_precision: 0.1614 with confidence interval (0.1590, 0.1636) 22 | 23 | ROUGE-l: 24 | rouge_l_f_score: 0.3587 with confidence interval (0.3565, 0.3608) 25 | rouge_l_recall: 0.4067 with confidence interval (0.4042, 0.4092) 26 | rouge_l_precision: 0.3397 with confidence interval (0.3371, 0.3420) 27 | ``` 28 | 29 | ![Alt text](learning_curve_coverage.png?raw=true "Learning Curve with coverage loss") 30 | 31 | ## Training with pointer generation enabled 32 | After training for 500k iterations (batch size 8) 33 | 34 | ``` 35 | ROUGE-1: 36 | rouge_1_f_score: 0.3500 with confidence interval (0.3477, 0.3523) 37 | rouge_1_recall: 0.3718 with confidence interval (0.3693, 0.3745) 38 | rouge_1_precision: 0.3529 with confidence interval (0.3501, 0.3555) 39 | 40 | ROUGE-2: 41 | rouge_2_f_score: 0.1486 with confidence interval (0.1465, 0.1508) 42 | rouge_2_recall: 0.1573 with confidence interval (0.1551, 0.1597) 43 | rouge_2_precision: 0.1506 with confidence interval (0.1483, 0.1529) 44 | 45 | ROUGE-l: 46 | rouge_l_f_score: 0.3202 with confidence interval (0.3179, 0.3225) 47 | rouge_l_recall: 0.3399 with confidence interval (0.3374, 0.3426) 48 | rouge_l_precision: 0.3231 with confidence interval (0.3205, 0.3256) 49 | ``` 50 | ![Alt text](learning_curve.png?raw=true "Learning Curve with pointer generation") 51 | 52 | 53 | ## How to run training: 54 | 1) Follow data generation instruction from https://github.com/abisee/cnn-dailymail 55 | 2) Run start_train.sh, you might need to change some path and parameters in data_util/config.py 56 | 3) For training run start_train.sh, for decoding run start_decode.sh, and for evaluating run run_eval.sh 57 | 58 | Note: 59 | 60 | * In decode mode beam search batch should have only one example replicated to batch size 61 | https://github.com/atulkum/pointer_summarizer/blob/master/training_ptr_gen/decode.py#L109 62 | https://github.com/atulkum/pointer_summarizer/blob/master/data_util/batcher.py#L226 63 | 64 | * It is tested on pytorch 0.4 with python 2.7 65 | * You need to setup [pyrouge](https://github.com/andersjo/pyrouge) to get the rouge score 66 | 67 | ## Papers using this code: 68 | 1) [Automatic Program Synthesis of Long Programs with a Learned Garbage Collector](http://papers.nips.cc/paper/7479-automatic-program-synthesis-of-long-programs-with-a-learned-garbage-collector) ***NeuroIPS 2018*** https://github.com/amitz25/PCCoder 69 | 2) [Automatic Fact-guided Sentence Modification](https://arxiv.org/abs/1909.13838) ***AAAI 2020*** https://github.com/darsh10/split_encoder_pointer_summarizer 70 | 3) [Resurrecting Submodularity in Neural Abstractive Summarization](https://arxiv.org/abs/1911.03014v1) 71 | 4) [StructSum: Summarization via Structured Representations](https://aclanthology.org/2021.eacl-main.220) ***EACL 2021*** 72 | 5) [Concept Pointer Network for Abstractive Summarization](https://arxiv.org/abs/1910.08486) ***EMNLP'2019*** https://github.com/wprojectsn/codes 73 | 6) [PaddlePaddle version](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_summarization/pointer_summarizer) 74 | 7) [VAE-PGN based Abstractive Model in Multi-stage Architecture for Text Summarization](https://www.aclweb.org/anthology/W19-8664/) ***INLG2019*** 75 | 8) [Clickbait? Sensational Headline Generation with Auto-tuned Reinforcement Learning](https://arxiv.org/abs/1909.03582) ***EMNLP'2019*** https://github.com/HLTCHKUST/sensational_headline 76 | 9) [Abstractive Spoken Document Summarization using Hierarchical Model with Multi-stage Attention Diversity Optimization](http://www.interspeech2020.org/index.php?m=content&c=index&a=show&catid=354&id=1173) ***INTERSPEECH 2020*** 77 | 10) [Nutribullets Hybrid: Multi-document Health Summarization](https://arxiv.org/abs/2104.03465) ***NAACL 2021*** 78 | 11) [A Corpus of Very Short Scientific Summaries](https://aclanthology.org/2020.conll-1.12) ***CoNLL 2020*** 79 | 12) [Towards Faithfulness in Open Domain Table-to-text Generation from an Entity-centric View](https://arxiv.org/abs/2102.08585) ***AAAI 2021*** 80 | 13) [CDEvalSumm: An Empirical Study of Cross-Dataset Evaluation for Neural Summarization Systems](https://aclanthology.org/2020.findings-emnlp.329) ***Findings of EMNLP2020*** 81 | 14) [A Study on Seq2seq for Sentence Compression in Vietnamese](https://aclanthology.org/2020.paclic-1.56) ***PACLIC 2020*** 82 | 15) [Other Roles Matter! Enhancing Role-Oriented Dialogue Summarization via Role Interactions](https://aclanthology.org/2022.acl-long.182/) ***ACL 2022*** 83 | 84 | -------------------------------------------------------------------------------- /data_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/data_util/__init__.py -------------------------------------------------------------------------------- /data_util/batcher.py: -------------------------------------------------------------------------------- 1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/batcher.py 2 | 3 | import Queue 4 | import time 5 | from random import shuffle 6 | from threading import Thread 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import config 12 | import data 13 | 14 | import random 15 | random.seed(1234) 16 | 17 | 18 | class Example(object): 19 | 20 | def __init__(self, article, abstract_sentences, vocab): 21 | # Get ids of special tokens 22 | start_decoding = vocab.word2id(data.START_DECODING) 23 | stop_decoding = vocab.word2id(data.STOP_DECODING) 24 | 25 | # Process the article 26 | article_words = article.split() 27 | if len(article_words) > config.max_enc_steps: 28 | article_words = article_words[:config.max_enc_steps] 29 | self.enc_len = len(article_words) # store the length after truncation but before padding 30 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token 31 | 32 | # Process the abstract 33 | abstract = ' '.join(abstract_sentences) # string 34 | abstract_words = abstract.split() # list of strings 35 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token 36 | 37 | # Get the decoder input sequence and target sequence 38 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding) 39 | self.dec_len = len(self.dec_input) 40 | 41 | # If using pointer-generator mode, we need to store some extra info 42 | if config.pointer_gen: 43 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves 44 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab) 45 | 46 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id 47 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs) 48 | 49 | # Overwrite decoder target sequence so it uses the temp article OOV ids 50 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding) 51 | 52 | # Store the original strings 53 | self.original_article = article 54 | self.original_abstract = abstract 55 | self.original_abstract_sents = abstract_sentences 56 | 57 | 58 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id): 59 | inp = [start_id] + sequence[:] 60 | target = sequence[:] 61 | if len(inp) > max_len: # truncate 62 | inp = inp[:max_len] 63 | target = target[:max_len] # no end_token 64 | else: # no truncation 65 | target.append(stop_id) # end token 66 | assert len(inp) == len(target) 67 | return inp, target 68 | 69 | 70 | def pad_decoder_inp_targ(self, max_len, pad_id): 71 | while len(self.dec_input) < max_len: 72 | self.dec_input.append(pad_id) 73 | while len(self.target) < max_len: 74 | self.target.append(pad_id) 75 | 76 | 77 | def pad_encoder_input(self, max_len, pad_id): 78 | while len(self.enc_input) < max_len: 79 | self.enc_input.append(pad_id) 80 | if config.pointer_gen: 81 | while len(self.enc_input_extend_vocab) < max_len: 82 | self.enc_input_extend_vocab.append(pad_id) 83 | 84 | 85 | class Batch(object): 86 | def __init__(self, example_list, vocab, batch_size): 87 | self.batch_size = batch_size 88 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences 89 | self.init_encoder_seq(example_list) # initialize the input to the encoder 90 | self.init_decoder_seq(example_list) # initialize the input and targets for the decoder 91 | self.store_orig_strings(example_list) # store the original strings 92 | 93 | 94 | def init_encoder_seq(self, example_list): 95 | # Determine the maximum length of the encoder input sequence in this batch 96 | max_enc_seq_len = max([ex.enc_len for ex in example_list]) 97 | 98 | # Pad the encoder input sequences up to the length of the longest sequence 99 | for ex in example_list: 100 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id) 101 | 102 | # Initialize the numpy arrays 103 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder. 104 | self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32) 105 | self.enc_lens = np.zeros((self.batch_size), dtype=np.int32) 106 | self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32) 107 | 108 | # Fill in the numpy arrays 109 | for i, ex in enumerate(example_list): 110 | self.enc_batch[i, :] = ex.enc_input[:] 111 | self.enc_lens[i] = ex.enc_len 112 | for j in xrange(ex.enc_len): 113 | self.enc_padding_mask[i][j] = 1 114 | 115 | # For pointer-generator mode, need to store some extra info 116 | if config.pointer_gen: 117 | # Determine the max number of in-article OOVs in this batch 118 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list]) 119 | # Store the in-article OOVs themselves 120 | self.art_oovs = [ex.article_oovs for ex in example_list] 121 | # Store the version of the enc_batch that uses the article OOV ids 122 | self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32) 123 | for i, ex in enumerate(example_list): 124 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:] 125 | 126 | def init_decoder_seq(self, example_list): 127 | # Pad the inputs and targets 128 | for ex in example_list: 129 | ex.pad_decoder_inp_targ(config.max_dec_steps, self.pad_id) 130 | 131 | # Initialize the numpy arrays. 132 | self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32) 133 | self.target_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32) 134 | self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32) 135 | self.dec_lens = np.zeros((self.batch_size), dtype=np.int32) 136 | 137 | # Fill in the numpy arrays 138 | for i, ex in enumerate(example_list): 139 | self.dec_batch[i, :] = ex.dec_input[:] 140 | self.target_batch[i, :] = ex.target[:] 141 | self.dec_lens[i] = ex.dec_len 142 | for j in xrange(ex.dec_len): 143 | self.dec_padding_mask[i][j] = 1 144 | 145 | def store_orig_strings(self, example_list): 146 | self.original_articles = [ex.original_article for ex in example_list] # list of lists 147 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists 148 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists 149 | 150 | 151 | class Batcher(object): 152 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold 153 | 154 | def __init__(self, data_path, vocab, mode, batch_size, single_pass): 155 | self._data_path = data_path 156 | self._vocab = vocab 157 | self._single_pass = single_pass 158 | self.mode = mode 159 | self.batch_size = batch_size 160 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched 161 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX) 162 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size) 163 | 164 | # Different settings depending on whether we're in single_pass mode or not 165 | if single_pass: 166 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once 167 | self._num_batch_q_threads = 1 # just one thread to batch examples 168 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing 169 | self._finished_reading = False # this will tell us when we're finished reading the dataset 170 | else: 171 | self._num_example_q_threads = 1 #16 # num threads to fill example queue 172 | self._num_batch_q_threads = 1 #4 # num threads to fill batch queue 173 | self._bucketing_cache_size = 1 #100 # how many batches-worth of examples to load into cache before bucketing 174 | 175 | # Start the threads that load the queues 176 | self._example_q_threads = [] 177 | for _ in xrange(self._num_example_q_threads): 178 | self._example_q_threads.append(Thread(target=self.fill_example_queue)) 179 | self._example_q_threads[-1].daemon = True 180 | self._example_q_threads[-1].start() 181 | self._batch_q_threads = [] 182 | for _ in xrange(self._num_batch_q_threads): 183 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) 184 | self._batch_q_threads[-1].daemon = True 185 | self._batch_q_threads[-1].start() 186 | 187 | # Start a thread that watches the other threads and restarts them if they're dead 188 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever 189 | self._watch_thread = Thread(target=self.watch_threads) 190 | self._watch_thread.daemon = True 191 | self._watch_thread.start() 192 | 193 | def next_batch(self): 194 | # If the batch queue is empty, print a warning 195 | if self._batch_queue.qsize() == 0: 196 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize()) 197 | if self._single_pass and self._finished_reading: 198 | tf.logging.info("Finished reading dataset in single_pass mode.") 199 | return None 200 | 201 | batch = self._batch_queue.get() # get the next Batch 202 | return batch 203 | 204 | def fill_example_queue(self): 205 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass)) 206 | 207 | while True: 208 | try: 209 | (article, abstract) = input_gen.next() # read the next example from file. article and abstract are both strings. 210 | except StopIteration: # if there are no more examples: 211 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.") 212 | if self._single_pass: 213 | tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.") 214 | self._finished_reading = True 215 | break 216 | else: 217 | raise Exception("single_pass mode is off but the example generator is out of data; error.") 218 | 219 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences. 220 | example = Example(article, abstract_sentences, self._vocab) # Process into an Example. 221 | self._example_queue.put(example) # place the Example in the example queue. 222 | 223 | def fill_batch_queue(self): 224 | while True: 225 | if self.mode == 'decode': 226 | # beam search decode mode single example repeated in the batch 227 | ex = self._example_queue.get() 228 | b = [ex for _ in xrange(self.batch_size)] 229 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size)) 230 | else: 231 | # Get bucketing_cache_size-many batches of Examples into a list, then sort 232 | inputs = [] 233 | for _ in xrange(self.batch_size * self._bucketing_cache_size): 234 | inputs.append(self._example_queue.get()) 235 | inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence 236 | 237 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. 238 | batches = [] 239 | for i in xrange(0, len(inputs), self.batch_size): 240 | batches.append(inputs[i:i + self.batch_size]) 241 | if not self._single_pass: 242 | shuffle(batches) 243 | for b in batches: # each b is a list of Example objects 244 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size)) 245 | 246 | def watch_threads(self): 247 | while True: 248 | tf.logging.info( 249 | 'Bucket queue size: %i, Input queue size: %i', 250 | self._batch_queue.qsize(), self._example_queue.qsize()) 251 | 252 | time.sleep(60) 253 | for idx,t in enumerate(self._example_q_threads): 254 | if not t.is_alive(): # if the thread is dead 255 | tf.logging.error('Found example queue thread dead. Restarting.') 256 | new_t = Thread(target=self.fill_example_queue) 257 | self._example_q_threads[idx] = new_t 258 | new_t.daemon = True 259 | new_t.start() 260 | for idx,t in enumerate(self._batch_q_threads): 261 | if not t.is_alive(): # if the thread is dead 262 | tf.logging.error('Found batch queue thread dead. Restarting.') 263 | new_t = Thread(target=self.fill_batch_queue) 264 | self._batch_q_threads[idx] = new_t 265 | new_t.daemon = True 266 | new_t.start() 267 | 268 | 269 | def text_generator(self, example_generator): 270 | while True: 271 | e = example_generator.next() # e is a tf.Example 272 | try: 273 | article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files 274 | abstract_text = e.features.feature['abstract'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files 275 | except ValueError: 276 | tf.logging.error('Failed to get article or abstract from example') 277 | continue 278 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1 279 | #tf.logging.warning('Found an example with empty article text. Skipping it.') 280 | continue 281 | else: 282 | yield (article_text, abstract_text) 283 | -------------------------------------------------------------------------------- /data_util/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | root_dir = os.path.expanduser("~") 4 | 5 | #train_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/train.bin") 6 | train_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/chunked/train_*") 7 | eval_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/val.bin") 8 | decode_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/test.bin") 9 | vocab_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/vocab") 10 | log_root = os.path.join(root_dir, "ptr_nw/log") 11 | 12 | # Hyperparameters 13 | hidden_dim= 256 14 | emb_dim= 128 15 | batch_size= 8 16 | max_enc_steps=400 17 | max_dec_steps=100 18 | beam_size=4 19 | min_dec_steps=35 20 | vocab_size=50000 21 | 22 | lr=0.15 23 | adagrad_init_acc=0.1 24 | rand_unif_init_mag=0.02 25 | trunc_norm_init_std=1e-4 26 | max_grad_norm=2.0 27 | 28 | pointer_gen = True 29 | is_coverage = False 30 | cov_loss_wt = 1.0 31 | 32 | eps = 1e-12 33 | max_iterations = 500000 34 | 35 | use_gpu=True 36 | 37 | lr_coverage=0.15 38 | -------------------------------------------------------------------------------- /data_util/data.py: -------------------------------------------------------------------------------- 1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/data.py 2 | 3 | import glob 4 | import random 5 | import struct 6 | import csv 7 | from tensorflow.core.example import example_pb2 8 | 9 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 10 | SENTENCE_START = '' 11 | SENTENCE_END = '' 12 | 13 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 14 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 15 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 16 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 17 | 18 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 19 | 20 | 21 | class Vocab(object): 22 | 23 | def __init__(self, vocab_file, max_size): 24 | self._word_to_id = {} 25 | self._id_to_word = {} 26 | self._count = 0 # keeps track of total number of words in the Vocab 27 | 28 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. 29 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 30 | self._word_to_id[w] = self._count 31 | self._id_to_word[self._count] = w 32 | self._count += 1 33 | 34 | # Read the vocab file and add words up to max_size 35 | with open(vocab_file, 'r') as vocab_f: 36 | for line in vocab_f: 37 | pieces = line.split() 38 | if len(pieces) != 2: 39 | print 'Warning: incorrectly formatted line in vocabulary file: %s\n' % line 40 | continue 41 | w = pieces[0] 42 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 43 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 44 | if w in self._word_to_id: 45 | raise Exception('Duplicated word in vocabulary file: %s' % w) 46 | self._word_to_id[w] = self._count 47 | self._id_to_word[self._count] = w 48 | self._count += 1 49 | if max_size != 0 and self._count >= max_size: 50 | print "max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count) 51 | break 52 | 53 | print "Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]) 54 | 55 | def word2id(self, word): 56 | if word not in self._word_to_id: 57 | return self._word_to_id[UNKNOWN_TOKEN] 58 | return self._word_to_id[word] 59 | 60 | def id2word(self, word_id): 61 | if word_id not in self._id_to_word: 62 | raise ValueError('Id not found in vocab: %d' % word_id) 63 | return self._id_to_word[word_id] 64 | 65 | def size(self): 66 | return self._count 67 | 68 | def write_metadata(self, fpath): 69 | print "Writing word embedding metadata file to %s..." % (fpath) 70 | with open(fpath, "w") as f: 71 | fieldnames = ['word'] 72 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 73 | for i in xrange(self.size()): 74 | writer.writerow({"word": self._id_to_word[i]}) 75 | 76 | 77 | def example_generator(data_path, single_pass): 78 | while True: 79 | filelist = glob.glob(data_path) # get the list of datafiles 80 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 81 | if single_pass: 82 | filelist = sorted(filelist) 83 | else: 84 | random.shuffle(filelist) 85 | for f in filelist: 86 | reader = open(f, 'rb') 87 | while True: 88 | len_bytes = reader.read(8) 89 | if not len_bytes: break # finished reading this file 90 | str_len = struct.unpack('q', len_bytes)[0] 91 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 92 | yield example_pb2.Example.FromString(example_str) 93 | if single_pass: 94 | print "example_generator completed reading all datafiles. No more data." 95 | break 96 | 97 | 98 | def article2ids(article_words, vocab): 99 | ids = [] 100 | oovs = [] 101 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 102 | for w in article_words: 103 | i = vocab.word2id(w) 104 | if i == unk_id: # If w is OOV 105 | if w not in oovs: # Add to list of OOVs 106 | oovs.append(w) 107 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 108 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 109 | else: 110 | ids.append(i) 111 | return ids, oovs 112 | 113 | 114 | def abstract2ids(abstract_words, vocab, article_oovs): 115 | ids = [] 116 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 117 | for w in abstract_words: 118 | i = vocab.word2id(w) 119 | if i == unk_id: # If w is an OOV word 120 | if w in article_oovs: # If w is an in-article OOV 121 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 122 | ids.append(vocab_idx) 123 | else: # If w is an out-of-article OOV 124 | ids.append(unk_id) # Map to the UNK token id 125 | else: 126 | ids.append(i) 127 | return ids 128 | 129 | 130 | def outputids2words(id_list, vocab, article_oovs): 131 | words = [] 132 | for i in id_list: 133 | try: 134 | w = vocab.id2word(i) # might be [UNK] 135 | except ValueError as e: # w is OOV 136 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 137 | article_oov_idx = i - vocab.size() 138 | try: 139 | w = article_oovs[article_oov_idx] 140 | except ValueError as e: # i doesn't correspond to an article oov 141 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 142 | words.append(w) 143 | return words 144 | 145 | 146 | def abstract2sents(abstract): 147 | cur = 0 148 | sents = [] 149 | while True: 150 | try: 151 | start_p = abstract.index(SENTENCE_START, cur) 152 | end_p = abstract.index(SENTENCE_END, start_p + 1) 153 | cur = end_p + len(SENTENCE_END) 154 | sents.append(abstract[start_p+len(SENTENCE_START):end_p]) 155 | except ValueError as e: # no more sentences 156 | return sents 157 | 158 | 159 | def show_art_oovs(article, vocab): 160 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 161 | words = article.split(' ') 162 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words] 163 | out_str = ' '.join(words) 164 | return out_str 165 | 166 | 167 | def show_abs_oovs(abstract, vocab, article_oovs): 168 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 169 | words = abstract.split(' ') 170 | new_words = [] 171 | for w in words: 172 | if vocab.word2id(w) == unk_token: # w is oov 173 | if article_oovs is None: # baseline mode 174 | new_words.append("__%s__" % w) 175 | else: # pointer-generator mode 176 | if w in article_oovs: 177 | new_words.append("__%s__" % w) 178 | else: 179 | new_words.append("!!__%s__!!" % w) 180 | else: # w is in-vocab word 181 | new_words.append(w) 182 | out_str = ' '.join(new_words) 183 | return out_str 184 | -------------------------------------------------------------------------------- /data_util/utils.py: -------------------------------------------------------------------------------- 1 | #Content of this file is copied from https://github.com/abisee/pointer-generator/blob/master/ 2 | import os 3 | import pyrouge 4 | import logging 5 | import tensorflow as tf 6 | 7 | def print_results(article, abstract, decoded_output): 8 | print ("") 9 | print('ARTICLE: %s', article) 10 | print('REFERENCE SUMMARY: %s', abstract) 11 | print('GENERATED SUMMARY: %s', decoded_output) 12 | print( "") 13 | 14 | 15 | def make_html_safe(s): 16 | s.replace("<", "<") 17 | s.replace(">", ">") 18 | return s 19 | 20 | 21 | def rouge_eval(ref_dir, dec_dir): 22 | r = pyrouge.Rouge155() 23 | r.model_filename_pattern = '#ID#_reference.txt' 24 | r.system_filename_pattern = '(\d+)_decoded.txt' 25 | r.model_dir = ref_dir 26 | r.system_dir = dec_dir 27 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging 28 | rouge_results = r.convert_and_evaluate() 29 | return r.output_to_dict(rouge_results) 30 | 31 | 32 | def rouge_log(results_dict, dir_to_write): 33 | log_str = "" 34 | for x in ["1","2","l"]: 35 | log_str += "\nROUGE-%s:\n" % x 36 | for y in ["f_score", "recall", "precision"]: 37 | key = "rouge_%s_%s" % (x,y) 38 | key_cb = key + "_cb" 39 | key_ce = key + "_ce" 40 | val = results_dict[key] 41 | val_cb = results_dict[key_cb] 42 | val_ce = results_dict[key_ce] 43 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 44 | print(log_str) 45 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt") 46 | print("Writing final ROUGE results to %s..."%(results_file)) 47 | with open(results_file, "w") as f: 48 | f.write(log_str) 49 | 50 | 51 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): 52 | if running_avg_loss == 0: # on the first iteration just take the loss 53 | running_avg_loss = loss 54 | else: 55 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 56 | running_avg_loss = min(running_avg_loss, 12) # clip 57 | loss_sum = tf.Summary() 58 | tag_name = 'running_avg_loss/decay=%f' % (decay) 59 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) 60 | summary_writer.add_summary(loss_sum, step) 61 | return running_avg_loss 62 | 63 | 64 | def write_for_rouge(reference_sents, decoded_words, ex_index, 65 | _rouge_ref_dir, _rouge_dec_dir): 66 | decoded_sents = [] 67 | while len(decoded_words) > 0: 68 | try: 69 | fst_period_idx = decoded_words.index(".") 70 | except ValueError: 71 | fst_period_idx = len(decoded_words) 72 | sent = decoded_words[:fst_period_idx + 1] 73 | decoded_words = decoded_words[fst_period_idx + 1:] 74 | decoded_sents.append(' '.join(sent)) 75 | 76 | # pyrouge calls a perl script that puts the data into HTML files. 77 | # Therefore we need to make our output HTML safe. 78 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 79 | reference_sents = [make_html_safe(w) for w in reference_sents] 80 | 81 | ref_file = os.path.join(_rouge_ref_dir, "%06d_reference.txt" % ex_index) 82 | decoded_file = os.path.join(_rouge_dec_dir, "%06d_decoded.txt" % ex_index) 83 | 84 | with open(ref_file, "w") as f: 85 | for idx, sent in enumerate(reference_sents): 86 | f.write(sent) if idx == len(reference_sents) - 1 else f.write(sent + "\n") 87 | with open(decoded_file, "w") as f: 88 | for idx, sent in enumerate(decoded_sents): 89 | f.write(sent) if idx == len(decoded_sents) - 1 else f.write(sent + "\n") 90 | 91 | #print("Wrote example %i to file" % ex_index) 92 | -------------------------------------------------------------------------------- /learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/learning_curve.png -------------------------------------------------------------------------------- /learning_curve_coverage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/learning_curve_coverage.png -------------------------------------------------------------------------------- /start_decode.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=`pwd` 2 | MODEL=$1 3 | python training_ptr_gen/decode.py $MODEL >& ../log/decode_log & 4 | 5 | -------------------------------------------------------------------------------- /start_eval.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=`pwd` 2 | MODEL_PATH=$1 3 | MODEL_NAME=$(basename $MODEL_PATH) 4 | python training_ptr_gen/eval.py $MODEL_PATH >& ../log/eval_log.$MODEL_NAME & 5 | 6 | -------------------------------------------------------------------------------- /start_train.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=`pwd` 2 | python training_ptr_gen/train.py >& ../log/training_log & 3 | 4 | -------------------------------------------------------------------------------- /training_ptr_gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/training_ptr_gen/__init__.py -------------------------------------------------------------------------------- /training_ptr_gen/decode.py: -------------------------------------------------------------------------------- 1 | #Except for the pytorch part content of this file is copied from https://github.com/abisee/pointer-generator/blob/master/ 2 | 3 | from __future__ import unicode_literals, print_function, division 4 | 5 | import sys 6 | 7 | reload(sys) 8 | sys.setdefaultencoding('utf8') 9 | 10 | import os 11 | import time 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | 16 | from data_util.batcher import Batcher 17 | from data_util.data import Vocab 18 | from data_util import data, config 19 | from model import Model 20 | from data_util.utils import write_for_rouge, rouge_eval, rouge_log 21 | from train_util import get_input_from_batch 22 | 23 | 24 | use_cuda = config.use_gpu and torch.cuda.is_available() 25 | 26 | class Beam(object): 27 | def __init__(self, tokens, log_probs, state, context, coverage): 28 | self.tokens = tokens 29 | self.log_probs = log_probs 30 | self.state = state 31 | self.context = context 32 | self.coverage = coverage 33 | 34 | def extend(self, token, log_prob, state, context, coverage): 35 | return Beam(tokens = self.tokens + [token], 36 | log_probs = self.log_probs + [log_prob], 37 | state = state, 38 | context = context, 39 | coverage = coverage) 40 | 41 | @property 42 | def latest_token(self): 43 | return self.tokens[-1] 44 | 45 | @property 46 | def avg_log_prob(self): 47 | return sum(self.log_probs) / len(self.tokens) 48 | 49 | 50 | class BeamSearch(object): 51 | def __init__(self, model_file_path): 52 | model_name = os.path.basename(model_file_path) 53 | self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name)) 54 | self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') 55 | self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') 56 | for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: 57 | if not os.path.exists(p): 58 | os.mkdir(p) 59 | 60 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 61 | self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', 62 | batch_size=config.beam_size, single_pass=True) 63 | time.sleep(15) 64 | 65 | self.model = Model(model_file_path, is_eval=True) 66 | 67 | def sort_beams(self, beams): 68 | return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) 69 | 70 | 71 | def decode(self): 72 | start = time.time() 73 | counter = 0 74 | batch = self.batcher.next_batch() 75 | while batch is not None: 76 | # Run beam search to get best Hypothesis 77 | best_summary = self.beam_search(batch) 78 | 79 | # Extract the output ids from the hypothesis and convert back to words 80 | output_ids = [int(t) for t in best_summary.tokens[1:]] 81 | decoded_words = data.outputids2words(output_ids, self.vocab, 82 | (batch.art_oovs[0] if config.pointer_gen else None)) 83 | 84 | # Remove the [STOP] token from decoded_words, if necessary 85 | try: 86 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) 87 | decoded_words = decoded_words[:fst_stop_idx] 88 | except ValueError: 89 | decoded_words = decoded_words 90 | 91 | original_abstract_sents = batch.original_abstracts_sents[0] 92 | 93 | write_for_rouge(original_abstract_sents, decoded_words, counter, 94 | self._rouge_ref_dir, self._rouge_dec_dir) 95 | counter += 1 96 | if counter % 1000 == 0: 97 | print('%d example in %d sec'%(counter, time.time() - start)) 98 | start = time.time() 99 | 100 | batch = self.batcher.next_batch() 101 | 102 | print("Decoder has finished reading dataset for single_pass.") 103 | print("Now starting ROUGE eval...") 104 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) 105 | rouge_log(results_dict, self._decode_dir) 106 | 107 | 108 | def beam_search(self, batch): 109 | #batch should have only one example 110 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \ 111 | get_input_from_batch(batch, use_cuda) 112 | 113 | encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) 114 | s_t_0 = self.model.reduce_state(encoder_hidden) 115 | 116 | dec_h, dec_c = s_t_0 # 1 x 2*hidden_size 117 | dec_h = dec_h.squeeze() 118 | dec_c = dec_c.squeeze() 119 | 120 | #decoder batch preparation, it has beam_size example initially everything is repeated 121 | beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)], 122 | log_probs=[0.0], 123 | state=(dec_h[0], dec_c[0]), 124 | context = c_t_0[0], 125 | coverage=(coverage_t_0[0] if config.is_coverage else None)) 126 | for _ in xrange(config.beam_size)] 127 | results = [] 128 | steps = 0 129 | while steps < config.max_dec_steps and len(results) < config.beam_size: 130 | latest_tokens = [h.latest_token for h in beams] 131 | latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ 132 | for t in latest_tokens] 133 | y_t_1 = Variable(torch.LongTensor(latest_tokens)) 134 | if use_cuda: 135 | y_t_1 = y_t_1.cuda() 136 | all_state_h =[] 137 | all_state_c = [] 138 | 139 | all_context = [] 140 | 141 | for h in beams: 142 | state_h, state_c = h.state 143 | all_state_h.append(state_h) 144 | all_state_c.append(state_c) 145 | 146 | all_context.append(h.context) 147 | 148 | s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) 149 | c_t_1 = torch.stack(all_context, 0) 150 | 151 | coverage_t_1 = None 152 | if config.is_coverage: 153 | all_coverage = [] 154 | for h in beams: 155 | all_coverage.append(h.coverage) 156 | coverage_t_1 = torch.stack(all_coverage, 0) 157 | 158 | final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1, 159 | encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, 160 | extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps) 161 | log_probs = torch.log(final_dist) 162 | topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2) 163 | 164 | dec_h, dec_c = s_t 165 | dec_h = dec_h.squeeze() 166 | dec_c = dec_c.squeeze() 167 | 168 | all_beams = [] 169 | num_orig_beams = 1 if steps == 0 else len(beams) 170 | for i in xrange(num_orig_beams): 171 | h = beams[i] 172 | state_i = (dec_h[i], dec_c[i]) 173 | context_i = c_t[i] 174 | coverage_i = (coverage_t[i] if config.is_coverage else None) 175 | 176 | for j in xrange(config.beam_size * 2): # for each of the top 2*beam_size hyps: 177 | new_beam = h.extend(token=topk_ids[i, j].item(), 178 | log_prob=topk_log_probs[i, j].item(), 179 | state=state_i, 180 | context=context_i, 181 | coverage=coverage_i) 182 | all_beams.append(new_beam) 183 | 184 | beams = [] 185 | for h in self.sort_beams(all_beams): 186 | if h.latest_token == self.vocab.word2id(data.STOP_DECODING): 187 | if steps >= config.min_dec_steps: 188 | results.append(h) 189 | else: 190 | beams.append(h) 191 | if len(beams) == config.beam_size or len(results) == config.beam_size: 192 | break 193 | 194 | steps += 1 195 | 196 | if len(results) == 0: 197 | results = beams 198 | 199 | beams_sorted = self.sort_beams(results) 200 | 201 | return beams_sorted[0] 202 | 203 | if __name__ == '__main__': 204 | model_filename = sys.argv[1] 205 | beam_Search_processor = BeamSearch(model_filename) 206 | beam_Search_processor.decode() 207 | 208 | 209 | -------------------------------------------------------------------------------- /training_ptr_gen/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import os 4 | import time 5 | import sys 6 | 7 | import tensorflow as tf 8 | import torch 9 | 10 | from data_util import config 11 | from data_util.batcher import Batcher 12 | from data_util.data import Vocab 13 | 14 | from data_util.utils import calc_running_avg_loss 15 | from train_util import get_input_from_batch, get_output_from_batch 16 | from model import Model 17 | 18 | use_cuda = config.use_gpu and torch.cuda.is_available() 19 | 20 | class Evaluate(object): 21 | def __init__(self, model_file_path): 22 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 23 | self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', 24 | batch_size=config.batch_size, single_pass=True) 25 | time.sleep(15) 26 | model_name = os.path.basename(model_file_path) 27 | 28 | eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name)) 29 | if not os.path.exists(eval_dir): 30 | os.mkdir(eval_dir) 31 | self.summary_writer = tf.summary.FileWriter(eval_dir) 32 | 33 | self.model = Model(model_file_path, is_eval=True) 34 | 35 | def eval_one_batch(self, batch): 36 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ 37 | get_input_from_batch(batch, use_cuda) 38 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ 39 | get_output_from_batch(batch, use_cuda) 40 | 41 | encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) 42 | s_t_1 = self.model.reduce_state(encoder_hidden) 43 | 44 | step_losses = [] 45 | for di in range(min(max_dec_len, config.max_dec_steps)): 46 | y_t_1 = dec_batch[:, di] # Teacher forcing 47 | final_dist, s_t_1, c_t_1,attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1, 48 | encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, 49 | extra_zeros, enc_batch_extend_vocab, coverage, di) 50 | target = target_batch[:, di] 51 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() 52 | step_loss = -torch.log(gold_probs + config.eps) 53 | if config.is_coverage: 54 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) 55 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 56 | coverage = next_coverage 57 | 58 | step_mask = dec_padding_mask[:, di] 59 | step_loss = step_loss * step_mask 60 | step_losses.append(step_loss) 61 | 62 | sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1) 63 | batch_avg_loss = sum_step_losses / dec_lens_var 64 | loss = torch.mean(batch_avg_loss) 65 | 66 | return loss.data[0] 67 | 68 | def run_eval(self): 69 | running_avg_loss, iter = 0, 0 70 | start = time.time() 71 | batch = self.batcher.next_batch() 72 | while batch is not None: 73 | loss = self.eval_one_batch(batch) 74 | 75 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 76 | iter += 1 77 | 78 | if iter % 100 == 0: 79 | self.summary_writer.flush() 80 | print_interval = 1000 81 | if iter % print_interval == 0: 82 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % ( 83 | iter, print_interval, time.time() - start, running_avg_loss)) 84 | start = time.time() 85 | batch = self.batcher.next_batch() 86 | 87 | 88 | if __name__ == '__main__': 89 | model_filename = sys.argv[1] 90 | eval_processor = Evaluate(model_filename) 91 | eval_processor.run_eval() 92 | 93 | 94 | -------------------------------------------------------------------------------- /training_ptr_gen/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | from data_util import config 8 | from numpy import random 9 | 10 | use_cuda = config.use_gpu and torch.cuda.is_available() 11 | 12 | random.seed(123) 13 | torch.manual_seed(123) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed_all(123) 16 | 17 | def init_lstm_wt(lstm): 18 | for names in lstm._all_weights: 19 | for name in names: 20 | if name.startswith('weight_'): 21 | wt = getattr(lstm, name) 22 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag) 23 | elif name.startswith('bias_'): 24 | # set forget bias to 1 25 | bias = getattr(lstm, name) 26 | n = bias.size(0) 27 | start, end = n // 4, n // 2 28 | bias.data.fill_(0.) 29 | bias.data[start:end].fill_(1.) 30 | 31 | def init_linear_wt(linear): 32 | linear.weight.data.normal_(std=config.trunc_norm_init_std) 33 | if linear.bias is not None: 34 | linear.bias.data.normal_(std=config.trunc_norm_init_std) 35 | 36 | def init_wt_normal(wt): 37 | wt.data.normal_(std=config.trunc_norm_init_std) 38 | 39 | def init_wt_unif(wt): 40 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag) 41 | 42 | class Encoder(nn.Module): 43 | def __init__(self): 44 | super(Encoder, self).__init__() 45 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim) 46 | init_wt_normal(self.embedding.weight) 47 | 48 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True) 49 | init_lstm_wt(self.lstm) 50 | 51 | self.W_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False) 52 | 53 | #seq_lens should be in descending order 54 | def forward(self, input, seq_lens): 55 | embedded = self.embedding(input) 56 | 57 | packed = pack_padded_sequence(embedded, seq_lens, batch_first=True) 58 | output, hidden = self.lstm(packed) 59 | 60 | encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x t_k x n 61 | encoder_outputs = encoder_outputs.contiguous() 62 | 63 | encoder_feature = encoder_outputs.view(-1, 2*config.hidden_dim) # B * t_k x 2*hidden_dim 64 | encoder_feature = self.W_h(encoder_feature) 65 | 66 | return encoder_outputs, encoder_feature, hidden 67 | 68 | class ReduceState(nn.Module): 69 | def __init__(self): 70 | super(ReduceState, self).__init__() 71 | 72 | self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim) 73 | init_linear_wt(self.reduce_h) 74 | self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim) 75 | init_linear_wt(self.reduce_c) 76 | 77 | def forward(self, hidden): 78 | h, c = hidden # h, c dim = 2 x b x hidden_dim 79 | h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2) 80 | hidden_reduced_h = F.relu(self.reduce_h(h_in)) 81 | c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2) 82 | hidden_reduced_c = F.relu(self.reduce_c(c_in)) 83 | 84 | return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim 85 | 86 | class Attention(nn.Module): 87 | def __init__(self): 88 | super(Attention, self).__init__() 89 | # attention 90 | if config.is_coverage: 91 | self.W_c = nn.Linear(1, config.hidden_dim * 2, bias=False) 92 | self.decode_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2) 93 | self.v = nn.Linear(config.hidden_dim * 2, 1, bias=False) 94 | 95 | def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage): 96 | b, t_k, n = list(encoder_outputs.size()) 97 | 98 | dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dim 99 | dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dim 100 | dec_fea_expanded = dec_fea_expanded.view(-1, n) # B * t_k x 2*hidden_dim 101 | 102 | att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim 103 | if config.is_coverage: 104 | coverage_input = coverage.view(-1, 1) # B * t_k x 1 105 | coverage_feature = self.W_c(coverage_input) # B * t_k x 2*hidden_dim 106 | att_features = att_features + coverage_feature 107 | 108 | e = F.tanh(att_features) # B * t_k x 2*hidden_dim 109 | scores = self.v(e) # B * t_k x 1 110 | scores = scores.view(-1, t_k) # B x t_k 111 | 112 | attn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_k 113 | normalization_factor = attn_dist_.sum(1, keepdim=True) 114 | attn_dist = attn_dist_ / normalization_factor 115 | 116 | attn_dist = attn_dist.unsqueeze(1) # B x 1 x t_k 117 | c_t = torch.bmm(attn_dist, encoder_outputs) # B x 1 x n 118 | c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim 119 | 120 | attn_dist = attn_dist.view(-1, t_k) # B x t_k 121 | 122 | if config.is_coverage: 123 | coverage = coverage.view(-1, t_k) 124 | coverage = coverage + attn_dist 125 | 126 | return c_t, attn_dist, coverage 127 | 128 | class Decoder(nn.Module): 129 | def __init__(self): 130 | super(Decoder, self).__init__() 131 | self.attention_network = Attention() 132 | # decoder 133 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim) 134 | init_wt_normal(self.embedding.weight) 135 | 136 | self.x_context = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim) 137 | 138 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=False) 139 | init_lstm_wt(self.lstm) 140 | 141 | if config.pointer_gen: 142 | self.p_gen_linear = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1) 143 | 144 | #p_vocab 145 | self.out1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim) 146 | self.out2 = nn.Linear(config.hidden_dim, config.vocab_size) 147 | init_linear_wt(self.out2) 148 | 149 | def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, 150 | c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step): 151 | 152 | if not self.training and step == 0: 153 | h_decoder, c_decoder = s_t_1 154 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim), 155 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim 156 | c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature, 157 | enc_padding_mask, coverage) 158 | coverage = coverage_next 159 | 160 | y_t_1_embd = self.embedding(y_t_1) 161 | x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1)) 162 | lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1) 163 | 164 | h_decoder, c_decoder = s_t 165 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim), 166 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim 167 | c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature, 168 | enc_padding_mask, coverage) 169 | 170 | if self.training or step > 0: 171 | coverage = coverage_next 172 | 173 | p_gen = None 174 | if config.pointer_gen: 175 | p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim) 176 | p_gen = self.p_gen_linear(p_gen_input) 177 | p_gen = F.sigmoid(p_gen) 178 | 179 | output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3 180 | output = self.out1(output) # B x hidden_dim 181 | 182 | #output = F.relu(output) 183 | 184 | output = self.out2(output) # B x vocab_size 185 | vocab_dist = F.softmax(output, dim=1) 186 | 187 | if config.pointer_gen: 188 | vocab_dist_ = p_gen * vocab_dist 189 | attn_dist_ = (1 - p_gen) * attn_dist 190 | 191 | if extra_zeros is not None: 192 | vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1) 193 | 194 | final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_) 195 | else: 196 | final_dist = vocab_dist 197 | 198 | return final_dist, s_t, c_t, attn_dist, p_gen, coverage 199 | 200 | class Model(object): 201 | def __init__(self, model_file_path=None, is_eval=False): 202 | encoder = Encoder() 203 | decoder = Decoder() 204 | reduce_state = ReduceState() 205 | 206 | # shared the embedding between encoder and decoder 207 | decoder.embedding.weight = encoder.embedding.weight 208 | if is_eval: 209 | encoder = encoder.eval() 210 | decoder = decoder.eval() 211 | reduce_state = reduce_state.eval() 212 | 213 | if use_cuda: 214 | encoder = encoder.cuda() 215 | decoder = decoder.cuda() 216 | reduce_state = reduce_state.cuda() 217 | 218 | self.encoder = encoder 219 | self.decoder = decoder 220 | self.reduce_state = reduce_state 221 | 222 | if model_file_path is not None: 223 | state = torch.load(model_file_path, map_location= lambda storage, location: storage) 224 | self.encoder.load_state_dict(state['encoder_state_dict']) 225 | self.decoder.load_state_dict(state['decoder_state_dict'], strict=False) 226 | self.reduce_state.load_state_dict(state['reduce_state_dict']) 227 | -------------------------------------------------------------------------------- /training_ptr_gen/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import os 4 | import time 5 | import argparse 6 | 7 | import tensorflow as tf 8 | import torch 9 | from model import Model 10 | from torch.nn.utils import clip_grad_norm_ 11 | 12 | from torch.optim import Adagrad 13 | 14 | from data_util import config 15 | from data_util.batcher import Batcher 16 | from data_util.data import Vocab 17 | from data_util.utils import calc_running_avg_loss 18 | from train_util import get_input_from_batch, get_output_from_batch 19 | 20 | use_cuda = config.use_gpu and torch.cuda.is_available() 21 | 22 | class Train(object): 23 | def __init__(self): 24 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 25 | self.batcher = Batcher(config.train_data_path, self.vocab, mode='train', 26 | batch_size=config.batch_size, single_pass=False) 27 | time.sleep(15) 28 | 29 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) 30 | if not os.path.exists(train_dir): 31 | os.mkdir(train_dir) 32 | 33 | self.model_dir = os.path.join(train_dir, 'model') 34 | if not os.path.exists(self.model_dir): 35 | os.mkdir(self.model_dir) 36 | 37 | self.summary_writer = tf.summary.FileWriter(train_dir) 38 | 39 | def save_model(self, running_avg_loss, iter): 40 | state = { 41 | 'iter': iter, 42 | 'encoder_state_dict': self.model.encoder.state_dict(), 43 | 'decoder_state_dict': self.model.decoder.state_dict(), 44 | 'reduce_state_dict': self.model.reduce_state.state_dict(), 45 | 'optimizer': self.optimizer.state_dict(), 46 | 'current_loss': running_avg_loss 47 | } 48 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) 49 | torch.save(state, model_save_path) 50 | 51 | def setup_train(self, model_file_path=None): 52 | self.model = Model(model_file_path) 53 | 54 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \ 55 | list(self.model.reduce_state.parameters()) 56 | initial_lr = config.lr_coverage if config.is_coverage else config.lr 57 | self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc) 58 | 59 | start_iter, start_loss = 0, 0 60 | 61 | if model_file_path is not None: 62 | state = torch.load(model_file_path, map_location= lambda storage, location: storage) 63 | start_iter = state['iter'] 64 | start_loss = state['current_loss'] 65 | 66 | if not config.is_coverage: 67 | self.optimizer.load_state_dict(state['optimizer']) 68 | if use_cuda: 69 | for state in self.optimizer.state.values(): 70 | for k, v in state.items(): 71 | if torch.is_tensor(v): 72 | state[k] = v.cuda() 73 | 74 | return start_iter, start_loss 75 | 76 | def train_one_batch(self, batch): 77 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ 78 | get_input_from_batch(batch, use_cuda) 79 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ 80 | get_output_from_batch(batch, use_cuda) 81 | 82 | self.optimizer.zero_grad() 83 | 84 | encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) 85 | s_t_1 = self.model.reduce_state(encoder_hidden) 86 | 87 | step_losses = [] 88 | for di in range(min(max_dec_len, config.max_dec_steps)): 89 | y_t_1 = dec_batch[:, di] # Teacher forcing 90 | final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1, 91 | encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, 92 | extra_zeros, enc_batch_extend_vocab, 93 | coverage, di) 94 | target = target_batch[:, di] 95 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() 96 | step_loss = -torch.log(gold_probs + config.eps) 97 | if config.is_coverage: 98 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) 99 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 100 | coverage = next_coverage 101 | 102 | step_mask = dec_padding_mask[:, di] 103 | step_loss = step_loss * step_mask 104 | step_losses.append(step_loss) 105 | 106 | sum_losses = torch.sum(torch.stack(step_losses, 1), 1) 107 | batch_avg_loss = sum_losses/dec_lens_var 108 | loss = torch.mean(batch_avg_loss) 109 | 110 | loss.backward() 111 | 112 | self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) 113 | clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) 114 | clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) 115 | 116 | self.optimizer.step() 117 | 118 | return loss.item() 119 | 120 | def trainIters(self, n_iters, model_file_path=None): 121 | iter, running_avg_loss = self.setup_train(model_file_path) 122 | start = time.time() 123 | while iter < n_iters: 124 | batch = self.batcher.next_batch() 125 | loss = self.train_one_batch(batch) 126 | 127 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 128 | iter += 1 129 | 130 | if iter % 100 == 0: 131 | self.summary_writer.flush() 132 | print_interval = 1000 133 | if iter % print_interval == 0: 134 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval, 135 | time.time() - start, loss)) 136 | start = time.time() 137 | if iter % 5000 == 0: 138 | self.save_model(running_avg_loss, iter) 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser(description="Train script") 142 | parser.add_argument("-m", 143 | dest="model_file_path", 144 | required=False, 145 | default=None, 146 | help="Model file for retraining (default: None).") 147 | args = parser.parse_args() 148 | 149 | train_processor = Train() 150 | train_processor.trainIters(config.max_iterations, args.model_file_path) 151 | -------------------------------------------------------------------------------- /training_ptr_gen/train_util.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import numpy as np 3 | import torch 4 | from data_util import config 5 | 6 | def get_input_from_batch(batch, use_cuda): 7 | batch_size = len(batch.enc_lens) 8 | 9 | enc_batch = Variable(torch.from_numpy(batch.enc_batch).long()) 10 | enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float() 11 | enc_lens = batch.enc_lens 12 | extra_zeros = None 13 | enc_batch_extend_vocab = None 14 | 15 | if config.pointer_gen: 16 | enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long()) 17 | # max_art_oovs is the max over all the article oov list in the batch 18 | if batch.max_art_oovs > 0: 19 | extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs))) 20 | 21 | c_t_1 = Variable(torch.zeros((batch_size, 2 * config.hidden_dim))) 22 | 23 | coverage = None 24 | if config.is_coverage: 25 | coverage = Variable(torch.zeros(enc_batch.size())) 26 | 27 | if use_cuda: 28 | enc_batch = enc_batch.cuda() 29 | enc_padding_mask = enc_padding_mask.cuda() 30 | 31 | if enc_batch_extend_vocab is not None: 32 | enc_batch_extend_vocab = enc_batch_extend_vocab.cuda() 33 | if extra_zeros is not None: 34 | extra_zeros = extra_zeros.cuda() 35 | c_t_1 = c_t_1.cuda() 36 | 37 | if coverage is not None: 38 | coverage = coverage.cuda() 39 | 40 | return enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage 41 | 42 | def get_output_from_batch(batch, use_cuda): 43 | dec_batch = Variable(torch.from_numpy(batch.dec_batch).long()) 44 | dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float() 45 | dec_lens = batch.dec_lens 46 | max_dec_len = np.max(dec_lens) 47 | dec_lens_var = Variable(torch.from_numpy(dec_lens)).float() 48 | 49 | target_batch = Variable(torch.from_numpy(batch.target_batch)).long() 50 | 51 | if use_cuda: 52 | dec_batch = dec_batch.cuda() 53 | dec_padding_mask = dec_padding_mask.cuda() 54 | dec_lens_var = dec_lens_var.cuda() 55 | target_batch = target_batch.cuda() 56 | 57 | 58 | return dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch 59 | 60 | -------------------------------------------------------------------------------- /training_ptr_gen/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | #This is still a work in progress I will work on it once I get some free time. 2 | 3 | from __future__ import unicode_literals, print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import logging 9 | import math 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | class PositionalEncoding(nn.Module): 14 | def __init__(self, d_model, dropout, max_len=5000): 15 | super(PositionalEncoding, self).__init__() 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len).float().unsqueeze(1) 20 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 21 | -(math.log(10000.0) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | pe = pe.unsqueeze(0) 25 | self.register_buffer('pe', pe) 26 | 27 | def forward(self, x): 28 | x = x + self.pe[:, :x.size(1)] 29 | return self.dropout(x) 30 | 31 | class MultiHeadedAttention(nn.Module): 32 | def __init__(self, num_head, d_model, dropout=0.1): 33 | super(MultiHeadedAttention, self).__init__() 34 | assert d_model % num_head == 0 35 | self.d_k = d_model // num_head #d_k == d_v 36 | self.h = num_head 37 | 38 | self.linear_key = nn.Linear(d_model, d_model) 39 | self.linear_value = nn.Linear(d_model, d_model) 40 | self.linear_query = nn.Linear(d_model, d_model) 41 | self.linear_out = nn.Linear(d_model, d_model) 42 | 43 | self.dropout = nn.Dropout(p=dropout) 44 | 45 | def attention(self, query, key, value, mask, dropout=None): 46 | d_k = query.size(-1) 47 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 48 | scores = scores.masked_fill(mask == 0, -1e9) 49 | 50 | p_attn = F.softmax(scores, dim=-1) 51 | if dropout is not None: 52 | p_attn = dropout(p_attn) 53 | return torch.matmul(p_attn, value), p_attn 54 | 55 | def forward(self, query, key, value, mask): 56 | nbatches = query.size(0) 57 | query = self.linear_query(query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 58 | key = self.linear_key(key).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 59 | value = self.linear_value(value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 60 | 61 | mask = mask.unsqueeze(1) 62 | x, attn = self.attention(query, key, value, mask, dropout=self.dropout) 63 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) 64 | return self.linear_out(x) 65 | 66 | class AffineLayer(nn.Module): 67 | def __init__(self, dropout, d_model, d_ff): 68 | super(AffineLayer, self).__init__() 69 | self.w_1 = nn.Linear(d_model, d_ff) 70 | self.w_2 = nn.Linear(d_ff, d_model) 71 | self.dropout = nn.Dropout(dropout) 72 | 73 | def forward(self, x): 74 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 75 | 76 | class EncoderLayer(nn.Module): 77 | def __init__(self, num_head, dropout, d_model, d_ff): 78 | super(EncoderLayer, self).__init__() 79 | 80 | self.att_layer = MultiHeadedAttention(num_head, d_model, dropout) 81 | self.norm_att = nn.LayerNorm(d_model) 82 | self.dropout_att = nn.Dropout(dropout) 83 | 84 | self.affine_layer = AffineLayer(dropout, d_model, d_ff) 85 | self.norm_affine = nn.LayerNorm(d_model) 86 | self.dropout_affine = nn.Dropout(dropout) 87 | 88 | def forward(self, x, mask): 89 | x_att = self.norm_att(x*mask) 90 | x_att = self.att_layer(x_att, x_att, x_att, mask) 91 | x = x + self.dropout_att(x_att) 92 | 93 | x_affine = self.norm_affine(x*mask) 94 | x_affine = self.affine_layer(x_affine) 95 | return x + self.dropout_affine(x_affine) 96 | 97 | class Encoder(nn.Module): 98 | def __init__(self, N, num_head, dropout, d_model, d_ff): 99 | super(Encoder, self).__init__() 100 | self.position = PositionalEncoding(d_model, dropout) 101 | self.layers = nn.ModuleList() 102 | for _ in range(N): 103 | self.layers.append(EncoderLayer(num_head, dropout, d_model, d_ff)) 104 | self.norm = nn.LayerNorm(d_model) 105 | 106 | def forward(self, word_embed, mask): 107 | x = self.position(word_embed) 108 | for layer in self.layers: 109 | x = layer(x, mask) 110 | return self.norm(x*mask) 111 | --------------------------------------------------------------------------------