├── data_util ├── __init__.py ├── config.py ├── utils.py ├── data.py └── batcher.py ├── .gitignore ├── training_ptr_gen ├── __init__.py ├── train_util.py ├── eval.py ├── train.py └── model.py └── README.md /data_util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.bin 3 | -------------------------------------------------------------------------------- /training_ptr_gen/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty File 2 | -------------------------------------------------------------------------------- /data_util/config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | root_dir = os.path.expanduser("~") 5 | 6 | train_data_path = "/dataset_path/chunked/train_*" 7 | eval_data_path = "/dataset_path/finished_files/val.bin" 8 | decode_data_path = "/dataset_path/finished_files/test.bin" 9 | vocab_path = "/dataset_path/finished_files/vocab" 10 | log_root = "/log_path/pointer_summarizer" 11 | 12 | # Hyperparameters 13 | hidden_dim= 256 14 | emb_dim= 128 15 | batch_size= 64 16 | max_enc_steps=400 17 | max_dec_steps=100 18 | beam_size=4 19 | min_dec_steps=35 20 | vocab_size=50000 21 | 22 | lr=0.001 23 | adagrad_init_acc=0.1 24 | rand_unif_init_mag=0.02 25 | trunc_norm_init_std=1e-4 26 | max_grad_norm=2.0 27 | 28 | pointer_gen = True 29 | is_coverage = False 30 | cov_loss_wt = 1.0 31 | 32 | eps = 1e-12 33 | max_iterations = 50000 34 | 35 | use_gpu=True 36 | 37 | lr_coverage=0.15 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of *[Automatic Fact-guided Sentence Modification](https://arxiv.org/pdf/1909.13838.pdf)* (AAAI 2020). 2 | 3 | This repository implements the split encoder pointer generator module of our method. 4 | 5 | 6 | The code for the Masker is [here](https://github.com/TalSchuster/TokenMasker). 7 | 8 | 9 | ______________________________________________________________________________________________________________________________ 10 | 11 | Repository cloned and updated from https://github.com/atulkum/pointer_summarizer. 12 | 13 | 14 | 15 | Note: 16 | * It is tested on pytorch 0.4 with python 2.7 17 | 18 | ______________________________________________________________________________________________________________________________ 19 | 20 | Training the Model: 21 | 22 |

export PYTHONPATH=`pwd` && 23 | python training_ptr_gen/train.py 24 |

25 | 26 | ______________________________________________________________________________________________________________________________ 27 | 28 | 29 | Dataset: 30 | 31 | The dataset for training this model can be found here https://drive.google.com/open?id=1aOMEUksFpZwJDtQcgsrJ0rjC7nO2J1kr. 32 | 33 | (Download and edit the config file to the path of the train, val, test and vocab files.) 34 | 35 | 36 | ______________________________________________________________________________________________________________________________ 37 | 38 | Evaluation: 39 | 40 |

41 | export PYTHONPATH=`pwd` && 42 | python training_ptr_gen/eval.py _path_of_model_checkpoint 43 |

