├── LICENSE.md ├── README.md ├── __init__.py ├── _config.yml ├── batcher.py ├── layers.py ├── main.py ├── model.py ├── test_helper.py ├── train_test_eval.py ├── training_helper.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Copyright (C) 2018 David Stephane Belemkoabga 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pointer_Generator_Summarizer Tensorflow 2.0.0 (V3) 2 | 3 | 4 | The pointer generator is a deep neural network built for abstractive summarizations. 5 | For more informations on this model, https://arxiv.org/pdf/1704.04368 6 | 7 | With my collaborator Kevin Sylla , we re-made this model in tensorflow for our research project. This neural net will be our baseline model. 8 | We will do some experiments with this model, and propose a new architecture based on this one. 9 | 10 | In this project, you can: 11 | - train models 12 | - test ² 13 | - evaluate ² 14 | 15 | This project reads tfrecords format files. For our experiments, we will be working on the ccn and dailymail datasets. 16 | You can download the preprocessed files with this link : 17 | https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail 18 | 19 | Or do the pre-processing by yourself with this link : 20 | https://github.com/abisee/cnn-dailymail 21 | 22 | 23 | You may launch the program with the following command: (have a look at the main.py script for more informations about the attributes) 24 |
25 |
26 | 27 | __python main.py \ 28 | --max_enc_len=400 \ 29 | --max_dec_len=100 \ 30 | --max_dec_steps=120 \ 31 | --min_dec_steps=30 \ 32 | --batch_size=4 \ 33 | --beam_size=4 \ 34 | --vocab_size=50000 \ 35 | --embed_size=128 \ 36 | --enc_units=256 \ 37 | --dec_units=256 \ 38 | --attn_units=512 \ 39 | --learning_rate=0.15 \ 40 | --adagrad_init_acc=0.1 \ 41 | --max_grad_norm=0.8 \ 42 | --mode="eval" \ 43 | --checkpoints_save_steps=5000 \ 44 | --max_steps=38000 \ 45 | --num_to_test=5 \ 46 | --max_num_to_eval=100 \ 47 | --vocab_path="../../Datasets/tfrecords_folder/tfrecords_folder/vocab" \ 48 | --data_dir="../../Datasets/tfrecords_folder/tfrecords_folder/val" \ 49 | --model_path="../pgn_model_dir/checkpoint/ckpt-37000" \ 50 | --checkpoint_dir="../pgn_model_dir/checkpoint" \ 51 | --test_save_dir="../pgn_model_dir/test_dir/"__ 52 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steph1793/Pointer_Generator_Summarizer/e03adf559e0e7727e95de60d143c1d46c0dd747d/__init__.py -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-time-machine -------------------------------------------------------------------------------- /batcher.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import glob 3 | import os 4 | import ntpath 5 | 6 | class Vocab: 7 | 8 | SENTENCE_START = '' 9 | SENTENCE_END = '' 10 | 11 | PAD_TOKEN = '[PAD]' 12 | UNKNOWN_TOKEN = '[UNK]' 13 | START_DECODING = '[START]' 14 | STOP_DECODING = '[STOP]' 15 | 16 | def __init__(self, vocab_file, max_size): 17 | 18 | self.word2id = {Vocab.UNKNOWN_TOKEN : 0, Vocab.PAD_TOKEN : 1, 19 | Vocab.START_DECODING : 2, Vocab.STOP_DECODING : 3} 20 | self.id2word = {0 : Vocab.UNKNOWN_TOKEN, 1 : Vocab.PAD_TOKEN, 2 : Vocab.START_DECODING, 3 : Vocab.STOP_DECODING} 21 | self.count = 4 22 | 23 | with open(vocab_file, 'r') as f: 24 | for line in f: 25 | pieces = line.split() 26 | if len(pieces) != 2 : 27 | print('Warning : incorrectly formatted line in vocabulary file : %s\n' % line) 28 | continue 29 | 30 | w = pieces[0] 31 | if w in [Vocab.SENTENCE_START, Vocab.SENTENCE_END, Vocab.UNKNOWN_TOKEN, Vocab.PAD_TOKEN, Vocab.START_DECODING, Vocab.STOP_DECODING]: 32 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 33 | 34 | if w in self.word2id: 35 | raise Exception('Duplicated word in vocabulary file: %s' % w) 36 | 37 | self.word2id[w] = self.count 38 | self.id2word[self.count] = w 39 | self.count += 1 40 | if max_size != 0 and self.count >= max_size: 41 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self.count)) 42 | break 43 | 44 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self.count, self.id2word[self.count-1])) 45 | 46 | 47 | def word_to_id(self, word): 48 | if word not in self.word2id: 49 | return self.word2id[Vocab.UNKNOWN_TOKEN] 50 | return self.word2id[word] 51 | 52 | def id_to_word(self, word_id): 53 | if word_id not in self.id2word: 54 | raise ValueError('Id not found in vocab: %d' % word_id) 55 | return self.id2word[word_id] 56 | 57 | def size(self): 58 | return self.count 59 | class Data_Helper: 60 | def article_to_ids(article_words, vocab): 61 | ids = [] 62 | oovs = [] 63 | unk_id = vocab.word_to_id(vocab.UNKNOWN_TOKEN) 64 | for w in article_words: 65 | i = vocab.word_to_id(w) 66 | if i == unk_id: # If w is OOV 67 | if w not in oovs: # Add to list of OOVs 68 | oovs.append(w) 69 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 70 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 71 | else: 72 | ids.append(i) 73 | return ids, oovs 74 | 75 | 76 | def abstract_to_ids(abstract_words, vocab, article_oovs): 77 | ids = [] 78 | unk_id = vocab.word_to_id(vocab.UNKNOWN_TOKEN) 79 | for w in abstract_words: 80 | i = vocab.word_to_id(w) 81 | if i == unk_id: # If w is an OOV word 82 | if w in article_oovs: # If w is an in-article OOV 83 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 84 | ids.append(vocab_idx) 85 | else: # If w is an out-of-article OOV 86 | ids.append(unk_id) # Map to the UNK token id 87 | else: 88 | ids.append(i) 89 | return ids 90 | 91 | 92 | 93 | def output_to_words(id_list, vocab, article_oovs): 94 | words = [] 95 | for i in id_list: 96 | try: 97 | w = vocab.id_to_word(i) # might be [UNK] 98 | except ValueError as e: # w is OOV 99 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 100 | article_oov_idx = i - vocab.size() 101 | try: 102 | w = article_oovs[article_oov_idx] 103 | except ValueError as e: # i doesn't correspond to an article oov 104 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 105 | words.append(w) 106 | return words 107 | 108 | 109 | 110 | def abstract_to_sents(abstract): 111 | """Splits abstract text from datafile into list of sentences. 112 | Args: 113 | abstract: string containing and tags for starts and ends of sentences 114 | Returns: 115 | sents: List of sentence strings (no tags)""" 116 | cur = 0 117 | sents = [] 118 | while True: 119 | try: 120 | start_p = abstract.index(Vocab.SENTENCE_START, cur) 121 | end_p = abstract.index(Vocab.SENTENCE_END, start_p + 1) 122 | cur = end_p + len(Vocab.SENTENCE_END) 123 | sents.append(abstract[start_p+len(Vocab.SENTENCE_START):end_p]) 124 | except ValueError as e: # no more sentences 125 | return sents 126 | 127 | def get_dec_inp_targ_seqs( sequence, max_len, start_id, stop_id): 128 | """Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated). 129 | Args: 130 | sequence: List of ids (integers) 131 | max_len: integer 132 | start_id: integer 133 | stop_id: integer 134 | Returns: 135 | inp: sequence length <=max_len starting with start_id 136 | target: sequence same length as input, ending with stop_id only if there was no truncation 137 | """ 138 | inp = [start_id] + sequence[:] 139 | target = sequence[:] 140 | if len(inp) > max_len: # truncate 141 | inp = inp[:max_len] 142 | target = target[:max_len] # no end_token 143 | else: # no truncation 144 | target.append(stop_id) # end token 145 | assert len(inp) == len(target) 146 | return inp, target 147 | 148 | 149 | def _parse_function(example_proto): 150 | # Create a description of the features. 151 | feature_description = { 152 | 'article': tf.io.FixedLenFeature([], tf.string, default_value=''), 153 | 'abstract': tf.io.FixedLenFeature([], tf.string, default_value='') 154 | } 155 | # Parse the input `tf.Example` proto using the dictionary above. 156 | parsed_example = tf.io.parse_single_example(example_proto, feature_description) 157 | return parsed_example 158 | 159 | 160 | def example_generator(filenames, vocab, max_enc_len, max_dec_len, mode, batch_size): 161 | 162 | raw_dataset = tf.data.TFRecordDataset(filenames) 163 | parsed_dataset = raw_dataset.map(_parse_function) 164 | if mode == "train": 165 | parsed_dataset = parsed_dataset.shuffle(1000, reshuffle_each_iteration=True).repeat() 166 | 167 | for raw_record in parsed_dataset: 168 | 169 | article = raw_record["article"].numpy().decode() 170 | abstract = raw_record["abstract"].numpy().decode() 171 | 172 | start_decoding = vocab.word_to_id(vocab.START_DECODING) 173 | stop_decoding = vocab.word_to_id(vocab.STOP_DECODING) 174 | 175 | article_words = article.split()[ : max_enc_len] 176 | enc_len = len(article_words) 177 | enc_input = [vocab.word_to_id(w) for w in article_words] 178 | enc_input_extend_vocab, article_oovs = Data_Helper.article_to_ids(article_words, vocab) 179 | 180 | abstract_sentences = [sent.strip() for sent in Data_Helper.abstract_to_sents(abstract)] 181 | abstract = ' '.join(abstract_sentences) 182 | abstract_words = abstract.split() 183 | abs_ids = [vocab.word_to_id(w) for w in abstract_words] 184 | abs_ids_extend_vocab = Data_Helper.abstract_to_ids(abstract_words, vocab, article_oovs) 185 | dec_input, target = Data_Helper.get_dec_inp_targ_seqs(abs_ids, max_dec_len, start_decoding, stop_decoding) 186 | _, target = Data_Helper.get_dec_inp_targ_seqs(abs_ids_extend_vocab, max_dec_len, start_decoding, stop_decoding) 187 | dec_len = len(dec_input) 188 | 189 | output = { 190 | "enc_len":enc_len, 191 | "enc_input" : enc_input, 192 | "enc_input_extend_vocab" : enc_input_extend_vocab, 193 | "article_oovs" : article_oovs, 194 | "dec_input" : dec_input, 195 | "target" : target, 196 | "dec_len" : dec_len, 197 | "article" : article, 198 | "abstract" : abstract, 199 | "abstract_sents" : abstract_sentences 200 | } 201 | if mode == "test" or mode == "eval": 202 | for _ in range(batch_size): 203 | yield output 204 | else: 205 | yield output 206 | 207 | 208 | def batch_generator(generator, filenames, vocab, max_enc_len, max_dec_len, batch_size, mode): 209 | 210 | dataset = tf.data.Dataset.from_generator(lambda : generator(filenames, vocab, max_enc_len, max_dec_len, mode, batch_size), 211 | output_types = { 212 | "enc_len":tf.int32, 213 | "enc_input" : tf.int32, 214 | "enc_input_extend_vocab" : tf.int32, 215 | "article_oovs" : tf.string, 216 | "dec_input" : tf.int32, 217 | "target" : tf.int32, 218 | "dec_len" : tf.int32, 219 | "article" : tf.string, 220 | "abstract" : tf.string, 221 | "abstract_sents" : tf.string 222 | }, output_shapes={ 223 | "enc_len":[], 224 | "enc_input" : [None], 225 | "enc_input_extend_vocab" : [None], 226 | "article_oovs" : [None], 227 | "dec_input" : [None], 228 | "target" : [None], 229 | "dec_len" : [], 230 | "article" : [], 231 | "abstract" : [], 232 | "abstract_sents" : [None] 233 | }) 234 | dataset = dataset.padded_batch(batch_size, padded_shapes=({"enc_len":[], 235 | "enc_input" : [None], 236 | "enc_input_extend_vocab" : [None], 237 | "article_oovs" : [None], 238 | "dec_input" : [max_dec_len], 239 | "target" : [max_dec_len], 240 | "dec_len" : [], 241 | "article" : [], 242 | "abstract" : [], 243 | "abstract_sents" : [None]}), 244 | padding_values={"enc_len":-1, 245 | "enc_input" : 1, 246 | "enc_input_extend_vocab" : 1, 247 | "article_oovs" : b'', 248 | "dec_input" : 1, 249 | "target" : 1, 250 | "dec_len" : -1, 251 | "article" : b"", 252 | "abstract" : b"", 253 | "abstract_sents" : b''}, 254 | drop_remainder=True) 255 | def update(entry): 256 | return ({"enc_input" : entry["enc_input"], 257 | "extended_enc_input" : entry["enc_input_extend_vocab"], 258 | "article_oovs" : entry["article_oovs"], 259 | "enc_len" : entry["enc_len"], 260 | "article" : entry["article"], 261 | "max_oov_len" : tf.shape(entry["article_oovs"])[1] }, 262 | 263 | {"dec_input" : entry["dec_input"], 264 | "dec_target" : entry["target"], 265 | "dec_len" : entry["dec_len"], 266 | "abstract" : entry["abstract"]}) 267 | 268 | 269 | dataset = dataset.map(update) 270 | return dataset 271 | 272 | 273 | def batcher(data_path, vocab, hpm): 274 | 275 | filenames = glob.glob("{}/*.tfrecords".format(data_path)) 276 | dataset = batch_generator(example_generator, filenames, vocab, hpm["max_enc_len"], hpm["max_dec_len"], hpm["batch_size"], hpm["mode"] ) 277 | 278 | return dataset 279 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Encoder(tf.keras.layers.Layer): 4 | def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz): 5 | super(Encoder, self).__init__() 6 | self.batch_sz = batch_sz 7 | self.enc_units = enc_units 8 | self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) 9 | self.gru = tf.keras.layers.GRU(self.enc_units, 10 | return_sequences=True, 11 | return_state=True, 12 | recurrent_initializer='glorot_uniform') 13 | 14 | def call(self, x, hidden): 15 | x = self.embedding(x) 16 | output, state = self.gru(x, initial_state = hidden) 17 | return output, state 18 | 19 | def initialize_hidden_state(self): 20 | return tf.zeros((self.batch_sz, self.enc_units)) 21 | 22 | 23 | class BahdanauAttention(tf.keras.layers.Layer): 24 | def __init__(self, units): 25 | super(BahdanauAttention, self).__init__() 26 | self.W1 = tf.keras.layers.Dense(units) 27 | self.W2 = tf.keras.layers.Dense(units) 28 | self.V = tf.keras.layers.Dense(1) 29 | 30 | def call(self, query, values): 31 | # hidden shape == (batch_size, hidden size) 32 | # hidden_with_time_axis shape == (batch_size, 1, hidden size) 33 | # we are doing this to perform addition to calculate the score 34 | hidden_with_time_axis = tf.expand_dims(query, 1) 35 | 36 | # score shape == (batch_size, max_length, 1) 37 | # we get 1 at the last axis because we are applying score to self.V 38 | # the shape of the tensor before applying self.V is (batch_size, max_length, units) 39 | score = self.V(tf.nn.tanh( 40 | self.W1(values) + self.W2(hidden_with_time_axis))) 41 | 42 | # attention_weights shape == (batch_size, max_length, 1) 43 | attention_weights = tf.nn.softmax(score, axis=1) 44 | 45 | # context_vector shape after sum == (batch_size, hidden_size) 46 | context_vector = attention_weights * values 47 | context_vector = tf.reduce_sum(context_vector, axis=1) 48 | 49 | return context_vector, tf.squeeze(attention_weights,-1) 50 | 51 | 52 | class Decoder(tf.keras.layers.Layer): 53 | def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz): 54 | super(Decoder, self).__init__() 55 | self.batch_sz = batch_sz 56 | self.dec_units = dec_units 57 | self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) 58 | self.gru = tf.keras.layers.GRU(self.dec_units, 59 | return_sequences=True, 60 | return_state=True, 61 | recurrent_initializer='glorot_uniform') 62 | self.fc = tf.keras.layers.Dense(vocab_size, activation=tf.keras.activations.softmax) 63 | 64 | 65 | def call(self, x, hidden, enc_output, context_vector): 66 | # enc_output shape == (batch_size, max_length, hidden_size) 67 | 68 | 69 | # x shape after passing through embedding == (batch_size, 1, embedding_dim) 70 | x = self.embedding(x) 71 | 72 | # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size) 73 | x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1) 74 | 75 | # passing the concatenated vector to the GRU 76 | output, state = self.gru(x) 77 | 78 | # output shape == (batch_size * 1, hidden_size) 79 | output = tf.reshape(output, (-1, output.shape[2])) 80 | 81 | # output shape == (batch_size, vocab) 82 | out = self.fc(output) 83 | 84 | return x, out, state 85 | 86 | 87 | class Pointer(tf.keras.layers.Layer): 88 | 89 | def __init__(self): 90 | super(Pointer, self).__init__() 91 | self.w_s_reduce = tf.keras.layers.Dense(1) 92 | self.w_i_reduce = tf.keras.layers.Dense(1) 93 | self.w_c_reduce = tf.keras.layers.Dense(1) 94 | 95 | def call(self, context_vector, state, dec_inp): 96 | return tf.nn.sigmoid(self.w_s_reduce(state)+self.w_c_reduce(context_vector)+self.w_i_reduce(dec_inp)) 97 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | from train_test_eval import train, test_and_save, evaluate 4 | import os 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--max_enc_len", default=400, help="Encoder input max sequence length", type=int) 9 | parser.add_argument("--max_dec_len", default=100, help="Decoder input max sequence length", type=int) 10 | parser.add_argument("--max_dec_steps", default=120, help="maximum number of words of the predicted abstract", type=int) 11 | parser.add_argument("--min_dec_steps", default=30, help="Minimum number of words of the predicted abstract", type=int) 12 | parser.add_argument("--batch_size", default=16, help="batch size", type=int) 13 | parser.add_argument("--beam_size", default=4, help="beam size for beam search decoding (must be equal to batch size in decode mode)", type=int) 14 | parser.add_argument("--vocab_size", default=50000, help="Vocabulary size", type=int) 15 | parser.add_argument("--embed_size", default=128, help="Words embeddings dimension", type=int) 16 | parser.add_argument("--enc_units", default=256, help="Encoder GRU cell units number", type=int) 17 | parser.add_argument("--dec_units", default=256, help="Decoder GRU cell units number", type=int) 18 | parser.add_argument("--attn_units", default=512, help="[context vector, decoder state, decoder input] feedforward result dimension - this result is used to compute the attention weights", type=int) 19 | parser.add_argument("--learning_rate", default=0.15, help="Learning rate", type=float) 20 | parser.add_argument("--adagrad_init_acc", default=0.1, help="Adagrad optimizer initial accumulator value. Please refer to the Adagrad optimizer API documentation on tensorflow site for more details.", type=float) 21 | parser.add_argument("--max_grad_norm",default=0.8, help="Gradient norm above which gradients must be clipped", type=float) 22 | parser.add_argument("--checkpoints_save_steps", default=10000, help="Save checkpoints every N steps", type=int) 23 | parser.add_argument("--max_steps", default=10000, help="Max number of iterations", type=int) 24 | parser.add_argument("--num_to_test", default=5, help="Number of examples to test", type=int) 25 | parser.add_argument("--max_num_to_eval", default=5, help="Max number of examples to evaluate", type=int) 26 | parser.add_argument("--mode", help="training, eval or test options", default="", type=str) 27 | parser.add_argument("--model_path", help="Path to a specific model", default="", type=str) 28 | parser.add_argument("--checkpoint_dir", help="Checkpoint directory", default="", type=str) 29 | parser.add_argument("--test_save_dir", help="Directory in which we store the decoding results", default="", type=str) 30 | parser.add_argument("--data_dir", help="Data Folder", default="", type=str) 31 | parser.add_argument("--vocab_path", help="Vocab path", default="", type=str) 32 | parser.add_argument("--log_file", help="File in which to redirect console outputs", default="", type=str) 33 | 34 | 35 | args = parser.parse_args() 36 | params = vars(args) 37 | print(params) 38 | 39 | assert params["mode"], "mode is required. train, test or eval option" 40 | assert params["mode"] in ["train", "test", "eval"], "The mode must be train , test or eval" 41 | assert os.path.exists(params["data_dir"]), "data_dir doesn't exist" 42 | assert os.path.isfile(params["vocab_path"]), "vocab_path doesn't exist" 43 | 44 | 45 | if params["mode"] == "train": 46 | train( params) 47 | elif params["mode"] == "test": 48 | test_and_save(params) 49 | elif params["mode"] == "eval": 50 | evaluate(params) 51 | 52 | 53 | if __name__ =="__main__": 54 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import _calc_final_dist 3 | from layers import Encoder, BahdanauAttention, Decoder, Pointer 4 | 5 | class PGN(tf.keras.Model): 6 | 7 | def __init__(self, params): 8 | super(PGN, self).__init__() 9 | self.params = params 10 | self.encoder = Encoder(params["vocab_size"], params["embed_size"], params["enc_units"], params["batch_size"]) 11 | self.attention = BahdanauAttention(params["attn_units"]) 12 | self.decoder = Decoder(params["vocab_size"], params["embed_size"], params["dec_units"], params["batch_size"]) 13 | self.pointer = Pointer() 14 | 15 | def call_encoder(self, enc_inp): 16 | enc_hidden = self.encoder.initialize_hidden_state() 17 | enc_output, enc_hidden = self.encoder(enc_inp, enc_hidden) 18 | return enc_hidden, enc_output 19 | 20 | def call(self, enc_output, dec_hidden, enc_inp, enc_extended_inp, dec_inp, batch_oov_len): 21 | 22 | predictions = [] 23 | attentions = [] 24 | p_gens = [] 25 | context_vector, _ = self.attention(dec_hidden, enc_output) 26 | for t in range(dec_inp.shape[1]): 27 | dec_x, pred, dec_hidden = self.decoder(tf.expand_dims(dec_inp[:, t],1), dec_hidden, enc_output, context_vector) 28 | context_vector, attn = self.attention(dec_hidden, enc_output) 29 | p_gen = self.pointer(context_vector, dec_hidden, tf.squeeze(dec_x, axis=1)) 30 | 31 | predictions.append(pred) 32 | attentions.append(attn) 33 | p_gens.append(p_gen) 34 | final_dists = _calc_final_dist( enc_extended_inp, predictions, attentions, p_gens, batch_oov_len, self.params["vocab_size"], self.params["batch_size"]) 35 | if self.params["mode"] == "train": 36 | return tf.stack(final_dists, 1), dec_hidden # predictions_shape = (batch_size, dec_len, vocab_size) with dec_len = 1 in pred mode 37 | else: 38 | return tf.stack(final_dists, 1), dec_hidden, context_vector, tf.stack(attentions, 1), tf.stack(p_gens, 1) 39 | -------------------------------------------------------------------------------- /test_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from batcher import Data_Helper 4 | 5 | def beam_decode(model, batch, vocab, params): 6 | 7 | def decode_onestep(batch, enc_outputs, dec_state, dec_input): 8 | """ 9 | Method to decode the output step by step (used for beamSearch decoding) 10 | Args: 11 | sess : tf.Session object 12 | batch : current batch, shape = [beam_size, 1, vocab_size( + max_oov_len if pointer_gen)] (for the beam search decoding, batch_size = beam_size) 13 | enc_outputs : hiddens outputs computed by the encoder LSTM 14 | dec_state : beam_size-many list of decoder previous state, LSTMStateTuple objects, shape = [beam_size, 2, hidden_size] 15 | dec_input : decoder_input, the previous decoded batch_size-many words, shape = [beam_size, embed_size] 16 | cov_vec : beam_size-many list of previous coverage vector 17 | Returns: A dictionary of the results of all the ops computations (see below for more details) 18 | """ 19 | # dictionary of all the ops that will be computed 20 | final_dists, dec_hidden, context_vector, attentions, p_gens = model(enc_outputs, dec_state,batch[0]["enc_input"], batch[0]["extended_enc_input"], dec_input, batch[0]["max_oov_len"]) 21 | top_k_probs, top_k_ids = tf.nn.top_k(tf.squeeze(final_dists), k = params["beam_size"]*2) 22 | top_k_log_probs = tf.math.log(top_k_probs) 23 | results = {"last_context_vector" : context_vector, 24 | "dec_state" : dec_hidden, 25 | "attention_vec" :attentions, 26 | "top_k_ids" : top_k_ids, 27 | "top_k_log_probs" : top_k_log_probs, 28 | "p_gen" : p_gens} 29 | return results 30 | 31 | 32 | # nested class 33 | class Hypothesis: 34 | """ Class designed to hold hypothesises throughout the beamSearch decoding """ 35 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens): 36 | self.tokens = tokens # list of all the tokens from time 0 to the current time step t 37 | self.log_probs = log_probs # list of the log probabilities of the tokens of the tokens 38 | self.state = state # decoder state after the last token decoding 39 | self.attn_dists = attn_dists # attention dists of all the tokens 40 | self.p_gens = p_gens # generation probability of all the tokens 41 | self.abstract = "" 42 | self.text = "" 43 | self.real_abstract = "" 44 | 45 | def extend(self, token, log_prob, state, attn_dist, p_gen): 46 | """Method to extend the current hypothesis by adding the next decoded toekn and all the informations associated with it""" 47 | return Hypothesis(tokens = self.tokens + [token], # we add the decoded token 48 | log_probs = self.log_probs + [log_prob], # we add the log prob of the decoded token 49 | state = state, # we update the state 50 | attn_dists = self.attn_dists + [attn_dist], # we add the attention dist of the decoded token 51 | p_gens = self.p_gens + [p_gen] # we add the p_gen 52 | ) 53 | 54 | @property 55 | def latest_token(self): 56 | return self.tokens[-1] 57 | 58 | @property 59 | def tot_log_prob(self): 60 | return sum(self.log_probs) 61 | 62 | @property 63 | def avg_log_prob(self): 64 | return self.tot_log_prob/len(self.tokens) 65 | 66 | # end of the nested class 67 | 68 | # We run the encoder once and then we use the results to decode each time step token 69 | 70 | state, enc_outputs = model.call_encoder(batch[0]["enc_input"]) 71 | 72 | # Initial Hypothesises (beam_size many list) 73 | hyps = [Hypothesis(tokens=[vocab.word_to_id('[START]')], # we initalize all the beam_size hypothesises with the token start 74 | log_probs = [0.0], # Initial log prob = 0 75 | state = state[0], #initial dec_state (we will use only the first dec_state because they're initially the same) 76 | attn_dists=[], 77 | p_gens = [], # we init the coverage vector to zero 78 | ) for _ in range(params['batch_size'])] # batch_size == beam_size 79 | 80 | results = [] # list to hold the top beam_size hypothesises 81 | steps=0 # initial step 82 | 83 | while steps < params['max_dec_steps'] and len(results) < params['beam_size'] : 84 | latest_tokens = [h.latest_token for h in hyps] # latest token for each hypothesis , shape : [beam_size] 85 | latest_tokens = [t if t in range(params['vocab_size']) else vocab.word_to_id('[UNK]') for t in latest_tokens] # we replace all the oov is by the unknown token 86 | states = [h.state for h in hyps] # we collect the last states for each hypothesis 87 | 88 | # we decode the top likely 2 x beam_size tokens tokens at time step t for each hypothesis 89 | returns = decode_onestep( batch, enc_outputs, tf.stack(states, axis=0), tf.expand_dims(latest_tokens, axis=1)) 90 | topk_ids, topk_log_probs, new_states, attn_dists , p_gens= returns['top_k_ids'], returns['top_k_log_probs'], returns['dec_state'], returns['attention_vec'], np.squeeze(returns["p_gen"]) 91 | all_hyps = [] 92 | num_orig_hyps = 1 if steps ==0 else len(hyps) 93 | for i in range(num_orig_hyps): 94 | h, new_state, attn_dist, p_gen = hyps[i], new_states[i], attn_dists[i], p_gens[i] 95 | 96 | for j in range(params['beam_size']*2): 97 | # we extend each hypothesis with each of the top k tokens (this gives 2 x beam_size new hypothesises for each of the beam_size old hypothesises) 98 | new_hyp = h.extend(token=topk_ids[i,j].numpy(), 99 | log_prob=topk_log_probs[i,j], 100 | state = new_state, 101 | attn_dist=attn_dist, 102 | p_gen=p_gen) 103 | all_hyps.append(new_hyp) 104 | 105 | # in the following lines, we sort all the hypothesises, and select only the beam_size most likely hypothesises 106 | hyps = [] 107 | sorted_hyps = sorted(all_hyps, key=lambda h: h.avg_log_prob, reverse=True) 108 | for h in sorted_hyps: 109 | if h.latest_token == vocab.word_to_id('[STOP]'): 110 | if steps >= params['min_dec_steps']: 111 | results.append(h) 112 | else: 113 | hyps.append(h) 114 | if len(hyps) == params['beam_size'] or len(results) == params['beam_size']: 115 | break 116 | 117 | steps += 1 118 | 119 | if len(results)==0: 120 | results=hyps 121 | 122 | # At the end of the loop we return the most likely hypothesis, which holds the most likely ouput sequence, given the input fed to the model 123 | hyps_sorted = sorted(results, key=lambda h: h.avg_log_prob, reverse=True) 124 | best_hyp = hyps_sorted[0] 125 | best_hyp.abstract = " ".join(Data_Helper.output_to_words(best_hyp.tokens, vocab, batch[0]["article_oovs"][0])[1:-1]) 126 | best_hyp.text = batch[0]["article"].numpy()[0].decode() 127 | if params["mode"] == "eval": 128 | best_hyp.real_abstract = batch[1]["abstract"].numpy()[0].decode() 129 | return best_hyp 130 | 131 | -------------------------------------------------------------------------------- /train_test_eval.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from model import PGN 3 | from training_helper import train_model 4 | from test_helper import beam_decode 5 | from batcher import batcher, Vocab, Data_Helper 6 | from tqdm import tqdm 7 | from rouge import Rouge 8 | import pprint 9 | 10 | def train(params): 11 | assert params["mode"].lower() == "train", "change training mode to 'train'" 12 | 13 | tf.compat.v1.logging.info("Building the model ...") 14 | model = PGN(params) 15 | 16 | print("Creating the vocab ...") 17 | vocab = Vocab(params["vocab_path"], params["vocab_size"]) 18 | 19 | print("Creating the batcher ...") 20 | b = batcher(params["data_dir"], vocab, params) 21 | 22 | print("Creating the checkpoint manager") 23 | checkpoint_dir = "{}".format(params["checkpoint_dir"]) 24 | ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model) 25 | ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=11) 26 | 27 | ckpt.restore(ckpt_manager.latest_checkpoint) 28 | if ckpt_manager.latest_checkpoint: 29 | print("Restored from {}".format(ckpt_manager.latest_checkpoint)) 30 | else: 31 | print("Initializing from scratch.") 32 | 33 | tf.compat.v1.logging.info("Starting the training ...") 34 | train_model(model, b, params, ckpt, ckpt_manager, "output.txt") 35 | 36 | 37 | def test(params): 38 | assert params["mode"].lower() in ["test","eval"], "change training mode to 'test' or 'eval'" 39 | assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params" 40 | 41 | tf.compat.v1.logging.info("Building the model ...") 42 | model = PGN(params) 43 | 44 | print("Creating the vocab ...") 45 | vocab = Vocab(params["vocab_path"], params["vocab_size"]) 46 | 47 | print("Creating the batcher ...") 48 | b = batcher(params["data_dir"], vocab, params) 49 | 50 | print("Creating the checkpoint manager") 51 | checkpoint_dir = "{}".format(params["checkpoint_dir"]) 52 | ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model) 53 | ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=11) 54 | 55 | path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint 56 | ckpt.restore(path) 57 | print("Model restored") 58 | 59 | for batch in b: 60 | yield beam_decode(model, batch, vocab, params) 61 | 62 | 63 | def test_and_save(params): 64 | assert params["test_save_dir"], "provide a dir where to save the results" 65 | gen = test(params) 66 | with tqdm(total=params["num_to_test"],position=0, leave=True) as pbar: 67 | for i in range(params["num_to_test"]): 68 | trial = next(gen) 69 | with open(params["test_save_dir"]+"/article_"+str(i)+".txt", "w") as f: 70 | f.write("article:\n") 71 | f.write(trial.text) 72 | f.write("\n\nabstract:\n") 73 | f.write(trial.abstract) 74 | pbar.update(1) 75 | 76 | def evaluate(params): 77 | gen = test(params) 78 | reals = [] 79 | preds = [] 80 | with tqdm(total=params["max_num_to_eval"],position=0, leave=True) as pbar: 81 | for i in range(params["max_num_to_eval"]): 82 | trial = next(gen) 83 | reals.append(trial.real_abstract) 84 | preds.append(trial.abstract) 85 | pbar.update(1) 86 | r=Rouge() 87 | scores = r.get_scores(preds, reals, avg=True) 88 | print("\n\n") 89 | pprint.pprint(scores) -------------------------------------------------------------------------------- /training_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | 4 | 5 | def train_model(model, dataset, params, ckpt, ckpt_manager, out_file): 6 | 7 | optimizer = tf.keras.optimizers.Adagrad(params['learning_rate'], initial_accumulator_value=params['adagrad_init_acc'], clipnorm=params['max_grad_norm']) 8 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy( 9 | from_logits=False, reduction='none') 10 | 11 | def loss_function(real, pred): 12 | mask = tf.math.logical_not(tf.math.equal(real, 1)) 13 | dec_lens = tf.reduce_sum(tf.cast(mask, dtype=tf.float32), axis=-1) 14 | loss_ = loss_object(real, pred) 15 | mask = tf.cast(mask, dtype=loss_.dtype) 16 | loss_ *= mask 17 | loss_ = tf.reduce_sum(loss_, axis=-1)/dec_lens # we have to make sure no empty abstract is being used otherwise dec_lens may contain null values 18 | return tf.reduce_mean(loss_) 19 | 20 | @tf.function(input_signature=(tf.TensorSpec(shape=[params["batch_size"], None], dtype=tf.int32), 21 | tf.TensorSpec(shape=[params["batch_size"], None], dtype=tf.int32), 22 | tf.TensorSpec(shape=[params["batch_size"], params["max_dec_len"]], dtype=tf.int32), 23 | tf.TensorSpec(shape=[params["batch_size"], params["max_dec_len"]], dtype=tf.int32), 24 | tf.TensorSpec(shape=[], dtype=tf.int32))) 25 | def train_step(enc_inp, enc_extended_inp, dec_inp, dec_tar, batch_oov_len): 26 | loss = 0 27 | 28 | with tf.GradientTape() as tape: 29 | enc_hidden, enc_output = model.call_encoder(enc_inp) 30 | predictions, _ = model(enc_output, enc_hidden, enc_inp, enc_extended_inp, dec_inp, batch_oov_len) 31 | loss = loss_function(dec_tar, predictions) 32 | variables = model.encoder.trainable_variables + model.attention.trainable_variables + model.decoder.trainable_variables + model.pointer.trainable_variables 33 | gradients = tape.gradient(loss, variables) 34 | optimizer.apply_gradients(zip(gradients, variables)) 35 | return loss 36 | 37 | 38 | 39 | try: 40 | f = open(out_file,"w+") 41 | for batch in dataset: 42 | t0 = time.time() 43 | loss = train_step(batch[0]["enc_input"], batch[0]["extended_enc_input"], batch[1]["dec_input"], batch[1]["dec_target"], batch[0]["max_oov_len"]) 44 | print('Step {}, time {:.4f}, Loss {:.4f}'.format(int(ckpt.step), 45 | time.time()-t0, 46 | loss.numpy())) 47 | f.write('Step {}, time {:.4f}, Loss {:.4f}\n'.format(int(ckpt.step), 48 | time.time()-t0, 49 | loss.numpy())) 50 | if int(ckpt.step) == params["max_steps"]: 51 | ckpt_manager.save(checkpoint_number=int(ckpt.step)) 52 | print("Saved checkpoint for step {}".format(int(ckpt.step))) 53 | f.close() 54 | break 55 | if int(ckpt.step) % params["checkpoints_save_steps"] ==0 : 56 | ckpt_manager.save(checkpoint_number=int(ckpt.step)) 57 | print("Saved checkpoint for step {}".format(int(ckpt.step))) 58 | ckpt.step.assign_add(1) 59 | f.close() 60 | 61 | 62 | except KeyboardInterrupt: 63 | ckpt_manager.save(int(ckpt.step)) 64 | print("Saved checkpoint for step {}".format(int(ckpt.step))) 65 | f.close() 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import logging 4 | 5 | def define_logger(log_file): 6 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 7 | # get TF logger 8 | log = logging.getLogger('tensorflow') 9 | log.setLevel(logging.DEBUG) 10 | 11 | # create formatter and add it to the handlers 12 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 13 | 14 | # create file handler which logs even debug messages 15 | fh = logging.FileHandler(log_file) 16 | fh.setLevel(logging.INFO) 17 | fh.setFormatter(formatter) 18 | log.addHandler(fh) 19 | 20 | def _calc_final_dist( _enc_batch_extend_vocab, vocab_dists, attn_dists, p_gens, batch_oov_len, vocab_size, batch_size): 21 | """Calculate the final distribution, for the pointer-generator model 22 | Args: 23 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file. 24 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays 25 | Returns: 26 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays. 27 | """ 28 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen) 29 | vocab_dists = [p_gen * dist for (p_gen,dist) in zip(p_gens, vocab_dists)] 30 | attn_dists = [(1-p_gen) * dist for (p_gen,dist) in zip(p_gens, attn_dists)] 31 | 32 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words 33 | extended_vsize = vocab_size + batch_oov_len # the maximum (over the batch) size of the extended vocabulary 34 | extra_zeros = tf.zeros((batch_size, batch_oov_len )) 35 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize) 36 | 37 | # Project the values in the attention distributions onto the appropriate entries in the final distributions 38 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution 39 | # This is done for each decoder timestep. 40 | # This is fiddly; we use tf.scatter_nd to do the projection 41 | batch_nums = tf.range(0, limit=batch_size) # shape (batch_size) 42 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1) 43 | attn_len = tf.shape(_enc_batch_extend_vocab)[1] # number of states we attend over 44 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len) 45 | indices = tf.stack( (batch_nums, _enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2) 46 | shape = [batch_size, extended_vsize] 47 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists] # list length max_dec_steps (batch_size, extended_vsize) 48 | 49 | # Add the vocab distributions and the copy distributions together to get the final distributions 50 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep 51 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore. 52 | final_dists = [vocab_dist + copy_dist for (vocab_dist,copy_dist) in zip(vocab_dists_extended, attn_dists_projected)] 53 | 54 | return final_dists --------------------------------------------------------------------------------