├── LICENSE.md ├── README.md ├── batcher.py ├── build_eval_test.py ├── data_helper.py ├── layers.py ├── main.py ├── pointer_generator_run_on_colab.ipynb ├── predict_helper.py ├── training_helper.py ├── transformer.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2019 David Stephane 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pointer_Transformer_Generator tensorflow 2.0.0 2 | 3 | For the abstractive summarization task, I wanted to experiment the transformer model. I recreated a transformer model (thanks to tensorflow transformer tutorial) and added a pointer module (have a look at this paper for more informations on the pointer generator network : https://arxiv.org/abs/1704.04368 ). 4 | 5 | PS : I will add very soon a section explaining the integration of the pointer module in the transformer 6 | 7 | Please follow the next steps to launch the project : 8 | 9 | ## Step 1 : The data 10 | 11 | ### Option 1 : Download the data 12 | Download the data (chunk files format : tfrecords) 13 | https://drive.google.com/open?id=1uHrMWd7Pbs_-DCl0eeMxePbxgmSce5LO 14 | 15 | ### Option 2 : Download raw data and process it 16 | Use this project : 17 | https://github.com/steph1793/CNN-DailyMail-Bin-To-TFRecords 18 | 19 | ## Step 2 : launch the project : 20 | 21 | **python main.py --max_enc_len=400 \
22 | --max_dec_len=100 \
23 | --batch_size=16 \
24 | --vocab_size=50000 \
25 | --num_layers=3 \
26 | --model_depth=512 \
27 | --num_heads=8 \
28 | --dff=2048 \
29 | --seed=123 \
30 | --log_step_count_steps=1 \
31 | --max_steps=230000 \
32 | --mode=train \
33 | --save_summary_steps=10000 \
34 | --checkpoints_save_steps=10000 \
35 | --model_dir=model_folder \
36 | --data_dir=data_folder \
37 | --vocab_path=vocab \
** 38 | 39 | PS : Feel free to change some of the hyperparameters
40 | python main.py --help , for more details on the hyperparameters 41 | 42 | 43 | 44 | ## Requirements 45 | - python >= 3.6 46 | - tensorflow 2.0.0 47 | - argparse 48 | - os 49 | - glob 50 | - numpy 51 | 52 | -------------------------------------------------------------------------------- /batcher.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import glob 3 | 4 | from data_helper import Vocab, Data_Helper 5 | 6 | 7 | def _parse_function(example_proto): 8 | # Create a description of the features. 9 | feature_description = { 10 | 'article': tf.io.FixedLenFeature([], tf.string, default_value=''), 11 | 'abstract': tf.io.FixedLenFeature([], tf.string, default_value='') 12 | } 13 | # Parse the input `tf.Example` proto using the dictionary above. 14 | parsed_example = tf.io.parse_single_example(example_proto, feature_description) 15 | 16 | return parsed_example 17 | 18 | 19 | 20 | def example_generator(filenames, vocab_path, vocab_size, max_enc_len, max_dec_len, training=False): 21 | 22 | raw_dataset = tf.data.TFRecordDataset(filenames) 23 | parsed_dataset = raw_dataset.map(_parse_function) 24 | if training: 25 | parsed_dataset = parsed_dataset.shuffle(1000, reshuffle_each_iteration=True).repeat() 26 | 27 | 28 | vocab = Vocab(vocab_path, vocab_size) 29 | 30 | for raw_record in parsed_dataset: 31 | 32 | article = raw_record["article"].numpy().decode() 33 | abstract = raw_record["abstract"].numpy().decode() 34 | 35 | start_decoding = vocab.word_to_id(vocab.START_DECODING) 36 | stop_decoding = vocab.word_to_id(vocab.STOP_DECODING) 37 | 38 | article_words = article.split()[ : max_enc_len] 39 | enc_len = len(article_words) 40 | enc_input = [vocab.word_to_id(w) for w in article_words] 41 | enc_input_extend_vocab, article_oovs = Data_Helper.article_to_ids(article_words, vocab) 42 | 43 | abstract_sentences = [sent.strip() for sent in Data_Helper.abstract_to_sents(abstract)] 44 | abstract = ' '.join(abstract_sentences) 45 | abstract_words = abstract.split() 46 | abs_ids = [vocab.word_to_id(w) for w in abstract_words] 47 | abs_ids_extend_vocab = Data_Helper.abstract_to_ids(abstract_words, vocab, article_oovs) 48 | dec_input, target = Data_Helper.get_dec_inp_targ_seqs(abs_ids, max_dec_len, start_decoding, stop_decoding) 49 | _, target = Data_Helper.get_dec_inp_targ_seqs(abs_ids_extend_vocab, max_dec_len, start_decoding, stop_decoding) 50 | dec_len = len(dec_input) 51 | 52 | output = { 53 | "enc_len":enc_len, 54 | "enc_input" : enc_input, 55 | "enc_input_extend_vocab" : enc_input_extend_vocab, 56 | "article_oovs" : article_oovs, 57 | "dec_input" : dec_input, 58 | "target" : target, 59 | "dec_len" : dec_len, 60 | "article" : article, 61 | "abstract" : abstract, 62 | "abstract_sents" : abstract_sentences 63 | } 64 | 65 | 66 | yield output 67 | 68 | 69 | def batch_generator(generator, filenames, vocab_path, vocab_size, max_enc_len, max_dec_len, batch_size, training): 70 | 71 | dataset = tf.data.Dataset.from_generator(generator, args = [filenames, vocab_path, vocab_size, max_enc_len, max_dec_len, training], 72 | output_types = { 73 | "enc_len":tf.int32, 74 | "enc_input" : tf.int32, 75 | "enc_input_extend_vocab" : tf.int32, 76 | "article_oovs" : tf.string, 77 | "dec_input" : tf.int32, 78 | "target" : tf.int32, 79 | "dec_len" : tf.int32, 80 | "article" : tf.string, 81 | "abstract" : tf.string, 82 | "abstract_sents" : tf.string 83 | }, output_shapes={ 84 | "enc_len":[], 85 | "enc_input" : [None], 86 | "enc_input_extend_vocab" : [None], 87 | "article_oovs" : [None], 88 | "dec_input" : [None], 89 | "target" : [None], 90 | "dec_len" : [], 91 | "article" : [], 92 | "abstract" : [], 93 | "abstract_sents" : [None] 94 | }) 95 | dataset = dataset.padded_batch(batch_size, padded_shapes=({"enc_len":[], 96 | "enc_input" : [None], 97 | "enc_input_extend_vocab" : [None], 98 | "article_oovs" : [None], 99 | "dec_input" : [max_dec_len], 100 | "target" : [max_dec_len], 101 | "dec_len" : [], 102 | "article" : [], 103 | "abstract" : [], 104 | "abstract_sents" : [None]}), 105 | padding_values={"enc_len":-1, 106 | "enc_input" : 1, 107 | "enc_input_extend_vocab" : 1, 108 | "article_oovs" : b'', 109 | "dec_input" : 1, 110 | "target" : 1, 111 | "dec_len" : -1, 112 | "article" : b"", 113 | "abstract" : b"", 114 | "abstract_sents" : b''}, 115 | drop_remainder=True) 116 | def update(entry): 117 | return ({"enc_input" : entry["enc_input"], 118 | "extended_enc_input" : entry["enc_input_extend_vocab"], 119 | "article_oovs" : entry["article_oovs"], 120 | "enc_len" : entry["enc_len"], 121 | "article" : entry["article"], 122 | "max_oov_len" : tf.shape(entry["article_oovs"])[1] }, 123 | 124 | {"dec_input" : entry["dec_input"], 125 | "dec_target" : entry["target"], 126 | "dec_len" : entry["dec_len"], 127 | "abstract" : entry["abstract"]}) 128 | 129 | 130 | dataset = dataset.map(update) 131 | 132 | return dataset 133 | 134 | 135 | def batcher(data_path, vocab_path, hpm): 136 | 137 | filenames = glob.glob("{}/*.tfrecords".format(data_path)) 138 | dataset = batch_generator(example_generator, filenames, vocab_path, hpm["vocab_size"], hpm["max_enc_len"], hpm["max_dec_len"], hpm["batch_size"], hpm["training"] ) 139 | 140 | return dataset 141 | -------------------------------------------------------------------------------- /build_eval_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.training import training_util 3 | from training_helper import train_model 4 | from predict_helper import predict 5 | from batcher import batcher 6 | from transformer import Transformer 7 | import os 8 | 9 | 10 | 11 | def my_model(features, labels, mode, params): 12 | 13 | predictions, attn_weights = predict(features, params, transformer) 14 | estimator_spec = tf.estimator.EstimatorSpec(mode, predictions={"predictions":predictions}) 15 | 16 | print(transformer.summary()) 17 | return estimator_spec 18 | 19 | 20 | def build_model(params): 21 | 22 | config = tf.estimator.RunConfig( 23 | tf_random_seed=params["seed"], 24 | log_step_count_steps=params["log_step_count_steps"], 25 | save_summary_steps=params["save_summary_steps"] 26 | ) 27 | 28 | return tf.estimator.Estimator( 29 | model_fn=my_model, 30 | params=params, config=config, model_dir=params["model_dir"] ) 31 | 32 | 33 | def train(params): 34 | assert params["training"], "change training mode to true" 35 | 36 | tf.compat.v1.logging.info("Building the model ...") 37 | transformer = Transformer( 38 | num_layers=params["num_layers"], d_model=params["model_depth"], num_heads=params["num_heads"], dff=params["dff"], 39 | vocab_size=params["vocab_size"], batch_size=params["batch_size"]) 40 | 41 | 42 | tf.compat.v1.logging.info("Creating the batcher ...") 43 | b = batcher(params["data_dir"], params["vocab_path"], params) 44 | 45 | tf.compat.v1.logging.info("Creating the checkpoint manager") 46 | logdir = "{}/logdir".format(params["model_dir"]) 47 | checkpoint_dir = "{}/checkpoint".format(params["model_dir"]) 48 | ckpt = tf.train.Checkpoint(step=tf.Variable(0), transformer=transformer) 49 | ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=11) 50 | 51 | ckpt.restore(ckpt_manager.latest_checkpoint) 52 | if ckpt_manager.latest_checkpoint: 53 | print("Restored from {}".format(ckpt_manager.latest_checkpoint)) 54 | else: 55 | print("Initializing from scratch.") 56 | 57 | tf.compat.v1.logging.info("Starting the training ...") 58 | train_model(transformer, b, params, ckpt, ckpt_manager) 59 | 60 | 61 | 62 | def eval(model, params): 63 | pass 64 | 65 | 66 | def test(model, params): 67 | assert not params["training"], "change training mode to false" 68 | checkpoint_dir = "{}/checkpoint".format(params["model_dir"]) 69 | logdir = "{}/logdir".format(params["model_dir"]) 70 | 71 | pred = model.predict(input_fn = lambda : batcher(params["data_dir"], params["vocab_path"], params), 72 | yield_single_examples=False) 73 | 74 | yield next(pred) -------------------------------------------------------------------------------- /data_helper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Vocab: 4 | 5 | SENTENCE_START = '' 6 | SENTENCE_END = '' 7 | 8 | PAD_TOKEN = '[PAD]' 9 | UNKNOWN_TOKEN = '[UNK]' 10 | START_DECODING = '[START]' 11 | STOP_DECODING = '[STOP]' 12 | 13 | def __init__(self, vocab_file, max_size): 14 | 15 | self.word2id = {Vocab.UNKNOWN_TOKEN : 0, Vocab.PAD_TOKEN : 1, Vocab.START_DECODING : 2, Vocab.STOP_DECODING : 3} 16 | self.id2word = {0 : Vocab.UNKNOWN_TOKEN, 1 : Vocab.PAD_TOKEN, 2 : Vocab.START_DECODING, 3 : Vocab.STOP_DECODING} 17 | self.count = 4 18 | 19 | with open(vocab_file, 'r') as f: 20 | for line in f: 21 | pieces = line.split() 22 | if len(pieces) != 2 : 23 | print('Warning : incorrectly formatted line in vocabulary file : %s\n' % line) 24 | continue 25 | 26 | w = pieces[0] 27 | if w in [Vocab.SENTENCE_START, Vocab.SENTENCE_END, Vocab.UNKNOWN_TOKEN, Vocab.PAD_TOKEN, Vocab.START_DECODING, Vocab.STOP_DECODING]: 28 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 29 | 30 | if w in self.word2id: 31 | raise Exception('Duplicated word in vocabulary file: %s' % w) 32 | 33 | self.word2id[w] = self.count 34 | self.id2word[self.count] = w 35 | self.count += 1 36 | if max_size != 0 and self.count >= max_size: 37 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self.count)) 38 | break 39 | 40 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self.count, self.id2word[self.count-1])) 41 | 42 | 43 | def word_to_id(self, word): 44 | if word not in self.word2id: 45 | return self.word2id[Vocab.UNKNOWN_TOKEN] 46 | return self.word2id[word] 47 | 48 | def id_to_word(self, word_id): 49 | if word_id not in self.id2word: 50 | raise ValueError('Id not found in vocab: %d' % word_id) 51 | return self.id2word[word_id] 52 | 53 | def size(self): 54 | return self.count 55 | 56 | 57 | class Data_Helper: 58 | def article_to_ids(article_words, vocab): 59 | ids = [] 60 | oovs = [] 61 | unk_id = vocab.word_to_id(vocab.UNKNOWN_TOKEN) 62 | for w in article_words: 63 | i = vocab.word_to_id(w) 64 | if i == unk_id: # If w is OOV 65 | if w not in oovs: # Add to list of OOVs 66 | oovs.append(w) 67 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 68 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 69 | else: 70 | ids.append(i) 71 | return ids, oovs 72 | 73 | 74 | def abstract_to_ids(abstract_words, vocab, article_oovs): 75 | ids = [] 76 | unk_id = vocab.word_to_id(vocab.UNKNOWN_TOKEN) 77 | for w in abstract_words: 78 | i = vocab.word_to_id(w) 79 | if i == unk_id: # If w is an OOV word 80 | if w in article_oovs: # If w is an in-article OOV 81 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 82 | ids.append(vocab_idx) 83 | else: # If w is an out-of-article OOV 84 | ids.append(unk_id) # Map to the UNK token id 85 | else: 86 | ids.append(i) 87 | return ids 88 | 89 | 90 | 91 | def output_to_words(id_list, vocab, article_oovs): 92 | words = [] 93 | for i in id_list: 94 | try: 95 | w = vocab.id_to_word(i) # might be [UNK] 96 | except ValueError as e: # w is OOV 97 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 98 | article_oov_idx = i - vocab.size() 99 | try: 100 | w = article_oovs[article_oov_idx] 101 | except ValueError as e: # i doesn't correspond to an article oov 102 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 103 | words.append(w) 104 | return words 105 | 106 | 107 | 108 | def abstract_to_sents(abstract): 109 | """Splits abstract text from datafile into list of sentences. 110 | Args: 111 | abstract: string containing and tags for starts and ends of sentences 112 | Returns: 113 | sents: List of sentence strings (no tags) 114 | """ 115 | cur = 0 116 | sents = [] 117 | while True: 118 | try: 119 | start_p = abstract.index(Vocab.SENTENCE_START, cur) 120 | end_p = abstract.index(Vocab.SENTENCE_END, start_p + 1) 121 | cur = end_p + len(Vocab.SENTENCE_END) 122 | sents.append(abstract[start_p+len(Vocab.SENTENCE_START):end_p]) 123 | except ValueError as e: # no more sentences 124 | return sents 125 | 126 | def get_dec_inp_targ_seqs( sequence, max_len, start_id, stop_id): 127 | """ 128 | Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated). 129 | Args: 130 | sequence: List of ids (integers) 131 | max_len: integer 132 | start_id: integer 133 | stop_id: integer 134 | Returns: 135 | inp: sequence length <=max_len starting with start_id 136 | target: sequence same length as input, ending with stop_id only if there was no truncation 137 | """ 138 | inp = [start_id] + sequence[:] 139 | target = sequence[:] 140 | if len(inp) > max_len: # truncate 141 | inp = inp[:max_len] 142 | target = target[:max_len] # no end_token 143 | else: # no truncation 144 | target.append(stop_id) # end token 145 | assert len(inp) == len(target) 146 | return inp, target 147 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import positional_encoding, scaled_dot_product_attention 3 | 4 | class Embedding(tf.keras.layers.Layer): 5 | 6 | def __init__(self, vocab_size, d_model): 7 | super(Embedding, self).__init__() 8 | self.vocab_size = vocab_size 9 | self.d_model = d_model 10 | 11 | self.embedding = tf.keras.layers.Embedding(vocab_size, d_model) 12 | self.pos_encoding = positional_encoding(vocab_size, d_model) 13 | 14 | def call(self, x): 15 | embed_x = self.embedding(x) # (batch_size, target_seq_len, d_model) 16 | embed_x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) 17 | embed_x += self.pos_encoding[:, :tf.shape(x)[1], :] 18 | return embed_x 19 | 20 | 21 | 22 | class MultiHeadAttention(tf.keras.layers.Layer): 23 | def __init__(self, d_model, num_heads): 24 | super(MultiHeadAttention, self).__init__() 25 | self.num_heads = num_heads 26 | self.d_model = d_model 27 | 28 | assert d_model % self.num_heads == 0 29 | 30 | self.depth = d_model // self.num_heads 31 | 32 | self.wq = tf.keras.layers.Dense(d_model) 33 | self.wk = tf.keras.layers.Dense(d_model) 34 | self.wv = tf.keras.layers.Dense(d_model) 35 | 36 | self.dense = tf.keras.layers.Dense(d_model) 37 | 38 | def split_heads(self, x, batch_size): 39 | """Split the last dimension into (num_heads, depth). 40 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) 41 | """ 42 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 43 | return tf.transpose(x, perm=[0, 2, 1, 3]) 44 | 45 | def call(self, v, k, q, mask): 46 | batch_size = tf.shape(q)[0] 47 | 48 | q = self.wq(q) # (batch_size, seq_len, d_model) 49 | k = self.wk(k) # (batch_size, seq_len, d_model) 50 | v = self.wv(v) # (batch_size, seq_len, d_model) 51 | 52 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 53 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 54 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 55 | 56 | # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) 57 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 58 | scaled_attention, attention_weights = scaled_dot_product_attention( 59 | q, k, v, mask) 60 | 61 | scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) 62 | 63 | concat_attention = tf.reshape(scaled_attention, 64 | (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) 65 | 66 | output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model) 67 | 68 | return output, attention_weights 69 | 70 | 71 | def point_wise_feed_forward_network(d_model, dff): 72 | return tf.keras.Sequential([ 73 | tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff) 74 | tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model) 75 | ]) 76 | 77 | 78 | class EncoderLayer(tf.keras.layers.Layer): 79 | def __init__(self, d_model, num_heads, dff, rate=0.1): 80 | super(EncoderLayer, self).__init__() 81 | 82 | self.mha = MultiHeadAttention(d_model, num_heads) 83 | self.ffn = point_wise_feed_forward_network(d_model, dff) 84 | 85 | self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 86 | self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 87 | 88 | self.dropout1 = tf.keras.layers.Dropout(rate) 89 | self.dropout2 = tf.keras.layers.Dropout(rate) 90 | 91 | def call(self, x, training, mask): 92 | 93 | attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model) 94 | attn_output = self.dropout1(attn_output, training=training) 95 | out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model) 96 | 97 | ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model) 98 | ffn_output = self.dropout2(ffn_output, training=training) 99 | out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model) 100 | 101 | return out2 102 | 103 | 104 | class DecoderLayer(tf.keras.layers.Layer): 105 | def __init__(self, d_model, num_heads, dff, rate=0.1): 106 | super(DecoderLayer, self).__init__() 107 | 108 | self.mha1 = MultiHeadAttention(d_model, num_heads) 109 | self.mha2 = MultiHeadAttention(d_model, num_heads) 110 | 111 | self.ffn = point_wise_feed_forward_network(d_model, dff) 112 | 113 | self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 114 | self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 115 | self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 116 | 117 | self.dropout1 = tf.keras.layers.Dropout(rate) 118 | self.dropout2 = tf.keras.layers.Dropout(rate) 119 | self.dropout3 = tf.keras.layers.Dropout(rate) 120 | 121 | 122 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask): 123 | # enc_output.shape == (batch_size, input_seq_len, d_model) 124 | 125 | attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model) 126 | attn1 = self.dropout1(attn1, training=training) 127 | out1 = self.layernorm1(attn1 + x) 128 | 129 | attn2, attn_weights_block2 = self.mha2( 130 | enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model) 131 | attn2 = self.dropout2(attn2, training=training) 132 | out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model) 133 | 134 | ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model) 135 | ffn_output = self.dropout3(ffn_output, training=training) 136 | out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model) 137 | 138 | return out3, attn_weights_block1, attn_weights_block2 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | from build_eval_test import build_model, train, test 4 | import os 5 | 6 | def main(): 7 | 8 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--max_enc_len", default=400, help="Encoder input max sequence length", type=int) 12 | 13 | parser.add_argument("--max_dec_len", default=100, help="Decoder input max sequence length", type=int) 14 | 15 | parser.add_argument("--batch_size", default=16, help="batch size", type=int) 16 | 17 | parser.add_argument("--vocab_size", default=50000, help="Vocabulary size", type=int) 18 | 19 | parser.add_argument("--num_layers", default=3, help="Model encoder and decoder number of layers", type=int) 20 | 21 | parser.add_argument("--model_depth", default=512, help="Model Embedding size", type=int) 22 | 23 | parser.add_argument("--num_heads", default=8, help="Multi Attention number of heads", type=int) 24 | 25 | parser.add_argument("--dff", default=2048, help="Dff", type=int) 26 | 27 | parser.add_argument("--seed", default=123, help="Seed", type=int) 28 | 29 | parser.add_argument("--log_step_count_steps", default=1, help="Log each N steps", type=int) 30 | 31 | parser.add_argument("--max_steps",default=230000, help="Max steps for training", type=int) 32 | 33 | parser.add_argument("--save_summary_steps", default=10000, help="Save summaries every N steps", type=int) 34 | 35 | parser.add_argument("--checkpoints_save_steps", default=10000, help="Save checkpoints every N steps", type=int) 36 | 37 | parser.add_argument("--mode", help="training, eval or test options") 38 | 39 | parser.add_argument("--model_dir", help="Model folder") 40 | 41 | parser.add_argument("--data_dir", help="Data Folder") 42 | 43 | parser.add_argument("--vocab_path", help="Vocab path") 44 | 45 | 46 | args = parser.parse_args() 47 | params = vars(args) 48 | print(params) 49 | 50 | assert params["mode"], "mode is required. train, test or eval option" 51 | if params["mode"] == "train": 52 | params["training"] = True ; params["eval"] = False ; params["test"] = False 53 | elif params["mode"] == "eval": 54 | params["training"] = False ; params["eval"] = True ; params["test"] = False 55 | elif params["mode"] == "test": 56 | params["training"] = False ; params["eval"] = False ; params["test"] = True; 57 | else: 58 | raise NameError("The mode must be train , test or eval") 59 | assert os.path.exists(params["data_dir"]), "data_dir doesn't exist" 60 | assert os.path.isfile(params["vocab_path"]), "vocab_path doesn't exist" 61 | 62 | 63 | 64 | if params["training"]: 65 | train( params) 66 | elif params["eval"]: 67 | pass 68 | elif not params["training"]: 69 | pass 70 | 71 | 72 | if __name__ == "__main__": 73 | main() -------------------------------------------------------------------------------- /pointer_generator_run_on_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pointer_generator_run_on_colab.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "R0YKfS3YGTk9", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "# Pointer Transformer Generator Tensorflow 2.0.2 beta-1" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "EkYCSmWOE_CF", 42 | "colab_type": "code", 43 | "colab": {} 44 | }, 45 | "source": [ 46 | "!pip install tensorflow-gpu==2.0.0-beta1" 47 | ], 48 | "execution_count": 0, 49 | "outputs": [] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "metadata": { 54 | "id": "aR0InIHa6GI3", 55 | "colab_type": "code", 56 | "colab": {} 57 | }, 58 | "source": [ 59 | "from google.colab import drive\n", 60 | "drive.mount(\"/content/drive\")" 61 | ], 62 | "execution_count": 0, 63 | "outputs": [] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": { 68 | "id": "5GG_DJjc_VIL", 69 | "colab_type": "text" 70 | }, 71 | "source": [ 72 | "\n", 73 | "**Git clone project**" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "metadata": { 79 | "id": "RYWr1D0f6j6K", 80 | "colab_type": "code", 81 | "colab": {} 82 | }, 83 | "source": [ 84 | "%%bash\n", 85 | "git clone https://github.com/steph1793/Pointer_Transformer_Generator " 86 | ], 87 | "execution_count": 0, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "id": "BR5vlKHz9Req", 94 | "colab_type": "text" 95 | }, 96 | "source": [ 97 | "You can modify model_dir in order to save you model in your drive; and data_dir and vocab_path, in repect with their real paths in your drive\n", 98 | "\n", 99 | "\n", 100 | "*** PS : We suppose you have your data on your drive***" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "metadata": { 106 | "id": "nox5S1LI8kN2", 107 | "colab_type": "code", 108 | "colab": {} 109 | }, 110 | "source": [ 111 | "!python Pointer_Transformer_Generator/main.py --max_enc_len=400 \\\n", 112 | "--max_dec_len=100 \\\n", 113 | "--batch_size=16 \\\n", 114 | "--vocab_size=50000 \\\n", 115 | "--num_layers=3 \\\n", 116 | "--model_depth=512 \\\n", 117 | "--num_heads=8 \\\n", 118 | "--dff=2048 \\\n", 119 | "--seed=123 \\\n", 120 | "--log_step_count_steps=1 \\\n", 121 | "--max_steps=230000 \\\n", 122 | "--mode=train \\\n", 123 | "--save_summary_steps=10000 \\\n", 124 | "--checkpoints_save_steps=10000 \\\n", 125 | "--model_dir=\"transformer_model_dir\" \\\n", 126 | "--data_dir=\"drive/My Drive/pointer_gen/cnn-dailymail/tfrecords_finished_files/chunked\" \\\n", 127 | "--vocab_path=\"drive/My Drive/pointer_gen/cnn-dailymail/tfrecords_finished_files/vocab\" \\" 128 | ], 129 | "execution_count": 0, 130 | "outputs": [] 131 | } 132 | ] 133 | } -------------------------------------------------------------------------------- /predict_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import create_masks 3 | 4 | def predict(features, params, model): 5 | 6 | output = tf.tile([[2]], [params["batch_size"], 1]) # 2 = start_decoding 7 | 8 | for i in range(params["max_dec_len"]): 9 | enc_padding_mask, combined_mask, dec_padding_mask = create_masks(features["enc_input"], output) 10 | 11 | # predictions.shape == (batch_size, seq_len, vocab_size) 12 | predictions, attention_weights = model(features["enc_input"],features["extended_enc_input"], features["max_oov_len"], output, training=params["training"], 13 | enc_padding_mask=enc_padding_mask, 14 | look_ahead_mask=combined_mask, 15 | dec_padding_mask=dec_padding_mask) 16 | 17 | # select the last word from the seq_len dimension 18 | predictions = predictions[: ,-1:, :] # (batch_size, 1, vocab_size) 19 | predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) 20 | 21 | 22 | # concatentate the predicted_id to the output which is given to the decoder 23 | # as its input. 24 | output = tf.concat([output, predicted_id], axis=-1) 25 | 26 | return output, attention_weights -------------------------------------------------------------------------------- /training_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import create_masks 3 | import time 4 | 5 | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 6 | def __init__(self, d_model, warmup_steps=4000): 7 | super(CustomSchedule, self).__init__() 8 | 9 | self.d_model = d_model 10 | self.d_model = tf.cast(self.d_model, tf.float32) 11 | 12 | self.warmup_steps = warmup_steps 13 | 14 | def __call__(self, step): 15 | arg1 = tf.math.rsqrt(step) 16 | arg2 = step * (self.warmup_steps ** -1.5) 17 | 18 | return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) 19 | 20 | 21 | def loss_function(loss_object, real, pred): 22 | mask = tf.math.logical_not(tf.math.equal(real, 0)) 23 | loss_ = loss_object(real, pred) 24 | 25 | mask = tf.cast(mask, dtype=loss_.dtype) 26 | loss_ *= mask 27 | 28 | return tf.reduce_mean(loss_) 29 | 30 | 31 | def train_step(features, labels, params, model, optimizer, loss_object, train_loss_metric): 32 | 33 | enc_padding_mask, combined_mask, dec_padding_mask = create_masks(features["enc_input"], labels["dec_input"]) 34 | 35 | with tf.GradientTape() as tape: 36 | output, attn_weights = model(features["enc_input"],features["extended_enc_input"], features["max_oov_len"], labels["dec_input"], training=params["training"], 37 | enc_padding_mask=enc_padding_mask, 38 | look_ahead_mask=combined_mask, 39 | dec_padding_mask=dec_padding_mask) 40 | loss = loss_function(loss_object, labels["dec_target"], output) 41 | 42 | gradients = tape.gradient(loss, model.trainable_variables) 43 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 44 | train_loss_metric(loss) 45 | 46 | def train_model(model, batcher, params, ckpt, ckpt_manager): 47 | learning_rate = CustomSchedule(params["model_depth"]) 48 | optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) 49 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction='none') 50 | train_loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") 51 | 52 | try: 53 | for batch in batcher: 54 | t0 = time.time() 55 | train_step(batch[0], batch[1], params, model, optimizer, loss_object, train_loss_metric) 56 | t1 = time.time() 57 | 58 | print("step {}, time : {}, loss: {}".format(int(ckpt.step), t1-t0, train_loss_metric.result())) 59 | if int(ckpt.step) % params["checkpoints_save_steps"] ==0 : 60 | ckpt_manager.save(checkpoint_number=int(ckpt.step)) 61 | print("Saved checkpoint for step {}".format(int(ckpt.step))) 62 | ckpt.step.assign_add(1) 63 | 64 | except KeyboardInterrupt: 65 | ckpt_manager.save(int(ckpt.step)) 66 | print("Saved checkpoint for step {}".format(int(ckpt.step))) -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from layers import Embedding, EncoderLayer, DecoderLayer 3 | from utils import _calc_final_dist 4 | 5 | 6 | class Encoder(tf.keras.layers.Layer): 7 | def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 8 | rate=0.1): 9 | super(Encoder, self).__init__() 10 | self.d_model = d_model 11 | self.num_layers = num_layers 12 | 13 | self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 14 | for _ in range(num_layers)] 15 | self.dropout = tf.keras.layers.Dropout(rate) 16 | 17 | def call(self, x, training, mask): 18 | x = self.dropout(x, training=training) 19 | 20 | for i in range(self.num_layers): 21 | x = self.enc_layers[i](x, training, mask) 22 | 23 | return x # (batch_size, input_seq_len, d_model) 24 | 25 | 26 | class Decoder(tf.keras.layers.Layer): 27 | def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, 28 | rate=0.1): 29 | super(Decoder, self).__init__() 30 | self.d_model = d_model 31 | self.num_layers = num_layers 32 | self.num_heads = num_heads 33 | self.depth = d_model // self.num_heads 34 | self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 35 | for _ in range(num_layers)] 36 | self.dropout = tf.keras.layers.Dropout(rate) 37 | self.Wh = tf.keras.layers.Dense(1) 38 | self.Ws = tf.keras.layers.Dense(1) 39 | self.Wx = tf.keras.layers.Dense(1) 40 | self.V = tf.keras.layers.Dense(1) 41 | 42 | 43 | def call(self, embed_x, enc_output, training, look_ahead_mask, padding_mask): 44 | 45 | attention_weights = {} 46 | out = self.dropout(embed_x, training=training) 47 | 48 | for i in range(self.num_layers): 49 | out, block1, block2 = self.dec_layers[i](out, enc_output, training, 50 | look_ahead_mask, padding_mask) 51 | 52 | attention_weights['decoder_layer{}_block1'.format(i+1)] = block1 53 | attention_weights['decoder_layer{}_block2'.format(i+1)] = block2 54 | 55 | # out.shape == (batch_size, target_seq_len, d_model) 56 | 57 | 58 | 59 | #context vectors 60 | enc_out_shape = tf.shape(enc_output) 61 | context = tf.reshape(enc_output,(enc_out_shape[0], enc_out_shape[1], self.num_heads, self.depth) ) # shape : (batch_size, input_seq_len, num_heads, depth) 62 | context = tf.transpose(context, [0,2,1,3]) # (batch_size, num_heads, input_seq_len, depth) 63 | context = tf.expand_dims(context, axis=2) # (batch_size, num_heads, 1, input_seq_len, depth) 64 | 65 | attn = tf.expand_dims(block2, axis=-1) # (batch_size, num_heads, target_seq_len, input_seq_len, 1) 66 | 67 | context = context * attn # (batch_size, num_heads, target_seq_len, input_seq_len, depth) 68 | context = tf.reduce_sum(context, axis=3) # (batch_size, num_heads, target_seq_len, depth) 69 | context = tf.transpose(context, [0,2,1,3]) # (batch_size, target_seq_len, num_heads, depth) 70 | context = tf.reshape(context, (tf.shape(context)[0], tf.shape(context)[1], self.d_model)) # (batch_size, target_seq_len, d_model) 71 | 72 | # P_gens computing 73 | a = self.Wx(embed_x) 74 | b = self.Ws(out) 75 | c = self.Wh(context) 76 | p_gens = tf.sigmoid(self.V(a + b + c)) 77 | 78 | return out, attention_weights, p_gens 79 | 80 | 81 | class Transformer(tf.keras.Model): 82 | def __init__(self, num_layers, d_model, num_heads, dff, vocab_size,batch_size, rate=0.1): 83 | super(Transformer, self).__init__() 84 | 85 | self.num_layers =num_layers 86 | self.vocab_size = vocab_size 87 | self.batch_size = batch_size 88 | self.model_depth = d_model 89 | self.num_heads = num_heads 90 | 91 | self.embedding = Embedding(vocab_size, d_model) 92 | self.encoder = Encoder(num_layers, d_model, num_heads, dff, vocab_size, rate) 93 | self.decoder = Decoder(num_layers, d_model, num_heads, dff, vocab_size, rate) 94 | self.final_layer = tf.keras.layers.Dense(vocab_size) 95 | 96 | 97 | def call(self, inp, extended_inp,max_oov_len, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask): 98 | 99 | embed_x = self.embedding(inp) 100 | embed_dec = self.embedding(tar) 101 | 102 | enc_output = self.encoder(embed_x, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model) 103 | 104 | # dec_output.shape == (batch_size, tar_seq_len, d_model) 105 | dec_output, attention_weights, p_gens = self.decoder(embed_dec, enc_output, training, look_ahead_mask, dec_padding_mask) 106 | 107 | output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size) 108 | output = tf.nn.softmax(output) # (batch_size, tar_seq_len, vocab_size) 109 | #output = tf.concat([output, tf.zeros((tf.shape(output)[0], tf.shape(output)[1], max_oov_len))], axis=-1) # (batch_size, targ_seq_len, vocab_size+max_oov_len) 110 | 111 | attn_dists = attention_weights['decoder_layer{}_block2'.format(self.num_layers)] # (batch_size,num_heads, targ_seq_len, inp_seq_len) 112 | attn_dists = tf.reduce_sum(attn_dists, axis=1)/self.num_heads # (batch_size, targ_seq_len, inp_seq_len) 113 | 114 | 115 | final_dists = _calc_final_dist( extended_inp, tf.unstack(output, axis=1) , tf.unstack(attn_dists, axis=1), tf.unstack(p_gens, axis=1), max_oov_len, self.vocab_size, self.batch_size) 116 | final_output =tf.stack(final_dists, axis=1) 117 | 118 | return final_output, attention_weights -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | 6 | def get_angles(pos, i, d_model): 7 | angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model)) 8 | return pos * angle_rates 9 | 10 | def positional_encoding(position, d_model): 11 | angle_rads = get_angles(np.arange(position)[:, np.newaxis], 12 | np.arange(d_model)[np.newaxis, :], 13 | d_model) 14 | 15 | # apply sin to even indices in the array; 2i 16 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 17 | 18 | # apply cos to odd indices in the array; 2i+1 19 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 20 | 21 | pos_encoding = angle_rads[np.newaxis, ...] 22 | 23 | return tf.cast(pos_encoding, dtype=tf.float32) 24 | 25 | 26 | 27 | def create_padding_mask(seq): 28 | seq = tf.cast(tf.math.equal(seq, 1), tf.float32) 29 | 30 | # add extra dimensions to add the padding 31 | # to the attention logits. 32 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) 33 | 34 | 35 | 36 | def create_look_ahead_mask(size): 37 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 38 | return mask # (seq_len, seq_len) 39 | 40 | 41 | def create_masks(inp, tar): 42 | # Encoder padding mask 43 | enc_padding_mask = create_padding_mask(inp) 44 | 45 | # Used in the 2nd attention block in the decoder. 46 | # This padding mask is used to mask the encoder outputs. 47 | dec_padding_mask = create_padding_mask(inp) 48 | 49 | # Used in the 1st attention block in the decoder. 50 | # It is used to pad and mask future tokens in the input received by 51 | # the decoder. 52 | look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) 53 | dec_target_padding_mask = create_padding_mask(tar) 54 | combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) 55 | 56 | return enc_padding_mask, combined_mask, dec_padding_mask 57 | 58 | 59 | def scaled_dot_product_attention(q, k, v, mask): 60 | """Calculate the attention weights. 61 | q, k, v must have matching leading dimensions. 62 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. 63 | The mask has different shapes depending on its type(padding or look ahead) 64 | but it must be broadcastable for addition. 65 | 66 | Args: 67 | q: query shape == (..., seq_len_q, depth) 68 | k: key shape == (..., seq_len_k, depth) 69 | v: value shape == (..., seq_len_v, depth_v) 70 | mask: Float tensor with shape broadcastable 71 | to (..., seq_len_q, seq_len_k). Defaults to None. 72 | 73 | Returns: 74 | output, attention_weights 75 | """ 76 | 77 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 78 | 79 | # scale matmul_qk 80 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 81 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 82 | 83 | # add the mask to the scaled tensor. 84 | if mask is not None: 85 | scaled_attention_logits += (mask * -1e9) 86 | 87 | # softmax is normalized on the last axis (seq_len_k) so that the scores 88 | # add up to 1. 89 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 90 | 91 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 92 | 93 | return output, attention_weights 94 | 95 | def _calc_final_dist( _enc_batch_extend_vocab, vocab_dists, attn_dists, p_gens, batch_oov_len, vocab_size, batch_size): 96 | """Calculate the final distribution, for the pointer-generator model 97 | 98 | Args: 99 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file. 100 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays 101 | 102 | Returns: 103 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays. 104 | """ 105 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen) 106 | vocab_dists = [p_gen * dist for (p_gen,dist) in zip(p_gens, vocab_dists)] 107 | attn_dists = [(1-p_gen) * dist for (p_gen,dist) in zip(p_gens, attn_dists)] 108 | 109 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words 110 | extended_vsize = vocab_size + batch_oov_len # the maximum (over the batch) size of the extended vocabulary 111 | extra_zeros = tf.zeros((batch_size, batch_oov_len )) 112 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize) 113 | 114 | # Project the values in the attention distributions onto the appropriate entries in the final distributions 115 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution 116 | # This is done for each decoder timestep. 117 | # This is fiddly; we use tf.scatter_nd to do the projection 118 | batch_nums = tf.range(0, limit=batch_size) # shape (batch_size) 119 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1) 120 | attn_len = tf.shape(_enc_batch_extend_vocab)[1] # number of states we attend over 121 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len) 122 | indices = tf.stack( (batch_nums, _enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2) 123 | shape = [batch_size, extended_vsize] 124 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists] # list length max_dec_steps (batch_size, extended_vsize) 125 | 126 | # Add the vocab distributions and the copy distributions together to get the final distributions 127 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep 128 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore. 129 | final_dists = [vocab_dist + copy_dist for (vocab_dist,copy_dist) in zip(vocab_dists_extended, attn_dists_projected)] 130 | 131 | return final_dists --------------------------------------------------------------------------------