├── .gitignore
├── LICENSE
├── README.md
├── data_util
├── __init__.py
├── batcher.py
├── config.py
├── data.py
└── utils.py
├── learning_curve.png
├── learning_curve_coverage.png
├── start_decode.sh
├── start_eval.sh
├── start_train.sh
└── training_ptr_gen
├── __init__.py
├── decode.py
├── eval.py
├── model.py
├── train.py
├── train_util.py
└── transformer_encoder.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.bin
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | pytorch implementation of *[Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)*
2 |
3 | 1. [Train with pointer generation and coverage loss enabled](#train-with-pointer-generation-and-coverage-loss-enabled)
4 | 2. [Training with pointer generation enabled](#training-with-pointer-generation-enabled)
5 | 3. [How to run training](#how-to-run-training)
6 | 4. [Papers using this code](#papers-using-this-code)
7 |
8 |
9 | ## Train with pointer generation and coverage loss enabled
10 | After training for 100k iterations with coverage loss enabled (batch size 8)
11 |
12 | ```
13 | ROUGE-1:
14 | rouge_1_f_score: 0.3907 with confidence interval (0.3885, 0.3928)
15 | rouge_1_recall: 0.4434 with confidence interval (0.4410, 0.4460)
16 | rouge_1_precision: 0.3698 with confidence interval (0.3672, 0.3721)
17 |
18 | ROUGE-2:
19 | rouge_2_f_score: 0.1697 with confidence interval (0.1674, 0.1720)
20 | rouge_2_recall: 0.1920 with confidence interval (0.1894, 0.1945)
21 | rouge_2_precision: 0.1614 with confidence interval (0.1590, 0.1636)
22 |
23 | ROUGE-l:
24 | rouge_l_f_score: 0.3587 with confidence interval (0.3565, 0.3608)
25 | rouge_l_recall: 0.4067 with confidence interval (0.4042, 0.4092)
26 | rouge_l_precision: 0.3397 with confidence interval (0.3371, 0.3420)
27 | ```
28 |
29 | 
30 |
31 | ## Training with pointer generation enabled
32 | After training for 500k iterations (batch size 8)
33 |
34 | ```
35 | ROUGE-1:
36 | rouge_1_f_score: 0.3500 with confidence interval (0.3477, 0.3523)
37 | rouge_1_recall: 0.3718 with confidence interval (0.3693, 0.3745)
38 | rouge_1_precision: 0.3529 with confidence interval (0.3501, 0.3555)
39 |
40 | ROUGE-2:
41 | rouge_2_f_score: 0.1486 with confidence interval (0.1465, 0.1508)
42 | rouge_2_recall: 0.1573 with confidence interval (0.1551, 0.1597)
43 | rouge_2_precision: 0.1506 with confidence interval (0.1483, 0.1529)
44 |
45 | ROUGE-l:
46 | rouge_l_f_score: 0.3202 with confidence interval (0.3179, 0.3225)
47 | rouge_l_recall: 0.3399 with confidence interval (0.3374, 0.3426)
48 | rouge_l_precision: 0.3231 with confidence interval (0.3205, 0.3256)
49 | ```
50 | 
51 |
52 |
53 | ## How to run training:
54 | 1) Follow data generation instruction from https://github.com/abisee/cnn-dailymail
55 | 2) Run start_train.sh, you might need to change some path and parameters in data_util/config.py
56 | 3) For training run start_train.sh, for decoding run start_decode.sh, and for evaluating run run_eval.sh
57 |
58 | Note:
59 |
60 | * In decode mode beam search batch should have only one example replicated to batch size
61 | https://github.com/atulkum/pointer_summarizer/blob/master/training_ptr_gen/decode.py#L109
62 | https://github.com/atulkum/pointer_summarizer/blob/master/data_util/batcher.py#L226
63 |
64 | * It is tested on pytorch 0.4 with python 2.7
65 | * You need to setup [pyrouge](https://github.com/andersjo/pyrouge) to get the rouge score
66 |
67 | ## Papers using this code:
68 | 1) [Automatic Program Synthesis of Long Programs with a Learned Garbage Collector](http://papers.nips.cc/paper/7479-automatic-program-synthesis-of-long-programs-with-a-learned-garbage-collector) ***NeuroIPS 2018*** https://github.com/amitz25/PCCoder
69 | 2) [Automatic Fact-guided Sentence Modification](https://arxiv.org/abs/1909.13838) ***AAAI 2020*** https://github.com/darsh10/split_encoder_pointer_summarizer
70 | 3) [Resurrecting Submodularity in Neural Abstractive Summarization](https://arxiv.org/abs/1911.03014v1)
71 | 4) [StructSum: Summarization via Structured Representations](https://aclanthology.org/2021.eacl-main.220) ***EACL 2021***
72 | 5) [Concept Pointer Network for Abstractive Summarization](https://arxiv.org/abs/1910.08486) ***EMNLP'2019*** https://github.com/wprojectsn/codes
73 | 6) [PaddlePaddle version](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_summarization/pointer_summarizer)
74 | 7) [VAE-PGN based Abstractive Model in Multi-stage Architecture for Text Summarization](https://www.aclweb.org/anthology/W19-8664/) ***INLG2019***
75 | 8) [Clickbait? Sensational Headline Generation with Auto-tuned Reinforcement Learning](https://arxiv.org/abs/1909.03582) ***EMNLP'2019*** https://github.com/HLTCHKUST/sensational_headline
76 | 9) [Abstractive Spoken Document Summarization using Hierarchical Model with Multi-stage Attention Diversity Optimization](http://www.interspeech2020.org/index.php?m=content&c=index&a=show&catid=354&id=1173) ***INTERSPEECH 2020***
77 | 10) [Nutribullets Hybrid: Multi-document Health Summarization](https://arxiv.org/abs/2104.03465) ***NAACL 2021***
78 | 11) [A Corpus of Very Short Scientific Summaries](https://aclanthology.org/2020.conll-1.12) ***CoNLL 2020***
79 | 12) [Towards Faithfulness in Open Domain Table-to-text Generation from an Entity-centric View](https://arxiv.org/abs/2102.08585) ***AAAI 2021***
80 | 13) [CDEvalSumm: An Empirical Study of Cross-Dataset Evaluation for Neural Summarization Systems](https://aclanthology.org/2020.findings-emnlp.329) ***Findings of EMNLP2020***
81 | 14) [A Study on Seq2seq for Sentence Compression in Vietnamese](https://aclanthology.org/2020.paclic-1.56) ***PACLIC 2020***
82 | 15) [Other Roles Matter! Enhancing Role-Oriented Dialogue Summarization via Role Interactions](https://aclanthology.org/2022.acl-long.182/) ***ACL 2022***
83 |
84 |
--------------------------------------------------------------------------------
/data_util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/data_util/__init__.py
--------------------------------------------------------------------------------
/data_util/batcher.py:
--------------------------------------------------------------------------------
1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/batcher.py
2 |
3 | import Queue
4 | import time
5 | from random import shuffle
6 | from threading import Thread
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | import config
12 | import data
13 |
14 | import random
15 | random.seed(1234)
16 |
17 |
18 | class Example(object):
19 |
20 | def __init__(self, article, abstract_sentences, vocab):
21 | # Get ids of special tokens
22 | start_decoding = vocab.word2id(data.START_DECODING)
23 | stop_decoding = vocab.word2id(data.STOP_DECODING)
24 |
25 | # Process the article
26 | article_words = article.split()
27 | if len(article_words) > config.max_enc_steps:
28 | article_words = article_words[:config.max_enc_steps]
29 | self.enc_len = len(article_words) # store the length after truncation but before padding
30 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token
31 |
32 | # Process the abstract
33 | abstract = ' '.join(abstract_sentences) # string
34 | abstract_words = abstract.split() # list of strings
35 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token
36 |
37 | # Get the decoder input sequence and target sequence
38 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding)
39 | self.dec_len = len(self.dec_input)
40 |
41 | # If using pointer-generator mode, we need to store some extra info
42 | if config.pointer_gen:
43 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
44 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
45 |
46 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
47 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)
48 |
49 | # Overwrite decoder target sequence so it uses the temp article OOV ids
50 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding)
51 |
52 | # Store the original strings
53 | self.original_article = article
54 | self.original_abstract = abstract
55 | self.original_abstract_sents = abstract_sentences
56 |
57 |
58 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
59 | inp = [start_id] + sequence[:]
60 | target = sequence[:]
61 | if len(inp) > max_len: # truncate
62 | inp = inp[:max_len]
63 | target = target[:max_len] # no end_token
64 | else: # no truncation
65 | target.append(stop_id) # end token
66 | assert len(inp) == len(target)
67 | return inp, target
68 |
69 |
70 | def pad_decoder_inp_targ(self, max_len, pad_id):
71 | while len(self.dec_input) < max_len:
72 | self.dec_input.append(pad_id)
73 | while len(self.target) < max_len:
74 | self.target.append(pad_id)
75 |
76 |
77 | def pad_encoder_input(self, max_len, pad_id):
78 | while len(self.enc_input) < max_len:
79 | self.enc_input.append(pad_id)
80 | if config.pointer_gen:
81 | while len(self.enc_input_extend_vocab) < max_len:
82 | self.enc_input_extend_vocab.append(pad_id)
83 |
84 |
85 | class Batch(object):
86 | def __init__(self, example_list, vocab, batch_size):
87 | self.batch_size = batch_size
88 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences
89 | self.init_encoder_seq(example_list) # initialize the input to the encoder
90 | self.init_decoder_seq(example_list) # initialize the input and targets for the decoder
91 | self.store_orig_strings(example_list) # store the original strings
92 |
93 |
94 | def init_encoder_seq(self, example_list):
95 | # Determine the maximum length of the encoder input sequence in this batch
96 | max_enc_seq_len = max([ex.enc_len for ex in example_list])
97 |
98 | # Pad the encoder input sequences up to the length of the longest sequence
99 | for ex in example_list:
100 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id)
101 |
102 | # Initialize the numpy arrays
103 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder.
104 | self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
105 | self.enc_lens = np.zeros((self.batch_size), dtype=np.int32)
106 | self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32)
107 |
108 | # Fill in the numpy arrays
109 | for i, ex in enumerate(example_list):
110 | self.enc_batch[i, :] = ex.enc_input[:]
111 | self.enc_lens[i] = ex.enc_len
112 | for j in xrange(ex.enc_len):
113 | self.enc_padding_mask[i][j] = 1
114 |
115 | # For pointer-generator mode, need to store some extra info
116 | if config.pointer_gen:
117 | # Determine the max number of in-article OOVs in this batch
118 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
119 | # Store the in-article OOVs themselves
120 | self.art_oovs = [ex.article_oovs for ex in example_list]
121 | # Store the version of the enc_batch that uses the article OOV ids
122 | self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
123 | for i, ex in enumerate(example_list):
124 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:]
125 |
126 | def init_decoder_seq(self, example_list):
127 | # Pad the inputs and targets
128 | for ex in example_list:
129 | ex.pad_decoder_inp_targ(config.max_dec_steps, self.pad_id)
130 |
131 | # Initialize the numpy arrays.
132 | self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
133 | self.target_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
134 | self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32)
135 | self.dec_lens = np.zeros((self.batch_size), dtype=np.int32)
136 |
137 | # Fill in the numpy arrays
138 | for i, ex in enumerate(example_list):
139 | self.dec_batch[i, :] = ex.dec_input[:]
140 | self.target_batch[i, :] = ex.target[:]
141 | self.dec_lens[i] = ex.dec_len
142 | for j in xrange(ex.dec_len):
143 | self.dec_padding_mask[i][j] = 1
144 |
145 | def store_orig_strings(self, example_list):
146 | self.original_articles = [ex.original_article for ex in example_list] # list of lists
147 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
148 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists
149 |
150 |
151 | class Batcher(object):
152 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold
153 |
154 | def __init__(self, data_path, vocab, mode, batch_size, single_pass):
155 | self._data_path = data_path
156 | self._vocab = vocab
157 | self._single_pass = single_pass
158 | self.mode = mode
159 | self.batch_size = batch_size
160 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
161 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX)
162 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size)
163 |
164 | # Different settings depending on whether we're in single_pass mode or not
165 | if single_pass:
166 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
167 | self._num_batch_q_threads = 1 # just one thread to batch examples
168 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing
169 | self._finished_reading = False # this will tell us when we're finished reading the dataset
170 | else:
171 | self._num_example_q_threads = 1 #16 # num threads to fill example queue
172 | self._num_batch_q_threads = 1 #4 # num threads to fill batch queue
173 | self._bucketing_cache_size = 1 #100 # how many batches-worth of examples to load into cache before bucketing
174 |
175 | # Start the threads that load the queues
176 | self._example_q_threads = []
177 | for _ in xrange(self._num_example_q_threads):
178 | self._example_q_threads.append(Thread(target=self.fill_example_queue))
179 | self._example_q_threads[-1].daemon = True
180 | self._example_q_threads[-1].start()
181 | self._batch_q_threads = []
182 | for _ in xrange(self._num_batch_q_threads):
183 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
184 | self._batch_q_threads[-1].daemon = True
185 | self._batch_q_threads[-1].start()
186 |
187 | # Start a thread that watches the other threads and restarts them if they're dead
188 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
189 | self._watch_thread = Thread(target=self.watch_threads)
190 | self._watch_thread.daemon = True
191 | self._watch_thread.start()
192 |
193 | def next_batch(self):
194 | # If the batch queue is empty, print a warning
195 | if self._batch_queue.qsize() == 0:
196 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize())
197 | if self._single_pass and self._finished_reading:
198 | tf.logging.info("Finished reading dataset in single_pass mode.")
199 | return None
200 |
201 | batch = self._batch_queue.get() # get the next Batch
202 | return batch
203 |
204 | def fill_example_queue(self):
205 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass))
206 |
207 | while True:
208 | try:
209 | (article, abstract) = input_gen.next() # read the next example from file. article and abstract are both strings.
210 | except StopIteration: # if there are no more examples:
211 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
212 | if self._single_pass:
213 | tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
214 | self._finished_reading = True
215 | break
216 | else:
217 | raise Exception("single_pass mode is off but the example generator is out of data; error.")
218 |
219 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences.
220 | example = Example(article, abstract_sentences, self._vocab) # Process into an Example.
221 | self._example_queue.put(example) # place the Example in the example queue.
222 |
223 | def fill_batch_queue(self):
224 | while True:
225 | if self.mode == 'decode':
226 | # beam search decode mode single example repeated in the batch
227 | ex = self._example_queue.get()
228 | b = [ex for _ in xrange(self.batch_size)]
229 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
230 | else:
231 | # Get bucketing_cache_size-many batches of Examples into a list, then sort
232 | inputs = []
233 | for _ in xrange(self.batch_size * self._bucketing_cache_size):
234 | inputs.append(self._example_queue.get())
235 | inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence
236 |
237 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
238 | batches = []
239 | for i in xrange(0, len(inputs), self.batch_size):
240 | batches.append(inputs[i:i + self.batch_size])
241 | if not self._single_pass:
242 | shuffle(batches)
243 | for b in batches: # each b is a list of Example objects
244 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
245 |
246 | def watch_threads(self):
247 | while True:
248 | tf.logging.info(
249 | 'Bucket queue size: %i, Input queue size: %i',
250 | self._batch_queue.qsize(), self._example_queue.qsize())
251 |
252 | time.sleep(60)
253 | for idx,t in enumerate(self._example_q_threads):
254 | if not t.is_alive(): # if the thread is dead
255 | tf.logging.error('Found example queue thread dead. Restarting.')
256 | new_t = Thread(target=self.fill_example_queue)
257 | self._example_q_threads[idx] = new_t
258 | new_t.daemon = True
259 | new_t.start()
260 | for idx,t in enumerate(self._batch_q_threads):
261 | if not t.is_alive(): # if the thread is dead
262 | tf.logging.error('Found batch queue thread dead. Restarting.')
263 | new_t = Thread(target=self.fill_batch_queue)
264 | self._batch_q_threads[idx] = new_t
265 | new_t.daemon = True
266 | new_t.start()
267 |
268 |
269 | def text_generator(self, example_generator):
270 | while True:
271 | e = example_generator.next() # e is a tf.Example
272 | try:
273 | article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files
274 | abstract_text = e.features.feature['abstract'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files
275 | except ValueError:
276 | tf.logging.error('Failed to get article or abstract from example')
277 | continue
278 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1
279 | #tf.logging.warning('Found an example with empty article text. Skipping it.')
280 | continue
281 | else:
282 | yield (article_text, abstract_text)
283 |
--------------------------------------------------------------------------------
/data_util/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | root_dir = os.path.expanduser("~")
4 |
5 | #train_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/train.bin")
6 | train_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/chunked/train_*")
7 | eval_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/val.bin")
8 | decode_data_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/test.bin")
9 | vocab_path = os.path.join(root_dir, "ptr_nw/cnn-dailymail-master/finished_files/vocab")
10 | log_root = os.path.join(root_dir, "ptr_nw/log")
11 |
12 | # Hyperparameters
13 | hidden_dim= 256
14 | emb_dim= 128
15 | batch_size= 8
16 | max_enc_steps=400
17 | max_dec_steps=100
18 | beam_size=4
19 | min_dec_steps=35
20 | vocab_size=50000
21 |
22 | lr=0.15
23 | adagrad_init_acc=0.1
24 | rand_unif_init_mag=0.02
25 | trunc_norm_init_std=1e-4
26 | max_grad_norm=2.0
27 |
28 | pointer_gen = True
29 | is_coverage = False
30 | cov_loss_wt = 1.0
31 |
32 | eps = 1e-12
33 | max_iterations = 500000
34 |
35 | use_gpu=True
36 |
37 | lr_coverage=0.15
38 |
--------------------------------------------------------------------------------
/data_util/data.py:
--------------------------------------------------------------------------------
1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/data.py
2 |
3 | import glob
4 | import random
5 | import struct
6 | import csv
7 | from tensorflow.core.example import example_pb2
8 |
9 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
10 | SENTENCE_START = ''
11 | SENTENCE_END = ''
12 |
13 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
14 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
15 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
16 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
17 |
18 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
19 |
20 |
21 | class Vocab(object):
22 |
23 | def __init__(self, vocab_file, max_size):
24 | self._word_to_id = {}
25 | self._id_to_word = {}
26 | self._count = 0 # keeps track of total number of words in the Vocab
27 |
28 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
29 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
30 | self._word_to_id[w] = self._count
31 | self._id_to_word[self._count] = w
32 | self._count += 1
33 |
34 | # Read the vocab file and add words up to max_size
35 | with open(vocab_file, 'r') as vocab_f:
36 | for line in vocab_f:
37 | pieces = line.split()
38 | if len(pieces) != 2:
39 | print 'Warning: incorrectly formatted line in vocabulary file: %s\n' % line
40 | continue
41 | w = pieces[0]
42 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
43 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
44 | if w in self._word_to_id:
45 | raise Exception('Duplicated word in vocabulary file: %s' % w)
46 | self._word_to_id[w] = self._count
47 | self._id_to_word[self._count] = w
48 | self._count += 1
49 | if max_size != 0 and self._count >= max_size:
50 | print "max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)
51 | break
52 |
53 | print "Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])
54 |
55 | def word2id(self, word):
56 | if word not in self._word_to_id:
57 | return self._word_to_id[UNKNOWN_TOKEN]
58 | return self._word_to_id[word]
59 |
60 | def id2word(self, word_id):
61 | if word_id not in self._id_to_word:
62 | raise ValueError('Id not found in vocab: %d' % word_id)
63 | return self._id_to_word[word_id]
64 |
65 | def size(self):
66 | return self._count
67 |
68 | def write_metadata(self, fpath):
69 | print "Writing word embedding metadata file to %s..." % (fpath)
70 | with open(fpath, "w") as f:
71 | fieldnames = ['word']
72 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
73 | for i in xrange(self.size()):
74 | writer.writerow({"word": self._id_to_word[i]})
75 |
76 |
77 | def example_generator(data_path, single_pass):
78 | while True:
79 | filelist = glob.glob(data_path) # get the list of datafiles
80 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
81 | if single_pass:
82 | filelist = sorted(filelist)
83 | else:
84 | random.shuffle(filelist)
85 | for f in filelist:
86 | reader = open(f, 'rb')
87 | while True:
88 | len_bytes = reader.read(8)
89 | if not len_bytes: break # finished reading this file
90 | str_len = struct.unpack('q', len_bytes)[0]
91 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
92 | yield example_pb2.Example.FromString(example_str)
93 | if single_pass:
94 | print "example_generator completed reading all datafiles. No more data."
95 | break
96 |
97 |
98 | def article2ids(article_words, vocab):
99 | ids = []
100 | oovs = []
101 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
102 | for w in article_words:
103 | i = vocab.word2id(w)
104 | if i == unk_id: # If w is OOV
105 | if w not in oovs: # Add to list of OOVs
106 | oovs.append(w)
107 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
108 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
109 | else:
110 | ids.append(i)
111 | return ids, oovs
112 |
113 |
114 | def abstract2ids(abstract_words, vocab, article_oovs):
115 | ids = []
116 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
117 | for w in abstract_words:
118 | i = vocab.word2id(w)
119 | if i == unk_id: # If w is an OOV word
120 | if w in article_oovs: # If w is an in-article OOV
121 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
122 | ids.append(vocab_idx)
123 | else: # If w is an out-of-article OOV
124 | ids.append(unk_id) # Map to the UNK token id
125 | else:
126 | ids.append(i)
127 | return ids
128 |
129 |
130 | def outputids2words(id_list, vocab, article_oovs):
131 | words = []
132 | for i in id_list:
133 | try:
134 | w = vocab.id2word(i) # might be [UNK]
135 | except ValueError as e: # w is OOV
136 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
137 | article_oov_idx = i - vocab.size()
138 | try:
139 | w = article_oovs[article_oov_idx]
140 | except ValueError as e: # i doesn't correspond to an article oov
141 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
142 | words.append(w)
143 | return words
144 |
145 |
146 | def abstract2sents(abstract):
147 | cur = 0
148 | sents = []
149 | while True:
150 | try:
151 | start_p = abstract.index(SENTENCE_START, cur)
152 | end_p = abstract.index(SENTENCE_END, start_p + 1)
153 | cur = end_p + len(SENTENCE_END)
154 | sents.append(abstract[start_p+len(SENTENCE_START):end_p])
155 | except ValueError as e: # no more sentences
156 | return sents
157 |
158 |
159 | def show_art_oovs(article, vocab):
160 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
161 | words = article.split(' ')
162 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words]
163 | out_str = ' '.join(words)
164 | return out_str
165 |
166 |
167 | def show_abs_oovs(abstract, vocab, article_oovs):
168 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
169 | words = abstract.split(' ')
170 | new_words = []
171 | for w in words:
172 | if vocab.word2id(w) == unk_token: # w is oov
173 | if article_oovs is None: # baseline mode
174 | new_words.append("__%s__" % w)
175 | else: # pointer-generator mode
176 | if w in article_oovs:
177 | new_words.append("__%s__" % w)
178 | else:
179 | new_words.append("!!__%s__!!" % w)
180 | else: # w is in-vocab word
181 | new_words.append(w)
182 | out_str = ' '.join(new_words)
183 | return out_str
184 |
--------------------------------------------------------------------------------
/data_util/utils.py:
--------------------------------------------------------------------------------
1 | #Content of this file is copied from https://github.com/abisee/pointer-generator/blob/master/
2 | import os
3 | import pyrouge
4 | import logging
5 | import tensorflow as tf
6 |
7 | def print_results(article, abstract, decoded_output):
8 | print ("")
9 | print('ARTICLE: %s', article)
10 | print('REFERENCE SUMMARY: %s', abstract)
11 | print('GENERATED SUMMARY: %s', decoded_output)
12 | print( "")
13 |
14 |
15 | def make_html_safe(s):
16 | s.replace("<", "<")
17 | s.replace(">", ">")
18 | return s
19 |
20 |
21 | def rouge_eval(ref_dir, dec_dir):
22 | r = pyrouge.Rouge155()
23 | r.model_filename_pattern = '#ID#_reference.txt'
24 | r.system_filename_pattern = '(\d+)_decoded.txt'
25 | r.model_dir = ref_dir
26 | r.system_dir = dec_dir
27 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
28 | rouge_results = r.convert_and_evaluate()
29 | return r.output_to_dict(rouge_results)
30 |
31 |
32 | def rouge_log(results_dict, dir_to_write):
33 | log_str = ""
34 | for x in ["1","2","l"]:
35 | log_str += "\nROUGE-%s:\n" % x
36 | for y in ["f_score", "recall", "precision"]:
37 | key = "rouge_%s_%s" % (x,y)
38 | key_cb = key + "_cb"
39 | key_ce = key + "_ce"
40 | val = results_dict[key]
41 | val_cb = results_dict[key_cb]
42 | val_ce = results_dict[key_ce]
43 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
44 | print(log_str)
45 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
46 | print("Writing final ROUGE results to %s..."%(results_file))
47 | with open(results_file, "w") as f:
48 | f.write(log_str)
49 |
50 |
51 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
52 | if running_avg_loss == 0: # on the first iteration just take the loss
53 | running_avg_loss = loss
54 | else:
55 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
56 | running_avg_loss = min(running_avg_loss, 12) # clip
57 | loss_sum = tf.Summary()
58 | tag_name = 'running_avg_loss/decay=%f' % (decay)
59 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
60 | summary_writer.add_summary(loss_sum, step)
61 | return running_avg_loss
62 |
63 |
64 | def write_for_rouge(reference_sents, decoded_words, ex_index,
65 | _rouge_ref_dir, _rouge_dec_dir):
66 | decoded_sents = []
67 | while len(decoded_words) > 0:
68 | try:
69 | fst_period_idx = decoded_words.index(".")
70 | except ValueError:
71 | fst_period_idx = len(decoded_words)
72 | sent = decoded_words[:fst_period_idx + 1]
73 | decoded_words = decoded_words[fst_period_idx + 1:]
74 | decoded_sents.append(' '.join(sent))
75 |
76 | # pyrouge calls a perl script that puts the data into HTML files.
77 | # Therefore we need to make our output HTML safe.
78 | decoded_sents = [make_html_safe(w) for w in decoded_sents]
79 | reference_sents = [make_html_safe(w) for w in reference_sents]
80 |
81 | ref_file = os.path.join(_rouge_ref_dir, "%06d_reference.txt" % ex_index)
82 | decoded_file = os.path.join(_rouge_dec_dir, "%06d_decoded.txt" % ex_index)
83 |
84 | with open(ref_file, "w") as f:
85 | for idx, sent in enumerate(reference_sents):
86 | f.write(sent) if idx == len(reference_sents) - 1 else f.write(sent + "\n")
87 | with open(decoded_file, "w") as f:
88 | for idx, sent in enumerate(decoded_sents):
89 | f.write(sent) if idx == len(decoded_sents) - 1 else f.write(sent + "\n")
90 |
91 | #print("Wrote example %i to file" % ex_index)
92 |
--------------------------------------------------------------------------------
/learning_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/learning_curve.png
--------------------------------------------------------------------------------
/learning_curve_coverage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/learning_curve_coverage.png
--------------------------------------------------------------------------------
/start_decode.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=`pwd`
2 | MODEL=$1
3 | python training_ptr_gen/decode.py $MODEL >& ../log/decode_log &
4 |
5 |
--------------------------------------------------------------------------------
/start_eval.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=`pwd`
2 | MODEL_PATH=$1
3 | MODEL_NAME=$(basename $MODEL_PATH)
4 | python training_ptr_gen/eval.py $MODEL_PATH >& ../log/eval_log.$MODEL_NAME &
5 |
6 |
--------------------------------------------------------------------------------
/start_train.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=`pwd`
2 | python training_ptr_gen/train.py >& ../log/training_log &
3 |
4 |
--------------------------------------------------------------------------------
/training_ptr_gen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atulkum/pointer_summarizer/4967dd478ba70b1c12116ad6357dcc1d000e5dfa/training_ptr_gen/__init__.py
--------------------------------------------------------------------------------
/training_ptr_gen/decode.py:
--------------------------------------------------------------------------------
1 | #Except for the pytorch part content of this file is copied from https://github.com/abisee/pointer-generator/blob/master/
2 |
3 | from __future__ import unicode_literals, print_function, division
4 |
5 | import sys
6 |
7 | reload(sys)
8 | sys.setdefaultencoding('utf8')
9 |
10 | import os
11 | import time
12 |
13 | import torch
14 | from torch.autograd import Variable
15 |
16 | from data_util.batcher import Batcher
17 | from data_util.data import Vocab
18 | from data_util import data, config
19 | from model import Model
20 | from data_util.utils import write_for_rouge, rouge_eval, rouge_log
21 | from train_util import get_input_from_batch
22 |
23 |
24 | use_cuda = config.use_gpu and torch.cuda.is_available()
25 |
26 | class Beam(object):
27 | def __init__(self, tokens, log_probs, state, context, coverage):
28 | self.tokens = tokens
29 | self.log_probs = log_probs
30 | self.state = state
31 | self.context = context
32 | self.coverage = coverage
33 |
34 | def extend(self, token, log_prob, state, context, coverage):
35 | return Beam(tokens = self.tokens + [token],
36 | log_probs = self.log_probs + [log_prob],
37 | state = state,
38 | context = context,
39 | coverage = coverage)
40 |
41 | @property
42 | def latest_token(self):
43 | return self.tokens[-1]
44 |
45 | @property
46 | def avg_log_prob(self):
47 | return sum(self.log_probs) / len(self.tokens)
48 |
49 |
50 | class BeamSearch(object):
51 | def __init__(self, model_file_path):
52 | model_name = os.path.basename(model_file_path)
53 | self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
54 | self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
55 | self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
56 | for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
57 | if not os.path.exists(p):
58 | os.mkdir(p)
59 |
60 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
61 | self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
62 | batch_size=config.beam_size, single_pass=True)
63 | time.sleep(15)
64 |
65 | self.model = Model(model_file_path, is_eval=True)
66 |
67 | def sort_beams(self, beams):
68 | return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)
69 |
70 |
71 | def decode(self):
72 | start = time.time()
73 | counter = 0
74 | batch = self.batcher.next_batch()
75 | while batch is not None:
76 | # Run beam search to get best Hypothesis
77 | best_summary = self.beam_search(batch)
78 |
79 | # Extract the output ids from the hypothesis and convert back to words
80 | output_ids = [int(t) for t in best_summary.tokens[1:]]
81 | decoded_words = data.outputids2words(output_ids, self.vocab,
82 | (batch.art_oovs[0] if config.pointer_gen else None))
83 |
84 | # Remove the [STOP] token from decoded_words, if necessary
85 | try:
86 | fst_stop_idx = decoded_words.index(data.STOP_DECODING)
87 | decoded_words = decoded_words[:fst_stop_idx]
88 | except ValueError:
89 | decoded_words = decoded_words
90 |
91 | original_abstract_sents = batch.original_abstracts_sents[0]
92 |
93 | write_for_rouge(original_abstract_sents, decoded_words, counter,
94 | self._rouge_ref_dir, self._rouge_dec_dir)
95 | counter += 1
96 | if counter % 1000 == 0:
97 | print('%d example in %d sec'%(counter, time.time() - start))
98 | start = time.time()
99 |
100 | batch = self.batcher.next_batch()
101 |
102 | print("Decoder has finished reading dataset for single_pass.")
103 | print("Now starting ROUGE eval...")
104 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
105 | rouge_log(results_dict, self._decode_dir)
106 |
107 |
108 | def beam_search(self, batch):
109 | #batch should have only one example
110 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
111 | get_input_from_batch(batch, use_cuda)
112 |
113 | encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
114 | s_t_0 = self.model.reduce_state(encoder_hidden)
115 |
116 | dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
117 | dec_h = dec_h.squeeze()
118 | dec_c = dec_c.squeeze()
119 |
120 | #decoder batch preparation, it has beam_size example initially everything is repeated
121 | beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
122 | log_probs=[0.0],
123 | state=(dec_h[0], dec_c[0]),
124 | context = c_t_0[0],
125 | coverage=(coverage_t_0[0] if config.is_coverage else None))
126 | for _ in xrange(config.beam_size)]
127 | results = []
128 | steps = 0
129 | while steps < config.max_dec_steps and len(results) < config.beam_size:
130 | latest_tokens = [h.latest_token for h in beams]
131 | latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
132 | for t in latest_tokens]
133 | y_t_1 = Variable(torch.LongTensor(latest_tokens))
134 | if use_cuda:
135 | y_t_1 = y_t_1.cuda()
136 | all_state_h =[]
137 | all_state_c = []
138 |
139 | all_context = []
140 |
141 | for h in beams:
142 | state_h, state_c = h.state
143 | all_state_h.append(state_h)
144 | all_state_c.append(state_c)
145 |
146 | all_context.append(h.context)
147 |
148 | s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
149 | c_t_1 = torch.stack(all_context, 0)
150 |
151 | coverage_t_1 = None
152 | if config.is_coverage:
153 | all_coverage = []
154 | for h in beams:
155 | all_coverage.append(h.coverage)
156 | coverage_t_1 = torch.stack(all_coverage, 0)
157 |
158 | final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
159 | encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
160 | extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps)
161 | log_probs = torch.log(final_dist)
162 | topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2)
163 |
164 | dec_h, dec_c = s_t
165 | dec_h = dec_h.squeeze()
166 | dec_c = dec_c.squeeze()
167 |
168 | all_beams = []
169 | num_orig_beams = 1 if steps == 0 else len(beams)
170 | for i in xrange(num_orig_beams):
171 | h = beams[i]
172 | state_i = (dec_h[i], dec_c[i])
173 | context_i = c_t[i]
174 | coverage_i = (coverage_t[i] if config.is_coverage else None)
175 |
176 | for j in xrange(config.beam_size * 2): # for each of the top 2*beam_size hyps:
177 | new_beam = h.extend(token=topk_ids[i, j].item(),
178 | log_prob=topk_log_probs[i, j].item(),
179 | state=state_i,
180 | context=context_i,
181 | coverage=coverage_i)
182 | all_beams.append(new_beam)
183 |
184 | beams = []
185 | for h in self.sort_beams(all_beams):
186 | if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
187 | if steps >= config.min_dec_steps:
188 | results.append(h)
189 | else:
190 | beams.append(h)
191 | if len(beams) == config.beam_size or len(results) == config.beam_size:
192 | break
193 |
194 | steps += 1
195 |
196 | if len(results) == 0:
197 | results = beams
198 |
199 | beams_sorted = self.sort_beams(results)
200 |
201 | return beams_sorted[0]
202 |
203 | if __name__ == '__main__':
204 | model_filename = sys.argv[1]
205 | beam_Search_processor = BeamSearch(model_filename)
206 | beam_Search_processor.decode()
207 |
208 |
209 |
--------------------------------------------------------------------------------
/training_ptr_gen/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, print_function, division
2 |
3 | import os
4 | import time
5 | import sys
6 |
7 | import tensorflow as tf
8 | import torch
9 |
10 | from data_util import config
11 | from data_util.batcher import Batcher
12 | from data_util.data import Vocab
13 |
14 | from data_util.utils import calc_running_avg_loss
15 | from train_util import get_input_from_batch, get_output_from_batch
16 | from model import Model
17 |
18 | use_cuda = config.use_gpu and torch.cuda.is_available()
19 |
20 | class Evaluate(object):
21 | def __init__(self, model_file_path):
22 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
23 | self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval',
24 | batch_size=config.batch_size, single_pass=True)
25 | time.sleep(15)
26 | model_name = os.path.basename(model_file_path)
27 |
28 | eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
29 | if not os.path.exists(eval_dir):
30 | os.mkdir(eval_dir)
31 | self.summary_writer = tf.summary.FileWriter(eval_dir)
32 |
33 | self.model = Model(model_file_path, is_eval=True)
34 |
35 | def eval_one_batch(self, batch):
36 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
37 | get_input_from_batch(batch, use_cuda)
38 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
39 | get_output_from_batch(batch, use_cuda)
40 |
41 | encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
42 | s_t_1 = self.model.reduce_state(encoder_hidden)
43 |
44 | step_losses = []
45 | for di in range(min(max_dec_len, config.max_dec_steps)):
46 | y_t_1 = dec_batch[:, di] # Teacher forcing
47 | final_dist, s_t_1, c_t_1,attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
48 | encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
49 | extra_zeros, enc_batch_extend_vocab, coverage, di)
50 | target = target_batch[:, di]
51 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
52 | step_loss = -torch.log(gold_probs + config.eps)
53 | if config.is_coverage:
54 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
55 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
56 | coverage = next_coverage
57 |
58 | step_mask = dec_padding_mask[:, di]
59 | step_loss = step_loss * step_mask
60 | step_losses.append(step_loss)
61 |
62 | sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
63 | batch_avg_loss = sum_step_losses / dec_lens_var
64 | loss = torch.mean(batch_avg_loss)
65 |
66 | return loss.data[0]
67 |
68 | def run_eval(self):
69 | running_avg_loss, iter = 0, 0
70 | start = time.time()
71 | batch = self.batcher.next_batch()
72 | while batch is not None:
73 | loss = self.eval_one_batch(batch)
74 |
75 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
76 | iter += 1
77 |
78 | if iter % 100 == 0:
79 | self.summary_writer.flush()
80 | print_interval = 1000
81 | if iter % print_interval == 0:
82 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % (
83 | iter, print_interval, time.time() - start, running_avg_loss))
84 | start = time.time()
85 | batch = self.batcher.next_batch()
86 |
87 |
88 | if __name__ == '__main__':
89 | model_filename = sys.argv[1]
90 | eval_processor = Evaluate(model_filename)
91 | eval_processor.run_eval()
92 |
93 |
94 |
--------------------------------------------------------------------------------
/training_ptr_gen/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, print_function, division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7 | from data_util import config
8 | from numpy import random
9 |
10 | use_cuda = config.use_gpu and torch.cuda.is_available()
11 |
12 | random.seed(123)
13 | torch.manual_seed(123)
14 | if torch.cuda.is_available():
15 | torch.cuda.manual_seed_all(123)
16 |
17 | def init_lstm_wt(lstm):
18 | for names in lstm._all_weights:
19 | for name in names:
20 | if name.startswith('weight_'):
21 | wt = getattr(lstm, name)
22 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag)
23 | elif name.startswith('bias_'):
24 | # set forget bias to 1
25 | bias = getattr(lstm, name)
26 | n = bias.size(0)
27 | start, end = n // 4, n // 2
28 | bias.data.fill_(0.)
29 | bias.data[start:end].fill_(1.)
30 |
31 | def init_linear_wt(linear):
32 | linear.weight.data.normal_(std=config.trunc_norm_init_std)
33 | if linear.bias is not None:
34 | linear.bias.data.normal_(std=config.trunc_norm_init_std)
35 |
36 | def init_wt_normal(wt):
37 | wt.data.normal_(std=config.trunc_norm_init_std)
38 |
39 | def init_wt_unif(wt):
40 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag)
41 |
42 | class Encoder(nn.Module):
43 | def __init__(self):
44 | super(Encoder, self).__init__()
45 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim)
46 | init_wt_normal(self.embedding.weight)
47 |
48 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
49 | init_lstm_wt(self.lstm)
50 |
51 | self.W_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False)
52 |
53 | #seq_lens should be in descending order
54 | def forward(self, input, seq_lens):
55 | embedded = self.embedding(input)
56 |
57 | packed = pack_padded_sequence(embedded, seq_lens, batch_first=True)
58 | output, hidden = self.lstm(packed)
59 |
60 | encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x t_k x n
61 | encoder_outputs = encoder_outputs.contiguous()
62 |
63 | encoder_feature = encoder_outputs.view(-1, 2*config.hidden_dim) # B * t_k x 2*hidden_dim
64 | encoder_feature = self.W_h(encoder_feature)
65 |
66 | return encoder_outputs, encoder_feature, hidden
67 |
68 | class ReduceState(nn.Module):
69 | def __init__(self):
70 | super(ReduceState, self).__init__()
71 |
72 | self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
73 | init_linear_wt(self.reduce_h)
74 | self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
75 | init_linear_wt(self.reduce_c)
76 |
77 | def forward(self, hidden):
78 | h, c = hidden # h, c dim = 2 x b x hidden_dim
79 | h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
80 | hidden_reduced_h = F.relu(self.reduce_h(h_in))
81 | c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
82 | hidden_reduced_c = F.relu(self.reduce_c(c_in))
83 |
84 | return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim
85 |
86 | class Attention(nn.Module):
87 | def __init__(self):
88 | super(Attention, self).__init__()
89 | # attention
90 | if config.is_coverage:
91 | self.W_c = nn.Linear(1, config.hidden_dim * 2, bias=False)
92 | self.decode_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2)
93 | self.v = nn.Linear(config.hidden_dim * 2, 1, bias=False)
94 |
95 | def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage):
96 | b, t_k, n = list(encoder_outputs.size())
97 |
98 | dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dim
99 | dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dim
100 | dec_fea_expanded = dec_fea_expanded.view(-1, n) # B * t_k x 2*hidden_dim
101 |
102 | att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim
103 | if config.is_coverage:
104 | coverage_input = coverage.view(-1, 1) # B * t_k x 1
105 | coverage_feature = self.W_c(coverage_input) # B * t_k x 2*hidden_dim
106 | att_features = att_features + coverage_feature
107 |
108 | e = F.tanh(att_features) # B * t_k x 2*hidden_dim
109 | scores = self.v(e) # B * t_k x 1
110 | scores = scores.view(-1, t_k) # B x t_k
111 |
112 | attn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_k
113 | normalization_factor = attn_dist_.sum(1, keepdim=True)
114 | attn_dist = attn_dist_ / normalization_factor
115 |
116 | attn_dist = attn_dist.unsqueeze(1) # B x 1 x t_k
117 | c_t = torch.bmm(attn_dist, encoder_outputs) # B x 1 x n
118 | c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim
119 |
120 | attn_dist = attn_dist.view(-1, t_k) # B x t_k
121 |
122 | if config.is_coverage:
123 | coverage = coverage.view(-1, t_k)
124 | coverage = coverage + attn_dist
125 |
126 | return c_t, attn_dist, coverage
127 |
128 | class Decoder(nn.Module):
129 | def __init__(self):
130 | super(Decoder, self).__init__()
131 | self.attention_network = Attention()
132 | # decoder
133 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim)
134 | init_wt_normal(self.embedding.weight)
135 |
136 | self.x_context = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim)
137 |
138 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=False)
139 | init_lstm_wt(self.lstm)
140 |
141 | if config.pointer_gen:
142 | self.p_gen_linear = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1)
143 |
144 | #p_vocab
145 | self.out1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim)
146 | self.out2 = nn.Linear(config.hidden_dim, config.vocab_size)
147 | init_linear_wt(self.out2)
148 |
149 | def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
150 | c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step):
151 |
152 | if not self.training and step == 0:
153 | h_decoder, c_decoder = s_t_1
154 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),
155 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
156 | c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature,
157 | enc_padding_mask, coverage)
158 | coverage = coverage_next
159 |
160 | y_t_1_embd = self.embedding(y_t_1)
161 | x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))
162 | lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1)
163 |
164 | h_decoder, c_decoder = s_t
165 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),
166 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
167 | c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature,
168 | enc_padding_mask, coverage)
169 |
170 | if self.training or step > 0:
171 | coverage = coverage_next
172 |
173 | p_gen = None
174 | if config.pointer_gen:
175 | p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim)
176 | p_gen = self.p_gen_linear(p_gen_input)
177 | p_gen = F.sigmoid(p_gen)
178 |
179 | output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3
180 | output = self.out1(output) # B x hidden_dim
181 |
182 | #output = F.relu(output)
183 |
184 | output = self.out2(output) # B x vocab_size
185 | vocab_dist = F.softmax(output, dim=1)
186 |
187 | if config.pointer_gen:
188 | vocab_dist_ = p_gen * vocab_dist
189 | attn_dist_ = (1 - p_gen) * attn_dist
190 |
191 | if extra_zeros is not None:
192 | vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)
193 |
194 | final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
195 | else:
196 | final_dist = vocab_dist
197 |
198 | return final_dist, s_t, c_t, attn_dist, p_gen, coverage
199 |
200 | class Model(object):
201 | def __init__(self, model_file_path=None, is_eval=False):
202 | encoder = Encoder()
203 | decoder = Decoder()
204 | reduce_state = ReduceState()
205 |
206 | # shared the embedding between encoder and decoder
207 | decoder.embedding.weight = encoder.embedding.weight
208 | if is_eval:
209 | encoder = encoder.eval()
210 | decoder = decoder.eval()
211 | reduce_state = reduce_state.eval()
212 |
213 | if use_cuda:
214 | encoder = encoder.cuda()
215 | decoder = decoder.cuda()
216 | reduce_state = reduce_state.cuda()
217 |
218 | self.encoder = encoder
219 | self.decoder = decoder
220 | self.reduce_state = reduce_state
221 |
222 | if model_file_path is not None:
223 | state = torch.load(model_file_path, map_location= lambda storage, location: storage)
224 | self.encoder.load_state_dict(state['encoder_state_dict'])
225 | self.decoder.load_state_dict(state['decoder_state_dict'], strict=False)
226 | self.reduce_state.load_state_dict(state['reduce_state_dict'])
227 |
--------------------------------------------------------------------------------
/training_ptr_gen/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, print_function, division
2 |
3 | import os
4 | import time
5 | import argparse
6 |
7 | import tensorflow as tf
8 | import torch
9 | from model import Model
10 | from torch.nn.utils import clip_grad_norm_
11 |
12 | from torch.optim import Adagrad
13 |
14 | from data_util import config
15 | from data_util.batcher import Batcher
16 | from data_util.data import Vocab
17 | from data_util.utils import calc_running_avg_loss
18 | from train_util import get_input_from_batch, get_output_from_batch
19 |
20 | use_cuda = config.use_gpu and torch.cuda.is_available()
21 |
22 | class Train(object):
23 | def __init__(self):
24 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
25 | self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
26 | batch_size=config.batch_size, single_pass=False)
27 | time.sleep(15)
28 |
29 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
30 | if not os.path.exists(train_dir):
31 | os.mkdir(train_dir)
32 |
33 | self.model_dir = os.path.join(train_dir, 'model')
34 | if not os.path.exists(self.model_dir):
35 | os.mkdir(self.model_dir)
36 |
37 | self.summary_writer = tf.summary.FileWriter(train_dir)
38 |
39 | def save_model(self, running_avg_loss, iter):
40 | state = {
41 | 'iter': iter,
42 | 'encoder_state_dict': self.model.encoder.state_dict(),
43 | 'decoder_state_dict': self.model.decoder.state_dict(),
44 | 'reduce_state_dict': self.model.reduce_state.state_dict(),
45 | 'optimizer': self.optimizer.state_dict(),
46 | 'current_loss': running_avg_loss
47 | }
48 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
49 | torch.save(state, model_save_path)
50 |
51 | def setup_train(self, model_file_path=None):
52 | self.model = Model(model_file_path)
53 |
54 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
55 | list(self.model.reduce_state.parameters())
56 | initial_lr = config.lr_coverage if config.is_coverage else config.lr
57 | self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
58 |
59 | start_iter, start_loss = 0, 0
60 |
61 | if model_file_path is not None:
62 | state = torch.load(model_file_path, map_location= lambda storage, location: storage)
63 | start_iter = state['iter']
64 | start_loss = state['current_loss']
65 |
66 | if not config.is_coverage:
67 | self.optimizer.load_state_dict(state['optimizer'])
68 | if use_cuda:
69 | for state in self.optimizer.state.values():
70 | for k, v in state.items():
71 | if torch.is_tensor(v):
72 | state[k] = v.cuda()
73 |
74 | return start_iter, start_loss
75 |
76 | def train_one_batch(self, batch):
77 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
78 | get_input_from_batch(batch, use_cuda)
79 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
80 | get_output_from_batch(batch, use_cuda)
81 |
82 | self.optimizer.zero_grad()
83 |
84 | encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
85 | s_t_1 = self.model.reduce_state(encoder_hidden)
86 |
87 | step_losses = []
88 | for di in range(min(max_dec_len, config.max_dec_steps)):
89 | y_t_1 = dec_batch[:, di] # Teacher forcing
90 | final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
91 | encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
92 | extra_zeros, enc_batch_extend_vocab,
93 | coverage, di)
94 | target = target_batch[:, di]
95 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
96 | step_loss = -torch.log(gold_probs + config.eps)
97 | if config.is_coverage:
98 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
99 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
100 | coverage = next_coverage
101 |
102 | step_mask = dec_padding_mask[:, di]
103 | step_loss = step_loss * step_mask
104 | step_losses.append(step_loss)
105 |
106 | sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
107 | batch_avg_loss = sum_losses/dec_lens_var
108 | loss = torch.mean(batch_avg_loss)
109 |
110 | loss.backward()
111 |
112 | self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
113 | clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
114 | clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)
115 |
116 | self.optimizer.step()
117 |
118 | return loss.item()
119 |
120 | def trainIters(self, n_iters, model_file_path=None):
121 | iter, running_avg_loss = self.setup_train(model_file_path)
122 | start = time.time()
123 | while iter < n_iters:
124 | batch = self.batcher.next_batch()
125 | loss = self.train_one_batch(batch)
126 |
127 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
128 | iter += 1
129 |
130 | if iter % 100 == 0:
131 | self.summary_writer.flush()
132 | print_interval = 1000
133 | if iter % print_interval == 0:
134 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval,
135 | time.time() - start, loss))
136 | start = time.time()
137 | if iter % 5000 == 0:
138 | self.save_model(running_avg_loss, iter)
139 |
140 | if __name__ == '__main__':
141 | parser = argparse.ArgumentParser(description="Train script")
142 | parser.add_argument("-m",
143 | dest="model_file_path",
144 | required=False,
145 | default=None,
146 | help="Model file for retraining (default: None).")
147 | args = parser.parse_args()
148 |
149 | train_processor = Train()
150 | train_processor.trainIters(config.max_iterations, args.model_file_path)
151 |
--------------------------------------------------------------------------------
/training_ptr_gen/train_util.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | import numpy as np
3 | import torch
4 | from data_util import config
5 |
6 | def get_input_from_batch(batch, use_cuda):
7 | batch_size = len(batch.enc_lens)
8 |
9 | enc_batch = Variable(torch.from_numpy(batch.enc_batch).long())
10 | enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float()
11 | enc_lens = batch.enc_lens
12 | extra_zeros = None
13 | enc_batch_extend_vocab = None
14 |
15 | if config.pointer_gen:
16 | enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long())
17 | # max_art_oovs is the max over all the article oov list in the batch
18 | if batch.max_art_oovs > 0:
19 | extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs)))
20 |
21 | c_t_1 = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))
22 |
23 | coverage = None
24 | if config.is_coverage:
25 | coverage = Variable(torch.zeros(enc_batch.size()))
26 |
27 | if use_cuda:
28 | enc_batch = enc_batch.cuda()
29 | enc_padding_mask = enc_padding_mask.cuda()
30 |
31 | if enc_batch_extend_vocab is not None:
32 | enc_batch_extend_vocab = enc_batch_extend_vocab.cuda()
33 | if extra_zeros is not None:
34 | extra_zeros = extra_zeros.cuda()
35 | c_t_1 = c_t_1.cuda()
36 |
37 | if coverage is not None:
38 | coverage = coverage.cuda()
39 |
40 | return enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage
41 |
42 | def get_output_from_batch(batch, use_cuda):
43 | dec_batch = Variable(torch.from_numpy(batch.dec_batch).long())
44 | dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float()
45 | dec_lens = batch.dec_lens
46 | max_dec_len = np.max(dec_lens)
47 | dec_lens_var = Variable(torch.from_numpy(dec_lens)).float()
48 |
49 | target_batch = Variable(torch.from_numpy(batch.target_batch)).long()
50 |
51 | if use_cuda:
52 | dec_batch = dec_batch.cuda()
53 | dec_padding_mask = dec_padding_mask.cuda()
54 | dec_lens_var = dec_lens_var.cuda()
55 | target_batch = target_batch.cuda()
56 |
57 |
58 | return dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch
59 |
60 |
--------------------------------------------------------------------------------
/training_ptr_gen/transformer_encoder.py:
--------------------------------------------------------------------------------
1 | #This is still a work in progress I will work on it once I get some free time.
2 |
3 | from __future__ import unicode_literals, print_function, division
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import logging
9 | import math
10 |
11 | logging.basicConfig(level=logging.INFO)
12 |
13 | class PositionalEncoding(nn.Module):
14 | def __init__(self, d_model, dropout, max_len=5000):
15 | super(PositionalEncoding, self).__init__()
16 | self.dropout = nn.Dropout(p=dropout)
17 |
18 | pe = torch.zeros(max_len, d_model)
19 | position = torch.arange(0, max_len).float().unsqueeze(1)
20 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
21 | -(math.log(10000.0) / d_model))
22 | pe[:, 0::2] = torch.sin(position * div_term)
23 | pe[:, 1::2] = torch.cos(position * div_term)
24 | pe = pe.unsqueeze(0)
25 | self.register_buffer('pe', pe)
26 |
27 | def forward(self, x):
28 | x = x + self.pe[:, :x.size(1)]
29 | return self.dropout(x)
30 |
31 | class MultiHeadedAttention(nn.Module):
32 | def __init__(self, num_head, d_model, dropout=0.1):
33 | super(MultiHeadedAttention, self).__init__()
34 | assert d_model % num_head == 0
35 | self.d_k = d_model // num_head #d_k == d_v
36 | self.h = num_head
37 |
38 | self.linear_key = nn.Linear(d_model, d_model)
39 | self.linear_value = nn.Linear(d_model, d_model)
40 | self.linear_query = nn.Linear(d_model, d_model)
41 | self.linear_out = nn.Linear(d_model, d_model)
42 |
43 | self.dropout = nn.Dropout(p=dropout)
44 |
45 | def attention(self, query, key, value, mask, dropout=None):
46 | d_k = query.size(-1)
47 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
48 | scores = scores.masked_fill(mask == 0, -1e9)
49 |
50 | p_attn = F.softmax(scores, dim=-1)
51 | if dropout is not None:
52 | p_attn = dropout(p_attn)
53 | return torch.matmul(p_attn, value), p_attn
54 |
55 | def forward(self, query, key, value, mask):
56 | nbatches = query.size(0)
57 | query = self.linear_query(query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
58 | key = self.linear_key(key).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
59 | value = self.linear_value(value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
60 |
61 | mask = mask.unsqueeze(1)
62 | x, attn = self.attention(query, key, value, mask, dropout=self.dropout)
63 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
64 | return self.linear_out(x)
65 |
66 | class AffineLayer(nn.Module):
67 | def __init__(self, dropout, d_model, d_ff):
68 | super(AffineLayer, self).__init__()
69 | self.w_1 = nn.Linear(d_model, d_ff)
70 | self.w_2 = nn.Linear(d_ff, d_model)
71 | self.dropout = nn.Dropout(dropout)
72 |
73 | def forward(self, x):
74 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
75 |
76 | class EncoderLayer(nn.Module):
77 | def __init__(self, num_head, dropout, d_model, d_ff):
78 | super(EncoderLayer, self).__init__()
79 |
80 | self.att_layer = MultiHeadedAttention(num_head, d_model, dropout)
81 | self.norm_att = nn.LayerNorm(d_model)
82 | self.dropout_att = nn.Dropout(dropout)
83 |
84 | self.affine_layer = AffineLayer(dropout, d_model, d_ff)
85 | self.norm_affine = nn.LayerNorm(d_model)
86 | self.dropout_affine = nn.Dropout(dropout)
87 |
88 | def forward(self, x, mask):
89 | x_att = self.norm_att(x*mask)
90 | x_att = self.att_layer(x_att, x_att, x_att, mask)
91 | x = x + self.dropout_att(x_att)
92 |
93 | x_affine = self.norm_affine(x*mask)
94 | x_affine = self.affine_layer(x_affine)
95 | return x + self.dropout_affine(x_affine)
96 |
97 | class Encoder(nn.Module):
98 | def __init__(self, N, num_head, dropout, d_model, d_ff):
99 | super(Encoder, self).__init__()
100 | self.position = PositionalEncoding(d_model, dropout)
101 | self.layers = nn.ModuleList()
102 | for _ in range(N):
103 | self.layers.append(EncoderLayer(num_head, dropout, d_model, d_ff))
104 | self.norm = nn.LayerNorm(d_model)
105 |
106 | def forward(self, word_embed, mask):
107 | x = self.position(word_embed)
108 | for layer in self.layers:
109 | x = layer(x, mask)
110 | return self.norm(x*mask)
111 |
--------------------------------------------------------------------------------