├── codes
├── __init__.py
├── params.py
├── rouge_batch.py
├── beam_search.py
├── run_model.py
├── NSE.py
├── HierNSE.py
├── utils.py
├── model.py
└── model_hier.py
├── data
├── __init__.py
└── preprocess.py
└── README.md
/codes/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/codes/params.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | dm_single_close_quote = u'\u2019' # unicode
3 | dm_double_close_quote = u'\u201d'
4 |
5 | # acceptable ways to end a sentence
6 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"]
7 |
8 | # Vocabulary id's for sentence start, end, pad, unknown token, start and stop decoding.
9 | SENTENCE_START = ''
10 | SENTENCE_END = ''
11 |
12 | PAD_TOKEN = '[PAD]'
13 | UNKNOWN_TOKEN = '[UNK]'
14 | START_DECODING = '[START]'
15 | STOP_DECODING = '[STOP]'
16 | VOCAB_SIZE = 200000
17 |
--------------------------------------------------------------------------------
/data/preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Code by Saptarashmi Bandyopadhyay
3 |
4 | import nltk
5 | import sys
6 | import os
7 | from nltk.stem import PorterStemmer
8 |
9 | porter = PorterStemmer()
10 |
11 | PathToDataset = sys.argv[1]
12 | PathToOutput = sys.argv[2]
13 |
14 | # Factoring.
15 | for split in ["train", "val", "test"]:
16 |
17 | # Creating a new directory if it doesn't exist.
18 | if not os.path.exists(PathToOutput + split):
19 | os.makedirs(PathToOutput + split)
20 |
21 | print("Processing files in the {} set".format(split))
22 |
23 | files = os.listdir(PathToDataset + split)
24 | # total_num = len(files)
25 |
26 | for num, filename in enumerate(files):
27 | f = open(PathToDataset + split + '/' + filename, 'r')
28 | of = open(PathToOutput + split + '/' + filename, 'w')
29 |
30 | for line in f:
31 | if "@highlight" not in line:
32 | text = line.encode('utf8')
33 | word = nltk.pos_tag(nltk.word_tokenize(text.decode('utf8')))
34 | for i in word:
35 | stemmed = porter.stem(i[0])
36 | oword = i[0] + ' | ' + stemmed + ' | ' + i[1] + ' '
37 | of.write(oword)
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Requirements
2 | - Tensorflow == 1.9.0
3 | - NLTK
4 | - joblib == 0.13.2
5 |
6 | # Files
7 | - run_model.py: This file has all the flags needed for training, it chooses the data module, models and starts training.
8 | - utils.py: Utilities module containing DataGenerators, Parallel model etc.
9 | - DataGenerator: Creates batches for vanilla NSE.
10 | - DataGeneratorHier: Creates batches for hierarchical NSE.
11 | - Memory: All the files below contain a variant of NSE.
12 | - NSE: This is a neural semantic encoder class.
13 | - HierNSE: This is the hier-NSE class.
14 | - models: All the files below contain an encoder, decoder, loss, optimizer functions that use an NSE.
15 | - model.py: Model using vanilla NSE.
16 | - model_hier.py: Model using hier-NSE (use this).
17 | - model_hier_sc.py: Self-Critic model (use this). It carefully back-propagates through the same multinomial samples that are sampled while forward pass.
18 | - rouge: Rouge scripts used.
19 | - rouge_batch: A NumPy implementation (faster than existing ones). Used outside the TensorFlow graph.
20 | - Data
21 | - Create a folder named `data`
22 | - Download the following splits into data folder:
23 | [train](https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_train.txt),
24 | [val](https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_val.txt),
25 | [test](https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt)
26 | - Download the CNN and Daily-Mail tokenized data:
27 | [CNN](https://drive.google.com/file/d/0BzQ6rtO2VN95cmNuc2xwUS1wdEE/view?usp=sharing),
28 | [DM](https://drive.google.com/file/d/0BzQ6rtO2VN95bndCZDdpdXJDV1U/view?usp=sharing)
29 | - Download [GloVe](http://nlp.stanford.edu/data/glove.840B.300d.zip)
30 |
31 | # supervised model
32 | - Training
33 |
34 | python run_model.py --model="hier" --mode="train" --PathToCheckpoint=/path/to/checkpoint --PathToTB=/path/to/tensorboard/logs
35 |
36 | - Testing
37 | - Check the epoch number of the best supervised model from TensorBoard, let it be X
38 |
39 | python run_model.py --model="hier" --mode="test" --PathToCheckpoint=/path/to/checkpoint/model_epochX --PathToResults=/path/to/results
40 |
41 | - Evaluation
42 |
43 | python run_model.py --model="hier" --mode="eval" --PathToResults=/path/to/results
44 |
45 | # self-critical model
46 | - Training
47 | - First copy the best supervised model to the rl checkpoint.
48 |
49 | python run_model.py --model="rlhier" --mode="train" --restore=True --PathToCheckpoint=/path/to/checkpoint --PathToTB=/path/to/tensorboard/logs
50 |
51 | - Testing
52 | - Check the epoch number of the best supervised model from TensorBoard, let it be X.
53 |
54 | python run_model.py --model="rlhier" --mode="train" --restore=True --PathToCheckpoint=/path/to/checkpoint/model_epochX --PathToResults=/path/to/results
55 |
56 | - Evaluation
57 |
58 | python run_model.py --model="rlhier" --mode="eval" --PathToResults=/path/to/results
59 |
--------------------------------------------------------------------------------
/codes/rouge_batch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def rouge_l_fscore_batch(hypothesis, references, end_id, lens, all_steps=False):
5 | """
6 | ROUGE scores computation between labels and predictions.
7 | This is an approximate ROUGE scoring method since we do not glue word pieces
8 | or decode the ids and tokenize the output.
9 | :param hypothesis: (predictions) tensor, model predictions (batch_size, <=max_dec_steps)
10 | :param references: (labels) tensor, gold output. (batch_size, max_dec_steps)
11 | :param end_id: End of sequence ID.
12 | :param lens: Lengths of the sequences excluding the PAD tokens, Shape: (batch_size).
13 | :param all_steps: Whether to calculate rewards for all time-steps.
14 | :return: rouge_l_fscore: approx rouge-l f1 score, Shape: (batch_size).
15 | """
16 |
17 | if all_steps:
18 | batch_fscore = rouge_l_sentence_level_all_batch(hypothesis, references, end_id, lens) # Shape: B x T.
19 | else:
20 | batch_fscore = rouge_l_sentence_level_final_batch(hypothesis, references, end_id, lens) # Shape: B x .
21 |
22 | return batch_fscore
23 |
24 |
25 | def infer_length(seq, end_id):
26 | """
27 | This function is used to calculate the length of given sequence based on the end ID.
28 | :param seq: Input sequence, Shape: B x T.
29 | :param end_id: End of sequence ID.
30 | :return:
31 | """
32 | batch_size = seq.shape[0]
33 | is_end = np.equal(seq, end_id).astype(np.int32) # Shape: B x T.
34 |
35 | # Avoiding the zero length case.
36 | front_zeros = np.zeros([batch_size, 1], dtype=np.int32) # Shape: B x 1.
37 | is_end = np.concatenate([front_zeros, is_end], axis=1) # Shape: B x (T + 1).
38 | is_end = is_end[:, : -1] # Shape: B x T.
39 |
40 | count_end = np.cumsum(is_end, axis=1)
41 | lengths = np.sum(np.equal(count_end, 0).astype(np.int32), axis=1)
42 |
43 | return lengths
44 |
45 |
46 | def rouge_l_sentence_level_final_batch(eval_sentences, ref_sentences, end_id=None, lens=None):
47 | """
48 | Computes ROUGE-L (sentence level) of two collections of sentences.
49 | Source: https://www.microsoft.com/en-us/research/publication/
50 | rouge-a-package-for-automatic-evaluation-of-summaries/
51 | Calculated according to:
52 | R_lcs = LCS(X,Y)/m
53 | P_lcs = LCS(X,Y)/n
54 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
55 | where:
56 | X = reference summary
57 | Y = Candidate summary
58 | m = length of reference summary
59 | n = length of candidate summary
60 | :param eval_sentences: The sentences that have been picked by the summarizer, Shape: (batch_size, m)
61 | :param ref_sentences: The sentences from the reference set, Shape: (batch_size, n)
62 | :param end_id: End of sentence ID.
63 | :param lens: Lengths of the sequences excluding PAD tokens.
64 | :return: F_lcs for all sentences in the batch, Shape: (batch_size,)
65 | """
66 | if lens is not None:
67 | n, m = lens, lens
68 | else:
69 | n, m = infer_length(eval_sentences, end_id), infer_length(ref_sentences, end_id)
70 | lcs = _len_lcs_batch(eval_sentences, ref_sentences, n, m)
71 | return np.array(_f_lcs_batch(lcs, n, m)).astype(np.float32)
72 |
73 |
74 | def rouge_l_sentence_level_all_batch(eval_sentences, ref_sentences, end_id, lens):
75 | """
76 | Computes ROUGE-L (sentence level) of two collections of sentences.
77 | Source: https://www.microsoft.com/en-us/research/publication/
78 | rouge-a-package-for-automatic-evaluation-of-summaries/
79 | Calculated according to:
80 | R_lcs = LCS(X,Y)/m
81 | P_lcs = LCS(X,Y)/n
82 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
83 | where:
84 | X = reference summary
85 | Y = Candidate summary
86 | m = length of reference summary
87 | n = length of candidate summary
88 | :param eval_sentences: The sentences that have been picked by the summarizer, Shape: (batch_size, T).
89 | :param ref_sentences: The sentences from the reference set, Shape: (batch_size, T).
90 | :param end_id: End of sentence ID.
91 | :param lens: length of the sequences, Shape: (batch_size).
92 | :return: F_lcs: Shape: B x T.
93 | """
94 | if lens is not None:
95 | m = lens
96 | else:
97 | m = infer_length(ref_sentences, end_id)
98 |
99 | batch_size, steps = eval_sentences.shape
100 |
101 | n = np.tile(np.arange(1, steps + 1), batch_size) # Shape: (B*T) x.
102 | m = np.tile(m, steps) # Shape: (B*T) x.
103 |
104 | # Calculate F1 scores.
105 | lcs = _len_lcs_batch(eval_sentences, ref_sentences, n, m, True) # Shape: B x T.
106 | lcs = np.reshape(lcs, [-1]) # Shape: (B*T) x.
107 | f1_scores_all_steps = _f_lcs_batch(lcs, n, m) # Shape: (B*T,)
108 | f1_scores = np.reshape(f1_scores_all_steps, (batch_size, steps)) # Shape: B x T.
109 |
110 | return f1_scores
111 |
112 |
113 | def _len_lcs_batch(x, y, n, m, all_steps=False):
114 | """
115 | Returns the length of Longest Common Sub-sequence between two steps.
116 | :param x: sequence of words, Shape: (batch_size, n).
117 | :param y: sequence of words, Shape: (batch_size, m).
118 | :param n: Lengths of the sequences in X, Shape: (batch_size,).
119 | :param m: Lengths of the sequences in Y, Shape: (batch_size,).
120 | :param all_steps: Whether to output LCS of all time-steps.
121 | :return: Lengths of LCS between a batch of x and y, Shape: (batch_size,) / (batch_size, T)
122 | """
123 | table = _lcs_batch(x, y) # Shape: batch_size x len x len.
124 | len_lcs = []
125 | for i in range(x.shape[0]):
126 | if all_steps:
127 | len_lcs.append(table[i, 1:, m[i]])
128 | else:
129 | len_lcs.append(table[i, n[i], m[i]])
130 |
131 | return np.array(len_lcs)
132 |
133 |
134 | def _lcs_batch(x, y):
135 | """
136 | Computes the length of LCS between two seqs.
137 | :param x: collection of words, (batch_size, n).
138 | :param y: collection of words, (batch_size, m).
139 | :return:
140 | """
141 | batch_size, n = x.shape
142 | m = y.shape[1]
143 |
144 | table = np.ndarray(shape=[batch_size, n + 1, m + 1], dtype=np.int32)
145 | for i in range(n + 1):
146 | for j in range(m + 1):
147 | if i == 0 or j == 0:
148 | table[:, i, j] = 0
149 | else:
150 | true_indcs = np.argwhere(x[:, i - 1] == y[:, j - 1])
151 | false_indcs = np.argwhere(x[:, i - 1] != y[:, j - 1])
152 | table[true_indcs, i, j] = table[true_indcs, i - 1, j - 1] + 1
153 | table[false_indcs, i, j] = np.maximum(table[false_indcs, i - 1, j], table[false_indcs, i, j - 1])
154 |
155 | return table
156 |
157 |
158 | def _f_lcs_batch(llcs, n, m):
159 | """
160 | Computes the LCS-based F-measure score.
161 | :param llcs: lengths of LCS, Shape: (batch_size,)
162 | :param n: number of words in candidate summary, Shape: (batch_size,)
163 | :param m: number of words in reference summary, Shape: (batch_size,)
164 | :return:
165 | """
166 | r_lcs = llcs / m
167 | p_lcs = llcs / n
168 | beta = p_lcs / (r_lcs + 1e-12)
169 | num = (1 + (beta**2)) * r_lcs * p_lcs
170 | denom = r_lcs + ((beta**2) * p_lcs)
171 | f_lcs = num / (denom + 1e-12)
172 |
173 | return f_lcs
174 |
--------------------------------------------------------------------------------
/codes/beam_search.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import params as params
4 |
5 | FLAGS = tf.app.flags.FLAGS
6 |
7 |
8 | class Hypothesis(object):
9 | """
10 | This is a class that represents a hypothesis.
11 | """
12 |
13 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens):
14 | """
15 | :param tokens: Tokens in this hypothesis.
16 | :param log_probs: log probabilities of each token in the hypothesis.
17 | :param state: Internal state of decoder after decoding the last token.
18 | :param attn_dists: Attention distributions of each token.
19 | :param p_gens: Generation probability of each token.
20 | """
21 | self.tokens = tokens
22 | self.log_probs = log_probs
23 | self.state = state
24 | self.attn_dists = attn_dists
25 | self.p_gens = p_gens
26 |
27 | def extend(self, token, log_prob, state=None, attn_dist=None, p_gen=None):
28 | """
29 | This function extends the hypothesis with the given new token.
30 | :param token: New decoded token.
31 | :param log_prob: log probability ot the token.
32 | :param state: Internal state of decoder after decoding the token.
33 | :param attn_dist: Attention distribution of the token.
34 | :param p_gen: Generation probability of each token.
35 | :return: returns the extended hypothesis
36 | """
37 | if state is None:
38 | state = self.state
39 |
40 | return Hypothesis(
41 | tokens=self.tokens + [token],
42 | log_probs=self.log_probs + log_prob,
43 | state=state,
44 | attn_dists=self.attn_dists + [attn_dist],
45 | p_gens=self.p_gens + [p_gen]
46 | )
47 |
48 | def latest_token(self):
49 | """
50 | :return: This function returns the last token in the hypothesis.
51 | """
52 | return self.tokens[-1]
53 |
54 | def avg_log_prob(self):
55 | """
56 | :return: This function returns the average of the log probability of hypothesis.
57 | Otherwise, longer sequences will have less probability.
58 | """
59 | return sum(self.log_probs) / len(self.tokens)
60 |
61 |
62 | def run_beam_search(inputs, run_encoder, decode_one_step, vocab):
63 | """
64 | This function performs the beam search decoding for one example.
65 | :param inputs: Inputs to the graph for running encoder and decoder.
66 | [enc_inp, enc_padding_mask, enc_inp_ext_vocab, max_oov_size] all with shapes Bm x T_in.
67 | :param run_encoder:
68 | :param decode_one_step:
69 | :param vocab:
70 | :return:
71 | """
72 | # Run the encoder and fix the resulted final state for decoding process.
73 | init_states = run_encoder(inputs[0], inputs[1])
74 |
75 | # Initialize beam size number of hypothesis.
76 | hyps = [Hypothesis(
77 | tokens=[vocab.word2id(params.START_DECODING)],
78 | log_probs=[0.0],
79 | state=init_states,
80 | attn_dists=[],
81 | p_gens=[]
82 | ) for _ in range(FLAGS.beam_size)]
83 |
84 | results = [] # This will contain finished hypothesis (those have emitted [STOP_DECODING] token.)
85 |
86 | for step in range(FLAGS.bs_dec_steps):
87 |
88 | # Stop decoding if beam_size number of complete hypothesis are emitted.
89 | if len(results) == FLAGS.beam_size:
90 | break
91 |
92 | prev_tokens = [h.latest_token() for h in hyps] # Tokens of all hypothesis from previous time step.
93 | # Replacing the OOV tokens with UNK tokens to perform the next decoding step.
94 | prev_tokens = [t if t in range(vocab.size()) else vocab.word2id(params.UNKNOWN_TOKEN) for t in prev_tokens]
95 | prev_states = [h.state for h in hyps] # Internal states of decoder.
96 |
97 | # Running one step of decoder.
98 | step_inputs = [prev_tokens, inputs[1]]
99 | if FLAGS.use_pgen:
100 | step_inputs += [inputs[2], inputs[3]]
101 |
102 | outputs = decode_one_step(step_inputs, prev_states)
103 | topk_ids, topk_log_probs, curr_states = outputs[: 3]
104 |
105 | p_gens = FLAGS.beam_size * [None]
106 | p_attns = FLAGS.beam_size * [None]
107 | if FLAGS.use_pgen:
108 | p_gens, p_attns = outputs[3: 5]
109 |
110 | # Extend each hypothesis with newly decoded predictions.
111 | all_hyps = []
112 | num_org_hyps = 1 if step == 0 else len(hyps) # In the first step, there is only 1 distinct hypothesis.
113 | for i in range(num_org_hyps):
114 | hyp, curr_state, p_gen, p_attn = hyps[i], curr_states[i], p_gens[i], p_attns[i]
115 | for j in range(2 * FLAGS.beam_size):
116 | new_hyp = hyp.extend(token=topk_ids[i][j],
117 | log_prob=topk_log_probs[i][j],
118 | state=curr_state,
119 | attn_dist=p_attn,
120 | p_gen=p_gen)
121 | all_hyps.append(new_hyp)
122 |
123 | hyps = [] # Top Bm hypothesis.
124 | for h in sort_hyps(all_hyps):
125 | if h.latest_token == vocab.word2id(params.STOP_DECODING): # If hypothesis ended.
126 | # Collect this sequence only if it is long enough.
127 | if step >= FLAGS.min_dec_len:
128 | results.append(h)
129 | else: # Use this hypothesis for next decoding step if not ended.
130 | hyps.append(h)
131 |
132 | # Stop if beam size is reached.
133 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
134 | break
135 |
136 | # If no complete hypothesis were collected, add all current hypothesis to the results.
137 | if len(results) == 0:
138 | results = hyps
139 |
140 | results = sort_hyps(results) # Return best hypothesis in the final beam.
141 |
142 | return results[0]
143 |
144 |
145 | def sort_hyps(hyps):
146 | """
147 | This function sorts the given hypothesis based on its probability.
148 | :param hyps: Input hypotheses.
149 | :return:
150 | """
151 | return sorted(hyps, key=lambda h: h.avg_log_prob(), reverse=True)
152 |
153 |
154 | def run_beam_search_hier(inputs, run_encoder, decode_one_step, vocab):
155 | """
156 | This function performs the beam search decoding for one example.
157 | :param inputs: Inputs to the graph for running encoder and decoder.
158 | [enc_inp, enc_pad_mask, enc_doc_mask, enc_inp_ext_vocab, max_oov_size] all with shapes Bm x T_in.
159 | :param run_encoder:
160 | :param decode_one_step:
161 | :param vocab:
162 | :return:
163 | """
164 | # Run the encoder and fix the resulted final state for decoding process.
165 | init_states = run_encoder(inputs[: 3])
166 |
167 | # Initialize beam size number of hypothesis.
168 | hyps = [Hypothesis(
169 | tokens=[vocab.word2id(params.START_DECODING)],
170 | log_probs=[0.0],
171 | state=init_states,
172 | attn_dists=[],
173 | p_gens=[]
174 | ) for _ in range(FLAGS.beam_size)]
175 |
176 | results = [] # This will contain finished hypothesis (those have emitted [STOP_DECODING] token.)
177 |
178 | for step in range(FLAGS.bs_dec_steps):
179 |
180 | # Stop decoding if beam_size number of complete hypothesis are emitted.
181 | if len(results) == FLAGS.beam_size:
182 | break
183 |
184 | prev_tokens = [h.latest_token() for h in hyps] # Tokens of all hypothesis from previous time step.
185 | # Replacing the OOV tokens with UNK tokens to perform the next decoding step.
186 | prev_tokens = [t if t in range(vocab.size()) else vocab.word2id(params.UNKNOWN_TOKEN) for t in prev_tokens]
187 | prev_states = [h.state for h in hyps] # Internal states of decoder.
188 |
189 | # Preparing inputs for the decoder.
190 | # inputs[1] = np.reshape(inputs[1], [-1, FLAGS.max_enc_sent * FLAGS.max_enc_steps_per_sent]) # B x T_enc.
191 | # Running one step of decoder.
192 | step_inputs = [prev_tokens] + inputs[1: 3]
193 | if FLAGS.use_pgen:
194 | step_inputs += [inputs[3], inputs[4]]
195 |
196 | outputs = decode_one_step(step_inputs, prev_states)
197 | topk_ids, topk_log_probs, curr_states = outputs[: 3]
198 |
199 | p_gens = FLAGS.beam_size * [None]
200 | p_attns = FLAGS.beam_size * [None]
201 | if FLAGS.use_pgen:
202 | p_gens, p_attns = outputs[3: 5]
203 |
204 | # Extend each hypothesis with newly decoded predictions.
205 | all_hyps = []
206 | num_org_hyps = 1 if step == 0 else len(hyps) # In the first step, there is only 1 distinct hypothesis.
207 | for i in range(num_org_hyps):
208 | hyp, curr_state, p_gen, p_attn = hyps[i], curr_states[i], p_gens[i], p_attns[i]
209 | for j in range(2 * FLAGS.beam_size):
210 | new_hyp = hyp.extend(token=topk_ids[i][j],
211 | log_prob=topk_log_probs[i][j],
212 | state=curr_state,
213 | attn_dist=p_attn,
214 | p_gen=p_gen)
215 | all_hyps.append(new_hyp)
216 |
217 | hyps = [] # Top Bm hypothesis.
218 | for h in sort_hyps(all_hyps):
219 | if h.latest_token == vocab.word2id(params.STOP_DECODING): # If hypothesis ended.
220 | # Collect this sequence only if it is long enough.
221 | if step >= FLAGS.min_dec_len:
222 | results.append(h)
223 | else: # Use this hypothesis for next decoding step if not ended.
224 | hyps.append(h)
225 |
226 | # Stop if beam size is reached.
227 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
228 | break
229 |
230 | # If no complete hypothesis were collected, add all current hypothesis to the results.
231 | if len(results) == 0:
232 | results = hyps
233 |
234 | results = sort_hyps(results) # Return best hypothesis in the final beam.
235 |
236 | return results[0]
237 |
--------------------------------------------------------------------------------
/codes/run_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Latest Command.
3 | # python run_model.py --sample=True --restore_type="all" --enc_steps=24 --dec_steps=12 --max_enc_sent=6
4 | # --max_enc_steps_per_sent=4 --max_dec_sent=4 --max_dec_steps_per_sent=3 --prob_dist="softmax" --rnn="LSTM"
5 | # --skip_rnn=False --use_comp_lstm=True --use_pos_emb=False --use_pgen=True --use_enc_coverage=True
6 | # --use_dec_coverage=False --use_reweight=False --vocab_size=20 --num_epochs=10 --val_every=1 --model="hier2"
7 | # --attn_type="ffn" --scaled=True --num_heads=1 --num_high=4 --num_high_heads=5 --num_high_iters=5 --mode="test"
8 | # --batch_size=3 --val_batch_size=3 --PathToCheckpoint="./my_nse_net/model_epoch10"
9 |
10 | __author__ = "Rajeev Bhatt Ambati"
11 | import tensorflow as tf
12 | from utils import Vocab, DataGenerator, DataGeneratorHier, eval_model
13 | from model import SummarizationModel
14 | from model_hier import SummarizationModelHier
15 | from model_hier_sc import SummarizationModelHierSC
16 | import os
17 | import random
18 | import numpy as np
19 | import time
20 |
21 | FLAGS = tf.app.flags.FLAGS
22 |
23 | # Parameters obtained from the input arguments.
24 |
25 | # Paths.
26 | tf.app.flags.DEFINE_string('PathToDataset', '../data/', 'Path to the datasets.')
27 | tf.app.flags.DEFINE_string('PathToGlove', '../data/glove.840B.300d.w2v.txt',
28 | 'Path to the pre-trained GloVe Word vectors.')
29 | tf.app.flags.DEFINE_string('PathToVocab', '../data/vocab.txt', 'Path to a vocabulary file if already stored. '
30 | 'Otherwise, a new vocabulary file will be stored here.')
31 | tf.app.flags.DEFINE_string('PathToLookups', '../data/lookups.pkl', 'Path to the lookup tables .pkl file if already'
32 | 'stored. Otherwise, a new vocabulary file will be'
33 | 'stored here.')
34 | tf.app.flags.DEFINE_string('PathToResults', '../results/', 'Path to the test results.')
35 | tf.app.flags.DEFINE_string('PathToCheckpoint', './my_nse_net/hier_v14/model_epoch10',
36 | 'Trained model will be stored here.')
37 | tf.app.flags.DEFINE_boolean('sample', False, 'Sample debugging the code or training for long.')
38 | tf.app.flags.DEFINE_bool('permutate', False, 'Whether to permutate or truncate sequences.')
39 | tf.app.flags.DEFINE_boolean('chunk', True, 'Whether to use chunks in place of sentences to use maximum effective '
40 | 'sequence lengths possible')
41 | tf.app.flags.DEFINE_string('PathToTB', 'log/', 'Tensorboard visualization directory.')
42 | tf.app.flags.DEFINE_boolean('restore_checkpoint', True, 'Boolean describing whether training has to be restored from '
43 | 'a checkpoint or start fresh.')
44 | tf.app.flags.DEFINE_string('restore_type', "all", 'String describing whether coverage/momentum parameters has to be'
45 | 'restored or initialized.')
46 |
47 | # Plain model inputs.
48 | tf.app.flags.DEFINE_integer('enc_steps', 400, 'No. of time steps in the encoder.')
49 | tf.app.flags.DEFINE_integer('dec_steps', 100, 'No. of time steps in the decoder.')
50 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'Max. no of time steps in the decoder during decoding.')
51 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 'Minimum no. of tokens in a complete hypothesis while decoding.')
52 |
53 | # Hier model inputs.
54 | tf.app.flags.DEFINE_integer('max_enc_sent', 20, 'Max. no. of sentences in the encoder.')
55 | tf.app.flags.DEFINE_integer('max_enc_steps_per_sent', 20, 'Max. no. of tokens per a sentence in the encoder.')
56 |
57 | # Common flags.
58 | tf.app.flags.DEFINE_integer('num_layers', 1, "No. of layers in each LSTM.")
59 | tf.app.flags.DEFINE_boolean('use_comp_lstm', True, 'Whether to use an LSTM for compose function.')
60 | tf.app.flags.DEFINE_boolean('use_pgen', True, 'Flag whether pointer mechanism should be used.')
61 | tf.app.flags.DEFINE_boolean('use_pretrained', True, 'Flag whether pre-trained word-vectors has to be used.')
62 |
63 | # Common sizes.
64 | tf.app.flags.DEFINE_integer('batch_size', 60, 'No. of examples in a batch of training data.')
65 | tf.app.flags.DEFINE_integer('val_batch_size', 60, 'No. of examples in a batch of validation data.')
66 | tf.app.flags.DEFINE_integer('beam_size', 5, 'Beam size for beam search decoding.')
67 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 'Size of the vocabulary.')
68 | tf.app.flags.DEFINE_integer('dim', 300, 'Dimension of the word embedding, it should be the dimension of the '
69 | 'pre-trained word vectors used.')
70 |
71 | # Common optimizer flags.
72 | tf.app.flags.DEFINE_boolean('use_entropy', True, 'Whether to use entropy of sampling in loss.')
73 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'Maximum gradient norm when gradient clipping.')
74 | tf.app.flags.DEFINE_float('lr', 0.001, 'Learning rate.')
75 |
76 | # Common training flags.
77 | tf.app.flags.DEFINE_boolean('rouge_summary', True, 'A flag whether ROUGE has to be included in the summary.')
78 | tf.app.flags.DEFINE_integer('num_epochs', 20, 'No. of epochs to train the model.')
79 | tf.app.flags.DEFINE_integer('summary_every', 1, 'Write training summaries every few iterations.')
80 | tf.app.flags.DEFINE_integer('val_every', 4, 'No. of training epochs after which model should be validated.')
81 | tf.app.flags.DEFINE_string('mode', 'test', 'train/test mode.')
82 | tf.app.flags.DEFINE_string('model', 'hier', 'Which model to train: plain/hier/rlhier.')
83 | tf.app.flags.DEFINE_list('GPUs', [0], 'GPU ids to be used.')
84 | tf.app.flags.DEFINE_integer('num_pools', 5, 'No. of pools per GPU.')
85 |
86 | # Beam search decoding flags.
87 | tf.app.flags.DEFINE_integer('bs_enc_steps', 400, 'No. of time steps in the encoder.')
88 | tf.app.flags.DEFINE_integer('bs_dec_steps', 100, 'No. of time steps in the decoder.')
89 |
90 | # Hier model inputs.
91 | tf.app.flags.DEFINE_integer('bs_enc_sent', 20, 'Max. no. of sentences in the encoder.')
92 | tf.app.flags.DEFINE_integer('bs_enc_steps_per_sent', 20, 'Max. no. of tokens per a sentence in the encoder.')
93 |
94 | # Self critic policy gradients model.
95 | tf.app.flags.DEFINE_boolean('use_self_critic', False, 'Flag whether to use self critical model.')
96 | tf.app.flags.DEFINE_boolean('teacher_forcing', False, 'Flag whether to use teacher-forcing in greedy mode.')
97 | tf.app.flags.DEFINE_integer('num_samples', 1, 'No. of samples')
98 | tf.app.flags.DEFINE_boolean('use_discounted_rewards', False, 'Flag whether discounted rewards has to be used.')
99 | tf.app.flags.DEFINE_boolean('use_intermediate_rewards', False, 'Flag whether intermediate rewards has to be used.')
100 | tf.app.flags.DEFINE_float('gamma', 0.99, 'Discount Factor')
101 | tf.app.flags.DEFINE_float('eta', 2.5E-5, 'RL/MLE scaling factor.')
102 | tf.app.flags.DEFINE_float('eta1', 0.0, 'Cross-entropy weight.')
103 | tf.app.flags.DEFINE_float('eta2', 0.0, 'RL loss weight.')
104 | tf.app.flags.DEFINE_float('eta3', 1E-4, 'Entropy weight.')
105 |
106 |
107 | def main(args):
108 | main_start = time.time()
109 |
110 | tf.set_random_seed(2019)
111 | random.seed(2019)
112 | np.random.seed(2019)
113 |
114 | if len(args) != 1:
115 | raise Exception('Problem with flags: %s' % args)
116 |
117 | # Correcting a few flags for test/eval mode.
118 | if FLAGS.mode != 'train':
119 | FLAGS.batch_size = FLAGS.beam_size
120 | FLAGS.bs_dec_steps = FLAGS.dec_steps
121 |
122 | if FLAGS.model.lower() != "tx":
123 | FLAGS.dec_steps = 1
124 |
125 | assert FLAGS.mode == 'train' or FLAGS.batch_size == FLAGS.beam_size, \
126 | "In test mode, batch size should be equal to beam size."
127 |
128 | assert FLAGS.mode == 'train' or FLAGS.dec_steps == 1 or FLAGS.model.lower() == "tx", \
129 | "In test mode, no. of decoder steps should be one."
130 |
131 | os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
132 | os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(str(gpu_id) for gpu_id in FLAGS.GPUs)
133 |
134 | if not os.path.exists(FLAGS.PathToCheckpoint):
135 | os.makedirs(FLAGS.PathToCheckpoint)
136 |
137 | if FLAGS.mode == "test" and not os.path.exists(FLAGS.PathToResults):
138 | os.makedirs(FLAGS.PathToResults)
139 | os.makedirs(FLAGS.PathToResults + 'predictions')
140 | os.makedirs(FLAGS.PathToResults + 'groundtruths')
141 |
142 | if FLAGS.mode == 'eval':
143 | eval_model(FLAGS.PathToResults)
144 | else:
145 | start = time.time()
146 | vocab = Vocab(max_vocab_size=FLAGS.vocab_size, emb_dim=FLAGS.dim, dataset_path=FLAGS.PathToDataset,
147 | glove_path=FLAGS.PathToGlove, vocab_path=FLAGS.PathToVocab, lookup_path=FLAGS.PathToLookups)
148 |
149 | if FLAGS.model.lower() == "plain":
150 | print("Setting up the plain model.\n")
151 | data = DataGenerator(path_to_dataset=FLAGS.PathToDataset, max_inp_seq_len=FLAGS.enc_steps,
152 | max_out_seq_len=FLAGS.dec_steps, vocab=vocab,
153 | use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample)
154 | summarizer = SummarizationModel(vocab, data)
155 |
156 | elif FLAGS.model.lower() == "hier":
157 | print("Setting up the hier model.\n")
158 | data = DataGeneratorHier(path_to_dataset=FLAGS.PathToDataset, max_inp_sent=FLAGS.max_enc_sent,
159 | max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent,
160 | max_out_tok=FLAGS.dec_steps, vocab=vocab,
161 | use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample)
162 | summarizer = SummarizationModelHier(vocab, data)
163 |
164 | elif FLAGS.model.lower() == "rlhier":
165 | print("Setting up the Hier RL model.\n")
166 | data = DataGeneratorHier(path_to_dataset=FLAGS.PathToDataset, max_inp_sent=FLAGS.max_enc_sent,
167 | max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent,
168 | max_out_tok=FLAGS.dec_steps, vocab=vocab,
169 | use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample)
170 | summarizer = SummarizationModelHierSC(vocab, data)
171 |
172 | else:
173 | raise ValueError("model flag should be either of plain/hier/bayesian/shared!! \n")
174 |
175 | end = time.time()
176 | print("Setting up vocab, data and model took {:.2f} sec.".format(end - start))
177 |
178 | summarizer.build_graph()
179 |
180 | if FLAGS.mode == 'train':
181 | summarizer.train()
182 | elif FLAGS.mode == "test":
183 | summarizer.test()
184 | else:
185 | raise ValueError("mode should be either train/test!! \n")
186 |
187 | main_end = time.time()
188 | print("Total time elapsed: %.2f \n" % (main_end - main_start))
189 |
190 |
191 | if __name__ == '__main__':
192 | tf.app.run()
193 |
--------------------------------------------------------------------------------
/codes/NSE.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "Rajeev Bhatt Ambati"
3 |
4 | import tensorflow as tf
5 | tf.set_random_seed(2019)
6 |
7 |
8 | def create_rnn_cell(rnn_size, scope):
9 | """
10 | This function creates and returns an RNN cell.
11 | :param rnn_size: Size of the hidden state.
12 | :param scope: scope for the RNN variables.
13 | :return: returns the RNN cell with the necessary specifications.
14 | """
15 | with tf.variable_scope(scope):
16 | cell = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units=rnn_size, reuse=tf.AUTO_REUSE)
17 |
18 | return cell
19 |
20 |
21 | class NSE:
22 | """
23 | This is a Neural Semantic Encoder class.
24 | """
25 |
26 | def __init__(self, batch_size, dim, dense_init, mode='train', use_comp_lstm=False):
27 | """
28 | :param batch_size: No. of examples in a batch of data.
29 | :param dim: Dimension of the memories. (Same as the dimension of wordvecs).
30 | :param dense_init: Dense kernel initializer.
31 | :param mode: 'train/val/test' mode.
32 | :param use_comp_lstm: Flag if LSTM should be used for compose function or MLP.
33 | """
34 | self._batch_size, self._dim = batch_size, dim
35 |
36 | self._dense_init = dense_init # Initializer.
37 |
38 | self._mode = mode
39 | self._use_comp_lstm = use_comp_lstm
40 |
41 | self._read_scope = 'read'
42 | self._write_scope = 'write'
43 | self._comp_scope = 'compose'
44 |
45 | # Read LSTM
46 | self._read_lstm = create_rnn_cell(self._dim, self._read_scope)
47 |
48 | # Compose LSTM
49 | if self._use_comp_lstm:
50 | self._comp_lstm = create_rnn_cell(2*self._dim, self._comp_scope)
51 |
52 | # Write LSTM
53 | self._write_lstm = create_rnn_cell(self._dim, self._write_scope)
54 |
55 | def read(self, x_t, state=None):
56 | """
57 | This is the read function.
58 | :param x_t: input sequence x, Shape: B x D.
59 | :param state: Previous hidden state of read LSTM.
60 | :return: r_t: Outputs of the read LSTM, Shape: B x D
61 | """
62 | with tf.variable_scope(self._read_scope, reuse=tf.AUTO_REUSE):
63 | r_t, state = self._read_lstm(x_t, state)
64 |
65 | return r_t, state
66 |
67 | def compose(self, r_t, z_t, m_t, state=None):
68 | """
69 | This is the compose function.
70 | :param r_t: Read from the input x_t, Shape: B x D.
71 | :param z_t: Attention distribution, Shape: B x T.
72 | :param m_t: Memory at time step 't', Shape: B x T x D.
73 | :param state: Previous hidden state of compose LSTM.
74 | :return: m_rt: Retrieved memory, Shape: B x D
75 | c_t: The composed vector, Shape: B x D.
76 | state: Hidden state of compose LSTM after previous time step if using it.
77 | """
78 | # z_t is repeated across dimension.
79 | z_t_rep = tf.tile(tf.expand_dims(z_t, axis=-1), multiples=[1, 1, self._dim]) # Shape: B x T x D.
80 | m_rt = tf.reduce_sum(tf.multiply(z_t_rep, m_t), axis=1) # Retrieved memory, Shape: B x D
81 |
82 | with tf.variable_scope(self._comp_scope, reuse=tf.AUTO_REUSE):
83 | r_m_t = tf.concat([r_t, m_rt], axis=-1) # Shape: B x (2*D)
84 |
85 | # Compose LSTM
86 | if self._use_comp_lstm:
87 | r_m_t, state = self._comp_lstm(r_m_t, state)
88 |
89 | # Dense layer to reduce size from (2*D) to D.
90 | c_t = tf.layers.dense(inputs=r_m_t,
91 | units=self._dim,
92 | activation=None,
93 | kernel_initializer=self._dense_init,
94 | name='MLP') # Composed vector, Shape: B x D
95 | c_t = tf.nn.relu(c_t) # Activation function
96 |
97 | return m_rt, c_t, state
98 |
99 | def write(self, c_t, state=None):
100 | """
101 | This is the write function.
102 | :param c_t: The composed vector, Shape: B x D.
103 | :param state: Previous hidden state of write LSTM.
104 | :return: h_t: The write vector, Shape: B x D
105 | """
106 | with tf.variable_scope(self._write_scope, reuse=tf.AUTO_REUSE):
107 | h_t, state = self._write_lstm(c_t, state)
108 |
109 | return h_t, state
110 |
111 | def attention(self, r_t, m_t, mem_mask):
112 | """
113 | This function computes the attention distribution.
114 | :param r_t: Read from the input x_t, Shape: B x D.
115 | :param m_t: Memory at time step 't', Shape: B x T x D.
116 | :param mem_mask: A mask to indicate the presence of PAD tokens, Shape: B x T.
117 | :return:
118 | attn_dist: The attention distribution at the current time-step, Shape: B x T.
119 | """
120 | # Shapes
121 | attn_len, attn_vec_size = m_t.get_shape().as_list()[1: 3]
122 |
123 | with tf.variable_scope("attention", reuse=tf.AUTO_REUSE):
124 | # Input features.
125 | with tf.variable_scope("input", reuse=tf.AUTO_REUSE):
126 | input_features = tf.layers.dense(inputs=r_t,
127 | units=attn_vec_size,
128 | activation=None,
129 | kernel_initializer=self._dense_init,
130 | reuse=tf.AUTO_REUSE,
131 | name='inp_dense') # Shape: B x D.
132 |
133 | input_features = tf.expand_dims(input_features, axis=1) # Shape: B x 1 x D.
134 |
135 | # Memory features.
136 | with tf.variable_scope("memory", reuse=tf.AUTO_REUSE):
137 | memory_features = tf.layers.dense(inputs=m_t,
138 | units=attn_vec_size,
139 | activation=None,
140 | kernel_initializer=self._dense_init,
141 | use_bias=False,
142 | reuse=tf.AUTO_REUSE,
143 | name="memory_dense") # Shape: B x T x D.
144 |
145 | v = tf.get_variable("v", [attn_vec_size])
146 |
147 | scores = tf.reduce_sum(
148 | v * tf.tanh(input_features + memory_features), axis=2) # Shape: B x T.
149 | attn_dist = tf.nn.softmax(scores) # Shape: B x T.
150 |
151 | # Assigning zero probability to the PAD tokens.
152 | # Re-normalizing the probability distribution to sum to one.
153 | attn_dist = tf.multiply(attn_dist, mem_mask)
154 | masked_sums = tf.reduce_sum(attn_dist, axis=1, keepdims=True) # Shape: B x 1
155 | attn_dist = tf.truediv(attn_dist, masked_sums) # Re-normalization.
156 |
157 | return attn_dist
158 |
159 | @staticmethod
160 | def update(z_t, m_t, h_t):
161 | """
162 | This function updates the memory with write vectors as per the retrieved slots.
163 | :param z_t: Retrieved attention distribution over memory, Shape: B x T.
164 | :param m_t: Memory at current time step 't', Shape: B x T x D.
165 | :param h_t: Write vector, Shape: B x D.
166 | :return: new_m: The updated memory, Shape: B x T x D.
167 | """
168 | # Write and erase mask for sentence memories.
169 | write_mask = tf.expand_dims(z_t, axis=2) # Shape: B x T x 1.
170 | erase_mask = tf.ones_like(write_mask) - write_mask # Shape: B x T x 1.
171 |
172 | # Write tensor
173 | write_tensor = tf.expand_dims(h_t, axis=1) # Shape: B x 1 x D.
174 |
175 | # Updated memory.
176 | new_m = tf.add(tf.multiply(m_t, erase_mask), tf.multiply(write_tensor, write_mask))
177 |
178 | return new_m
179 |
180 | def prob_gen(self, m_rt, h_t, r_t):
181 | """
182 | This function calculates the generation probability from the retrieved memory, write vector and the
183 | input read.
184 | :param m_rt: Retrieved memory, Shape: B x D.
185 | :param h_t: Write vector, Shape: B x D.
186 | :param r_t: Read vector from the input, Shape: B x D.
187 | :return: p_gen, Shape: B x 1
188 | """
189 | with tf.variable_scope('pgen', reuse=tf.AUTO_REUSE):
190 | inp = tf.concat([m_rt, h_t, r_t], axis=-1) # Shape: B x (3*D)
191 | p_gen = tf.layers.dense(inp,
192 | units=1,
193 | activation=None,
194 | kernel_initializer=self._dense_init,
195 | name='pgen_dense') # Shape: B x 1
196 |
197 | p_gen = tf.nn.sigmoid(p_gen) # Sigmoid.
198 |
199 | return p_gen
200 |
201 | def step(self, x_t, mem_mask, prev_state, use_pgen=False):
202 | """
203 | This function performs one-step of NSE.
204 | :param x_t: Input in the current time-step, Shape: B x D.
205 | :param mem_mask: Memory mask, Shape: B x T.
206 | :param prev_state: Internal state of NSE after the previous time step.
207 | [0] memory: The NSE memory, Shape: B x T x D.
208 | [1] read_state: Hidden state of the read LSTM, Shape: B x D.
209 | [2] write_state: Hidden state of the write LSTM, Shape: B x D.
210 | [3] comp_state: Hidden state of compose LSTM, Shape: B x (2*D).
211 | :param use_pgen: Flag whether pointer mechanism has to be used.
212 | :return:
213 | outputs: The outputs after the current time step.
214 | [0]: write vector: The written vector to NSE memory, Shape: B x D.
215 | [1]: p_attn: Attention distribution, Shape: B x T_in.
216 | Following additional output in pointer generator mode:
217 | [2]: p_gen: Generation probability, Shape: B x 1.
218 | state: Internal state of NSE after current time step.
219 | [memory, read_state, write_state, comp_state]
220 | """
221 | prev_comp_state = None
222 | prev_memory, prev_read_state, prev_write_state = prev_state[: 3]
223 | if self._use_comp_lstm:
224 | prev_comp_state = prev_state[3]
225 |
226 | if prev_read_state is None: # zero state of read LSTM for the first time step.
227 | prev_read_state = self._read_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32)
228 |
229 | if prev_write_state is None: # zero state of write LSTM for the first time step.
230 | prev_write_state = self._write_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32)
231 |
232 | if (prev_comp_state is None) and self._use_comp_lstm: # zero state of compose LSTM for the first time step.
233 | prev_comp_state = self._comp_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32)
234 |
235 | r_t, curr_read_state = self.read(x_t, prev_read_state) # Read step.
236 | z_t = self.attention(r_t, prev_memory, mem_mask) # Attention.
237 | m_rt, c_t, curr_comp_state = self.compose(r_t, z_t, prev_memory, prev_comp_state) # Compose step.
238 | h_t, curr_write_state = self.write(c_t, prev_write_state) # Write step.
239 | curr_memory = self.update(z_t, prev_memory, h_t) # Memory update step.
240 |
241 | curr_state = [curr_memory, curr_read_state, curr_write_state] # Current NSE states.
242 | if self._use_comp_lstm:
243 | curr_state.append(curr_comp_state)
244 |
245 | outputs = [h_t, z_t] # Outputs after the current step.
246 | if use_pgen: # Pointer generator mode.
247 | p_gen = self.prob_gen(m_rt, h_t, r_t)
248 | outputs.append(p_gen)
249 |
250 | return outputs, curr_state
251 |
--------------------------------------------------------------------------------
/codes/HierNSE.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "Rajeev Bhatt Ambati"
3 |
4 | import tensorflow as tf
5 | tf.set_random_seed(2018)
6 |
7 |
8 | def create_rnn_cell(rnn_size, num_layers, scope):
9 | """
10 | This function creates and returns an RNN cell.
11 | :param rnn_size: Size of the hidden state.
12 | :param num_layers: No. of layers if using a multi_rnn cell.
13 | :param scope: scope for the RNN variables.
14 | :return: returns the RNN cell with the necessary specifications.
15 | """
16 | with tf.variable_scope(scope):
17 | layers = []
18 | for _ in range(num_layers):
19 | cell = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units=rnn_size, reuse=tf.AUTO_REUSE)
20 | layers.append(cell)
21 |
22 | if num_layers > 1:
23 | return tf.nn.rnn_cell.MultiRNNCell(layers)
24 | else:
25 | return layers[0]
26 |
27 |
28 | class HierNSE:
29 | """
30 | This is a Hierarchical Neural Semantic Encoder class.
31 | """
32 |
33 | def __init__(self, batch_size, dim, dense_init, mode='train', use_comp_lstm=False, num_layers=1):
34 | """
35 | :param batch_size: No. of examples in a batch of data.
36 | :param dim: Dimension of the memories. (Same as the dimension of wordvecs).
37 | :param dense_init: Dense kernel initializer.
38 | :param mode: 'train/val/test' mode.
39 | :param use_comp_lstm: Flag if LSTM should be used for compose function or MLP.
40 | :param num_layers: Number of layers in the RNN cell.
41 | """
42 | self._batch_size, self._dim = batch_size, dim
43 |
44 | self._dense_init = dense_init # Initializer.
45 |
46 | self._mode = mode
47 | self._use_comp_lstm = use_comp_lstm
48 |
49 | self._read_scope = 'read'
50 | self._write_scope = 'write'
51 | self._comp_scope = 'compose'
52 |
53 | # Read LSTM
54 | self._read_lstm = create_rnn_cell(self._dim, num_layers, self._read_scope)
55 |
56 | # Compose LSTM
57 | if self._use_comp_lstm:
58 | self._comp_lstm = create_rnn_cell(3*self._dim, num_layers, self._comp_scope)
59 |
60 | # Write LSTM
61 | self._write_lstm = create_rnn_cell(self._dim, num_layers, self._write_scope)
62 |
63 | def read(self, x_t, state=None):
64 | """
65 | This is the read function.
66 | :param x_t: input sequence x, Shape: B x D.
67 | :param state: Previous hidden state of read LSTM.
68 | :return: r_t: Outputs of the read LSTM, Shape: B x D
69 | """
70 | with tf.variable_scope(self._read_scope, reuse=tf.AUTO_REUSE):
71 | r_t, state = self._read_lstm(x_t, state)
72 |
73 | return r_t, state
74 |
75 | def compose(self, r_t, zs_t, ms_t, zd_t, md_t, state=None):
76 | """
77 | This is the compose function.
78 | :param r_t: Read from the input x_t, Shape: B x D.
79 | :param zs_t: Attention distribution over sentence memory, Shape: B x T.
80 | :param ms_t: Sentence memory at time step 't', Shape: B x T x D.
81 | :param zd_t: Attention distribution over document memory, Shape: B x S.
82 | :param md_t: Document memory at time step 't', Shape: B x S x D.
83 | :param state: Previous hidden state of compose LSTM.
84 | :return: ms_rt: Retrieved sentence memory, Shape: B x D.
85 | md_rt: Retrieved document memory, Shape: B x D.
86 | c_t: The composed vector, Shape: B x D.
87 | state: Hidden state of compose LSTM after previous time step if using it.
88 | """
89 | # Retrieved memories.
90 | ms_rt = tf.squeeze(tf.matmul(tf.expand_dims(zs_t, axis=1), ms_t), axis=1) # Sentence memory, Shape: B x D.
91 | md_rt = tf.squeeze(tf.matmul(tf.expand_dims(zd_t, axis=1), md_t), axis=1) # Document memory, Shape: B x D.
92 |
93 | with tf.variable_scope(self._comp_scope, reuse=tf.AUTO_REUSE):
94 | r_m_t = tf.concat([r_t, ms_rt, md_rt], axis=-1) # B x (3*D).
95 |
96 | # Compose LSTM
97 | if self._use_comp_lstm:
98 | r_m_t, state = self._comp_lstm(r_m_t, state)
99 |
100 | # Dense layer to reduce size from 3*D to D.
101 | c_t = tf.layers.dense(inputs=r_m_t,
102 | units=self._dim,
103 | activation=None,
104 | kernel_initializer=self._dense_init,
105 | name='MLP') # Composed vector, Shape: B x D.
106 | c_t = tf.nn.relu(c_t) # Activation Function.
107 |
108 | return ms_rt, md_rt, c_t, state
109 |
110 | def write(self, c_t, state=None):
111 | """
112 | This function implements the write operation - equation 5 from the paper [1].
113 | :param c_t: The composed vector, Shape: B x D.
114 | :param state: Previous hidden state of write LSTM.
115 | :return: h_t: The write vector, Shape: B x D
116 | """
117 | with tf.variable_scope(self._write_scope, reuse=tf.AUTO_REUSE):
118 | h_t, state = self._write_lstm(c_t, state)
119 |
120 | return h_t, state
121 |
122 | def attention(self, r_t, m_t, mem_mask, scope="attention"):
123 | """
124 | This function computes the attention distribution.
125 | :param r_t: Read from the input x_t, Shape: B x D.
126 | :param m_t: Memory at time step 't', Shape: B x T' x D.
127 | T' is T for sentence memory and S for document memory.
128 | :param mem_mask: A mask to indicate the presence of PAD tokens, Shape: B x T'.
129 | :param scope: Name of the scope.
130 | :return:
131 | """
132 | # Shapes
133 | batch_size, attn_len, attn_vec_size = m_t.get_shape().as_list()
134 |
135 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
136 | # Input features.
137 | with tf.variable_scope("input"):
138 | input_features = tf.layers.dense(inputs=r_t,
139 | units=attn_vec_size,
140 | activation=None,
141 | kernel_initializer=self._dense_init,
142 | reuse=tf.AUTO_REUSE,
143 | name="inp_dense") # Shape: B x D.
144 | input_features = tf.expand_dims(input_features, axis=1) # Shape: B x 1 x D.
145 |
146 | # Memory features
147 | with tf.variable_scope("memory"):
148 | memory_features = tf.layers.dense(inputs=m_t,
149 | units=attn_vec_size,
150 | activation=None,
151 | kernel_initializer=self._dense_init,
152 | use_bias=False,
153 | reuse=tf.AUTO_REUSE,
154 | name="memory_dense") # Shape: B x T x D.
155 |
156 | v = tf.get_variable("v", [attn_vec_size])
157 |
158 | scores = tf.reduce_sum(
159 | v * tf.tanh(input_features + memory_features), axis=2) # Shape: B x T'.
160 | attn_dist = tf.nn.softmax(scores) # Shape: B x T'
161 |
162 | # Assigning zero probability to the PAD tokens.
163 | # Re-normalizing the probability distribution to sum to one.
164 | attn_dist = tf.multiply(attn_dist, mem_mask)
165 | masked_sums = tf.reduce_sum(attn_dist, axis=1, keepdims=True) # Shape: B x 1
166 | attn_dist = tf.truediv(attn_dist, masked_sums) # Re-normalization.
167 |
168 | return attn_dist
169 |
170 | @ staticmethod
171 | def update(zs_t, ms_t, zd_t, md_t, h_t):
172 | """
173 | This function updates the sentence and document memories as per the retrieved slots.
174 | :param zs_t: Retrieved attention distribution over sentence memory, Shape: B x T.
175 | :param ms_t: Sentence memory at time step 't', Shape: B x T x D.
176 | :param zd_t: Retrieved attention distribution over document memory, Shape: B x S.
177 | :param md_t: Document memory at time step 't', Shape: B x S x D.
178 | :param h_t: Write vector, Shape: B x D.
179 | :return: new_ms: Updated sentence memory, Shape: B x T x D.
180 | new_md: Updated document memory, Shape: B x S x D.
181 | """
182 |
183 | # Write and erase masks for sentence memories.
184 | write_mask_s = tf.expand_dims(zs_t, axis=2) # Shape: B x T x 1.
185 | erase_mask_s = tf.ones_like(write_mask_s) - write_mask_s # Shape: B x T x 1.
186 |
187 | # Write and erase masks for document memories.
188 | write_mask_d = tf.expand_dims(zd_t, axis=2) # Shape: B x S x 1.
189 | erase_mask_d = tf.ones_like(write_mask_d) - write_mask_d # Shape: B x S x 1.
190 |
191 | # Write tensors for sentence and document memories.
192 | write_tensor_s = tf.expand_dims(h_t, axis=1) # Shape: B x 1 x D.
193 | write_tensor_d = tf.expand_dims(h_t, axis=1) # Shape: B x 1 x D.
194 |
195 | new_ms = tf.add(tf.multiply(ms_t, erase_mask_s), tf.multiply(write_tensor_s, write_mask_s))
196 | new_md = tf.add(tf.multiply(md_t, erase_mask_d), tf.multiply(write_tensor_d, write_mask_d))
197 |
198 | return new_ms, new_md
199 |
200 | def prob_gen(self, ms_rt, md_rt, h_t, r_t):
201 | """
202 | This function calculates the generation probability from the retrieved sentence, document memories, write
203 | vector and the input read.
204 | :param ms_rt: Retrieved sentence memory, Shape: B x D.
205 | :param md_rt: Retrieved document memory, Shape: B x D.
206 | :param h_t: Write vector, Shape: B x D
207 | :param r_t: Read vector from the input, Shape: B x D.
208 | :return: p_gen, Shape: B x 1.
209 | """
210 | with tf.variable_scope('pgen', reuse=tf.AUTO_REUSE):
211 | inp = tf.concat([ms_rt, md_rt, h_t, r_t], axis=1) # Shape: B x (4*D).
212 | p_gen = tf.layers.dense(inp,
213 | units=1,
214 | activation=None,
215 | kernel_initializer=self._dense_init,
216 | name='pgen_dense') # Dense Layer, Shape: B x 1.
217 | p_gen = tf.nn.sigmoid(p_gen) # Sigmoid.
218 |
219 | return p_gen
220 |
221 | def step(self, x_t, mem_masks, prev_state, use_pgen=False):
222 | """
223 | This function performs one-step of Hier-NSE.
224 | :param x_t: Input in the current time-step.
225 | :param mem_masks: [mem_mask_s, mem_mask_d]: Masks for sentence and document memory respectively
226 | indicating the presence of PAD tokens, Shape: [B x T, B x S].
227 | :param prev_state: Internal state of the NSE after the previous time-step.
228 | [0]: [memory_s, memory_d]: The NSE sentence and document memory
229 | respectively, Shape: [B x T x D, B x S x D].
230 | [1]: read_state: Hidden state of the Read LSTM, Shape: B x D.
231 | [2]: write_state: Hidden state of the write LSTM, Shape: B x D.
232 | [3]: comp_state: Hidden state of the compose LSTM, Shape: B x (3*D).
233 | :param use_pgen: Flag whether pointer mechanism has to be used.
234 | :return:
235 | outputs: The following outputs after the current time step.
236 | [0]: write vector: The written vector to NSE memory, Shape: B x D.
237 | [1]: p_attn: [zs_t, zd_t] Attention distribution for sentence and
238 | document memories respectively, Shape: [B x T, B x S].
239 | Following additional output in pointer generator mode.
240 | [2]: p_gen: Generation probability, Shape: B x 1.
241 | Following additional output when using coverage.
242 | [3]: [curr_cov_s, curr_cov_d]: Updated coverages for sentence and document memory respectively.
243 | state: The internal state of NSE after the current time-step.
244 | [[memory_s, memory_d], read_state, write_state, comp_state].
245 | """
246 | # Memory masks
247 | mem_mask_s, mem_mask_d = mem_masks
248 |
249 | # NSE Internal states.
250 | prev_comp_state = None
251 | [prev_memory_s, prev_memory_d], prev_read_state, prev_write_state = prev_state[: 3]
252 | if self._use_comp_lstm:
253 | prev_comp_state = prev_state[3]
254 |
255 | if prev_read_state is None: # zero state of read LSTM for the first time step.
256 | prev_read_state = self._read_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32)
257 |
258 | if prev_write_state is None: # zero state of write LSTM for the first time step.
259 | prev_write_state = self._write_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32)
260 |
261 | if (prev_comp_state is None) and self._use_comp_lstm: # zero state of compose LSTM for the first time step.
262 | prev_comp_state = self._comp_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32)
263 |
264 | r_t, curr_read_state = self.read(x_t, prev_read_state) # Read step.
265 | zs_t = self.attention(
266 | r_t=r_t, m_t=prev_memory_s, mem_mask=mem_mask_s
267 | ) # Sentence attention distribution.
268 | zd_t = self.attention(
269 | r_t=r_t, m_t=prev_memory_d, mem_mask=mem_mask_d
270 | ) # Document attention distribution.
271 |
272 | ms_rt, md_rt, c_t, curr_comp_state = self.compose(
273 | r_t, zs_t, prev_memory_s, zd_t, prev_memory_d, prev_comp_state
274 | ) # Compose step.
275 | h_t, curr_write_state = self.write(c_t, prev_write_state) # Write step.
276 | curr_memory_s, curr_memory_d = self.update(
277 | zs_t, prev_memory_s, zd_t, prev_memory_d, h_t
278 | ) # Update step.
279 |
280 | curr_state = [[curr_memory_s, curr_memory_d], curr_read_state, curr_write_state] # Current NSE states.
281 | if self._use_comp_lstm:
282 | curr_state.append(curr_comp_state)
283 |
284 | outputs = [h_t, [zs_t, zd_t]]
285 | if use_pgen:
286 | p_gen = self.prob_gen(ms_rt, md_rt, h_t, r_t) # Generation probability.
287 | outputs.append(p_gen)
288 |
289 | return outputs, curr_state
290 |
--------------------------------------------------------------------------------
/codes/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "Rajeev Bhatt Ambati"
3 | """
4 | Acknowledgement: Most of the functions related to vocab object are either inspired/took from the
5 | Pointer Generator Network repository published by See. et al 2017.
6 | GitHub link: https://github.com/abisee/pointer-generator/blob/master/data.py
7 | """
8 |
9 | import params
10 | import os
11 | import hashlib
12 | from shutil import copyfile
13 | import math
14 | import collections
15 | import random
16 | from random import shuffle, randint, sample
17 | from gensim.models import KeyedVectors
18 |
19 | import pickle
20 | import numpy as np
21 | import tensorflow as tf
22 | # import pyrouge
23 |
24 | # Setting-up seeds
25 | random.seed(2019)
26 | np.random.seed(2019)
27 | tf.set_random_seed(2019)
28 |
29 |
30 | def read_text_file(text_file):
31 |
32 | lines = []
33 | with open(text_file, "r") as f:
34 | for line in f:
35 | lines.append(line.strip())
36 |
37 | return lines
38 |
39 |
40 | def hashhex(s):
41 | """
42 | Returns a heximal formatted SHA1 hash of the input string.
43 | """
44 | h = hashlib.sha1()
45 | h.update(s.encode('utf-8'))
46 |
47 | return h.hexdigest()
48 |
49 |
50 | def get_url_hashes(url_list):
51 | return [hashhex(url) for url in url_list]
52 |
53 |
54 | def create_set(path_to_data, split):
55 | if os.path.exists(path_to_data + split):
56 | return
57 | else:
58 | os.makedirs(path_to_data + split)
59 |
60 | url_list = read_text_file(path_to_data + "all_" + split + ".txt")
61 | url_hashes = get_url_hashes(url_list)
62 | story_fnames = [s + ".story" for s in url_hashes]
63 |
64 | print("{}:".format(split))
65 | for i, s in enumerate(story_fnames):
66 | if os.path.isfile(path_to_data + "cnn_stories_tokenized/" + s):
67 | src = path_to_data + "cnn_stories_tokenized/" + s
68 |
69 | elif os.path.isfile(path_to_data + "dm_stories_tokenized/" + s):
70 | src = path_to_data + "dm_stories_tokenized/" + s
71 |
72 | else:
73 | raise Exception("Story file {} is not found!".format(s))
74 |
75 | trg = path_to_data + split + "/" + s
76 |
77 | copyfile(src, trg)
78 | print("files done: {}/{}".format(i, len(story_fnames)))
79 |
80 |
81 | def create_train_val_test(path_to_data="../data/"):
82 | """
83 | :param path_to_data: Path to data.
84 | This function creates train, test and validation sets.
85 | :return:
86 | """
87 | # Train.
88 | create_set(path_to_data, "train")
89 |
90 | # Validation.
91 | create_set(path_to_data, "val")
92 |
93 | # Test.
94 | create_set(path_to_data, "test")
95 |
96 |
97 | def fix_missing_period(line):
98 | if "@highlight" in line:
99 | return line
100 |
101 | if line == "":
102 | return line
103 |
104 | if line[-1] in params.END_TOKENS:
105 | return line[: -1] + " " + line[-1]
106 |
107 | return line + " ."
108 |
109 |
110 | def get_article_summary(story_file, art_format="tokens", sum_format="tokens"):
111 | # Read the story file.
112 | f = open(story_file, 'r', encoding='ISO-8859-1')
113 |
114 | # Lowercase everything.
115 | lines = [line.strip().lower() for line in f.readlines()]
116 |
117 | # Some lines don't have periods in the end. Correct'em.
118 | lines = [fix_missing_period(line) for line in lines]
119 |
120 | # Separate article and abstract sentences.
121 | article_lines = []
122 | highlights = []
123 | next_is_highlight = False
124 |
125 | for idx, line in enumerate(lines):
126 |
127 | if line == "": # Empty line.
128 | continue
129 |
130 | elif line.startswith("@highlight"):
131 | next_is_highlight = True
132 |
133 | elif next_is_highlight:
134 | highlights.append(line)
135 |
136 | else:
137 | article_lines.append(line)
138 |
139 | # Joining all article sentences together into a string.
140 | article = ' '.join(article_lines)
141 |
142 | # Joining all highlights together into a string including sentence
143 | # start and end .
144 | summary = ' '.join(["%s %s %s" % (params.SENTENCE_START, sent,
145 | params.SENTENCE_END) for sent in highlights])
146 |
147 | if art_format == "tokens":
148 |
149 | if sum_format == "tokens":
150 | return article, summary
151 | elif sum_format == "sentences":
152 | return article, highlights
153 | else:
154 | raise ValueError("Article format should be either tokens or sentences!! \n")
155 |
156 | elif art_format == "sentences":
157 | if sum_format == "tokens":
158 | return article_lines, summary
159 | elif sum_format == "sentences":
160 | return article_lines, highlights
161 | else:
162 | raise ValueError("Article format should be either tokens or sentences!! \n")
163 | else:
164 | raise ValueError("Article format should be either tokens or sentences!! \n")
165 |
166 |
167 | def article2ids(article_words, vocab):
168 | """
169 | This function converts given article words to ID's
170 | :param article_words: article tokens.
171 | :param vocab: The vocabulary object used for lookup tables, vocabulary etc.
172 | :return: The corresponding ID's and a list of OOV tokens.
173 | """
174 | ids = []
175 | oovs = []
176 | unk_id = vocab.word2id(params.UNKNOWN_TOKEN)
177 | for word in article_words:
178 | i = vocab.word2id(word)
179 | if i == unk_id: # Out of vocabulary words.
180 | if word in oovs:
181 | ids.append(vocab.size() + oovs.index(word))
182 | else:
183 | oovs.append(word)
184 | ids.append(vocab.size() + oovs.index(word))
185 | else: # In vocabulary words.
186 | ids.append(i)
187 |
188 | return ids, oovs
189 |
190 |
191 | def summary2ids(summary_words, vocab, article_oovs):
192 | """
193 | This function converts the given summary words to ID's
194 | :param summary_words: summary tokens.
195 | :param vocab: The vocabulary object used for lookup tables, vocabulary etc.
196 | :param article_oovs: OOV tokens in the input article.
197 | :return: The corresponding ID's.
198 | """
199 | num_sent, sent_len = 0, 0
200 | if type(summary_words[0]) is list:
201 | num_sent = len(summary_words)
202 | sent_len = len(summary_words[0])
203 |
204 | cum_words = []
205 | for _, sent_words in enumerate(summary_words):
206 | cum_words += sent_words
207 |
208 | summary_words = cum_words
209 |
210 | ids = []
211 | unk_id = vocab.word2id(params.UNKNOWN_TOKEN)
212 | for word in summary_words:
213 | i = vocab.word2id(word)
214 | if i == unk_id: # Out of vocabulary words.
215 | if word in article_oovs: # In article OOV words.
216 | ids.append(vocab.size() + article_oovs.index(word))
217 | else: # Both OOV and article OOV words.
218 | ids.append(unk_id)
219 | else: # In vocabulary words.
220 | ids.append(i)
221 |
222 | if num_sent != 0:
223 | doc_ids = []
224 | for i in range(num_sent):
225 | doc_ids.append(ids[sent_len * i: sent_len * (i + 1)])
226 |
227 | ids = doc_ids
228 |
229 | return ids
230 |
231 |
232 | class Vocab(object):
233 | def __init__(self, max_vocab_size, emb_dim=300, dataset_path='../data/', glove_path='../data/glove.840B.300d.txt',
234 | vocab_path='../data/vocab.txt', lookup_path='../data/lookup.pkl'):
235 |
236 | self.max_size = max_vocab_size
237 | self._dim = emb_dim
238 | self.PathToGloveFile = glove_path
239 | self.PathToVocabFile = vocab_path
240 | self.PathToLookups = lookup_path
241 |
242 | create_train_val_test(dataset_path)
243 |
244 | stories = os.listdir(dataset_path + 'train') # Using only train files for Vocab,
245 |
246 | # All train Stories.
247 | self._story_files = [
248 | os.path.join(dataset_path + 'train/', s) for s in stories]
249 |
250 | self.vocab = [] # Vocabulary
251 |
252 | # Create the vocab file.
253 | self.create_total_vocab()
254 |
255 | # Create the lookup tables.
256 | self.wvecs = [] # Word vectors.
257 | self._word_to_id = {} # word to ID's lookups
258 | self._id_to_word = {} # ID to word lookups
259 |
260 | self.create_lookup_tables()
261 |
262 | assert len(self._word_to_id.keys()) == len(self._id_to_word.keys()), "Both lookups should have same size."
263 |
264 | def size(self):
265 | return len(self.vocab)
266 |
267 | def word2id(self, word):
268 | """
269 | This function returns the vocabulary ID for word if it is present. Otherwise, returns the ID
270 | for the unknown token.
271 | :param word: input word.
272 | :return: returns the ID.
273 | """
274 | if word in self._word_to_id:
275 | return self._word_to_id[word]
276 | else:
277 | return self._word_to_id[params.UNKNOWN_TOKEN]
278 |
279 | def id2word(self, word_id):
280 | """
281 | This function returns the corresponding word for a given vocabulary ID.
282 | :param word_id: input ID.
283 | :return: returns the word.
284 | """
285 | if word_id in self._id_to_word:
286 | return self._id_to_word[word_id]
287 | else:
288 | raise ValueError("{} is not a valid ID.\n".format(word_id))
289 |
290 | def create_total_vocab(self):
291 |
292 | if os.path.isfile(self.PathToVocabFile):
293 | print("Vocab file exists! \n")
294 |
295 | vocab_f = open(self.PathToVocabFile, 'r')
296 | for line in vocab_f:
297 | word = line.split()[0]
298 | self.vocab.append(word)
299 |
300 | return
301 | else:
302 | print("Vocab file NOT found!! \n")
303 | print("Creating a vocab file! \n")
304 |
305 | vocab_counter = collections.Counter()
306 |
307 | for idx, story in enumerate(self._story_files):
308 | article, summary = get_article_summary(story)
309 |
310 | art_tokens = article.split(' ') # Article tokens.
311 | sum_tokens = summary.split(' ') # Summary tokens.
312 | sum_tokens = [t for t in sum_tokens if t not in # Removing , tokens.
313 | [params.SENTENCE_START, params.SENTENCE_END]]
314 |
315 | assert (params.SENTENCE_START not in sum_tokens) and (params.SENTENCE_END not in sum_tokens), \
316 | " and shouldn't be present in sum_tokens"
317 |
318 | tokens = art_tokens + sum_tokens
319 | tokens = [t.strip() for t in tokens]
320 | tokens = [t for t in tokens if t != ''] # Removing empty tokens.
321 |
322 | vocab_counter.update(tokens) # Keeping a count of the tokens.
323 |
324 | print("\r{}/{} files read!".format(idx + 1, len(self._story_files)))
325 |
326 | print("\n Writing the vocab file! \n")
327 | f = open(self.PathToVocabFile, 'w', encoding='utf-8')
328 |
329 | for word, count in vocab_counter.most_common(params.VOCAB_SIZE):
330 | f.write(word + ' ' + str(count) + '\n')
331 | self.vocab.append(word)
332 |
333 | f.close()
334 |
335 | def create_small_vocab(self):
336 | """
337 | This function selects a few words out of the total vocabulary.
338 | """
339 |
340 | # Read the vocab file and assign id's to each word till the max_size.
341 | vocab_f = open(self.PathToVocabFile, 'r')
342 |
343 | for line in vocab_f:
344 | word = line.split()[0]
345 |
346 | if word in [params.SENTENCE_START, params.SENTENCE_END, params.UNKNOWN_TOKEN,
347 | params.PAD_TOKEN, params.START_DECODING, params.STOP_DECODING]:
348 | raise Exception(', , [UNK], [PAD], [START], \
349 | [STOP] shouldn\'t be in the vocab file, but %s is' % word)
350 |
351 | self.vocab.append(word)
352 | print("\r{}/{} words created!".format(len(self.vocab), self.max_size))
353 |
354 | if len(self.vocab) == self.max_size:
355 | print("\n Max size of the vocabulary reached! Stopping reading! \n")
356 | break
357 |
358 | def create_lookup_tables_(self):
359 | """
360 | This function creates lookup tables for word vectors, word to IDs
361 | and ID to words. First max_size words from GloVe that are also found in the small vocab are used
362 | to create the lookup tables.
363 | """
364 |
365 | if os.path.isfile(self.PathToLookups):
366 | print('\n Lookup tables found :) \n')
367 | f = open(self.PathToLookups, 'rb')
368 | data = pickle.load(f)
369 |
370 | self._word_to_id = data['word2id']
371 | self._id_to_word = data['id2word']
372 | self.wvecs = data['wvecs']
373 | self.vocab = list(self._word_to_id.keys())
374 |
375 | print('Lookup tables collected for {} tokens.\n'.format(len(self.vocab)))
376 | return
377 | else:
378 | print('\n Lookup files NOT found!! \n')
379 | print('\n Creating the lookup tables! \n')
380 |
381 | self.create_small_vocab() # Creating a small vocabulary.
382 | self.wvecs = [] # Word vectors.
383 |
384 | glove_f = open(self.PathToGloveFile, 'r', encoding='utf8')
385 | count = 0
386 |
387 | # [UNK], [PAD], [START] and [STOP] get ids 0, 1, 2, 3.
388 | for w in [params.UNKNOWN_TOKEN, params.PAD_TOKEN, params.START_DECODING, params.STOP_DECODING]:
389 | self._word_to_id[w] = count
390 | self._id_to_word[count] = w
391 | self.wvecs.append(np.random.uniform(-0.1, 0.1, (self._dim,)).astype(np.float32))
392 | count += 1
393 |
394 | print("\r Created tables for {}".format(w))
395 |
396 | for line in glove_f:
397 | vals = line.rstrip().split(' ')
398 | w = vals[0]
399 | vec = np.array(vals[1:]).astype(np.float32)
400 |
401 | if w in self.vocab:
402 | self._word_to_id[w] = count
403 | self._id_to_word[count] = w
404 | self.wvecs.append(vec)
405 | count += 1
406 |
407 | print("\r Created tables for {}".format(w))
408 |
409 | if count == self.max_size:
410 | print("\r Maximum vocab size reached! \n")
411 | break
412 |
413 | print("\n Lookup tables created for {} tokens. \n".format(count))
414 |
415 | self.wvecs = np.array(self.wvecs).astype(np.float32) # Converting to a Numpy array.
416 | self.vocab = list(self._word_to_id.keys()) # Adjusting the vocabulary to found pre-trained vectors.
417 |
418 | # Saving the lookup tables.
419 | f = open(self.PathToLookups, 'wb')
420 | data = {'word2id': self._word_to_id,
421 | 'id2word': self._id_to_word,
422 | 'wvecs': self.wvecs}
423 | pickle.dump(data, f)
424 |
425 | def create_lookup_tables(self):
426 | """
427 | This function creates lookup tables for word vectors, word to IDs
428 | and ID to words. First max_size words from GloVe that are also found in the small vocab are used
429 | to create the lookup tables.
430 | """
431 |
432 | if os.path.isfile(self.PathToLookups):
433 | print('\n Lookup tables found :) \n')
434 | f = open(self.PathToLookups, 'rb')
435 | data = pickle.load(f)
436 |
437 | self._word_to_id = data['word2id']
438 | self._id_to_word = data['id2word']
439 | self.wvecs = data['wvecs']
440 | self.vocab = list(self._word_to_id.keys())
441 |
442 | print('Lookup tables collected for {} tokens.\n'.format(len(self.vocab)))
443 | return
444 | else:
445 | print('\n Lookup files NOT found!! \n')
446 | print('\n Creating the lookup tables! \n')
447 |
448 | self.wvecs = [] # Word vectors.
449 |
450 | word2vec = KeyedVectors.load_word2vec_format(self.PathToGloveFile, binary=False)
451 |
452 | count = 0
453 | # [UNK], [PAD], [START] and [STOP] get ids 0, 1, 2, 3.
454 | for w in [params.UNKNOWN_TOKEN, params.PAD_TOKEN, params.START_DECODING, params.STOP_DECODING]:
455 | self._word_to_id[w] = count
456 | self._id_to_word[count] = w
457 | self.wvecs.append(np.random.uniform(-0.1, 0.1, (self._dim,)).astype(np.float32))
458 | count += 1
459 |
460 | print("\r Created tables for {}".format(w))
461 |
462 | vocab_f = open(self.PathToVocabFile, "r")
463 | for line in vocab_f:
464 | word = line.split()[0]
465 |
466 | if word in word2vec:
467 | self._word_to_id[word] = count
468 | self._id_to_word[count] = word
469 | self.wvecs.append(word2vec[word])
470 | count += 1
471 |
472 | print("\r Created tables for {}".format(word))
473 |
474 | if count == self.max_size:
475 | print("\r Maximum vocab size reached! \n")
476 | break
477 |
478 | print("\n Lookup tables created for {} tokens. \n".format(count))
479 |
480 | self.wvecs = np.array(self.wvecs).astype(np.float32) # Converting to a Numpy array.
481 | self.vocab = list(self._word_to_id.keys()) # Adjusting the vocabulary to found pre-trained vectors.
482 |
483 | # Saving the lookup tables.
484 | f = open(self.PathToLookups, 'wb')
485 | data = {'word2id': self._word_to_id,
486 | 'id2word': self._id_to_word,
487 | 'wvecs': self.wvecs}
488 | pickle.dump(data, f)
489 |
490 |
491 | class DataGenerator(object):
492 |
493 | def __init__(self, path_to_dataset, max_inp_seq_len, max_out_seq_len, vocab, use_pgen=False, use_sample=False):
494 | # Train files.
495 | train_stories = os.listdir(path_to_dataset + 'train')
496 | self.train_files = [os.path.join(path_to_dataset + 'train', s) for s in train_stories]
497 | self.num_train_examples = len(self.train_files)
498 | shuffle(self.train_files)
499 |
500 | # Validation files.
501 | val_stories = os.listdir(path_to_dataset + 'val')
502 | self.val_files = [os.path.join(path_to_dataset + 'val', s) for s in val_stories]
503 | self.num_val_examples = len(self.val_files)
504 | shuffle(self.val_files)
505 |
506 | # Test files.
507 | test_stories = os.listdir(path_to_dataset + 'test')
508 | self.test_files = [os.path.join(path_to_dataset + 'test', s) for s in test_stories]
509 | self.num_test_examples = len(self.test_files)
510 | # shuffle(self.test_files)
511 |
512 | self._max_enc_steps = max_inp_seq_len # Max. no. of tokens in the input sequence.
513 | self._max_dec_steps = max_out_seq_len # Max. no. of tokens in the output sequence.
514 |
515 | self.vocab = vocab # Vocabulary instance.
516 | self._use_pgen = use_pgen # Whether pointer mechanism should be used.
517 |
518 | self._ptr = 0 # Pointer for batching the data.
519 |
520 | if use_sample:
521 | # **************************** PATCH ************************* #
522 | self.train_files = self.train_files[:20]
523 | self.num_train_examples = len(self.train_files)
524 | self.val_files = self.val_files[:20]
525 | self.num_val_examples = len(self.val_files)
526 | self.test_files = self.test_files[:23]
527 | self.num_test_examples = len(self.test_files)
528 | # **************************** PATCH ************************* #
529 |
530 | print("Split the data as follows:\n")
531 | print("\t\t Training: {} examples. \n".format(self.num_train_examples))
532 | print("\t\t Validation: {} examples. \n".format(self.num_val_examples))
533 | print("\t\t Test: {} examples. \n".format(self.num_test_examples))
534 |
535 | def get_train_val_batch(self, batch_size, split='train', permutate=False):
536 | if split == 'train':
537 | num_examples = self.num_train_examples
538 | files = self.train_files
539 | elif split == 'val':
540 | num_examples = self.num_val_examples
541 | files = self.val_files
542 | else:
543 | raise ValueError("split is neither train nor val. check the function call!")
544 |
545 | enc_inp = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32)
546 | dec_inp = np.ndarray(shape=(batch_size, self._max_dec_steps), dtype=np.int32)
547 | dec_out = np.ndarray(shape=(batch_size, self._max_dec_steps), dtype=np.int32)
548 |
549 | enc_inp_ext_vocab = None
550 | max_oov_size = -np.infty
551 | if self._use_pgen:
552 | enc_inp_ext_vocab = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32)
553 |
554 | # Shuffle files at the start of an epoch.
555 | if self._ptr == 0:
556 | shuffle(files)
557 |
558 | # Start and end index for a batch of data.
559 | start = self._ptr
560 | end = self._ptr + batch_size
561 | self._ptr = end
562 |
563 | for i in range(start, end):
564 | j = i - start # Index of the example in current batch.
565 | article, summary = get_article_summary(files[i])
566 | enc_inp_tokens = article.split(' ') # Article tokens.
567 |
568 | # Article Tokens
569 | if len(enc_inp_tokens) >= self._max_enc_steps: # Truncate.
570 | if permutate:
571 | indcs = sorted(sample(range(len(enc_inp_tokens)), self._max_enc_steps))
572 | enc_inp_tokens = [enc_inp_tokens[i] for i in indcs]
573 | else:
574 | enc_inp_tokens = enc_inp_tokens[: self._max_enc_steps]
575 | else: # Pad.
576 | enc_inp_tokens += (self._max_enc_steps - len(enc_inp_tokens)) * [params.PAD_TOKEN]
577 |
578 | # Encoder Input
579 | enc_inp_ids = [self.vocab.word2id(w) for w in enc_inp_tokens] # Word to ID's
580 |
581 | # Summary Tokens
582 | sum_tokens = summary.split(' ') # Summary tokens.
583 | sum_tokens = [t for t in sum_tokens if t not in # Removing , tokens.
584 | [params.SENTENCE_START, params.SENTENCE_END]]
585 |
586 | if len(sum_tokens) > self._max_dec_steps - 1: # Truncate.
587 | sum_tokens = sum_tokens[: self._max_dec_steps - 1]
588 |
589 | # Decoder Input
590 | dec_inp_tokens = [params.START_DECODING] + sum_tokens
591 | if len(dec_inp_tokens) < self._max_dec_steps:
592 | dec_inp_tokens += (self._max_dec_steps - len(dec_inp_tokens)) * [params.PAD_TOKEN]
593 |
594 | # Decoder Output
595 | dec_out_tokens = sum_tokens + [params.STOP_DECODING]
596 | dec_out_len = len(dec_out_tokens)
597 | if dec_out_len < self._max_dec_steps:
598 | dec_out_tokens += (self._max_dec_steps - dec_out_len) * [params.PAD_TOKEN]
599 |
600 | dec_inp_ids = [self.vocab.word2id(w) for w in dec_inp_tokens]
601 | dec_out_ids = [self.vocab.word2id(w) for w in dec_out_tokens]
602 |
603 | enc_inp_ids_ext_vocab = None
604 | if self._use_pgen:
605 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self.vocab)
606 | dec_out_ids = summary2ids(dec_out_tokens, self.vocab, article_oovs)
607 |
608 | if len(article_oovs) > max_oov_size:
609 | max_oov_size = len(article_oovs)
610 |
611 | # Appending to the batch of inputs.
612 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32) # Appending to the enc_inp batch.
613 | dec_inp[j] = np.array(dec_inp_ids).astype(np.int32) # Appending to the dec_inp batch.
614 | dec_out[j] = np.array(dec_out_ids).astype(np.int32) # Appending to the dec_out batch.
615 |
616 | if self._use_pgen:
617 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32)
618 |
619 | # Resetting the pointer after the last batch
620 | if self._ptr == num_examples:
621 | self._ptr = 0
622 |
623 | # Setting the pointer for the last batch
624 | if self._ptr + batch_size > num_examples:
625 | self._ptr = num_examples - batch_size
626 |
627 | enc_padding_mask = (enc_inp != self.vocab.word2id(params.PAD_TOKEN)).astype(np.float32)
628 | dec_padding_mask = (dec_out != self.vocab.word2id(params.PAD_TOKEN)).astype(np.float32)
629 |
630 | batches = [enc_inp, enc_padding_mask, dec_inp, dec_out, dec_padding_mask]
631 | if self._use_pgen:
632 | batches += [enc_inp_ext_vocab, max_oov_size]
633 |
634 | return batches
635 |
636 | def get_test_batch(self, batch_size):
637 | num_examples = self.num_test_examples
638 | files = self.test_files
639 |
640 | enc_inp = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32)
641 |
642 | enc_inp_ext_vocab = None
643 | max_oov_size = -np.infty
644 | if self._use_pgen:
645 | enc_inp_ext_vocab = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32)
646 |
647 | summaries = [] # Used in 'test' mode.
648 | ext_vocabs = [] # Extended vocabularies.
649 |
650 | # # Shuffle files at the start of an epoch.
651 | # if self._ptr == 0:
652 | # shuffle(files)
653 |
654 | # Start and end index for a batch of data.
655 | start = self._ptr
656 | end = self._ptr + batch_size
657 | self._ptr = end
658 |
659 | for i in range(start, end):
660 | j = i - start # Index of the example in current batch.
661 | article, summary = get_article_summary(files[i])
662 | enc_inp_tokens = article.split(' ')
663 |
664 | if len(enc_inp_tokens) >= self._max_enc_steps: # Truncate.
665 | enc_inp_tokens = enc_inp_tokens[: self._max_enc_steps]
666 | else: # Pad.
667 | enc_inp_tokens += (self._max_enc_steps - len(enc_inp_tokens)) * [params.PAD_TOKEN]
668 |
669 | # Encoder Input representation in fixed vocabulary.
670 | enc_inp_ids = [self.vocab.word2id(w) for w in enc_inp_tokens] # Word to ID's
671 |
672 | # Encoder Input representation in extended vocabulary.
673 | enc_inp_ids_ext_vocab = None
674 | article_oovs = None
675 | if self._use_pgen:
676 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self.vocab)
677 |
678 | if len(article_oovs) > max_oov_size:
679 | max_oov_size = len(article_oovs)
680 |
681 | # Appending to the input batch.
682 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32)
683 |
684 | if self._use_pgen:
685 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32)
686 | ext_vocabs.append(article_oovs)
687 |
688 | sum_tokens = summary.split(' ')
689 | summaries.append(sum_tokens)
690 |
691 | # Resetting the pointer after the last batch
692 | if self._ptr == num_examples:
693 | self._ptr = 0
694 |
695 | # Setting the pointer for the last batch
696 | if self._ptr + batch_size > num_examples:
697 | self._ptr = num_examples - batch_size
698 |
699 | # Repeat a single input beam size times.
700 | # enc_inp = np.repeat(enc_inp, repeats=beam_size, axis=0) # Shape: Bm x T_in.
701 | # enc_inp_ext_vocab = np.repeat(enc_inp_ext_vocab, repeats=beam_size, axis=0) # Shape: Bm x T_in.
702 | enc_padding_mask = (enc_inp != self.vocab.word2id(params.PAD_TOKEN)).astype(np.float32) # Shape: B x T_in.
703 |
704 | # Example indices.
705 | indices = list(range(start, end))
706 | batches = [indices, summaries, files[start: end], enc_inp, enc_padding_mask]
707 | if self._use_pgen:
708 | batches += [enc_inp_ext_vocab, ext_vocabs, max_oov_size]
709 |
710 | return batches
711 |
712 | def get_batch(self, batch_size, split='train', permutate=False):
713 | if split == 'train' or split == 'val':
714 | return self.get_train_val_batch(batch_size, split, permutate)
715 | elif split == 'test':
716 | return self.get_test_batch(batch_size)
717 | else:
718 | raise ValueError('split should be either of train/val/test only!! \n')
719 |
720 |
721 | class DataGeneratorHier(object):
722 |
723 | def __init__(self, path_to_dataset, max_inp_sent, max_inp_tok_per_sent, max_out_tok, vocab,
724 | use_pgen=False, use_sample=False):
725 |
726 | self._max_enc_sent = max_inp_sent # Max. no of sentences in the encoder sequence.
727 | self._max_enc_tok_per_sent = max_inp_tok_per_sent # Max. no of tokens per sentence in the encoder sequence.
728 | self._max_enc_tok = self._max_enc_sent * self._max_enc_tok_per_sent
729 | self._max_dec_tok = max_out_tok # Max. no of tokens in the decoder sequence.
730 |
731 | self._vocab = vocab # Vocabulary object.
732 | self._use_pgen = use_pgen # Flag whether pointer-generator mechanism is used.
733 |
734 | self._ptr = 0 # Pointer for batching the data.
735 |
736 | # Train files.
737 | train_stories = os.listdir(path_to_dataset + 'train')
738 | self._train_files = [os.path.join(path_to_dataset + 'train/', s) for s in train_stories]
739 | self.num_train_examples = len(self._train_files)
740 | shuffle(self._train_files)
741 |
742 | # Validation files.
743 | val_stories = os.listdir(path_to_dataset + 'val')
744 | self._val_files = [os.path.join(path_to_dataset + 'val/', s) for s in val_stories]
745 | self.num_val_examples = len(self._val_files)
746 | shuffle(self._val_files)
747 |
748 | # Test files.
749 | test_stories = os.listdir(path_to_dataset + 'test')
750 | self._test_files = [os.path.join(path_to_dataset + 'test/', s) for s in test_stories]
751 | self.num_test_examples = len(self._test_files)
752 | # shuffle(self._test_files)
753 |
754 | if use_sample:
755 | # ***************************** PATCH ***************************** #
756 | train_idx = randint(0, self.num_train_examples - 20)
757 | self._train_files = self._train_files[train_idx: train_idx + 20]
758 | self.num_train_examples = len(self._train_files)
759 |
760 | val_idx = randint(0, self.num_val_examples - 20)
761 | self._val_files = self._val_files[val_idx: val_idx + 20]
762 | self.num_val_examples = len(self._val_files)
763 |
764 | test_idx = randint(0, self.num_test_examples - 23)
765 | self._test_files = self._test_files[test_idx: test_idx + 23]
766 | self.num_test_examples = len(self._test_files)
767 | # ***************************** PATCH ***************************** #
768 |
769 | print("Split the data as follows:\n")
770 | print("\t\t Training: {} examples. \n".format(self.num_train_examples))
771 | print("\t\t Validation: {} examples. \n".format(self.num_val_examples))
772 | print("\t\t Test: {} examples. \n".format(self.num_test_examples))
773 |
774 | def get_batch(self, batch_size, split="train", permutate=False, chunk=False):
775 | if split == "train" or split == "val":
776 | return self._get_train_val_batch(batch_size, split, permutate, chunk)
777 |
778 | elif split == "test":
779 | return self._get_test_batch(batch_size, chunk)
780 |
781 | else:
782 | raise ValueError("Split should be either of train/val/test only!! \n")
783 |
784 | def _get_train_val_batch(self, batch_size, split="train", permutate=False, chunk=False):
785 | if split == "train":
786 | num_examples = self.num_train_examples
787 | files = self._train_files
788 |
789 | elif split == "val":
790 | num_examples = self.num_val_examples
791 | files = self._val_files
792 |
793 | else:
794 | raise ValueError("Split is neither train nor val. Check the function call!")
795 |
796 | enc_inp = np.ndarray(shape=[batch_size, self._max_enc_sent, self._max_enc_tok_per_sent], dtype=np.int32)
797 | dec_inp = np.ndarray(shape=[batch_size, self._max_dec_tok], dtype=np.int32)
798 | dec_out = np.ndarray(shape=[batch_size, self._max_dec_tok], dtype=np.int32)
799 |
800 | # Additional inputs in the pointer-generator mode.
801 | max_oov_size = -np.infty
802 | enc_inp_ext_vocab = None
803 | if self._use_pgen:
804 | enc_inp_ext_vocab = np.ndarray(shape=[batch_size, self._max_enc_tok], dtype=np.int32)
805 |
806 | # Shuffle files at the start of an epoch.
807 | if self._ptr == 0:
808 | shuffle(files)
809 |
810 | # Start and end index for a batch of data.
811 | start = self._ptr
812 | end = self._ptr + batch_size
813 | self._ptr = end
814 |
815 | for i in range(start, end):
816 | j = i - start # Index of the example in current batch.
817 | article_sents, summary = get_article_summary(files[i], art_format="sentences")
818 |
819 | # When chunking reshaping the data.
820 | if chunk:
821 | art_words = ' '.join(article_sents)
822 |
823 | article_lines = [[]]
824 | word_count = 0
825 | for word in art_words.split(' '):
826 | article_lines[-1].append(word)
827 | word_count += 1
828 |
829 | if word_count == self._max_enc_tok_per_sent:
830 | word_count = 0
831 | article_lines.append([])
832 |
833 | article_sents = []
834 | for line in article_lines:
835 | article_sents.append(' '.join(line))
836 |
837 | if len(article_sents) >= self._max_enc_sent: # Truncate no. of sentences.
838 | if permutate:
839 | indcs = sorted(sample(range(len(article_sents)), self._max_enc_sent))
840 | article_sents = [article_sents[i] for i in indcs]
841 | else:
842 | article_sents = article_sents[: self._max_enc_sent]
843 | else:
844 | article_sents += (self._max_enc_sent -
845 | len(article_sents)) * [''] # Add empty sentences.
846 |
847 | enc_inp_ids = []
848 | enc_inp_tokens = []
849 | for sent_idx, art_sent in enumerate(article_sents):
850 |
851 | enc_sent_tokens = art_sent.split(' ')
852 | if len(enc_sent_tokens) >= self._max_enc_tok_per_sent: # Truncate no. of tokens.
853 | if permutate:
854 | indcs = sorted(sample(range(len(enc_sent_tokens)), self._max_enc_tok_per_sent))
855 | enc_sent_tokens = [enc_sent_tokens[i] for i in indcs]
856 | else:
857 | enc_sent_tokens = enc_sent_tokens[: self._max_enc_tok_per_sent]
858 | else:
859 | enc_sent_tokens += (self._max_enc_tok_per_sent
860 | - len(enc_sent_tokens)) * [params.PAD_TOKEN] # Pad.
861 |
862 | # Encoder sentence representation in the fixed vocabulary.
863 | enc_sent_ids = [self._vocab.word2id(w) for w in enc_sent_tokens]
864 |
865 | # Appending to the lists.
866 | enc_inp_ids.append(enc_sent_ids)
867 | enc_inp_tokens += enc_sent_tokens
868 |
869 | # Summary tokens.
870 | sum_tokens = summary.split(' ') # Summary tokens.
871 | sum_tokens = [t for t in sum_tokens if t not in
872 | [params.SENTENCE_START, params.SENTENCE_END]] # Removing , tokens.
873 |
874 | if len(sum_tokens) > self._max_dec_tok - 1: # Truncate.
875 | sum_tokens = sum_tokens[: self._max_dec_tok - 1]
876 |
877 | # Decoder Input.
878 | dec_inp_tokens = [params.START_DECODING] + sum_tokens
879 | if len(dec_inp_tokens) < self._max_dec_tok:
880 | dec_inp_tokens += (self._max_dec_tok - len(dec_inp_tokens)) * [params.PAD_TOKEN]
881 |
882 | # Decoder Output.
883 | dec_out_tokens = sum_tokens + [params.STOP_DECODING]
884 | if len(dec_out_tokens) < self._max_dec_tok:
885 | dec_out_tokens += (self._max_dec_tok - len(dec_out_tokens)) * [params.PAD_TOKEN]
886 |
887 | dec_inp_ids = [self._vocab.word2id(w) for w in dec_inp_tokens]
888 | dec_out_ids = [self._vocab.word2id(w) for w in dec_out_tokens]
889 |
890 | # Encoder input, decoder output representation in extended vocabulary.
891 | enc_inp_ids_ext_vocab = None
892 | if self._use_pgen:
893 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self._vocab)
894 | dec_out_ids = summary2ids(dec_out_tokens, self._vocab, article_oovs)
895 |
896 | if len(article_oovs) > max_oov_size:
897 | max_oov_size = len(article_oovs)
898 |
899 | # Appending to the batch of inputs.
900 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32)
901 | dec_inp[j] = np.array(dec_inp_ids).astype(np.int32)
902 | dec_out[j] = np.array(dec_out_ids).astype(np.int32)
903 |
904 | if self._use_pgen:
905 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32)
906 |
907 | # Resetting the pointer after the last batch.
908 | if self._ptr == num_examples:
909 | self._ptr = 0
910 |
911 | # Setting the pointer for the last batch.
912 | if self._ptr + batch_size > num_examples:
913 | self._ptr = num_examples - batch_size
914 |
915 | # Padding masks.
916 | pad_id = self._vocab.word2id(params.PAD_TOKEN)
917 | enc_pad_mask = (enc_inp != pad_id).astype(np.float32) # B x S_in x T_in.
918 | enc_doc_mask = np.sum(enc_pad_mask, axis=2) # B x S_in.
919 | enc_doc_mask = np.greater(enc_doc_mask, 0).astype(np.float32) # B x S_in.
920 | dec_pad_mask = (dec_out != pad_id).astype(np.float32) # B x T_dec.
921 |
922 | batches = [enc_inp, enc_pad_mask, enc_doc_mask, dec_inp, dec_out, dec_pad_mask]
923 | if self._use_pgen:
924 | batches += [enc_inp_ext_vocab, max_oov_size]
925 |
926 | return batches
927 |
928 | def _get_test_batch(self, batch_size, chunk=False):
929 | num_examples = self.num_test_examples
930 | files = self._test_files
931 |
932 | enc_inp = np.ndarray(shape=[batch_size, self._max_enc_sent, self._max_enc_tok_per_sent], dtype=np.int32)
933 |
934 | max_oov_size = -np.infty
935 | enc_inp_ext_vocab = None
936 | if self._use_pgen:
937 | enc_inp_ext_vocab = np.ndarray(shape=[batch_size, self._max_enc_tok], dtype=np.int32)
938 |
939 | summaries = [] # Used in 'test' mode.
940 | ext_vocabs = [] # Extended vocabularies.
941 |
942 | # # Shuffle files at the start of an epoch.
943 | # if self._ptr == 0:
944 | # shuffle(files)
945 |
946 | # Start and end index for a batch of data.
947 | start = self._ptr
948 | end = self._ptr + batch_size
949 | self._ptr = end
950 |
951 | for i in range(start, end):
952 | j = i - start
953 | article_sents, summary = get_article_summary(files[i], art_format="sentences", sum_format="sentences")
954 |
955 | # When chunking reshaping the data.
956 | if chunk:
957 | art_words = ' '.join(article_sents)
958 |
959 | article_lines = [[]]
960 | word_count = 0
961 | for word in art_words.split(' '):
962 | article_lines[-1].append(word)
963 | word_count += 1
964 |
965 | if word_count == self._max_enc_tok_per_sent:
966 | word_count = 0
967 | article_lines.append([])
968 |
969 | article_sents = []
970 | for line in article_lines:
971 | article_sents.append(' '.join(line))
972 |
973 | if len(article_sents) >= self._max_enc_sent: # Truncate no. of sentences.
974 | article_sents = article_sents[: self._max_enc_sent]
975 | else: # Add empty sentences.
976 | article_sents += (self._max_enc_sent -
977 | len(article_sents)) * ['']
978 |
979 | enc_inp_ids = []
980 | enc_inp_tokens = []
981 | for sent_idx, art_sent in enumerate(article_sents):
982 | # Break if max. no. of encoder sentences is reached.
983 | if sent_idx >= self._max_enc_sent:
984 | break
985 |
986 | enc_sent_tokens = art_sent.split(' ')
987 | if len(enc_sent_tokens) >= self._max_enc_tok_per_sent: # Truncate no. of tokens.
988 | enc_sent_tokens = enc_sent_tokens[: self._max_enc_tok_per_sent]
989 | else: # Pad.
990 | enc_sent_tokens += (self._max_enc_tok_per_sent -
991 | len(enc_sent_tokens)) * [params.PAD_TOKEN]
992 |
993 | # Encoder representation in the fixed vocabulary.
994 | enc_sent_ids = [self._vocab.word2id(w) for w in enc_sent_tokens]
995 |
996 | # Appending to the lists.
997 | enc_inp_ids.append(enc_sent_ids)
998 | enc_inp_tokens += enc_sent_tokens
999 |
1000 | # Encoder input representation in the extended vocabulary.
1001 | enc_inp_ids_ext_vocab = None
1002 | article_oovs = None
1003 | if self._use_pgen:
1004 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self._vocab)
1005 |
1006 | if len(article_oovs) > max_oov_size:
1007 | max_oov_size = len(article_oovs)
1008 |
1009 | # Appending to the input batch.
1010 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32)
1011 |
1012 | if self._use_pgen:
1013 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32)
1014 | ext_vocabs.append(article_oovs)
1015 |
1016 | # Summaries.
1017 | summaries.append(summary)
1018 |
1019 | # Resetting the pointer after the last batch.
1020 | if self._ptr == num_examples:
1021 | self._ptr = 0
1022 |
1023 | # Setting the pointer for the last batch.
1024 | if self._ptr + batch_size > num_examples:
1025 | self._ptr = num_examples - batch_size
1026 |
1027 | # Padding masks.
1028 | pad_id = self._vocab.word2id(params.PAD_TOKEN)
1029 | enc_pad_mask = (enc_inp != pad_id).astype(np.float32) # B x S_in x T_in.
1030 | enc_doc_mask = np.sum(enc_pad_mask, axis=2) # B x S_in.
1031 | enc_doc_mask = np.greater(enc_doc_mask, 0).astype(np.float32) # B x S_in.
1032 |
1033 | indices = list(range(start, end))
1034 | batches = [indices, summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask]
1035 | if self._use_pgen:
1036 | batches += [enc_inp_ext_vocab, ext_vocabs, max_oov_size]
1037 |
1038 | return batches
1039 |
1040 |
1041 | def average_gradients(tower_grads):
1042 | """
1043 | This function collects gradients from all towers and returns the average gradient.
1044 | :param tower_grads: List of gradients from all towers.
1045 | :return: List of average gradients of all variables.
1046 | """
1047 | average_grads = []
1048 | for grad_and_vars in zip(*tower_grads):
1049 | # For M variables and N GPUs, tower_grads is of the form
1050 | # ((grad0_gpu0, var0_gpu0), (grad0_gpu1, var0_gpu1), ..., (grad0_gpuN, var0_gpuN))
1051 | # ((grad1_gpu0, var1_gpu0), (grad1_gpu1, var1_gpu1), ..., (grad1_gpuN, var1_gpuN))
1052 | # ......
1053 | # ((gradM_gpu0, varM_gpu0), (gradM_gpu1, varM_gpu1), ..., (gradM_gpuN, varM_gpuN))
1054 |
1055 | grads = []
1056 | for g, _ in grad_and_vars:
1057 | # Adding an extra dimension for concatenation later.
1058 | expanded_g = tf.expand_dims(g, 0)
1059 | grads.append(expanded_g)
1060 |
1061 | grad = tf.concat(axis=0, values=grads) # Concatenation along added dimension.
1062 | grad = tf.reduce_mean(grad, axis=0) # Average gradient.
1063 |
1064 | # Variable name is same in all the GPUs. So, it suffices to use the one in the 1st GPU.
1065 | v = grad_and_vars[0][1]
1066 | grad_and_var = (grad, v)
1067 | average_grads.append(grad_and_var)
1068 |
1069 | return average_grads
1070 |
1071 |
1072 | def eval_model(path_to_results):
1073 | """
1074 | This function is for the ROUGE evaluation.
1075 | :return:
1076 | """
1077 | r = pyrouge.Rouge155()
1078 | r.system_dir = path_to_results + 'predictions/'
1079 | r.model_dir = path_to_results + 'groundtruths/'
1080 | r.system_filename_pattern = '(\d+)_pred.txt'
1081 | r.model_filename_pattern = '#ID#_gt.txt'
1082 | rouge_results = r.convert_and_evaluate()
1083 |
1084 | rouge_dict = r.output_to_dict(rouge_results)
1085 | rouge_log(rouge_dict, path_to_results)
1086 |
1087 |
1088 | def rouge_log(results_dict, path_to_results):
1089 | """
1090 | This function saves the rouge results into a file.
1091 | :param results_dict: Dictionary output from pyrouge consisting of ROUGE results.
1092 | :param path_to_results: Path where the results file has to be stored.
1093 | :return:
1094 | """
1095 | log_str = ""
1096 | for x in ["1", "2", "l"]:
1097 |
1098 | log_str += "\nROUGE-%s:\n" % x
1099 | for y in ["f_score", "recall", "precision"]:
1100 | key = "rouge_%s_%s" % (x, y)
1101 | key_cb = key + "_cb"
1102 | key_ce = key + "_ce"
1103 | val = results_dict[key]
1104 | val_cb = results_dict[key_cb]
1105 | val_ce = results_dict[key_ce]
1106 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
1107 |
1108 | results_file = path_to_results + 'ROUGE_results.txt'
1109 | with open(results_file, "w") as f:
1110 | f.write(log_str)
1111 |
1112 |
1113 | def get_running_avg_loss(loss, running_avg_loss, decay=0.99):
1114 | """
1115 | This function updates the running averages loss.
1116 | :param loss: Loss at the current step.
1117 | :param running_avg_loss: Running average loss after the previous step.
1118 | :param decay: The decay rate.
1119 | :return:
1120 | """
1121 |
1122 | if running_avg_loss == 0:
1123 | running_avg_loss = loss
1124 | else:
1125 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
1126 |
1127 | return running_avg_loss
1128 |
1129 |
1130 | def get_dependencies(tensor):
1131 | dependencies = set()
1132 | dependencies.update(tensor.op.inputs)
1133 | for sub_op in tensor.op.inputs:
1134 | dependencies.update(get_dependencies(sub_op))
1135 |
1136 | return dependencies
1137 |
1138 |
1139 | def get_placeholder_dependencies(tensor):
1140 | dependencies = get_dependencies(tensor)
1141 | dependencies = [tensor for tensor in dependencies if tensor.op.type == "Placeholder"]
1142 |
1143 | return dependencies
1144 |
--------------------------------------------------------------------------------
/codes/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "Rajeev Bhatt Ambati"
3 |
4 | from NSE import NSE
5 | from utils import average_gradients, get_running_avg_loss
6 | import params as params
7 | from beam_search import run_beam_search
8 | from rouge_batch import rouge_l_fscore_batch as rouge_l_fscore
9 |
10 | import random
11 | import time
12 | import math
13 | import tensorflow as tf
14 | from tensorflow.nn.rnn_cell import LSTMStateTuple
15 | from tensorflow.contrib.framework.python.framework import checkpoint_utils
16 | import numpy as np
17 | from joblib import Parallel, delayed
18 |
19 | tf.set_random_seed(2019)
20 | random.seed(2019)
21 | tf.reset_default_graph()
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 |
26 | class SummarizationModel(object):
27 | def __init__(self, vocab, data):
28 | self.vocab = vocab
29 | self.data = data
30 |
31 | self._dense_init = tf.contrib.layers.variance_scaling_initializer()
32 |
33 | self._config = tf.ConfigProto()
34 | self._config.gpu_options.allow_growth = True
35 |
36 | self._best_val_loss = np.infty
37 | self._num_gpus = len(FLAGS.GPUs)
38 |
39 | self._saver = None # Saver.
40 | self._init = None # Initializer.
41 | self._sess = None # Session.
42 |
43 | def _create_placeholders(self):
44 | """
45 | This function creates the placeholders needed for the computation graph.
46 | :return:
47 | """
48 | self._global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
49 | self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
50 |
51 | # Word embedding.
52 | if FLAGS.use_pretrained:
53 | with tf.variable_scope('embed', reuse=tf.AUTO_REUSE):
54 | self.word_embedding = tf.get_variable(name="word_embedding",
55 | shape=[self.vocab.size(), FLAGS.dim],
56 | initializer=tf.constant_initializer(self.vocab.wvecs),
57 | dtype=tf.float32,
58 | trainable=True)
59 | else:
60 | self.word_embedding = tf.get_variable(name="word_embedding",
61 | shape=[self.vocab.size(), FLAGS.dim],
62 | dtype=tf.float32,
63 | trainable=True)
64 |
65 | # Graph Inputs/Outputs.
66 | self._enc_in = tf.placeholder(tf.int32, name='enc_in',
67 | shape=[FLAGS.batch_size, FLAGS.enc_steps]) # Shape: B x T_enc.
68 | self._enc_pad_mask = tf.placeholder(tf.float32, name='enc_pad_mask',
69 | shape=[FLAGS.batch_size, FLAGS.enc_steps]) # Shape: B x T_enc.
70 | self._dec_in = tf.placeholder(tf.int32, name='dec_in',
71 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec.
72 | inputs = [self._enc_in, self._enc_pad_mask, self._dec_in]
73 |
74 | if FLAGS.use_pgen:
75 | self._enc_in_ext_vocab = tf.placeholder(tf.int32, name='enc_in_ext_vocab',
76 | shape=[FLAGS.batch_size,
77 | FLAGS.enc_steps]) # Shape: B x T_enc.
78 | self._max_oov_size = tf.placeholder(tf.int32, name='max_oov_size')
79 | inputs.append(self._enc_in_ext_vocab)
80 |
81 | if FLAGS.mode == "train":
82 | self._dec_out = tf.placeholder(tf.int32, name='dec_out',
83 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec.
84 | self._dec_pad_mask = tf.placeholder(tf.float32, name='dec_pad_mask',
85 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec.
86 | inputs += [self._dec_out, self._dec_pad_mask]
87 |
88 | return inputs
89 |
90 | def _create_writers(self):
91 | """
92 | This function creates the summaries and writers needed for visualization through tensorboard.
93 | :return: writers
94 | """
95 | self._mean_crossentropy_loss = tf.placeholder(tf.float32, name='mean_crossentropy_loss')
96 | self._mean_rouge_score = tf.placeholder(dtype=tf.float32, name="mean_rouge")
97 |
98 | # Summaries.
99 | cross_entropy_summary = tf.summary.scalar('cross_entropy', self._mean_crossentropy_loss)
100 | rouge_summary = None
101 | if FLAGS.rouge_summary:
102 | rouge_summary = tf.summary.scalar('rouge_score', self._mean_rouge_score)
103 | self._summaries = tf.summary.merge_all()
104 |
105 | summary_list = [cross_entropy_summary]
106 | if FLAGS.rouge_summary:
107 | summary_list.append(rouge_summary)
108 |
109 | self._val_summaries = tf.summary.merge(summary_list, name="validation_summaries")
110 |
111 | # Summary Writers.
112 | self._train_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'train') # , self._sess.graph)
113 | self._val_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'val')
114 |
115 | def build_graph(self):
116 | if FLAGS.mode == 'train':
117 | if len(FLAGS.GPUs) > 1:
118 | self._parallel_model() # Parallel model in case of multiple-GPUs.
119 | else:
120 | self._single_model() # Single model for a single GPU.
121 |
122 | if FLAGS.mode == 'test':
123 | inputs = self._create_placeholders() # [enc_in, dec_in, enc_in_ext_vocab]
124 | self._forward(inputs) # Predictions Shape: Bm x 1
125 |
126 | self._saver = tf.train.Saver(tf.global_variables()) # Saver.
127 | self._init = tf.global_variables_initializer() # Initializer.
128 | self._sess = tf.Session(config=self._config) # Session.
129 |
130 | # Restoring the trained model.
131 | self._saver.restore(self._sess, FLAGS.PathToCheckpoint)
132 |
133 | def _forward(self, inputs):
134 | """
135 | This function creates the TensorFlow computation graph.
136 | :param inputs: A list of input placeholders.
137 | [0]: enc_in, Encoder input sequence of ID's, Shape: B x T_in.
138 | [1]: enc_pad_mask: Mask for padding tokens in encoder input, Shape: B x T_in.
139 | [2]: dec_in: Input to the decoder, Shape: B x T_out.
140 | Following additional input in pointer generator mode.
141 | [3]: enc_in_ext_vocab: Encoder input representation in the extended vocabulary, Shape: B x T_in.
142 | [4]: labels: Ground-Truth labels, Only in train mode, Shape: B x T_out.
143 | [5]: dec_pad_mask: Mask for padding tokens in decoder output, Shape: B x T_out
144 | :return: returns loss in train/val mode and predictions in test mode.
145 | """
146 | batch_size = inputs[0].get_shape().as_list()[0] # Batch-size.
147 |
148 | # The NSE instance.
149 | self._nse = NSE(batch_size=batch_size, dim=FLAGS.dim, dense_init=self._dense_init,
150 | mode=FLAGS.mode, use_comp_lstm=FLAGS.use_comp_lstm)
151 |
152 | # Encoder.
153 | self._prev_states, self._p_attns_enc = self._encoder(inputs[: 2])
154 |
155 | # Decoder.
156 | outputs = self._decoder(inputs[1: 4], self._prev_states)
157 |
158 | if FLAGS.mode == 'test':
159 | self._topk_ids, self._topk_log_probs, self._curr_states, self._p_attns_dec = outputs[: 4]
160 |
161 | if FLAGS.use_pgen:
162 | self._p_gens = outputs[4]
163 |
164 | else:
165 | probs, p_attns_dec = outputs
166 |
167 | crossentropy_loss = self._get_crossentropy_loss(probs, inputs[-2], inputs[-1])
168 |
169 | # Evaluating a few samples.
170 | samples = []
171 | if FLAGS.rouge_summary:
172 | samples = self._get_samples(inputs[1: 5], self._prev_states)
173 |
174 | return crossentropy_loss, samples
175 |
176 | def _encoder(self, inputs):
177 | """
178 | This is the encoder.
179 | :param inputs: A list of the following inputs.
180 | enc_in: Encoder input sequence of ID's, Shape: B x T_in.
181 | enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x T_in.
182 | :return:
183 | output:
184 | [0]: p_attn: Attention distribution while encoding, Shape: B x T_in.
185 | [1]: prev_coverage: Coverage vector after the last encoding step, Shape: B x T_in.
186 | Returns the internal states of NSE after the last encoding step.
187 | memory: NSE memory after the last time step, Shape: B x T_in x D.
188 | read_state: Hidden state of the read LSTM after the last time step, (c, h) Shape: 2 * [B x D].
189 | write_state: Hidden state of write LSTM after the last encoding step, (c, h) Shape: 2 * [B x D].
190 | comp_state: Hidden state of compose LSTM after last encoding step, (c, h) Shape: 2 * [B x (2*D)].
191 | """
192 | enc_in, enc_pad_mask = inputs
193 |
194 | # Converting ID's to word-vectors.
195 | enc_in_vecs = tf.nn.embedding_lookup(params=self.word_embedding, ids=enc_in) # Shape: B x T_in x D.
196 | enc_in_vecs = tf.cast(enc_in_vecs, dtype=tf.float32) # Cast to float32
197 |
198 | p_attns = []
199 | state = [enc_in_vecs, None, None, None]
200 | for i in range(FLAGS.enc_steps):
201 | x_t = enc_in_vecs[:, i, :] # Shape: B x D.
202 | output, state = self._nse.step(
203 | x_t=x_t, mem_mask=enc_pad_mask, prev_state=state) # One step of NSE.
204 |
205 | return state, p_attns
206 |
207 | def _decoder(self, inputs, state):
208 | """
209 | This is the decoder.
210 | :param inputs: List of the following inputs.
211 | [0]: enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x T_in.
212 | [1]: dec_in: Input to the decoder.
213 | Test mode: Output of previous time step, Shape: Bm x 1.
214 | Train/Val mode: decoder input sequence, Shape: B x T_out.
215 | [2]: enc_in_ext_vocab (Only in test mode):
216 | Encoder input representation in the extended vocabulary in
217 | pointer generator mode, Shape: B x T_in.
218 | :param state: Internal states of NSE after the encoding is done.
219 | [0]: memory: NSE memory after the last time step of encoder, Shape: B x T_in x D.
220 | [1]: read_state: Hidden state of read LSTM after the previous step, (c, h), Shape: (B x D, B x D)
221 | [2]: write_state: Hidden state of write LSTM after the previous step, (c, h), Shape: (B x D, B x D)
222 | [3]: comp_state: Hidden state of compose LSTM after the previous step, (c, h), Shape: (B x D, B x D)
223 | :return:
224 | In test mode:
225 | [topk_ids, topk_log_probs, state, p_attns] + [p_gens]
226 | In train mode:
227 | [p_cums, p_attns]
228 | """
229 | enc_pad_mask, dec_in = inputs[0], inputs[1]
230 |
231 | # Converting ID's to word-vectors.
232 | dec_in_vecs = tf.nn.embedding_lookup(params=self.word_embedding, ids=dec_in) # Shape: B x T_out x D.
233 | dec_in_vecs = tf.cast(dec_in_vecs, tf.float32) # Cast to tf.float32.
234 |
235 | writes = []
236 | p_attns = []
237 | p_gens = []
238 | for i in range(FLAGS.dec_steps):
239 | x_t = dec_in_vecs[:, i, :] # Shape: B x D.
240 | output, state = self._nse.step(
241 | x_t=x_t, mem_mask=enc_pad_mask, prev_state=state, use_pgen=FLAGS.use_pgen
242 | ) # One-step of NSE.
243 |
244 | # Appending the outputs.
245 | writes.append(output[0])
246 | p_attns.append(output[1])
247 | if FLAGS.use_pgen:
248 | p_gens.append(output[2])
249 |
250 | p_vocabs = self._get_vocab_dist(writes) # Shape: T_out * [B x vsize]
251 |
252 | p_cums = p_vocabs
253 | if FLAGS.use_pgen:
254 | enc_in_ext_vocab = inputs[2] # Shape: T_in x B
255 | p_cums = self._get_cumulative_dist(
256 | p_vocabs, p_gens, p_attns, enc_in_ext_vocab
257 | ) # Shape: T_out * [B x ext_vsize]
258 |
259 | if FLAGS.mode == 'test':
260 | p_final = p_cums[0] # Shape: Bm x ext_vsize
261 |
262 | # The top k predictions, topk_ids, Shape: Bm x (2*Bm).
263 | # Respective probabilities, topk_probs, Shape: Bm x (2*Bm).
264 | topk_probs, topk_ids = tf.nn.top_k(p_final, k=2*FLAGS.beam_size, name='topk_preds')
265 | topk_log_probs = tf.log(tf.clip_by_value(topk_probs, 1e-10, 1.0))
266 |
267 | outputs = [topk_ids, topk_log_probs, state, p_attns]
268 | if FLAGS.use_pgen:
269 | outputs.append(p_gens)
270 |
271 | return outputs
272 | else:
273 | return p_cums, p_attns
274 |
275 | def _get_samples(self, inputs, state):
276 | """
277 | This is the decoder.
278 | :param inputs: List of the following inputs.
279 | [0]: enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x T_in.
280 | [1]: dec_in: Input to the decoder.
281 | Test mode: Output of previous time step, Shape: Bm x 1.
282 | Train/Val mode: decoder input sequence, Shape: B x T_out.
283 | [2]: enc_in_ext_vocab (Only in test mode):
284 | Encoder input representation in the extended vocabulary in
285 | pointer generator mode, Shape: B x T_in.
286 | :param state: Internal states of NSE after the encoding is done.
287 | [0]: memory: NSE memory after the last time step of encoder, Shape: B x T_in x D.
288 | [1]: read_state: Hidden state of read LSTM after the previous step, (c, h), Shape: (B x D, B x D)
289 | [2]: write_state: Hidden state of write LSTM after the previous step, (c, h), Shape: (B x D, B x D)
290 | [3]: comp_state: Hidden state of compose LSTM after the previous step, (c, h), Shape: (B x D, B x D)
291 | :return:
292 | In test mode:
293 | [topk_ids, topk_log_probs, state, p_attns] + [p_gens]
294 | In train mode:
295 | [p_cums, p_attns]
296 | """
297 | unk_id = self.vocab.word2id(params.UNKNOWN_TOKEN) # Unknown ID.
298 | start_id = self.vocab.word2id(params.START_DECODING) # Start ID.
299 |
300 | enc_pad_mask = inputs[0] # Encoder mask.
301 | batch_size = enc_pad_mask.get_shape().as_list()[0]
302 | samples = [] # Sampling Outputs.
303 | for i in range(FLAGS.dec_steps):
304 |
305 | if i == 0:
306 | id_t = tf.fill([batch_size], start_id) # Shape: B x .
307 | else:
308 | id_t = samples[-1][:, 0]
309 | # Replacing the ID's from external vocabulary (if any) to UNK id's.
310 | id_t = tf.where(
311 | tf.less(id_t, self.vocab.size()), id_t, unk_id * tf.ones_like(id_t)
312 | )
313 |
314 | # Getting the word vector.
315 | x_t = tf.nn.embedding_lookup(params=self.word_embedding, ids=id_t) # Shape: B x D.
316 | x_t = tf.cast(x_t, dtype=tf.float32)
317 |
318 | output, state = self._nse.step(
319 | x_t=x_t, mem_mask=enc_pad_mask, prev_state=state, use_pgen=FLAGS.use_pgen
320 | ) # One-step of NSE.
321 |
322 | # Output probability distribution.
323 | p_vocab = self._get_vocab_dist([output[0]]) # Shape: [B x vsize].
324 |
325 | # Calculate cumulative probability distribution using pointer mechanism.
326 | p_cum = p_vocab
327 | if FLAGS.use_pgen:
328 | p_cum = self._get_cumulative_dist(
329 | p_vocabs=p_vocab, p_gens=[output[2]], p_attns=[output[1]], enc_in_ext_vocab=inputs[2]
330 | ) # Shape: T_dec * [B x ext_vsize].
331 |
332 | # Greedy sampling.
333 | _, gs_sample = tf.nn.top_k(p_cum[0], k=FLAGS.num_samples) # Shape: B x 1.
334 | samples.append(gs_sample)
335 |
336 | samples = tf.concat(samples, axis=1) # Shape: B x T_dec.
337 |
338 | return samples
339 |
340 | def run_encoder(self, enc_in_batch, enc_pad_mask):
341 | """
342 | This function calculates the internal states of NSE after the last step of encoding.
343 | :param enc_in_batch: A batch of encoder input sequence, Shape: Bm x T_enc.
344 | :param enc_pad_mask: Padding mask for encoder input, Shape: Bm x T_enc.
345 | :return:
346 | NSE internal states after encoding.
347 | final_memory: NSE memory after last time-step of encoding, Shape: 1 x mem_size x D.
348 | final_read_state: Hidden state of read LSTM after last time-step of encoding,
349 | (c, h) Shape: (1 x D, 1 x D).
350 | final_write_state: Hidden state of write LSTM after last time-step of encoding,
351 | (c, h) Shape: (1 x D, 1 x D).
352 | final_comp_state: Hidden state of compose LSTM after last time-step of encoding,
353 | (c, h) Shape: (1 x 2D, 1 x 2D).
354 | Additionally, returns the following when using encoder coverage:
355 | p_attns: Encoder attention distributions, Shape: T_enc * [1 x T_enc].
356 | """
357 | to_return = [self._prev_states]
358 |
359 | outputs = self._sess.run(to_return, feed_dict={self._enc_in: enc_in_batch,
360 | self._enc_pad_mask: enc_pad_mask})
361 | states = outputs[0]
362 |
363 | final_memory, final_read_state, final_write_state = states[: 3]
364 | # Since the states repeated values, slicing only first one.
365 | states[0] = final_memory[0, np.newaxis, :, :] # Memory, Shape: 1 x mem_size x D.
366 | states[1] = LSTMStateTuple(final_read_state.c[0, np.newaxis, :],
367 | final_read_state.h[0, np.newaxis, :]) # Shape: (1 x D, 1 x D).
368 | states[2] = LSTMStateTuple(final_write_state.c[0, np.newaxis, :],
369 | final_write_state.h[0, np.newaxis, :]) # Shape: (1 x D, 1 x D).
370 |
371 | if FLAGS.use_comp_lstm:
372 | final_comp_state = states[3]
373 | states[3] = LSTMStateTuple(final_comp_state.c[0, np.newaxis, :],
374 | final_comp_state.h[0, np.newaxis, :]) # Shape: (1 x (2D), 1 x (2D)).
375 |
376 | return states
377 |
378 | def decode_one_step(self, inputs, prev_states):
379 | """
380 | This function performs one step of decoding.
381 | :param inputs:
382 | dec_in_batch: The input to the decoder. This is the output from previous time-step, Shape: Bm * [1 x]
383 | enc_pad_mask: Padding mask for the memory, Shape: Bm x T.
384 | In pointer generator mode, there are following additional inputs:
385 | enc_in_ex_vocab_batch: Encoder input sequence represented in extended vocabulary, Shape: Bm x T_in.
386 | max_oov_size: Size of the largest OOV tokens in the current batch, Shape: ()
387 | :param prev_states: previous internal states of NSE of all Bm hypothesis, Bm * [prev_state].
388 | where state is a list of internal states of NSE for a single hypothesis.
389 | prev_state = [prev_memory, prev_read_state, prev_write_state]
390 | prev_memory: NSE memory after the previous time step, Shape: 1 x T x D.
391 | prev_read_state: Hidden state of read LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D].
392 | prev_write_state: Hidden state of write LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D].
393 | prev_comp_state: Hidden state of compose LSTM after previous time step, (c, h) Shape: [1 x 2D, 1 x 2D].
394 | :return:
395 | topk_ids: Top-k predictions in the current step, Shape: Bm x (2*Bm)
396 | topk_log_probs: log probabilities of top-k predictions, Shape: Bm x (2*Bm)
397 | curr_states: Current internal states of NSE, Bm * [state].
398 | [0]: memory: NSE memory, Shape: Bm x T x D.
399 | [1]: read_state, (c, h) Shape: (Bm x D, Bm x D).
400 | [2]: write_state, (c, h) Shape: (Bm x D, Bm x D).
401 | [3]: comp_state, (c, h) Shape: (Bm x 2D, Bm x 2D).
402 | p_gens: Generation probabilities, Shape: Bm x .
403 | p_attns: Attention probabilities, Shape: Bm x mem_size.
404 | """
405 | # Decoder input
406 | dec_in = inputs[0] # Shape: Bm * [1 x]
407 | dec_in = np.stack(dec_in, axis=0) # Shape: Bm x
408 | inputs[0] = np.expand_dims(dec_in, axis=-1) # Shape: Bm x 1
409 |
410 | # Previous memories of Bm hypothesis.
411 | prev_memories = [state[0] for state in prev_states] # Shape: Bm * [1 x T x D].
412 | prev_memories = np.concatenate(prev_memories, axis=0) # Shape: Bm x T x D.
413 |
414 | # Previous read states of Bm hypothesis.
415 | # Cell states.
416 | prev_read_states_c = [state[1].c for state in prev_states] # Shape: Bm * [1 x D].
417 | prev_read_states_c = np.concatenate(prev_read_states_c, axis=0) # Shape: Bm x D.
418 |
419 | # Hidden states.
420 | prev_read_states_h = [state[1].h for state in prev_states] # Shape: Bm * [1 x D].
421 | prev_read_states_h = np.concatenate(prev_read_states_h, axis=0) # Shape: Bm x D.
422 |
423 | prev_read_states = LSTMStateTuple(prev_read_states_c, prev_read_states_h)
424 |
425 | # Previous write states of Bm hypothesis.
426 | # Cell states.
427 | prev_write_states_c = [state[2].c for state in prev_states] # Shape: Bm * [1 x D].
428 | prev_write_states_c = np.concatenate(prev_write_states_c, axis=0) # Shape: Bm x D.
429 |
430 | # Hidden states.
431 | prev_write_states_h = [state[2].h for state in prev_states] # Shape: Bm * [1 x D].
432 | prev_write_states_h = np.concatenate(prev_write_states_h, axis=0) # Shape: Bm x D.
433 |
434 | prev_write_states = LSTMStateTuple(prev_write_states_c, prev_write_states_h)
435 |
436 | feed_dict = {
437 | self._dec_in: inputs[0],
438 | self._enc_pad_mask: inputs[1],
439 | self._prev_states[0]: prev_memories,
440 | self._prev_states[1]: prev_read_states,
441 | self._prev_states[2]: prev_write_states
442 | }
443 |
444 | # Previous compose state of Bm hypothesis.
445 | # Cell states.
446 | if FLAGS.use_comp_lstm:
447 | prev_comp_states_c = [state[3].h for state in prev_states] # Shape: Bm * [1 x (2D)].
448 | prev_comp_states_c = np.concatenate(prev_comp_states_c, axis=0) # Shape: Bm x (2D).
449 |
450 | prev_comp_states_h = [state[3].h for state in prev_states] # Shape: Bm * [1 x (2D)].
451 | prev_comp_states_h = np.concatenate(prev_comp_states_h, axis=0) # Shape: Bm x (2D).
452 |
453 | prev_comp_states = LSTMStateTuple(prev_comp_states_c, prev_comp_states_h)
454 |
455 | feed_dict[self._prev_states[3]] = prev_comp_states
456 |
457 | if FLAGS.use_pgen:
458 | feed_dict[self._enc_in_ext_vocab] = inputs[2]
459 | feed_dict[self._max_oov_size] = inputs[3]
460 |
461 | to_return = [self._topk_ids, self._topk_log_probs, self._curr_states]
462 |
463 | if FLAGS.use_pgen:
464 | to_return += [self._p_gens, self._p_attns_dec]
465 |
466 | outputs = self._sess.run(to_return, feed_dict=feed_dict)
467 |
468 | # Preparing the state outputs into lists.
469 | curr_states = outputs[2]
470 |
471 | # Current memories.
472 | curr_memories = np.split(curr_states[0], FLAGS.beam_size, axis=0) # Bm * [1 x T x D].
473 |
474 | # Current read states.
475 | curr_read_states_c = np.split(curr_states[1].c, FLAGS.beam_size, axis=0) # Bm * [1 x D].
476 | curr_read_states_h = np.split(curr_states[1].h, FLAGS.beam_size, axis=0) # Bm * [1 x D].
477 | curr_read_states = [LSTMStateTuple(c, h)
478 | for c, h in zip(curr_read_states_c, curr_read_states_h)] # Bm * [(1 x D, 1 x D)].
479 |
480 | # Current write states.
481 | curr_write_states_c = np.split(curr_states[2].c, FLAGS.beam_size, axis=0) # Bm * [1 x D].
482 | curr_write_states_h = np.split(curr_states[2].h, FLAGS.beam_size, axis=0) # Bm * [1 x D].
483 | curr_write_states = [LSTMStateTuple(c, h)
484 | for c, h in zip(curr_write_states_c, curr_write_states_h)] # Bm * [(1 x D, 1 x D)].
485 |
486 | if FLAGS.use_comp_lstm:
487 | # Current compose states.
488 | curr_comp_states_c = np.split(curr_states[3].c, FLAGS.beam_size, axis=0) # Bm * [1 x 2D].
489 | curr_comp_states_h = np.split(curr_states[3].h, FLAGS.beam_size, axis=0) # Bm * [1 x 2D].
490 | curr_comp_states = [LSTMStateTuple(c, h)
491 | for c, h in zip(curr_comp_states_c, curr_comp_states_h)] # Bm * [(1 x 2D, 1 x 2D)].
492 |
493 | curr_states_list = [[memory, read_state, write_state, comp_state]
494 | for memory, read_state, write_state, comp_state in
495 | zip(curr_memories, curr_read_states, curr_write_states, curr_comp_states)]
496 | else:
497 | # Forming a list of internal states for Bm hypothesis.
498 | curr_states_list = [[memory, read_state, write_state] for memory, read_state, write_state
499 | in zip(curr_memories, curr_read_states, curr_write_states)]
500 |
501 | outputs[2] = curr_states_list
502 |
503 | # Generation probabilities.
504 | if FLAGS.use_pgen:
505 | p_gens = outputs[3] # Shape: 1 x [Bm x 1].
506 | p_gens = p_gens[0] # Only one time-step in decoding phase, Shape: Bm x 1.
507 | p_gens = np.squeeze(p_gens, axis=1) # Shape: Bm x
508 | outputs[3] = np.split(p_gens, FLAGS.beam_size, axis=0) # Shape: Bm * [1]
509 |
510 | # Attention probabilities.
511 | if FLAGS.use_pgen:
512 | p_attns = outputs[4] # Shape: 1 x [Bm x mem_size].
513 | p_attns = p_attns[0] # Only one time-step in the decoding phase, Shape: Bm x mem_size.
514 | outputs[4] = np.split(p_attns, FLAGS.beam_size, axis=0) # Shape: Bm * [1 x mem_size].
515 |
516 | return outputs
517 |
518 | def _get_vocab_dist(self, inputs):
519 | """
520 | This function passes the NSE hidden states obtained from decoder through the output layer (dense layer
521 | followed by a softmax layer) and returns the vocabulary distribution thus obtained.
522 | :param inputs: List of hidden states, Shape: T * [B x D].
523 | :return: p_vocab: List of vocabulary distributions for each time step, Shape: T * [B x vsize].
524 | """
525 | steps = len(inputs)
526 | vsize = self.vocab.size()
527 |
528 | inputs = tf.concat(inputs, axis=0) # Shape: (T*B) x D
529 |
530 | with tf.variable_scope('output', reuse=tf.AUTO_REUSE):
531 | scores = tf.layers.dense(inputs,
532 | units=vsize,
533 | activation=None,
534 | kernel_initializer=self._dense_init,
535 | name='output') # Shape: (T*B) x vsize
536 |
537 | p_vocab = tf.nn.softmax(scores, name='prob_scores') # Shape: (T*B) x vsize
538 |
539 | p_vocab = tf.split(p_vocab, num_or_size_splits=steps, axis=0) # Shape: T * [B x vsize]
540 |
541 | return p_vocab
542 |
543 | def _get_cumulative_dist(self, p_vocabs, p_gens, p_attns, enc_in_ext_vocab):
544 | """
545 | This function calculates the cumulative probability distribution from the vocabulary distribution,
546 | attention distribution and generation probabilities.
547 | :param p_vocabs: A list of vocabulary distributions for each time step, Shape: T_out * [B x vsize].
548 | :param p_gens: A list of generation probabilities for each time step, Shape: T_out * [B x 1].
549 | :param p_attns: A list of attention distributions for each time step, Shape: T_out * [B x T_in].
550 | :param enc_in_ext_vocab: Encoder input represented using extended vocabulary, Shape: B x T_in
551 | :return:
552 | """
553 | vsize = self.vocab.size()
554 | ext_vsize = vsize + self._max_oov_size
555 | batch_size, enc_steps = enc_in_ext_vocab.get_shape().as_list()
556 |
557 | p_vocabs = [p_gen * p_vocab for p_gen, p_vocab in zip(p_gens, p_vocabs)] # Shape: T_out * [B x vsize]
558 | p_attns = [(1.0 - p_gen) * p_attn for p_gen, p_attn in zip(p_gens, p_attns)] # Shape: T_out * [B x T_in]
559 |
560 | zero_vocab = tf.zeros(shape=[batch_size, self._max_oov_size])
561 | # Shape: T_out * [B x ext_vsize]
562 | p_vocabs_ext = [tf.concat([p_vocab, zero_vocab], axis=-1) for p_vocab in p_vocabs]
563 |
564 | idx = tf.range(0, limit=batch_size) # Shape: B x
565 | idx = tf.expand_dims(idx, axis=-1) # Shape: B x 1
566 | idx = tf.tile(idx, [1, enc_steps]) # Shape: B x T_in
567 | indices = tf.stack([idx, enc_in_ext_vocab], axis=-1) # Shape: B x T_in x 2
568 |
569 | # First, A zero matrix of shape B x ext_vsize is created. Then, the attention score for each input token
570 | # indexed as in indices is looked in p_attn and updated in the corresponding location,
571 | # Shape: T_out * [B x ext_vsize]
572 | p_attn_cum = [tf.scatter_nd(indices, p_attn, [batch_size, ext_vsize]) for p_attn in p_attns]
573 |
574 | # Cumulative distribution, Shape: T_out * [B x ext_vsize]
575 | p_cum = [gen_prob + copy_prob for gen_prob, copy_prob in zip(p_vocabs_ext, p_attn_cum)]
576 |
577 | return p_cum
578 |
579 | @ staticmethod
580 | def _get_crossentropy_loss(probs, labels, mask):
581 | """
582 | This function calculates the crossentropy loss from ground-truth and the probability scores.
583 | :param probs: Predicted probabilities, Shape: T * [B x vocab_size]
584 | :param labels: Ground Truth labels, Shape: B x T.
585 | :param mask: Mask to exclude PAD tokens while calculating loss, Shape: B x T.
586 | :return: average mini-batch loss.
587 | """
588 | if type(probs) is list:
589 | probs = tf.stack(probs, axis=1) # Shape: B x T x vsize.
590 |
591 | true_probs = tf.reduce_sum(
592 | probs * tf.one_hot(labels, depth=tf.shape(probs)[2]), axis=2
593 | ) # Shape: B x T.
594 | logprobs = -tf.log(tf.clip_by_value(true_probs, 1e-10, 1.0)) # Shape: B x T.
595 | xe_loss = tf.reduce_sum(logprobs * mask) / tf.reduce_sum(mask)
596 |
597 | return xe_loss
598 |
599 | def _restore(self):
600 | """
601 | This function restores parameters from a saved checkpoint.
602 | :return:
603 | """
604 | restore_path = FLAGS.PathToCheckpoint + "model_epoch10"
605 |
606 | if FLAGS.restore_checkpoint and tf.train.checkpoint_exists(restore_path):
607 | start = time.time()
608 |
609 | # Initializing all variables.
610 | print("Initializing all variables!\n")
611 | self._sess.run(self._init)
612 |
613 | # Restoring checkpoint variables.
614 | vars_restore = [v for v in self._global_vars if "Adam" not in v.name]
615 | restore_saver = tf.train.Saver(vars_restore)
616 | print("Restoring non-Adam parameters from a previous checkpoint.\n")
617 | restore_saver.restore(self._sess, restore_path)
618 |
619 | end = time.time()
620 | print("Restoring model took %.2f sec. \n" % (end - start))
621 | else:
622 | start = time.time()
623 | self._init.run(session=self._sess)
624 | end = time.time()
625 | print("Running initializer took %.2f time. \n" % (end - start))
626 |
627 | def train(self):
628 | """
629 | This function performs the training followed by validation after every few epochs and saves the best model
630 | if validation loss is decreased. It also writes the training summaries for visualization using tensorboard.
631 | :return:
632 | """
633 | # Restore from previous checkpoint if found one.
634 | self._restore()
635 |
636 | # No. of iterations
637 | num_train_iters_per_epoch = int(math.ceil(self.data.num_train_examples / FLAGS.batch_size))
638 |
639 | # Running Averages
640 | running_avg_crossentropy_loss = 0.0
641 | running_avg_rouge = 0.0
642 | for epoch in range(1, 1 + FLAGS.num_epochs):
643 | # Training
644 | for iteration in range(1, 1 + num_train_iters_per_epoch):
645 | feed_dict = self._get_feed_dict(split='train')
646 |
647 | to_return = [self._train_op, self._crossentropy_loss, self._global_step, self._summaries]
648 | if FLAGS.rouge_summary:
649 | to_return.append(self._samples)
650 |
651 | # Evaluate summaries for last batch.
652 | feed_dict[self._mean_crossentropy_loss] = running_avg_crossentropy_loss
653 | if FLAGS.rouge_summary:
654 | feed_dict[self._mean_rouge_score] = running_avg_rouge
655 |
656 | outputs = self._sess.run(to_return, feed_dict=feed_dict)
657 | _, crossentropy_loss, global_step = outputs[: 3]
658 |
659 | # Calculating ROUGE score.
660 | rouge_score = 0.0
661 | if FLAGS.rouge_summary:
662 | samples = outputs[4]
663 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x .
664 | gt_labels = feed_dict[self._dec_out]
665 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False)
666 |
667 | # Updating the running average losses.
668 | running_avg_crossentropy_loss = get_running_avg_loss(
669 | crossentropy_loss, running_avg_crossentropy_loss
670 | )
671 | running_avg_rouge = get_running_avg_loss(
672 | np.mean(rouge_score), running_avg_rouge
673 | )
674 |
675 | if ((iteration - 2) % FLAGS.summary_every == 0) or (iteration == num_train_iters_per_epoch):
676 | train_summary = outputs[3]
677 | self._train_writer.add_summary(train_summary, global_step)
678 |
679 | print("\rTraining Iteration: {}/{} ({:.1f}%)".format(
680 | iteration, num_train_iters_per_epoch, iteration * 100 / num_train_iters_per_epoch,
681 | ))
682 |
683 | if epoch % FLAGS.val_every == 0:
684 | self._validate()
685 |
686 | def _validate(self):
687 | """
688 | This function validates the saved model.
689 | :param epoch: Used while writing the validation summary.
690 | :return: N/A.
691 | """
692 | # Validation
693 | num_val_iters_per_epoch = int(math.ceil(self.data.num_val_examples / FLAGS.val_batch_size))
694 | num_train_iters_per_epoch = int(math.ceil(self.data.num_train_examples / FLAGS.batch_size))
695 |
696 | outputs = None
697 | total_crossentropy_loss = 0.0
698 | total_rouge_score = 0.0
699 | for iteration in range(1, 1 + num_val_iters_per_epoch):
700 | feed_dict = self._get_feed_dict(split='val')
701 | to_return = [self._crossentropy_loss]
702 | if FLAGS.rouge_summary:
703 | to_return.append(self._samples)
704 |
705 | if iteration == num_val_iters_per_epoch:
706 | feed_dict[self._mean_crossentropy_loss] = total_crossentropy_loss / (num_val_iters_per_epoch - 1)
707 | feed_dict[self._mean_rouge_score] = total_rouge_score / (num_val_iters_per_epoch - 1)
708 | to_return += [self._val_summaries, self._global_step]
709 |
710 | outputs = self._sess.run(to_return, feed_dict=feed_dict)
711 | crossentropy_loss = outputs[0]
712 |
713 | # Calculating ROUGE score.
714 | rouge_score = 0.0
715 | if FLAGS.rouge_summary:
716 | samples = outputs[1]
717 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x .
718 | gt_labels = feed_dict[self._dec_out]
719 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False)
720 |
721 | print("\rValidation Iteration: {}/{} ({:.1f}%)".format(
722 | iteration, num_val_iters_per_epoch, iteration * 100 / num_val_iters_per_epoch,
723 | ))
724 |
725 | # Updating the total losses.
726 | total_crossentropy_loss += crossentropy_loss
727 | total_rouge_score += np.mean(rouge_score)
728 |
729 | # Writing the validation summaries for visualization.
730 | val_summary, global_step = outputs[-2], outputs[-1]
731 | self._val_writer.add_summary(val_summary, global_step)
732 | epoch = global_step // num_train_iters_per_epoch
733 |
734 | # Cumulative loss.
735 | avg_xe_loss = total_crossentropy_loss / num_val_iters_per_epoch
736 | if avg_xe_loss < self._best_val_loss:
737 | print("\rValidation loss improved :).")
738 | self._saver.save(self._sess, FLAGS.PathToCheckpoint + "model_epoch" + str(epoch))
739 |
740 | def test(self):
741 | """
742 | This function predicts the outputs for the test set.
743 | :return:
744 | """
745 | num_gpus = len(FLAGS.GPUs) # No. of GPUs.
746 | batch_size = num_gpus * FLAGS.num_pools # No. of examples processed per a single test iteration.
747 |
748 | # No. of iterations
749 | num_test_iters_per_epoch = int(math.ceil(self.data.num_test_examples / batch_size))
750 |
751 | for iteration in range(1, 1 + num_test_iters_per_epoch):
752 | start = time.time()
753 | # [indices, summaries, files[start: end], enc_inp, enc_padding_mask,
754 | # enc_inp_ext_vocab, ext_vocabs, max_oov_size]
755 | input_batches = self.data.get_batch(batch_size, "test")
756 |
757 | # Split the data into batches of size "num_pools" per GPU.
758 | input_batches_per_gpu = [[] for _ in range(num_gpus)]
759 | for i, input_batch in enumerate(input_batches):
760 | for j, idx in zip(range(num_gpus), range(0, batch_size, FLAGS.num_pools)):
761 | if i < 5 or (FLAGS.use_pgen and i < 7): # First 5 inputs (7 inputs in pgen mode).
762 | input_batches_per_gpu[j].append(
763 | input_batch[idx: idx + FLAGS.num_pools]
764 | )
765 | else: # max_oov_size is the same for all examples.
766 | input_batches_per_gpu[j].append(input_batch)
767 |
768 | # Appending the GPU id.
769 | for gpu_id in range(num_gpus):
770 | input_batches_per_gpu[gpu_id].append(gpu_id)
771 |
772 | Parallel(n_jobs=num_gpus, backend="threading")(
773 | map(delayed(self._test_one_gpu), input_batches_per_gpu))
774 |
775 | end = time.time()
776 | print("\rTesting Iteration: {}/{} ({:.1f}%) in {:.1f} sec.".format(
777 | iteration, num_test_iters_per_epoch, iteration * 100 / num_test_iters_per_epoch, end - start))
778 |
779 | def _test_one_gpu(self, inputs):
780 | """
781 | This function performs testing on the inputs of a single GPU.
782 | :param inputs:
783 | :return:
784 | """
785 | input_batches = inputs[: -1]
786 | gpu_idx = inputs[-1]
787 |
788 | with tf.device('/gpu:%d' % gpu_idx):
789 | self._test_one_pool(input_batches)
790 |
791 | def _test_one_pool(self, inputs):
792 | """
793 | This function performs testing on the inputs of a single GPU over parallel pools.
794 | :param inputs: All the first 7 inputs have "num_pools" examples each.
795 | [indices, summaries, files[start: end], enc_inp, enc_padding_mask,
796 | enc_inp_ext_vocab, ext_vocabs, max_oov_size]
797 | :return:
798 | """
799 | # Split the inputs into each example per pool.
800 | input_batches_per_pool = [[] for _ in range(FLAGS.num_pools)]
801 | for i, input_batch in enumerate(inputs):
802 | for j in range(FLAGS.num_pools):
803 | # Splitting the GT summaries.
804 | if i == 0 or i == 1 or i == 2 or (FLAGS.use_pgen and i == 6):
805 | input_batches_per_pool[j].append([input_batch[j]])
806 |
807 | # Repeating the same Numpy array for beam size no. of times.
808 | elif i == 3 or i == 4 or (FLAGS.use_pgen and i == 5):
809 | input_batches_per_pool[j].append(
810 | np.repeat(input_batch[np.newaxis, j], repeats=FLAGS.beam_size, axis=0))
811 |
812 | # max_oov_size is same for all examples in the pool.
813 | else:
814 | input_batches_per_pool[j].append(input_batch)
815 |
816 | Parallel(n_jobs=FLAGS.num_pools, backend="threading")(
817 | map(delayed(self._test_one_ex), input_batches_per_pool))
818 |
819 | def _test_one_ex(self, inputs):
820 | """
821 | This function performs testing on one example.
822 | :param inputs:
823 | [indices, summaries, files[start: end], enc_inp, enc_padding_mask,
824 | enc_inp_ext_vocab, ext_vocabs, max_oov_size]
825 | :return:
826 | """
827 | iteration = inputs[0][0]
828 | input_batches = inputs[1:]
829 | # Inputs for running beam search on one example.
830 | beam_search_inputs = input_batches[2: 4]
831 | if FLAGS.use_pgen:
832 | beam_search_inputs += [input_batches[4], input_batches[6]]
833 |
834 | best_hyp = run_beam_search(beam_search_inputs, self.run_encoder, self.decode_one_step, self.vocab)
835 | pred_ids = best_hyp.tokens # Shape: T_dec * [1 x]
836 | pred_ids = np.stack(pred_ids, axis=0) # Shape: T_dec x
837 | pred_ids = pred_ids[np.newaxis, :] # Shape: 1 x T_dec
838 |
839 | # Writing the outputs.
840 | if FLAGS.use_pgen:
841 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1], input_batches[5])
842 | else:
843 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1])
844 |
845 | def _get_feed_dict(self, split='train'):
846 |
847 | if split == "train":
848 | input_batches = self.data.get_batch(FLAGS.batch_size, split="train")
849 |
850 | elif split == "val":
851 | input_batches = self.data.get_batch(FLAGS.val_batch_size, split="val")
852 |
853 | else:
854 | raise ValueError("Split should be either train/val!! \n")
855 |
856 | feed_dict = {
857 | self._enc_in: input_batches[0],
858 | self._enc_pad_mask: input_batches[1],
859 | self._dec_in: input_batches[2],
860 | self._dec_out: input_batches[3],
861 | self._dec_pad_mask: input_batches[4]
862 | }
863 |
864 | if FLAGS.use_pgen:
865 | feed_dict[self._enc_in_ext_vocab] = input_batches[5]
866 | feed_dict[self._max_oov_size] = input_batches[6]
867 |
868 | return feed_dict
869 |
870 | @staticmethod
871 | def make_html_safe(s):
872 | """
873 | Replace any angled brackets in string s to avoid interfering with HTML attention visualizer.
874 | :param s: Input string.
875 | :return:
876 | """
877 | s.replace("<", "<")
878 | s.replace(">", ">")
879 |
880 | return s
881 |
882 | def _write_file(self, pred, pred_name, gt, gt_name, ext_vocab=None):
883 | """
884 | This function writes tokens in vals to a .txt file with given name.
885 | :param pred: Predicted ID's. Shape: 1 x T
886 | :param pred_name: Name of the file in which predictions will be written.
887 | :param gt: Ground truth summary, a list of sentences (strings).
888 | :param gt_name: Name of the file in which GTs will be written.
889 | :param ext_vocab: Extended vocabulary for each example. [ext_words x]
890 | :return: _pred, _gt files will be created for ROUGE evaluation.
891 | """
892 | # Writing predictions.
893 | vsize = self.vocab.size()
894 |
895 | # Removing the [START] token.
896 | pred = pred[1:]
897 |
898 | # Converting words to ID's
899 | pred_words = []
900 | for t in pred:
901 | try:
902 | pred_words.append(self.vocab.id2word(t))
903 | except ValueError:
904 | pred_words.append(ext_vocab[t - vsize])
905 |
906 | # Considering tokens only till STOP token.
907 | try:
908 | stop_idx = pred_words.index(params.STOP_DECODING)
909 | except ValueError:
910 | stop_idx = len(pred_words)
911 | pred_words = pred_words[: stop_idx]
912 |
913 | # Creating sentences out of the predicted sequence.
914 | pred_sents = []
915 | while pred_words:
916 | try:
917 | period_idx = pred_words.index(".")
918 | except ValueError:
919 | period_idx = len(pred_words)
920 |
921 | # Append the sentence.
922 | sent = pred_words[: period_idx + 1]
923 | pred_sents.append(" ".join(sent))
924 |
925 | # Consider the remaining words now.
926 | pred_words = pred_words[period_idx + 1:]
927 |
928 | # Making HTML safe.
929 | pred_sents = [self.make_html_safe(s) for s in pred_sents]
930 | gt_sents = [self.make_html_safe(s) for s in gt]
931 |
932 | # Writing predicted sentences.
933 | f = open(pred_name, 'w', encoding='utf-8')
934 | for i, sent in enumerate(pred_sents):
935 | f.write(sent) if i == len(pred_sents) - 1 else f.write(sent + "\n")
936 | f.close()
937 |
938 | # Writing GT sentences.
939 | f = open(gt_name, 'w', encoding='utf-8')
940 | for i, sent in enumerate(gt_sents):
941 | f.write(sent) if i == len(gt_sents) - 1 else f.write(sent + "\n")
942 | f.close()
943 |
944 | def _write_outputs(self, index, preds, gts, files, ext_vocabs=None):
945 | """
946 | This function writes the input files
947 | :param index: Number of the test example.
948 | :param preds: The predictions. Shape: B x T
949 | :param: gts: The ground truths. Shape: B x T
950 | :param files: The names of the files.
951 | :param ext_vocabs: Extended vocabularies for each example, Shape: B * [ext_words x]
952 | :return: Saves the predictions and GT's in a .txt format.
953 | """
954 | for i in range(len(files)):
955 | file, pred, gt = files[i], preds[i], gts[i]
956 | name_pred = FLAGS.PathToResults + 'predictions/' + '%06d_pred.txt' % index
957 | name_gt = FLAGS.PathToResults + 'groundtruths/' + '%06d_gt.txt' % index
958 |
959 | ext_vocab = None
960 | if FLAGS.use_pgen:
961 | ext_vocab = ext_vocabs[i]
962 |
963 | self._write_file(pred, name_pred, gt, name_gt, ext_vocab)
964 |
965 | def _single_model(self):
966 | placeholders = self._create_placeholders()
967 | self._crossentropy_loss, self._samples = self._forward(placeholders)
968 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer
969 | self._train_op = sgd_solver.minimize(
970 | self._crossentropy_loss, global_step=self._global_step, name='train_op')
971 |
972 | self._saver = tf.train.Saver(tf.global_variables()) # Saver.
973 | self._init = tf.global_variables_initializer() # Initializer.
974 | self._sess = tf.Session(config=self._config) # Session.
975 |
976 | # Train and validation summary writers.
977 | if FLAGS.mode == 'train':
978 | self._create_writers()
979 |
980 | self._global_vars = tf.global_variables()
981 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables())))
982 | # print(tf.trainable_variables())
983 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()])
984 | # print('No. of params = {:d}'.format(int(no_params)))
985 |
986 | def _parallel_model(self):
987 | with tf.Graph().as_default(), tf.device('/cpu:0'):
988 | placeholders = self._create_placeholders()
989 |
990 | # Splitting the placeholders for each GPU.
991 | placeholders_per_gpu = [[] for _ in range(self._num_gpus)]
992 | for placeholder in placeholders:
993 | splits = tf.split(placeholder, num_or_size_splits=self._num_gpus, axis=0)
994 | for i, split in enumerate(splits):
995 | placeholders_per_gpu[i].append(split)
996 |
997 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer
998 | tower_grads = [] # Gradients calculated in each tower
999 | with tf.variable_scope(tf.get_variable_scope()):
1000 | all_losses = []
1001 | all_samples = []
1002 | for i in range(self._num_gpus):
1003 | with tf.device('/gpu:%d' % i):
1004 | with tf.name_scope('tower_%d' % i):
1005 | crossentropy_loss, samples = self._forward(placeholders_per_gpu[i])
1006 |
1007 | # Gathering samples from all GPUs.
1008 | all_samples.append(samples)
1009 |
1010 | tf.get_variable_scope().reuse_variables()
1011 |
1012 | grads = sgd_solver.compute_gradients(loss_to_minimize)
1013 | tower_grads.append(grads)
1014 |
1015 | # Updating the losses
1016 | all_losses.append(crossentropy_loss)
1017 |
1018 | self._crossentropy_loss = tf.add_n(all_losses) / self._num_gpus
1019 |
1020 | # Samples from all the GPUs.
1021 | if FLAGS.rouge_summary:
1022 | self._samples = tf.concat(all_samples, axis=0)
1023 | else:
1024 | self._samples = []
1025 |
1026 | # Synchronization Point
1027 | gradients, variables = zip(*average_gradients(tower_grads))
1028 | gradients = [tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad) for grad in gradients]
1029 | gradients, global_norm = tf.clip_by_global_norm(gradients, FLAGS.max_grad_norm)
1030 |
1031 | # Summary for the global norm
1032 | tf.summary.scalar('global_norm', global_norm)
1033 |
1034 | # Histograms for gradients.
1035 | for grad, var in zip(gradients, variables):
1036 | if grad is not None:
1037 | tf.summary.histogram(var.op.name + '/gradients', grad)
1038 |
1039 | # Histograms for variables.
1040 | for var in tf.trainable_variables():
1041 | tf.summary.histogram(var.op.name, var)
1042 |
1043 | self._train_op = sgd_solver.apply_gradients(
1044 | zip(gradients, variables), global_step=self._global_step)
1045 |
1046 | self._saver = tf.train.Saver(tf.global_variables()) # Saver
1047 | self._init = tf.global_variables_initializer() # Initializer.
1048 | self._sess = tf.Session(config=self._config) # Session.
1049 |
1050 | # Train and validation summary writers.
1051 | if FLAGS.mode == 'train':
1052 | self._create_writers()
1053 |
1054 | self._global_vars = tf.global_variables()
1055 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables())))
1056 | # print(tf.trainable_variables())
1057 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()])
1058 | # print('No. of params = {:d}'.format(int(no_params)))
1059 |
--------------------------------------------------------------------------------
/codes/model_hier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "Rajeev Bhatt Ambati"
3 |
4 | from HierNSE import HierNSE
5 | from utils import average_gradients, get_running_avg_loss
6 | import params as params
7 | from beam_search import run_beam_search_hier
8 | from rouge_batch import rouge_l_fscore_batch as rouge_l_fscore
9 |
10 | import random
11 | import math
12 | import time
13 |
14 | import tensorflow as tf
15 | from tensorflow.nn.rnn_cell import LSTMStateTuple
16 | import numpy as np
17 | from joblib import Parallel, delayed
18 |
19 | tf.set_random_seed(2019)
20 | random.seed(2019)
21 | tf.reset_default_graph()
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 |
26 | class SummarizationModelHier(object):
27 | def __init__(self, vocab, data):
28 | self._vocab = vocab
29 | self._data = data
30 |
31 | self._dense_init = tf.contrib.layers.xavier_initializer()
32 |
33 | self._config = tf.ConfigProto()
34 | self._config.gpu_options.allow_growth = True
35 |
36 | self._best_val_loss = np.infty
37 | self._num_gpus = len(FLAGS.GPUs)
38 |
39 | self._sess = None
40 | self._saver = None
41 | self._init = None
42 |
43 | def _create_placeholders(self):
44 | """
45 | This function creates the placeholders needed for the computation graph.
46 | [enc_in, enc_pad_mask, enc_doc_mask, dec_in, enc_in_ext_vocab, labels, dec_pad_mask]
47 | :return:
48 | """
49 | self._global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
50 | self._learning_rate = tf.placeholder(tf.float32, name='learning_rate')
51 |
52 | # Word embedding.
53 | if FLAGS.use_pretrained:
54 | with tf.variable_scope("embed", reuse=tf.AUTO_REUSE):
55 | self._word_embedding = tf.get_variable(name="word_embedding",
56 | shape=[self._vocab.size(), FLAGS.dim],
57 | initializer=tf.constant_initializer(self._vocab.wvecs),
58 | dtype=tf.float32,
59 | trainable=True)
60 |
61 | else:
62 | self._word_embedding = tf.get_variable(name="word_embedding",
63 | shape=[self._vocab.size(), FLAGS.dim],
64 | dtype=tf.float32,
65 | trainable=True)
66 |
67 | # Graph Inputs/Outputs.
68 | self._enc_in = tf.placeholder(dtype=tf.int32, name='enc_in',
69 | shape=[FLAGS.batch_size, FLAGS.max_enc_sent,
70 | FLAGS.max_enc_steps_per_sent]) # Shape: B x S_in x T_in.
71 | self._enc_pad_mask = tf.placeholder(dtype=tf.float32, name='enc_pad_mask',
72 | shape=[FLAGS.batch_size, FLAGS.max_enc_sent,
73 | FLAGS.max_enc_steps_per_sent]) # Shape: B x S_in x T_in.
74 | self._enc_doc_mask = tf.placeholder(dtype=tf.float32, name='enc_doc_mask',
75 | shape=[FLAGS.batch_size, FLAGS.max_enc_sent]) # Shape: B x S_in.
76 | self._dec_in = tf.placeholder(dtype=tf.int32, name='dec_in',
77 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec.
78 |
79 | inputs = [self._enc_in, self._enc_pad_mask, self._enc_doc_mask, self._dec_in]
80 |
81 | # Additional inputs in pointer-generator mode.
82 | if FLAGS.use_pgen:
83 | self._enc_in_ext_vocab = tf.placeholder(dtype=tf.int32,
84 | name='enc_inp_ext_vocab',
85 | shape=[FLAGS.batch_size,
86 | FLAGS.max_enc_sent *
87 | FLAGS.max_enc_steps_per_sent]) # Shape: B x T_enc.
88 | self._max_oov_size = tf.placeholder(dtype=tf.int32, name='max_oov_size')
89 | inputs.append(self._enc_in_ext_vocab)
90 |
91 | # Additional ground-truth's when in training.
92 | if FLAGS.mode.lower() == "train":
93 | self._dec_out = tf.placeholder(dtype=tf.int32, name='dec_out',
94 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec.
95 | self._dec_pad_mask = tf.placeholder(dtype=tf.float32, name='dec_pad_mask',
96 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec.
97 | inputs += [self._dec_out, self._dec_pad_mask]
98 |
99 | return inputs
100 |
101 | def _create_writers(self):
102 | """
103 | This function creates the summaries and writers needed for visualization through tensorboard.
104 | :return: writers.
105 | """
106 | self._mean_crossentropy_loss = tf.placeholder(dtype=tf.float32, name='mean_crossentropy_loss')
107 | self._mean_rouge_score = tf.placeholder(dtype=tf.float32, name="mean_rouge")
108 |
109 | # Summaries.
110 | cross_entropy_summary = tf.summary.scalar('cross_entropy', self._mean_crossentropy_loss)
111 | rouge_summary = None
112 | if FLAGS.rouge_summary:
113 | rouge_summary = tf.summary.scalar('rouge_score', self._mean_rouge_score)
114 | self._summaries = tf.summary.merge_all()
115 |
116 | summary_list = [cross_entropy_summary]
117 | if FLAGS.rouge_summary:
118 | summary_list.append(rouge_summary)
119 |
120 | self._val_summaries = tf.summary.merge(summary_list, name="validation_summaries")
121 |
122 | # Summary writers.
123 | self._train_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'train') # , self._sess.graph)
124 | self._val_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'val')
125 |
126 | def build_graph(self):
127 | start = time.time()
128 | if FLAGS.mode == 'train':
129 | if len(FLAGS.GPUs) > 1:
130 | self._parallel_model() # Parallel model in case of multiple-GPUs.
131 | else:
132 | self._single_model() # Single model for a single GPU/CPU.
133 |
134 | if FLAGS.mode == 'test':
135 | inputs = self._create_placeholders() # [enc_in, dec_in, enc_in_ext_vocab]
136 | self._forward(inputs) # Predictions Shape: Bm x 1
137 |
138 | self._saver = tf.train.Saver(tf.global_variables()) # Saver.
139 | self._init = tf.global_variables_initializer() # Initializer.
140 | self._sess = tf.Session(config=self._config) # Session.
141 |
142 | # Restoring the trained model.
143 | self._saver.restore(self._sess, FLAGS.PathToCheckpoint)
144 |
145 | end = time.time()
146 | print("build_graph took %.2f sec. \n" % (end - start))
147 |
148 | def _forward(self, inputs):
149 | """
150 | This function creates the TensorFlow computation graph.
151 | :param inputs: A list of input placeholders.
152 | [0] enc_in: Encoder input sequence of ID's, Shape: B x S_in x T_in.
153 | [1] enc_pad_mask: Encoder input mask to indicate the presence of PAd tokens, Shape: B x S_in x T_in.
154 | [2] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in.
155 | [3] dec_in: Decoder input sequence of ID's, Shape: B x T_dec.
156 |
157 | Following additional input in pointer-generator mode.
158 | [4] enc_in_ext_vocab: Encoder input representation in the extended vocabulary, Shape: B x T_enc.
159 |
160 | [5] labels: Ground-Truth labels, Only in train mode, Shape: B x T_dec.
161 | [6] dec_pad_mask: Decoder output mask, Shape: B x T_dec.
162 | :return: returns loss in train mode and predictions in test mode.
163 | """
164 | batch_size = inputs[0].get_shape().as_list()[0] # Batch-size
165 | # The NSE instance
166 | self._nse = HierNSE(batch_size=batch_size, dim=FLAGS.dim, dense_init=self._dense_init,
167 | mode=FLAGS.mode, use_comp_lstm=FLAGS.use_comp_lstm, num_layers=FLAGS.num_layers)
168 |
169 | # Encoder, used while testing phase.
170 | self._prev_states = self._encoder(inputs[: 3])
171 |
172 | # Decoder.
173 | outputs = self._decoder(inputs[1: 5] + [inputs[-1]], self._prev_states)
174 |
175 | if FLAGS.mode.lower() == "test":
176 | self._topk_ids, self._topk_log_probs, self._curr_states, self._p_attns = outputs[: 4]
177 |
178 | if FLAGS.use_pgen:
179 | self._p_gens = outputs[4]
180 |
181 | else:
182 | probs, p_attns = outputs
183 |
184 | crossentropy_loss = self._get_crossentropy_loss(probs=probs, labels=inputs[-2], mask=inputs[-1])
185 |
186 | # Evaluating a few samples.
187 | samples = []
188 | if FLAGS.rouge_summary:
189 | samples = self._get_samples(inputs[1: 5], self._prev_states)
190 |
191 | return crossentropy_loss, samples
192 |
193 | def _encoder(self, inputs):
194 | """
195 | This is the encoder.
196 | :param inputs: A list of the following inputs.
197 | enc_in: Encoder input sequence of ID's, Shape: B x S_in x T_in.
198 | enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x S_in x T_in.
199 | enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in.
200 | :return:
201 | A list of internal states of NSE after the last encoding step.
202 | [0] memory: [sent_mems, doc_mem] The sentence and document memories respectively,
203 | Shape: [B x S_in x T_in x D, B x S_in x D].
204 | [1] read_state: Hidden state of the read LSTM after the last encoding step, (c, h) Shape: 2 * [B x D].
205 | [2] write_state: Hidden state of the write LSTM after the last encoding step, (c, h) Shape: 2 * [B x D].
206 | [3] comp_state: Hidden state of the compose LSTM after the last encoding step, (c, h) Shape: 2 * [B x D].
207 | """
208 | with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
209 | enc_in, enc_pad_mask, enc_doc_mask = inputs
210 |
211 | # Converting ID's to word-vectors.
212 | enc_in_vecs = tf.nn.embedding_lookup(params=self._word_embedding, ids=enc_in) # Shape: B x S_in x T_in x D.
213 | enc_in_vecs = tf.cast(enc_in_vecs, dtype=tf.float32) # Cast to float32.
214 |
215 | # Document memory.
216 | doc_mem = tf.reduce_mean(enc_in_vecs, axis=2, name="document_memory") # Shape: B x S_in x D.
217 |
218 | new_sent_mems = [] # New sentence memories.
219 | lstm_states = [None, None, None] # LSTM states.
220 |
221 | for i in range(FLAGS.max_enc_sent):
222 | # Mask, memory of ith sentence.
223 | sent_i_mask = enc_pad_mask[:, i, :] # Shape: B x T_in.
224 | sent_i_mem = enc_in_vecs[:, i, :, :] # Shape: B x T_in x D.
225 |
226 | mem_masks = [sent_i_mask, enc_doc_mask] # ith-sentence and document masks.
227 | state = [[sent_i_mem, doc_mem]] + lstm_states # NSE internal state.
228 |
229 | for j in range(FLAGS.max_enc_steps_per_sent):
230 | # j-th token from the ith sentence.
231 | x_t = enc_in_vecs[:, i, j, :] # Shape: B x D.
232 | output, state = self._nse.step(x_t=x_t, mem_masks=mem_masks, prev_state=state)
233 |
234 | # Update
235 | new_sent_mems.append(state[0][0]) # ith-sentence memory.
236 | doc_mem = state[0][1] # Document memory.
237 | lstm_states = state[1:] # Read, write, compose states.
238 |
239 | new_sent_mems = tf.concat(new_sent_mems, axis=1) # Shape: B x (S_in*T_in) x D.
240 | all_states = [[new_sent_mems, doc_mem]] + lstm_states
241 |
242 | return all_states
243 |
244 | def _decoder(self, inputs, all_states):
245 | """
246 | This is the decoder.
247 | :param inputs: A list of the following inputs.
248 | [0] enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x S_in x T_in.
249 | [1] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in.
250 | [2] dec_in: Input to the decoder, Shape: B x T_dec.
251 | Following additional inputs in pointer generator mode:
252 | [3] enc_in_ext_vocab: (For pointer generator mode)
253 | Encoder input representation in the extended vocabulary, Shape: B x T_enc.
254 | [4] dec_pad_mask: Decoder mask to indicate the presence of PAD tokens, Shape: B x T_dec.
255 |
256 | :param all_states: The internal states of NSE after the last encoding step.
257 | [0] memory: [sent_mems, doc_mem] The sentence and document memories respectively,
258 | Shape: [B x T_enc x D, B x S_in x D].
259 | [1] read_state: Hidden state of the read LSTM after the last encoding step,(c, h) Shape: 2 * [B x D].
260 | [2] write_state: Hidden state of the write LSTM after the last encoding step,(c, h) Shape: 2 * [B x D].
261 | [3] comp_state: Hidden state of the compose LSTM after the last encoding step,(c, h) Shape: 2 * [B x D].
262 | :return:
263 | In test mode:
264 | [topk_ids, topk_log_probs, state, p_attns] + [p_gens]
265 | In train mode:
266 | [p_cums, p_attns]
267 | """
268 | with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
269 | enc_pad_mask, enc_doc_mask, dec_in = inputs[: 3]
270 |
271 | # Concatenating the pad masks of all sentences.
272 | enc_pad_mask = tf.reshape(
273 | enc_pad_mask, [-1, FLAGS.max_enc_sent*FLAGS.max_enc_steps_per_sent]) # Shape: B x T_enc.
274 |
275 | # Converting ID's to word vectors.
276 | dec_in_vecs = tf.nn.embedding_lookup(params=self._word_embedding, ids=dec_in) # Shape: B x T_dec x D.
277 | dec_in_vecs = tf.cast(dec_in_vecs, dtype=tf.float32) # Shape: B x T_dec x D.
278 |
279 | # Memory masks.
280 | mem_masks = [enc_pad_mask, enc_doc_mask] # Shape: [B x T_enc, B x S_in].
281 |
282 | # NSE internal states.
283 | sent_mems, doc_mem = all_states[0] # Shape: B x T_enc x D, B x S_in x D.
284 | state = [[sent_mems, doc_mem]] + all_states[1:]
285 |
286 | writes = []
287 | p_attns = []
288 | p_gens = []
289 | for i in range(FLAGS.dec_steps):
290 |
291 | x_t = dec_in_vecs[:, i, :] # Shape: B x D.
292 | output, state = self._nse.step(
293 | x_t=x_t, mem_masks=mem_masks, prev_state=state, use_pgen=FLAGS.use_pgen
294 | )
295 |
296 | # Appending the outputs.
297 | writes.append(output[0])
298 | p_attns.append(output[1])
299 |
300 | if FLAGS.use_pgen:
301 | p_gens.append(output[2])
302 |
303 | p_vocabs = self._get_vocab_dist(writes) # Shape: T_dec * [B x vsize].
304 |
305 | p_cums = p_vocabs
306 | if FLAGS.use_pgen:
307 | enc_in_ext_vocab = inputs[3]
308 | p_cums = self._get_cumulative_dist(
309 | p_vocabs, p_gens, p_attns, enc_in_ext_vocab
310 | ) # Shape: T_dec * [B x ext_vsize].
311 |
312 | if FLAGS.mode.lower() == "test":
313 | p_final = p_cums[0]
314 |
315 | # The top k predictions, topk_ids, Shape: Bm * (2*Bm).
316 | # Respective probabilities, topk_probs, Shape: Bm * (2*Bm).
317 | topk_probs, topk_ids = tf.nn.top_k(p_final, k=2*FLAGS.beam_size, name="topk_preds")
318 | topk_log_probs = tf.log(tf.clip_by_value(topk_probs, 1e-10, 1.0))
319 |
320 | outputs = [topk_ids, topk_log_probs, state, p_attns]
321 | if FLAGS.use_pgen:
322 | outputs.append(p_gens)
323 |
324 | return outputs
325 | else:
326 | return p_cums, p_attns
327 |
328 | def _get_samples(self, inputs, state):
329 | """
330 | This function samples greedily from decoder output to calculate the ROUGE score.
331 | :param inputs: A list of the following inputs.
332 | [0] enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x S_in x T_in.
333 | [1] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in.
334 | [2] dec_in: Input to the decoder, Shape: B x T_dec.
335 | Following additional inputs in pointer generator mode:
336 | [3] enc_in_ext_vocab: (For pointer generator mode)
337 | Encoder input representation in the extended vocabulary, Shape: B x T_enc.
338 | :param state: The internal states of NSE after the last encoding step.
339 | [0] memory: [sent_mems, doc_mem] The sentence and document memories respectively,
340 | Shape: [B x T_enc x D, B x S_in x D].
341 | [1] read_state: Hidden state of the read LSTM after the last encoding step,(c, h) Shape: 2 * [B x D].
342 | [2] write_state: Hidden state of the write LSTM after the last encoding step,(c, h) Shape: 2 * [B x D].
343 | [3] comp_state: Hidden state of the compose LSTM after the last encoding step,(c, h) Shape: 2 * [B x D].
344 | :return:
345 | probs: Probabilities used for sampling, Shape: B x T_dec x V.
346 | samples: The samples, Shape: B x T_dec.
347 | """
348 | unk_id = self._vocab.word2id(params.UNKNOWN_TOKEN) # Unknown ID.
349 | start_id = self._vocab.word2id(params.START_DECODING) # Start ID.
350 |
351 | with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
352 | enc_pad_mask, enc_doc_mask = inputs[: 2]
353 | batch_size = enc_pad_mask.get_shape().as_list()[0]
354 | # Concatenating the pad masks of all sentences.
355 | enc_pad_mask = tf.reshape(
356 | enc_pad_mask,
357 | [-1, FLAGS.max_enc_sent * FLAGS.max_enc_steps_per_sent]) # Shape: B x T_enc.
358 |
359 | # Memory masks.
360 | mem_masks = [enc_pad_mask, enc_doc_mask] # Shape: [B x T_enc, B x S_in].
361 |
362 | samples = [] # Sampling outputs.
363 | for i in range(FLAGS.dec_steps):
364 |
365 | if i == 0:
366 | id_t = tf.fill([batch_size], start_id) # Shape: B x .
367 | else:
368 | id_t = samples[-1][:, 0]
369 | # Replacing the ID's from external vocabulary (if any) to UNK id's.
370 | id_t = tf.where(
371 | tf.less(id_t, self._vocab.size()), id_t, unk_id * tf.ones_like(id_t)
372 | )
373 |
374 | # Getting the word vector.
375 | x_t = tf.nn.embedding_lookup(params=self._word_embedding, ids=id_t) # Shape: B x D.
376 | x_t = tf.cast(x_t, dtype=tf.float32)
377 |
378 | output, state = self._nse.step(
379 | x_t=x_t, mem_masks=mem_masks, prev_state=state, use_pgen=FLAGS.use_pgen,
380 | )
381 |
382 | # Output probability distribution.
383 | p_vocab = self._get_vocab_dist([output[0]]) # Shape: [B x vsize].
384 |
385 | # Calculate cumulative probability distribution using pointer mechanism.
386 | p_cum = p_vocab
387 | if FLAGS.use_pgen:
388 | p_cum = self._get_cumulative_dist(
389 | p_vocabs=p_vocab, p_gens=[output[2]], p_attns=[output[1]], enc_in_ext_vocab=inputs[-1]
390 | ) # Shape: T_dec * [B x ext_vsize].
391 |
392 | # Greedy sampling.
393 | _, gs_sample = tf.nn.top_k(p_cum[0], k=FLAGS.num_samples) # Shape: B x 1.
394 | samples.append(gs_sample)
395 |
396 | samples = tf.concat(samples, axis=1) # Shape: B x T_dec.
397 |
398 | return samples
399 |
400 | def run_encoder(self, inputs):
401 | """
402 | This function calculates the internal states of NSE after the last step of encoding.
403 | :param inputs:
404 | enc_in_batch: A batch of encoder input sequence, Shape: Bm x S_in x T_in.
405 | enc_pad_mask: Encoder input mask, Shape: Bm x S_in x T_in.
406 | enc_doc_mask: Encoder document mask, Shape: Bm x S_in.
407 | :return:
408 | NSE internal states after encoding:
409 | final_memory: [sent_mems, doc_mem] The sentence and document memories respectively after last time-step
410 | of encoding. Shape: [1 x T_enc x D, 1 x S_in x D].
411 | final_read_state: Hidden state of read LSTM after last time-step of encoding,(c, h)
412 | Shape: (1 x D, 1 x D).
413 | final_write_state: Hidden state of write LSTM after last time-step of encoding,(c, h)
414 | Shape: (1 x D, 1 x D).
415 | final_comp_state: Hidden state of compose LSTM after last time-step of encoding,(c, h),
416 | Shape: (1 x 3D, 1 x 3D).
417 | Attention distributions: sentence, document attention.
418 | T_enc * [1 x T_enc, 1 x S_in].
419 |
420 | """
421 | to_return = [self._prev_states[0][0], self._prev_states[0][1]] + self._prev_states[1:]
422 |
423 | outputs = self._sess.run(to_return, feed_dict={self._enc_in: inputs[0],
424 | self._enc_pad_mask: inputs[1],
425 | self._enc_doc_mask: inputs[2]})
426 | final_memory_s, final_memory_d = outputs[: 2]
427 |
428 | # memory Shape: [Bm x T_enc x D, Bm x S_in x D].
429 | # read state Shape: (Bm x D, Bm x D).
430 | # write state Shape: (Bm x D, Bm x D).
431 | final_read_state, final_write_state = outputs[2: 4]
432 |
433 | # Since the states repeated values, slicing only first one.
434 | final_memory_s = final_memory_s[0, np.newaxis, :, :] # Shape: 1 x T_enc x D.
435 | final_memory_d = final_memory_d[0, np.newaxis, :, :] # Shape: 1 x S_in x D.
436 |
437 | final_memory = [final_memory_s, final_memory_d]
438 | final_comp_state = None
439 |
440 | def get_state_slice(inp_states):
441 | if FLAGS.num_layers == 1:
442 | return LSTMStateTuple(inp_states.c[0, np.newaxis, :],
443 | inp_states.h[0, np.newaxis, :])
444 |
445 | else:
446 | sliced_states = []
447 | for _, layer_state in enumerate(inp_states):
448 | if FLAGS.rnn.lower() == "lstm":
449 | sliced_states.append(
450 | LSTMStateTuple(layer_state.c[0, np.newaxis, :],
451 | layer_state.h[0, np.newaxis, :])
452 | )
453 | else:
454 | sliced_states.append(
455 | layer_state[0, np.newaxis, :]
456 | )
457 |
458 | return tuple(sliced_states)
459 |
460 | final_read_state = get_state_slice(final_read_state) # Shape: (1 x D, 1 x D).
461 | final_write_state = get_state_slice(final_write_state) # Shape: (1 x D, 1 x D).
462 | if FLAGS.use_comp_lstm:
463 | final_comp_state = get_state_slice(outputs[4]) # Shape: (1 x 3D, 1 x 3D).
464 |
465 | state = [final_memory, final_read_state, final_write_state]
466 | if FLAGS.use_comp_lstm:
467 | state.append(final_comp_state)
468 |
469 | return state
470 |
471 | def decode_one_step(self, inputs, prev_states):
472 | """
473 | This function performs one step of decoding.
474 | :param inputs:
475 | [0] dec_in_batch:
476 | The input to the decoder. This is the output from previous time-step, Shape: Bm * [1 x]
477 | [1] enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens.
478 | Useful to calculate the attention over memory, Shape: Bm x T_enc.
479 | [2] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: Bm x S_in.
480 | In pointer generator mode, there are following additional inputs:
481 | [3]: enc_in_ex_vocab_batch:
482 | Encoder input sequence represented in extended vocabulary, Shape: Bm x T_enc.
483 | [4]: max_oov_size: Size of the largest OOV tokens in the current batch, Shape: ()
484 | :param prev_states: previous internal states of NSE of all Bm hypothesis, Bm * [prev_state].
485 | where state is a list of internal states of NSE for a single hypothesis.
486 | prev_state = [prev_memory, prev_read_state, prev_write_state]
487 | prev_memory: [sent_mems, doc_mem] The sentence and document memories respectively after last
488 | previous time-step, Shape: [1 x T_enc x D, 1 x S_in x D].
489 | prev_read_state: Hidden state of read LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D].
490 | prev_write_state: Hidden state of write LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D].
491 | prev_comp_state: Hidden state of compose LSTM after previous time step, (c, h) Shape: [1 x 3D, 1 x 3D].
492 | :return:
493 | topk_ids: Top-k predictions in the current step, Shape: Bm x (2*Bm)
494 | topk_log_probs: log probabilities of top-k predictions, Shape: Bm x (2*Bm)
495 | curr_states: Current internal states of NSE, Bm * [state].
496 | [0]: memory: NSE sentence and document memories. [Bm x T_enc x D, Bm x S_in x D].
497 | [1]: read_state, (c, h) Shape: (Bm x D, Bm x D).
498 | [2]: write_state, (c, h) Shape: (Bm x D, Bm x D).
499 | [3]: comp_state, (c, h) Shape: (Bm x 3D, Bm x 3D).
500 | p_gens: Generation probabilities, Shape: Bm x .
501 | p_attns: Attention probabilities, Shape: [Bm x T_in, Bm x S_in].
502 | """
503 | # Decoder input
504 | dec_in = inputs[0] # Shape: Bm * [1 x]
505 | dec_in = np.stack(dec_in, axis=0) # Shape: Bm x
506 | inputs[0] = np.expand_dims(dec_in, axis=-1) # Shape: Bm x 1
507 |
508 | # Previous memories of Bm hypothesis.
509 | # Sentence memories.
510 | prev_memories_s = [state[0][0] for state in prev_states] # Shape: Bm * [1 x T_enc x D].
511 | prev_memories_s = np.concatenate(prev_memories_s, axis=0) # Shape: Bm x T_enc x D.
512 |
513 | # Document memory.
514 | prev_memories_d = [state[0][1] for state in prev_states] # Shape: Bm * [1 x S_in x D].
515 | prev_memories_d = np.concatenate(prev_memories_d, axis=0) # Shape: Bm x S_in x D.
516 |
517 | prev_memories = [prev_memories_s, prev_memories_d]
518 |
519 | def get_combined_states(inp_states):
520 | """
521 | A function to combine the states of Bm hypothesis.
522 | :param inp_states: List of states of Bm hypothesis. Bm * [(s1, s2, ..., s_l)]
523 | :return:
524 | """
525 | if FLAGS.num_layers == 1:
526 | # Cell states.
527 | combined_states_c = [hyp_state.c for hyp_state in inp_states] # Shape: Bm * [1 x D].
528 | combined_states_c = np.concatenate(combined_states_c, axis=0) # Shape: Bm x D.
529 |
530 | # Hidden states.
531 | combined_states_h = [hyp_state.h for hyp_state in inp_states] # Shape: Bm * [1 x D].
532 | combined_states_h = np.concatenate(combined_states_h, axis=0) # Shape: Bm x D.
533 |
534 | combined_states = LSTMStateTuple(combined_states_c,
535 | combined_states_h) # Shape: (Bm x D, Bm x D).
536 |
537 | else:
538 | combined_states = []
539 | for i in range(FLAGS.num_layers):
540 | # Cell states.
541 | layer_state_c = [layer_state[i].c for layer_state in inp_states] # Shape: Bm * [1 x D].
542 | layer_state_c = np.concatenate(layer_state_c, axis=0) # Shape: Bm x D.
543 |
544 | # Hidden states.
545 | layer_state_h = [layer_state[i].h for layer_state in inp_states] # Shape: Bm * [1 x D].
546 | layer_state_h = np.concatenate(layer_state_h, axis=0) # Shape: Bm x D.
547 |
548 | combined_states.append(
549 | LSTMStateTuple(layer_state_c, layer_state_h)
550 | )
551 |
552 | combined_states = tuple(combined_states)
553 |
554 | return combined_states
555 |
556 | prev_read_states = get_combined_states([state[1] for state in prev_states])
557 | prev_write_states = get_combined_states([state[2] for state in prev_states])
558 | if FLAGS.use_comp_lstm:
559 | prev_comp_states = get_combined_states([state[3] for state in prev_states])
560 | else:
561 | prev_comp_states = None
562 |
563 | feed_dict = {
564 | self._dec_in: inputs[0],
565 | self._enc_pad_mask: inputs[1],
566 | self._enc_doc_mask: inputs[2],
567 | self._prev_states[0][0]: prev_memories[0],
568 | self._prev_states[0][1]: prev_memories[1],
569 | self._prev_states[1]: prev_read_states,
570 | self._prev_states[2]: prev_write_states
571 | }
572 |
573 | if FLAGS.use_comp_lstm:
574 | feed_dict[self._prev_states[3]] = prev_comp_states
575 |
576 | if FLAGS.use_pgen:
577 | feed_dict[self._enc_in_ext_vocab] = inputs[3]
578 | feed_dict[self._max_oov_size] = inputs[4]
579 |
580 | to_return = [self._topk_ids, self._topk_log_probs, self._curr_states[0][0], self._curr_states[0][1],
581 | self._curr_states[1], self._curr_states[2]]
582 |
583 | if FLAGS.use_comp_lstm:
584 | to_return.append(self._curr_states[3])
585 |
586 | if FLAGS.use_pgen:
587 | to_return += [self._p_gens, self._p_attns]
588 |
589 | outputs = self._sess.run(to_return, feed_dict=feed_dict)
590 |
591 | # Preparing the next values (inputs and states for next time step).
592 | next_values = outputs[: 2]
593 |
594 | # Current memories.
595 | curr_memories_s = np.split(outputs[2], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x T_enc x D].
596 | curr_memories_d = np.split(outputs[3], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x S_in x D].
597 | curr_memories = [[mem_s, mem_d] for mem_s, mem_d in
598 | zip(curr_memories_s, curr_memories_d)] # Shape: Bm * [[1 x T_enc x D, 1 x S_in x D]].
599 |
600 | def get_states_split(inp_states):
601 | """
602 | This function splits the states for Bm hypothesis.
603 | :param inp_states:
604 | if just one layer:
605 | A NumPy array of shape Bm x D.
606 | for multiple layers:
607 | A tuple of states of RNN layers (s1, s2, ...., s_l).
608 | :return:
609 | """
610 | if FLAGS.num_layers == 1:
611 | split_states_c = np.split(inp_states.c, FLAGS.beam_size, axis=0) # Shape: Bm * [1 x D].
612 | split_states_h = np.split(inp_states.h, FLAGS.beam_size, axis=0) # Shape: Bm * [1 x D].
613 | split_states = [LSTMStateTuple(c, h)
614 | for c, h in zip(split_states_c, split_states_h)] # Shape: Bm * [(1 x D, 1 x D)].
615 |
616 | else:
617 | split_states = []
618 | for i in range(FLAGS.beam_size):
619 | hyp_state_c = [layer_state.c[i, np.newaxis, :]
620 | for layer_state in inp_states] # num_layers * [1 x D].
621 | hyp_state_h = [layer_state.h[i, np.newaxis, :]
622 | for layer_state in inp_states] # num_layers * [1 x D].
623 | hyp_state = [LSTMStateTuple(c, h)
624 | for c, h in zip(hyp_state_c, hyp_state_h)] # num_layers * [(1 x D, 1 x D)]
625 | split_states.append(tuple(hyp_state))
626 |
627 | return split_states
628 |
629 | # Split the states for Bm hypothesis.
630 | curr_read_states = get_states_split(outputs[4])
631 | curr_write_states = get_states_split(outputs[5])
632 | if FLAGS.use_comp_lstm:
633 | curr_comp_states = get_states_split(outputs[6])
634 | else:
635 | curr_comp_states = None
636 |
637 | if FLAGS.use_comp_lstm:
638 |
639 | curr_states_list = [[memory, read_state, write_state, comp_state]
640 | for memory, read_state, write_state, comp_state in
641 | zip(curr_memories, curr_read_states, curr_write_states, curr_comp_states)]
642 | else:
643 | # Forming a list of internal states for Bm hypothesis.
644 | curr_states_list = [[memory, read_state, write_state] for memory, read_state, write_state
645 | in zip(curr_memories, curr_read_states, curr_write_states)]
646 |
647 | next_values.append(curr_states_list)
648 |
649 | # Generation probabilities.
650 | if FLAGS.use_pgen:
651 | p_gens = outputs[-2] # Shape: 1 x [Bm x 1].
652 | p_gens = p_gens[0] # Only one time-step in decoding phase, Shape: Bm x 1.
653 | p_gens = np.squeeze(p_gens, axis=1) # Shape: Bm x
654 | p_gens = np.split(p_gens, FLAGS.beam_size, axis=0) # Shape: Bm * [1]
655 | next_values.append(p_gens)
656 |
657 | # Attention probabilities.
658 | if FLAGS.use_pgen:
659 | p_attns = outputs[-1] # Shape: 1 x [[Bm x T_enc, Bm x S_in]].
660 | p_attns = p_attns[0] # Only one time-step in the decoding phase, Shape: [Bm x T_enc, Bm x S_in].
661 | p_attns_s = np.split(p_attns[0], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x T_enc].
662 | p_attns_d = np.split(p_attns[1], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x S_in].
663 | p_attns = [[attn_s, attn_d] for attn_s, attn_d
664 | in zip(p_attns_s, p_attns_d)] # Shape: Bm * [[1 x T_enc, 1 x S_in]].
665 | next_values.append(p_attns)
666 |
667 | return next_values
668 |
669 | def _get_vocab_dist(self, inputs):
670 | """
671 | This function passes the NSE hidden states obtained from decoder through the output layer (dense layer
672 | followed by a softmax layer) and returns the vocabulary distribution thus obtained.
673 | :param inputs: List of hidden states, Shape: T_dec * [B x D]
674 | :return: p_vocab: List of vocabulary distributions for each time-step, Shape: T_dec * [B x vsize].
675 | """
676 | steps = len(inputs)
677 | vsize = self._vocab.size()
678 |
679 | inputs = tf.concat(inputs, axis=0) # Shape: (B*T_dec) x D.
680 |
681 | with tf.variable_scope("output", reuse=tf.AUTO_REUSE):
682 | scores = tf.layers.dense(inputs,
683 | units=vsize,
684 | activation=None,
685 | kernel_initializer=self._dense_init,
686 | kernel_regularizer=None,
687 | name='output') # Shape: (B*T_dec) x vsize.
688 |
689 | p_vocab = tf.nn.softmax(scores, name="prob_scores") # Shape: (B*T_dec) x vsize.
690 |
691 | p_vocab = tf.split(p_vocab, num_or_size_splits=steps, axis=0) # Shape: T_dec * [B x vsize].
692 |
693 | return p_vocab
694 |
695 | def _get_cumulative_dist(self, p_vocabs, p_gens, p_attns, enc_in_ext_vocab):
696 | """
697 | This function calculates the cumulative probability distribution from the vocabulary distribution, attention
698 | distribution and generation probabilities.
699 | :param p_vocabs: A list of vocabulary distributions for each time-step, Shape: T_dec * [B x vsize].
700 | :param p_gens: A list of generation probabilities for each time-step, Shape: T_dec * [B x 1]
701 | :param p_attns: A list of attention distributions for each time-step, T_dec * [[sent_attn, doc_attn]],
702 | Shape: T_dec * [[B x T_enc, B x S_in]].
703 | :param enc_in_ext_vocab: Encoder input represented using extended vocabulary, Shape: B x T_enc.
704 | :return:
705 | """
706 | vsize = self._vocab.size()
707 | ext_vsize = vsize + self._max_oov_size
708 | batch_size, enc_steps = enc_in_ext_vocab.get_shape().as_list()
709 |
710 | p_vocabs = [tf.multiply(p_gen, p_vocab)
711 | for p_gen, p_vocab in zip(p_gens, p_vocabs)] # Shape: T_dec * [B x vsize].
712 | p_attns = [tf.multiply(tf.subtract(1.0, p_gen), p_attn[0])
713 | for p_gen, p_attn in zip(p_gens, p_attns)] # Shape: T_dec * [B x T_enc].
714 |
715 | zero_vocab = tf.zeros(shape=[batch_size, self._max_oov_size])
716 | # Shape: T_dec * [B x ext_vsize].
717 | p_vocabs_ext = [tf.concat([p_vocab, zero_vocab], axis=-1) for p_vocab in p_vocabs]
718 |
719 | idx = tf.range(0, limit=batch_size) # Shape: B x.
720 | idx = tf.expand_dims(idx, axis=-1) # Shape: B x 1.
721 | idx = tf.tile(idx, [1, enc_steps]) # Shape: B x T_enc.
722 | indices = tf.stack([idx, enc_in_ext_vocab], axis=-1) # Shape: B x T_enc x 2.
723 |
724 | # First, A zero matrix of shape B x ext_vsize is created. Then, the attention score for each input token
725 | # indexed as in indices is looked in p_attn and updated in the corresponding location,
726 | # Shape: T_dec * [B x ext_vsize]
727 | p_attn_cum = [tf.scatter_nd(indices, p_attn, [batch_size, ext_vsize]) for p_attn in p_attns]
728 |
729 | # Cumulative distribution, Shape: T_dec * [B x ext_vsize].
730 | p_cum = [tf.add(gen_prob, copy_prob) for gen_prob, copy_prob in zip(p_vocabs_ext, p_attn_cum)]
731 |
732 | return p_cum
733 |
734 | @ staticmethod
735 | def _get_crossentropy_loss(probs, labels, mask):
736 | """
737 | This function calculates the crossentropy loss from ground-truth and the probability scores.
738 | :param probs: Predicted probabilities, Shape: T * [B x vocab_size]
739 | :param labels: Ground Truth labels, Shape: B x T.
740 | :param mask: Mask to exclude PAD tokens while calculating loss, Shape: B x T.
741 | :return: average mini-batch loss.
742 | """
743 | if type(probs) is list:
744 | probs = tf.stack(probs, axis=1) # Shape: B x T x vsize.
745 |
746 | true_probs = tf.reduce_sum(
747 | probs * tf.one_hot(labels, depth=tf.shape(probs)[2]), axis=2
748 | ) # Shape: B x T.
749 | logprobs = -tf.log(tf.clip_by_value(true_probs, 1e-10, 1.0)) # Shape: B x T.
750 | xe_loss = tf.reduce_sum(logprobs * mask) / tf.reduce_sum(mask)
751 |
752 | return xe_loss
753 |
754 | def _restore(self):
755 | """
756 | This function restores parameters from a saved checkpoint.
757 | :return:
758 | """
759 | restore_path = FLAGS.PathToCheckpoint + "model_epoch10"
760 |
761 | if FLAGS.restore_checkpoint and tf.train.checkpoint_exists(restore_path):
762 | start = time.time()
763 |
764 | # Initializing all variables.
765 | print("Initializing all variables!\n")
766 | self._sess.run(self._init)
767 |
768 | # Restoring checkpoint variables.
769 | vars_restore = [v for v in self._global_vars if "Adam" not in v.name]
770 | restore_saver = tf.train.Saver(vars_restore)
771 | print("Restoring non-Adam parameters from a previous checkpoint.\n")
772 | restore_saver.restore(self._sess, restore_path)
773 |
774 | end = time.time()
775 | print("Restoring model took %.2f sec. \n" % (end - start))
776 | else:
777 | start = time.time()
778 | self._init.run(session=self._sess)
779 | end = time.time()
780 | print("Running initializer took %.2f time. \n" % (end - start))
781 |
782 | def train(self):
783 | """
784 | This function performs the training followed by validation after every few epochs and saves the best model
785 | if validation loss is decreased. It also writes the training summaries for visualization using tensorboard.
786 | :return:
787 | """
788 | # Restore from previous checkpoint if found one.
789 | self._restore()
790 |
791 | # No. of iterations.
792 | num_train_iters_per_epoch = int(math.ceil(self._data.num_train_examples / FLAGS.batch_size))
793 |
794 | # Running averages.
795 | running_avg_crossentropy_loss = 0.0
796 | running_avg_rouge = 0.0
797 | for epoch in range(1, 1 + FLAGS.num_epochs):
798 | # Training.
799 | for iteration in range(1, 1 + num_train_iters_per_epoch):
800 | start = time.time()
801 | feed_dict = self._get_feed_dict(split="train")
802 |
803 | to_return = [self._train_op, self._crossentropy_loss, self._global_step, self._summaries]
804 | if FLAGS.rouge_summary:
805 | to_return.append(self._samples)
806 |
807 | # Evaluate summaries for last batch.
808 | feed_dict[self._mean_crossentropy_loss] = running_avg_crossentropy_loss
809 | if FLAGS.rouge_summary:
810 | feed_dict[self._mean_rouge_score] = running_avg_rouge
811 |
812 | outputs = self._sess.run(to_return, feed_dict=feed_dict)
813 | _, crossentropy_loss, global_step = outputs[: 3]
814 |
815 | # Calculating ROUGE score.
816 | rouge_score = 0.0
817 | if FLAGS.rouge_summary:
818 | samples = outputs[4]
819 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x .
820 | gt_labels = feed_dict[self._dec_out]
821 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False)
822 |
823 | # Updating the running averages.
824 | running_avg_crossentropy_loss = get_running_avg_loss(
825 | crossentropy_loss, running_avg_crossentropy_loss
826 | )
827 | running_avg_rouge = get_running_avg_loss(
828 | np.mean(rouge_score), running_avg_rouge
829 | )
830 |
831 | if ((iteration - 2) % FLAGS.summary_every == 0) or (iteration == num_train_iters_per_epoch):
832 | train_summary = outputs[3]
833 | self._train_writer.add_summary(train_summary, global_step)
834 |
835 | end = time.time()
836 | print("\rTraining Iteration: {}/{} ({:.1f}%) took {:.2f} sec.".format(
837 | iteration, num_train_iters_per_epoch, iteration * 100 / num_train_iters_per_epoch, end - start
838 | ))
839 |
840 | if epoch % FLAGS.val_every == 0:
841 | start = time.time()
842 | self._validate()
843 | end = time.time()
844 | print("Validation took {:.2f} sec.".format(end - start))
845 |
846 | def _validate(self):
847 | """
848 | This function validates the saved model.
849 | :return:
850 | """
851 | # Validation
852 | num_val_iters_per_epoch = int(math.ceil(self._data.num_val_examples / FLAGS.val_batch_size))
853 | num_train_iters_per_epoch = int(math.ceil(self._data.num_train_examples / FLAGS.batch_size))
854 |
855 | outputs = None
856 | total_crossentropy_loss = 0.0
857 | total_rouge_score = 0.0
858 | for iteration in range(1, 1 + num_val_iters_per_epoch):
859 | feed_dict = self._get_feed_dict(split="val")
860 | to_return = [self._crossentropy_loss]
861 | if FLAGS.rouge_summary:
862 | to_return.append(self._samples)
863 |
864 | if iteration == num_val_iters_per_epoch:
865 | feed_dict[self._mean_crossentropy_loss] = total_crossentropy_loss / (num_val_iters_per_epoch - 1)
866 | feed_dict[self._mean_rouge_score] = total_rouge_score / (num_val_iters_per_epoch - 1)
867 | to_return += [self._val_summaries, self._global_step]
868 |
869 | outputs = self._sess.run(to_return, feed_dict=feed_dict)
870 | crossentropy_loss = outputs[0]
871 |
872 | # Calculating ROUGE score.
873 | rouge_score = 0.0
874 | if FLAGS.rouge_summary:
875 | samples = outputs[1]
876 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x .
877 | gt_labels = feed_dict[self._dec_out]
878 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False)
879 |
880 | print("\rValidation Iteration: {}/{} ({:.1f}%)".format(
881 | iteration, num_val_iters_per_epoch, iteration * 100 / num_val_iters_per_epoch,
882 | ))
883 |
884 | # Updating the total losses.
885 | total_crossentropy_loss += crossentropy_loss
886 | total_rouge_score += np.mean(rouge_score)
887 |
888 | # Writing the validation summaries for visualization.
889 | val_summary, global_step = outputs[-2], outputs[-1]
890 | self._val_writer.add_summary(val_summary, global_step)
891 | epoch = global_step // num_train_iters_per_epoch
892 |
893 | # Cumulative loss.
894 | avg_xe_loss = total_crossentropy_loss / num_val_iters_per_epoch
895 | if avg_xe_loss < self._best_val_loss:
896 | print("\rValidation loss improved :).")
897 | self._saver.save(self._sess, FLAGS.PathToCheckpoint + "model_epoch" + str(epoch))
898 |
899 | def test(self):
900 | """
901 | This function predicts the outputs for the test set.
902 | :return:
903 | """
904 | num_gpus = len(FLAGS.GPUs) # Total no. of GPUs.
905 | batch_size = num_gpus * FLAGS.num_pools # Total no. of examples processed per a single test iteration.
906 |
907 | # No. of iterations
908 | num_test_iters_per_epoch = int(math.ceil(self._data.num_test_examples / batch_size))
909 |
910 | for iteration in range(1, 1 + num_test_iters_per_epoch):
911 | start = time.time()
912 | # [indices, summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask,
913 | # enc_inp_ext_vocab, ext_vocabs, max_oov_size]
914 | input_batches = self._data.get_batch(
915 | batch_size, split="test", permutate=FLAGS.permutate, chunk=FLAGS.chunk
916 | )
917 |
918 | # Split the inputs into batches of size "num_pools" per GPU.
919 | input_batches_per_gpu = [[] for _ in range(num_gpus)]
920 | for i, input_batch in enumerate(input_batches):
921 | for j, idx in zip(range(num_gpus), range(0, batch_size, FLAGS.num_pools)):
922 |
923 | if i < 5 or (FLAGS.use_pgen and i < 7): # First 5 inputs (7 inputs in pgen mode).
924 | input_batches_per_gpu[j].append(
925 | input_batch[idx: idx + FLAGS.num_pools])
926 | else: # max_oov_size is the same for all examples.
927 | input_batches_per_gpu[j].append(input_batch)
928 |
929 | # Appending
930 | # Appending the GPU id.
931 | for gpu_id in range(num_gpus):
932 | input_batches_per_gpu[gpu_id].append(gpu_id)
933 |
934 | Parallel(n_jobs=num_gpus, backend="threading")(
935 | map(delayed(self._test_one_gpu), input_batches_per_gpu)
936 | )
937 | end = time.time()
938 | print("\rTesting Iteration: {}/{} ({:.1f}%) in {:.1f} sec.".format(
939 | iteration, num_test_iters_per_epoch, iteration * 100 / num_test_iters_per_epoch, end - start))
940 |
941 | def _test_one_gpu(self, inputs):
942 | """
943 | This function performs testing on the inputs of a single GPU.
944 | :return:
945 | """
946 | input_batches = inputs[: -1]
947 | gpu_idx = inputs[-1]
948 |
949 | with tf.device('/gpu:%d' % gpu_idx):
950 | self._test_one_pool(input_batches)
951 |
952 | def _test_one_pool(self, inputs):
953 | """
954 | This function performs testing on the inputs of a single GPU over parallel pools.
955 | :param inputs: All the first 8 inputs have "num_pools" examples each.
956 | [indices, summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask,
957 | enc_inp_ext_vocab, ext_vocabs, max_oov_size]
958 | :return:
959 | """
960 | # Split the inputs into each example per pool.
961 | input_batches_per_pool = [[] for _ in range(FLAGS.num_pools)]
962 | for i, input_batch in enumerate(inputs):
963 | for j in range(FLAGS.num_pools):
964 | # Splitting the GT summaries.
965 | if i == 0 or i == 1 or i == 2 or (FLAGS.use_pgen and i == 7):
966 | input_batches_per_pool[j].append([input_batch[j]])
967 |
968 | # Repeating the input Numpy arrays for beam size no. of times.
969 | elif i == 3 or i == 4 or i == 5 or (FLAGS.use_pgen and i == 6):
970 | input_batches_per_pool[j].append(
971 | np.repeat(input_batch[np.newaxis, j], repeats=FLAGS.beam_size, axis=0))
972 |
973 | # max_oov_size is same for all examples in the pool.
974 | else:
975 | input_batches_per_pool[j].append(input_batch)
976 |
977 | Parallel(n_jobs=FLAGS.num_pools, backend="threading")(
978 | map(delayed(self._test_one_ex), input_batches_per_pool))
979 |
980 | def _test_one_ex(self, inputs):
981 | """
982 | This function performs testing on one example.
983 | :param inputs:
984 | [summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask,
985 | enc_inp_ext_vocab, ext_vocabs, max_oov_size, iteration]
986 | :return:
987 | """
988 | iteration = inputs[0][0]
989 | input_batches = inputs[1:]
990 | # Inputs for running beam search on one example.
991 | beam_search_inputs = input_batches[2: 5]
992 | if FLAGS.use_pgen:
993 | beam_search_inputs += [input_batches[5], input_batches[7]]
994 |
995 | best_hyp = run_beam_search_hier(
996 | beam_search_inputs, self.run_encoder, self.decode_one_step, self._vocab
997 | )
998 | pred_ids = best_hyp.tokens # Shape: T_dec * [1 x]
999 | pred_ids = np.stack(pred_ids, axis=0) # Shape: T_dec x
1000 | pred_ids = pred_ids[np.newaxis, :] # Shape: 1 x T_dec
1001 |
1002 | # Writing the outputs.
1003 | if FLAGS.use_pgen:
1004 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1], input_batches[6])
1005 | else:
1006 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1])
1007 |
1008 | def _get_feed_dict(self, split):
1009 | """
1010 | Returns a feed_dict assigning data batches to the following placeholders:
1011 | [enc_in, enc_pad_mask, enc_doc_mask, dec_in, labels, dec_pad_mask, enc_inp_ext_vocab, max_oov_size]
1012 | :param split:
1013 | :return:
1014 | """
1015 |
1016 | if split == "train":
1017 | input_batches = self._data.get_batch(
1018 | FLAGS.batch_size, split="train", permutate=FLAGS.permutate, chunk=FLAGS.chunk
1019 | )
1020 |
1021 | elif split == "val":
1022 | input_batches = self._data.get_batch(
1023 | FLAGS.val_batch_size, split="val", permutate=FLAGS.permutate, chunk=FLAGS.chunk
1024 | )
1025 |
1026 | else:
1027 | raise ValueError("split should be either train/val!! \n")
1028 |
1029 | feed_dict = {
1030 | self._enc_in: input_batches[0],
1031 | self._enc_pad_mask: input_batches[1],
1032 | self._enc_doc_mask: input_batches[2],
1033 | self._dec_in: input_batches[3],
1034 | self._dec_out: input_batches[4],
1035 | self._dec_pad_mask: input_batches[5]
1036 | }
1037 |
1038 | if FLAGS.use_pgen:
1039 | feed_dict[self._enc_in_ext_vocab] = input_batches[6]
1040 | feed_dict[self._max_oov_size] = input_batches[7]
1041 |
1042 | return feed_dict
1043 |
1044 | @staticmethod
1045 | def make_html_safe(s):
1046 | """
1047 | Replace any angled brackets in string s to avoid interfering with HTML attention visualizer.
1048 | :param s: Input string.
1049 | :return:
1050 | """
1051 | s.replace("<", "<")
1052 | s.replace(">", ">")
1053 |
1054 | return s
1055 |
1056 | def _write_file(self, pred, pred_name, gt, gt_name, ext_vocab=None):
1057 | """
1058 | This function writes tokens in vals to a .txt file with given name.
1059 | :param pred: Predicted ID's. Shape: 1 x T
1060 | :param pred_name: Name of the file in which predictions will be written.
1061 | :param gt: Ground truth summary, a list of sentences (strings).
1062 | :param gt_name: Name of the file in which GTs will be written.
1063 | :param ext_vocab: Extended vocabulary for each example. [ext_words x]
1064 | :return: _pred, _gt files will be created for ROUGE evaluation.
1065 | """
1066 | # Writing predictions.
1067 | vsize = self._vocab.size()
1068 |
1069 | # Removing the [START] token.
1070 | pred = pred[1:]
1071 |
1072 | # Converting words to ID's
1073 | pred_words = []
1074 | for t in pred:
1075 | try:
1076 | pred_words.append(self._vocab.id2word(t))
1077 | except ValueError:
1078 | pred_words.append(ext_vocab[t - vsize])
1079 |
1080 | # Considering tokens only till STOP token.
1081 | try:
1082 | stop_idx = pred_words.index(params.STOP_DECODING)
1083 | except ValueError:
1084 | stop_idx = len(pred_words)
1085 | pred_words = pred_words[: stop_idx]
1086 |
1087 | # Creating sentences out of the predicted sequence.
1088 | pred_sents = []
1089 | while pred_words:
1090 | try:
1091 | period_idx = pred_words.index(".")
1092 | except ValueError:
1093 | period_idx = len(pred_words)
1094 |
1095 | # Append the sentence.
1096 | sent = pred_words[: period_idx + 1]
1097 | pred_sents.append(" ".join(sent))
1098 |
1099 | # Consider the remaining words now.
1100 | pred_words = pred_words[period_idx + 1:]
1101 |
1102 | # Making HTML safe.
1103 | pred_sents = [self.make_html_safe(s) for s in pred_sents]
1104 | gt_sents = [self.make_html_safe(s) for s in gt]
1105 |
1106 | # Writing predicted sentences.
1107 | f = open(pred_name, 'w', encoding='utf-8')
1108 | for i, sent in enumerate(pred_sents):
1109 | f.write(sent) if i == len(pred_sents) - 1 else f.write(sent + "\n")
1110 | f.close()
1111 |
1112 | # Writing GT sentences.
1113 | f = open(gt_name, 'w', encoding='utf-8')
1114 | for i, sent in enumerate(gt_sents):
1115 | f.write(sent) if i == len(gt_sents) - 1 else f.write(sent + "\n")
1116 | f.close()
1117 |
1118 | def _write_outputs(self, index, preds, gts, files, ext_vocabs=None):
1119 | """
1120 | This function writes the input files
1121 | :param index: Number of the test example.
1122 | :param preds: The predictions. Shape: B x T
1123 | :param: gts: The ground truths. Shape: B x T
1124 | :param files: The names of the files.
1125 | :param ext_vocabs: Extended vocabularies for each example, Shape: B * [ext_words x]
1126 | :return: Saves the predictions and GT's in a .txt format.
1127 | """
1128 | for i in range(len(files)):
1129 | file, pred, gt = files[i], preds[i], gts[i]
1130 | name_pred = FLAGS.PathToResults + 'predictions/' + '%06d_pred.txt' % index
1131 | name_gt = FLAGS.PathToResults + 'groundtruths/' + '%06d_gt.txt' % index
1132 |
1133 | ext_vocab = None
1134 | if FLAGS.use_pgen:
1135 | ext_vocab = ext_vocabs[i]
1136 |
1137 | self._write_file(pred, name_pred, gt, name_gt, ext_vocab)
1138 |
1139 | def _single_model(self):
1140 | # with tf.device('/gpu'):
1141 | placeholders = self._create_placeholders()
1142 | self._crossentropy_loss, self._samples = self._forward(placeholders)
1143 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer
1144 | self._train_op = sgd_solver.minimize(
1145 | self._crossentropy_loss, global_step=self._global_step, name='train_op')
1146 |
1147 | self._saver = tf.train.Saver(var_list=tf.global_variables()) # Saver.
1148 | self._init = tf.global_variables_initializer() # Initializer.
1149 | self._sess = tf.Session(config=self._config) # Session.
1150 |
1151 | # Train and validation summary writers.
1152 | if FLAGS.mode == 'train':
1153 | self._create_writers()
1154 |
1155 | self._global_vars = tf.global_variables()
1156 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables())))
1157 | # print(tf.trainable_variables())
1158 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()])
1159 | # print('No. of params = {:d}'.format(int(no_params)))
1160 |
1161 | def _parallel_model(self):
1162 | with tf.Graph().as_default(), tf.device('/cpu:0'):
1163 | placeholders = self._create_placeholders()
1164 |
1165 | # Splitting the placeholders for each GPU.
1166 | placeholders_per_gpu = [[] for _ in range(self._num_gpus)]
1167 | for placeholder in placeholders:
1168 | splits = tf.split(placeholder, num_or_size_splits=self._num_gpus, axis=0)
1169 | for i, split in enumerate(splits):
1170 | placeholders_per_gpu[i].append(split)
1171 |
1172 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer
1173 | tower_grads = [] # Gradients calculated in each tower
1174 | with tf.variable_scope(tf.get_variable_scope()):
1175 | all_losses = []
1176 | all_samples = []
1177 | for i in range(self._num_gpus):
1178 | with tf.device('/gpu:%d' % i):
1179 | with tf.name_scope('tower_%d' % i):
1180 | crossentropy_loss, samples = self._forward(placeholders_per_gpu[i])
1181 |
1182 | # Gathering samples from all GPUs.
1183 | all_samples.append(samples)
1184 |
1185 | tf.get_variable_scope().reuse_variables()
1186 |
1187 | grads = sgd_solver.compute_gradients(crossentropy_loss)
1188 | tower_grads.append(grads)
1189 |
1190 | # Updating the losses
1191 | all_losses.append(crossentropy_loss)
1192 |
1193 | self._crossentropy_loss = tf.add_n(all_losses) / self._num_gpus
1194 |
1195 | # Samples from all the GPUs.
1196 | if FLAGS.rouge_summary:
1197 | self._samples = tf.concat(all_samples, axis=0)
1198 | else:
1199 | self._samples = []
1200 |
1201 | # Synchronization Point
1202 | gradients, variables = zip(*average_gradients(tower_grads))
1203 | gradients, global_norm = tf.clip_by_global_norm(gradients, FLAGS.max_grad_norm)
1204 |
1205 | # Summary for the global norm
1206 | tf.summary.scalar('global_norm', global_norm)
1207 |
1208 | # Histograms for gradients.
1209 | for grad, var in zip(gradients, variables):
1210 | # if (grad is not None) and ("word_embedding" in var.name):
1211 | tf.summary.histogram(var.op.name + '/gradients', grad)
1212 |
1213 | # Histograms for variables.
1214 | for var in tf.trainable_variables():
1215 | # if "word_embedding" in var.name:
1216 | tf.summary.histogram(var.op.name, var)
1217 |
1218 | self._train_op = sgd_solver.apply_gradients(
1219 | zip(gradients, variables), global_step=self._global_step)
1220 | self._saver = tf.train.Saver(var_list=tf.global_variables(),
1221 | max_to_keep=None) # Saver.
1222 | self._init = tf.global_variables_initializer() # Initializer.
1223 | self._sess = tf.Session(config=self._config) # Session.
1224 |
1225 | # Train and validation summary writers.
1226 | if FLAGS.mode == 'train':
1227 | self._create_writers()
1228 |
1229 | self._global_vars = tf.global_variables()
1230 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables())))
1231 | # print(tf.trainable_variables())
1232 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()])
1233 | # print('No. of params = {:d}'.format(int(no_params)))
1234 |
--------------------------------------------------------------------------------