├── LICENSE.md
├── README.md
├── batcher.py
├── build_eval_test.py
├── data_helper.py
├── layers.py
├── main.py
├── pointer_generator_run_on_colab.ipynb
├── predict_helper.py
├── training_helper.py
├── transformer.py
└── utils.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 |
2 | The MIT License (MIT)
3 |
4 | Copyright (c) 2019 David Stephane
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pointer_Transformer_Generator tensorflow 2.0.0
2 |
3 | For the abstractive summarization task, I wanted to experiment the transformer model. I recreated a transformer model (thanks to tensorflow transformer tutorial) and added a pointer module (have a look at this paper for more informations on the pointer generator network : https://arxiv.org/abs/1704.04368 ).
4 |
5 | PS : I will add very soon a section explaining the integration of the pointer module in the transformer
6 |
7 | Please follow the next steps to launch the project :
8 |
9 | ## Step 1 : The data
10 |
11 | ### Option 1 : Download the data
12 | Download the data (chunk files format : tfrecords)
13 | https://drive.google.com/open?id=1uHrMWd7Pbs_-DCl0eeMxePbxgmSce5LO
14 |
15 | ### Option 2 : Download raw data and process it
16 | Use this project :
17 | https://github.com/steph1793/CNN-DailyMail-Bin-To-TFRecords
18 |
19 | ## Step 2 : launch the project :
20 |
21 | **python main.py --max_enc_len=400 \
22 | --max_dec_len=100 \
23 | --batch_size=16 \
24 | --vocab_size=50000 \
25 | --num_layers=3 \
26 | --model_depth=512 \
27 | --num_heads=8 \
28 | --dff=2048 \
29 | --seed=123 \
30 | --log_step_count_steps=1 \
31 | --max_steps=230000 \
32 | --mode=train \
33 | --save_summary_steps=10000 \
34 | --checkpoints_save_steps=10000 \
35 | --model_dir=model_folder \
36 | --data_dir=data_folder \
37 | --vocab_path=vocab \
**
38 |
39 | PS : Feel free to change some of the hyperparameters
40 | python main.py --help , for more details on the hyperparameters
41 |
42 |
43 |
44 | ## Requirements
45 | - python >= 3.6
46 | - tensorflow 2.0.0
47 | - argparse
48 | - os
49 | - glob
50 | - numpy
51 |
52 |
--------------------------------------------------------------------------------
/batcher.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import glob
3 |
4 | from data_helper import Vocab, Data_Helper
5 |
6 |
7 | def _parse_function(example_proto):
8 | # Create a description of the features.
9 | feature_description = {
10 | 'article': tf.io.FixedLenFeature([], tf.string, default_value=''),
11 | 'abstract': tf.io.FixedLenFeature([], tf.string, default_value='')
12 | }
13 | # Parse the input `tf.Example` proto using the dictionary above.
14 | parsed_example = tf.io.parse_single_example(example_proto, feature_description)
15 |
16 | return parsed_example
17 |
18 |
19 |
20 | def example_generator(filenames, vocab_path, vocab_size, max_enc_len, max_dec_len, training=False):
21 |
22 | raw_dataset = tf.data.TFRecordDataset(filenames)
23 | parsed_dataset = raw_dataset.map(_parse_function)
24 | if training:
25 | parsed_dataset = parsed_dataset.shuffle(1000, reshuffle_each_iteration=True).repeat()
26 |
27 |
28 | vocab = Vocab(vocab_path, vocab_size)
29 |
30 | for raw_record in parsed_dataset:
31 |
32 | article = raw_record["article"].numpy().decode()
33 | abstract = raw_record["abstract"].numpy().decode()
34 |
35 | start_decoding = vocab.word_to_id(vocab.START_DECODING)
36 | stop_decoding = vocab.word_to_id(vocab.STOP_DECODING)
37 |
38 | article_words = article.split()[ : max_enc_len]
39 | enc_len = len(article_words)
40 | enc_input = [vocab.word_to_id(w) for w in article_words]
41 | enc_input_extend_vocab, article_oovs = Data_Helper.article_to_ids(article_words, vocab)
42 |
43 | abstract_sentences = [sent.strip() for sent in Data_Helper.abstract_to_sents(abstract)]
44 | abstract = ' '.join(abstract_sentences)
45 | abstract_words = abstract.split()
46 | abs_ids = [vocab.word_to_id(w) for w in abstract_words]
47 | abs_ids_extend_vocab = Data_Helper.abstract_to_ids(abstract_words, vocab, article_oovs)
48 | dec_input, target = Data_Helper.get_dec_inp_targ_seqs(abs_ids, max_dec_len, start_decoding, stop_decoding)
49 | _, target = Data_Helper.get_dec_inp_targ_seqs(abs_ids_extend_vocab, max_dec_len, start_decoding, stop_decoding)
50 | dec_len = len(dec_input)
51 |
52 | output = {
53 | "enc_len":enc_len,
54 | "enc_input" : enc_input,
55 | "enc_input_extend_vocab" : enc_input_extend_vocab,
56 | "article_oovs" : article_oovs,
57 | "dec_input" : dec_input,
58 | "target" : target,
59 | "dec_len" : dec_len,
60 | "article" : article,
61 | "abstract" : abstract,
62 | "abstract_sents" : abstract_sentences
63 | }
64 |
65 |
66 | yield output
67 |
68 |
69 | def batch_generator(generator, filenames, vocab_path, vocab_size, max_enc_len, max_dec_len, batch_size, training):
70 |
71 | dataset = tf.data.Dataset.from_generator(generator, args = [filenames, vocab_path, vocab_size, max_enc_len, max_dec_len, training],
72 | output_types = {
73 | "enc_len":tf.int32,
74 | "enc_input" : tf.int32,
75 | "enc_input_extend_vocab" : tf.int32,
76 | "article_oovs" : tf.string,
77 | "dec_input" : tf.int32,
78 | "target" : tf.int32,
79 | "dec_len" : tf.int32,
80 | "article" : tf.string,
81 | "abstract" : tf.string,
82 | "abstract_sents" : tf.string
83 | }, output_shapes={
84 | "enc_len":[],
85 | "enc_input" : [None],
86 | "enc_input_extend_vocab" : [None],
87 | "article_oovs" : [None],
88 | "dec_input" : [None],
89 | "target" : [None],
90 | "dec_len" : [],
91 | "article" : [],
92 | "abstract" : [],
93 | "abstract_sents" : [None]
94 | })
95 | dataset = dataset.padded_batch(batch_size, padded_shapes=({"enc_len":[],
96 | "enc_input" : [None],
97 | "enc_input_extend_vocab" : [None],
98 | "article_oovs" : [None],
99 | "dec_input" : [max_dec_len],
100 | "target" : [max_dec_len],
101 | "dec_len" : [],
102 | "article" : [],
103 | "abstract" : [],
104 | "abstract_sents" : [None]}),
105 | padding_values={"enc_len":-1,
106 | "enc_input" : 1,
107 | "enc_input_extend_vocab" : 1,
108 | "article_oovs" : b'',
109 | "dec_input" : 1,
110 | "target" : 1,
111 | "dec_len" : -1,
112 | "article" : b"",
113 | "abstract" : b"",
114 | "abstract_sents" : b''},
115 | drop_remainder=True)
116 | def update(entry):
117 | return ({"enc_input" : entry["enc_input"],
118 | "extended_enc_input" : entry["enc_input_extend_vocab"],
119 | "article_oovs" : entry["article_oovs"],
120 | "enc_len" : entry["enc_len"],
121 | "article" : entry["article"],
122 | "max_oov_len" : tf.shape(entry["article_oovs"])[1] },
123 |
124 | {"dec_input" : entry["dec_input"],
125 | "dec_target" : entry["target"],
126 | "dec_len" : entry["dec_len"],
127 | "abstract" : entry["abstract"]})
128 |
129 |
130 | dataset = dataset.map(update)
131 |
132 | return dataset
133 |
134 |
135 | def batcher(data_path, vocab_path, hpm):
136 |
137 | filenames = glob.glob("{}/*.tfrecords".format(data_path))
138 | dataset = batch_generator(example_generator, filenames, vocab_path, hpm["vocab_size"], hpm["max_enc_len"], hpm["max_dec_len"], hpm["batch_size"], hpm["training"] )
139 |
140 | return dataset
141 |
--------------------------------------------------------------------------------
/build_eval_test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.training import training_util
3 | from training_helper import train_model
4 | from predict_helper import predict
5 | from batcher import batcher
6 | from transformer import Transformer
7 | import os
8 |
9 |
10 |
11 | def my_model(features, labels, mode, params):
12 |
13 | predictions, attn_weights = predict(features, params, transformer)
14 | estimator_spec = tf.estimator.EstimatorSpec(mode, predictions={"predictions":predictions})
15 |
16 | print(transformer.summary())
17 | return estimator_spec
18 |
19 |
20 | def build_model(params):
21 |
22 | config = tf.estimator.RunConfig(
23 | tf_random_seed=params["seed"],
24 | log_step_count_steps=params["log_step_count_steps"],
25 | save_summary_steps=params["save_summary_steps"]
26 | )
27 |
28 | return tf.estimator.Estimator(
29 | model_fn=my_model,
30 | params=params, config=config, model_dir=params["model_dir"] )
31 |
32 |
33 | def train(params):
34 | assert params["training"], "change training mode to true"
35 |
36 | tf.compat.v1.logging.info("Building the model ...")
37 | transformer = Transformer(
38 | num_layers=params["num_layers"], d_model=params["model_depth"], num_heads=params["num_heads"], dff=params["dff"],
39 | vocab_size=params["vocab_size"], batch_size=params["batch_size"])
40 |
41 |
42 | tf.compat.v1.logging.info("Creating the batcher ...")
43 | b = batcher(params["data_dir"], params["vocab_path"], params)
44 |
45 | tf.compat.v1.logging.info("Creating the checkpoint manager")
46 | logdir = "{}/logdir".format(params["model_dir"])
47 | checkpoint_dir = "{}/checkpoint".format(params["model_dir"])
48 | ckpt = tf.train.Checkpoint(step=tf.Variable(0), transformer=transformer)
49 | ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=11)
50 |
51 | ckpt.restore(ckpt_manager.latest_checkpoint)
52 | if ckpt_manager.latest_checkpoint:
53 | print("Restored from {}".format(ckpt_manager.latest_checkpoint))
54 | else:
55 | print("Initializing from scratch.")
56 |
57 | tf.compat.v1.logging.info("Starting the training ...")
58 | train_model(transformer, b, params, ckpt, ckpt_manager)
59 |
60 |
61 |
62 | def eval(model, params):
63 | pass
64 |
65 |
66 | def test(model, params):
67 | assert not params["training"], "change training mode to false"
68 | checkpoint_dir = "{}/checkpoint".format(params["model_dir"])
69 | logdir = "{}/logdir".format(params["model_dir"])
70 |
71 | pred = model.predict(input_fn = lambda : batcher(params["data_dir"], params["vocab_path"], params),
72 | yield_single_examples=False)
73 |
74 | yield next(pred)
--------------------------------------------------------------------------------
/data_helper.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class Vocab:
4 |
5 | SENTENCE_START = ''
6 | SENTENCE_END = ''
7 |
8 | PAD_TOKEN = '[PAD]'
9 | UNKNOWN_TOKEN = '[UNK]'
10 | START_DECODING = '[START]'
11 | STOP_DECODING = '[STOP]'
12 |
13 | def __init__(self, vocab_file, max_size):
14 |
15 | self.word2id = {Vocab.UNKNOWN_TOKEN : 0, Vocab.PAD_TOKEN : 1, Vocab.START_DECODING : 2, Vocab.STOP_DECODING : 3}
16 | self.id2word = {0 : Vocab.UNKNOWN_TOKEN, 1 : Vocab.PAD_TOKEN, 2 : Vocab.START_DECODING, 3 : Vocab.STOP_DECODING}
17 | self.count = 4
18 |
19 | with open(vocab_file, 'r') as f:
20 | for line in f:
21 | pieces = line.split()
22 | if len(pieces) != 2 :
23 | print('Warning : incorrectly formatted line in vocabulary file : %s\n' % line)
24 | continue
25 |
26 | w = pieces[0]
27 | if w in [Vocab.SENTENCE_START, Vocab.SENTENCE_END, Vocab.UNKNOWN_TOKEN, Vocab.PAD_TOKEN, Vocab.START_DECODING, Vocab.STOP_DECODING]:
28 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
29 |
30 | if w in self.word2id:
31 | raise Exception('Duplicated word in vocabulary file: %s' % w)
32 |
33 | self.word2id[w] = self.count
34 | self.id2word[self.count] = w
35 | self.count += 1
36 | if max_size != 0 and self.count >= max_size:
37 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self.count))
38 | break
39 |
40 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self.count, self.id2word[self.count-1]))
41 |
42 |
43 | def word_to_id(self, word):
44 | if word not in self.word2id:
45 | return self.word2id[Vocab.UNKNOWN_TOKEN]
46 | return self.word2id[word]
47 |
48 | def id_to_word(self, word_id):
49 | if word_id not in self.id2word:
50 | raise ValueError('Id not found in vocab: %d' % word_id)
51 | return self.id2word[word_id]
52 |
53 | def size(self):
54 | return self.count
55 |
56 |
57 | class Data_Helper:
58 | def article_to_ids(article_words, vocab):
59 | ids = []
60 | oovs = []
61 | unk_id = vocab.word_to_id(vocab.UNKNOWN_TOKEN)
62 | for w in article_words:
63 | i = vocab.word_to_id(w)
64 | if i == unk_id: # If w is OOV
65 | if w not in oovs: # Add to list of OOVs
66 | oovs.append(w)
67 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
68 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
69 | else:
70 | ids.append(i)
71 | return ids, oovs
72 |
73 |
74 | def abstract_to_ids(abstract_words, vocab, article_oovs):
75 | ids = []
76 | unk_id = vocab.word_to_id(vocab.UNKNOWN_TOKEN)
77 | for w in abstract_words:
78 | i = vocab.word_to_id(w)
79 | if i == unk_id: # If w is an OOV word
80 | if w in article_oovs: # If w is an in-article OOV
81 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
82 | ids.append(vocab_idx)
83 | else: # If w is an out-of-article OOV
84 | ids.append(unk_id) # Map to the UNK token id
85 | else:
86 | ids.append(i)
87 | return ids
88 |
89 |
90 |
91 | def output_to_words(id_list, vocab, article_oovs):
92 | words = []
93 | for i in id_list:
94 | try:
95 | w = vocab.id_to_word(i) # might be [UNK]
96 | except ValueError as e: # w is OOV
97 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
98 | article_oov_idx = i - vocab.size()
99 | try:
100 | w = article_oovs[article_oov_idx]
101 | except ValueError as e: # i doesn't correspond to an article oov
102 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
103 | words.append(w)
104 | return words
105 |
106 |
107 |
108 | def abstract_to_sents(abstract):
109 | """Splits abstract text from datafile into list of sentences.
110 | Args:
111 | abstract: string containing and tags for starts and ends of sentences
112 | Returns:
113 | sents: List of sentence strings (no tags)
114 | """
115 | cur = 0
116 | sents = []
117 | while True:
118 | try:
119 | start_p = abstract.index(Vocab.SENTENCE_START, cur)
120 | end_p = abstract.index(Vocab.SENTENCE_END, start_p + 1)
121 | cur = end_p + len(Vocab.SENTENCE_END)
122 | sents.append(abstract[start_p+len(Vocab.SENTENCE_START):end_p])
123 | except ValueError as e: # no more sentences
124 | return sents
125 |
126 | def get_dec_inp_targ_seqs( sequence, max_len, start_id, stop_id):
127 | """
128 | Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated).
129 | Args:
130 | sequence: List of ids (integers)
131 | max_len: integer
132 | start_id: integer
133 | stop_id: integer
134 | Returns:
135 | inp: sequence length <=max_len starting with start_id
136 | target: sequence same length as input, ending with stop_id only if there was no truncation
137 | """
138 | inp = [start_id] + sequence[:]
139 | target = sequence[:]
140 | if len(inp) > max_len: # truncate
141 | inp = inp[:max_len]
142 | target = target[:max_len] # no end_token
143 | else: # no truncation
144 | target.append(stop_id) # end token
145 | assert len(inp) == len(target)
146 | return inp, target
147 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import positional_encoding, scaled_dot_product_attention
3 |
4 | class Embedding(tf.keras.layers.Layer):
5 |
6 | def __init__(self, vocab_size, d_model):
7 | super(Embedding, self).__init__()
8 | self.vocab_size = vocab_size
9 | self.d_model = d_model
10 |
11 | self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
12 | self.pos_encoding = positional_encoding(vocab_size, d_model)
13 |
14 | def call(self, x):
15 | embed_x = self.embedding(x) # (batch_size, target_seq_len, d_model)
16 | embed_x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
17 | embed_x += self.pos_encoding[:, :tf.shape(x)[1], :]
18 | return embed_x
19 |
20 |
21 |
22 | class MultiHeadAttention(tf.keras.layers.Layer):
23 | def __init__(self, d_model, num_heads):
24 | super(MultiHeadAttention, self).__init__()
25 | self.num_heads = num_heads
26 | self.d_model = d_model
27 |
28 | assert d_model % self.num_heads == 0
29 |
30 | self.depth = d_model // self.num_heads
31 |
32 | self.wq = tf.keras.layers.Dense(d_model)
33 | self.wk = tf.keras.layers.Dense(d_model)
34 | self.wv = tf.keras.layers.Dense(d_model)
35 |
36 | self.dense = tf.keras.layers.Dense(d_model)
37 |
38 | def split_heads(self, x, batch_size):
39 | """Split the last dimension into (num_heads, depth).
40 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
41 | """
42 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
43 | return tf.transpose(x, perm=[0, 2, 1, 3])
44 |
45 | def call(self, v, k, q, mask):
46 | batch_size = tf.shape(q)[0]
47 |
48 | q = self.wq(q) # (batch_size, seq_len, d_model)
49 | k = self.wk(k) # (batch_size, seq_len, d_model)
50 | v = self.wv(v) # (batch_size, seq_len, d_model)
51 |
52 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
53 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
54 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
55 |
56 | # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
57 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
58 | scaled_attention, attention_weights = scaled_dot_product_attention(
59 | q, k, v, mask)
60 |
61 | scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
62 |
63 | concat_attention = tf.reshape(scaled_attention,
64 | (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
65 |
66 | output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
67 |
68 | return output, attention_weights
69 |
70 |
71 | def point_wise_feed_forward_network(d_model, dff):
72 | return tf.keras.Sequential([
73 | tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
74 | tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
75 | ])
76 |
77 |
78 | class EncoderLayer(tf.keras.layers.Layer):
79 | def __init__(self, d_model, num_heads, dff, rate=0.1):
80 | super(EncoderLayer, self).__init__()
81 |
82 | self.mha = MultiHeadAttention(d_model, num_heads)
83 | self.ffn = point_wise_feed_forward_network(d_model, dff)
84 |
85 | self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
86 | self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
87 |
88 | self.dropout1 = tf.keras.layers.Dropout(rate)
89 | self.dropout2 = tf.keras.layers.Dropout(rate)
90 |
91 | def call(self, x, training, mask):
92 |
93 | attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
94 | attn_output = self.dropout1(attn_output, training=training)
95 | out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
96 |
97 | ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
98 | ffn_output = self.dropout2(ffn_output, training=training)
99 | out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
100 |
101 | return out2
102 |
103 |
104 | class DecoderLayer(tf.keras.layers.Layer):
105 | def __init__(self, d_model, num_heads, dff, rate=0.1):
106 | super(DecoderLayer, self).__init__()
107 |
108 | self.mha1 = MultiHeadAttention(d_model, num_heads)
109 | self.mha2 = MultiHeadAttention(d_model, num_heads)
110 |
111 | self.ffn = point_wise_feed_forward_network(d_model, dff)
112 |
113 | self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
114 | self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
115 | self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
116 |
117 | self.dropout1 = tf.keras.layers.Dropout(rate)
118 | self.dropout2 = tf.keras.layers.Dropout(rate)
119 | self.dropout3 = tf.keras.layers.Dropout(rate)
120 |
121 |
122 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
123 | # enc_output.shape == (batch_size, input_seq_len, d_model)
124 |
125 | attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
126 | attn1 = self.dropout1(attn1, training=training)
127 | out1 = self.layernorm1(attn1 + x)
128 |
129 | attn2, attn_weights_block2 = self.mha2(
130 | enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
131 | attn2 = self.dropout2(attn2, training=training)
132 | out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model)
133 |
134 | ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
135 | ffn_output = self.dropout3(ffn_output, training=training)
136 | out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)
137 |
138 | return out3, attn_weights_block1, attn_weights_block2
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import argparse
3 | from build_eval_test import build_model, train, test
4 | import os
5 |
6 | def main():
7 |
8 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--max_enc_len", default=400, help="Encoder input max sequence length", type=int)
12 |
13 | parser.add_argument("--max_dec_len", default=100, help="Decoder input max sequence length", type=int)
14 |
15 | parser.add_argument("--batch_size", default=16, help="batch size", type=int)
16 |
17 | parser.add_argument("--vocab_size", default=50000, help="Vocabulary size", type=int)
18 |
19 | parser.add_argument("--num_layers", default=3, help="Model encoder and decoder number of layers", type=int)
20 |
21 | parser.add_argument("--model_depth", default=512, help="Model Embedding size", type=int)
22 |
23 | parser.add_argument("--num_heads", default=8, help="Multi Attention number of heads", type=int)
24 |
25 | parser.add_argument("--dff", default=2048, help="Dff", type=int)
26 |
27 | parser.add_argument("--seed", default=123, help="Seed", type=int)
28 |
29 | parser.add_argument("--log_step_count_steps", default=1, help="Log each N steps", type=int)
30 |
31 | parser.add_argument("--max_steps",default=230000, help="Max steps for training", type=int)
32 |
33 | parser.add_argument("--save_summary_steps", default=10000, help="Save summaries every N steps", type=int)
34 |
35 | parser.add_argument("--checkpoints_save_steps", default=10000, help="Save checkpoints every N steps", type=int)
36 |
37 | parser.add_argument("--mode", help="training, eval or test options")
38 |
39 | parser.add_argument("--model_dir", help="Model folder")
40 |
41 | parser.add_argument("--data_dir", help="Data Folder")
42 |
43 | parser.add_argument("--vocab_path", help="Vocab path")
44 |
45 |
46 | args = parser.parse_args()
47 | params = vars(args)
48 | print(params)
49 |
50 | assert params["mode"], "mode is required. train, test or eval option"
51 | if params["mode"] == "train":
52 | params["training"] = True ; params["eval"] = False ; params["test"] = False
53 | elif params["mode"] == "eval":
54 | params["training"] = False ; params["eval"] = True ; params["test"] = False
55 | elif params["mode"] == "test":
56 | params["training"] = False ; params["eval"] = False ; params["test"] = True;
57 | else:
58 | raise NameError("The mode must be train , test or eval")
59 | assert os.path.exists(params["data_dir"]), "data_dir doesn't exist"
60 | assert os.path.isfile(params["vocab_path"]), "vocab_path doesn't exist"
61 |
62 |
63 |
64 | if params["training"]:
65 | train( params)
66 | elif params["eval"]:
67 | pass
68 | elif not params["training"]:
69 | pass
70 |
71 |
72 | if __name__ == "__main__":
73 | main()
--------------------------------------------------------------------------------
/pointer_generator_run_on_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "pointer_generator_run_on_colab.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "R0YKfS3YGTk9",
32 | "colab_type": "text"
33 | },
34 | "source": [
35 | "# Pointer Transformer Generator Tensorflow 2.0.2 beta-1"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "metadata": {
41 | "id": "EkYCSmWOE_CF",
42 | "colab_type": "code",
43 | "colab": {}
44 | },
45 | "source": [
46 | "!pip install tensorflow-gpu==2.0.0-beta1"
47 | ],
48 | "execution_count": 0,
49 | "outputs": []
50 | },
51 | {
52 | "cell_type": "code",
53 | "metadata": {
54 | "id": "aR0InIHa6GI3",
55 | "colab_type": "code",
56 | "colab": {}
57 | },
58 | "source": [
59 | "from google.colab import drive\n",
60 | "drive.mount(\"/content/drive\")"
61 | ],
62 | "execution_count": 0,
63 | "outputs": []
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {
68 | "id": "5GG_DJjc_VIL",
69 | "colab_type": "text"
70 | },
71 | "source": [
72 | "\n",
73 | "**Git clone project**"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "metadata": {
79 | "id": "RYWr1D0f6j6K",
80 | "colab_type": "code",
81 | "colab": {}
82 | },
83 | "source": [
84 | "%%bash\n",
85 | "git clone https://github.com/steph1793/Pointer_Transformer_Generator "
86 | ],
87 | "execution_count": 0,
88 | "outputs": []
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "metadata": {
93 | "id": "BR5vlKHz9Req",
94 | "colab_type": "text"
95 | },
96 | "source": [
97 | "You can modify model_dir in order to save you model in your drive; and data_dir and vocab_path, in repect with their real paths in your drive\n",
98 | "\n",
99 | "\n",
100 | "*** PS : We suppose you have your data on your drive***"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "metadata": {
106 | "id": "nox5S1LI8kN2",
107 | "colab_type": "code",
108 | "colab": {}
109 | },
110 | "source": [
111 | "!python Pointer_Transformer_Generator/main.py --max_enc_len=400 \\\n",
112 | "--max_dec_len=100 \\\n",
113 | "--batch_size=16 \\\n",
114 | "--vocab_size=50000 \\\n",
115 | "--num_layers=3 \\\n",
116 | "--model_depth=512 \\\n",
117 | "--num_heads=8 \\\n",
118 | "--dff=2048 \\\n",
119 | "--seed=123 \\\n",
120 | "--log_step_count_steps=1 \\\n",
121 | "--max_steps=230000 \\\n",
122 | "--mode=train \\\n",
123 | "--save_summary_steps=10000 \\\n",
124 | "--checkpoints_save_steps=10000 \\\n",
125 | "--model_dir=\"transformer_model_dir\" \\\n",
126 | "--data_dir=\"drive/My Drive/pointer_gen/cnn-dailymail/tfrecords_finished_files/chunked\" \\\n",
127 | "--vocab_path=\"drive/My Drive/pointer_gen/cnn-dailymail/tfrecords_finished_files/vocab\" \\"
128 | ],
129 | "execution_count": 0,
130 | "outputs": []
131 | }
132 | ]
133 | }
--------------------------------------------------------------------------------
/predict_helper.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import create_masks
3 |
4 | def predict(features, params, model):
5 |
6 | output = tf.tile([[2]], [params["batch_size"], 1]) # 2 = start_decoding
7 |
8 | for i in range(params["max_dec_len"]):
9 | enc_padding_mask, combined_mask, dec_padding_mask = create_masks(features["enc_input"], output)
10 |
11 | # predictions.shape == (batch_size, seq_len, vocab_size)
12 | predictions, attention_weights = model(features["enc_input"],features["extended_enc_input"], features["max_oov_len"], output, training=params["training"],
13 | enc_padding_mask=enc_padding_mask,
14 | look_ahead_mask=combined_mask,
15 | dec_padding_mask=dec_padding_mask)
16 |
17 | # select the last word from the seq_len dimension
18 | predictions = predictions[: ,-1:, :] # (batch_size, 1, vocab_size)
19 | predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
20 |
21 |
22 | # concatentate the predicted_id to the output which is given to the decoder
23 | # as its input.
24 | output = tf.concat([output, predicted_id], axis=-1)
25 |
26 | return output, attention_weights
--------------------------------------------------------------------------------
/training_helper.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import create_masks
3 | import time
4 |
5 | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
6 | def __init__(self, d_model, warmup_steps=4000):
7 | super(CustomSchedule, self).__init__()
8 |
9 | self.d_model = d_model
10 | self.d_model = tf.cast(self.d_model, tf.float32)
11 |
12 | self.warmup_steps = warmup_steps
13 |
14 | def __call__(self, step):
15 | arg1 = tf.math.rsqrt(step)
16 | arg2 = step * (self.warmup_steps ** -1.5)
17 |
18 | return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
19 |
20 |
21 | def loss_function(loss_object, real, pred):
22 | mask = tf.math.logical_not(tf.math.equal(real, 0))
23 | loss_ = loss_object(real, pred)
24 |
25 | mask = tf.cast(mask, dtype=loss_.dtype)
26 | loss_ *= mask
27 |
28 | return tf.reduce_mean(loss_)
29 |
30 |
31 | def train_step(features, labels, params, model, optimizer, loss_object, train_loss_metric):
32 |
33 | enc_padding_mask, combined_mask, dec_padding_mask = create_masks(features["enc_input"], labels["dec_input"])
34 |
35 | with tf.GradientTape() as tape:
36 | output, attn_weights = model(features["enc_input"],features["extended_enc_input"], features["max_oov_len"], labels["dec_input"], training=params["training"],
37 | enc_padding_mask=enc_padding_mask,
38 | look_ahead_mask=combined_mask,
39 | dec_padding_mask=dec_padding_mask)
40 | loss = loss_function(loss_object, labels["dec_target"], output)
41 |
42 | gradients = tape.gradient(loss, model.trainable_variables)
43 | optimizer.apply_gradients(zip(gradients, model.trainable_variables))
44 | train_loss_metric(loss)
45 |
46 | def train_model(model, batcher, params, ckpt, ckpt_manager):
47 | learning_rate = CustomSchedule(params["model_depth"])
48 | optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
49 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction='none')
50 | train_loss_metric = tf.keras.metrics.Mean(name="train_loss_metric")
51 |
52 | try:
53 | for batch in batcher:
54 | t0 = time.time()
55 | train_step(batch[0], batch[1], params, model, optimizer, loss_object, train_loss_metric)
56 | t1 = time.time()
57 |
58 | print("step {}, time : {}, loss: {}".format(int(ckpt.step), t1-t0, train_loss_metric.result()))
59 | if int(ckpt.step) % params["checkpoints_save_steps"] ==0 :
60 | ckpt_manager.save(checkpoint_number=int(ckpt.step))
61 | print("Saved checkpoint for step {}".format(int(ckpt.step)))
62 | ckpt.step.assign_add(1)
63 |
64 | except KeyboardInterrupt:
65 | ckpt_manager.save(int(ckpt.step))
66 | print("Saved checkpoint for step {}".format(int(ckpt.step)))
--------------------------------------------------------------------------------
/transformer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from layers import Embedding, EncoderLayer, DecoderLayer
3 | from utils import _calc_final_dist
4 |
5 |
6 | class Encoder(tf.keras.layers.Layer):
7 | def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
8 | rate=0.1):
9 | super(Encoder, self).__init__()
10 | self.d_model = d_model
11 | self.num_layers = num_layers
12 |
13 | self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
14 | for _ in range(num_layers)]
15 | self.dropout = tf.keras.layers.Dropout(rate)
16 |
17 | def call(self, x, training, mask):
18 | x = self.dropout(x, training=training)
19 |
20 | for i in range(self.num_layers):
21 | x = self.enc_layers[i](x, training, mask)
22 |
23 | return x # (batch_size, input_seq_len, d_model)
24 |
25 |
26 | class Decoder(tf.keras.layers.Layer):
27 | def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
28 | rate=0.1):
29 | super(Decoder, self).__init__()
30 | self.d_model = d_model
31 | self.num_layers = num_layers
32 | self.num_heads = num_heads
33 | self.depth = d_model // self.num_heads
34 | self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
35 | for _ in range(num_layers)]
36 | self.dropout = tf.keras.layers.Dropout(rate)
37 | self.Wh = tf.keras.layers.Dense(1)
38 | self.Ws = tf.keras.layers.Dense(1)
39 | self.Wx = tf.keras.layers.Dense(1)
40 | self.V = tf.keras.layers.Dense(1)
41 |
42 |
43 | def call(self, embed_x, enc_output, training, look_ahead_mask, padding_mask):
44 |
45 | attention_weights = {}
46 | out = self.dropout(embed_x, training=training)
47 |
48 | for i in range(self.num_layers):
49 | out, block1, block2 = self.dec_layers[i](out, enc_output, training,
50 | look_ahead_mask, padding_mask)
51 |
52 | attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
53 | attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
54 |
55 | # out.shape == (batch_size, target_seq_len, d_model)
56 |
57 |
58 |
59 | #context vectors
60 | enc_out_shape = tf.shape(enc_output)
61 | context = tf.reshape(enc_output,(enc_out_shape[0], enc_out_shape[1], self.num_heads, self.depth) ) # shape : (batch_size, input_seq_len, num_heads, depth)
62 | context = tf.transpose(context, [0,2,1,3]) # (batch_size, num_heads, input_seq_len, depth)
63 | context = tf.expand_dims(context, axis=2) # (batch_size, num_heads, 1, input_seq_len, depth)
64 |
65 | attn = tf.expand_dims(block2, axis=-1) # (batch_size, num_heads, target_seq_len, input_seq_len, 1)
66 |
67 | context = context * attn # (batch_size, num_heads, target_seq_len, input_seq_len, depth)
68 | context = tf.reduce_sum(context, axis=3) # (batch_size, num_heads, target_seq_len, depth)
69 | context = tf.transpose(context, [0,2,1,3]) # (batch_size, target_seq_len, num_heads, depth)
70 | context = tf.reshape(context, (tf.shape(context)[0], tf.shape(context)[1], self.d_model)) # (batch_size, target_seq_len, d_model)
71 |
72 | # P_gens computing
73 | a = self.Wx(embed_x)
74 | b = self.Ws(out)
75 | c = self.Wh(context)
76 | p_gens = tf.sigmoid(self.V(a + b + c))
77 |
78 | return out, attention_weights, p_gens
79 |
80 |
81 | class Transformer(tf.keras.Model):
82 | def __init__(self, num_layers, d_model, num_heads, dff, vocab_size,batch_size, rate=0.1):
83 | super(Transformer, self).__init__()
84 |
85 | self.num_layers =num_layers
86 | self.vocab_size = vocab_size
87 | self.batch_size = batch_size
88 | self.model_depth = d_model
89 | self.num_heads = num_heads
90 |
91 | self.embedding = Embedding(vocab_size, d_model)
92 | self.encoder = Encoder(num_layers, d_model, num_heads, dff, vocab_size, rate)
93 | self.decoder = Decoder(num_layers, d_model, num_heads, dff, vocab_size, rate)
94 | self.final_layer = tf.keras.layers.Dense(vocab_size)
95 |
96 |
97 | def call(self, inp, extended_inp,max_oov_len, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
98 |
99 | embed_x = self.embedding(inp)
100 | embed_dec = self.embedding(tar)
101 |
102 | enc_output = self.encoder(embed_x, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model)
103 |
104 | # dec_output.shape == (batch_size, tar_seq_len, d_model)
105 | dec_output, attention_weights, p_gens = self.decoder(embed_dec, enc_output, training, look_ahead_mask, dec_padding_mask)
106 |
107 | output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
108 | output = tf.nn.softmax(output) # (batch_size, tar_seq_len, vocab_size)
109 | #output = tf.concat([output, tf.zeros((tf.shape(output)[0], tf.shape(output)[1], max_oov_len))], axis=-1) # (batch_size, targ_seq_len, vocab_size+max_oov_len)
110 |
111 | attn_dists = attention_weights['decoder_layer{}_block2'.format(self.num_layers)] # (batch_size,num_heads, targ_seq_len, inp_seq_len)
112 | attn_dists = tf.reduce_sum(attn_dists, axis=1)/self.num_heads # (batch_size, targ_seq_len, inp_seq_len)
113 |
114 |
115 | final_dists = _calc_final_dist( extended_inp, tf.unstack(output, axis=1) , tf.unstack(attn_dists, axis=1), tf.unstack(p_gens, axis=1), max_oov_len, self.vocab_size, self.batch_size)
116 | final_output =tf.stack(final_dists, axis=1)
117 |
118 | return final_output, attention_weights
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 |
6 | def get_angles(pos, i, d_model):
7 | angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
8 | return pos * angle_rates
9 |
10 | def positional_encoding(position, d_model):
11 | angle_rads = get_angles(np.arange(position)[:, np.newaxis],
12 | np.arange(d_model)[np.newaxis, :],
13 | d_model)
14 |
15 | # apply sin to even indices in the array; 2i
16 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
17 |
18 | # apply cos to odd indices in the array; 2i+1
19 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
20 |
21 | pos_encoding = angle_rads[np.newaxis, ...]
22 |
23 | return tf.cast(pos_encoding, dtype=tf.float32)
24 |
25 |
26 |
27 | def create_padding_mask(seq):
28 | seq = tf.cast(tf.math.equal(seq, 1), tf.float32)
29 |
30 | # add extra dimensions to add the padding
31 | # to the attention logits.
32 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
33 |
34 |
35 |
36 | def create_look_ahead_mask(size):
37 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
38 | return mask # (seq_len, seq_len)
39 |
40 |
41 | def create_masks(inp, tar):
42 | # Encoder padding mask
43 | enc_padding_mask = create_padding_mask(inp)
44 |
45 | # Used in the 2nd attention block in the decoder.
46 | # This padding mask is used to mask the encoder outputs.
47 | dec_padding_mask = create_padding_mask(inp)
48 |
49 | # Used in the 1st attention block in the decoder.
50 | # It is used to pad and mask future tokens in the input received by
51 | # the decoder.
52 | look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
53 | dec_target_padding_mask = create_padding_mask(tar)
54 | combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
55 |
56 | return enc_padding_mask, combined_mask, dec_padding_mask
57 |
58 |
59 | def scaled_dot_product_attention(q, k, v, mask):
60 | """Calculate the attention weights.
61 | q, k, v must have matching leading dimensions.
62 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
63 | The mask has different shapes depending on its type(padding or look ahead)
64 | but it must be broadcastable for addition.
65 |
66 | Args:
67 | q: query shape == (..., seq_len_q, depth)
68 | k: key shape == (..., seq_len_k, depth)
69 | v: value shape == (..., seq_len_v, depth_v)
70 | mask: Float tensor with shape broadcastable
71 | to (..., seq_len_q, seq_len_k). Defaults to None.
72 |
73 | Returns:
74 | output, attention_weights
75 | """
76 |
77 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
78 |
79 | # scale matmul_qk
80 | dk = tf.cast(tf.shape(k)[-1], tf.float32)
81 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
82 |
83 | # add the mask to the scaled tensor.
84 | if mask is not None:
85 | scaled_attention_logits += (mask * -1e9)
86 |
87 | # softmax is normalized on the last axis (seq_len_k) so that the scores
88 | # add up to 1.
89 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
90 |
91 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
92 |
93 | return output, attention_weights
94 |
95 | def _calc_final_dist( _enc_batch_extend_vocab, vocab_dists, attn_dists, p_gens, batch_oov_len, vocab_size, batch_size):
96 | """Calculate the final distribution, for the pointer-generator model
97 |
98 | Args:
99 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.
100 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays
101 |
102 | Returns:
103 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays.
104 | """
105 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen)
106 | vocab_dists = [p_gen * dist for (p_gen,dist) in zip(p_gens, vocab_dists)]
107 | attn_dists = [(1-p_gen) * dist for (p_gen,dist) in zip(p_gens, attn_dists)]
108 |
109 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words
110 | extended_vsize = vocab_size + batch_oov_len # the maximum (over the batch) size of the extended vocabulary
111 | extra_zeros = tf.zeros((batch_size, batch_oov_len ))
112 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize)
113 |
114 | # Project the values in the attention distributions onto the appropriate entries in the final distributions
115 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution
116 | # This is done for each decoder timestep.
117 | # This is fiddly; we use tf.scatter_nd to do the projection
118 | batch_nums = tf.range(0, limit=batch_size) # shape (batch_size)
119 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1)
120 | attn_len = tf.shape(_enc_batch_extend_vocab)[1] # number of states we attend over
121 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len)
122 | indices = tf.stack( (batch_nums, _enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2)
123 | shape = [batch_size, extended_vsize]
124 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists] # list length max_dec_steps (batch_size, extended_vsize)
125 |
126 | # Add the vocab distributions and the copy distributions together to get the final distributions
127 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep
128 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore.
129 | final_dists = [vocab_dist + copy_dist for (vocab_dist,copy_dist) in zip(vocab_dists_extended, attn_dists_projected)]
130 |
131 | return final_dists
--------------------------------------------------------------------------------