44 | 45 | This will generate the corresponding outputs for the desired eval file (specified in the validation path). 46 | 47 | ______________________________________________________________________________________________________________________________ 48 | 49 | If you find this repository helpful, please cite our paper: 50 | ``` 51 | @inproceedings{shah2020automatic, 52 | title={Automatic Fact-guided Sentence Modification}, 53 | author={Darsh J Shah and Tal Schuster and Regina Barzilay}, 54 | booktitle={Association for the Advancement of Artificial Intelligence ({AAAI})}, 55 | year={2020}, 56 | url={https://arxiv.org/pdf/1909.13838.pdf} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /training_ptr_gen/train_util.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import numpy as np 3 | import torch 4 | from data_util import config 5 | 6 | def get_input_from_batch(batch, use_cuda): 7 | batch_size = len(batch.enc_lens) 8 | 9 | enc_batch = Variable(torch.from_numpy(batch.enc_batch).long()) 10 | enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float() 11 | enc_lens = batch.enc_lens 12 | 13 | enc_batch_2 = Variable(torch.from_numpy(batch.enc_batch_2).long()) 14 | enc_padding_mask_2 = Variable(torch.from_numpy(batch.enc_padding_mask_2).float()) 15 | enc_lens_2 = batch.enc_lens_2 16 | 17 | extra_zeros = None 18 | enc_batch_extend_vocab = None 19 | 20 | extra_zeros_2 = None 21 | enc_batch_extend_vocab_2 = None 22 | 23 | if config.pointer_gen: 24 | enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long()) 25 | enc_batch_extend_vocab_2 = Variable(torch.from_numpy(batch.enc_batch_extend_vocab_2).long()) 26 | # max_art_oovs is the max over all the article oov list in the batch 27 | if batch.max_art_oovs > 0 or batch.max_art_oovs_2 > 0: 28 | extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs))) 29 | extra_zeros_2 = Variable(torch.zeros((batch_size, batch.max_art_oovs_2))) 30 | 31 | c_t_1 = Variable(torch.zeros((batch_size, 2 * config.hidden_dim))) 32 | c_t_1_2 = Variable(torch.zeros((batch_size, 2 * config.hidden_dim))) 33 | 34 | coverage = None 35 | coverage_2 = None 36 | if config.is_coverage: 37 | coverage = Variable(torch.zeros(enc_batch.size())) 38 | coverage_2 = Variable(torch.zeros(enc_batch_2.size())) 39 | 40 | if use_cuda: 41 | enc_batch = enc_batch.cuda() 42 | enc_padding_mask = enc_padding_mask.cuda() 43 | enc_batch_2 = enc_batch_2.cuda() 44 | enc_padding_mask_2 = enc_padding_mask_2.cuda() 45 | 46 | if enc_batch_extend_vocab is not None: 47 | enc_batch_extend_vocab = enc_batch_extend_vocab.cuda() 48 | enc_batch_extend_vocab_2 = enc_batch_extend_vocab_2.cuda() 49 | if extra_zeros is not None: 50 | extra_zeros = extra_zeros.cuda() 51 | extra_zeros_2 = extra_zeros_2.cuda() 52 | c_t_1 = c_t_1.cuda() 53 | c_t_1_2 = c_t_1_2.cuda() 54 | 55 | if coverage is not None: 56 | coverage = coverage.cuda() 57 | coverage_2 = coverage_2.cuda() 58 | 59 | return [enc_batch, enc_batch_2], [enc_padding_mask, enc_padding_mask_2], [enc_lens, enc_lens_2], [enc_batch_extend_vocab, enc_batch_extend_vocab_2], [extra_zeros, extra_zeros_2], [c_t_1, c_t_1_2], [coverage, coverage_2] 60 | 61 | def get_output_from_batch(batch, use_cuda): 62 | dec_batch = Variable(torch.from_numpy(batch.dec_batch).long()) 63 | dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float() 64 | dec_lens = batch.dec_lens 65 | max_dec_len = np.max(dec_lens) 66 | dec_lens_var = Variable(torch.from_numpy(dec_lens)).float() 67 | 68 | target_batch = Variable(torch.from_numpy(batch.target_batch)).long() 69 | 70 | if use_cuda: 71 | dec_batch = dec_batch.cuda() 72 | dec_padding_mask = dec_padding_mask.cuda() 73 | dec_lens_var = dec_lens_var.cuda() 74 | target_batch = target_batch.cuda() 75 | 76 | 77 | return dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch 78 | 79 | -------------------------------------------------------------------------------- /data_util/utils.py: -------------------------------------------------------------------------------- 1 | #Content of this file is copied from https://github.com/abisee/pointer-generator/blob/master/ 2 | import os 3 | import pyrouge 4 | import logging 5 | import tensorflow as tf 6 | 7 | def print_results(article, abstract, decoded_output): 8 | print ("") 9 | print('ARTICLE: %s', article) 10 | print('REFERENCE SUMMARY: %s', abstract) 11 | print('GENERATED SUMMARY: %s', decoded_output) 12 | print( "") 13 | 14 | 15 | def make_html_safe(s): 16 | s.replace("<", "<") 17 | s.replace(">", ">") 18 | return s 19 | 20 | 21 | def rouge_eval(ref_dir, dec_dir): 22 | r = pyrouge.Rouge155() 23 | r.model_filename_pattern = '#ID#_reference.txt' 24 | r.system_filename_pattern = '(\d+)_decoded.txt' 25 | r.model_dir = ref_dir 26 | r.system_dir = dec_dir 27 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging 28 | rouge_results = r.convert_and_evaluate() 29 | return r.output_to_dict(rouge_results) 30 | 31 | 32 | def rouge_log(results_dict, dir_to_write): 33 | log_str = "" 34 | for x in ["1","2","l"]: 35 | log_str += "\nROUGE-%s:\n" % x 36 | for y in ["f_score", "recall", "precision"]: 37 | key = "rouge_%s_%s" % (x,y) 38 | key_cb = key + "_cb" 39 | key_ce = key + "_ce" 40 | val = results_dict[key] 41 | val_cb = results_dict[key_cb] 42 | val_ce = results_dict[key_ce] 43 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 44 | print(log_str) 45 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt") 46 | print("Writing final ROUGE results to %s..."%(results_file)) 47 | with open(results_file, "w") as f: 48 | f.write(log_str) 49 | 50 | 51 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): 52 | if running_avg_loss == 0: # on the first iteration just take the loss 53 | running_avg_loss = loss 54 | else: 55 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 56 | running_avg_loss = min(running_avg_loss, 12) # clip 57 | loss_sum = tf.Summary() 58 | tag_name = 'running_avg_loss/decay=%f' % (decay) 59 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) 60 | summary_writer.add_summary(loss_sum, step) 61 | return running_avg_loss 62 | 63 | 64 | def write_for_rouge(reference_sents, decoded_words, ex_index, 65 | _rouge_ref_dir, _rouge_dec_dir): 66 | decoded_sents = [] 67 | while len(decoded_words) > 0: 68 | try: 69 | fst_period_idx = decoded_words.index(".") 70 | except ValueError: 71 | fst_period_idx = len(decoded_words) 72 | sent = decoded_words[:fst_period_idx + 1] 73 | decoded_words = decoded_words[fst_period_idx + 1:] 74 | decoded_sents.append(' '.join(sent)) 75 | 76 | # pyrouge calls a perl script that puts the data into HTML files. 77 | # Therefore we need to make our output HTML safe. 78 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 79 | reference_sents = [make_html_safe(w) for w in reference_sents] 80 | 81 | ref_file = os.path.join(_rouge_ref_dir, "%06d_reference.txt" % ex_index) 82 | decoded_file = os.path.join(_rouge_dec_dir, "%06d_decoded.txt" % ex_index) 83 | 84 | with open(ref_file, "w") as f: 85 | for idx, sent in enumerate(reference_sents): 86 | f.write(sent) if idx == len(reference_sents) - 1 else f.write(sent + "\n") 87 | with open(decoded_file, "w") as f: 88 | for idx, sent in enumerate(decoded_sents): 89 | f.write(sent) if idx == len(decoded_sents) - 1 else f.write(sent + "\n") 90 | 91 | #print("Wrote example %i to file" % ex_index) 92 | -------------------------------------------------------------------------------- /data_util/data.py: -------------------------------------------------------------------------------- 1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/data.py 2 | 3 | import glob 4 | import random 5 | import struct 6 | import csv 7 | from tensorflow.core.example import example_pb2 8 | 9 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 10 | SENTENCE_START = '' 11 | SENTENCE_END = '' 12 | 13 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 14 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 15 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 16 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 17 | 18 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 19 | 20 | 21 | class Vocab(object): 22 | 23 | def __init__(self, vocab_file, max_size): 24 | self.word_to_id = {} 25 | self._word_to_id = {} 26 | self._id_to_word = {} 27 | self._count = 0 # keeps track of total number of words in the Vocab 28 | 29 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. 30 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 31 | self._word_to_id[w] = self._count 32 | self._id_to_word[self._count] = w 33 | self._count += 1 34 | 35 | # Read the vocab file and add words up to max_size 36 | with open(vocab_file, 'r') as vocab_f: 37 | for line in vocab_f: 38 | pieces = line.split() 39 | if len(pieces) != 2: 40 | print 'Warning: incorrectly formatted line in vocabulary file: %s\n' % line 41 | continue 42 | w = pieces[0] 43 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 44 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 45 | if w in self._word_to_id: 46 | raise Exception('Duplicated word in vocabulary file: %s' % w) 47 | self._word_to_id[w] = self._count 48 | self._id_to_word[self._count] = w 49 | self._count += 1 50 | if max_size != 0 and self._count >= max_size: 51 | print "max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count) 52 | break 53 | 54 | print "Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]) 55 | self.word_to_id = self._word_to_id 56 | 57 | 58 | def word2id(self, word): 59 | if word not in self._word_to_id: 60 | return self._word_to_id[UNKNOWN_TOKEN] 61 | return self._word_to_id[word] 62 | 63 | def id2word(self, word_id): 64 | if word_id not in self._id_to_word: 65 | raise ValueError('Id not found in vocab: %d' % word_id) 66 | return self._id_to_word[word_id] 67 | 68 | def size(self): 69 | return self._count 70 | 71 | def write_metadata(self, fpath): 72 | print "Writing word embedding metadata file to %s..." % (fpath) 73 | with open(fpath, "w") as f: 74 | fieldnames = ['word'] 75 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 76 | for i in xrange(self.size()): 77 | writer.writerow({"word": self._id_to_word[i]}) 78 | 79 | 80 | def example_generator(data_path, single_pass): 81 | while True: 82 | filelist = glob.glob(data_path) # get the list of datafiles 83 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 84 | if single_pass: 85 | filelist = sorted(filelist) 86 | else: 87 | random.shuffle(filelist) 88 | for f in filelist: 89 | reader = open(f, 'rb') 90 | while True: 91 | len_bytes = reader.read(8) 92 | if not len_bytes: break # finished reading this file 93 | str_len = struct.unpack('q', len_bytes)[0] 94 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 95 | yield example_pb2.Example.FromString(example_str) 96 | if single_pass: 97 | print "example_generator completed reading all datafiles. No more data." 98 | break 99 | 100 | 101 | def article2ids(article_words, vocab): 102 | ids = [] 103 | oovs = [] 104 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 105 | for w in article_words: 106 | i = vocab.word2id(w) 107 | if i == unk_id: # If w is OOV 108 | if w not in oovs: # Add to list of OOVs 109 | oovs.append(w) 110 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 111 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 112 | else: 113 | ids.append(i) 114 | return ids, oovs 115 | 116 | 117 | def abstract2ids(abstract_words, vocab, article_oovs): 118 | ids = [] 119 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 120 | for w in abstract_words: 121 | i = vocab.word2id(w) 122 | if i == unk_id: # If w is an OOV word 123 | if w in article_oovs: # If w is an in-article OOV 124 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 125 | ids.append(vocab_idx) 126 | else: # If w is an out-of-article OOV 127 | ids.append(unk_id) # Map to the UNK token id 128 | else: 129 | ids.append(i) 130 | return ids 131 | 132 | 133 | def outputids2words(id_list, vocab, article_oovs): 134 | words = [] 135 | for i in id_list: 136 | try: 137 | w = vocab.id2word(i) # might be [UNK] 138 | except ValueError as e: # w is OOV 139 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 140 | article_oov_idx = i - vocab.size() 141 | try: 142 | w = article_oovs[article_oov_idx] 143 | except ValueError as e: # i doesn't correspond to an article oov 144 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 145 | words.append(w) 146 | return words 147 | 148 | 149 | def abstract2sents(abstract): 150 | cur = 0 151 | sents = [] 152 | while True: 153 | try: 154 | start_p = abstract.index(SENTENCE_START, cur) 155 | end_p = abstract.index(SENTENCE_END, start_p + 1) 156 | cur = end_p + len(SENTENCE_END) 157 | sents.append(abstract[start_p+len(SENTENCE_START):end_p]) 158 | except ValueError as e: # no more sentences 159 | return sents 160 | 161 | 162 | def show_art_oovs(article, vocab): 163 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 164 | words = article.split(' ') 165 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words] 166 | out_str = ' '.join(words) 167 | return out_str 168 | 169 | 170 | def show_abs_oovs(abstract, vocab, article_oovs): 171 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 172 | words = abstract.split(' ') 173 | new_words = [] 174 | for w in words: 175 | if vocab.word2id(w) == unk_token: # w is oov 176 | if article_oovs is None: # baseline mode 177 | new_words.append("__%s__" % w) 178 | else: # pointer-generator mode 179 | if w in article_oovs: 180 | new_words.append("__%s__" % w) 181 | else: 182 | new_words.append("!!__%s__!!" % w) 183 | else: # w is in-vocab word 184 | new_words.append(w) 185 | out_str = ' '.join(new_words) 186 | return out_str 187 | -------------------------------------------------------------------------------- /training_ptr_gen/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import os 4 | import time 5 | import sys 6 | 7 | import tensorflow as tf 8 | import torch 9 | 10 | from data_util import config 11 | from data_util.batcher import Batcher 12 | from data_util.data import Vocab 13 | 14 | from data_util.utils import calc_running_avg_loss 15 | from train_util import get_input_from_batch, get_output_from_batch 16 | from model import Model 17 | import sys 18 | reload(sys) 19 | sys.setdefaultencoding('utf8') 20 | 21 | use_cuda = config.use_gpu and torch.cuda.is_available() 22 | 23 | class Evaluate(object): 24 | def __init__(self, model_file_path): 25 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 26 | self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', 27 | batch_size=config.batch_size, single_pass=True) 28 | time.sleep(15) 29 | model_name = os.path.basename(model_file_path) 30 | 31 | eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name)) 32 | if not os.path.exists(eval_dir): 33 | os.mkdir(eval_dir) 34 | self.summary_writer = tf.summary.FileWriter(eval_dir) 35 | 36 | self.model = Model(model_file_path, is_eval=True) 37 | 38 | def eval_one_batch(self, batch): 39 | enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \ 40 | get_input_from_batch(batch, use_cuda) 41 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ 42 | get_output_from_batch(batch, use_cuda) 43 | 44 | encoder_outputs_list = [] 45 | encoder_feature_list = [] 46 | s_t_1 = None 47 | s_t_1_0 = None 48 | s_t_1_1 = None 49 | for enc_batch,enc_lens in zip(enc_batch_list, enc_lens_list): 50 | sorted_indices = sorted(range(len(enc_lens)),key=enc_lens.__getitem__) 51 | sorted_indices.reverse() 52 | inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))] 53 | for index,position in enumerate(sorted_indices): 54 | inverse_sorted_indices[position] = index 55 | sorted_enc_batch = torch.index_select(enc_batch, 0, torch.LongTensor(sorted_indices) if not use_cuda else torch.LongTensor(sorted_indices).cuda()) 56 | sorted_enc_lens = enc_lens[sorted_indices] 57 | sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(sorted_enc_batch 58 | , sorted_enc_lens) 59 | encoder_outputs = torch.index_select(sorted_encoder_outputs, 0, torch.LongTensor(inverse_sorted_indices) if 60 | not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()) 61 | encoder_feature = torch.index_select(sorted_encoder_feature.view(encoder_outputs.shape), 0, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()).view(sorted_encoder_feature.shape) 62 | encoder_hidden = tuple([torch.index_select(sorted_encoder_hidden[0], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()), torch.index_select(sorted_encoder_hidden[1], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda())]) 63 | #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) 64 | encoder_outputs_list.append(encoder_outputs) 65 | encoder_feature_list.append(encoder_feature) 66 | if s_t_1 is None: 67 | s_t_1 = self.model.reduce_state(encoder_hidden) 68 | s_t_1_0, s_t_1_1 = s_t_1 69 | else: 70 | s_t_1_new = self.model.reduce_state(encoder_hidden) 71 | s_t_1_0 = s_t_1_0 + s_t_1_new[0] 72 | s_t_1_1 = s_t_1_1 + s_t_1_new[1] 73 | s_t_1 = tuple([s_t_1_0, s_t_1_1]) 74 | 75 | #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) 76 | #s_t_1 = self.model.reduce_state(encoder_hidden) 77 | 78 | step_losses = [] 79 | target_words = [] 80 | output_words = [] 81 | id_to_words = {v: k for k, v in self.vocab.word_to_id.iteritems()} 82 | for di in range(min(max_dec_len, config.max_dec_steps)): 83 | y_t_1 = dec_batch[:, di] # Teacher forcing 84 | final_dist, s_t_1, c_t_1_list,attn_dist_list, p_gen, next_coverage_list = self.model.decoder(y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list, enc_padding_mask_list, c_t_1_list, extra_zeros_list, enc_batch_extend_vocab_list, coverage_list, di) 85 | target = target_batch[:, di] 86 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() 87 | output_ids = final_dist.max(1)[1] 88 | output_2_candidates = final_dist.topk(2,1)[1] 89 | for ind in range(output_ids.shape[0]): 90 | if self.vocab.word_to_id['X'] == output_ids[ind].item(): 91 | output_ids[ind] = output_2_candidates[ind][1] 92 | target_step = [] 93 | output_step = [] 94 | step_mask = dec_padding_mask[:, di] 95 | for i in range(target.shape[0]): 96 | if target[i].item() >= len(id_to_words) or step_mask[i].item() == 0: 97 | target[i] = 0 98 | target_step.append(id_to_words[target[i].item()]) 99 | if output_ids[i].item() >= len(id_to_words) or step_mask[i].item() == 0: 100 | output_ids[i] = 0 101 | output_step.append(id_to_words[output_ids[i].item()]) 102 | target_words.append(target_step) 103 | output_words.append(output_step) 104 | step_loss = -torch.log(gold_probs + config.eps) 105 | if config.is_coverage: 106 | #step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) 107 | #step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 108 | #coverage = next_coverage 109 | step_coverage_loss = 0.0 110 | for ind in range(len(coverage_list)): 111 | step_coverage_loss += torch.sum(torch.min(attn_dist_list[ind], coverage_list[ind]), 1) 112 | coverage_list[ind] = next_coverage_list[ind] 113 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 114 | 115 | step_mask = dec_padding_mask[:, di] 116 | step_loss = step_loss * step_mask 117 | step_losses.append(step_loss) 118 | 119 | self.write_words(output_words,"output.txt") 120 | self.write_words(target_words,"input.txt") 121 | 122 | sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1) 123 | batch_avg_loss = sum_step_losses / dec_lens_var 124 | loss = torch.mean(batch_avg_loss) 125 | 126 | return loss.item() 127 | 128 | def write_words(self, len_batch_words, output_file): 129 | batch_sentences = ["" for _ in range(len(len_batch_words[0]))] 130 | batch_sentences_done = [False for _ in range(len(len_batch_words[0]))] 131 | for i in range(len(len_batch_words)): 132 | for j in range(len(len_batch_words[i])): 133 | if len_batch_words[i][j] == "[STOP]": 134 | batch_sentences_done[j] = True 135 | if batch_sentences_done[j] != True: 136 | batch_sentences[j] += len_batch_words[i][j] + " " 137 | f = open(output_file, "a") 138 | for sentence in batch_sentences: 139 | f.write(sentence.strip() + "\n") 140 | f.close() 141 | 142 | 143 | def run_eval(self): 144 | running_avg_loss, iter = 0, 0 145 | start = time.time() 146 | batch = self.batcher.next_batch() 147 | while batch is not None: 148 | loss = self.eval_one_batch(batch) 149 | 150 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 151 | iter += 1 152 | 153 | if iter % 100 == 0: 154 | self.summary_writer.flush() 155 | print_interval = 1 156 | if iter % print_interval == 0: 157 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % ( 158 | iter, print_interval, time.time() - start, running_avg_loss)) 159 | start = time.time() 160 | batch = self.batcher.next_batch() 161 | 162 | 163 | if __name__ == '__main__': 164 | model_filename = sys.argv[1] 165 | eval_processor = Evaluate(model_filename) 166 | eval_processor.run_eval() 167 | 168 | 169 | -------------------------------------------------------------------------------- /training_ptr_gen/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import os 4 | import time 5 | import argparse 6 | 7 | import tensorflow as tf 8 | import torch 9 | from model import Model 10 | from torch.nn.utils import clip_grad_norm_ 11 | 12 | from torch.optim import Adagrad, Adam 13 | 14 | from data_util import config 15 | from data_util.batcher import Batcher 16 | from data_util.data import Vocab 17 | from data_util.utils import calc_running_avg_loss 18 | from train_util import get_input_from_batch, get_output_from_batch 19 | 20 | use_cuda = config.use_gpu and torch.cuda.is_available() 21 | 22 | class Train(object): 23 | def __init__(self): 24 | self.vocab = Vocab(config.vocab_path, config.vocab_size) 25 | self.batcher = Batcher(config.train_data_path, self.vocab, mode='train', 26 | batch_size=config.batch_size, single_pass=False) 27 | time.sleep(15) 28 | 29 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) 30 | if not os.path.exists(train_dir): 31 | os.mkdir(train_dir) 32 | 33 | self.model_dir = os.path.join(train_dir, 'model') 34 | if not os.path.exists(self.model_dir): 35 | os.mkdir(self.model_dir) 36 | 37 | self.summary_writer = tf.summary.FileWriter(train_dir) 38 | 39 | def save_model(self, running_avg_loss, iter): 40 | state = { 41 | 'iter': iter, 42 | 'encoder_state_dict': self.model.encoder.state_dict(), 43 | 'decoder_state_dict': self.model.decoder.state_dict(), 44 | 'reduce_state_dict': self.model.reduce_state.state_dict(), 45 | 'optimizer': self.optimizer.state_dict(), 46 | 'current_loss': running_avg_loss 47 | } 48 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) 49 | torch.save(state, model_save_path) 50 | 51 | def setup_train(self, model_file_path=None): 52 | self.model = Model(model_file_path) 53 | 54 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \ 55 | list(self.model.reduce_state.parameters()) 56 | initial_lr = config.lr_coverage if config.is_coverage else config.lr 57 | self.optimizer = Adam(params, lr=initial_lr)#Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc) 58 | 59 | start_iter, start_loss = 0, 0 60 | 61 | if model_file_path is not None: 62 | state = torch.load(model_file_path, map_location= lambda storage, location: storage) 63 | start_iter = state['iter'] 64 | start_loss = state['current_loss'] 65 | 66 | if not config.is_coverage: 67 | self.optimizer.load_state_dict(state['optimizer']) 68 | if use_cuda: 69 | for state in self.optimizer.state.values(): 70 | for k, v in state.items(): 71 | if torch.is_tensor(v): 72 | state[k] = v.cuda() 73 | 74 | return start_iter, start_loss 75 | 76 | def train_one_batch(self, batch): 77 | enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \ 78 | get_input_from_batch(batch, use_cuda) 79 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ 80 | get_output_from_batch(batch, use_cuda) 81 | 82 | self.optimizer.zero_grad() 83 | 84 | encoder_outputs_list = [] 85 | encoder_feature_list = [] 86 | s_t_1 = None 87 | s_t_1_0 = None 88 | s_t_1_1 = None 89 | for enc_batch,enc_lens in zip(enc_batch_list, enc_lens_list): 90 | sorted_indices = sorted(range(len(enc_lens)),key=enc_lens.__getitem__) 91 | sorted_indices.reverse() 92 | inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))] 93 | for index,position in enumerate(sorted_indices): 94 | inverse_sorted_indices[position] = index 95 | sorted_enc_batch = torch.index_select(enc_batch, 0, torch.LongTensor(sorted_indices) if not use_cuda else torch.LongTensor(sorted_indices).cuda()) 96 | sorted_enc_lens = enc_lens[sorted_indices] 97 | sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(sorted_enc_batch, sorted_enc_lens) 98 | encoder_outputs = torch.index_select(sorted_encoder_outputs, 0, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()) 99 | encoder_feature = torch.index_select(sorted_encoder_feature.view(encoder_outputs.shape), 0, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()).view(sorted_encoder_feature.shape) 100 | encoder_hidden = tuple([torch.index_select(sorted_encoder_hidden[0], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()), torch.index_select(sorted_encoder_hidden[1], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda())]) 101 | #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) 102 | encoder_outputs_list.append(encoder_outputs) 103 | encoder_feature_list.append(encoder_feature) 104 | if s_t_1 is None: 105 | s_t_1 = self.model.reduce_state(encoder_hidden) 106 | s_t_1_0, s_t_1_1 = s_t_1 107 | else: 108 | s_t_1_new = self.model.reduce_state(encoder_hidden) 109 | s_t_1_0 = s_t_1_0 + s_t_1_new[0] 110 | s_t_1_1 = s_t_1_1 + s_t_1_new[1] 111 | s_t_1 = tuple([s_t_1_0, s_t_1_1]) 112 | 113 | #c_t_1_list = [c_t_1] 114 | #coverage_list = [coverage] 115 | 116 | step_losses = [] 117 | for di in range(min(max_dec_len, config.max_dec_steps)): 118 | y_t_1 = dec_batch[:, di] # Teacher forcing 119 | final_dist, s_t_1, c_t_1_list, attn_dist_list, p_gen, next_coverage_list = self.model.decoder(y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list, enc_padding_mask_list, c_t_1_list, extra_zeros_list, enc_batch_extend_vocab_list, coverage_list, di) 120 | target = target_batch[:, di] 121 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() 122 | step_loss = -torch.log(gold_probs + config.eps) 123 | if config.is_coverage: 124 | step_coverage_loss = 0.0 125 | for ind in range(len(coverage_list)): 126 | step_coverage_loss += torch.sum(torch.min(attn_dist_list[ind], coverage_list[ind]), 1) 127 | coverage_list[ind] = next_coverage_list[ind] 128 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss 129 | 130 | step_mask = dec_padding_mask[:, di] 131 | step_loss = step_loss * step_mask 132 | step_losses.append(step_loss) 133 | 134 | sum_losses = torch.sum(torch.stack(step_losses, 1), 1) 135 | batch_avg_loss = sum_losses/dec_lens_var 136 | loss = torch.mean(batch_avg_loss) 137 | 138 | loss.backward() 139 | 140 | self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) 141 | clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) 142 | clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) 143 | 144 | self.optimizer.step() 145 | 146 | return loss.item() 147 | 148 | def trainIters(self, n_iters, model_file_path=None): 149 | iter, running_avg_loss = self.setup_train(model_file_path) 150 | start = time.time() 151 | while iter < n_iters: 152 | batch = self.batcher.next_batch() 153 | loss = self.train_one_batch(batch) 154 | 155 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) 156 | iter += 1 157 | 158 | if iter % 100 == 0: 159 | self.summary_writer.flush() 160 | print_interval = 500 161 | if iter % print_interval == 0: 162 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval, 163 | time.time() - start, loss)) 164 | start = time.time() 165 | if iter % 500 == 0: 166 | self.save_model(running_avg_loss, iter) 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser(description="Train script") 170 | parser.add_argument("-m", 171 | dest="model_file_path", 172 | required=False, 173 | default=None, 174 | help="Model file for retraining (default: None).") 175 | args = parser.parse_args() 176 | 177 | train_processor = Train() 178 | train_processor.trainIters(config.max_iterations, args.model_file_path) 179 | -------------------------------------------------------------------------------- /training_ptr_gen/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | from data_util import config 8 | from numpy import random 9 | import sys 10 | 11 | use_cuda = config.use_gpu and torch.cuda.is_available() 12 | 13 | random.seed(123) 14 | torch.manual_seed(123) 15 | if torch.cuda.is_available(): 16 | torch.cuda.manual_seed_all(123) 17 | 18 | def init_lstm_wt(lstm): 19 | for names in lstm._all_weights: 20 | for name in names: 21 | if name.startswith('weight_'): 22 | wt = getattr(lstm, name) 23 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag) 24 | elif name.startswith('bias_'): 25 | # set forget bias to 1 26 | bias = getattr(lstm, name) 27 | n = bias.size(0) 28 | start, end = n // 4, n // 2 29 | bias.data.fill_(0.) 30 | bias.data[start:end].fill_(1.) 31 | 32 | def init_linear_wt(linear): 33 | linear.weight.data.normal_(std=config.trunc_norm_init_std) 34 | if linear.bias is not None: 35 | linear.bias.data.normal_(std=config.trunc_norm_init_std) 36 | 37 | def init_wt_normal(wt): 38 | wt.data.normal_(std=config.trunc_norm_init_std) 39 | 40 | def init_wt_unif(wt): 41 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag) 42 | 43 | class Encoder(nn.Module): 44 | def __init__(self): 45 | super(Encoder, self).__init__() 46 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim) 47 | init_wt_normal(self.embedding.weight) 48 | 49 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True) 50 | init_lstm_wt(self.lstm) 51 | 52 | self.W_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False) 53 | 54 | #seq_lens should be in descending order 55 | def forward(self, input, seq_lens): 56 | embedded = self.embedding(input) 57 | 58 | packed = pack_padded_sequence(embedded, seq_lens, batch_first=True) 59 | output, hidden = self.lstm(packed) 60 | 61 | encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x t_k x n 62 | encoder_outputs = encoder_outputs.contiguous() 63 | 64 | encoder_feature = encoder_outputs.view(-1, 2*config.hidden_dim) # B * t_k x 2*hidden_dim 65 | encoder_feature = self.W_h(encoder_feature) 66 | 67 | return encoder_outputs, encoder_feature, hidden 68 | 69 | class ReduceState(nn.Module): 70 | def __init__(self): 71 | super(ReduceState, self).__init__() 72 | 73 | self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim) 74 | init_linear_wt(self.reduce_h) 75 | self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim) 76 | init_linear_wt(self.reduce_c) 77 | 78 | def forward(self, hidden): 79 | h, c = hidden # h, c dim = 2 x b x hidden_dim 80 | h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2) 81 | hidden_reduced_h = F.relu(self.reduce_h(h_in)) 82 | c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2) 83 | hidden_reduced_c = F.relu(self.reduce_c(c_in)) 84 | 85 | return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim 86 | 87 | class Attention(nn.Module): 88 | def __init__(self): 89 | super(Attention, self).__init__() 90 | # attention 91 | if config.is_coverage: 92 | self.W_c = nn.Linear(1, config.hidden_dim * 2, bias=False) 93 | self.decode_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2) 94 | self.v = nn.Linear(config.hidden_dim * 2, 1, bias=False) 95 | 96 | def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage): 97 | b, t_k, n = list(encoder_outputs.size()) 98 | 99 | dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dim 100 | dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dim 101 | dec_fea_expanded = dec_fea_expanded.view(-1, n) # B * t_k x 2*hidden_dim 102 | 103 | att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim 104 | if config.is_coverage: 105 | coverage_input = coverage.view(-1, 1) # B * t_k x 1 106 | coverage_feature = self.W_c(coverage_input) # B * t_k x 2*hidden_dim 107 | att_features = att_features + coverage_feature 108 | 109 | e = F.tanh(att_features) # B * t_k x 2*hidden_dim 110 | scores = self.v(e) # B * t_k x 1 111 | scores = scores.view(-1, t_k) # B x t_k 112 | 113 | attn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_k 114 | normalization_factor = attn_dist_.sum(1, keepdim=True) 115 | attn_dist = attn_dist_ / normalization_factor 116 | 117 | attn_dist = attn_dist.unsqueeze(1) # B x 1 x t_k 118 | c_t = torch.bmm(attn_dist, encoder_outputs) # B x 1 x n 119 | c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim 120 | 121 | attn_dist = attn_dist.view(-1, t_k) # B x t_k 122 | 123 | if config.is_coverage: 124 | coverage = coverage.view(-1, t_k) 125 | coverage = coverage + attn_dist 126 | 127 | return c_t, attn_dist, coverage 128 | 129 | class Decoder(nn.Module): 130 | def __init__(self): 131 | super(Decoder, self).__init__() 132 | self.attention_network = Attention() 133 | # decoder 134 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim) 135 | init_wt_normal(self.embedding.weight) 136 | 137 | self.x_context = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim) 138 | 139 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=False) 140 | init_lstm_wt(self.lstm) 141 | 142 | self.encoder_division = nn.Linear(config.hidden_dim * 4, 1) 143 | 144 | if config.pointer_gen: 145 | self.p_gen_linear = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1) 146 | self.p_gen_encoder = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1) 147 | 148 | #p_vocab 149 | self.out1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim) 150 | self.out2 = nn.Linear(config.hidden_dim, config.vocab_size) 151 | init_linear_wt(self.out2) 152 | 153 | def forward(self, y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list, enc_padding_mask_list, 154 | c_t_1_list, extra_zeros_list, enc_batch_extend_vocab_list, coverage_list, step): 155 | 156 | c_t_list = [] 157 | attn_dist_list = [] 158 | if not self.training and step == 0: 159 | h_decoder, c_decoder = s_t_1 160 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim), 161 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim 162 | for ind in range(len(coverage_list)): 163 | c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs_list[ind], encoder_feature_list[ind], enc_padding_mask_list[ind], coverage_list[ind]) 164 | c_t_list.append(c_t) 165 | coverage_list[ind] = coverage_next 166 | 167 | y_t_1_embd = self.embedding(y_t_1) 168 | c_t_1 = None 169 | p_dec = F.sigmoid(self.encoder_division(torch.cat((c_t_1_list[0], c_t_1_list[1]),-1))) 170 | if len(sys.argv) > 2: 171 | p_dec = (p_dec + int(sys.argv[2]))/101 172 | for ind,value in enumerate(c_t_1_list): 173 | if c_t_1 is None: 174 | #c_t_1 = c_t_1_list[ind] 175 | c_t_1 = (p_dec) * c_t_1_list[ind] 176 | else: 177 | #c_t_1 += c_t_1_list[ind] 178 | c_t_1 += (1-p_dec) *c_t_1_list[ind] 179 | x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1)) 180 | lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1) 181 | 182 | h_decoder, c_decoder = s_t 183 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim), 184 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim 185 | c_t_list = [] 186 | for ind in range(len(coverage_list)): 187 | c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs_list[ind], encoder_feature_list[ind], enc_padding_mask_list[ind], coverage_list[ind]) 188 | c_t_list.append(c_t) 189 | attn_dist_list.append(attn_dist) 190 | if self.training or step > 0: 191 | coverage_list[ind] = coverage_next 192 | 193 | p_gen = None 194 | p_enc = None 195 | if config.pointer_gen: 196 | c_t = None 197 | assert len(c_t_list) > 0 198 | for ind in range(len(c_t_list)): 199 | if c_t is None: 200 | c_t = c_t_list[ind] 201 | else: 202 | c_t += c_t_list[ind] 203 | p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim) 204 | p_gen = self.p_gen_linear(p_gen_input) 205 | p_gen = F.sigmoid(p_gen) 206 | 207 | p_enc = self.p_gen_encoder(p_gen_input) 208 | p_enc = F.sigmoid(p_enc) 209 | if len(sys.argv) > 2: 210 | p_enc = (p_enc + int(sys.argv[2]))/101 211 | 212 | c_t = None 213 | for ind in range(len(c_t_list)): 214 | if c_t is None: 215 | c_t = c_t_list[ind] #p_enc * c_t_list[ind] 216 | else: 217 | c_t = c_t + c_t_list[ind] #c_t + (1-p_enc) * c_t_list[ind] 218 | output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3 219 | output = self.out1(output) # B x hidden_dim 220 | 221 | #output = F.relu(output) 222 | 223 | output = self.out2(output) # B x vocab_size 224 | vocab_dist = F.softmax(output, dim=1) 225 | 226 | 227 | attn_dist_list_ = [] 228 | if config.pointer_gen: 229 | vocab_dist_ = p_gen * vocab_dist 230 | for ind,attn_dist in enumerate(attn_dist_list): 231 | if ind == 0: 232 | attn_dist_list_.append(p_enc * (1 - p_gen) * attn_dist) 233 | else: 234 | attn_dist_list_.append((1 - p_enc) * (1 - p_gen) * attn_dist) 235 | 236 | for extra_zeros in extra_zeros_list: 237 | if extra_zeros is not None: 238 | vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1) 239 | 240 | for enc_batch_extend_vocab,attn_dist_ in zip(enc_batch_extend_vocab_list, attn_dist_list_): 241 | vocab_dist_ = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_) 242 | final_dist = vocab_dist_ 243 | else: 244 | final_dist = vocab_dist 245 | 246 | return final_dist, s_t, c_t_list, attn_dist_list, p_gen, coverage_list 247 | 248 | class Model(object): 249 | def __init__(self, model_file_path=None, is_eval=False): 250 | encoder = Encoder() 251 | decoder = Decoder() 252 | reduce_state = ReduceState() 253 | 254 | # shared the embedding between encoder and decoder 255 | decoder.embedding.weight = encoder.embedding.weight 256 | if is_eval: 257 | encoder = encoder.eval() 258 | decoder = decoder.eval() 259 | reduce_state = reduce_state.eval() 260 | 261 | if use_cuda: 262 | encoder = encoder.cuda() 263 | decoder = decoder.cuda() 264 | reduce_state = reduce_state.cuda() 265 | 266 | self.encoder = encoder 267 | self.decoder = decoder 268 | self.reduce_state = reduce_state 269 | 270 | if model_file_path is not None: 271 | state = torch.load(model_file_path, map_location= lambda storage, location: storage) 272 | self.encoder.load_state_dict(state['encoder_state_dict']) 273 | self.decoder.load_state_dict(state['decoder_state_dict'], strict=False) 274 | self.reduce_state.load_state_dict(state['reduce_state_dict']) 275 | -------------------------------------------------------------------------------- /data_util/batcher.py: -------------------------------------------------------------------------------- 1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/batcher.py 2 | 3 | import Queue 4 | import time 5 | from random import shuffle 6 | from threading import Thread 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import config 12 | import data 13 | 14 | import random 15 | random.seed(1234) 16 | 17 | 18 | class Example2(object): 19 | 20 | def __init__(self, article, article_extra, abstract_sentences, vocab): 21 | # Get ids of special tokens 22 | start_decoding = vocab.word2id(data.START_DECODING) 23 | stop_decoding = vocab.word2id(data.STOP_DECODING) 24 | 25 | # Process the article 26 | article_words = article.split() 27 | if len(article_words) > config.max_enc_steps: 28 | article_words = article_words[:config.max_enc_steps] 29 | self.enc_len = len(article_words) # store the length after truncation but before padding 30 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token 31 | 32 | article_extra_words = article_extra.split() 33 | if len(article_extra_words) > config.max_enc_steps: 34 | article_extra_words = article_extra_words[:config.max_enc_steps] 35 | self.enc_len_2 = len(article_extra_words) 36 | self.enc_input_2 = [vocab.word2id(w) for w in article_extra_words] 37 | 38 | # Process the abstract 39 | abstract = ' '.join(abstract_sentences) # string 40 | abstract_words = abstract.split() # list of strings 41 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token 42 | 43 | # Get the decoder input sequence and target sequence 44 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding) 45 | self.dec_len = len(self.dec_input) 46 | 47 | # If using pointer-generator mode, we need to store some extra info 48 | if config.pointer_gen: 49 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves 50 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab) 51 | self.enc_input_extend_vocab_2, self.article_oovs_2 = data.article2ids(article_extra_words, vocab) 52 | 53 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id 54 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs+self.article_oovs_2) 55 | 56 | # Overwrite decoder target sequence so it uses the temp article OOV ids 57 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding) 58 | 59 | # Store the original strings 60 | self.original_article = article 61 | self.original_abstract = abstract 62 | self.original_abstract_sents = abstract_sentences 63 | 64 | 65 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id): 66 | inp = [start_id] + sequence[:] 67 | target = sequence[:] 68 | if len(inp) > max_len: # truncate 69 | inp = inp[:max_len] 70 | target = target[:max_len] # no end_token 71 | else: # no truncation 72 | target.append(stop_id) # end token 73 | assert len(inp) == len(target) 74 | return inp, target 75 | 76 | 77 | def pad_decoder_inp_targ(self, max_len, pad_id): 78 | while len(self.dec_input) < max_len: 79 | self.dec_input.append(pad_id) 80 | while len(self.target) < max_len: 81 | self.target.append(pad_id) 82 | 83 | 84 | def pad_encoder_input(self, max_len, pad_id): 85 | while len(self.enc_input) < max_len: 86 | self.enc_input.append(pad_id) 87 | if config.pointer_gen: 88 | while len(self.enc_input_extend_vocab) < max_len: 89 | self.enc_input_extend_vocab.append(pad_id) 90 | 91 | 92 | def pad_encoder_input_2(self, max_len, pad_id): 93 | while len(self.enc_input_2) < max_len: 94 | self.enc_input_2.append(pad_id) 95 | if config.pointer_gen: 96 | while len(self.enc_input_extend_vocab_2) < max_len: 97 | self.enc_input_extend_vocab_2.append(pad_id) 98 | 99 | 100 | class Example(object): 101 | 102 | def __init__(self, article, abstract_sentences, vocab): 103 | # Get ids of special tokens 104 | start_decoding = vocab.word2id(data.START_DECODING) 105 | stop_decoding = vocab.word2id(data.STOP_DECODING) 106 | 107 | # Process the article 108 | article_words = article.split() 109 | if len(article_words) > config.max_enc_steps: 110 | article_words = article_words[:config.max_enc_steps] 111 | self.enc_len = len(article_words) # store the length after truncation but before padding 112 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token 113 | 114 | # Process the abstract 115 | abstract = ' '.join(abstract_sentences) # string 116 | abstract_words = abstract.split() # list of strings 117 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token 118 | 119 | # Get the decoder input sequence and target sequence 120 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding) 121 | self.dec_len = len(self.dec_input) 122 | 123 | # If using pointer-generator mode, we need to store some extra info 124 | if config.pointer_gen: 125 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves 126 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab) 127 | 128 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id 129 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs) 130 | 131 | # Overwrite decoder target sequence so it uses the temp article OOV ids 132 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding) 133 | 134 | # Store the original strings 135 | self.original_article = article 136 | self.original_abstract = abstract 137 | self.original_abstract_sents = abstract_sentences 138 | 139 | 140 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id): 141 | inp = [start_id] + sequence[:] 142 | target = sequence[:] 143 | if len(inp) > max_len: # truncate 144 | inp = inp[:max_len] 145 | target = target[:max_len] # no end_token 146 | else: # no truncation 147 | target.append(stop_id) # end token 148 | assert len(inp) == len(target) 149 | return inp, target 150 | 151 | 152 | def pad_decoder_inp_targ(self, max_len, pad_id): 153 | while len(self.dec_input) < max_len: 154 | self.dec_input.append(pad_id) 155 | while len(self.target) < max_len: 156 | self.target.append(pad_id) 157 | 158 | 159 | def pad_encoder_input(self, max_len, pad_id): 160 | while len(self.enc_input) < max_len: 161 | self.enc_input.append(pad_id) 162 | if config.pointer_gen: 163 | while len(self.enc_input_extend_vocab) < max_len: 164 | self.enc_input_extend_vocab.append(pad_id) 165 | 166 | 167 | class Batch(object): 168 | def __init__(self, example_list, vocab, batch_size): 169 | self.batch_size = batch_size 170 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences 171 | self.init_encoder_seq(example_list) # initialize the input to the encoder 172 | self.init_decoder_seq(example_list) # initialize the input and targets for the decoder 173 | self.store_orig_strings(example_list) # store the original strings 174 | 175 | 176 | def init_encoder_seq(self, example_list): 177 | # Determine the maximum length of the encoder input sequence in this batch 178 | max_enc_seq_len = max([ex.enc_len for ex in example_list]) 179 | max_enc_seq_len_2 = max([ex.enc_len_2 for ex in example_list]) 180 | 181 | # Pad the encoder input sequences up to the length of the longest sequence 182 | for ex in example_list: 183 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id) 184 | ex.pad_encoder_input_2(max_enc_seq_len_2, self.pad_id) 185 | 186 | # Initialize the numpy arrays 187 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder. 188 | self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32) 189 | self.enc_lens = np.zeros((self.batch_size), dtype=np.int32) 190 | self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32) 191 | 192 | self.enc_batch_2 = np.zeros((self.batch_size, max_enc_seq_len_2), dtype=np.int32) 193 | self.enc_lens_2 = np.zeros((self.batch_size), dtype=np.int32) 194 | self.enc_padding_mask_2 = np.zeros((self.batch_size, max_enc_seq_len_2), dtype=np.float32) 195 | 196 | # Fill in the numpy arrays 197 | for i, ex in enumerate(example_list): 198 | self.enc_batch[i, :] = ex.enc_input[:] 199 | self.enc_lens[i] = ex.enc_len 200 | for j in xrange(ex.enc_len): 201 | self.enc_padding_mask[i][j] = 1 202 | self.enc_batch_2[i, :] = ex.enc_input_2[:] 203 | self.enc_lens_2[i] = ex.enc_len_2 204 | for j in xrange(ex.enc_len_2): 205 | self.enc_padding_mask_2[i][j] = 1 206 | 207 | # For pointer-generator mode, need to store some extra info 208 | if config.pointer_gen: 209 | # Determine the max number of in-article OOVs in this batch 210 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list]) 211 | self.max_art_oovs_2 = max([len(ex.article_oovs_2) for ex in example_list]) 212 | # Store the in-article OOVs themselves 213 | self.art_oovs = [ex.article_oovs for ex in example_list] 214 | self.art_oovs_2 = [ex.article_oovs_2 for ex in example_list] 215 | # Store the version of the enc_batch that uses the article OOV ids 216 | self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32) 217 | self.enc_batch_extend_vocab_2 = np.zeros((self.batch_size, max_enc_seq_len_2), dtype=np.int32) 218 | for i, ex in enumerate(example_list): 219 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:] 220 | self.enc_batch_extend_vocab_2[i, :] = ex.enc_input_extend_vocab_2[:] 221 | 222 | def init_decoder_seq(self, example_list): 223 | # Pad the inputs and targets 224 | for ex in example_list: 225 | ex.pad_decoder_inp_targ(config.max_dec_steps, self.pad_id) 226 | 227 | # Initialize the numpy arrays. 228 | self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32) 229 | self.target_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32) 230 | self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32) 231 | self.dec_lens = np.zeros((self.batch_size), dtype=np.int32) 232 | 233 | # Fill in the numpy arrays 234 | for i, ex in enumerate(example_list): 235 | self.dec_batch[i, :] = ex.dec_input[:] 236 | self.target_batch[i, :] = ex.target[:] 237 | self.dec_lens[i] = ex.dec_len 238 | for j in xrange(ex.dec_len): 239 | self.dec_padding_mask[i][j] = 1 240 | 241 | def store_orig_strings(self, example_list): 242 | self.original_articles = [ex.original_article for ex in example_list] # list of lists 243 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists 244 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists 245 | 246 | 247 | class Batcher(object): 248 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold 249 | 250 | def __init__(self, data_path, vocab, mode, batch_size, single_pass): 251 | self._data_path = data_path 252 | self._vocab = vocab 253 | self._single_pass = single_pass 254 | self.mode = mode 255 | self.batch_size = batch_size 256 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched 257 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX) 258 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size) 259 | 260 | # Different settings depending on whether we're in single_pass mode or not 261 | if single_pass: 262 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once 263 | self._num_batch_q_threads = 1 # just one thread to batch examples 264 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing 265 | self._finished_reading = False # this will tell us when we're finished reading the dataset 266 | else: 267 | self._num_example_q_threads = 1 #16 # num threads to fill example queue 268 | self._num_batch_q_threads = 1 #4 # num threads to fill batch queue 269 | self._bucketing_cache_size = 1 #100 # how many batches-worth of examples to load into cache before bucketing 270 | 271 | # Start the threads that load the queues 272 | self._example_q_threads = [] 273 | for _ in xrange(self._num_example_q_threads): 274 | self._example_q_threads.append(Thread(target=self.fill_example_queue)) 275 | self._example_q_threads[-1].daemon = True 276 | self._example_q_threads[-1].start() 277 | self._batch_q_threads = [] 278 | for _ in xrange(self._num_batch_q_threads): 279 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) 280 | self._batch_q_threads[-1].daemon = True 281 | self._batch_q_threads[-1].start() 282 | 283 | # Start a thread that watches the other threads and restarts them if they're dead 284 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever 285 | self._watch_thread = Thread(target=self.watch_threads) 286 | self._watch_thread.daemon = True 287 | self._watch_thread.start() 288 | 289 | def next_batch(self): 290 | # If the batch queue is empty, print a warning 291 | if self._batch_queue.qsize() == 0: 292 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize()) 293 | if self._single_pass and self._finished_reading: 294 | tf.logging.info("Finished reading dataset in single_pass mode.") 295 | return None 296 | 297 | batch = self._batch_queue.get() # get the next Batch 298 | return batch 299 | 300 | def fill_example_queue(self): 301 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass)) 302 | f = open("inputs.txt", "a") 303 | while True: 304 | try: 305 | (source1, source2, target) = input_gen.next() # read the next example from file. article and abstract are both strings. 306 | f.write(source1+"\t"+source2+"\t"+target+"\n") 307 | except StopIteration: # if there are no more examples: 308 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.") 309 | if self._single_pass: 310 | tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.") 311 | self._finished_reading = True 312 | break 313 | else: 314 | raise Exception("single_pass mode is off but the example generator is out of data; error.") 315 | 316 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(target)] # Use the and tags in abstract to get a list of sentences. 317 | #example = Example(article, abstract_sentences, self._vocab) # Process into an Example. 318 | #example = Example2(article, ' '.join(abstract_sentences), abstract_sentences, self._vocab) 319 | #example = Example2(' '.join(abstract_sentences), article, abstract_sentences, self._vocab) 320 | #example = Example2(article, article, abstract_sentences, self._vocab) 321 | #example = Example2(' '.join(abstract_sentences), ' '.join(abstract_sentences), abstract_sentences, self._vocab) 322 | example = Example2(source1, source2, target.split(), self._vocab) 323 | self._example_queue.put(example) # place the Example in the example queue. 324 | f.close() 325 | 326 | def fill_batch_queue(self): 327 | while True: 328 | if self.mode == 'decode': 329 | # beam search decode mode single example repeated in the batch 330 | ex = self._example_queue.get() 331 | b = [ex for _ in xrange(self.batch_size)] 332 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size)) 333 | else: 334 | # Get bucketing_cache_size-many batches of Examples into a list, then sort 335 | inputs = [] 336 | for _ in xrange(self.batch_size * self._bucketing_cache_size): 337 | inputs.append(self._example_queue.get()) 338 | #inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence 339 | 340 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. 341 | batches = [] 342 | for i in xrange(0, len(inputs), self.batch_size): 343 | batches.append(inputs[i:i + self.batch_size]) 344 | if not self._single_pass: 345 | shuffle(batches) 346 | for b in batches: # each b is a list of Example objects 347 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size)) 348 | 349 | def watch_threads(self): 350 | while True: 351 | tf.logging.info( 352 | 'Bucket queue size: %i, Input queue size: %i', 353 | self._batch_queue.qsize(), self._example_queue.qsize()) 354 | 355 | time.sleep(60) 356 | for idx,t in enumerate(self._example_q_threads): 357 | if not t.is_alive(): # if the thread is dead 358 | tf.logging.error('Found example queue thread dead. Restarting.') 359 | new_t = Thread(target=self.fill_example_queue) 360 | self._example_q_threads[idx] = new_t 361 | new_t.daemon = True 362 | new_t.start() 363 | for idx,t in enumerate(self._batch_q_threads): 364 | if not t.is_alive(): # if the thread is dead 365 | tf.logging.error('Found batch queue thread dead. Restarting.') 366 | new_t = Thread(target=self.fill_batch_queue) 367 | self._batch_q_threads[idx] = new_t 368 | new_t.daemon = True 369 | new_t.start() 370 | 371 | 372 | def text_generator(self, example_generator): 373 | while True: 374 | e = example_generator.next() # e is a tf.Example 375 | try: 376 | article_text = e.features.feature['source1'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files 377 | abstract_text = e.features.feature['target'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files 378 | abstract_text_extra = e.features.feature['source2'].bytes_list.value[0] 379 | except ValueError: 380 | tf.logging.error('Failed to get article or abstract from example') 381 | continue 382 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1 383 | #tf.logging.warning('Found an example with empty article text. Skipping it.') 384 | continue 385 | else: 386 | yield (article_text, abstract_text_extra, abstract_text) 387 | --------------------------------------------------------------------------------