├── README.md ├── dataset └── url_lists │ ├── all_test.txt │ ├── all_train.txt │ ├── all_val.txt │ ├── cnn_wayback_test_urls.txt │ ├── cnn_wayback_training_urls.txt │ ├── cnn_wayback_validation_urls.txt │ ├── dailymail_wayback_test_urls.txt │ ├── dailymail_wayback_training_urls.txt │ ├── dailymail_wayback_validation_urls.txt │ └── readme ├── eval.py ├── make_datafiles.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── attention.cpython-35.pyc │ ├── basic.cpython-35.pyc │ ├── layers.cpython-35.pyc │ └── model.cpython-35.pyc ├── attention.py ├── basic.py ├── layers.py └── model.py ├── requirements.txt ├── test.py ├── train.py ├── transformer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-37.pyc │ ├── layers.cpython-35.pyc │ ├── model.cpython-35.pyc │ ├── optim.cpython-35.pyc │ └── sublayers.cpython-35.pyc ├── layers.py ├── model.py ├── optim.py ├── sublayers.py └── tran_train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── config.cpython-35.pyc ├── dataset.cpython-35.pyc └── utils.cpython-35.pyc ├── config.py ├── dataset.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Pointer-Generator-Pytorch 2 | 3 | ## About 4 | The pytorch implementation of [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368). 5 | This implementation also tries to use Transformer as the encoder. 6 | The project are heavily borrowed from [atulkum-pointer_summarizer](https://github.com/atulkum/pointer_summarizer.git) and 7 | [jadore801120-attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch). 8 | 9 | ## Requirements 10 | * python==3.7.4 11 | * pytorch==1.4.0 12 | * pyrouge==0.1.3 13 | * tensorflow>=1.13.1 14 | 15 | ## Quick start 16 | * The path and parameters of project: 17 | you might need to change some path and parameters in utils/config.py according your setup. 18 | * Dataset: 19 | you can download the CNN/DailyMail dataset from https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail, 20 | then run make-datafiles.py to process data. For the specific process, you can refer to https://github.com/abisee/cnn-dailymail. 21 | * Run: 22 | you can run train.py, eval.py, and test for training, evaluating, and test, respectively. 23 | 24 | ### Note: 25 | * There is only single example repeated across the batch in the decode mode of beam search. 26 | 27 | 28 | -------------------------------------------------------------------------------- /dataset/url_lists/readme: -------------------------------------------------------------------------------- 1 | Note: all_train.txt is simply the concatenation of cnn_wayback_training_urls.txt and dailymail_wayback_training_urls.txt. 2 | 3 | Similarly for validation and test sets. 4 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import unicode_literals, print_function, division 4 | 5 | import os 6 | import time 7 | import sys 8 | 9 | import tensorflow as tf 10 | import torch 11 | 12 | from models.model import Model 13 | from utils import config 14 | from utils.dataset import Vocab 15 | from utils.dataset import Batcher 16 | from utils.utils import get_input_from_batch 17 | from utils.utils import get_output_from_batch 18 | from utils.utils import calc_running_avg_loss 19 | 20 | use_cuda = config.use_gpu and torch.cuda.is_available() 21 | 22 | class Evaluate(object): 23 | def __init__(self, model_path): 24 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 25 | self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', 26 | batch_size=config.batch_size, single_pass=True) 27 | time.sleep(15) 28 | model_name = os.path.basename(model_path) 29 | 30 | eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name)) 31 | if not os.path.exists(eval_dir): 32 | os.mkdir(eval_dir) 33 | self.summary_writer = tf.summary.FileWriter(eval_dir) 34 | 35 | self.model = Model(model_path, is_eval=True) 36 | 37 | def eval_one_batch(self, batch): 38 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \ 39 | get_input_from_batch(batch, use_cuda) 40 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, tgt_batch = \ 41 | get_output_from_batch(batch, use_cuda) 42 | 43 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens) 44 | s_t = self.model.reduce_state(enc_h) 45 | 46 | step_losses = [] 47 | for di in range(min(max_dec_len, config.max_dec_steps)): 48 | y_t = dec_batch[:, di] # Teacher forcing 49 | final_dist, s_t, c_t,attn_dist, p_gen, next_coverage = self.model.decoder(y_t, s_t, 50 | enc_out, enc_fea, enc_padding_mask, c_t, 51 | extra_zeros, enc_batch_extend_vocab, coverage, di) 52 | tgt = tgt_batch[:, di] 53 | gold_probs = torch.gather(final_dist, 1, tgt.unsqueeze(1)).squeeze() 54 | step_loss = -torch.log(gold_probs + config.eps) 55 | if config.is_coverage: 56 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) 57 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 58 | coverage = next_coverage 59 | 60 | step_mask = dec_padding_mask[:, di] 61 | step_loss = step_loss * step_mask 62 | step_losses.append(step_loss) 63 | 64 | sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1) 65 | batch_avg_loss = sum_step_losses / dec_lens_var 66 | loss = torch.mean(batch_avg_loss) 67 | 68 | return loss.data[0] 69 | 70 | def run(self): 71 | start = time.time() 72 | running_avg_loss, iter = 0, 0 73 | batch = self.batcher.next_batch() 74 | print_interval = 100 75 | while batch is not None: 76 | loss = self.eval_one_batch(batch) 77 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 78 | iter += 1 79 | 80 | if iter % print_interval == 0: 81 | self.summary_writer.flush() 82 | print('step: %d, second: %.2f , loss: %f' % (iter, time.time() - start, running_avg_loss)) 83 | start = time.time() 84 | batch = self.batcher.next_batch() 85 | 86 | return running_avg_loss 87 | 88 | 89 | if __name__ == '__main__': 90 | model_filename = sys.argv[1] 91 | eval_processor = Evaluate(model_filename) 92 | eval_processor.run() 93 | 94 | 95 | -------------------------------------------------------------------------------- /make_datafiles.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # This code is from https://github.com/abisee/cnn-dailymail.git 4 | 5 | import sys 6 | import os 7 | import hashlib 8 | import struct 9 | import subprocess 10 | import collections 11 | from tensorflow.core.example import example_pb2 12 | 13 | dm_single_close_quote = u'\u2019' # unicode 14 | dm_double_close_quote = u'\u201d' 15 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, 16 | ")"] # acceptable ways to end a sentence 17 | 18 | # We use these to separate the summary sentences in the .bin datafiles 19 | SENTENCE_START = '' 20 | SENTENCE_END = '' 21 | 22 | all_train_urls = "dataset/url_lists/all_train.txt" 23 | all_val_urls = "dataset/url_lists/all_val.txt" 24 | all_test_urls = "dataset/url_lists/all_test.txt" 25 | 26 | cnn_tokenized_stories_dir = "dataset/cnn_stories_tokenized" 27 | dm_tokenized_stories_dir = "dataset/dm_stories_tokenized" 28 | finished_files_dir = "dataset/finished_files" 29 | chunks_dir = os.path.join(finished_files_dir, "chunked") 30 | 31 | # These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir 32 | num_expected_cnn_stories = 92579 33 | num_expected_dm_stories = 219506 34 | 35 | VOCAB_SIZE = 200000 36 | CHUNK_SIZE = 1000 # num examples per chunk, for the chunked data 37 | 38 | 39 | def chunk_file(set_name): 40 | in_file = 'dataset/finished_files/%s.bin' % set_name 41 | reader = open(in_file, "rb") 42 | chunk = 0 43 | finished = False 44 | while not finished: 45 | chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (set_name, chunk)) # new chunk 46 | with open(chunk_fname, 'wb') as writer: 47 | for _ in range(CHUNK_SIZE): 48 | len_bytes = reader.read(8) 49 | if not len_bytes: 50 | finished = True 51 | break 52 | str_len = struct.unpack('q', len_bytes)[0] 53 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 54 | writer.write(struct.pack('q', str_len)) 55 | writer.write(struct.pack('%ds' % str_len, example_str)) 56 | chunk += 1 57 | 58 | 59 | def chunk_all(): 60 | # Make a dir to hold the chunks 61 | if not os.path.isdir(chunks_dir): 62 | os.mkdir(chunks_dir) 63 | # Chunk the data 64 | for set_name in ['train', 'val', 'test']: 65 | print("Splitting %s data into chunks..." % set_name) 66 | chunk_file(set_name) 67 | print("Saved chunked data in %s" % chunks_dir) 68 | 69 | 70 | def tokenize_stories(stories_dir, tokenized_stories_dir): 71 | """Maps a whole directory of .story files to a tokenized version using Stanford CoreNLP Tokenizer""" 72 | print("Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir)) 73 | stories = os.listdir(stories_dir) 74 | # make IO list file 75 | print("Making list of files to tokenize...") 76 | with open("mapping.txt", "w") as f: 77 | for s in stories: 78 | f.write("%s \t %s\n" % (os.path.join(stories_dir, s), os.path.join(tokenized_stories_dir, s))) 79 | command = ['java', 'edu.stanford.nlp.process.PTBTokenizer', '-ioFileList', '-preserveLines', 'mapping.txt'] 80 | print("Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir)) 81 | subprocess.call(command) 82 | print("Stanford CoreNLP Tokenizer has finished.") 83 | os.remove("mapping.txt") 84 | 85 | # Check that the tokenized stories directory contains the same number of files as the original directory 86 | num_orig = len(os.listdir(stories_dir)) 87 | num_tokenized = len(os.listdir(tokenized_stories_dir)) 88 | if num_orig != num_tokenized: 89 | raise Exception( 90 | "The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % ( 91 | tokenized_stories_dir, num_tokenized, stories_dir, num_orig)) 92 | print("Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir)) 93 | 94 | 95 | def read_text_file(text_file): 96 | lines = [] 97 | with open(text_file, "r") as f: 98 | for line in f: 99 | lines.append(line.strip()) 100 | return lines 101 | 102 | 103 | def hashhex(s): 104 | """Returns a heximal formated SHA1 hash of the input string.""" 105 | h = hashlib.sha1() 106 | h.update(s.encode('utf-8')) 107 | return h.hexdigest() 108 | 109 | 110 | def get_url_hashes(url_list): 111 | return [hashhex(url) for url in url_list] 112 | 113 | 114 | def fix_missing_period(line): 115 | """Adds a period to a line that is missing a period""" 116 | if "@highlight" in line: return line 117 | if line == "": return line 118 | if line[-1] in END_TOKENS: return line 119 | # print line[-1] 120 | return line + " ." 121 | 122 | 123 | def get_art_abs(story_file): 124 | lines = read_text_file(story_file) 125 | 126 | # Lowercase everything 127 | lines = [line.lower() for line in lines] 128 | 129 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) 130 | lines = [fix_missing_period(line) for line in lines] 131 | 132 | # Separate out article and abstract sentences 133 | article_lines = [] 134 | highlights = [] 135 | next_is_highlight = False 136 | for idx, line in enumerate(lines): 137 | if line == "": 138 | continue # empty line 139 | elif line.startswith("@highlight"): 140 | next_is_highlight = True 141 | elif next_is_highlight: 142 | highlights.append(line) 143 | else: 144 | article_lines.append(line) 145 | 146 | # Make article into a single string 147 | article = ' '.join(article_lines) 148 | 149 | # Make abstract into a signle string, putting and tags around the sentences 150 | abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) 151 | 152 | return article, abstract 153 | 154 | 155 | def write_to_bin(url_file, out_file, makevocab=False): 156 | """Reads the tokenized .story files corresponding to the urls listed in the url_file and writes them to a out_file.""" 157 | print("Making bin file for URLs listed in %s..." % url_file) 158 | url_list = read_text_file(url_file) 159 | url_hashes = get_url_hashes(url_list) 160 | story_fnames = [s + ".story" for s in url_hashes] 161 | num_stories = len(story_fnames) 162 | 163 | if makevocab: 164 | vocab_counter = collections.Counter() 165 | 166 | with open(out_file, 'wb') as writer: 167 | for idx, s in enumerate(story_fnames): 168 | if idx % 1000 == 0: 169 | print("Writing story %i of %i; %.2f percent done" % ( 170 | idx, num_stories, float(idx) * 100.0 / float(num_stories))) 171 | 172 | # Look in the tokenized story dirs to find the .story file corresponding to this url 173 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)): 174 | story_file = os.path.join(cnn_tokenized_stories_dir, s) 175 | elif os.path.isfile(os.path.join(dm_tokenized_stories_dir, s)): 176 | story_file = os.path.join(dm_tokenized_stories_dir, s) 177 | else: 178 | print("Error: Couldn't find tokenized story file %s in either tokenized story directories %s and %s. Was there an error during tokenization?" % ( 179 | s, cnn_tokenized_stories_dir, dm_tokenized_stories_dir)) 180 | # Check again if tokenized stories directories contain correct number of files 181 | print("Checking that the tokenized stories directories %s and %s contain correct number of files..." % ( 182 | cnn_tokenized_stories_dir, dm_tokenized_stories_dir)) 183 | check_num_stories(cnn_tokenized_stories_dir, num_expected_cnn_stories) 184 | check_num_stories(dm_tokenized_stories_dir, num_expected_dm_stories) 185 | raise Exception( 186 | "Tokenized stories directories %s and %s contain correct number of files but story file %s found in neither." % ( 187 | cnn_tokenized_stories_dir, dm_tokenized_stories_dir, s)) 188 | 189 | # Get the strings to write to .bin file 190 | article, abstract = get_art_abs(story_file) 191 | 192 | # Write to tf.Example 193 | tf_example = example_pb2.Example() 194 | tf_example.features.feature['article'].bytes_list.value.extend([bytes(article, encoding='utf-8')]) 195 | tf_example.features.feature['abstract'].bytes_list.value.extend([bytes(abstract, encoding='utf-8')]) 196 | tf_example_str = tf_example.SerializeToString() 197 | str_len = len(tf_example_str) 198 | writer.write(struct.pack('q', str_len)) 199 | writer.write(struct.pack('%ds' % str_len, tf_example_str)) 200 | 201 | # Write the vocab to file, if applicable 202 | if makevocab: 203 | art_tokens = article.split(' ') 204 | abs_tokens = abstract.split(' ') 205 | abs_tokens = [t for t in abs_tokens if 206 | t not in [SENTENCE_START, SENTENCE_END]] # remove these tags from vocab 207 | tokens = art_tokens + abs_tokens 208 | tokens = [t.strip() for t in tokens] # strip 209 | tokens = [t for t in tokens if t != ""] # remove empty 210 | vocab_counter.update(tokens) 211 | 212 | print("Finished writing file %s\n" % out_file) 213 | 214 | # write vocab to file 215 | if makevocab: 216 | print( "Writing vocab file...") 217 | with open(os.path.join(finished_files_dir, "vocab"), 'w') as writer: 218 | for word, count in vocab_counter.most_common(VOCAB_SIZE): 219 | writer.write(word + ' ' + str(count) + '\n') 220 | print("Finished writing vocab file") 221 | 222 | 223 | def check_num_stories(stories_dir, num_expected): 224 | num_stories = len(os.listdir(stories_dir)) 225 | if num_stories != num_expected: 226 | raise Exception( 227 | "stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected)) 228 | 229 | 230 | if __name__ == '__main__': 231 | if len(sys.argv) != 3: 232 | print("USAGE: python make_datafiles.py ") 233 | sys.exit() 234 | cnn_stories_dir = sys.argv[1] 235 | dm_stories_dir = sys.argv[2] 236 | 237 | # Check the stories directories contain the correct number of .story files 238 | check_num_stories(cnn_stories_dir, num_expected_cnn_stories) 239 | check_num_stories(dm_stories_dir, num_expected_dm_stories) 240 | 241 | # Create some new directories 242 | if not os.path.exists(cnn_tokenized_stories_dir): os.makedirs(cnn_tokenized_stories_dir) 243 | if not os.path.exists(dm_tokenized_stories_dir): os.makedirs(dm_tokenized_stories_dir) 244 | if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir) 245 | 246 | # Run stanford tokenizer on both stories dirs, outputting to tokenized stories directories 247 | tokenize_stories(cnn_stories_dir, cnn_tokenized_stories_dir) 248 | tokenize_stories(dm_stories_dir, dm_tokenized_stories_dir) 249 | 250 | # Read the tokenized stories, do a little postprocessing then write to bin files 251 | write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test.bin")) 252 | write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val.bin")) 253 | write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train.bin"), makevocab=True) 254 | 255 | # Chunk the data. This splits each of train.bin, val.bin and test.bin into smaller chunks, each containing e.g. 1000 examples, and saves them in finished_files/chunks 256 | chunk_all() 257 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/basic.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/layers.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/models/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils import config 8 | from models.basic import BasicModule 9 | 10 | 11 | class Attention(BasicModule): 12 | def __init__(self): 13 | super(Attention, self).__init__() 14 | 15 | self.fc = nn.Linear(config.hidden_dim * 2, 1, bias=False) 16 | self.dec_fc = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2) 17 | if config.is_coverage: 18 | self.con_fc = nn.Linear(1, config.hidden_dim * 2, bias=False) 19 | 20 | self.init_params() 21 | 22 | def forward(self, s_t, enc_out, enc_fea, enc_padding_mask, coverage): 23 | b, l, n = list(enc_out.size()) 24 | 25 | dec_fea = self.dec_fc(s_t) # B x 2*hidden_dim 26 | dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, l, n).contiguous() # B x l x 2*hidden_dim 27 | dec_fea_expanded = dec_fea_expanded.view(-1, n) # B*l x 2*hidden_dim 28 | 29 | att_features = enc_fea + dec_fea_expanded # B*l x 2*hidden_dim 30 | if config.is_coverage: 31 | coverage_inp = coverage.view(-1, 1) # B*l x 1 32 | coverage_fea = self.con_fc(coverage_inp) # B*l x 2*hidden_dim 33 | att_features = att_features + coverage_fea 34 | 35 | e = torch.tanh(att_features) # B*l x 2*hidden_dim 36 | scores = self.fc(e) # B*l x 1 37 | scores = scores.view(-1, l) # B x l 38 | 39 | attn_dist_ = F.softmax(scores, dim=1) * enc_padding_mask # B x l 40 | normalization_factor = attn_dist_.sum(1, keepdim=True) 41 | attn_dist = attn_dist_ / normalization_factor 42 | 43 | attn_dist = attn_dist.unsqueeze(1) # B x 1 x l 44 | c_t = torch.bmm(attn_dist, enc_out) # B x 1 x n 45 | c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim 46 | 47 | attn_dist = attn_dist.view(-1, l) # B x l 48 | 49 | if config.is_coverage: 50 | coverage = coverage.view(-1, l) 51 | coverage = coverage + attn_dist 52 | 53 | return c_t, attn_dist, coverage 54 | 55 | -------------------------------------------------------------------------------- /models/basic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | class BasicModule(nn.Module): 8 | def __init__(self, init='uniform'): 9 | super(BasicModule, self).__init__() 10 | self.init = init 11 | 12 | def init_params(self): 13 | for param in self.parameters(): 14 | if param.requires_grad and len(param.shape) > 0: 15 | stddev = 1 / math.sqrt(param.shape[0]) 16 | if self.init == 'uniform': 17 | torch.nn.init.uniform_(param, a=-0.05, b=0.05) 18 | elif self.init == 'normal': 19 | torch.nn.init.normal_(param, std=stddev) 20 | elif self.init == 'truncated_normal': 21 | self.truncated_normal_(param, mean=0,std=stddev) 22 | 23 | def truncated_normal_(self, tensor, mean=0, std=1.): 24 | """ 25 | Implemented by @ruotianluo 26 | See https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 27 | """ 28 | size = tensor.shape 29 | tmp = tensor.new_empty(size + (4,)).normal_() 30 | valid = (tmp < 2) & (tmp > -2) 31 | ind = valid.max(-1, keepdim=True)[1] 32 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 33 | tensor.data.mul_(std).add_(mean) 34 | return tensor -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pad_packed_sequence 7 | from torch.nn.utils.rnn import pack_padded_sequence 8 | 9 | from utils import config 10 | from models.basic import BasicModule 11 | from models.attention import Attention 12 | 13 | 14 | class Encoder(BasicModule): 15 | def __init__(self): 16 | super(Encoder, self).__init__() 17 | self.src_word_emb = nn.Embedding(config.vocab_size, config.emb_dim) 18 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, batch_first=True, bidirectional=True) 19 | self.fc = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False) 20 | 21 | self.init_params() 22 | 23 | # seq_lens should be in descending order 24 | def forward(self, input, seq_lens): 25 | embedded = self.src_word_emb(input) 26 | 27 | packed = pack_padded_sequence(embedded, seq_lens, batch_first=True) 28 | output, hidden = self.lstm(packed) 29 | 30 | encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x l x n 31 | encoder_outputs = encoder_outputs.contiguous() 32 | 33 | encoder_feature = encoder_outputs.view(-1, 2 * config.hidden_dim) # B*l x 2*hidden_dim 34 | encoder_feature = self.fc(encoder_feature) 35 | 36 | return encoder_outputs, encoder_feature, hidden 37 | 38 | 39 | class ReduceState(BasicModule): 40 | def __init__(self): 41 | super(ReduceState, self).__init__() 42 | 43 | self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim) 44 | self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim) 45 | self.init_params() 46 | 47 | 48 | def forward(self, hidden): 49 | h, c = hidden # h, c dim = 2 x b x hidden_dim 50 | h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2) 51 | hidden_reduced_h = F.relu(self.reduce_h(h_in)) 52 | c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2) 53 | hidden_reduced_c = F.relu(self.reduce_c(c_in)) 54 | 55 | return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim 56 | 57 | class Decoder(BasicModule): 58 | def __init__(self): 59 | super(Decoder, self).__init__() 60 | self.attention_network = Attention() 61 | # decoder 62 | self.tgt_word_emb = nn.Embedding(config.vocab_size, config.emb_dim) 63 | self.con_fc = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim) 64 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, batch_first=True, bidirectional=False) 65 | 66 | if config.pointer_gen: 67 | self.p_gen_fc = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1) 68 | 69 | # p_vocab 70 | self.fc1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim) 71 | self.fc2 = nn.Linear(config.hidden_dim, config.vocab_size) 72 | 73 | self.init_params() 74 | 75 | def forward(self, y_t, s_t, enc_out, enc_fea, enc_padding_mask, 76 | c_t, extra_zeros, enc_batch_extend_vocab, coverage, step): 77 | 78 | if not self.training and step == 0: 79 | dec_h, dec_c = s_t 80 | s_t_hat = torch.cat((dec_h.view(-1, config.hidden_dim), 81 | dec_c.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim 82 | c_t, _, coverage_next = self.attention_network(s_t_hat, enc_out, enc_fea, 83 | enc_padding_mask, coverage) 84 | coverage = coverage_next 85 | 86 | y_t_embd = self.tgt_word_emb(y_t) 87 | x = self.con_fc(torch.cat((c_t, y_t_embd), 1)) 88 | lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t) 89 | 90 | dec_h, dec_c = s_t 91 | s_t_hat = torch.cat((dec_h.view(-1, config.hidden_dim), 92 | dec_c.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim 93 | c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, enc_out, enc_fea, 94 | enc_padding_mask, coverage) 95 | 96 | if self.training or step > 0: 97 | coverage = coverage_next 98 | 99 | p_gen = None 100 | if config.pointer_gen: 101 | p_gen_inp = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim) 102 | p_gen = self.p_gen_fc(p_gen_inp) 103 | p_gen = torch.sigmoid(p_gen) 104 | 105 | output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3 106 | output = self.fc1(output) # B x hidden_dim 107 | # output = F.relu(output) 108 | 109 | output = self.fc2(output) # B x vocab_size 110 | vocab_dist = F.softmax(output, dim=1) 111 | 112 | if config.pointer_gen: 113 | vocab_dist_ = p_gen * vocab_dist 114 | attn_dist_ = (1 - p_gen) * attn_dist 115 | 116 | if extra_zeros is not None: 117 | vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1) 118 | 119 | final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_) 120 | else: 121 | final_dist = vocab_dist 122 | 123 | return final_dist, s_t, c_t, attn_dist, p_gen, coverage 124 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | from utils import config 6 | from numpy import random 7 | from models.layers import Encoder 8 | from models.layers import Decoder 9 | from models.layers import ReduceState 10 | from transformer.model import TranEncoder 11 | 12 | use_cuda = config.use_gpu and torch.cuda.is_available() 13 | 14 | random.seed(123) 15 | torch.manual_seed(123) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(123) 18 | 19 | 20 | class Model(object): 21 | def __init__(self, model_path=None, is_eval=False, is_tran = False): 22 | encoder = Encoder() 23 | decoder = Decoder() 24 | reduce_state = ReduceState() 25 | if is_tran: 26 | encoder = TranEncoder(config.vocab_size, config.max_enc_steps, config.emb_dim, 27 | config.n_layers, config.n_head, config.d_k, config.d_v, config.d_model, config.d_inner) 28 | 29 | # shared the embedding between encoder and decoder 30 | decoder.tgt_word_emb.weight = encoder.src_word_emb.weight 31 | 32 | if is_eval: 33 | encoder = encoder.eval() 34 | decoder = decoder.eval() 35 | reduce_state = reduce_state.eval() 36 | 37 | if use_cuda: 38 | encoder = encoder.cuda() 39 | decoder = decoder.cuda() 40 | reduce_state = reduce_state.cuda() 41 | 42 | self.encoder = encoder 43 | self.decoder = decoder 44 | self.reduce_state = reduce_state 45 | 46 | if model_path is not None: 47 | state = torch.load(model_path, map_location=lambda storage, location: storage) 48 | self.encoder.load_state_dict(state['encoder_state_dict']) 49 | self.decoder.load_state_dict(state['decoder_state_dict'], strict=False) 50 | self.reduce_state.load_state_dict(state['reduce_state_dict']) 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.7.4 2 | torch==1.4.0 3 | pyrouge==0.1.3 4 | tensorflow==1.13.1 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import unicode_literals, print_function, division 4 | 5 | import os 6 | import sys 7 | import time 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from models.model import Model 12 | from utils import utils 13 | from utils.dataset import Batcher 14 | from utils.dataset import Vocab 15 | from utils import dataset, config 16 | from utils.utils import get_input_from_batch 17 | from utils.utils import write_for_rouge, rouge_eval, rouge_log 18 | 19 | use_cuda = config.use_gpu and torch.cuda.is_available() 20 | 21 | 22 | class Beam(object): 23 | def __init__(self, tokens, log_probs, state, context, coverage): 24 | self.tokens = tokens 25 | self.state = state 26 | self.context = context 27 | self.coverage = coverage 28 | self.log_probs = log_probs 29 | 30 | def extend(self, token, log_prob, state, context, coverage): 31 | return Beam(tokens=self.tokens + [token], 32 | log_probs=self.log_probs + [log_prob], 33 | state=state, 34 | context=context, 35 | coverage=coverage) 36 | 37 | @property 38 | def latest_token(self): 39 | return self.tokens[-1] 40 | 41 | @property 42 | def avg_log_prob(self): 43 | return sum(self.log_probs) / len(self.tokens) 44 | 45 | 46 | class BeamSearch(object): 47 | def __init__(self, model_file_path): 48 | 49 | model_name = os.path.basename(model_file_path) 50 | self._test_dir = os.path.join(config.log_root, 'decode_%s' % (model_name)) 51 | self._rouge_ref_dir = os.path.join(self._test_dir, 'rouge_ref') 52 | self._rouge_dec_dir = os.path.join(self._test_dir, 'rouge_dec') 53 | for p in [self._test_dir, self._rouge_ref_dir, self._rouge_dec_dir]: 54 | if not os.path.exists(p): 55 | os.mkdir(p) 56 | 57 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 58 | self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', 59 | batch_size=config.beam_size, single_pass=True) 60 | time.sleep(15) 61 | 62 | self.model = Model(model_file_path, is_eval=True) 63 | 64 | def sort_beams(self, beams): 65 | return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) 66 | 67 | 68 | def beam_search(self, batch): 69 | # single example repeated across the batch 70 | enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \ 71 | get_input_from_batch(batch, use_cuda) 72 | 73 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens) 74 | s_t = self.model.reduce_state(enc_h) 75 | 76 | dec_h, dec_c = s_t # b x hidden_dim 77 | dec_h = dec_h.squeeze() 78 | dec_c = dec_c.squeeze() 79 | 80 | # decoder batch preparation, it has beam_size example initially everything is repeated 81 | beams = [Beam(tokens=[self.vocab.word2id(config.BOS_TOKEN)], 82 | log_probs=[0.0], 83 | state=(dec_h[0], dec_c[0]), 84 | context=c_t[0], 85 | coverage=(coverage[0] if config.is_coverage else None)) 86 | for _ in range(config.beam_size)] 87 | 88 | steps = 0 89 | results = [] 90 | while steps < config.max_dec_steps and len(results) < config.beam_size: 91 | latest_tokens = [h.latest_token for h in beams] 92 | latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(config.UNK_TOKEN) \ 93 | for t in latest_tokens] 94 | y_t = Variable(torch.LongTensor(latest_tokens)) 95 | if use_cuda: 96 | y_t = y_t.cuda() 97 | all_state_h = [h.state[0] for h in beams] 98 | all_state_c = [h.state[1] for h in beams] 99 | all_context = [h.context for h in beams] 100 | 101 | s_t = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) 102 | c_t = torch.stack(all_context, 0) 103 | 104 | coverage_t = None 105 | if config.is_coverage: 106 | all_coverage = [h.coverage for h in beams] 107 | coverage_t = torch.stack(all_coverage, 0) 108 | 109 | final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t, s_t, 110 | enc_out, enc_fea, 111 | enc_padding_mask, c_t, 112 | extra_zeros, enc_batch_extend_vocab, 113 | coverage_t, steps) 114 | log_probs = torch.log(final_dist) 115 | topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2) 116 | 117 | dec_h, dec_c = s_t 118 | dec_h = dec_h.squeeze() 119 | dec_c = dec_c.squeeze() 120 | 121 | all_beams = [] 122 | # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct. 123 | num_orig_beams = 1 if steps == 0 else len(beams) 124 | for i in range(num_orig_beams): 125 | h = beams[i] 126 | state_i = (dec_h[i], dec_c[i]) 127 | context_i = c_t[i] 128 | coverage_i = (coverage[i] if config.is_coverage else None) 129 | 130 | for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps: 131 | new_beam = h.extend(token=topk_ids[i, j].item(), 132 | log_prob=topk_log_probs[i, j].item(), 133 | state=state_i, 134 | context=context_i, 135 | coverage=coverage_i) 136 | all_beams.append(new_beam) 137 | 138 | beams = [] 139 | for h in self.sort_beams(all_beams): 140 | if h.latest_token == self.vocab.word2id(config.EOS_TOKEN): 141 | if steps >= config.min_dec_steps: 142 | results.append(h) 143 | else: 144 | beams.append(h) 145 | if len(beams) == config.beam_size or len(results) == config.beam_size: 146 | break 147 | 148 | steps += 1 149 | 150 | if len(results) == 0: 151 | results = beams 152 | 153 | beams_sorted = self.sort_beams(results) 154 | 155 | return beams_sorted[0] 156 | 157 | def run(self): 158 | 159 | counter = 0 160 | start = time.time() 161 | batch = self.batcher.next_batch() 162 | while batch is not None: 163 | # Run beam search to get best Hypothesis 164 | best_summary = self.beam_search(batch) 165 | 166 | # Extract the output ids from the hypothesis and convert back to words 167 | output_ids = [int(t) for t in best_summary.tokens[1:]] 168 | decoded_words = utils.outputids2words(output_ids, self.vocab, 169 | (batch.art_oovs[0] if config.pointer_gen else None)) 170 | 171 | # Remove the [STOP] token from decoded_words, if necessary 172 | try: 173 | fst_stop_idx = decoded_words.index(dataset.EOS_TOKEN) 174 | decoded_words = decoded_words[:fst_stop_idx] 175 | except ValueError: 176 | decoded_words = decoded_words 177 | 178 | original_abstract_sents = batch.original_abstracts_sents[0] 179 | 180 | write_for_rouge(original_abstract_sents, decoded_words, counter, 181 | self._rouge_ref_dir, self._rouge_dec_dir) 182 | counter += 1 183 | if counter % 1000 == 0: 184 | print('%d example in %d sec' % (counter, time.time() - start)) 185 | start = time.time() 186 | 187 | batch = self.batcher.next_batch() 188 | 189 | print("Decoder has finished reading dataset for single_pass.") 190 | print("Now starting ROUGE eval...") 191 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) 192 | rouge_log(results_dict, self._test_dir) 193 | 194 | 195 | if __name__ == '__main__': 196 | model_filename = sys.argv[1] 197 | test_processor = BeamSearch(model_filename) 198 | test_processor.run() 199 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import unicode_literals, print_function, division 4 | 5 | import os 6 | import time 7 | import argparse 8 | import tensorflow as tf 9 | 10 | import torch 11 | import torch.optim as optim 12 | from torch.nn.utils import clip_grad_norm_ 13 | 14 | from models.model import Model 15 | from utils import config 16 | from utils.dataset import Vocab 17 | from utils.dataset import Batcher 18 | from utils.utils import get_input_from_batch 19 | from utils.utils import get_output_from_batch 20 | from utils.utils import calc_running_avg_loss 21 | 22 | use_cuda = config.use_gpu and torch.cuda.is_available() 23 | 24 | 25 | class Train(object): 26 | def __init__(self): 27 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 28 | self.batcher = Batcher(self.vocab, config.train_data_path, 29 | config.batch_size, single_pass=False, mode='train') 30 | time.sleep(10) 31 | 32 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) 33 | if not os.path.exists(train_dir): 34 | os.mkdir(train_dir) 35 | 36 | self.model_dir = os.path.join(train_dir, 'models') 37 | if not os.path.exists(self.model_dir): 38 | os.mkdir(self.model_dir) 39 | 40 | self.summary_writer = tf.summary.FileWriter(train_dir) 41 | 42 | def save_model(self, running_avg_loss, iter): 43 | state = { 44 | 'iter': iter, 45 | 'encoder_state_dict': self.model.encoder.state_dict(), 46 | 'decoder_state_dict': self.model.decoder.state_dict(), 47 | 'reduce_state_dict': self.model.reduce_state.state_dict(), 48 | 'optimizer': self.optimizer.state_dict(), 49 | 'current_loss': running_avg_loss 50 | } 51 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) 52 | torch.save(state, model_save_path) 53 | 54 | def setup_train(self, model_path=None): 55 | self.model = Model(model_path, is_tran= config.tran) 56 | initial_lr = config.lr_coverage if config.is_coverage else config.lr 57 | 58 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \ 59 | list(self.model.reduce_state.parameters()) 60 | total_params = sum([param[0].nelement() for param in params]) 61 | print('The Number of params of model: %.3f million' % (total_params / 1e6)) # million 62 | self.optimizer = optim.Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc) 63 | 64 | start_iter, start_loss = 0, 0 65 | 66 | if model_path is not None: 67 | state = torch.load(model_path, map_location=lambda storage, location: storage) 68 | start_iter = state['iter'] 69 | start_loss = state['current_loss'] 70 | 71 | if not config.is_coverage: 72 | self.optimizer.load_state_dict(state['optimizer']) 73 | if use_cuda: 74 | for state in self.optimizer.state.values(): 75 | for k, v in state.items(): 76 | if torch.is_tensor(v): 77 | state[k] = v.cuda() 78 | 79 | return start_iter, start_loss 80 | 81 | def train_one_batch(self, batch): 82 | enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \ 83 | extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda) 84 | dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \ 85 | get_output_from_batch(batch, use_cuda) 86 | 87 | self.optimizer.zero_grad() 88 | 89 | if not config.tran: 90 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens) 91 | else: 92 | enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_pos) 93 | 94 | s_t = self.model.reduce_state(enc_h) 95 | 96 | step_losses, cove_losses = [], [] 97 | for di in range(min(max_dec_len, config.max_dec_steps)): 98 | y_t = dec_batch[:, di] # Teacher forcing 99 | final_dist, s_t, c_t, attn_dist, p_gen, next_coverage = \ 100 | self.model.decoder(y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t, 101 | extra_zeros, enc_batch_extend_vocab, coverage, di) 102 | tgt = tgt_batch[:, di] 103 | step_mask = dec_padding_mask[:, di] 104 | gold_probs = torch.gather(final_dist, 1, tgt.unsqueeze(1)).squeeze() 105 | step_loss = -torch.log(gold_probs + config.eps) 106 | if config.is_coverage: 107 | step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) 108 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 109 | cove_losses.append(step_coverage_loss * step_mask) 110 | coverage = next_coverage 111 | 112 | step_loss = step_loss * step_mask 113 | step_losses.append(step_loss) 114 | 115 | sum_losses = torch.sum(torch.stack(step_losses, 1), 1) 116 | batch_avg_loss = sum_losses / dec_lens 117 | loss = torch.mean(batch_avg_loss) 118 | 119 | loss.backward() 120 | 121 | clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) 122 | clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) 123 | clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) 124 | 125 | self.optimizer.step() 126 | 127 | if config.is_coverage: 128 | cove_losses = torch.sum(torch.stack(cove_losses, 1), 1) 129 | batch_cove_loss = cove_losses / dec_lens 130 | batch_cove_loss = torch.mean(batch_cove_loss) 131 | return loss.item(), batch_cove_loss.item() 132 | 133 | return loss.item(), 0. 134 | 135 | def run(self, n_iters, model_path=None): 136 | iter, running_avg_loss = self.setup_train(model_path) 137 | start = time.time() 138 | interval = 100 139 | 140 | while iter < n_iters: 141 | batch = self.batcher.next_batch() 142 | loss, cove_loss = self.train_one_batch(batch) 143 | 144 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 145 | iter += 1 146 | 147 | if iter % interval == 0: 148 | self.summary_writer.flush() 149 | print( 150 | 'step: %d, second: %.2f , loss: %f, cover_loss: %f' % (iter, time.time() - start, loss, cove_loss)) 151 | start = time.time() 152 | if iter % 5000 == 0: 153 | self.save_model(running_avg_loss, iter) 154 | 155 | 156 | if __name__ == '__main__': 157 | parser = argparse.ArgumentParser(description="Train script") 158 | parser.add_argument("-m", 159 | dest="model_path", 160 | required=False, 161 | default=None, 162 | help="Model file for retraining (default: None).") 163 | args = parser.parse_args() 164 | 165 | train_processor = Train() 166 | train_processor.run(config.max_iterations, args.model_path) 167 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__init__.py -------------------------------------------------------------------------------- /transformer/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/layers.cpython-35.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/optim.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/optim.cpython-35.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/sublayers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/transformer/__pycache__/sublayers.cpython-35.pyc -------------------------------------------------------------------------------- /transformer/layers.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from transformer.sublayers import MultiHeadAttention 5 | from transformer.sublayers import PositionwiseFeedForward 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn -------------------------------------------------------------------------------- /transformer/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from utils import config 9 | from transformer.layers import EncoderLayer 10 | from transformer.layers import DecoderLayer 11 | 12 | use_cuda = config.use_gpu and torch.cuda.is_available() 13 | 14 | def get_non_pad_mask(seq): 15 | assert seq.dim() == 2 16 | return seq.ne(config.PAD).type(torch.float).unsqueeze(-1) 17 | 18 | 19 | def positional_encoding(max_len, d_model, padding_idx=None): 20 | ''' Sinusoid position encoding table ''' 21 | 22 | pe = torch.zeros(max_len, d_model) 23 | position = torch.arange(0., max_len).unsqueeze(1) 24 | div_term = torch.exp(torch.arange(0., d_model, 2) * 25 | -(math.log(10000.0) / d_model)) 26 | pe[:, 0::2] = torch.sin(position * div_term) 27 | pe[:, 1::2] = torch.cos(position * div_term) 28 | 29 | if padding_idx is not None: 30 | # zero vector for padding dimension 31 | pe[padding_idx] = 0. 32 | 33 | return torch.FloatTensor(pe) 34 | 35 | 36 | def get_attn_key_pad_mask(seq_k, seq_q): 37 | ''' For masking out the padding part of key sequence. ''' 38 | 39 | # Expand to fit the shape of key query attention matrix. 40 | len_q = seq_q.size(1) 41 | padding_mask = seq_k.eq(config.PAD) 42 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 43 | 44 | return padding_mask 45 | 46 | 47 | def get_subsequent_mask(seq): 48 | ''' For masking out the subsequent info. ''' 49 | 50 | sz_b, len_s = seq.size() 51 | subsequent_mask = torch.triu( 52 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 53 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 54 | 55 | return subsequent_mask 56 | 57 | 58 | class TranEncoder(nn.Module): 59 | ''' A encoder model with attention mechanism. ''' 60 | 61 | def __init__(self, n_src_vocab, len_max_seq, d_word_vec, 62 | n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.d_model = d_model 68 | 69 | self.src_word_emb = nn.Embedding( 70 | n_src_vocab, d_word_vec, padding_idx=config.PAD) 71 | 72 | self.position_enc = nn.Embedding.from_pretrained( 73 | positional_encoding(n_position, d_word_vec, padding_idx=1), 74 | freeze=True) 75 | 76 | self.layer_stack = nn.ModuleList([ 77 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 78 | for _ in range(n_layers)]) 79 | 80 | self.fc = nn.Linear(d_model, 2 * config.hidden_dim, bias=False) 81 | 82 | def forward(self, src_seq, src_pos, return_attns=False): 83 | 84 | enc_slf_attn_list = [] 85 | 86 | # -- Prepare masks 87 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 88 | non_pad_mask = get_non_pad_mask(src_seq) 89 | 90 | # -- Forward 91 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, 96 | non_pad_mask=non_pad_mask, 97 | slf_attn_mask=slf_attn_mask) 98 | if return_attns: 99 | enc_slf_attn_list += [enc_slf_attn] 100 | 101 | if return_attns: 102 | return enc_output, enc_slf_attn_list 103 | 104 | enc_output = self.fc(enc_output) 105 | enc_feature = enc_output.view(-1, 2*config.hidden_dim) # B*l x 2*hidden_dim 106 | b, l, n = list(enc_output.size()) 107 | enc_h = enc_output[:,-1,:].reshape(2, b, config.hidden_dim).contiguous() 108 | 109 | return enc_output, enc_feature, (enc_h, enc_h) 110 | 111 | 112 | class TranDecoder(nn.Module): 113 | ''' A decoder model with self attention mechanism. ''' 114 | 115 | def __init__(self, n_tgt_vocab, len_max_seq, d_word_vec, 116 | n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 117 | 118 | super().__init__() 119 | n_position = len_max_seq + 1 120 | 121 | self.tgt_word_emb = nn.Embedding( 122 | n_tgt_vocab, d_word_vec, padding_idx=config.PAD) 123 | 124 | self.position_enc = nn.Embedding.from_pretrained( 125 | positional_encoding(n_position, d_word_vec, padding_idx=config.PAD), 126 | freeze=True) 127 | 128 | self.layer_stack = nn.ModuleList([ 129 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 130 | for _ in range(n_layers)]) 131 | 132 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 133 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 134 | 135 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 136 | 137 | dec_slf_attn_list, dec_enc_attn_list = [], [] 138 | 139 | # -- Prepare masks 140 | non_pad_mask = get_non_pad_mask(tgt_seq) 141 | 142 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 143 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 144 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 145 | 146 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 147 | 148 | # -- Forward 149 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 150 | 151 | for dec_layer in self.layer_stack: 152 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 153 | dec_output, enc_output, 154 | non_pad_mask=non_pad_mask, 155 | slf_attn_mask=slf_attn_mask, 156 | dec_enc_attn_mask=dec_enc_attn_mask) 157 | 158 | if return_attns: 159 | dec_slf_attn_list += [dec_slf_attn] 160 | dec_enc_attn_list += [dec_enc_attn] 161 | 162 | if return_attns: 163 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 164 | 165 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 166 | seq_logit = F.softmax(seq_logit, dim=-1) 167 | return seq_logit, 168 | 169 | 170 | class Transformer(nn.Module): 171 | ''' A sequence to sequence model with attention mechanism. ''' 172 | 173 | def __init__( 174 | self, 175 | n_src_vocab, n_tgt_vocab, len_max_seq, 176 | d_word_vec=512, d_model=512, d_inner=2048, 177 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 178 | tgt_emb_prj_weight_sharing=True, 179 | emb_src_tgt_weight_sharing=True): 180 | 181 | super().__init__() 182 | 183 | self.encoder = TranEncoder( 184 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 185 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 186 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 187 | dropout=dropout) 188 | 189 | self.decoder = TranDecoder( 190 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 191 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 192 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 193 | dropout=dropout) 194 | 195 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 196 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 197 | 198 | assert d_model == d_word_vec, \ 199 | 'To facilitate the residual connections, \ 200 | the dimensions of all module outputs shall be the same.' 201 | 202 | if tgt_emb_prj_weight_sharing: 203 | # Share the weight matrix between target word embedding & the final logit dense layer 204 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 205 | self.x_logit_scale = (d_model ** -0.5) 206 | else: 207 | self.x_logit_scale = 1. 208 | 209 | if emb_src_tgt_weight_sharing: 210 | # Share the weight matrix between source & target word embeddings 211 | assert n_src_vocab == n_tgt_vocab, \ 212 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 213 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 214 | 215 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 216 | 217 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 218 | 219 | enc_output, *_ = self.encoder(src_seq, src_pos) 220 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 221 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 222 | 223 | return seq_logit.view(-1, seq_logit.size(2)) -------------------------------------------------------------------------------- /transformer/optim.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | class ScheduledOptim(): 6 | '''A simple wrapper class for learning rate scheduling''' 7 | 8 | def __init__(self, optimizer, d_model, n_warmup_steps): 9 | self._optimizer = optimizer 10 | self.n_warmup_steps = n_warmup_steps 11 | self.n_current_steps = 0 12 | self.init_lr = np.power(d_model, -0.5) 13 | 14 | def step_and_update_lr(self): 15 | "Step with the inner optimizer" 16 | self._update_learning_rate() 17 | self._optimizer.step() 18 | 19 | def zero_grad(self): 20 | "Zero out the gradients by the inner optimizer" 21 | self._optimizer.zero_grad() 22 | 23 | def _get_lr_scale(self): 24 | return np.min([ 25 | np.power(self.n_current_steps, -0.5), 26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 27 | 28 | def _update_learning_rate(self): 29 | ''' Learning rate scheduling per step ''' 30 | 31 | self.n_current_steps += 1 32 | lr = self.init_lr * self._get_lr_scale() 33 | 34 | for param_group in self._optimizer.param_groups: 35 | param_group['lr'] = lr -------------------------------------------------------------------------------- /transformer/sublayers.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import math 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class PositionalEncoding(nn.Module): 12 | '''Positional Encoding module''' 13 | 14 | def __init__(self, d_model, max_len=5000): 15 | super().__init__() 16 | 17 | # Compute the positional encodings once in log space. 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len).unsqueeze(1) 20 | div_term = torch.exp(torch.arange(0, d_model, 2) * 21 | -(math.log(10000.0) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | self.pe = pe.unsqueeze(0) 25 | self.register_buffer('PositionalEncoding', self.pe) 26 | 27 | def forward(self, x): 28 | return self.pe[:, :x.size(1)] 29 | 30 | class ScaledDotProductAttention(nn.Module): 31 | ''' Scaled Dot-Product Attention module''' 32 | 33 | def __init__(self, temperature, attn_dropout=0.1): 34 | super().__init__() 35 | self.temperature = temperature 36 | self.dropout = nn.Dropout(attn_dropout) 37 | self.softmax = nn.Softmax(dim=2) 38 | 39 | def forward(self, q, k, v, mask=None): 40 | 41 | attn = torch.bmm(q, k.transpose(1, 2)) 42 | attn = attn / self.temperature 43 | 44 | if mask is not None: 45 | attn = attn.masked_fill(mask, -np.inf) 46 | 47 | attn = self.softmax(attn) 48 | attn = self.dropout(attn) 49 | output = torch.bmm(attn, v) 50 | 51 | return output, attn 52 | 53 | class MultiHeadAttention(nn.Module): 54 | ''' Multi-Head Attention module ''' 55 | 56 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 57 | super().__init__() 58 | 59 | self.d_k = d_k 60 | self.d_v = d_v 61 | self.n_head = n_head 62 | 63 | self.w_qs = nn.Linear(d_model, n_head * d_k) 64 | self.w_ks = nn.Linear(d_model, n_head * d_k) 65 | self.w_vs = nn.Linear(d_model, n_head * d_v) 66 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 67 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 68 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 69 | 70 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 71 | self.layer_norm = nn.LayerNorm(d_model) 72 | 73 | self.fc = nn.Linear(n_head * d_v, d_model) 74 | nn.init.xavier_normal_(self.fc.weight) 75 | 76 | self.dropout = nn.Dropout(dropout) 77 | 78 | 79 | def forward(self, q, k, v, mask=None): 80 | 81 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 82 | 83 | sz_b, len_q, _ = q.size() 84 | sz_b, len_k, _ = k.size() 85 | sz_b, len_v, _ = v.size() 86 | 87 | residual = q 88 | 89 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 90 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 91 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 92 | 93 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 94 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 95 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 96 | 97 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 98 | output, attn = self.attention(q, k, v, mask=mask) 99 | 100 | output = output.view(n_head, sz_b, len_q, d_v) 101 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 102 | 103 | output = self.dropout(self.fc(output)) 104 | output = self.layer_norm(output + residual) 105 | 106 | return output, attn 107 | 108 | class PositionwiseFeedForward(nn.Module): 109 | ''' A two-feed-forward-layer module ''' 110 | 111 | def __init__(self, d_in, d_hid, dropout=0.1): 112 | super().__init__() 113 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 114 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 115 | self.layer_norm = nn.LayerNorm(d_in) 116 | self.dropout = nn.Dropout(dropout) 117 | 118 | def forward(self, x): 119 | residual = x 120 | output = x.transpose(1, 2) 121 | output = self.w_2(F.relu(self.w_1(output))) 122 | output = output.transpose(1, 2) 123 | output = self.dropout(output) 124 | output = self.layer_norm(output + residual) 125 | return output -------------------------------------------------------------------------------- /transformer/tran_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import unicode_literals, print_function, division 4 | 5 | import os 6 | import time 7 | import argparse 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import torch 12 | import torch.optim as optim 13 | 14 | from transformer.model import Model 15 | from transformer.optim import ScheduledOptim 16 | from utils import config 17 | from utils.dataset import Vocab 18 | from utils.dataset import Batcher 19 | from utils.utils import get_input_from_batch 20 | from utils.utils import get_output_from_batch 21 | from utils.utils import calc_running_avg_loss 22 | 23 | use_cuda = config.use_gpu and torch.cuda.is_available() 24 | 25 | class Train(object): 26 | def __init__(self): 27 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 28 | self.batcher = Batcher(self.vocab, config.train_data_path, 29 | config.batch_size, single_pass=False, mode='train') 30 | time.sleep(10) 31 | 32 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) 33 | if not os.path.exists(train_dir): 34 | os.mkdir(train_dir) 35 | 36 | self.model_dir = os.path.join(train_dir, 'models') 37 | if not os.path.exists(self.model_dir): 38 | os.mkdir(self.model_dir) 39 | 40 | self.summary_writer = tf.summary.FileWriter(train_dir) 41 | 42 | def save_model(self, running_avg_loss, iter): 43 | model_state_dict = self.model.state_dict() 44 | 45 | state = { 46 | 'iter': iter, 47 | 'current_loss': running_avg_loss, 48 | 'optimizer': self.optimizer._optimizer.state_dict(), 49 | "model": model_state_dict 50 | } 51 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) 52 | torch.save(state, model_save_path) 53 | 54 | def setup_train(self, model_path): 55 | 56 | device = torch.device('cuda' if use_cuda else 'cpu') 57 | 58 | self.model = Model( 59 | config.vocab_size, 60 | config.vocab_size, 61 | config.max_enc_steps, 62 | config.max_dec_steps, 63 | d_k=config.d_k, 64 | d_v=config.d_v, 65 | d_model=config.d_model, 66 | d_word_vec=config.emb_dim, 67 | d_inner=config.d_inner_hid, 68 | n_layers=config.n_layers, 69 | n_head=config.n_head, 70 | dropout=config.dropout).to(device) 71 | 72 | self.optimizer = ScheduledOptim( 73 | optim.Adam( 74 | filter(lambda x: x.requires_grad, self.model.parameters()), 75 | betas=(0.9, 0.98), eps=1e-09), 76 | config.d_model, config.n_warmup_steps) 77 | 78 | 79 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) 80 | total_params = sum([param[0].nelement() for param in params]) 81 | print('The Number of params of model: %.3f million' % (total_params / 1e6)) # million 82 | 83 | start_iter, start_loss = 0, 0 84 | 85 | if model_path is not None: 86 | state = torch.load(model_path, map_location=lambda storage, location: storage) 87 | start_iter = state['iter'] 88 | start_loss = state['current_loss'] 89 | 90 | if not config.is_coverage: 91 | self.optimizer._optimizer.load_state_dict(state['optimizer']) 92 | if use_cuda: 93 | for state in self.optimizer._optimizer.state.values(): 94 | for k, v in state.items(): 95 | if torch.is_tensor(v): 96 | state[k] = v.cuda() 97 | 98 | return start_iter, start_loss 99 | 100 | def train_one_batch(self, batch): 101 | enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \ 102 | extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda, transformer=True) 103 | dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \ 104 | get_output_from_batch(batch, use_cuda, transformer=True) 105 | 106 | self.optimizer.zero_grad() 107 | 108 | pred = self.model(enc_batch, enc_pos, dec_batch, dec_pos) 109 | gold_probs = torch.gather(pred, -1, tgt_batch.unsqueeze(-1)).squeeze() 110 | batch_loss = -torch.log(gold_probs + config.eps) 111 | batch_loss = batch_loss * dec_padding_mask 112 | 113 | sum_losses = torch.sum(batch_loss, 1) 114 | batch_avg_loss = sum_losses / dec_lens 115 | loss = torch.mean(batch_avg_loss) 116 | 117 | loss.backward() 118 | 119 | # update parameters 120 | self.optimizer.step_and_update_lr() 121 | 122 | return loss.item(), 0. 123 | 124 | def run(self, n_iters, model_path=None): 125 | iter, running_avg_loss = self.setup_train(model_path) 126 | start = time.time() 127 | interval = 100 128 | 129 | while iter < n_iters: 130 | batch = self.batcher.next_batch() 131 | loss, cove_loss = self.train_one_batch(batch) 132 | 133 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 134 | iter += 1 135 | 136 | if iter % interval == 0: 137 | self.summary_writer.flush() 138 | print( 139 | 'step: %d, second: %.2f , loss: %f, cover_loss: %f' % (iter, time.time() - start, loss, cove_loss)) 140 | start = time.time() 141 | if iter % 5000 == 0: 142 | self.save_model(running_avg_loss, iter) 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(description="Train script") 146 | parser.add_argument("-m", 147 | dest="model_path", 148 | required=False, 149 | default=None, 150 | help="Model file for retraining (default: None).") 151 | args = parser.parse_args() 152 | 153 | train_processor = Train() 154 | train_processor.run(config.max_iterations, args.model_path) 155 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/config.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laihuiyuan/pointer-generator/6a727f4a2f314c2b47df9ce8838dca0de61bfcd4/utils/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | SENTENCE_STA = '' 6 | SENTENCE_END = '' 7 | 8 | UNK = 0 9 | PAD = 1 10 | BOS = 2 11 | EOS = 3 12 | 13 | PAD_TOKEN = '[PAD]' 14 | UNK_TOKEN = '[UNK]' 15 | BOS_TOKEN = '[BOS]' 16 | EOS_TOKEN = '[EOS]' 17 | 18 | beam_size=4 19 | emb_dim= 128 20 | batch_size= 16 21 | hidden_dim= 256 22 | max_enc_steps=400 23 | max_dec_steps=100 24 | max_tes_steps=100 25 | min_dec_steps=35 26 | vocab_size=50000 27 | 28 | lr=0.15 29 | cov_loss_wt = 1.0 30 | pointer_gen = True 31 | is_coverage = False 32 | 33 | max_grad_norm=2.0 34 | adagrad_init_acc=0.1 35 | rand_unif_init_mag=0.02 36 | trunc_norm_init_std=1e-4 37 | 38 | eps = 1e-12 39 | use_gpu=True 40 | lr_coverage=0.15 41 | max_iterations = 500000 42 | 43 | # transformer 44 | d_k = 64 45 | d_v = 64 46 | n_head = 6 47 | tran = True 48 | dropout = 0.1 49 | n_layers = 6 50 | d_model = 128 51 | d_inner = 512 52 | n_warmup_steps = 4000 53 | 54 | root_dir = os.path.expanduser("./") 55 | log_root = os.path.join(root_dir, "dataset/log/") 56 | 57 | #train_data_path = os.path.join(root_dir, "pointer_generator/dataset/finished_files/train.bin") 58 | train_data_path = os.path.join(root_dir, "dataset/finished_files/chunked/train_*") 59 | eval_data_path = os.path.join(root_dir, "dataset/finished_files/val.bin") 60 | decode_data_path = os.path.join(root_dir, "dataset/finished_files/test.bin") 61 | vocab_path = os.path.join(root_dir, "dataset/finished_files/vocab") 62 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import csv 4 | import glob 5 | import time 6 | import queue 7 | import struct 8 | import numpy as np 9 | import tensorflow as tf 10 | from random import shuffle 11 | from threading import Thread 12 | from tensorflow.core.example import example_pb2 13 | 14 | from utils import utils 15 | from utils import config 16 | 17 | import random 18 | random.seed(1234) 19 | 20 | 21 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 22 | SENTENCE_STA = '' 23 | SENTENCE_END = '' 24 | 25 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 26 | UNK_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 27 | BOS_TOKEN = '[BOS]' # This has a vocab id, which is used at the start of every decoder input sequence 28 | EOS_TOKEN = '[EOS]' # This has a vocab id, which is used at the end of untruncated target sequences 29 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 30 | 31 | 32 | class Vocab(object): 33 | 34 | def __init__(self, file, max_size): 35 | self.word2idx = {} 36 | self.idx2word = {} 37 | self.count = 0 # keeps track of total number of words in the Vocab 38 | 39 | # [UNK], [PAD], [BOS] and [EOS] get the ids 0,1,2,3. 40 | for w in [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]: 41 | self.word2idx[w] = self.count 42 | self.idx2word[self.count] = w 43 | self.count += 1 44 | 45 | # Read the vocab file and add words up to max_size 46 | with open(file, 'r') as fin: 47 | for line in fin: 48 | items = line.split() 49 | if len(items) != 2: 50 | print('Warning: incorrectly formatted line in vocabulary file: %s' % line.strip()) 51 | continue 52 | w = items[0] 53 | if w in [SENTENCE_STA, SENTENCE_END, UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]: 54 | raise Exception( 55 | ', , [UNK], [PAD], [BOS] and [EOS] shouldn\'t be in the vocab file, but %s is' % w) 56 | if w in self.word2idx: 57 | raise Exception('Duplicated word in vocabulary file: %s' % w) 58 | self.word2idx[w] = self.count 59 | self.idx2word[self.count] = w 60 | self.count += 1 61 | if max_size != 0 and self.count >= max_size: 62 | break 63 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % ( 64 | self.count, self.idx2word[self.count - 1])) 65 | 66 | def word2id(self, word): 67 | if word not in self.word2idx: 68 | return self.word2idx[UNK_TOKEN] 69 | return self.word2idx[word] 70 | 71 | def id2word(self, word_id): 72 | if word_id not in self.idx2word: 73 | raise ValueError('Id not found in vocab: %d' % word_id) 74 | return self.idx2word[word_id] 75 | 76 | def size(self): 77 | return self.count 78 | 79 | def write_metadata(self, path): 80 | print( "Writing word embedding metadata file to %s..." % (path)) 81 | with open(path, "w") as f: 82 | fieldnames = ['word'] 83 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 84 | for i in range(self.size()): 85 | writer.writerow({"word": self.idx2word[i]}) 86 | 87 | class Example(object): 88 | 89 | def __init__(self, article, abstract_sentences, vocab): 90 | # Get ids of special tokens 91 | bos_decoding = vocab.word2id(BOS_TOKEN) 92 | eos_decoding = vocab.word2id(EOS_TOKEN) 93 | 94 | # Process the article 95 | article_words = article.decode().split() 96 | if len(article_words) > config.max_enc_steps: 97 | article_words = article_words[:config.max_enc_steps] 98 | self.enc_len = len(article_words) # store the length after truncation but before padding 99 | self.enc_inp = [vocab.word2id(w) for w in 100 | article_words] # list of word ids; OOVs are represented by the id for UNK token 101 | 102 | # Process the abstract 103 | abstract = ' '.encode().join(abstract_sentences).decode() 104 | abstract_words = abstract.split() # list of strings 105 | abs_ids = [vocab.word2id(w) for w in 106 | abstract_words] # list of word ids; OOVs are represented by the id for UNK token 107 | 108 | # Get the decoder input sequence and target sequence 109 | self.dec_inp, self.tgt = self.get_dec_seq(abs_ids, config.max_dec_steps, bos_decoding, eos_decoding) 110 | self.dec_len = len(self.dec_inp) 111 | 112 | # If using pointer-generator mode, we need to store some extra info 113 | if config.pointer_gen: 114 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; 115 | # also store the in-article OOVs words themselves 116 | self.enc_inp_extend_vocab, self.article_oovs = utils.article2ids(article_words, vocab) 117 | 118 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id 119 | abs_ids_extend_vocab = utils.abstract2ids(abstract_words, vocab, self.article_oovs) 120 | 121 | # Overwrite decoder target sequence so it uses the temp article OOV ids 122 | _, self.tgt = self.get_dec_seq(abs_ids_extend_vocab, config.max_dec_steps, bos_decoding, eos_decoding) 123 | 124 | # Store the original strings 125 | self.original_article = article 126 | self.original_abstract = abstract 127 | self.original_abstract_sents = abstract_sentences 128 | 129 | def get_dec_seq(self, sequence, max_len, start_id, stop_id): 130 | src = [start_id] + sequence[:] 131 | tgt = sequence[:] 132 | if len(src) > max_len: # truncate 133 | src = src[:max_len] 134 | tgt = tgt[:max_len] # no end_token 135 | else: # no truncation 136 | tgt.append(stop_id) # end token 137 | assert len(src) == len(tgt) 138 | return src, tgt 139 | 140 | def pad_enc_seq(self, max_len, pad_id): 141 | while len(self.enc_inp) < max_len: 142 | self.enc_inp.append(pad_id) 143 | if config.pointer_gen: 144 | while len(self.enc_inp_extend_vocab) < max_len: 145 | self.enc_inp_extend_vocab.append(pad_id) 146 | 147 | def pad_dec_seq(self, max_len, pad_id): 148 | while len(self.dec_inp) < max_len: 149 | self.dec_inp.append(pad_id) 150 | while len(self.tgt) < max_len: 151 | self.tgt.append(pad_id) 152 | 153 | 154 | class Batch(object): 155 | def __init__(self, example_list, vocab, batch_size): 156 | self.batch_size = batch_size 157 | self.pad_id = vocab.word2id(PAD_TOKEN) # id of the PAD token used to pad sequences 158 | self.init_encoder_seq(example_list) # initialize the input to the encoder 159 | self.init_decoder_seq(example_list) # initialize the input and targets for the decoder 160 | self.store_orig_strings(example_list) # store the original strings 161 | 162 | def init_encoder_seq(self, example_list): 163 | # Determine the maximum length of the encoder input sequence in this batch 164 | max_enc_seq_len = max([ex.enc_len for ex in example_list]) 165 | 166 | # Pad the encoder input sequences up to the length of the longest sequence 167 | for ex in example_list: 168 | ex.pad_enc_seq(max_enc_seq_len, self.pad_id) 169 | 170 | # Initialize the numpy arrays 171 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder. 172 | self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32) 173 | self.enc_lens = np.zeros((self.batch_size), dtype=np.int32) 174 | self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32) 175 | 176 | # Fill in the numpy arrays 177 | for i, ex in enumerate(example_list): 178 | self.enc_batch[i, :] = ex.enc_inp[:] 179 | self.enc_lens[i] = ex.enc_len 180 | for j in range(ex.enc_len): 181 | self.enc_padding_mask[i][j] = 1 182 | 183 | # For pointer-generator mode, need to store some extra info 184 | if config.pointer_gen: 185 | # Determine the max number of in-article OOVs in this batch 186 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list]) 187 | # Store the in-article OOVs themselves 188 | self.art_oovs = [ex.article_oovs for ex in example_list] 189 | # Store the version of the enc_batch that uses the article OOV ids 190 | self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32) 191 | for i, ex in enumerate(example_list): 192 | self.enc_batch_extend_vocab[i, :] = ex.enc_inp_extend_vocab[:] 193 | 194 | def init_decoder_seq(self, example_list): 195 | # Pad the inputs and targets 196 | for ex in example_list: 197 | ex.pad_dec_seq(config.max_dec_steps, self.pad_id) 198 | 199 | # Initialize the numpy arrays. 200 | self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32) 201 | self.tgt_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32) 202 | self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32) 203 | self.dec_lens = np.zeros((self.batch_size), dtype=np.int32) 204 | 205 | # Fill in the numpy arrays 206 | for i, ex in enumerate(example_list): 207 | self.dec_batch[i, :] = ex.dec_inp[:] 208 | self.tgt_batch[i, :] = ex.tgt[:] 209 | self.dec_lens[i] = ex.dec_len 210 | for j in range(ex.dec_len): 211 | self.dec_padding_mask[i][j] = 1 212 | 213 | def store_orig_strings(self, example_list): 214 | self.original_articles = [ex.original_article for ex in example_list] # list of lists 215 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists 216 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists 217 | 218 | 219 | class Batcher(object): 220 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold 221 | 222 | def __init__(self, vocab, data_path, batch_size, single_pass, mode): 223 | self._vocab = vocab 224 | self._data_path = data_path 225 | self.batch_size = batch_size 226 | self.single_pass = single_pass 227 | self.mode = mode 228 | 229 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched 230 | self._batch_queue = queue.Queue(self.BATCH_QUEUE_MAX) 231 | self._example_queue = queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size) 232 | 233 | # Different settings depending on whether we're in single_pass mode or not 234 | if single_pass: 235 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once 236 | self._num_batch_q_threads = 1 # just one thread to batch examples 237 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing 238 | self._finished_reading = False # this will tell us when we're finished reading the dataset 239 | else: 240 | self._num_example_q_threads = 1 # num threads to fill example queue 241 | self._num_batch_q_threads = 1 # num threads to fill batch queue 242 | self._bucketing_cache_size = 1 # how many batches-worth of examples to load into cache before bucketing 243 | 244 | # Start the threads that load the queues 245 | self._example_q_threads = [] 246 | for _ in range(self._num_example_q_threads): 247 | self._example_q_threads.append(Thread(target=self.fill_example_queue)) 248 | self._example_q_threads[-1].daemon = True 249 | self._example_q_threads[-1].start() 250 | self._batch_q_threads = [] 251 | for _ in range(self._num_batch_q_threads): 252 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) 253 | self._batch_q_threads[-1].daemon = True 254 | self._batch_q_threads[-1].start() 255 | 256 | # Start a thread that watches the other threads and restarts them if they're dead 257 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever 258 | self._watch_thread = Thread(target=self.watch_threads) 259 | self._watch_thread.daemon = True 260 | self._watch_thread.start() 261 | 262 | def next_batch(self): 263 | # If the batch queue is empty, print a warning 264 | if self._batch_queue.qsize() == 0: 265 | tf.logging.warning( 266 | 'Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', 267 | self._batch_queue.qsize(), self._example_queue.qsize()) 268 | if self.single_pass and self._finished_reading: 269 | tf.logging.info("Finished reading dataset in single_pass mode.") 270 | return None 271 | 272 | batch = self._batch_queue.get() # get the next Batch 273 | return batch 274 | 275 | def fill_example_queue(self): 276 | example_generator = self.example_generator(self._data_path, self.single_pass) 277 | input_gen = self.pair_generator(example_generator) 278 | 279 | while True: 280 | try: 281 | (article, 282 | abstract) = input_gen.__next__() # read the next example from file. article and abstract are both strings. 283 | except StopIteration: # if there are no more examples: 284 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.") 285 | if self.single_pass: 286 | tf.logging.info( 287 | "single_pass mode is on, so we've finished reading dataset. This thread is stopping.") 288 | self._finished_reading = True 289 | break 290 | else: 291 | raise Exception("single_pass mode is off but the example generator is out of data; error.") 292 | 293 | abstract_sentences = [sent.strip() for sent in utils.abstract2sents( 294 | abstract)] # Use the and tags in abstract to get a list of sentences. 295 | example = Example(article, abstract_sentences, self._vocab) 296 | self._example_queue.put(example) 297 | 298 | def fill_batch_queue(self): 299 | while True: 300 | if self.mode == 'decode': 301 | # beam search decode mode single example repeated in the batch 302 | ex = self._example_queue.get() 303 | b = [ex for _ in range(self.batch_size)] 304 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size)) 305 | else: 306 | # Get bucketing_cache_size-many batches of Examples into a list, then sort 307 | inputs = [] 308 | for _ in range(self.batch_size * self._bucketing_cache_size): 309 | inputs.append(self._example_queue.get()) 310 | inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence 311 | 312 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. 313 | batches = [] 314 | for i in range(0, len(inputs), self.batch_size): 315 | batches.append(inputs[i:i + self.batch_size]) 316 | if not self.single_pass: 317 | shuffle(batches) 318 | for b in batches: # each b is a list of Example objects 319 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size)) 320 | 321 | def watch_threads(self): 322 | while True: 323 | tf.logging.info( 324 | 'Bucket queue size: %i, Input queue size: %i', 325 | self._batch_queue.qsize(), self._example_queue.qsize()) 326 | 327 | time.sleep(60) 328 | for idx, t in enumerate(self._example_q_threads): 329 | if not t.is_alive(): # if the thread is dead 330 | tf.logging.error('Found example queue thread dead. Restarting.') 331 | new_t = Thread(target=self.fill_example_queue) 332 | self._example_q_threads[idx] = new_t 333 | new_t.daemon = True 334 | new_t.start() 335 | for idx, t in enumerate(self._batch_q_threads): 336 | if not t.is_alive(): # if the thread is dead 337 | tf.logging.error('Found batch queue thread dead. Restarting.') 338 | new_t = Thread(target=self.fill_batch_queue) 339 | self._batch_q_threads[idx] = new_t 340 | new_t.daemon = True 341 | new_t.start() 342 | 343 | def pair_generator(self, example_generator): 344 | while True: 345 | e = example_generator.__next__() # e is a tf.Example 346 | try: 347 | article_text = e.features.feature['article'].bytes_list.value[ 348 | 0] # the article text was saved under the key 'article' in the data files 349 | abstract_text = e.features.feature['abstract'].bytes_list.value[ 350 | 0] # the abstract text was saved under the key 'abstract' in the data files 351 | except ValueError: 352 | tf.logging.error('Failed to get article or abstract from example') 353 | continue 354 | if len(article_text) == 0: # See https://github.com/abisee/pointer-generator/issues/1 355 | # tf.logging.warning('Found an example with empty article text. Skipping it.') 356 | continue 357 | else: 358 | yield (article_text, abstract_text) 359 | 360 | def example_generator(self, data_path, single_pass): 361 | while True: 362 | filelist = glob.glob(data_path) # get the list of datafiles 363 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 364 | if single_pass: 365 | filelist = sorted(filelist) 366 | else: 367 | random.shuffle(filelist) 368 | for f in filelist: 369 | reader = open(f, 'rb') 370 | while True: 371 | len_bytes = reader.read(8) 372 | if not len_bytes: break # finished reading this file 373 | str_len = struct.unpack('q', len_bytes)[0] 374 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 375 | yield example_pb2.Example.FromString(example_str) 376 | if single_pass: 377 | print("example_generator completed reading all datafiles. No more data.") 378 | break -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import pyrouge 5 | import logging 6 | import numpy as np 7 | 8 | import torch 9 | import tensorflow as tf 10 | from torch.autograd import Variable 11 | 12 | from utils import config 13 | 14 | 15 | def article2ids(article_words, vocab): 16 | ids = [] 17 | oov = [] 18 | unk_id = vocab.word2id(config.UNK_TOKEN) 19 | for w in article_words: 20 | i = vocab.word2id(w) 21 | if i == unk_id: # If w is OOV 22 | if w not in oov: # Add to list of OOVs 23 | oov.append(w) 24 | oov_num = oov.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 25 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 26 | else: 27 | ids.append(i) 28 | return ids, oov 29 | 30 | 31 | def abstract2ids(abstract_words, vocab, article_oovs): 32 | ids = [] 33 | unk_id = vocab.word2id(config.UNK_TOKEN) 34 | for w in abstract_words: 35 | i = vocab.word2id(w) 36 | if i == unk_id: # If w is an OOV word 37 | if w in article_oovs: # If w is an in-article OOV 38 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 39 | ids.append(vocab_idx) 40 | else: # If w is an out-of-article OOV 41 | ids.append(unk_id) # Map to the UNK token id 42 | else: 43 | ids.append(i) 44 | return ids 45 | 46 | 47 | def outputids2words(id_list, vocab, article_oovs): 48 | words = [] 49 | for i in id_list: 50 | try: 51 | w = vocab.id2word(i) # might be [UNK] 52 | except ValueError as e: # w is OOV 53 | assert article_oovs is not None, \ 54 | "Error: models produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 55 | article_oov_idx = i - vocab.size() 56 | try: 57 | w = article_oovs[article_oov_idx] 58 | except ValueError as e: # i doesn't correspond to an article oov 59 | raise ValueError( 60 | 'Error: models produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % ( 61 | i, article_oov_idx, len(article_oovs))) 62 | words.append(w) 63 | return words 64 | 65 | 66 | def abstract2sents(abstract): 67 | cur_p = 0 68 | sents = [] 69 | while True: 70 | try: 71 | sta_p = abstract.index(config.SENTENCE_STA.encode(), cur_p) 72 | end_p = abstract.index(config.SENTENCE_END.encode(), sta_p + 1) 73 | cur_p = end_p + len(config.SENTENCE_END.encode()) 74 | sents.append(abstract[sta_p + len(config.SENTENCE_STA.encode()):end_p]) 75 | except ValueError as e: # no more sentences 76 | return sents 77 | 78 | 79 | def show_art_oovs(article, vocab): 80 | unk_token = vocab.word2id(config.UNK_TOKEN) 81 | words = article.split(' ') 82 | words = [("__%s__" % w) if vocab.word2id(w) == unk_token else w for w in words] 83 | out_str = ' '.join(words) 84 | return out_str 85 | 86 | 87 | def show_abs_oovs(abstract, vocab, article_oovs): 88 | unk_token = vocab.word2id(config.UNK_TOKEN) 89 | words = abstract.split(' ') 90 | new_words = [] 91 | for w in words: 92 | if vocab.word2id(w) == unk_token: # w is oov 93 | if article_oovs is None: # baseline mode 94 | new_words.append("__%s__" % w) 95 | else: # pointer-generator mode 96 | if w in article_oovs: 97 | new_words.append("__%s__" % w) 98 | else: 99 | new_words.append("!!__%s__!!" % w) 100 | else: # w is in-vocab word 101 | new_words.append(w) 102 | out_str = ' '.join(new_words) 103 | return out_str 104 | 105 | 106 | def print_results(article, abstract, decoded_output): 107 | print("") 108 | print('ARTICLE: %s', article) 109 | print('REFERENCE SUMMARY: %s', abstract) 110 | print('GENERATED SUMMARY: %s', decoded_output) 111 | print("") 112 | 113 | 114 | def make_html_safe(s): 115 | s.replace("<", "<") 116 | s.replace(">", ">") 117 | return s 118 | 119 | 120 | def rouge_eval(ref_dir, dec_dir): 121 | r = pyrouge.Rouge155() 122 | r.model_filename_pattern = '#ID#_reference.txt' 123 | r.system_filename_pattern = '(\d+)_decoded.txt' 124 | r.model_dir = ref_dir 125 | r.system_dir = dec_dir 126 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging 127 | rouge_results = r.convert_and_evaluate() 128 | return r.output_to_dict(rouge_results) 129 | 130 | 131 | def rouge_log(results_dict, dir_to_write): 132 | log_str = "" 133 | for x in ["1", "2", "l"]: 134 | log_str += "\nROUGE-%s:\n" % x 135 | for y in ["f_score", "recall", "precision"]: 136 | key = "rouge_%s_%s" % (x, y) 137 | key_cb = key + "_cb" 138 | key_ce = key + "_ce" 139 | val = results_dict[key] 140 | val_cb = results_dict[key_cb] 141 | val_ce = results_dict[key_ce] 142 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 143 | print(log_str) 144 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt") 145 | print("Writing final ROUGE results to %s..." % (results_file)) 146 | with open(results_file, "w") as f: 147 | f.write(log_str) 148 | 149 | 150 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): 151 | if running_avg_loss == 0: # on the first iteration just take the loss 152 | running_avg_loss = loss 153 | else: 154 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 155 | running_avg_loss = min(running_avg_loss, 12) # clip 156 | loss_sum = tf.Summary() 157 | tag_name = 'running_avg_loss/decay=%f' % (decay) 158 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) 159 | summary_writer.add_summary(loss_sum, step) 160 | return running_avg_loss 161 | 162 | 163 | def write_for_rouge(reference_sents, decoded_words, ex_index, 164 | _rouge_ref_dir, _rouge_dec_dir): 165 | decoded_sents = [] 166 | while len(decoded_words) > 0: 167 | try: 168 | fst_period_idx = decoded_words.index(".") 169 | except ValueError: 170 | fst_period_idx = len(decoded_words) 171 | sent = decoded_words[:fst_period_idx + 1] 172 | decoded_words = decoded_words[fst_period_idx + 1:] 173 | decoded_sents.append(' '.join(sent)) 174 | 175 | # pyrouge calls a perl script that puts the data into HTML files. 176 | # Therefore we need to make our output HTML safe. 177 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 178 | reference_sents = [make_html_safe(w) for w in reference_sents] 179 | 180 | ref_file = os.path.join(_rouge_ref_dir, "%06d_reference.txt" % ex_index) 181 | decoded_file = os.path.join(_rouge_dec_dir, "%06d_decoded.txt" % ex_index) 182 | 183 | with open(ref_file, "w") as f: 184 | for idx, sent in enumerate(reference_sents): 185 | f.write(sent) if idx == len(reference_sents) - 1 else f.write(sent + "\n") 186 | with open(decoded_file, "w") as f: 187 | for idx, sent in enumerate(decoded_sents): 188 | f.write(sent) if idx == len(decoded_sents) - 1 else f.write(sent + "\n") 189 | 190 | # print("Wrote example %i to file" % ex_index) 191 | 192 | 193 | def get_input_from_batch(batch, use_cuda): 194 | extra_zeros = None 195 | enc_lens = batch.enc_lens 196 | max_enc_len = np.max(enc_lens) 197 | enc_batch_extend_vocab = None 198 | batch_size = len(batch.enc_lens) 199 | enc_batch = Variable(torch.from_numpy(batch.enc_batch).long()) 200 | enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float() 201 | 202 | if config.pointer_gen: 203 | enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long()) 204 | # max_art_oovs is the max over all the article oov list in the batch 205 | if batch.max_art_oovs > 0: 206 | extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs))) 207 | 208 | c_t = Variable(torch.zeros((batch_size, 2 * config.hidden_dim))) 209 | 210 | coverage = None 211 | if config.is_coverage: 212 | coverage = Variable(torch.zeros(enc_batch.size())) 213 | 214 | enc_pos = np.zeros((batch_size, max_enc_len)) 215 | for i, inst in enumerate(batch.enc_batch): 216 | for j, w_i in enumerate(inst): 217 | if w_i != config.PAD: 218 | enc_pos[i, j] = (j + 1) 219 | else: 220 | break 221 | enc_pos = Variable(torch.from_numpy(enc_pos).long()) 222 | 223 | if use_cuda: 224 | c_t = c_t.cuda() 225 | enc_pos = enc_pos.cuda() 226 | enc_batch = enc_batch.cuda() 227 | enc_padding_mask = enc_padding_mask.cuda() 228 | 229 | if coverage is not None: 230 | coverage = coverage.cuda() 231 | 232 | if extra_zeros is not None: 233 | extra_zeros = extra_zeros.cuda() 234 | 235 | if enc_batch_extend_vocab is not None: 236 | enc_batch_extend_vocab = enc_batch_extend_vocab.cuda() 237 | 238 | 239 | return enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t, coverage 240 | 241 | 242 | def get_output_from_batch(batch, use_cuda): 243 | dec_lens = batch.dec_lens 244 | max_dec_len = np.max(dec_lens) 245 | batch_size = len(batch.dec_lens) 246 | dec_lens = Variable(torch.from_numpy(dec_lens)).float() 247 | tgt_batch = Variable(torch.from_numpy(batch.tgt_batch)).long() 248 | dec_batch = Variable(torch.from_numpy(batch.dec_batch).long()) 249 | dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float() 250 | 251 | dec_pos = np.zeros((batch_size, config.max_dec_steps)) 252 | for i, inst in enumerate(batch.dec_batch): 253 | for j, w_i in enumerate(inst): 254 | if w_i != config.PAD: 255 | dec_pos[i, j] = (j + 1) 256 | else: 257 | break 258 | dec_pos = Variable(torch.from_numpy(dec_pos).long()) 259 | 260 | if use_cuda: 261 | dec_lens = dec_lens.cuda() 262 | tgt_batch = tgt_batch.cuda() 263 | dec_batch = dec_batch.cuda() 264 | dec_padding_mask = dec_padding_mask.cuda() 265 | dec_pos = dec_pos.cuda() 266 | 267 | return dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch 268 | --------------------------------------------------------------------------------