├── README.md
├── dataset
└── url_lists
│ ├── all_test.txt
│ ├── all_train.txt
│ ├── all_val.txt
│ ├── cnn_wayback_test_urls.txt
│ ├── cnn_wayback_training_urls.txt
│ ├── cnn_wayback_validation_urls.txt
│ ├── dailymail_wayback_test_urls.txt
│ ├── dailymail_wayback_training_urls.txt
│ ├── dailymail_wayback_validation_urls.txt
│ └── readme
├── eval.py
├── make_datafiles.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── attention.cpython-35.pyc
│ ├── basic.cpython-35.pyc
│ ├── layers.cpython-35.pyc
│ └── model.cpython-35.pyc
├── attention.py
├── basic.py
├── layers.py
└── model.py
├── requirements.txt
├── test.py
├── train.py
├── transformer
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-37.pyc
│ ├── layers.cpython-35.pyc
│ ├── model.cpython-35.pyc
│ ├── optim.cpython-35.pyc
│ └── sublayers.cpython-35.pyc
├── layers.py
├── model.py
├── optim.py
├── sublayers.py
└── tran_train.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-35.pyc
├── config.cpython-35.pyc
├── dataset.cpython-35.pyc
└── utils.cpython-35.pyc
├── config.py
├── dataset.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Pointer-Generator-Pytorch
2 |
3 | ## About
4 | The pytorch implementation of [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368).
5 | This implementation also tries to use Transformer as the encoder.
6 | The project are heavily borrowed from [atulkum-pointer_summarizer](https://github.com/atulkum/pointer_summarizer.git) and
7 | [jadore801120-attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch).
8 |
9 | ## Requirements
10 | * python==3.7.4
11 | * pytorch==1.4.0
12 | * pyrouge==0.1.3
13 | * tensorflow>=1.13.1
14 |
15 | ## Quick start
16 | * The path and parameters of project:
17 | you might need to change some path and parameters in utils/config.py according your setup.
18 | * Dataset:
19 | you can download the CNN/DailyMail dataset from https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail,
20 | then run make-datafiles.py to process data. For the specific process, you can refer to https://github.com/abisee/cnn-dailymail.
21 | * Run:
22 | you can run train.py, eval.py, and test for training, evaluating, and test, respectively.
23 |
24 | ### Note:
25 | * There is only single example repeated across the batch in the decode mode of beam search.
26 |
27 |
28 |
--------------------------------------------------------------------------------
/dataset/url_lists/readme:
--------------------------------------------------------------------------------
1 | Note: all_train.txt is simply the concatenation of cnn_wayback_training_urls.txt and dailymail_wayback_training_urls.txt.
2 |
3 | Similarly for validation and test sets.
4 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import unicode_literals, print_function, division
4 |
5 | import os
6 | import time
7 | import sys
8 |
9 | import tensorflow as tf
10 | import torch
11 |
12 | from models.model import Model
13 | from utils import config
14 | from utils.dataset import Vocab
15 | from utils.dataset import Batcher
16 | from utils.utils import get_input_from_batch
17 | from utils.utils import get_output_from_batch
18 | from utils.utils import calc_running_avg_loss
19 |
20 | use_cuda = config.use_gpu and torch.cuda.is_available()
21 |
22 | class Evaluate(object):
23 | def __init__(self, model_path):
24 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
25 | self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval',
26 | batch_size=config.batch_size, single_pass=True)
27 | time.sleep(15)
28 | model_name = os.path.basename(model_path)
29 |
30 | eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
31 | if not os.path.exists(eval_dir):
32 | os.mkdir(eval_dir)
33 | self.summary_writer = tf.summary.FileWriter(eval_dir)
34 |
35 | self.model = Model(model_path, is_eval=True)
36 |
37 | def eval_one_batch(self, batch):
38 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \
39 | get_input_from_batch(batch, use_cuda)
40 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, tgt_batch = \
41 | get_output_from_batch(batch, use_cuda)
42 |
43 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
44 | s_t = self.model.reduce_state(enc_h)
45 |
46 | step_losses = []
47 | for di in range(min(max_dec_len, config.max_dec_steps)):
48 | y_t = dec_batch[:, di] # Teacher forcing
49 | final_dist, s_t, c_t,attn_dist, p_gen, next_coverage = self.model.decoder(y_t, s_t,
50 | enc_out, enc_fea, enc_padding_mask, c_t,
51 | extra_zeros, enc_batch_extend_vocab, coverage, di)
52 | tgt = tgt_batch[:, di]
53 | gold_probs = torch.gather(final_dist, 1, tgt.unsqueeze(1)).squeeze()
54 | step_loss = -torch.log(gold_probs + config.eps)
55 | if config.is_coverage:
56 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
57 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
58 | coverage = next_coverage
59 |
60 | step_mask = dec_padding_mask[:, di]
61 | step_loss = step_loss * step_mask
62 | step_losses.append(step_loss)
63 |
64 | sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
65 | batch_avg_loss = sum_step_losses / dec_lens_var
66 | loss = torch.mean(batch_avg_loss)
67 |
68 | return loss.data[0]
69 |
70 | def run(self):
71 | start = time.time()
72 | running_avg_loss, iter = 0, 0
73 | batch = self.batcher.next_batch()
74 | print_interval = 100
75 | while batch is not None:
76 | loss = self.eval_one_batch(batch)
77 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
78 | iter += 1
79 |
80 | if iter % print_interval == 0:
81 | self.summary_writer.flush()
82 | print('step: %d, second: %.2f , loss: %f' % (iter, time.time() - start, running_avg_loss))
83 | start = time.time()
84 | batch = self.batcher.next_batch()
85 |
86 | return running_avg_loss
87 |
88 |
89 | if __name__ == '__main__':
90 | model_filename = sys.argv[1]
91 | eval_processor = Evaluate(model_filename)
92 | eval_processor.run()
93 |
94 |
95 |
--------------------------------------------------------------------------------
/make_datafiles.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # This code is from https://github.com/abisee/cnn-dailymail.git
4 |
5 | import sys
6 | import os
7 | import hashlib
8 | import struct
9 | import subprocess
10 | import collections
11 | from tensorflow.core.example import example_pb2
12 |
13 | dm_single_close_quote = u'\u2019' # unicode
14 | dm_double_close_quote = u'\u201d'
15 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote,
16 | ")"] # acceptable ways to end a sentence
17 |
18 | # We use these to separate the summary sentences in the .bin datafiles
19 | SENTENCE_START = ''
20 | SENTENCE_END = ''
21 |
22 | all_train_urls = "dataset/url_lists/all_train.txt"
23 | all_val_urls = "dataset/url_lists/all_val.txt"
24 | all_test_urls = "dataset/url_lists/all_test.txt"
25 |
26 | cnn_tokenized_stories_dir = "dataset/cnn_stories_tokenized"
27 | dm_tokenized_stories_dir = "dataset/dm_stories_tokenized"
28 | finished_files_dir = "dataset/finished_files"
29 | chunks_dir = os.path.join(finished_files_dir, "chunked")
30 |
31 | # These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir
32 | num_expected_cnn_stories = 92579
33 | num_expected_dm_stories = 219506
34 |
35 | VOCAB_SIZE = 200000
36 | CHUNK_SIZE = 1000 # num examples per chunk, for the chunked data
37 |
38 |
39 | def chunk_file(set_name):
40 | in_file = 'dataset/finished_files/%s.bin' % set_name
41 | reader = open(in_file, "rb")
42 | chunk = 0
43 | finished = False
44 | while not finished:
45 | chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (set_name, chunk)) # new chunk
46 | with open(chunk_fname, 'wb') as writer:
47 | for _ in range(CHUNK_SIZE):
48 | len_bytes = reader.read(8)
49 | if not len_bytes:
50 | finished = True
51 | break
52 | str_len = struct.unpack('q', len_bytes)[0]
53 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
54 | writer.write(struct.pack('q', str_len))
55 | writer.write(struct.pack('%ds' % str_len, example_str))
56 | chunk += 1
57 |
58 |
59 | def chunk_all():
60 | # Make a dir to hold the chunks
61 | if not os.path.isdir(chunks_dir):
62 | os.mkdir(chunks_dir)
63 | # Chunk the data
64 | for set_name in ['train', 'val', 'test']:
65 | print("Splitting %s data into chunks..." % set_name)
66 | chunk_file(set_name)
67 | print("Saved chunked data in %s" % chunks_dir)
68 |
69 |
70 | def tokenize_stories(stories_dir, tokenized_stories_dir):
71 | """Maps a whole directory of .story files to a tokenized version using Stanford CoreNLP Tokenizer"""
72 | print("Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir))
73 | stories = os.listdir(stories_dir)
74 | # make IO list file
75 | print("Making list of files to tokenize...")
76 | with open("mapping.txt", "w") as f:
77 | for s in stories:
78 | f.write("%s \t %s\n" % (os.path.join(stories_dir, s), os.path.join(tokenized_stories_dir, s)))
79 | command = ['java', 'edu.stanford.nlp.process.PTBTokenizer', '-ioFileList', '-preserveLines', 'mapping.txt']
80 | print("Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir))
81 | subprocess.call(command)
82 | print("Stanford CoreNLP Tokenizer has finished.")
83 | os.remove("mapping.txt")
84 |
85 | # Check that the tokenized stories directory contains the same number of files as the original directory
86 | num_orig = len(os.listdir(stories_dir))
87 | num_tokenized = len(os.listdir(tokenized_stories_dir))
88 | if num_orig != num_tokenized:
89 | raise Exception(
90 | "The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % (
91 | tokenized_stories_dir, num_tokenized, stories_dir, num_orig))
92 | print("Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir))
93 |
94 |
95 | def read_text_file(text_file):
96 | lines = []
97 | with open(text_file, "r") as f:
98 | for line in f:
99 | lines.append(line.strip())
100 | return lines
101 |
102 |
103 | def hashhex(s):
104 | """Returns a heximal formated SHA1 hash of the input string."""
105 | h = hashlib.sha1()
106 | h.update(s.encode('utf-8'))
107 | return h.hexdigest()
108 |
109 |
110 | def get_url_hashes(url_list):
111 | return [hashhex(url) for url in url_list]
112 |
113 |
114 | def fix_missing_period(line):
115 | """Adds a period to a line that is missing a period"""
116 | if "@highlight" in line: return line
117 | if line == "": return line
118 | if line[-1] in END_TOKENS: return line
119 | # print line[-1]
120 | return line + " ."
121 |
122 |
123 | def get_art_abs(story_file):
124 | lines = read_text_file(story_file)
125 |
126 | # Lowercase everything
127 | lines = [line.lower() for line in lines]
128 |
129 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences)
130 | lines = [fix_missing_period(line) for line in lines]
131 |
132 | # Separate out article and abstract sentences
133 | article_lines = []
134 | highlights = []
135 | next_is_highlight = False
136 | for idx, line in enumerate(lines):
137 | if line == "":
138 | continue # empty line
139 | elif line.startswith("@highlight"):
140 | next_is_highlight = True
141 | elif next_is_highlight:
142 | highlights.append(line)
143 | else:
144 | article_lines.append(line)
145 |
146 | # Make article into a single string
147 | article = ' '.join(article_lines)
148 |
149 | # Make abstract into a signle string, putting and tags around the sentences
150 | abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights])
151 |
152 | return article, abstract
153 |
154 |
155 | def write_to_bin(url_file, out_file, makevocab=False):
156 | """Reads the tokenized .story files corresponding to the urls listed in the url_file and writes them to a out_file."""
157 | print("Making bin file for URLs listed in %s..." % url_file)
158 | url_list = read_text_file(url_file)
159 | url_hashes = get_url_hashes(url_list)
160 | story_fnames = [s + ".story" for s in url_hashes]
161 | num_stories = len(story_fnames)
162 |
163 | if makevocab:
164 | vocab_counter = collections.Counter()
165 |
166 | with open(out_file, 'wb') as writer:
167 | for idx, s in enumerate(story_fnames):
168 | if idx % 1000 == 0:
169 | print("Writing story %i of %i; %.2f percent done" % (
170 | idx, num_stories, float(idx) * 100.0 / float(num_stories)))
171 |
172 | # Look in the tokenized story dirs to find the .story file corresponding to this url
173 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)):
174 | story_file = os.path.join(cnn_tokenized_stories_dir, s)
175 | elif os.path.isfile(os.path.join(dm_tokenized_stories_dir, s)):
176 | story_file = os.path.join(dm_tokenized_stories_dir, s)
177 | else:
178 | print("Error: Couldn't find tokenized story file %s in either tokenized story directories %s and %s. Was there an error during tokenization?" % (
179 | s, cnn_tokenized_stories_dir, dm_tokenized_stories_dir))
180 | # Check again if tokenized stories directories contain correct number of files
181 | print("Checking that the tokenized stories directories %s and %s contain correct number of files..." % (
182 | cnn_tokenized_stories_dir, dm_tokenized_stories_dir))
183 | check_num_stories(cnn_tokenized_stories_dir, num_expected_cnn_stories)
184 | check_num_stories(dm_tokenized_stories_dir, num_expected_dm_stories)
185 | raise Exception(
186 | "Tokenized stories directories %s and %s contain correct number of files but story file %s found in neither." % (
187 | cnn_tokenized_stories_dir, dm_tokenized_stories_dir, s))
188 |
189 | # Get the strings to write to .bin file
190 | article, abstract = get_art_abs(story_file)
191 |
192 | # Write to tf.Example
193 | tf_example = example_pb2.Example()
194 | tf_example.features.feature['article'].bytes_list.value.extend([bytes(article, encoding='utf-8')])
195 | tf_example.features.feature['abstract'].bytes_list.value.extend([bytes(abstract, encoding='utf-8')])
196 | tf_example_str = tf_example.SerializeToString()
197 | str_len = len(tf_example_str)
198 | writer.write(struct.pack('q', str_len))
199 | writer.write(struct.pack('%ds' % str_len, tf_example_str))
200 |
201 | # Write the vocab to file, if applicable
202 | if makevocab:
203 | art_tokens = article.split(' ')
204 | abs_tokens = abstract.split(' ')
205 | abs_tokens = [t for t in abs_tokens if
206 | t not in [SENTENCE_START, SENTENCE_END]] # remove these tags from vocab
207 | tokens = art_tokens + abs_tokens
208 | tokens = [t.strip() for t in tokens] # strip
209 | tokens = [t for t in tokens if t != ""] # remove empty
210 | vocab_counter.update(tokens)
211 |
212 | print("Finished writing file %s\n" % out_file)
213 |
214 | # write vocab to file
215 | if makevocab:
216 | print( "Writing vocab file...")
217 | with open(os.path.join(finished_files_dir, "vocab"), 'w') as writer:
218 | for word, count in vocab_counter.most_common(VOCAB_SIZE):
219 | writer.write(word + ' ' + str(count) + '\n')
220 | print("Finished writing vocab file")
221 |
222 |
223 | def check_num_stories(stories_dir, num_expected):
224 | num_stories = len(os.listdir(stories_dir))
225 | if num_stories != num_expected:
226 | raise Exception(
227 | "stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected))
228 |
229 |
230 | if __name__ == '__main__':
231 | if len(sys.argv) != 3:
232 | print("USAGE: python make_datafiles.py ")
233 | sys.exit()
234 | cnn_stories_dir = sys.argv[1]
235 | dm_stories_dir = sys.argv[2]
236 |
237 | # Check the stories directories contain the correct number of .story files
238 | check_num_stories(cnn_stories_dir, num_expected_cnn_stories)
239 | check_num_stories(dm_stories_dir, num_expected_dm_stories)
240 |
241 | # Create some new directories
242 | if not os.path.exists(cnn_tokenized_stories_dir): os.makedirs(cnn_tokenized_stories_dir)
243 | if not os.path.exists(dm_tokenized_stories_dir): os.makedirs(dm_tokenized_stories_dir)
244 | if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir)
245 |
246 | # Run stanford tokenizer on both stories dirs, outputting to tokenized stories directories
247 | tokenize_stories(cnn_stories_dir, cnn_tokenized_stories_dir)
248 | tokenize_stories(dm_stories_dir, dm_tokenized_stories_dir)
249 |
250 | # Read the tokenized stories, do a little postprocessing then write to bin files
251 | write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test.bin"))
252 | write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val.bin"))
253 | write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train.bin"), makevocab=True)
254 |
255 | # Chunk the data. This splits each of train.bin, val.bin and test.bin into smaller chunks, each containing e.g. 1000 examples, and saves them in finished_files/chunks
256 | chunk_all()
257 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/attention.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/attention.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/basic.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/basic.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/layers.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/layers.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/model.cpython-35.pyc
--------------------------------------------------------------------------------
/models/attention.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from utils import config
8 | from models.basic import BasicModule
9 |
10 |
11 | class Attention(BasicModule):
12 | def __init__(self):
13 | super(Attention, self).__init__()
14 |
15 | self.fc = nn.Linear(config.hidden_dim * 2, 1, bias=False)
16 | self.dec_fc = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2)
17 | if config.is_coverage:
18 | self.con_fc = nn.Linear(1, config.hidden_dim * 2, bias=False)
19 |
20 | self.init_params()
21 |
22 | def forward(self, s_t, enc_out, enc_fea, enc_padding_mask, coverage):
23 | b, l, n = list(enc_out.size())
24 |
25 | dec_fea = self.dec_fc(s_t) # B x 2*hidden_dim
26 | dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, l, n).contiguous() # B x l x 2*hidden_dim
27 | dec_fea_expanded = dec_fea_expanded.view(-1, n) # B*l x 2*hidden_dim
28 |
29 | att_features = enc_fea + dec_fea_expanded # B*l x 2*hidden_dim
30 | if config.is_coverage:
31 | coverage_inp = coverage.view(-1, 1) # B*l x 1
32 | coverage_fea = self.con_fc(coverage_inp) # B*l x 2*hidden_dim
33 | att_features = att_features + coverage_fea
34 |
35 | e = torch.tanh(att_features) # B*l x 2*hidden_dim
36 | scores = self.fc(e) # B*l x 1
37 | scores = scores.view(-1, l) # B x l
38 |
39 | attn_dist_ = F.softmax(scores, dim=1) * enc_padding_mask # B x l
40 | normalization_factor = attn_dist_.sum(1, keepdim=True)
41 | attn_dist = attn_dist_ / normalization_factor
42 |
43 | attn_dist = attn_dist.unsqueeze(1) # B x 1 x l
44 | c_t = torch.bmm(attn_dist, enc_out) # B x 1 x n
45 | c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim
46 |
47 | attn_dist = attn_dist.view(-1, l) # B x l
48 |
49 | if config.is_coverage:
50 | coverage = coverage.view(-1, l)
51 | coverage = coverage + attn_dist
52 |
53 | return c_t, attn_dist, coverage
54 |
55 |
--------------------------------------------------------------------------------
/models/basic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 |
7 | class BasicModule(nn.Module):
8 | def __init__(self, init='uniform'):
9 | super(BasicModule, self).__init__()
10 | self.init = init
11 |
12 | def init_params(self):
13 | for param in self.parameters():
14 | if param.requires_grad and len(param.shape) > 0:
15 | stddev = 1 / math.sqrt(param.shape[0])
16 | if self.init == 'uniform':
17 | torch.nn.init.uniform_(param, a=-0.05, b=0.05)
18 | elif self.init == 'normal':
19 | torch.nn.init.normal_(param, std=stddev)
20 | elif self.init == 'truncated_normal':
21 | self.truncated_normal_(param, mean=0,std=stddev)
22 |
23 | def truncated_normal_(self, tensor, mean=0, std=1.):
24 | """
25 | Implemented by @ruotianluo
26 | See https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
27 | """
28 | size = tensor.shape
29 | tmp = tensor.new_empty(size + (4,)).normal_()
30 | valid = (tmp < 2) & (tmp > -2)
31 | ind = valid.max(-1, keepdim=True)[1]
32 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
33 | tensor.data.mul_(std).add_(mean)
34 | return tensor
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pad_packed_sequence
7 | from torch.nn.utils.rnn import pack_padded_sequence
8 |
9 | from utils import config
10 | from models.basic import BasicModule
11 | from models.attention import Attention
12 |
13 |
14 | class Encoder(BasicModule):
15 | def __init__(self):
16 | super(Encoder, self).__init__()
17 | self.src_word_emb = nn.Embedding(config.vocab_size, config.emb_dim)
18 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, batch_first=True, bidirectional=True)
19 | self.fc = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False)
20 |
21 | self.init_params()
22 |
23 | # seq_lens should be in descending order
24 | def forward(self, input, seq_lens):
25 | embedded = self.src_word_emb(input)
26 |
27 | packed = pack_padded_sequence(embedded, seq_lens, batch_first=True)
28 | output, hidden = self.lstm(packed)
29 |
30 | encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x l x n
31 | encoder_outputs = encoder_outputs.contiguous()
32 |
33 | encoder_feature = encoder_outputs.view(-1, 2 * config.hidden_dim) # B*l x 2*hidden_dim
34 | encoder_feature = self.fc(encoder_feature)
35 |
36 | return encoder_outputs, encoder_feature, hidden
37 |
38 |
39 | class ReduceState(BasicModule):
40 | def __init__(self):
41 | super(ReduceState, self).__init__()
42 |
43 | self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
44 | self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
45 | self.init_params()
46 |
47 |
48 | def forward(self, hidden):
49 | h, c = hidden # h, c dim = 2 x b x hidden_dim
50 | h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
51 | hidden_reduced_h = F.relu(self.reduce_h(h_in))
52 | c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
53 | hidden_reduced_c = F.relu(self.reduce_c(c_in))
54 |
55 | return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim
56 |
57 | class Decoder(BasicModule):
58 | def __init__(self):
59 | super(Decoder, self).__init__()
60 | self.attention_network = Attention()
61 | # decoder
62 | self.tgt_word_emb = nn.Embedding(config.vocab_size, config.emb_dim)
63 | self.con_fc = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim)
64 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, batch_first=True, bidirectional=False)
65 |
66 | if config.pointer_gen:
67 | self.p_gen_fc = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1)
68 |
69 | # p_vocab
70 | self.fc1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim)
71 | self.fc2 = nn.Linear(config.hidden_dim, config.vocab_size)
72 |
73 | self.init_params()
74 |
75 | def forward(self, y_t, s_t, enc_out, enc_fea, enc_padding_mask,
76 | c_t, extra_zeros, enc_batch_extend_vocab, coverage, step):
77 |
78 | if not self.training and step == 0:
79 | dec_h, dec_c = s_t
80 | s_t_hat = torch.cat((dec_h.view(-1, config.hidden_dim),
81 | dec_c.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
82 | c_t, _, coverage_next = self.attention_network(s_t_hat, enc_out, enc_fea,
83 | enc_padding_mask, coverage)
84 | coverage = coverage_next
85 |
86 | y_t_embd = self.tgt_word_emb(y_t)
87 | x = self.con_fc(torch.cat((c_t, y_t_embd), 1))
88 | lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t)
89 |
90 | dec_h, dec_c = s_t
91 | s_t_hat = torch.cat((dec_h.view(-1, config.hidden_dim),
92 | dec_c.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
93 | c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, enc_out, enc_fea,
94 | enc_padding_mask, coverage)
95 |
96 | if self.training or step > 0:
97 | coverage = coverage_next
98 |
99 | p_gen = None
100 | if config.pointer_gen:
101 | p_gen_inp = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim)
102 | p_gen = self.p_gen_fc(p_gen_inp)
103 | p_gen = torch.sigmoid(p_gen)
104 |
105 | output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3
106 | output = self.fc1(output) # B x hidden_dim
107 | # output = F.relu(output)
108 |
109 | output = self.fc2(output) # B x vocab_size
110 | vocab_dist = F.softmax(output, dim=1)
111 |
112 | if config.pointer_gen:
113 | vocab_dist_ = p_gen * vocab_dist
114 | attn_dist_ = (1 - p_gen) * attn_dist
115 |
116 | if extra_zeros is not None:
117 | vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)
118 |
119 | final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
120 | else:
121 | final_dist = vocab_dist
122 |
123 | return final_dist, s_t, c_t, attn_dist, p_gen, coverage
124 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 | from utils import config
6 | from numpy import random
7 | from models.layers import Encoder
8 | from models.layers import Decoder
9 | from models.layers import ReduceState
10 | from transformer.model import TranEncoder
11 |
12 | use_cuda = config.use_gpu and torch.cuda.is_available()
13 |
14 | random.seed(123)
15 | torch.manual_seed(123)
16 | if torch.cuda.is_available():
17 | torch.cuda.manual_seed_all(123)
18 |
19 |
20 | class Model(object):
21 | def __init__(self, model_path=None, is_eval=False, is_tran = False):
22 | encoder = Encoder()
23 | decoder = Decoder()
24 | reduce_state = ReduceState()
25 | if is_tran:
26 | encoder = TranEncoder(config.vocab_size, config.max_enc_steps, config.emb_dim,
27 | config.n_layers, config.n_head, config.d_k, config.d_v, config.d_model, config.d_inner)
28 |
29 | # shared the embedding between encoder and decoder
30 | decoder.tgt_word_emb.weight = encoder.src_word_emb.weight
31 |
32 | if is_eval:
33 | encoder = encoder.eval()
34 | decoder = decoder.eval()
35 | reduce_state = reduce_state.eval()
36 |
37 | if use_cuda:
38 | encoder = encoder.cuda()
39 | decoder = decoder.cuda()
40 | reduce_state = reduce_state.cuda()
41 |
42 | self.encoder = encoder
43 | self.decoder = decoder
44 | self.reduce_state = reduce_state
45 |
46 | if model_path is not None:
47 | state = torch.load(model_path, map_location=lambda storage, location: storage)
48 | self.encoder.load_state_dict(state['encoder_state_dict'])
49 | self.decoder.load_state_dict(state['decoder_state_dict'], strict=False)
50 | self.reduce_state.load_state_dict(state['reduce_state_dict'])
51 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.7.4
2 | torch==1.4.0
3 | pyrouge==0.1.3
4 | tensorflow==1.13.1
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import unicode_literals, print_function, division
4 |
5 | import os
6 | import sys
7 | import time
8 | import torch
9 | from torch.autograd import Variable
10 |
11 | from models.model import Model
12 | from utils import utils
13 | from utils.dataset import Batcher
14 | from utils.dataset import Vocab
15 | from utils import dataset, config
16 | from utils.utils import get_input_from_batch
17 | from utils.utils import write_for_rouge, rouge_eval, rouge_log
18 |
19 | use_cuda = config.use_gpu and torch.cuda.is_available()
20 |
21 |
22 | class Beam(object):
23 | def __init__(self, tokens, log_probs, state, context, coverage):
24 | self.tokens = tokens
25 | self.state = state
26 | self.context = context
27 | self.coverage = coverage
28 | self.log_probs = log_probs
29 |
30 | def extend(self, token, log_prob, state, context, coverage):
31 | return Beam(tokens=self.tokens + [token],
32 | log_probs=self.log_probs + [log_prob],
33 | state=state,
34 | context=context,
35 | coverage=coverage)
36 |
37 | @property
38 | def latest_token(self):
39 | return self.tokens[-1]
40 |
41 | @property
42 | def avg_log_prob(self):
43 | return sum(self.log_probs) / len(self.tokens)
44 |
45 |
46 | class BeamSearch(object):
47 | def __init__(self, model_file_path):
48 |
49 | model_name = os.path.basename(model_file_path)
50 | self._test_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
51 | self._rouge_ref_dir = os.path.join(self._test_dir, 'rouge_ref')
52 | self._rouge_dec_dir = os.path.join(self._test_dir, 'rouge_dec')
53 | for p in [self._test_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
54 | if not os.path.exists(p):
55 | os.mkdir(p)
56 |
57 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
58 | self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
59 | batch_size=config.beam_size, single_pass=True)
60 | time.sleep(15)
61 |
62 | self.model = Model(model_file_path, is_eval=True)
63 |
64 | def sort_beams(self, beams):
65 | return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)
66 |
67 |
68 | def beam_search(self, batch):
69 | # single example repeated across the batch
70 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \
71 | get_input_from_batch(batch, use_cuda)
72 |
73 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
74 | s_t = self.model.reduce_state(enc_h)
75 |
76 | dec_h, dec_c = s_t # b x hidden_dim
77 | dec_h = dec_h.squeeze()
78 | dec_c = dec_c.squeeze()
79 |
80 | # decoder batch preparation, it has beam_size example initially everything is repeated
81 | beams = [Beam(tokens=[self.vocab.word2id(config.BOS_TOKEN)],
82 | log_probs=[0.0],
83 | state=(dec_h[0], dec_c[0]),
84 | context=c_t[0],
85 | coverage=(coverage[0] if config.is_coverage else None))
86 | for _ in range(config.beam_size)]
87 |
88 | steps = 0
89 | results = []
90 | while steps < config.max_dec_steps and len(results) < config.beam_size:
91 | latest_tokens = [h.latest_token for h in beams]
92 | latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(config.UNK_TOKEN) \
93 | for t in latest_tokens]
94 | y_t = Variable(torch.LongTensor(latest_tokens))
95 | if use_cuda:
96 | y_t = y_t.cuda()
97 | all_state_h = [h.state[0] for h in beams]
98 | all_state_c = [h.state[1] for h in beams]
99 | all_context = [h.context for h in beams]
100 |
101 | s_t = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
102 | c_t = torch.stack(all_context, 0)
103 |
104 | coverage_t = None
105 | if config.is_coverage:
106 | all_coverage = [h.coverage for h in beams]
107 | coverage_t = torch.stack(all_coverage, 0)
108 |
109 | final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t, s_t,
110 | enc_out, enc_fea,
111 | enc_padding_mask, c_t,
112 | extra_zeros, enc_batch_extend_vocab,
113 | coverage_t, steps)
114 | log_probs = torch.log(final_dist)
115 | topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2)
116 |
117 | dec_h, dec_c = s_t
118 | dec_h = dec_h.squeeze()
119 | dec_c = dec_c.squeeze()
120 |
121 | all_beams = []
122 | # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
123 | num_orig_beams = 1 if steps == 0 else len(beams)
124 | for i in range(num_orig_beams):
125 | h = beams[i]
126 | state_i = (dec_h[i], dec_c[i])
127 | context_i = c_t[i]
128 | coverage_i = (coverage[i] if config.is_coverage else None)
129 |
130 | for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps:
131 | new_beam = h.extend(token=topk_ids[i, j].item(),
132 | log_prob=topk_log_probs[i, j].item(),
133 | state=state_i,
134 | context=context_i,
135 | coverage=coverage_i)
136 | all_beams.append(new_beam)
137 |
138 | beams = []
139 | for h in self.sort_beams(all_beams):
140 | if h.latest_token == self.vocab.word2id(config.EOS_TOKEN):
141 | if steps >= config.min_dec_steps:
142 | results.append(h)
143 | else:
144 | beams.append(h)
145 | if len(beams) == config.beam_size or len(results) == config.beam_size:
146 | break
147 |
148 | steps += 1
149 |
150 | if len(results) == 0:
151 | results = beams
152 |
153 | beams_sorted = self.sort_beams(results)
154 |
155 | return beams_sorted[0]
156 |
157 | def run(self):
158 |
159 | counter = 0
160 | start = time.time()
161 | batch = self.batcher.next_batch()
162 | while batch is not None:
163 | # Run beam search to get best Hypothesis
164 | best_summary = self.beam_search(batch)
165 |
166 | # Extract the output ids from the hypothesis and convert back to words
167 | output_ids = [int(t) for t in best_summary.tokens[1:]]
168 | decoded_words = utils.outputids2words(output_ids, self.vocab,
169 | (batch.art_oovs[0] if config.pointer_gen else None))
170 |
171 | # Remove the [STOP] token from decoded_words, if necessary
172 | try:
173 | fst_stop_idx = decoded_words.index(dataset.EOS_TOKEN)
174 | decoded_words = decoded_words[:fst_stop_idx]
175 | except ValueError:
176 | decoded_words = decoded_words
177 |
178 | original_abstract_sents = batch.original_abstracts_sents[0]
179 |
180 | write_for_rouge(original_abstract_sents, decoded_words, counter,
181 | self._rouge_ref_dir, self._rouge_dec_dir)
182 | counter += 1
183 | if counter % 1000 == 0:
184 | print('%d example in %d sec' % (counter, time.time() - start))
185 | start = time.time()
186 |
187 | batch = self.batcher.next_batch()
188 |
189 | print("Decoder has finished reading dataset for single_pass.")
190 | print("Now starting ROUGE eval...")
191 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
192 | rouge_log(results_dict, self._test_dir)
193 |
194 |
195 | if __name__ == '__main__':
196 | model_filename = sys.argv[1]
197 | test_processor = BeamSearch(model_filename)
198 | test_processor.run()
199 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import unicode_literals, print_function, division
4 |
5 | import os
6 | import time
7 | import argparse
8 | import tensorflow as tf
9 |
10 | import torch
11 | import torch.optim as optim
12 | from torch.nn.utils import clip_grad_norm_
13 |
14 | from models.model import Model
15 | from utils import config
16 | from utils.dataset import Vocab
17 | from utils.dataset import Batcher
18 | from utils.utils import get_input_from_batch
19 | from utils.utils import get_output_from_batch
20 | from utils.utils import calc_running_avg_loss
21 |
22 | use_cuda = config.use_gpu and torch.cuda.is_available()
23 |
24 |
25 | class Train(object):
26 | def __init__(self):
27 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
28 | self.batcher = Batcher(self.vocab, config.train_data_path,
29 | config.batch_size, single_pass=False, mode='train')
30 | time.sleep(10)
31 |
32 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
33 | if not os.path.exists(train_dir):
34 | os.mkdir(train_dir)
35 |
36 | self.model_dir = os.path.join(train_dir, 'models')
37 | if not os.path.exists(self.model_dir):
38 | os.mkdir(self.model_dir)
39 |
40 | self.summary_writer = tf.summary.FileWriter(train_dir)
41 |
42 | def save_model(self, running_avg_loss, iter):
43 | state = {
44 | 'iter': iter,
45 | 'encoder_state_dict': self.model.encoder.state_dict(),
46 | 'decoder_state_dict': self.model.decoder.state_dict(),
47 | 'reduce_state_dict': self.model.reduce_state.state_dict(),
48 | 'optimizer': self.optimizer.state_dict(),
49 | 'current_loss': running_avg_loss
50 | }
51 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
52 | torch.save(state, model_save_path)
53 |
54 | def setup_train(self, model_path=None):
55 | self.model = Model(model_path, is_tran= config.tran)
56 | initial_lr = config.lr_coverage if config.is_coverage else config.lr
57 |
58 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
59 | list(self.model.reduce_state.parameters())
60 | total_params = sum([param[0].nelement() for param in params])
61 | print('The Number of params of model: %.3f million' % (total_params / 1e6)) # million
62 | self.optimizer = optim.Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
63 |
64 | start_iter, start_loss = 0, 0
65 |
66 | if model_path is not None:
67 | state = torch.load(model_path, map_location=lambda storage, location: storage)
68 | start_iter = state['iter']
69 | start_loss = state['current_loss']
70 |
71 | if not config.is_coverage:
72 | self.optimizer.load_state_dict(state['optimizer'])
73 | if use_cuda:
74 | for state in self.optimizer.state.values():
75 | for k, v in state.items():
76 | if torch.is_tensor(v):
77 | state[k] = v.cuda()
78 |
79 | return start_iter, start_loss
80 |
81 | def train_one_batch(self, batch):
82 | enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \
83 | extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda)
84 | dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
85 | get_output_from_batch(batch, use_cuda)
86 |
87 | self.optimizer.zero_grad()
88 |
89 | if not config.tran:
90 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
91 | else:
92 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_pos)
93 |
94 | s_t = self.model.reduce_state(enc_h)
95 |
96 | step_losses, cove_losses = [], []
97 | for di in range(min(max_dec_len, config.max_dec_steps)):
98 | y_t = dec_batch[:, di] # Teacher forcing
99 | final_dist, s_t, c_t, attn_dist, p_gen, next_coverage = \
100 | self.model.decoder(y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t,
101 | extra_zeros, enc_batch_extend_vocab, coverage, di)
102 | tgt = tgt_batch[:, di]
103 | step_mask = dec_padding_mask[:, di]
104 | gold_probs = torch.gather(final_dist, 1, tgt.unsqueeze(1)).squeeze()
105 | step_loss = -torch.log(gold_probs + config.eps)
106 | if config.is_coverage:
107 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
108 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
109 | cove_losses.append(step_coverage_loss * step_mask)
110 | coverage = next_coverage
111 |
112 | step_loss = step_loss * step_mask
113 | step_losses.append(step_loss)
114 |
115 | sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
116 | batch_avg_loss = sum_losses / dec_lens
117 | loss = torch.mean(batch_avg_loss)
118 |
119 | loss.backward()
120 |
121 | clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
122 | clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
123 | clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)
124 |
125 | self.optimizer.step()
126 |
127 | if config.is_coverage:
128 | cove_losses = torch.sum(torch.stack(cove_losses, 1), 1)
129 | batch_cove_loss = cove_losses / dec_lens
130 | batch_cove_loss = torch.mean(batch_cove_loss)
131 | return loss.item(), batch_cove_loss.item()
132 |
133 | return loss.item(), 0.
134 |
135 | def run(self, n_iters, model_path=None):
136 | iter, running_avg_loss = self.setup_train(model_path)
137 | start = time.time()
138 | interval = 100
139 |
140 | while iter < n_iters:
141 | batch = self.batcher.next_batch()
142 | loss, cove_loss = self.train_one_batch(batch)
143 |
144 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
145 | iter += 1
146 |
147 | if iter % interval == 0:
148 | self.summary_writer.flush()
149 | print(
150 | 'step: %d, second: %.2f , loss: %f, cover_loss: %f' % (iter, time.time() - start, loss, cove_loss))
151 | start = time.time()
152 | if iter % 5000 == 0:
153 | self.save_model(running_avg_loss, iter)
154 |
155 |
156 | if __name__ == '__main__':
157 | parser = argparse.ArgumentParser(description="Train script")
158 | parser.add_argument("-m",
159 | dest="model_path",
160 | required=False,
161 | default=None,
162 | help="Model file for retraining (default: None).")
163 | args = parser.parse_args()
164 |
165 | train_processor = Train()
166 | train_processor.run(config.max_iterations, args.model_path)
167 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__init__.py
--------------------------------------------------------------------------------
/transformer/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/transformer/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/transformer/__pycache__/layers.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/layers.cpython-35.pyc
--------------------------------------------------------------------------------
/transformer/__pycache__/model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/model.cpython-35.pyc
--------------------------------------------------------------------------------
/transformer/__pycache__/optim.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/optim.cpython-35.pyc
--------------------------------------------------------------------------------
/transformer/__pycache__/sublayers.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/sublayers.cpython-35.pyc
--------------------------------------------------------------------------------
/transformer/layers.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 | from transformer.sublayers import MultiHeadAttention
5 | from transformer.sublayers import PositionwiseFeedForward
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | ''' Compose with two layers '''
10 |
11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
12 | super(EncoderLayer, self).__init__()
13 | self.slf_attn = MultiHeadAttention(
14 | n_head, d_model, d_k, d_v, dropout=dropout)
15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
16 |
17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
18 | enc_output, enc_slf_attn = self.slf_attn(
19 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
20 | enc_output *= non_pad_mask
21 |
22 | enc_output = self.pos_ffn(enc_output)
23 | enc_output *= non_pad_mask
24 |
25 | return enc_output, enc_slf_attn
26 |
27 |
28 | class DecoderLayer(nn.Module):
29 | ''' Compose with three layers '''
30 |
31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
32 | super(DecoderLayer, self).__init__()
33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
36 |
37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
38 | dec_output, dec_slf_attn = self.slf_attn(
39 | dec_input, dec_input, dec_input, mask=slf_attn_mask)
40 | dec_output *= non_pad_mask
41 |
42 | dec_output, dec_enc_attn = self.enc_attn(
43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
44 | dec_output *= non_pad_mask
45 |
46 | dec_output = self.pos_ffn(dec_output)
47 | dec_output *= non_pad_mask
48 |
49 | return dec_output, dec_slf_attn, dec_enc_attn
--------------------------------------------------------------------------------
/transformer/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from utils import config
9 | from transformer.layers import EncoderLayer
10 | from transformer.layers import DecoderLayer
11 |
12 | use_cuda = config.use_gpu and torch.cuda.is_available()
13 |
14 | def get_non_pad_mask(seq):
15 | assert seq.dim() == 2
16 | return seq.ne(config.PAD).type(torch.float).unsqueeze(-1)
17 |
18 |
19 | def positional_encoding(max_len, d_model, padding_idx=None):
20 | ''' Sinusoid position encoding table '''
21 |
22 | pe = torch.zeros(max_len, d_model)
23 | position = torch.arange(0., max_len).unsqueeze(1)
24 | div_term = torch.exp(torch.arange(0., d_model, 2) *
25 | -(math.log(10000.0) / d_model))
26 | pe[:, 0::2] = torch.sin(position * div_term)
27 | pe[:, 1::2] = torch.cos(position * div_term)
28 |
29 | if padding_idx is not None:
30 | # zero vector for padding dimension
31 | pe[padding_idx] = 0.
32 |
33 | return torch.FloatTensor(pe)
34 |
35 |
36 | def get_attn_key_pad_mask(seq_k, seq_q):
37 | ''' For masking out the padding part of key sequence. '''
38 |
39 | # Expand to fit the shape of key query attention matrix.
40 | len_q = seq_q.size(1)
41 | padding_mask = seq_k.eq(config.PAD)
42 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
43 |
44 | return padding_mask
45 |
46 |
47 | def get_subsequent_mask(seq):
48 | ''' For masking out the subsequent info. '''
49 |
50 | sz_b, len_s = seq.size()
51 | subsequent_mask = torch.triu(
52 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
53 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
54 |
55 | return subsequent_mask
56 |
57 |
58 | class TranEncoder(nn.Module):
59 | ''' A encoder model with attention mechanism. '''
60 |
61 | def __init__(self, n_src_vocab, len_max_seq, d_word_vec,
62 | n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
63 | super().__init__()
64 |
65 | n_position = len_max_seq + 1
66 |
67 | self.d_model = d_model
68 |
69 | self.src_word_emb = nn.Embedding(
70 | n_src_vocab, d_word_vec, padding_idx=config.PAD)
71 |
72 | self.position_enc = nn.Embedding.from_pretrained(
73 | positional_encoding(n_position, d_word_vec, padding_idx=1),
74 | freeze=True)
75 |
76 | self.layer_stack = nn.ModuleList([
77 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
78 | for _ in range(n_layers)])
79 |
80 | self.fc = nn.Linear(d_model, 2 * config.hidden_dim, bias=False)
81 |
82 | def forward(self, src_seq, src_pos, return_attns=False):
83 |
84 | enc_slf_attn_list = []
85 |
86 | # -- Prepare masks
87 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
88 | non_pad_mask = get_non_pad_mask(src_seq)
89 |
90 | # -- Forward
91 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
92 |
93 | for enc_layer in self.layer_stack:
94 | enc_output, enc_slf_attn = enc_layer(
95 | enc_output,
96 | non_pad_mask=non_pad_mask,
97 | slf_attn_mask=slf_attn_mask)
98 | if return_attns:
99 | enc_slf_attn_list += [enc_slf_attn]
100 |
101 | if return_attns:
102 | return enc_output, enc_slf_attn_list
103 |
104 | enc_output = self.fc(enc_output)
105 | enc_feature = enc_output.view(-1, 2*config.hidden_dim) # B*l x 2*hidden_dim
106 | b, l, n = list(enc_output.size())
107 | enc_h = enc_output[:,-1,:].reshape(2, b, config.hidden_dim).contiguous()
108 |
109 | return enc_output, enc_feature, (enc_h, enc_h)
110 |
111 |
112 | class TranDecoder(nn.Module):
113 | ''' A decoder model with self attention mechanism. '''
114 |
115 | def __init__(self, n_tgt_vocab, len_max_seq, d_word_vec,
116 | n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
117 |
118 | super().__init__()
119 | n_position = len_max_seq + 1
120 |
121 | self.tgt_word_emb = nn.Embedding(
122 | n_tgt_vocab, d_word_vec, padding_idx=config.PAD)
123 |
124 | self.position_enc = nn.Embedding.from_pretrained(
125 | positional_encoding(n_position, d_word_vec, padding_idx=config.PAD),
126 | freeze=True)
127 |
128 | self.layer_stack = nn.ModuleList([
129 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
130 | for _ in range(n_layers)])
131 |
132 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
133 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
134 |
135 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):
136 |
137 | dec_slf_attn_list, dec_enc_attn_list = [], []
138 |
139 | # -- Prepare masks
140 | non_pad_mask = get_non_pad_mask(tgt_seq)
141 |
142 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
143 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
144 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
145 |
146 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)
147 |
148 | # -- Forward
149 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)
150 |
151 | for dec_layer in self.layer_stack:
152 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
153 | dec_output, enc_output,
154 | non_pad_mask=non_pad_mask,
155 | slf_attn_mask=slf_attn_mask,
156 | dec_enc_attn_mask=dec_enc_attn_mask)
157 |
158 | if return_attns:
159 | dec_slf_attn_list += [dec_slf_attn]
160 | dec_enc_attn_list += [dec_enc_attn]
161 |
162 | if return_attns:
163 | return dec_output, dec_slf_attn_list, dec_enc_attn_list
164 |
165 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
166 | seq_logit = F.softmax(seq_logit, dim=-1)
167 | return seq_logit,
168 |
169 |
170 | class Transformer(nn.Module):
171 | ''' A sequence to sequence model with attention mechanism. '''
172 |
173 | def __init__(
174 | self,
175 | n_src_vocab, n_tgt_vocab, len_max_seq,
176 | d_word_vec=512, d_model=512, d_inner=2048,
177 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
178 | tgt_emb_prj_weight_sharing=True,
179 | emb_src_tgt_weight_sharing=True):
180 |
181 | super().__init__()
182 |
183 | self.encoder = TranEncoder(
184 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq,
185 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
186 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
187 | dropout=dropout)
188 |
189 | self.decoder = TranDecoder(
190 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
191 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
192 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
193 | dropout=dropout)
194 |
195 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
196 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
197 |
198 | assert d_model == d_word_vec, \
199 | 'To facilitate the residual connections, \
200 | the dimensions of all module outputs shall be the same.'
201 |
202 | if tgt_emb_prj_weight_sharing:
203 | # Share the weight matrix between target word embedding & the final logit dense layer
204 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
205 | self.x_logit_scale = (d_model ** -0.5)
206 | else:
207 | self.x_logit_scale = 1.
208 |
209 | if emb_src_tgt_weight_sharing:
210 | # Share the weight matrix between source & target word embeddings
211 | assert n_src_vocab == n_tgt_vocab, \
212 | "To share word embedding table, the vocabulary size of src/tgt shall be the same."
213 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
214 |
215 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
216 |
217 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]
218 |
219 | enc_output, *_ = self.encoder(src_seq, src_pos)
220 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
221 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
222 |
223 | return seq_logit.view(-1, seq_logit.size(2))
--------------------------------------------------------------------------------
/transformer/optim.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 |
5 | class ScheduledOptim():
6 | '''A simple wrapper class for learning rate scheduling'''
7 |
8 | def __init__(self, optimizer, d_model, n_warmup_steps):
9 | self._optimizer = optimizer
10 | self.n_warmup_steps = n_warmup_steps
11 | self.n_current_steps = 0
12 | self.init_lr = np.power(d_model, -0.5)
13 |
14 | def step_and_update_lr(self):
15 | "Step with the inner optimizer"
16 | self._update_learning_rate()
17 | self._optimizer.step()
18 |
19 | def zero_grad(self):
20 | "Zero out the gradients by the inner optimizer"
21 | self._optimizer.zero_grad()
22 |
23 | def _get_lr_scale(self):
24 | return np.min([
25 | np.power(self.n_current_steps, -0.5),
26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
27 |
28 | def _update_learning_rate(self):
29 | ''' Learning rate scheduling per step '''
30 |
31 | self.n_current_steps += 1
32 | lr = self.init_lr * self._get_lr_scale()
33 |
34 | for param_group in self._optimizer.param_groups:
35 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/transformer/sublayers.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | import math
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class PositionalEncoding(nn.Module):
12 | '''Positional Encoding module'''
13 |
14 | def __init__(self, d_model, max_len=5000):
15 | super().__init__()
16 |
17 | # Compute the positional encodings once in log space.
18 | pe = torch.zeros(max_len, d_model)
19 | position = torch.arange(0, max_len).unsqueeze(1)
20 | div_term = torch.exp(torch.arange(0, d_model, 2) *
21 | -(math.log(10000.0) / d_model))
22 | pe[:, 0::2] = torch.sin(position * div_term)
23 | pe[:, 1::2] = torch.cos(position * div_term)
24 | self.pe = pe.unsqueeze(0)
25 | self.register_buffer('PositionalEncoding', self.pe)
26 |
27 | def forward(self, x):
28 | return self.pe[:, :x.size(1)]
29 |
30 | class ScaledDotProductAttention(nn.Module):
31 | ''' Scaled Dot-Product Attention module'''
32 |
33 | def __init__(self, temperature, attn_dropout=0.1):
34 | super().__init__()
35 | self.temperature = temperature
36 | self.dropout = nn.Dropout(attn_dropout)
37 | self.softmax = nn.Softmax(dim=2)
38 |
39 | def forward(self, q, k, v, mask=None):
40 |
41 | attn = torch.bmm(q, k.transpose(1, 2))
42 | attn = attn / self.temperature
43 |
44 | if mask is not None:
45 | attn = attn.masked_fill(mask, -np.inf)
46 |
47 | attn = self.softmax(attn)
48 | attn = self.dropout(attn)
49 | output = torch.bmm(attn, v)
50 |
51 | return output, attn
52 |
53 | class MultiHeadAttention(nn.Module):
54 | ''' Multi-Head Attention module '''
55 |
56 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
57 | super().__init__()
58 |
59 | self.d_k = d_k
60 | self.d_v = d_v
61 | self.n_head = n_head
62 |
63 | self.w_qs = nn.Linear(d_model, n_head * d_k)
64 | self.w_ks = nn.Linear(d_model, n_head * d_k)
65 | self.w_vs = nn.Linear(d_model, n_head * d_v)
66 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
67 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
68 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
69 |
70 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
71 | self.layer_norm = nn.LayerNorm(d_model)
72 |
73 | self.fc = nn.Linear(n_head * d_v, d_model)
74 | nn.init.xavier_normal_(self.fc.weight)
75 |
76 | self.dropout = nn.Dropout(dropout)
77 |
78 |
79 | def forward(self, q, k, v, mask=None):
80 |
81 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
82 |
83 | sz_b, len_q, _ = q.size()
84 | sz_b, len_k, _ = k.size()
85 | sz_b, len_v, _ = v.size()
86 |
87 | residual = q
88 |
89 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
90 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
91 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
92 |
93 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
94 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
95 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
96 |
97 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
98 | output, attn = self.attention(q, k, v, mask=mask)
99 |
100 | output = output.view(n_head, sz_b, len_q, d_v)
101 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
102 |
103 | output = self.dropout(self.fc(output))
104 | output = self.layer_norm(output + residual)
105 |
106 | return output, attn
107 |
108 | class PositionwiseFeedForward(nn.Module):
109 | ''' A two-feed-forward-layer module '''
110 |
111 | def __init__(self, d_in, d_hid, dropout=0.1):
112 | super().__init__()
113 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
114 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
115 | self.layer_norm = nn.LayerNorm(d_in)
116 | self.dropout = nn.Dropout(dropout)
117 |
118 | def forward(self, x):
119 | residual = x
120 | output = x.transpose(1, 2)
121 | output = self.w_2(F.relu(self.w_1(output)))
122 | output = output.transpose(1, 2)
123 | output = self.dropout(output)
124 | output = self.layer_norm(output + residual)
125 | return output
--------------------------------------------------------------------------------
/transformer/tran_train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import unicode_literals, print_function, division
4 |
5 | import os
6 | import time
7 | import argparse
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | import torch
12 | import torch.optim as optim
13 |
14 | from transformer.model import Model
15 | from transformer.optim import ScheduledOptim
16 | from utils import config
17 | from utils.dataset import Vocab
18 | from utils.dataset import Batcher
19 | from utils.utils import get_input_from_batch
20 | from utils.utils import get_output_from_batch
21 | from utils.utils import calc_running_avg_loss
22 |
23 | use_cuda = config.use_gpu and torch.cuda.is_available()
24 |
25 | class Train(object):
26 | def __init__(self):
27 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
28 | self.batcher = Batcher(self.vocab, config.train_data_path,
29 | config.batch_size, single_pass=False, mode='train')
30 | time.sleep(10)
31 |
32 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
33 | if not os.path.exists(train_dir):
34 | os.mkdir(train_dir)
35 |
36 | self.model_dir = os.path.join(train_dir, 'models')
37 | if not os.path.exists(self.model_dir):
38 | os.mkdir(self.model_dir)
39 |
40 | self.summary_writer = tf.summary.FileWriter(train_dir)
41 |
42 | def save_model(self, running_avg_loss, iter):
43 | model_state_dict = self.model.state_dict()
44 |
45 | state = {
46 | 'iter': iter,
47 | 'current_loss': running_avg_loss,
48 | 'optimizer': self.optimizer._optimizer.state_dict(),
49 | "model": model_state_dict
50 | }
51 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
52 | torch.save(state, model_save_path)
53 |
54 | def setup_train(self, model_path):
55 |
56 | device = torch.device('cuda' if use_cuda else 'cpu')
57 |
58 | self.model = Model(
59 | config.vocab_size,
60 | config.vocab_size,
61 | config.max_enc_steps,
62 | config.max_dec_steps,
63 | d_k=config.d_k,
64 | d_v=config.d_v,
65 | d_model=config.d_model,
66 | d_word_vec=config.emb_dim,
67 | d_inner=config.d_inner_hid,
68 | n_layers=config.n_layers,
69 | n_head=config.n_head,
70 | dropout=config.dropout).to(device)
71 |
72 | self.optimizer = ScheduledOptim(
73 | optim.Adam(
74 | filter(lambda x: x.requires_grad, self.model.parameters()),
75 | betas=(0.9, 0.98), eps=1e-09),
76 | config.d_model, config.n_warmup_steps)
77 |
78 |
79 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters())
80 | total_params = sum([param[0].nelement() for param in params])
81 | print('The Number of params of model: %.3f million' % (total_params / 1e6)) # million
82 |
83 | start_iter, start_loss = 0, 0
84 |
85 | if model_path is not None:
86 | state = torch.load(model_path, map_location=lambda storage, location: storage)
87 | start_iter = state['iter']
88 | start_loss = state['current_loss']
89 |
90 | if not config.is_coverage:
91 | self.optimizer._optimizer.load_state_dict(state['optimizer'])
92 | if use_cuda:
93 | for state in self.optimizer._optimizer.state.values():
94 | for k, v in state.items():
95 | if torch.is_tensor(v):
96 | state[k] = v.cuda()
97 |
98 | return start_iter, start_loss
99 |
100 | def train_one_batch(self, batch):
101 | enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \
102 | extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda, transformer=True)
103 | dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
104 | get_output_from_batch(batch, use_cuda, transformer=True)
105 |
106 | self.optimizer.zero_grad()
107 |
108 | pred = self.model(enc_batch, enc_pos, dec_batch, dec_pos)
109 | gold_probs = torch.gather(pred, -1, tgt_batch.unsqueeze(-1)).squeeze()
110 | batch_loss = -torch.log(gold_probs + config.eps)
111 | batch_loss = batch_loss * dec_padding_mask
112 |
113 | sum_losses = torch.sum(batch_loss, 1)
114 | batch_avg_loss = sum_losses / dec_lens
115 | loss = torch.mean(batch_avg_loss)
116 |
117 | loss.backward()
118 |
119 | # update parameters
120 | self.optimizer.step_and_update_lr()
121 |
122 | return loss.item(), 0.
123 |
124 | def run(self, n_iters, model_path=None):
125 | iter, running_avg_loss = self.setup_train(model_path)
126 | start = time.time()
127 | interval = 100
128 |
129 | while iter < n_iters:
130 | batch = self.batcher.next_batch()
131 | loss, cove_loss = self.train_one_batch(batch)
132 |
133 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
134 | iter += 1
135 |
136 | if iter % interval == 0:
137 | self.summary_writer.flush()
138 | print(
139 | 'step: %d, second: %.2f , loss: %f, cover_loss: %f' % (iter, time.time() - start, loss, cove_loss))
140 | start = time.time()
141 | if iter % 5000 == 0:
142 | self.save_model(running_avg_loss, iter)
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser(description="Train script")
146 | parser.add_argument("-m",
147 | dest="model_path",
148 | required=False,
149 | default=None,
150 | help="Model file for retraining (default: None).")
151 | args = parser.parse_args()
152 |
153 | train_processor = Train()
154 | train_processor.run(config.max_iterations, args.model_path)
155 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/config.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/dataset.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/utils.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 |
5 | SENTENCE_STA = ''
6 | SENTENCE_END = ''
7 |
8 | UNK = 0
9 | PAD = 1
10 | BOS = 2
11 | EOS = 3
12 |
13 | PAD_TOKEN = '[PAD]'
14 | UNK_TOKEN = '[UNK]'
15 | BOS_TOKEN = '[BOS]'
16 | EOS_TOKEN = '[EOS]'
17 |
18 | beam_size=4
19 | emb_dim= 128
20 | batch_size= 16
21 | hidden_dim= 256
22 | max_enc_steps=400
23 | max_dec_steps=100
24 | max_tes_steps=100
25 | min_dec_steps=35
26 | vocab_size=50000
27 |
28 | lr=0.15
29 | cov_loss_wt = 1.0
30 | pointer_gen = True
31 | is_coverage = False
32 |
33 | max_grad_norm=2.0
34 | adagrad_init_acc=0.1
35 | rand_unif_init_mag=0.02
36 | trunc_norm_init_std=1e-4
37 |
38 | eps = 1e-12
39 | use_gpu=True
40 | lr_coverage=0.15
41 | max_iterations = 500000
42 |
43 | # transformer
44 | d_k = 64
45 | d_v = 64
46 | n_head = 6
47 | tran = True
48 | dropout = 0.1
49 | n_layers = 6
50 | d_model = 128
51 | d_inner = 512
52 | n_warmup_steps = 4000
53 |
54 | root_dir = os.path.expanduser("./")
55 | log_root = os.path.join(root_dir, "dataset/log/")
56 |
57 | #train_data_path = os.path.join(root_dir, "pointer_generator/dataset/finished_files/train.bin")
58 | train_data_path = os.path.join(root_dir, "dataset/finished_files/chunked/train_*")
59 | eval_data_path = os.path.join(root_dir, "dataset/finished_files/val.bin")
60 | decode_data_path = os.path.join(root_dir, "dataset/finished_files/test.bin")
61 | vocab_path = os.path.join(root_dir, "dataset/finished_files/vocab")
62 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import csv
4 | import glob
5 | import time
6 | import queue
7 | import struct
8 | import numpy as np
9 | import tensorflow as tf
10 | from random import shuffle
11 | from threading import Thread
12 | from tensorflow.core.example import example_pb2
13 |
14 | from utils import utils
15 | from utils import config
16 |
17 | import random
18 | random.seed(1234)
19 |
20 |
21 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
22 | SENTENCE_STA = ''
23 | SENTENCE_END = ''
24 |
25 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
26 | UNK_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
27 | BOS_TOKEN = '[BOS]' # This has a vocab id, which is used at the start of every decoder input sequence
28 | EOS_TOKEN = '[EOS]' # This has a vocab id, which is used at the end of untruncated target sequences
29 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
30 |
31 |
32 | class Vocab(object):
33 |
34 | def __init__(self, file, max_size):
35 | self.word2idx = {}
36 | self.idx2word = {}
37 | self.count = 0 # keeps track of total number of words in the Vocab
38 |
39 | # [UNK], [PAD], [BOS] and [EOS] get the ids 0,1,2,3.
40 | for w in [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]:
41 | self.word2idx[w] = self.count
42 | self.idx2word[self.count] = w
43 | self.count += 1
44 |
45 | # Read the vocab file and add words up to max_size
46 | with open(file, 'r') as fin:
47 | for line in fin:
48 | items = line.split()
49 | if len(items) != 2:
50 | print('Warning: incorrectly formatted line in vocabulary file: %s' % line.strip())
51 | continue
52 | w = items[0]
53 | if w in [SENTENCE_STA, SENTENCE_END, UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]:
54 | raise Exception(
55 | ', , [UNK], [PAD], [BOS] and [EOS] shouldn\'t be in the vocab file, but %s is' % w)
56 | if w in self.word2idx:
57 | raise Exception('Duplicated word in vocabulary file: %s' % w)
58 | self.word2idx[w] = self.count
59 | self.idx2word[self.count] = w
60 | self.count += 1
61 | if max_size != 0 and self.count >= max_size:
62 | break
63 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (
64 | self.count, self.idx2word[self.count - 1]))
65 |
66 | def word2id(self, word):
67 | if word not in self.word2idx:
68 | return self.word2idx[UNK_TOKEN]
69 | return self.word2idx[word]
70 |
71 | def id2word(self, word_id):
72 | if word_id not in self.idx2word:
73 | raise ValueError('Id not found in vocab: %d' % word_id)
74 | return self.idx2word[word_id]
75 |
76 | def size(self):
77 | return self.count
78 |
79 | def write_metadata(self, path):
80 | print( "Writing word embedding metadata file to %s..." % (path))
81 | with open(path, "w") as f:
82 | fieldnames = ['word']
83 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
84 | for i in range(self.size()):
85 | writer.writerow({"word": self.idx2word[i]})
86 |
87 | class Example(object):
88 |
89 | def __init__(self, article, abstract_sentences, vocab):
90 | # Get ids of special tokens
91 | bos_decoding = vocab.word2id(BOS_TOKEN)
92 | eos_decoding = vocab.word2id(EOS_TOKEN)
93 |
94 | # Process the article
95 | article_words = article.decode().split()
96 | if len(article_words) > config.max_enc_steps:
97 | article_words = article_words[:config.max_enc_steps]
98 | self.enc_len = len(article_words) # store the length after truncation but before padding
99 | self.enc_inp = [vocab.word2id(w) for w in
100 | article_words] # list of word ids; OOVs are represented by the id for UNK token
101 |
102 | # Process the abstract
103 | abstract = ' '.encode().join(abstract_sentences).decode()
104 | abstract_words = abstract.split() # list of strings
105 | abs_ids = [vocab.word2id(w) for w in
106 | abstract_words] # list of word ids; OOVs are represented by the id for UNK token
107 |
108 | # Get the decoder input sequence and target sequence
109 | self.dec_inp, self.tgt = self.get_dec_seq(abs_ids, config.max_dec_steps, bos_decoding, eos_decoding)
110 | self.dec_len = len(self.dec_inp)
111 |
112 | # If using pointer-generator mode, we need to store some extra info
113 | if config.pointer_gen:
114 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id;
115 | # also store the in-article OOVs words themselves
116 | self.enc_inp_extend_vocab, self.article_oovs = utils.article2ids(article_words, vocab)
117 |
118 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
119 | abs_ids_extend_vocab = utils.abstract2ids(abstract_words, vocab, self.article_oovs)
120 |
121 | # Overwrite decoder target sequence so it uses the temp article OOV ids
122 | _, self.tgt = self.get_dec_seq(abs_ids_extend_vocab, config.max_dec_steps, bos_decoding, eos_decoding)
123 |
124 | # Store the original strings
125 | self.original_article = article
126 | self.original_abstract = abstract
127 | self.original_abstract_sents = abstract_sentences
128 |
129 | def get_dec_seq(self, sequence, max_len, start_id, stop_id):
130 | src = [start_id] + sequence[:]
131 | tgt = sequence[:]
132 | if len(src) > max_len: # truncate
133 | src = src[:max_len]
134 | tgt = tgt[:max_len] # no end_token
135 | else: # no truncation
136 | tgt.append(stop_id) # end token
137 | assert len(src) == len(tgt)
138 | return src, tgt
139 |
140 | def pad_enc_seq(self, max_len, pad_id):
141 | while len(self.enc_inp) < max_len:
142 | self.enc_inp.append(pad_id)
143 | if config.pointer_gen:
144 | while len(self.enc_inp_extend_vocab) < max_len:
145 | self.enc_inp_extend_vocab.append(pad_id)
146 |
147 | def pad_dec_seq(self, max_len, pad_id):
148 | while len(self.dec_inp) < max_len:
149 | self.dec_inp.append(pad_id)
150 | while len(self.tgt) < max_len:
151 | self.tgt.append(pad_id)
152 |
153 |
154 | class Batch(object):
155 | def __init__(self, example_list, vocab, batch_size):
156 | self.batch_size = batch_size
157 | self.pad_id = vocab.word2id(PAD_TOKEN) # id of the PAD token used to pad sequences
158 | self.init_encoder_seq(example_list) # initialize the input to the encoder
159 | self.init_decoder_seq(example_list) # initialize the input and targets for the decoder
160 | self.store_orig_strings(example_list) # store the original strings
161 |
162 | def init_encoder_seq(self, example_list):
163 | # Determine the maximum length of the encoder input sequence in this batch
164 | max_enc_seq_len = max([ex.enc_len for ex in example_list])
165 |
166 | # Pad the encoder input sequences up to the length of the longest sequence
167 | for ex in example_list:
168 | ex.pad_enc_seq(max_enc_seq_len, self.pad_id)
169 |
170 | # Initialize the numpy arrays
171 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder.
172 | self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
173 | self.enc_lens = np.zeros((self.batch_size), dtype=np.int32)
174 | self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32)
175 |
176 | # Fill in the numpy arrays
177 | for i, ex in enumerate(example_list):
178 | self.enc_batch[i, :] = ex.enc_inp[:]
179 | self.enc_lens[i] = ex.enc_len
180 | for j in range(ex.enc_len):
181 | self.enc_padding_mask[i][j] = 1
182 |
183 | # For pointer-generator mode, need to store some extra info
184 | if config.pointer_gen:
185 | # Determine the max number of in-article OOVs in this batch
186 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
187 | # Store the in-article OOVs themselves
188 | self.art_oovs = [ex.article_oovs for ex in example_list]
189 | # Store the version of the enc_batch that uses the article OOV ids
190 | self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
191 | for i, ex in enumerate(example_list):
192 | self.enc_batch_extend_vocab[i, :] = ex.enc_inp_extend_vocab[:]
193 |
194 | def init_decoder_seq(self, example_list):
195 | # Pad the inputs and targets
196 | for ex in example_list:
197 | ex.pad_dec_seq(config.max_dec_steps, self.pad_id)
198 |
199 | # Initialize the numpy arrays.
200 | self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
201 | self.tgt_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
202 | self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32)
203 | self.dec_lens = np.zeros((self.batch_size), dtype=np.int32)
204 |
205 | # Fill in the numpy arrays
206 | for i, ex in enumerate(example_list):
207 | self.dec_batch[i, :] = ex.dec_inp[:]
208 | self.tgt_batch[i, :] = ex.tgt[:]
209 | self.dec_lens[i] = ex.dec_len
210 | for j in range(ex.dec_len):
211 | self.dec_padding_mask[i][j] = 1
212 |
213 | def store_orig_strings(self, example_list):
214 | self.original_articles = [ex.original_article for ex in example_list] # list of lists
215 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
216 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists
217 |
218 |
219 | class Batcher(object):
220 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold
221 |
222 | def __init__(self, vocab, data_path, batch_size, single_pass, mode):
223 | self._vocab = vocab
224 | self._data_path = data_path
225 | self.batch_size = batch_size
226 | self.single_pass = single_pass
227 | self.mode = mode
228 |
229 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
230 | self._batch_queue = queue.Queue(self.BATCH_QUEUE_MAX)
231 | self._example_queue = queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size)
232 |
233 | # Different settings depending on whether we're in single_pass mode or not
234 | if single_pass:
235 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
236 | self._num_batch_q_threads = 1 # just one thread to batch examples
237 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing
238 | self._finished_reading = False # this will tell us when we're finished reading the dataset
239 | else:
240 | self._num_example_q_threads = 1 # num threads to fill example queue
241 | self._num_batch_q_threads = 1 # num threads to fill batch queue
242 | self._bucketing_cache_size = 1 # how many batches-worth of examples to load into cache before bucketing
243 |
244 | # Start the threads that load the queues
245 | self._example_q_threads = []
246 | for _ in range(self._num_example_q_threads):
247 | self._example_q_threads.append(Thread(target=self.fill_example_queue))
248 | self._example_q_threads[-1].daemon = True
249 | self._example_q_threads[-1].start()
250 | self._batch_q_threads = []
251 | for _ in range(self._num_batch_q_threads):
252 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
253 | self._batch_q_threads[-1].daemon = True
254 | self._batch_q_threads[-1].start()
255 |
256 | # Start a thread that watches the other threads and restarts them if they're dead
257 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
258 | self._watch_thread = Thread(target=self.watch_threads)
259 | self._watch_thread.daemon = True
260 | self._watch_thread.start()
261 |
262 | def next_batch(self):
263 | # If the batch queue is empty, print a warning
264 | if self._batch_queue.qsize() == 0:
265 | tf.logging.warning(
266 | 'Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i',
267 | self._batch_queue.qsize(), self._example_queue.qsize())
268 | if self.single_pass and self._finished_reading:
269 | tf.logging.info("Finished reading dataset in single_pass mode.")
270 | return None
271 |
272 | batch = self._batch_queue.get() # get the next Batch
273 | return batch
274 |
275 | def fill_example_queue(self):
276 | example_generator = self.example_generator(self._data_path, self.single_pass)
277 | input_gen = self.pair_generator(example_generator)
278 |
279 | while True:
280 | try:
281 | (article,
282 | abstract) = input_gen.__next__() # read the next example from file. article and abstract are both strings.
283 | except StopIteration: # if there are no more examples:
284 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
285 | if self.single_pass:
286 | tf.logging.info(
287 | "single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
288 | self._finished_reading = True
289 | break
290 | else:
291 | raise Exception("single_pass mode is off but the example generator is out of data; error.")
292 |
293 | abstract_sentences = [sent.strip() for sent in utils.abstract2sents(
294 | abstract)] # Use the and tags in abstract to get a list of sentences.
295 | example = Example(article, abstract_sentences, self._vocab)
296 | self._example_queue.put(example)
297 |
298 | def fill_batch_queue(self):
299 | while True:
300 | if self.mode == 'decode':
301 | # beam search decode mode single example repeated in the batch
302 | ex = self._example_queue.get()
303 | b = [ex for _ in range(self.batch_size)]
304 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
305 | else:
306 | # Get bucketing_cache_size-many batches of Examples into a list, then sort
307 | inputs = []
308 | for _ in range(self.batch_size * self._bucketing_cache_size):
309 | inputs.append(self._example_queue.get())
310 | inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence
311 |
312 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
313 | batches = []
314 | for i in range(0, len(inputs), self.batch_size):
315 | batches.append(inputs[i:i + self.batch_size])
316 | if not self.single_pass:
317 | shuffle(batches)
318 | for b in batches: # each b is a list of Example objects
319 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
320 |
321 | def watch_threads(self):
322 | while True:
323 | tf.logging.info(
324 | 'Bucket queue size: %i, Input queue size: %i',
325 | self._batch_queue.qsize(), self._example_queue.qsize())
326 |
327 | time.sleep(60)
328 | for idx, t in enumerate(self._example_q_threads):
329 | if not t.is_alive(): # if the thread is dead
330 | tf.logging.error('Found example queue thread dead. Restarting.')
331 | new_t = Thread(target=self.fill_example_queue)
332 | self._example_q_threads[idx] = new_t
333 | new_t.daemon = True
334 | new_t.start()
335 | for idx, t in enumerate(self._batch_q_threads):
336 | if not t.is_alive(): # if the thread is dead
337 | tf.logging.error('Found batch queue thread dead. Restarting.')
338 | new_t = Thread(target=self.fill_batch_queue)
339 | self._batch_q_threads[idx] = new_t
340 | new_t.daemon = True
341 | new_t.start()
342 |
343 | def pair_generator(self, example_generator):
344 | while True:
345 | e = example_generator.__next__() # e is a tf.Example
346 | try:
347 | article_text = e.features.feature['article'].bytes_list.value[
348 | 0] # the article text was saved under the key 'article' in the data files
349 | abstract_text = e.features.feature['abstract'].bytes_list.value[
350 | 0] # the abstract text was saved under the key 'abstract' in the data files
351 | except ValueError:
352 | tf.logging.error('Failed to get article or abstract from example')
353 | continue
354 | if len(article_text) == 0: # See https://github.com/abisee/pointer-generator/issues/1
355 | # tf.logging.warning('Found an example with empty article text. Skipping it.')
356 | continue
357 | else:
358 | yield (article_text, abstract_text)
359 |
360 | def example_generator(self, data_path, single_pass):
361 | while True:
362 | filelist = glob.glob(data_path) # get the list of datafiles
363 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
364 | if single_pass:
365 | filelist = sorted(filelist)
366 | else:
367 | random.shuffle(filelist)
368 | for f in filelist:
369 | reader = open(f, 'rb')
370 | while True:
371 | len_bytes = reader.read(8)
372 | if not len_bytes: break # finished reading this file
373 | str_len = struct.unpack('q', len_bytes)[0]
374 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
375 | yield example_pb2.Example.FromString(example_str)
376 | if single_pass:
377 | print("example_generator completed reading all datafiles. No more data.")
378 | break
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import pyrouge
5 | import logging
6 | import numpy as np
7 |
8 | import torch
9 | import tensorflow as tf
10 | from torch.autograd import Variable
11 |
12 | from utils import config
13 |
14 |
15 | def article2ids(article_words, vocab):
16 | ids = []
17 | oov = []
18 | unk_id = vocab.word2id(config.UNK_TOKEN)
19 | for w in article_words:
20 | i = vocab.word2id(w)
21 | if i == unk_id: # If w is OOV
22 | if w not in oov: # Add to list of OOVs
23 | oov.append(w)
24 | oov_num = oov.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
25 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
26 | else:
27 | ids.append(i)
28 | return ids, oov
29 |
30 |
31 | def abstract2ids(abstract_words, vocab, article_oovs):
32 | ids = []
33 | unk_id = vocab.word2id(config.UNK_TOKEN)
34 | for w in abstract_words:
35 | i = vocab.word2id(w)
36 | if i == unk_id: # If w is an OOV word
37 | if w in article_oovs: # If w is an in-article OOV
38 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
39 | ids.append(vocab_idx)
40 | else: # If w is an out-of-article OOV
41 | ids.append(unk_id) # Map to the UNK token id
42 | else:
43 | ids.append(i)
44 | return ids
45 |
46 |
47 | def outputids2words(id_list, vocab, article_oovs):
48 | words = []
49 | for i in id_list:
50 | try:
51 | w = vocab.id2word(i) # might be [UNK]
52 | except ValueError as e: # w is OOV
53 | assert article_oovs is not None, \
54 | "Error: models produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
55 | article_oov_idx = i - vocab.size()
56 | try:
57 | w = article_oovs[article_oov_idx]
58 | except ValueError as e: # i doesn't correspond to an article oov
59 | raise ValueError(
60 | 'Error: models produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (
61 | i, article_oov_idx, len(article_oovs)))
62 | words.append(w)
63 | return words
64 |
65 |
66 | def abstract2sents(abstract):
67 | cur_p = 0
68 | sents = []
69 | while True:
70 | try:
71 | sta_p = abstract.index(config.SENTENCE_STA.encode(), cur_p)
72 | end_p = abstract.index(config.SENTENCE_END.encode(), sta_p + 1)
73 | cur_p = end_p + len(config.SENTENCE_END.encode())
74 | sents.append(abstract[sta_p + len(config.SENTENCE_STA.encode()):end_p])
75 | except ValueError as e: # no more sentences
76 | return sents
77 |
78 |
79 | def show_art_oovs(article, vocab):
80 | unk_token = vocab.word2id(config.UNK_TOKEN)
81 | words = article.split(' ')
82 | words = [("__%s__" % w) if vocab.word2id(w) == unk_token else w for w in words]
83 | out_str = ' '.join(words)
84 | return out_str
85 |
86 |
87 | def show_abs_oovs(abstract, vocab, article_oovs):
88 | unk_token = vocab.word2id(config.UNK_TOKEN)
89 | words = abstract.split(' ')
90 | new_words = []
91 | for w in words:
92 | if vocab.word2id(w) == unk_token: # w is oov
93 | if article_oovs is None: # baseline mode
94 | new_words.append("__%s__" % w)
95 | else: # pointer-generator mode
96 | if w in article_oovs:
97 | new_words.append("__%s__" % w)
98 | else:
99 | new_words.append("!!__%s__!!" % w)
100 | else: # w is in-vocab word
101 | new_words.append(w)
102 | out_str = ' '.join(new_words)
103 | return out_str
104 |
105 |
106 | def print_results(article, abstract, decoded_output):
107 | print("")
108 | print('ARTICLE: %s', article)
109 | print('REFERENCE SUMMARY: %s', abstract)
110 | print('GENERATED SUMMARY: %s', decoded_output)
111 | print("")
112 |
113 |
114 | def make_html_safe(s):
115 | s.replace("<", "<")
116 | s.replace(">", ">")
117 | return s
118 |
119 |
120 | def rouge_eval(ref_dir, dec_dir):
121 | r = pyrouge.Rouge155()
122 | r.model_filename_pattern = '#ID#_reference.txt'
123 | r.system_filename_pattern = '(\d+)_decoded.txt'
124 | r.model_dir = ref_dir
125 | r.system_dir = dec_dir
126 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
127 | rouge_results = r.convert_and_evaluate()
128 | return r.output_to_dict(rouge_results)
129 |
130 |
131 | def rouge_log(results_dict, dir_to_write):
132 | log_str = ""
133 | for x in ["1", "2", "l"]:
134 | log_str += "\nROUGE-%s:\n" % x
135 | for y in ["f_score", "recall", "precision"]:
136 | key = "rouge_%s_%s" % (x, y)
137 | key_cb = key + "_cb"
138 | key_ce = key + "_ce"
139 | val = results_dict[key]
140 | val_cb = results_dict[key_cb]
141 | val_ce = results_dict[key_ce]
142 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
143 | print(log_str)
144 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
145 | print("Writing final ROUGE results to %s..." % (results_file))
146 | with open(results_file, "w") as f:
147 | f.write(log_str)
148 |
149 |
150 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
151 | if running_avg_loss == 0: # on the first iteration just take the loss
152 | running_avg_loss = loss
153 | else:
154 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
155 | running_avg_loss = min(running_avg_loss, 12) # clip
156 | loss_sum = tf.Summary()
157 | tag_name = 'running_avg_loss/decay=%f' % (decay)
158 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
159 | summary_writer.add_summary(loss_sum, step)
160 | return running_avg_loss
161 |
162 |
163 | def write_for_rouge(reference_sents, decoded_words, ex_index,
164 | _rouge_ref_dir, _rouge_dec_dir):
165 | decoded_sents = []
166 | while len(decoded_words) > 0:
167 | try:
168 | fst_period_idx = decoded_words.index(".")
169 | except ValueError:
170 | fst_period_idx = len(decoded_words)
171 | sent = decoded_words[:fst_period_idx + 1]
172 | decoded_words = decoded_words[fst_period_idx + 1:]
173 | decoded_sents.append(' '.join(sent))
174 |
175 | # pyrouge calls a perl script that puts the data into HTML files.
176 | # Therefore we need to make our output HTML safe.
177 | decoded_sents = [make_html_safe(w) for w in decoded_sents]
178 | reference_sents = [make_html_safe(w) for w in reference_sents]
179 |
180 | ref_file = os.path.join(_rouge_ref_dir, "%06d_reference.txt" % ex_index)
181 | decoded_file = os.path.join(_rouge_dec_dir, "%06d_decoded.txt" % ex_index)
182 |
183 | with open(ref_file, "w") as f:
184 | for idx, sent in enumerate(reference_sents):
185 | f.write(sent) if idx == len(reference_sents) - 1 else f.write(sent + "\n")
186 | with open(decoded_file, "w") as f:
187 | for idx, sent in enumerate(decoded_sents):
188 | f.write(sent) if idx == len(decoded_sents) - 1 else f.write(sent + "\n")
189 |
190 | # print("Wrote example %i to file" % ex_index)
191 |
192 |
193 | def get_input_from_batch(batch, use_cuda):
194 | extra_zeros = None
195 | enc_lens = batch.enc_lens
196 | max_enc_len = np.max(enc_lens)
197 | enc_batch_extend_vocab = None
198 | batch_size = len(batch.enc_lens)
199 | enc_batch = Variable(torch.from_numpy(batch.enc_batch).long())
200 | enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float()
201 |
202 | if config.pointer_gen:
203 | enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long())
204 | # max_art_oovs is the max over all the article oov list in the batch
205 | if batch.max_art_oovs > 0:
206 | extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs)))
207 |
208 | c_t = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))
209 |
210 | coverage = None
211 | if config.is_coverage:
212 | coverage = Variable(torch.zeros(enc_batch.size()))
213 |
214 | enc_pos = np.zeros((batch_size, max_enc_len))
215 | for i, inst in enumerate(batch.enc_batch):
216 | for j, w_i in enumerate(inst):
217 | if w_i != config.PAD:
218 | enc_pos[i, j] = (j + 1)
219 | else:
220 | break
221 | enc_pos = Variable(torch.from_numpy(enc_pos).long())
222 |
223 | if use_cuda:
224 | c_t = c_t.cuda()
225 | enc_pos = enc_pos.cuda()
226 | enc_batch = enc_batch.cuda()
227 | enc_padding_mask = enc_padding_mask.cuda()
228 |
229 | if coverage is not None:
230 | coverage = coverage.cuda()
231 |
232 | if extra_zeros is not None:
233 | extra_zeros = extra_zeros.cuda()
234 |
235 | if enc_batch_extend_vocab is not None:
236 | enc_batch_extend_vocab = enc_batch_extend_vocab.cuda()
237 |
238 |
239 | return enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t, coverage
240 |
241 |
242 | def get_output_from_batch(batch, use_cuda):
243 | dec_lens = batch.dec_lens
244 | max_dec_len = np.max(dec_lens)
245 | batch_size = len(batch.dec_lens)
246 | dec_lens = Variable(torch.from_numpy(dec_lens)).float()
247 | tgt_batch = Variable(torch.from_numpy(batch.tgt_batch)).long()
248 | dec_batch = Variable(torch.from_numpy(batch.dec_batch).long())
249 | dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float()
250 |
251 | dec_pos = np.zeros((batch_size, config.max_dec_steps))
252 | for i, inst in enumerate(batch.dec_batch):
253 | for j, w_i in enumerate(inst):
254 | if w_i != config.PAD:
255 | dec_pos[i, j] = (j + 1)
256 | else:
257 | break
258 | dec_pos = Variable(torch.from_numpy(dec_pos).long())
259 |
260 | if use_cuda:
261 | dec_lens = dec_lens.cuda()
262 | tgt_batch = tgt_batch.cuda()
263 | dec_batch = dec_batch.cuda()
264 | dec_padding_mask = dec_padding_mask.cuda()
265 | dec_pos = dec_pos.cuda()
266 |
267 | return dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch
268 |
--------------------------------------------------------------------------------