├── codes ├── __init__.py ├── params.py ├── rouge_batch.py ├── beam_search.py ├── run_model.py ├── NSE.py ├── HierNSE.py ├── utils.py ├── model.py └── model_hier.py ├── data ├── __init__.py └── preprocess.py └── README.md /codes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /codes/params.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | dm_single_close_quote = u'\u2019' # unicode 3 | dm_double_close_quote = u'\u201d' 4 | 5 | # acceptable ways to end a sentence 6 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] 7 | 8 | # Vocabulary id's for sentence start, end, pad, unknown token, start and stop decoding. 9 | SENTENCE_START = '' 10 | SENTENCE_END = '' 11 | 12 | PAD_TOKEN = '[PAD]' 13 | UNKNOWN_TOKEN = '[UNK]' 14 | START_DECODING = '[START]' 15 | STOP_DECODING = '[STOP]' 16 | VOCAB_SIZE = 200000 17 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Code by Saptarashmi Bandyopadhyay 3 | 4 | import nltk 5 | import sys 6 | import os 7 | from nltk.stem import PorterStemmer 8 | 9 | porter = PorterStemmer() 10 | 11 | PathToDataset = sys.argv[1] 12 | PathToOutput = sys.argv[2] 13 | 14 | # Factoring. 15 | for split in ["train", "val", "test"]: 16 | 17 | # Creating a new directory if it doesn't exist. 18 | if not os.path.exists(PathToOutput + split): 19 | os.makedirs(PathToOutput + split) 20 | 21 | print("Processing files in the {} set".format(split)) 22 | 23 | files = os.listdir(PathToDataset + split) 24 | # total_num = len(files) 25 | 26 | for num, filename in enumerate(files): 27 | f = open(PathToDataset + split + '/' + filename, 'r') 28 | of = open(PathToOutput + split + '/' + filename, 'w') 29 | 30 | for line in f: 31 | if "@highlight" not in line: 32 | text = line.encode('utf8') 33 | word = nltk.pos_tag(nltk.word_tokenize(text.decode('utf8'))) 34 | for i in word: 35 | stemmed = porter.stem(i[0]) 36 | oword = i[0] + ' | ' + stemmed + ' | ' + i[1] + ' ' 37 | of.write(oword) 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Requirements 2 | - Tensorflow == 1.9.0 3 | - NLTK 4 | - joblib == 0.13.2 5 | 6 | # Files 7 | - run_model.py: This file has all the flags needed for training, it chooses the data module, models and starts training. 8 | - utils.py: Utilities module containing DataGenerators, Parallel model etc. 9 | - DataGenerator: Creates batches for vanilla NSE. 10 | - DataGeneratorHier: Creates batches for hierarchical NSE. 11 | - Memory: All the files below contain a variant of NSE. 12 | - NSE: This is a neural semantic encoder class. 13 | - HierNSE: This is the hier-NSE class. 14 | - models: All the files below contain an encoder, decoder, loss, optimizer functions that use an NSE. 15 | - model.py: Model using vanilla NSE. 16 | - model_hier.py: Model using hier-NSE (use this). 17 | - model_hier_sc.py: Self-Critic model (use this). It carefully back-propagates through the same multinomial samples that are sampled while forward pass. 18 | - rouge: Rouge scripts used. 19 | - rouge_batch: A NumPy implementation (faster than existing ones). Used outside the TensorFlow graph. 20 | - Data 21 | - Create a folder named `data` 22 | - Download the following splits into data folder: 23 | [train](https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_train.txt), 24 | [val](https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_val.txt), 25 | [test](https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt) 26 | - Download the CNN and Daily-Mail tokenized data: 27 | [CNN](https://drive.google.com/file/d/0BzQ6rtO2VN95cmNuc2xwUS1wdEE/view?usp=sharing), 28 | [DM](https://drive.google.com/file/d/0BzQ6rtO2VN95bndCZDdpdXJDV1U/view?usp=sharing) 29 | - Download [GloVe](http://nlp.stanford.edu/data/glove.840B.300d.zip) 30 | 31 | # supervised model 32 | - Training 33 | 34 | python run_model.py --model="hier" --mode="train" --PathToCheckpoint=/path/to/checkpoint --PathToTB=/path/to/tensorboard/logs 35 | 36 | - Testing 37 | - Check the epoch number of the best supervised model from TensorBoard, let it be X 38 | 39 | python run_model.py --model="hier" --mode="test" --PathToCheckpoint=/path/to/checkpoint/model_epochX --PathToResults=/path/to/results 40 | 41 | - Evaluation 42 | 43 | python run_model.py --model="hier" --mode="eval" --PathToResults=/path/to/results 44 | 45 | # self-critical model 46 | - Training 47 | - First copy the best supervised model to the rl checkpoint. 48 | 49 | python run_model.py --model="rlhier" --mode="train" --restore=True --PathToCheckpoint=/path/to/checkpoint --PathToTB=/path/to/tensorboard/logs 50 | 51 | - Testing 52 | - Check the epoch number of the best supervised model from TensorBoard, let it be X. 53 | 54 | python run_model.py --model="rlhier" --mode="train" --restore=True --PathToCheckpoint=/path/to/checkpoint/model_epochX --PathToResults=/path/to/results 55 | 56 | - Evaluation 57 | 58 | python run_model.py --model="rlhier" --mode="eval" --PathToResults=/path/to/results 59 | -------------------------------------------------------------------------------- /codes/rouge_batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rouge_l_fscore_batch(hypothesis, references, end_id, lens, all_steps=False): 5 | """ 6 | ROUGE scores computation between labels and predictions. 7 | This is an approximate ROUGE scoring method since we do not glue word pieces 8 | or decode the ids and tokenize the output. 9 | :param hypothesis: (predictions) tensor, model predictions (batch_size, <=max_dec_steps) 10 | :param references: (labels) tensor, gold output. (batch_size, max_dec_steps) 11 | :param end_id: End of sequence ID. 12 | :param lens: Lengths of the sequences excluding the PAD tokens, Shape: (batch_size). 13 | :param all_steps: Whether to calculate rewards for all time-steps. 14 | :return: rouge_l_fscore: approx rouge-l f1 score, Shape: (batch_size). 15 | """ 16 | 17 | if all_steps: 18 | batch_fscore = rouge_l_sentence_level_all_batch(hypothesis, references, end_id, lens) # Shape: B x T. 19 | else: 20 | batch_fscore = rouge_l_sentence_level_final_batch(hypothesis, references, end_id, lens) # Shape: B x . 21 | 22 | return batch_fscore 23 | 24 | 25 | def infer_length(seq, end_id): 26 | """ 27 | This function is used to calculate the length of given sequence based on the end ID. 28 | :param seq: Input sequence, Shape: B x T. 29 | :param end_id: End of sequence ID. 30 | :return: 31 | """ 32 | batch_size = seq.shape[0] 33 | is_end = np.equal(seq, end_id).astype(np.int32) # Shape: B x T. 34 | 35 | # Avoiding the zero length case. 36 | front_zeros = np.zeros([batch_size, 1], dtype=np.int32) # Shape: B x 1. 37 | is_end = np.concatenate([front_zeros, is_end], axis=1) # Shape: B x (T + 1). 38 | is_end = is_end[:, : -1] # Shape: B x T. 39 | 40 | count_end = np.cumsum(is_end, axis=1) 41 | lengths = np.sum(np.equal(count_end, 0).astype(np.int32), axis=1) 42 | 43 | return lengths 44 | 45 | 46 | def rouge_l_sentence_level_final_batch(eval_sentences, ref_sentences, end_id=None, lens=None): 47 | """ 48 | Computes ROUGE-L (sentence level) of two collections of sentences. 49 | Source: https://www.microsoft.com/en-us/research/publication/ 50 | rouge-a-package-for-automatic-evaluation-of-summaries/ 51 | Calculated according to: 52 | R_lcs = LCS(X,Y)/m 53 | P_lcs = LCS(X,Y)/n 54 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 55 | where: 56 | X = reference summary 57 | Y = Candidate summary 58 | m = length of reference summary 59 | n = length of candidate summary 60 | :param eval_sentences: The sentences that have been picked by the summarizer, Shape: (batch_size, m) 61 | :param ref_sentences: The sentences from the reference set, Shape: (batch_size, n) 62 | :param end_id: End of sentence ID. 63 | :param lens: Lengths of the sequences excluding PAD tokens. 64 | :return: F_lcs for all sentences in the batch, Shape: (batch_size,) 65 | """ 66 | if lens is not None: 67 | n, m = lens, lens 68 | else: 69 | n, m = infer_length(eval_sentences, end_id), infer_length(ref_sentences, end_id) 70 | lcs = _len_lcs_batch(eval_sentences, ref_sentences, n, m) 71 | return np.array(_f_lcs_batch(lcs, n, m)).astype(np.float32) 72 | 73 | 74 | def rouge_l_sentence_level_all_batch(eval_sentences, ref_sentences, end_id, lens): 75 | """ 76 | Computes ROUGE-L (sentence level) of two collections of sentences. 77 | Source: https://www.microsoft.com/en-us/research/publication/ 78 | rouge-a-package-for-automatic-evaluation-of-summaries/ 79 | Calculated according to: 80 | R_lcs = LCS(X,Y)/m 81 | P_lcs = LCS(X,Y)/n 82 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 83 | where: 84 | X = reference summary 85 | Y = Candidate summary 86 | m = length of reference summary 87 | n = length of candidate summary 88 | :param eval_sentences: The sentences that have been picked by the summarizer, Shape: (batch_size, T). 89 | :param ref_sentences: The sentences from the reference set, Shape: (batch_size, T). 90 | :param end_id: End of sentence ID. 91 | :param lens: length of the sequences, Shape: (batch_size). 92 | :return: F_lcs: Shape: B x T. 93 | """ 94 | if lens is not None: 95 | m = lens 96 | else: 97 | m = infer_length(ref_sentences, end_id) 98 | 99 | batch_size, steps = eval_sentences.shape 100 | 101 | n = np.tile(np.arange(1, steps + 1), batch_size) # Shape: (B*T) x. 102 | m = np.tile(m, steps) # Shape: (B*T) x. 103 | 104 | # Calculate F1 scores. 105 | lcs = _len_lcs_batch(eval_sentences, ref_sentences, n, m, True) # Shape: B x T. 106 | lcs = np.reshape(lcs, [-1]) # Shape: (B*T) x. 107 | f1_scores_all_steps = _f_lcs_batch(lcs, n, m) # Shape: (B*T,) 108 | f1_scores = np.reshape(f1_scores_all_steps, (batch_size, steps)) # Shape: B x T. 109 | 110 | return f1_scores 111 | 112 | 113 | def _len_lcs_batch(x, y, n, m, all_steps=False): 114 | """ 115 | Returns the length of Longest Common Sub-sequence between two steps. 116 | :param x: sequence of words, Shape: (batch_size, n). 117 | :param y: sequence of words, Shape: (batch_size, m). 118 | :param n: Lengths of the sequences in X, Shape: (batch_size,). 119 | :param m: Lengths of the sequences in Y, Shape: (batch_size,). 120 | :param all_steps: Whether to output LCS of all time-steps. 121 | :return: Lengths of LCS between a batch of x and y, Shape: (batch_size,) / (batch_size, T) 122 | """ 123 | table = _lcs_batch(x, y) # Shape: batch_size x len x len. 124 | len_lcs = [] 125 | for i in range(x.shape[0]): 126 | if all_steps: 127 | len_lcs.append(table[i, 1:, m[i]]) 128 | else: 129 | len_lcs.append(table[i, n[i], m[i]]) 130 | 131 | return np.array(len_lcs) 132 | 133 | 134 | def _lcs_batch(x, y): 135 | """ 136 | Computes the length of LCS between two seqs. 137 | :param x: collection of words, (batch_size, n). 138 | :param y: collection of words, (batch_size, m). 139 | :return: 140 | """ 141 | batch_size, n = x.shape 142 | m = y.shape[1] 143 | 144 | table = np.ndarray(shape=[batch_size, n + 1, m + 1], dtype=np.int32) 145 | for i in range(n + 1): 146 | for j in range(m + 1): 147 | if i == 0 or j == 0: 148 | table[:, i, j] = 0 149 | else: 150 | true_indcs = np.argwhere(x[:, i - 1] == y[:, j - 1]) 151 | false_indcs = np.argwhere(x[:, i - 1] != y[:, j - 1]) 152 | table[true_indcs, i, j] = table[true_indcs, i - 1, j - 1] + 1 153 | table[false_indcs, i, j] = np.maximum(table[false_indcs, i - 1, j], table[false_indcs, i, j - 1]) 154 | 155 | return table 156 | 157 | 158 | def _f_lcs_batch(llcs, n, m): 159 | """ 160 | Computes the LCS-based F-measure score. 161 | :param llcs: lengths of LCS, Shape: (batch_size,) 162 | :param n: number of words in candidate summary, Shape: (batch_size,) 163 | :param m: number of words in reference summary, Shape: (batch_size,) 164 | :return: 165 | """ 166 | r_lcs = llcs / m 167 | p_lcs = llcs / n 168 | beta = p_lcs / (r_lcs + 1e-12) 169 | num = (1 + (beta**2)) * r_lcs * p_lcs 170 | denom = r_lcs + ((beta**2) * p_lcs) 171 | f_lcs = num / (denom + 1e-12) 172 | 173 | return f_lcs 174 | -------------------------------------------------------------------------------- /codes/beam_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import params as params 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | 7 | 8 | class Hypothesis(object): 9 | """ 10 | This is a class that represents a hypothesis. 11 | """ 12 | 13 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens): 14 | """ 15 | :param tokens: Tokens in this hypothesis. 16 | :param log_probs: log probabilities of each token in the hypothesis. 17 | :param state: Internal state of decoder after decoding the last token. 18 | :param attn_dists: Attention distributions of each token. 19 | :param p_gens: Generation probability of each token. 20 | """ 21 | self.tokens = tokens 22 | self.log_probs = log_probs 23 | self.state = state 24 | self.attn_dists = attn_dists 25 | self.p_gens = p_gens 26 | 27 | def extend(self, token, log_prob, state=None, attn_dist=None, p_gen=None): 28 | """ 29 | This function extends the hypothesis with the given new token. 30 | :param token: New decoded token. 31 | :param log_prob: log probability ot the token. 32 | :param state: Internal state of decoder after decoding the token. 33 | :param attn_dist: Attention distribution of the token. 34 | :param p_gen: Generation probability of each token. 35 | :return: returns the extended hypothesis 36 | """ 37 | if state is None: 38 | state = self.state 39 | 40 | return Hypothesis( 41 | tokens=self.tokens + [token], 42 | log_probs=self.log_probs + log_prob, 43 | state=state, 44 | attn_dists=self.attn_dists + [attn_dist], 45 | p_gens=self.p_gens + [p_gen] 46 | ) 47 | 48 | def latest_token(self): 49 | """ 50 | :return: This function returns the last token in the hypothesis. 51 | """ 52 | return self.tokens[-1] 53 | 54 | def avg_log_prob(self): 55 | """ 56 | :return: This function returns the average of the log probability of hypothesis. 57 | Otherwise, longer sequences will have less probability. 58 | """ 59 | return sum(self.log_probs) / len(self.tokens) 60 | 61 | 62 | def run_beam_search(inputs, run_encoder, decode_one_step, vocab): 63 | """ 64 | This function performs the beam search decoding for one example. 65 | :param inputs: Inputs to the graph for running encoder and decoder. 66 | [enc_inp, enc_padding_mask, enc_inp_ext_vocab, max_oov_size] all with shapes Bm x T_in. 67 | :param run_encoder: 68 | :param decode_one_step: 69 | :param vocab: 70 | :return: 71 | """ 72 | # Run the encoder and fix the resulted final state for decoding process. 73 | init_states = run_encoder(inputs[0], inputs[1]) 74 | 75 | # Initialize beam size number of hypothesis. 76 | hyps = [Hypothesis( 77 | tokens=[vocab.word2id(params.START_DECODING)], 78 | log_probs=[0.0], 79 | state=init_states, 80 | attn_dists=[], 81 | p_gens=[] 82 | ) for _ in range(FLAGS.beam_size)] 83 | 84 | results = [] # This will contain finished hypothesis (those have emitted [STOP_DECODING] token.) 85 | 86 | for step in range(FLAGS.bs_dec_steps): 87 | 88 | # Stop decoding if beam_size number of complete hypothesis are emitted. 89 | if len(results) == FLAGS.beam_size: 90 | break 91 | 92 | prev_tokens = [h.latest_token() for h in hyps] # Tokens of all hypothesis from previous time step. 93 | # Replacing the OOV tokens with UNK tokens to perform the next decoding step. 94 | prev_tokens = [t if t in range(vocab.size()) else vocab.word2id(params.UNKNOWN_TOKEN) for t in prev_tokens] 95 | prev_states = [h.state for h in hyps] # Internal states of decoder. 96 | 97 | # Running one step of decoder. 98 | step_inputs = [prev_tokens, inputs[1]] 99 | if FLAGS.use_pgen: 100 | step_inputs += [inputs[2], inputs[3]] 101 | 102 | outputs = decode_one_step(step_inputs, prev_states) 103 | topk_ids, topk_log_probs, curr_states = outputs[: 3] 104 | 105 | p_gens = FLAGS.beam_size * [None] 106 | p_attns = FLAGS.beam_size * [None] 107 | if FLAGS.use_pgen: 108 | p_gens, p_attns = outputs[3: 5] 109 | 110 | # Extend each hypothesis with newly decoded predictions. 111 | all_hyps = [] 112 | num_org_hyps = 1 if step == 0 else len(hyps) # In the first step, there is only 1 distinct hypothesis. 113 | for i in range(num_org_hyps): 114 | hyp, curr_state, p_gen, p_attn = hyps[i], curr_states[i], p_gens[i], p_attns[i] 115 | for j in range(2 * FLAGS.beam_size): 116 | new_hyp = hyp.extend(token=topk_ids[i][j], 117 | log_prob=topk_log_probs[i][j], 118 | state=curr_state, 119 | attn_dist=p_attn, 120 | p_gen=p_gen) 121 | all_hyps.append(new_hyp) 122 | 123 | hyps = [] # Top Bm hypothesis. 124 | for h in sort_hyps(all_hyps): 125 | if h.latest_token == vocab.word2id(params.STOP_DECODING): # If hypothesis ended. 126 | # Collect this sequence only if it is long enough. 127 | if step >= FLAGS.min_dec_len: 128 | results.append(h) 129 | else: # Use this hypothesis for next decoding step if not ended. 130 | hyps.append(h) 131 | 132 | # Stop if beam size is reached. 133 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 134 | break 135 | 136 | # If no complete hypothesis were collected, add all current hypothesis to the results. 137 | if len(results) == 0: 138 | results = hyps 139 | 140 | results = sort_hyps(results) # Return best hypothesis in the final beam. 141 | 142 | return results[0] 143 | 144 | 145 | def sort_hyps(hyps): 146 | """ 147 | This function sorts the given hypothesis based on its probability. 148 | :param hyps: Input hypotheses. 149 | :return: 150 | """ 151 | return sorted(hyps, key=lambda h: h.avg_log_prob(), reverse=True) 152 | 153 | 154 | def run_beam_search_hier(inputs, run_encoder, decode_one_step, vocab): 155 | """ 156 | This function performs the beam search decoding for one example. 157 | :param inputs: Inputs to the graph for running encoder and decoder. 158 | [enc_inp, enc_pad_mask, enc_doc_mask, enc_inp_ext_vocab, max_oov_size] all with shapes Bm x T_in. 159 | :param run_encoder: 160 | :param decode_one_step: 161 | :param vocab: 162 | :return: 163 | """ 164 | # Run the encoder and fix the resulted final state for decoding process. 165 | init_states = run_encoder(inputs[: 3]) 166 | 167 | # Initialize beam size number of hypothesis. 168 | hyps = [Hypothesis( 169 | tokens=[vocab.word2id(params.START_DECODING)], 170 | log_probs=[0.0], 171 | state=init_states, 172 | attn_dists=[], 173 | p_gens=[] 174 | ) for _ in range(FLAGS.beam_size)] 175 | 176 | results = [] # This will contain finished hypothesis (those have emitted [STOP_DECODING] token.) 177 | 178 | for step in range(FLAGS.bs_dec_steps): 179 | 180 | # Stop decoding if beam_size number of complete hypothesis are emitted. 181 | if len(results) == FLAGS.beam_size: 182 | break 183 | 184 | prev_tokens = [h.latest_token() for h in hyps] # Tokens of all hypothesis from previous time step. 185 | # Replacing the OOV tokens with UNK tokens to perform the next decoding step. 186 | prev_tokens = [t if t in range(vocab.size()) else vocab.word2id(params.UNKNOWN_TOKEN) for t in prev_tokens] 187 | prev_states = [h.state for h in hyps] # Internal states of decoder. 188 | 189 | # Preparing inputs for the decoder. 190 | # inputs[1] = np.reshape(inputs[1], [-1, FLAGS.max_enc_sent * FLAGS.max_enc_steps_per_sent]) # B x T_enc. 191 | # Running one step of decoder. 192 | step_inputs = [prev_tokens] + inputs[1: 3] 193 | if FLAGS.use_pgen: 194 | step_inputs += [inputs[3], inputs[4]] 195 | 196 | outputs = decode_one_step(step_inputs, prev_states) 197 | topk_ids, topk_log_probs, curr_states = outputs[: 3] 198 | 199 | p_gens = FLAGS.beam_size * [None] 200 | p_attns = FLAGS.beam_size * [None] 201 | if FLAGS.use_pgen: 202 | p_gens, p_attns = outputs[3: 5] 203 | 204 | # Extend each hypothesis with newly decoded predictions. 205 | all_hyps = [] 206 | num_org_hyps = 1 if step == 0 else len(hyps) # In the first step, there is only 1 distinct hypothesis. 207 | for i in range(num_org_hyps): 208 | hyp, curr_state, p_gen, p_attn = hyps[i], curr_states[i], p_gens[i], p_attns[i] 209 | for j in range(2 * FLAGS.beam_size): 210 | new_hyp = hyp.extend(token=topk_ids[i][j], 211 | log_prob=topk_log_probs[i][j], 212 | state=curr_state, 213 | attn_dist=p_attn, 214 | p_gen=p_gen) 215 | all_hyps.append(new_hyp) 216 | 217 | hyps = [] # Top Bm hypothesis. 218 | for h in sort_hyps(all_hyps): 219 | if h.latest_token == vocab.word2id(params.STOP_DECODING): # If hypothesis ended. 220 | # Collect this sequence only if it is long enough. 221 | if step >= FLAGS.min_dec_len: 222 | results.append(h) 223 | else: # Use this hypothesis for next decoding step if not ended. 224 | hyps.append(h) 225 | 226 | # Stop if beam size is reached. 227 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 228 | break 229 | 230 | # If no complete hypothesis were collected, add all current hypothesis to the results. 231 | if len(results) == 0: 232 | results = hyps 233 | 234 | results = sort_hyps(results) # Return best hypothesis in the final beam. 235 | 236 | return results[0] 237 | -------------------------------------------------------------------------------- /codes/run_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Latest Command. 3 | # python run_model.py --sample=True --restore_type="all" --enc_steps=24 --dec_steps=12 --max_enc_sent=6 4 | # --max_enc_steps_per_sent=4 --max_dec_sent=4 --max_dec_steps_per_sent=3 --prob_dist="softmax" --rnn="LSTM" 5 | # --skip_rnn=False --use_comp_lstm=True --use_pos_emb=False --use_pgen=True --use_enc_coverage=True 6 | # --use_dec_coverage=False --use_reweight=False --vocab_size=20 --num_epochs=10 --val_every=1 --model="hier2" 7 | # --attn_type="ffn" --scaled=True --num_heads=1 --num_high=4 --num_high_heads=5 --num_high_iters=5 --mode="test" 8 | # --batch_size=3 --val_batch_size=3 --PathToCheckpoint="./my_nse_net/model_epoch10" 9 | 10 | __author__ = "Rajeev Bhatt Ambati" 11 | import tensorflow as tf 12 | from utils import Vocab, DataGenerator, DataGeneratorHier, eval_model 13 | from model import SummarizationModel 14 | from model_hier import SummarizationModelHier 15 | from model_hier_sc import SummarizationModelHierSC 16 | import os 17 | import random 18 | import numpy as np 19 | import time 20 | 21 | FLAGS = tf.app.flags.FLAGS 22 | 23 | # Parameters obtained from the input arguments. 24 | 25 | # Paths. 26 | tf.app.flags.DEFINE_string('PathToDataset', '../data/', 'Path to the datasets.') 27 | tf.app.flags.DEFINE_string('PathToGlove', '../data/glove.840B.300d.w2v.txt', 28 | 'Path to the pre-trained GloVe Word vectors.') 29 | tf.app.flags.DEFINE_string('PathToVocab', '../data/vocab.txt', 'Path to a vocabulary file if already stored. ' 30 | 'Otherwise, a new vocabulary file will be stored here.') 31 | tf.app.flags.DEFINE_string('PathToLookups', '../data/lookups.pkl', 'Path to the lookup tables .pkl file if already' 32 | 'stored. Otherwise, a new vocabulary file will be' 33 | 'stored here.') 34 | tf.app.flags.DEFINE_string('PathToResults', '../results/', 'Path to the test results.') 35 | tf.app.flags.DEFINE_string('PathToCheckpoint', './my_nse_net/hier_v14/model_epoch10', 36 | 'Trained model will be stored here.') 37 | tf.app.flags.DEFINE_boolean('sample', False, 'Sample debugging the code or training for long.') 38 | tf.app.flags.DEFINE_bool('permutate', False, 'Whether to permutate or truncate sequences.') 39 | tf.app.flags.DEFINE_boolean('chunk', True, 'Whether to use chunks in place of sentences to use maximum effective ' 40 | 'sequence lengths possible') 41 | tf.app.flags.DEFINE_string('PathToTB', 'log/', 'Tensorboard visualization directory.') 42 | tf.app.flags.DEFINE_boolean('restore_checkpoint', True, 'Boolean describing whether training has to be restored from ' 43 | 'a checkpoint or start fresh.') 44 | tf.app.flags.DEFINE_string('restore_type', "all", 'String describing whether coverage/momentum parameters has to be' 45 | 'restored or initialized.') 46 | 47 | # Plain model inputs. 48 | tf.app.flags.DEFINE_integer('enc_steps', 400, 'No. of time steps in the encoder.') 49 | tf.app.flags.DEFINE_integer('dec_steps', 100, 'No. of time steps in the decoder.') 50 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'Max. no of time steps in the decoder during decoding.') 51 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 'Minimum no. of tokens in a complete hypothesis while decoding.') 52 | 53 | # Hier model inputs. 54 | tf.app.flags.DEFINE_integer('max_enc_sent', 20, 'Max. no. of sentences in the encoder.') 55 | tf.app.flags.DEFINE_integer('max_enc_steps_per_sent', 20, 'Max. no. of tokens per a sentence in the encoder.') 56 | 57 | # Common flags. 58 | tf.app.flags.DEFINE_integer('num_layers', 1, "No. of layers in each LSTM.") 59 | tf.app.flags.DEFINE_boolean('use_comp_lstm', True, 'Whether to use an LSTM for compose function.') 60 | tf.app.flags.DEFINE_boolean('use_pgen', True, 'Flag whether pointer mechanism should be used.') 61 | tf.app.flags.DEFINE_boolean('use_pretrained', True, 'Flag whether pre-trained word-vectors has to be used.') 62 | 63 | # Common sizes. 64 | tf.app.flags.DEFINE_integer('batch_size', 60, 'No. of examples in a batch of training data.') 65 | tf.app.flags.DEFINE_integer('val_batch_size', 60, 'No. of examples in a batch of validation data.') 66 | tf.app.flags.DEFINE_integer('beam_size', 5, 'Beam size for beam search decoding.') 67 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 'Size of the vocabulary.') 68 | tf.app.flags.DEFINE_integer('dim', 300, 'Dimension of the word embedding, it should be the dimension of the ' 69 | 'pre-trained word vectors used.') 70 | 71 | # Common optimizer flags. 72 | tf.app.flags.DEFINE_boolean('use_entropy', True, 'Whether to use entropy of sampling in loss.') 73 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'Maximum gradient norm when gradient clipping.') 74 | tf.app.flags.DEFINE_float('lr', 0.001, 'Learning rate.') 75 | 76 | # Common training flags. 77 | tf.app.flags.DEFINE_boolean('rouge_summary', True, 'A flag whether ROUGE has to be included in the summary.') 78 | tf.app.flags.DEFINE_integer('num_epochs', 20, 'No. of epochs to train the model.') 79 | tf.app.flags.DEFINE_integer('summary_every', 1, 'Write training summaries every few iterations.') 80 | tf.app.flags.DEFINE_integer('val_every', 4, 'No. of training epochs after which model should be validated.') 81 | tf.app.flags.DEFINE_string('mode', 'test', 'train/test mode.') 82 | tf.app.flags.DEFINE_string('model', 'hier', 'Which model to train: plain/hier/rlhier.') 83 | tf.app.flags.DEFINE_list('GPUs', [0], 'GPU ids to be used.') 84 | tf.app.flags.DEFINE_integer('num_pools', 5, 'No. of pools per GPU.') 85 | 86 | # Beam search decoding flags. 87 | tf.app.flags.DEFINE_integer('bs_enc_steps', 400, 'No. of time steps in the encoder.') 88 | tf.app.flags.DEFINE_integer('bs_dec_steps', 100, 'No. of time steps in the decoder.') 89 | 90 | # Hier model inputs. 91 | tf.app.flags.DEFINE_integer('bs_enc_sent', 20, 'Max. no. of sentences in the encoder.') 92 | tf.app.flags.DEFINE_integer('bs_enc_steps_per_sent', 20, 'Max. no. of tokens per a sentence in the encoder.') 93 | 94 | # Self critic policy gradients model. 95 | tf.app.flags.DEFINE_boolean('use_self_critic', False, 'Flag whether to use self critical model.') 96 | tf.app.flags.DEFINE_boolean('teacher_forcing', False, 'Flag whether to use teacher-forcing in greedy mode.') 97 | tf.app.flags.DEFINE_integer('num_samples', 1, 'No. of samples') 98 | tf.app.flags.DEFINE_boolean('use_discounted_rewards', False, 'Flag whether discounted rewards has to be used.') 99 | tf.app.flags.DEFINE_boolean('use_intermediate_rewards', False, 'Flag whether intermediate rewards has to be used.') 100 | tf.app.flags.DEFINE_float('gamma', 0.99, 'Discount Factor') 101 | tf.app.flags.DEFINE_float('eta', 2.5E-5, 'RL/MLE scaling factor.') 102 | tf.app.flags.DEFINE_float('eta1', 0.0, 'Cross-entropy weight.') 103 | tf.app.flags.DEFINE_float('eta2', 0.0, 'RL loss weight.') 104 | tf.app.flags.DEFINE_float('eta3', 1E-4, 'Entropy weight.') 105 | 106 | 107 | def main(args): 108 | main_start = time.time() 109 | 110 | tf.set_random_seed(2019) 111 | random.seed(2019) 112 | np.random.seed(2019) 113 | 114 | if len(args) != 1: 115 | raise Exception('Problem with flags: %s' % args) 116 | 117 | # Correcting a few flags for test/eval mode. 118 | if FLAGS.mode != 'train': 119 | FLAGS.batch_size = FLAGS.beam_size 120 | FLAGS.bs_dec_steps = FLAGS.dec_steps 121 | 122 | if FLAGS.model.lower() != "tx": 123 | FLAGS.dec_steps = 1 124 | 125 | assert FLAGS.mode == 'train' or FLAGS.batch_size == FLAGS.beam_size, \ 126 | "In test mode, batch size should be equal to beam size." 127 | 128 | assert FLAGS.mode == 'train' or FLAGS.dec_steps == 1 or FLAGS.model.lower() == "tx", \ 129 | "In test mode, no. of decoder steps should be one." 130 | 131 | os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' 132 | os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(str(gpu_id) for gpu_id in FLAGS.GPUs) 133 | 134 | if not os.path.exists(FLAGS.PathToCheckpoint): 135 | os.makedirs(FLAGS.PathToCheckpoint) 136 | 137 | if FLAGS.mode == "test" and not os.path.exists(FLAGS.PathToResults): 138 | os.makedirs(FLAGS.PathToResults) 139 | os.makedirs(FLAGS.PathToResults + 'predictions') 140 | os.makedirs(FLAGS.PathToResults + 'groundtruths') 141 | 142 | if FLAGS.mode == 'eval': 143 | eval_model(FLAGS.PathToResults) 144 | else: 145 | start = time.time() 146 | vocab = Vocab(max_vocab_size=FLAGS.vocab_size, emb_dim=FLAGS.dim, dataset_path=FLAGS.PathToDataset, 147 | glove_path=FLAGS.PathToGlove, vocab_path=FLAGS.PathToVocab, lookup_path=FLAGS.PathToLookups) 148 | 149 | if FLAGS.model.lower() == "plain": 150 | print("Setting up the plain model.\n") 151 | data = DataGenerator(path_to_dataset=FLAGS.PathToDataset, max_inp_seq_len=FLAGS.enc_steps, 152 | max_out_seq_len=FLAGS.dec_steps, vocab=vocab, 153 | use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample) 154 | summarizer = SummarizationModel(vocab, data) 155 | 156 | elif FLAGS.model.lower() == "hier": 157 | print("Setting up the hier model.\n") 158 | data = DataGeneratorHier(path_to_dataset=FLAGS.PathToDataset, max_inp_sent=FLAGS.max_enc_sent, 159 | max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent, 160 | max_out_tok=FLAGS.dec_steps, vocab=vocab, 161 | use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample) 162 | summarizer = SummarizationModelHier(vocab, data) 163 | 164 | elif FLAGS.model.lower() == "rlhier": 165 | print("Setting up the Hier RL model.\n") 166 | data = DataGeneratorHier(path_to_dataset=FLAGS.PathToDataset, max_inp_sent=FLAGS.max_enc_sent, 167 | max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent, 168 | max_out_tok=FLAGS.dec_steps, vocab=vocab, 169 | use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample) 170 | summarizer = SummarizationModelHierSC(vocab, data) 171 | 172 | else: 173 | raise ValueError("model flag should be either of plain/hier/bayesian/shared!! \n") 174 | 175 | end = time.time() 176 | print("Setting up vocab, data and model took {:.2f} sec.".format(end - start)) 177 | 178 | summarizer.build_graph() 179 | 180 | if FLAGS.mode == 'train': 181 | summarizer.train() 182 | elif FLAGS.mode == "test": 183 | summarizer.test() 184 | else: 185 | raise ValueError("mode should be either train/test!! \n") 186 | 187 | main_end = time.time() 188 | print("Total time elapsed: %.2f \n" % (main_end - main_start)) 189 | 190 | 191 | if __name__ == '__main__': 192 | tf.app.run() 193 | -------------------------------------------------------------------------------- /codes/NSE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Rajeev Bhatt Ambati" 3 | 4 | import tensorflow as tf 5 | tf.set_random_seed(2019) 6 | 7 | 8 | def create_rnn_cell(rnn_size, scope): 9 | """ 10 | This function creates and returns an RNN cell. 11 | :param rnn_size: Size of the hidden state. 12 | :param scope: scope for the RNN variables. 13 | :return: returns the RNN cell with the necessary specifications. 14 | """ 15 | with tf.variable_scope(scope): 16 | cell = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units=rnn_size, reuse=tf.AUTO_REUSE) 17 | 18 | return cell 19 | 20 | 21 | class NSE: 22 | """ 23 | This is a Neural Semantic Encoder class. 24 | """ 25 | 26 | def __init__(self, batch_size, dim, dense_init, mode='train', use_comp_lstm=False): 27 | """ 28 | :param batch_size: No. of examples in a batch of data. 29 | :param dim: Dimension of the memories. (Same as the dimension of wordvecs). 30 | :param dense_init: Dense kernel initializer. 31 | :param mode: 'train/val/test' mode. 32 | :param use_comp_lstm: Flag if LSTM should be used for compose function or MLP. 33 | """ 34 | self._batch_size, self._dim = batch_size, dim 35 | 36 | self._dense_init = dense_init # Initializer. 37 | 38 | self._mode = mode 39 | self._use_comp_lstm = use_comp_lstm 40 | 41 | self._read_scope = 'read' 42 | self._write_scope = 'write' 43 | self._comp_scope = 'compose' 44 | 45 | # Read LSTM 46 | self._read_lstm = create_rnn_cell(self._dim, self._read_scope) 47 | 48 | # Compose LSTM 49 | if self._use_comp_lstm: 50 | self._comp_lstm = create_rnn_cell(2*self._dim, self._comp_scope) 51 | 52 | # Write LSTM 53 | self._write_lstm = create_rnn_cell(self._dim, self._write_scope) 54 | 55 | def read(self, x_t, state=None): 56 | """ 57 | This is the read function. 58 | :param x_t: input sequence x, Shape: B x D. 59 | :param state: Previous hidden state of read LSTM. 60 | :return: r_t: Outputs of the read LSTM, Shape: B x D 61 | """ 62 | with tf.variable_scope(self._read_scope, reuse=tf.AUTO_REUSE): 63 | r_t, state = self._read_lstm(x_t, state) 64 | 65 | return r_t, state 66 | 67 | def compose(self, r_t, z_t, m_t, state=None): 68 | """ 69 | This is the compose function. 70 | :param r_t: Read from the input x_t, Shape: B x D. 71 | :param z_t: Attention distribution, Shape: B x T. 72 | :param m_t: Memory at time step 't', Shape: B x T x D. 73 | :param state: Previous hidden state of compose LSTM. 74 | :return: m_rt: Retrieved memory, Shape: B x D 75 | c_t: The composed vector, Shape: B x D. 76 | state: Hidden state of compose LSTM after previous time step if using it. 77 | """ 78 | # z_t is repeated across dimension. 79 | z_t_rep = tf.tile(tf.expand_dims(z_t, axis=-1), multiples=[1, 1, self._dim]) # Shape: B x T x D. 80 | m_rt = tf.reduce_sum(tf.multiply(z_t_rep, m_t), axis=1) # Retrieved memory, Shape: B x D 81 | 82 | with tf.variable_scope(self._comp_scope, reuse=tf.AUTO_REUSE): 83 | r_m_t = tf.concat([r_t, m_rt], axis=-1) # Shape: B x (2*D) 84 | 85 | # Compose LSTM 86 | if self._use_comp_lstm: 87 | r_m_t, state = self._comp_lstm(r_m_t, state) 88 | 89 | # Dense layer to reduce size from (2*D) to D. 90 | c_t = tf.layers.dense(inputs=r_m_t, 91 | units=self._dim, 92 | activation=None, 93 | kernel_initializer=self._dense_init, 94 | name='MLP') # Composed vector, Shape: B x D 95 | c_t = tf.nn.relu(c_t) # Activation function 96 | 97 | return m_rt, c_t, state 98 | 99 | def write(self, c_t, state=None): 100 | """ 101 | This is the write function. 102 | :param c_t: The composed vector, Shape: B x D. 103 | :param state: Previous hidden state of write LSTM. 104 | :return: h_t: The write vector, Shape: B x D 105 | """ 106 | with tf.variable_scope(self._write_scope, reuse=tf.AUTO_REUSE): 107 | h_t, state = self._write_lstm(c_t, state) 108 | 109 | return h_t, state 110 | 111 | def attention(self, r_t, m_t, mem_mask): 112 | """ 113 | This function computes the attention distribution. 114 | :param r_t: Read from the input x_t, Shape: B x D. 115 | :param m_t: Memory at time step 't', Shape: B x T x D. 116 | :param mem_mask: A mask to indicate the presence of PAD tokens, Shape: B x T. 117 | :return: 118 | attn_dist: The attention distribution at the current time-step, Shape: B x T. 119 | """ 120 | # Shapes 121 | attn_len, attn_vec_size = m_t.get_shape().as_list()[1: 3] 122 | 123 | with tf.variable_scope("attention", reuse=tf.AUTO_REUSE): 124 | # Input features. 125 | with tf.variable_scope("input", reuse=tf.AUTO_REUSE): 126 | input_features = tf.layers.dense(inputs=r_t, 127 | units=attn_vec_size, 128 | activation=None, 129 | kernel_initializer=self._dense_init, 130 | reuse=tf.AUTO_REUSE, 131 | name='inp_dense') # Shape: B x D. 132 | 133 | input_features = tf.expand_dims(input_features, axis=1) # Shape: B x 1 x D. 134 | 135 | # Memory features. 136 | with tf.variable_scope("memory", reuse=tf.AUTO_REUSE): 137 | memory_features = tf.layers.dense(inputs=m_t, 138 | units=attn_vec_size, 139 | activation=None, 140 | kernel_initializer=self._dense_init, 141 | use_bias=False, 142 | reuse=tf.AUTO_REUSE, 143 | name="memory_dense") # Shape: B x T x D. 144 | 145 | v = tf.get_variable("v", [attn_vec_size]) 146 | 147 | scores = tf.reduce_sum( 148 | v * tf.tanh(input_features + memory_features), axis=2) # Shape: B x T. 149 | attn_dist = tf.nn.softmax(scores) # Shape: B x T. 150 | 151 | # Assigning zero probability to the PAD tokens. 152 | # Re-normalizing the probability distribution to sum to one. 153 | attn_dist = tf.multiply(attn_dist, mem_mask) 154 | masked_sums = tf.reduce_sum(attn_dist, axis=1, keepdims=True) # Shape: B x 1 155 | attn_dist = tf.truediv(attn_dist, masked_sums) # Re-normalization. 156 | 157 | return attn_dist 158 | 159 | @staticmethod 160 | def update(z_t, m_t, h_t): 161 | """ 162 | This function updates the memory with write vectors as per the retrieved slots. 163 | :param z_t: Retrieved attention distribution over memory, Shape: B x T. 164 | :param m_t: Memory at current time step 't', Shape: B x T x D. 165 | :param h_t: Write vector, Shape: B x D. 166 | :return: new_m: The updated memory, Shape: B x T x D. 167 | """ 168 | # Write and erase mask for sentence memories. 169 | write_mask = tf.expand_dims(z_t, axis=2) # Shape: B x T x 1. 170 | erase_mask = tf.ones_like(write_mask) - write_mask # Shape: B x T x 1. 171 | 172 | # Write tensor 173 | write_tensor = tf.expand_dims(h_t, axis=1) # Shape: B x 1 x D. 174 | 175 | # Updated memory. 176 | new_m = tf.add(tf.multiply(m_t, erase_mask), tf.multiply(write_tensor, write_mask)) 177 | 178 | return new_m 179 | 180 | def prob_gen(self, m_rt, h_t, r_t): 181 | """ 182 | This function calculates the generation probability from the retrieved memory, write vector and the 183 | input read. 184 | :param m_rt: Retrieved memory, Shape: B x D. 185 | :param h_t: Write vector, Shape: B x D. 186 | :param r_t: Read vector from the input, Shape: B x D. 187 | :return: p_gen, Shape: B x 1 188 | """ 189 | with tf.variable_scope('pgen', reuse=tf.AUTO_REUSE): 190 | inp = tf.concat([m_rt, h_t, r_t], axis=-1) # Shape: B x (3*D) 191 | p_gen = tf.layers.dense(inp, 192 | units=1, 193 | activation=None, 194 | kernel_initializer=self._dense_init, 195 | name='pgen_dense') # Shape: B x 1 196 | 197 | p_gen = tf.nn.sigmoid(p_gen) # Sigmoid. 198 | 199 | return p_gen 200 | 201 | def step(self, x_t, mem_mask, prev_state, use_pgen=False): 202 | """ 203 | This function performs one-step of NSE. 204 | :param x_t: Input in the current time-step, Shape: B x D. 205 | :param mem_mask: Memory mask, Shape: B x T. 206 | :param prev_state: Internal state of NSE after the previous time step. 207 | [0] memory: The NSE memory, Shape: B x T x D. 208 | [1] read_state: Hidden state of the read LSTM, Shape: B x D. 209 | [2] write_state: Hidden state of the write LSTM, Shape: B x D. 210 | [3] comp_state: Hidden state of compose LSTM, Shape: B x (2*D). 211 | :param use_pgen: Flag whether pointer mechanism has to be used. 212 | :return: 213 | outputs: The outputs after the current time step. 214 | [0]: write vector: The written vector to NSE memory, Shape: B x D. 215 | [1]: p_attn: Attention distribution, Shape: B x T_in. 216 | Following additional output in pointer generator mode: 217 | [2]: p_gen: Generation probability, Shape: B x 1. 218 | state: Internal state of NSE after current time step. 219 | [memory, read_state, write_state, comp_state] 220 | """ 221 | prev_comp_state = None 222 | prev_memory, prev_read_state, prev_write_state = prev_state[: 3] 223 | if self._use_comp_lstm: 224 | prev_comp_state = prev_state[3] 225 | 226 | if prev_read_state is None: # zero state of read LSTM for the first time step. 227 | prev_read_state = self._read_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32) 228 | 229 | if prev_write_state is None: # zero state of write LSTM for the first time step. 230 | prev_write_state = self._write_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32) 231 | 232 | if (prev_comp_state is None) and self._use_comp_lstm: # zero state of compose LSTM for the first time step. 233 | prev_comp_state = self._comp_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32) 234 | 235 | r_t, curr_read_state = self.read(x_t, prev_read_state) # Read step. 236 | z_t = self.attention(r_t, prev_memory, mem_mask) # Attention. 237 | m_rt, c_t, curr_comp_state = self.compose(r_t, z_t, prev_memory, prev_comp_state) # Compose step. 238 | h_t, curr_write_state = self.write(c_t, prev_write_state) # Write step. 239 | curr_memory = self.update(z_t, prev_memory, h_t) # Memory update step. 240 | 241 | curr_state = [curr_memory, curr_read_state, curr_write_state] # Current NSE states. 242 | if self._use_comp_lstm: 243 | curr_state.append(curr_comp_state) 244 | 245 | outputs = [h_t, z_t] # Outputs after the current step. 246 | if use_pgen: # Pointer generator mode. 247 | p_gen = self.prob_gen(m_rt, h_t, r_t) 248 | outputs.append(p_gen) 249 | 250 | return outputs, curr_state 251 | -------------------------------------------------------------------------------- /codes/HierNSE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Rajeev Bhatt Ambati" 3 | 4 | import tensorflow as tf 5 | tf.set_random_seed(2018) 6 | 7 | 8 | def create_rnn_cell(rnn_size, num_layers, scope): 9 | """ 10 | This function creates and returns an RNN cell. 11 | :param rnn_size: Size of the hidden state. 12 | :param num_layers: No. of layers if using a multi_rnn cell. 13 | :param scope: scope for the RNN variables. 14 | :return: returns the RNN cell with the necessary specifications. 15 | """ 16 | with tf.variable_scope(scope): 17 | layers = [] 18 | for _ in range(num_layers): 19 | cell = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units=rnn_size, reuse=tf.AUTO_REUSE) 20 | layers.append(cell) 21 | 22 | if num_layers > 1: 23 | return tf.nn.rnn_cell.MultiRNNCell(layers) 24 | else: 25 | return layers[0] 26 | 27 | 28 | class HierNSE: 29 | """ 30 | This is a Hierarchical Neural Semantic Encoder class. 31 | """ 32 | 33 | def __init__(self, batch_size, dim, dense_init, mode='train', use_comp_lstm=False, num_layers=1): 34 | """ 35 | :param batch_size: No. of examples in a batch of data. 36 | :param dim: Dimension of the memories. (Same as the dimension of wordvecs). 37 | :param dense_init: Dense kernel initializer. 38 | :param mode: 'train/val/test' mode. 39 | :param use_comp_lstm: Flag if LSTM should be used for compose function or MLP. 40 | :param num_layers: Number of layers in the RNN cell. 41 | """ 42 | self._batch_size, self._dim = batch_size, dim 43 | 44 | self._dense_init = dense_init # Initializer. 45 | 46 | self._mode = mode 47 | self._use_comp_lstm = use_comp_lstm 48 | 49 | self._read_scope = 'read' 50 | self._write_scope = 'write' 51 | self._comp_scope = 'compose' 52 | 53 | # Read LSTM 54 | self._read_lstm = create_rnn_cell(self._dim, num_layers, self._read_scope) 55 | 56 | # Compose LSTM 57 | if self._use_comp_lstm: 58 | self._comp_lstm = create_rnn_cell(3*self._dim, num_layers, self._comp_scope) 59 | 60 | # Write LSTM 61 | self._write_lstm = create_rnn_cell(self._dim, num_layers, self._write_scope) 62 | 63 | def read(self, x_t, state=None): 64 | """ 65 | This is the read function. 66 | :param x_t: input sequence x, Shape: B x D. 67 | :param state: Previous hidden state of read LSTM. 68 | :return: r_t: Outputs of the read LSTM, Shape: B x D 69 | """ 70 | with tf.variable_scope(self._read_scope, reuse=tf.AUTO_REUSE): 71 | r_t, state = self._read_lstm(x_t, state) 72 | 73 | return r_t, state 74 | 75 | def compose(self, r_t, zs_t, ms_t, zd_t, md_t, state=None): 76 | """ 77 | This is the compose function. 78 | :param r_t: Read from the input x_t, Shape: B x D. 79 | :param zs_t: Attention distribution over sentence memory, Shape: B x T. 80 | :param ms_t: Sentence memory at time step 't', Shape: B x T x D. 81 | :param zd_t: Attention distribution over document memory, Shape: B x S. 82 | :param md_t: Document memory at time step 't', Shape: B x S x D. 83 | :param state: Previous hidden state of compose LSTM. 84 | :return: ms_rt: Retrieved sentence memory, Shape: B x D. 85 | md_rt: Retrieved document memory, Shape: B x D. 86 | c_t: The composed vector, Shape: B x D. 87 | state: Hidden state of compose LSTM after previous time step if using it. 88 | """ 89 | # Retrieved memories. 90 | ms_rt = tf.squeeze(tf.matmul(tf.expand_dims(zs_t, axis=1), ms_t), axis=1) # Sentence memory, Shape: B x D. 91 | md_rt = tf.squeeze(tf.matmul(tf.expand_dims(zd_t, axis=1), md_t), axis=1) # Document memory, Shape: B x D. 92 | 93 | with tf.variable_scope(self._comp_scope, reuse=tf.AUTO_REUSE): 94 | r_m_t = tf.concat([r_t, ms_rt, md_rt], axis=-1) # B x (3*D). 95 | 96 | # Compose LSTM 97 | if self._use_comp_lstm: 98 | r_m_t, state = self._comp_lstm(r_m_t, state) 99 | 100 | # Dense layer to reduce size from 3*D to D. 101 | c_t = tf.layers.dense(inputs=r_m_t, 102 | units=self._dim, 103 | activation=None, 104 | kernel_initializer=self._dense_init, 105 | name='MLP') # Composed vector, Shape: B x D. 106 | c_t = tf.nn.relu(c_t) # Activation Function. 107 | 108 | return ms_rt, md_rt, c_t, state 109 | 110 | def write(self, c_t, state=None): 111 | """ 112 | This function implements the write operation - equation 5 from the paper [1]. 113 | :param c_t: The composed vector, Shape: B x D. 114 | :param state: Previous hidden state of write LSTM. 115 | :return: h_t: The write vector, Shape: B x D 116 | """ 117 | with tf.variable_scope(self._write_scope, reuse=tf.AUTO_REUSE): 118 | h_t, state = self._write_lstm(c_t, state) 119 | 120 | return h_t, state 121 | 122 | def attention(self, r_t, m_t, mem_mask, scope="attention"): 123 | """ 124 | This function computes the attention distribution. 125 | :param r_t: Read from the input x_t, Shape: B x D. 126 | :param m_t: Memory at time step 't', Shape: B x T' x D. 127 | T' is T for sentence memory and S for document memory. 128 | :param mem_mask: A mask to indicate the presence of PAD tokens, Shape: B x T'. 129 | :param scope: Name of the scope. 130 | :return: 131 | """ 132 | # Shapes 133 | batch_size, attn_len, attn_vec_size = m_t.get_shape().as_list() 134 | 135 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 136 | # Input features. 137 | with tf.variable_scope("input"): 138 | input_features = tf.layers.dense(inputs=r_t, 139 | units=attn_vec_size, 140 | activation=None, 141 | kernel_initializer=self._dense_init, 142 | reuse=tf.AUTO_REUSE, 143 | name="inp_dense") # Shape: B x D. 144 | input_features = tf.expand_dims(input_features, axis=1) # Shape: B x 1 x D. 145 | 146 | # Memory features 147 | with tf.variable_scope("memory"): 148 | memory_features = tf.layers.dense(inputs=m_t, 149 | units=attn_vec_size, 150 | activation=None, 151 | kernel_initializer=self._dense_init, 152 | use_bias=False, 153 | reuse=tf.AUTO_REUSE, 154 | name="memory_dense") # Shape: B x T x D. 155 | 156 | v = tf.get_variable("v", [attn_vec_size]) 157 | 158 | scores = tf.reduce_sum( 159 | v * tf.tanh(input_features + memory_features), axis=2) # Shape: B x T'. 160 | attn_dist = tf.nn.softmax(scores) # Shape: B x T' 161 | 162 | # Assigning zero probability to the PAD tokens. 163 | # Re-normalizing the probability distribution to sum to one. 164 | attn_dist = tf.multiply(attn_dist, mem_mask) 165 | masked_sums = tf.reduce_sum(attn_dist, axis=1, keepdims=True) # Shape: B x 1 166 | attn_dist = tf.truediv(attn_dist, masked_sums) # Re-normalization. 167 | 168 | return attn_dist 169 | 170 | @ staticmethod 171 | def update(zs_t, ms_t, zd_t, md_t, h_t): 172 | """ 173 | This function updates the sentence and document memories as per the retrieved slots. 174 | :param zs_t: Retrieved attention distribution over sentence memory, Shape: B x T. 175 | :param ms_t: Sentence memory at time step 't', Shape: B x T x D. 176 | :param zd_t: Retrieved attention distribution over document memory, Shape: B x S. 177 | :param md_t: Document memory at time step 't', Shape: B x S x D. 178 | :param h_t: Write vector, Shape: B x D. 179 | :return: new_ms: Updated sentence memory, Shape: B x T x D. 180 | new_md: Updated document memory, Shape: B x S x D. 181 | """ 182 | 183 | # Write and erase masks for sentence memories. 184 | write_mask_s = tf.expand_dims(zs_t, axis=2) # Shape: B x T x 1. 185 | erase_mask_s = tf.ones_like(write_mask_s) - write_mask_s # Shape: B x T x 1. 186 | 187 | # Write and erase masks for document memories. 188 | write_mask_d = tf.expand_dims(zd_t, axis=2) # Shape: B x S x 1. 189 | erase_mask_d = tf.ones_like(write_mask_d) - write_mask_d # Shape: B x S x 1. 190 | 191 | # Write tensors for sentence and document memories. 192 | write_tensor_s = tf.expand_dims(h_t, axis=1) # Shape: B x 1 x D. 193 | write_tensor_d = tf.expand_dims(h_t, axis=1) # Shape: B x 1 x D. 194 | 195 | new_ms = tf.add(tf.multiply(ms_t, erase_mask_s), tf.multiply(write_tensor_s, write_mask_s)) 196 | new_md = tf.add(tf.multiply(md_t, erase_mask_d), tf.multiply(write_tensor_d, write_mask_d)) 197 | 198 | return new_ms, new_md 199 | 200 | def prob_gen(self, ms_rt, md_rt, h_t, r_t): 201 | """ 202 | This function calculates the generation probability from the retrieved sentence, document memories, write 203 | vector and the input read. 204 | :param ms_rt: Retrieved sentence memory, Shape: B x D. 205 | :param md_rt: Retrieved document memory, Shape: B x D. 206 | :param h_t: Write vector, Shape: B x D 207 | :param r_t: Read vector from the input, Shape: B x D. 208 | :return: p_gen, Shape: B x 1. 209 | """ 210 | with tf.variable_scope('pgen', reuse=tf.AUTO_REUSE): 211 | inp = tf.concat([ms_rt, md_rt, h_t, r_t], axis=1) # Shape: B x (4*D). 212 | p_gen = tf.layers.dense(inp, 213 | units=1, 214 | activation=None, 215 | kernel_initializer=self._dense_init, 216 | name='pgen_dense') # Dense Layer, Shape: B x 1. 217 | p_gen = tf.nn.sigmoid(p_gen) # Sigmoid. 218 | 219 | return p_gen 220 | 221 | def step(self, x_t, mem_masks, prev_state, use_pgen=False): 222 | """ 223 | This function performs one-step of Hier-NSE. 224 | :param x_t: Input in the current time-step. 225 | :param mem_masks: [mem_mask_s, mem_mask_d]: Masks for sentence and document memory respectively 226 | indicating the presence of PAD tokens, Shape: [B x T, B x S]. 227 | :param prev_state: Internal state of the NSE after the previous time-step. 228 | [0]: [memory_s, memory_d]: The NSE sentence and document memory 229 | respectively, Shape: [B x T x D, B x S x D]. 230 | [1]: read_state: Hidden state of the Read LSTM, Shape: B x D. 231 | [2]: write_state: Hidden state of the write LSTM, Shape: B x D. 232 | [3]: comp_state: Hidden state of the compose LSTM, Shape: B x (3*D). 233 | :param use_pgen: Flag whether pointer mechanism has to be used. 234 | :return: 235 | outputs: The following outputs after the current time step. 236 | [0]: write vector: The written vector to NSE memory, Shape: B x D. 237 | [1]: p_attn: [zs_t, zd_t] Attention distribution for sentence and 238 | document memories respectively, Shape: [B x T, B x S]. 239 | Following additional output in pointer generator mode. 240 | [2]: p_gen: Generation probability, Shape: B x 1. 241 | Following additional output when using coverage. 242 | [3]: [curr_cov_s, curr_cov_d]: Updated coverages for sentence and document memory respectively. 243 | state: The internal state of NSE after the current time-step. 244 | [[memory_s, memory_d], read_state, write_state, comp_state]. 245 | """ 246 | # Memory masks 247 | mem_mask_s, mem_mask_d = mem_masks 248 | 249 | # NSE Internal states. 250 | prev_comp_state = None 251 | [prev_memory_s, prev_memory_d], prev_read_state, prev_write_state = prev_state[: 3] 252 | if self._use_comp_lstm: 253 | prev_comp_state = prev_state[3] 254 | 255 | if prev_read_state is None: # zero state of read LSTM for the first time step. 256 | prev_read_state = self._read_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32) 257 | 258 | if prev_write_state is None: # zero state of write LSTM for the first time step. 259 | prev_write_state = self._write_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32) 260 | 261 | if (prev_comp_state is None) and self._use_comp_lstm: # zero state of compose LSTM for the first time step. 262 | prev_comp_state = self._comp_lstm.zero_state(batch_size=self._batch_size, dtype=tf.float32) 263 | 264 | r_t, curr_read_state = self.read(x_t, prev_read_state) # Read step. 265 | zs_t = self.attention( 266 | r_t=r_t, m_t=prev_memory_s, mem_mask=mem_mask_s 267 | ) # Sentence attention distribution. 268 | zd_t = self.attention( 269 | r_t=r_t, m_t=prev_memory_d, mem_mask=mem_mask_d 270 | ) # Document attention distribution. 271 | 272 | ms_rt, md_rt, c_t, curr_comp_state = self.compose( 273 | r_t, zs_t, prev_memory_s, zd_t, prev_memory_d, prev_comp_state 274 | ) # Compose step. 275 | h_t, curr_write_state = self.write(c_t, prev_write_state) # Write step. 276 | curr_memory_s, curr_memory_d = self.update( 277 | zs_t, prev_memory_s, zd_t, prev_memory_d, h_t 278 | ) # Update step. 279 | 280 | curr_state = [[curr_memory_s, curr_memory_d], curr_read_state, curr_write_state] # Current NSE states. 281 | if self._use_comp_lstm: 282 | curr_state.append(curr_comp_state) 283 | 284 | outputs = [h_t, [zs_t, zd_t]] 285 | if use_pgen: 286 | p_gen = self.prob_gen(ms_rt, md_rt, h_t, r_t) # Generation probability. 287 | outputs.append(p_gen) 288 | 289 | return outputs, curr_state 290 | -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Rajeev Bhatt Ambati" 3 | """ 4 | Acknowledgement: Most of the functions related to vocab object are either inspired/took from the 5 | Pointer Generator Network repository published by See. et al 2017. 6 | GitHub link: https://github.com/abisee/pointer-generator/blob/master/data.py 7 | """ 8 | 9 | import params 10 | import os 11 | import hashlib 12 | from shutil import copyfile 13 | import math 14 | import collections 15 | import random 16 | from random import shuffle, randint, sample 17 | from gensim.models import KeyedVectors 18 | 19 | import pickle 20 | import numpy as np 21 | import tensorflow as tf 22 | # import pyrouge 23 | 24 | # Setting-up seeds 25 | random.seed(2019) 26 | np.random.seed(2019) 27 | tf.set_random_seed(2019) 28 | 29 | 30 | def read_text_file(text_file): 31 | 32 | lines = [] 33 | with open(text_file, "r") as f: 34 | for line in f: 35 | lines.append(line.strip()) 36 | 37 | return lines 38 | 39 | 40 | def hashhex(s): 41 | """ 42 | Returns a heximal formatted SHA1 hash of the input string. 43 | """ 44 | h = hashlib.sha1() 45 | h.update(s.encode('utf-8')) 46 | 47 | return h.hexdigest() 48 | 49 | 50 | def get_url_hashes(url_list): 51 | return [hashhex(url) for url in url_list] 52 | 53 | 54 | def create_set(path_to_data, split): 55 | if os.path.exists(path_to_data + split): 56 | return 57 | else: 58 | os.makedirs(path_to_data + split) 59 | 60 | url_list = read_text_file(path_to_data + "all_" + split + ".txt") 61 | url_hashes = get_url_hashes(url_list) 62 | story_fnames = [s + ".story" for s in url_hashes] 63 | 64 | print("{}:".format(split)) 65 | for i, s in enumerate(story_fnames): 66 | if os.path.isfile(path_to_data + "cnn_stories_tokenized/" + s): 67 | src = path_to_data + "cnn_stories_tokenized/" + s 68 | 69 | elif os.path.isfile(path_to_data + "dm_stories_tokenized/" + s): 70 | src = path_to_data + "dm_stories_tokenized/" + s 71 | 72 | else: 73 | raise Exception("Story file {} is not found!".format(s)) 74 | 75 | trg = path_to_data + split + "/" + s 76 | 77 | copyfile(src, trg) 78 | print("files done: {}/{}".format(i, len(story_fnames))) 79 | 80 | 81 | def create_train_val_test(path_to_data="../data/"): 82 | """ 83 | :param path_to_data: Path to data. 84 | This function creates train, test and validation sets. 85 | :return: 86 | """ 87 | # Train. 88 | create_set(path_to_data, "train") 89 | 90 | # Validation. 91 | create_set(path_to_data, "val") 92 | 93 | # Test. 94 | create_set(path_to_data, "test") 95 | 96 | 97 | def fix_missing_period(line): 98 | if "@highlight" in line: 99 | return line 100 | 101 | if line == "": 102 | return line 103 | 104 | if line[-1] in params.END_TOKENS: 105 | return line[: -1] + " " + line[-1] 106 | 107 | return line + " ." 108 | 109 | 110 | def get_article_summary(story_file, art_format="tokens", sum_format="tokens"): 111 | # Read the story file. 112 | f = open(story_file, 'r', encoding='ISO-8859-1') 113 | 114 | # Lowercase everything. 115 | lines = [line.strip().lower() for line in f.readlines()] 116 | 117 | # Some lines don't have periods in the end. Correct'em. 118 | lines = [fix_missing_period(line) for line in lines] 119 | 120 | # Separate article and abstract sentences. 121 | article_lines = [] 122 | highlights = [] 123 | next_is_highlight = False 124 | 125 | for idx, line in enumerate(lines): 126 | 127 | if line == "": # Empty line. 128 | continue 129 | 130 | elif line.startswith("@highlight"): 131 | next_is_highlight = True 132 | 133 | elif next_is_highlight: 134 | highlights.append(line) 135 | 136 | else: 137 | article_lines.append(line) 138 | 139 | # Joining all article sentences together into a string. 140 | article = ' '.join(article_lines) 141 | 142 | # Joining all highlights together into a string including sentence 143 | # start and end . 144 | summary = ' '.join(["%s %s %s" % (params.SENTENCE_START, sent, 145 | params.SENTENCE_END) for sent in highlights]) 146 | 147 | if art_format == "tokens": 148 | 149 | if sum_format == "tokens": 150 | return article, summary 151 | elif sum_format == "sentences": 152 | return article, highlights 153 | else: 154 | raise ValueError("Article format should be either tokens or sentences!! \n") 155 | 156 | elif art_format == "sentences": 157 | if sum_format == "tokens": 158 | return article_lines, summary 159 | elif sum_format == "sentences": 160 | return article_lines, highlights 161 | else: 162 | raise ValueError("Article format should be either tokens or sentences!! \n") 163 | else: 164 | raise ValueError("Article format should be either tokens or sentences!! \n") 165 | 166 | 167 | def article2ids(article_words, vocab): 168 | """ 169 | This function converts given article words to ID's 170 | :param article_words: article tokens. 171 | :param vocab: The vocabulary object used for lookup tables, vocabulary etc. 172 | :return: The corresponding ID's and a list of OOV tokens. 173 | """ 174 | ids = [] 175 | oovs = [] 176 | unk_id = vocab.word2id(params.UNKNOWN_TOKEN) 177 | for word in article_words: 178 | i = vocab.word2id(word) 179 | if i == unk_id: # Out of vocabulary words. 180 | if word in oovs: 181 | ids.append(vocab.size() + oovs.index(word)) 182 | else: 183 | oovs.append(word) 184 | ids.append(vocab.size() + oovs.index(word)) 185 | else: # In vocabulary words. 186 | ids.append(i) 187 | 188 | return ids, oovs 189 | 190 | 191 | def summary2ids(summary_words, vocab, article_oovs): 192 | """ 193 | This function converts the given summary words to ID's 194 | :param summary_words: summary tokens. 195 | :param vocab: The vocabulary object used for lookup tables, vocabulary etc. 196 | :param article_oovs: OOV tokens in the input article. 197 | :return: The corresponding ID's. 198 | """ 199 | num_sent, sent_len = 0, 0 200 | if type(summary_words[0]) is list: 201 | num_sent = len(summary_words) 202 | sent_len = len(summary_words[0]) 203 | 204 | cum_words = [] 205 | for _, sent_words in enumerate(summary_words): 206 | cum_words += sent_words 207 | 208 | summary_words = cum_words 209 | 210 | ids = [] 211 | unk_id = vocab.word2id(params.UNKNOWN_TOKEN) 212 | for word in summary_words: 213 | i = vocab.word2id(word) 214 | if i == unk_id: # Out of vocabulary words. 215 | if word in article_oovs: # In article OOV words. 216 | ids.append(vocab.size() + article_oovs.index(word)) 217 | else: # Both OOV and article OOV words. 218 | ids.append(unk_id) 219 | else: # In vocabulary words. 220 | ids.append(i) 221 | 222 | if num_sent != 0: 223 | doc_ids = [] 224 | for i in range(num_sent): 225 | doc_ids.append(ids[sent_len * i: sent_len * (i + 1)]) 226 | 227 | ids = doc_ids 228 | 229 | return ids 230 | 231 | 232 | class Vocab(object): 233 | def __init__(self, max_vocab_size, emb_dim=300, dataset_path='../data/', glove_path='../data/glove.840B.300d.txt', 234 | vocab_path='../data/vocab.txt', lookup_path='../data/lookup.pkl'): 235 | 236 | self.max_size = max_vocab_size 237 | self._dim = emb_dim 238 | self.PathToGloveFile = glove_path 239 | self.PathToVocabFile = vocab_path 240 | self.PathToLookups = lookup_path 241 | 242 | create_train_val_test(dataset_path) 243 | 244 | stories = os.listdir(dataset_path + 'train') # Using only train files for Vocab, 245 | 246 | # All train Stories. 247 | self._story_files = [ 248 | os.path.join(dataset_path + 'train/', s) for s in stories] 249 | 250 | self.vocab = [] # Vocabulary 251 | 252 | # Create the vocab file. 253 | self.create_total_vocab() 254 | 255 | # Create the lookup tables. 256 | self.wvecs = [] # Word vectors. 257 | self._word_to_id = {} # word to ID's lookups 258 | self._id_to_word = {} # ID to word lookups 259 | 260 | self.create_lookup_tables() 261 | 262 | assert len(self._word_to_id.keys()) == len(self._id_to_word.keys()), "Both lookups should have same size." 263 | 264 | def size(self): 265 | return len(self.vocab) 266 | 267 | def word2id(self, word): 268 | """ 269 | This function returns the vocabulary ID for word if it is present. Otherwise, returns the ID 270 | for the unknown token. 271 | :param word: input word. 272 | :return: returns the ID. 273 | """ 274 | if word in self._word_to_id: 275 | return self._word_to_id[word] 276 | else: 277 | return self._word_to_id[params.UNKNOWN_TOKEN] 278 | 279 | def id2word(self, word_id): 280 | """ 281 | This function returns the corresponding word for a given vocabulary ID. 282 | :param word_id: input ID. 283 | :return: returns the word. 284 | """ 285 | if word_id in self._id_to_word: 286 | return self._id_to_word[word_id] 287 | else: 288 | raise ValueError("{} is not a valid ID.\n".format(word_id)) 289 | 290 | def create_total_vocab(self): 291 | 292 | if os.path.isfile(self.PathToVocabFile): 293 | print("Vocab file exists! \n") 294 | 295 | vocab_f = open(self.PathToVocabFile, 'r') 296 | for line in vocab_f: 297 | word = line.split()[0] 298 | self.vocab.append(word) 299 | 300 | return 301 | else: 302 | print("Vocab file NOT found!! \n") 303 | print("Creating a vocab file! \n") 304 | 305 | vocab_counter = collections.Counter() 306 | 307 | for idx, story in enumerate(self._story_files): 308 | article, summary = get_article_summary(story) 309 | 310 | art_tokens = article.split(' ') # Article tokens. 311 | sum_tokens = summary.split(' ') # Summary tokens. 312 | sum_tokens = [t for t in sum_tokens if t not in # Removing , tokens. 313 | [params.SENTENCE_START, params.SENTENCE_END]] 314 | 315 | assert (params.SENTENCE_START not in sum_tokens) and (params.SENTENCE_END not in sum_tokens), \ 316 | " and shouldn't be present in sum_tokens" 317 | 318 | tokens = art_tokens + sum_tokens 319 | tokens = [t.strip() for t in tokens] 320 | tokens = [t for t in tokens if t != ''] # Removing empty tokens. 321 | 322 | vocab_counter.update(tokens) # Keeping a count of the tokens. 323 | 324 | print("\r{}/{} files read!".format(idx + 1, len(self._story_files))) 325 | 326 | print("\n Writing the vocab file! \n") 327 | f = open(self.PathToVocabFile, 'w', encoding='utf-8') 328 | 329 | for word, count in vocab_counter.most_common(params.VOCAB_SIZE): 330 | f.write(word + ' ' + str(count) + '\n') 331 | self.vocab.append(word) 332 | 333 | f.close() 334 | 335 | def create_small_vocab(self): 336 | """ 337 | This function selects a few words out of the total vocabulary. 338 | """ 339 | 340 | # Read the vocab file and assign id's to each word till the max_size. 341 | vocab_f = open(self.PathToVocabFile, 'r') 342 | 343 | for line in vocab_f: 344 | word = line.split()[0] 345 | 346 | if word in [params.SENTENCE_START, params.SENTENCE_END, params.UNKNOWN_TOKEN, 347 | params.PAD_TOKEN, params.START_DECODING, params.STOP_DECODING]: 348 | raise Exception(', , [UNK], [PAD], [START], \ 349 | [STOP] shouldn\'t be in the vocab file, but %s is' % word) 350 | 351 | self.vocab.append(word) 352 | print("\r{}/{} words created!".format(len(self.vocab), self.max_size)) 353 | 354 | if len(self.vocab) == self.max_size: 355 | print("\n Max size of the vocabulary reached! Stopping reading! \n") 356 | break 357 | 358 | def create_lookup_tables_(self): 359 | """ 360 | This function creates lookup tables for word vectors, word to IDs 361 | and ID to words. First max_size words from GloVe that are also found in the small vocab are used 362 | to create the lookup tables. 363 | """ 364 | 365 | if os.path.isfile(self.PathToLookups): 366 | print('\n Lookup tables found :) \n') 367 | f = open(self.PathToLookups, 'rb') 368 | data = pickle.load(f) 369 | 370 | self._word_to_id = data['word2id'] 371 | self._id_to_word = data['id2word'] 372 | self.wvecs = data['wvecs'] 373 | self.vocab = list(self._word_to_id.keys()) 374 | 375 | print('Lookup tables collected for {} tokens.\n'.format(len(self.vocab))) 376 | return 377 | else: 378 | print('\n Lookup files NOT found!! \n') 379 | print('\n Creating the lookup tables! \n') 380 | 381 | self.create_small_vocab() # Creating a small vocabulary. 382 | self.wvecs = [] # Word vectors. 383 | 384 | glove_f = open(self.PathToGloveFile, 'r', encoding='utf8') 385 | count = 0 386 | 387 | # [UNK], [PAD], [START] and [STOP] get ids 0, 1, 2, 3. 388 | for w in [params.UNKNOWN_TOKEN, params.PAD_TOKEN, params.START_DECODING, params.STOP_DECODING]: 389 | self._word_to_id[w] = count 390 | self._id_to_word[count] = w 391 | self.wvecs.append(np.random.uniform(-0.1, 0.1, (self._dim,)).astype(np.float32)) 392 | count += 1 393 | 394 | print("\r Created tables for {}".format(w)) 395 | 396 | for line in glove_f: 397 | vals = line.rstrip().split(' ') 398 | w = vals[0] 399 | vec = np.array(vals[1:]).astype(np.float32) 400 | 401 | if w in self.vocab: 402 | self._word_to_id[w] = count 403 | self._id_to_word[count] = w 404 | self.wvecs.append(vec) 405 | count += 1 406 | 407 | print("\r Created tables for {}".format(w)) 408 | 409 | if count == self.max_size: 410 | print("\r Maximum vocab size reached! \n") 411 | break 412 | 413 | print("\n Lookup tables created for {} tokens. \n".format(count)) 414 | 415 | self.wvecs = np.array(self.wvecs).astype(np.float32) # Converting to a Numpy array. 416 | self.vocab = list(self._word_to_id.keys()) # Adjusting the vocabulary to found pre-trained vectors. 417 | 418 | # Saving the lookup tables. 419 | f = open(self.PathToLookups, 'wb') 420 | data = {'word2id': self._word_to_id, 421 | 'id2word': self._id_to_word, 422 | 'wvecs': self.wvecs} 423 | pickle.dump(data, f) 424 | 425 | def create_lookup_tables(self): 426 | """ 427 | This function creates lookup tables for word vectors, word to IDs 428 | and ID to words. First max_size words from GloVe that are also found in the small vocab are used 429 | to create the lookup tables. 430 | """ 431 | 432 | if os.path.isfile(self.PathToLookups): 433 | print('\n Lookup tables found :) \n') 434 | f = open(self.PathToLookups, 'rb') 435 | data = pickle.load(f) 436 | 437 | self._word_to_id = data['word2id'] 438 | self._id_to_word = data['id2word'] 439 | self.wvecs = data['wvecs'] 440 | self.vocab = list(self._word_to_id.keys()) 441 | 442 | print('Lookup tables collected for {} tokens.\n'.format(len(self.vocab))) 443 | return 444 | else: 445 | print('\n Lookup files NOT found!! \n') 446 | print('\n Creating the lookup tables! \n') 447 | 448 | self.wvecs = [] # Word vectors. 449 | 450 | word2vec = KeyedVectors.load_word2vec_format(self.PathToGloveFile, binary=False) 451 | 452 | count = 0 453 | # [UNK], [PAD], [START] and [STOP] get ids 0, 1, 2, 3. 454 | for w in [params.UNKNOWN_TOKEN, params.PAD_TOKEN, params.START_DECODING, params.STOP_DECODING]: 455 | self._word_to_id[w] = count 456 | self._id_to_word[count] = w 457 | self.wvecs.append(np.random.uniform(-0.1, 0.1, (self._dim,)).astype(np.float32)) 458 | count += 1 459 | 460 | print("\r Created tables for {}".format(w)) 461 | 462 | vocab_f = open(self.PathToVocabFile, "r") 463 | for line in vocab_f: 464 | word = line.split()[0] 465 | 466 | if word in word2vec: 467 | self._word_to_id[word] = count 468 | self._id_to_word[count] = word 469 | self.wvecs.append(word2vec[word]) 470 | count += 1 471 | 472 | print("\r Created tables for {}".format(word)) 473 | 474 | if count == self.max_size: 475 | print("\r Maximum vocab size reached! \n") 476 | break 477 | 478 | print("\n Lookup tables created for {} tokens. \n".format(count)) 479 | 480 | self.wvecs = np.array(self.wvecs).astype(np.float32) # Converting to a Numpy array. 481 | self.vocab = list(self._word_to_id.keys()) # Adjusting the vocabulary to found pre-trained vectors. 482 | 483 | # Saving the lookup tables. 484 | f = open(self.PathToLookups, 'wb') 485 | data = {'word2id': self._word_to_id, 486 | 'id2word': self._id_to_word, 487 | 'wvecs': self.wvecs} 488 | pickle.dump(data, f) 489 | 490 | 491 | class DataGenerator(object): 492 | 493 | def __init__(self, path_to_dataset, max_inp_seq_len, max_out_seq_len, vocab, use_pgen=False, use_sample=False): 494 | # Train files. 495 | train_stories = os.listdir(path_to_dataset + 'train') 496 | self.train_files = [os.path.join(path_to_dataset + 'train', s) for s in train_stories] 497 | self.num_train_examples = len(self.train_files) 498 | shuffle(self.train_files) 499 | 500 | # Validation files. 501 | val_stories = os.listdir(path_to_dataset + 'val') 502 | self.val_files = [os.path.join(path_to_dataset + 'val', s) for s in val_stories] 503 | self.num_val_examples = len(self.val_files) 504 | shuffle(self.val_files) 505 | 506 | # Test files. 507 | test_stories = os.listdir(path_to_dataset + 'test') 508 | self.test_files = [os.path.join(path_to_dataset + 'test', s) for s in test_stories] 509 | self.num_test_examples = len(self.test_files) 510 | # shuffle(self.test_files) 511 | 512 | self._max_enc_steps = max_inp_seq_len # Max. no. of tokens in the input sequence. 513 | self._max_dec_steps = max_out_seq_len # Max. no. of tokens in the output sequence. 514 | 515 | self.vocab = vocab # Vocabulary instance. 516 | self._use_pgen = use_pgen # Whether pointer mechanism should be used. 517 | 518 | self._ptr = 0 # Pointer for batching the data. 519 | 520 | if use_sample: 521 | # **************************** PATCH ************************* # 522 | self.train_files = self.train_files[:20] 523 | self.num_train_examples = len(self.train_files) 524 | self.val_files = self.val_files[:20] 525 | self.num_val_examples = len(self.val_files) 526 | self.test_files = self.test_files[:23] 527 | self.num_test_examples = len(self.test_files) 528 | # **************************** PATCH ************************* # 529 | 530 | print("Split the data as follows:\n") 531 | print("\t\t Training: {} examples. \n".format(self.num_train_examples)) 532 | print("\t\t Validation: {} examples. \n".format(self.num_val_examples)) 533 | print("\t\t Test: {} examples. \n".format(self.num_test_examples)) 534 | 535 | def get_train_val_batch(self, batch_size, split='train', permutate=False): 536 | if split == 'train': 537 | num_examples = self.num_train_examples 538 | files = self.train_files 539 | elif split == 'val': 540 | num_examples = self.num_val_examples 541 | files = self.val_files 542 | else: 543 | raise ValueError("split is neither train nor val. check the function call!") 544 | 545 | enc_inp = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32) 546 | dec_inp = np.ndarray(shape=(batch_size, self._max_dec_steps), dtype=np.int32) 547 | dec_out = np.ndarray(shape=(batch_size, self._max_dec_steps), dtype=np.int32) 548 | 549 | enc_inp_ext_vocab = None 550 | max_oov_size = -np.infty 551 | if self._use_pgen: 552 | enc_inp_ext_vocab = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32) 553 | 554 | # Shuffle files at the start of an epoch. 555 | if self._ptr == 0: 556 | shuffle(files) 557 | 558 | # Start and end index for a batch of data. 559 | start = self._ptr 560 | end = self._ptr + batch_size 561 | self._ptr = end 562 | 563 | for i in range(start, end): 564 | j = i - start # Index of the example in current batch. 565 | article, summary = get_article_summary(files[i]) 566 | enc_inp_tokens = article.split(' ') # Article tokens. 567 | 568 | # Article Tokens 569 | if len(enc_inp_tokens) >= self._max_enc_steps: # Truncate. 570 | if permutate: 571 | indcs = sorted(sample(range(len(enc_inp_tokens)), self._max_enc_steps)) 572 | enc_inp_tokens = [enc_inp_tokens[i] for i in indcs] 573 | else: 574 | enc_inp_tokens = enc_inp_tokens[: self._max_enc_steps] 575 | else: # Pad. 576 | enc_inp_tokens += (self._max_enc_steps - len(enc_inp_tokens)) * [params.PAD_TOKEN] 577 | 578 | # Encoder Input 579 | enc_inp_ids = [self.vocab.word2id(w) for w in enc_inp_tokens] # Word to ID's 580 | 581 | # Summary Tokens 582 | sum_tokens = summary.split(' ') # Summary tokens. 583 | sum_tokens = [t for t in sum_tokens if t not in # Removing , tokens. 584 | [params.SENTENCE_START, params.SENTENCE_END]] 585 | 586 | if len(sum_tokens) > self._max_dec_steps - 1: # Truncate. 587 | sum_tokens = sum_tokens[: self._max_dec_steps - 1] 588 | 589 | # Decoder Input 590 | dec_inp_tokens = [params.START_DECODING] + sum_tokens 591 | if len(dec_inp_tokens) < self._max_dec_steps: 592 | dec_inp_tokens += (self._max_dec_steps - len(dec_inp_tokens)) * [params.PAD_TOKEN] 593 | 594 | # Decoder Output 595 | dec_out_tokens = sum_tokens + [params.STOP_DECODING] 596 | dec_out_len = len(dec_out_tokens) 597 | if dec_out_len < self._max_dec_steps: 598 | dec_out_tokens += (self._max_dec_steps - dec_out_len) * [params.PAD_TOKEN] 599 | 600 | dec_inp_ids = [self.vocab.word2id(w) for w in dec_inp_tokens] 601 | dec_out_ids = [self.vocab.word2id(w) for w in dec_out_tokens] 602 | 603 | enc_inp_ids_ext_vocab = None 604 | if self._use_pgen: 605 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self.vocab) 606 | dec_out_ids = summary2ids(dec_out_tokens, self.vocab, article_oovs) 607 | 608 | if len(article_oovs) > max_oov_size: 609 | max_oov_size = len(article_oovs) 610 | 611 | # Appending to the batch of inputs. 612 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32) # Appending to the enc_inp batch. 613 | dec_inp[j] = np.array(dec_inp_ids).astype(np.int32) # Appending to the dec_inp batch. 614 | dec_out[j] = np.array(dec_out_ids).astype(np.int32) # Appending to the dec_out batch. 615 | 616 | if self._use_pgen: 617 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32) 618 | 619 | # Resetting the pointer after the last batch 620 | if self._ptr == num_examples: 621 | self._ptr = 0 622 | 623 | # Setting the pointer for the last batch 624 | if self._ptr + batch_size > num_examples: 625 | self._ptr = num_examples - batch_size 626 | 627 | enc_padding_mask = (enc_inp != self.vocab.word2id(params.PAD_TOKEN)).astype(np.float32) 628 | dec_padding_mask = (dec_out != self.vocab.word2id(params.PAD_TOKEN)).astype(np.float32) 629 | 630 | batches = [enc_inp, enc_padding_mask, dec_inp, dec_out, dec_padding_mask] 631 | if self._use_pgen: 632 | batches += [enc_inp_ext_vocab, max_oov_size] 633 | 634 | return batches 635 | 636 | def get_test_batch(self, batch_size): 637 | num_examples = self.num_test_examples 638 | files = self.test_files 639 | 640 | enc_inp = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32) 641 | 642 | enc_inp_ext_vocab = None 643 | max_oov_size = -np.infty 644 | if self._use_pgen: 645 | enc_inp_ext_vocab = np.ndarray(shape=(batch_size, self._max_enc_steps), dtype=np.int32) 646 | 647 | summaries = [] # Used in 'test' mode. 648 | ext_vocabs = [] # Extended vocabularies. 649 | 650 | # # Shuffle files at the start of an epoch. 651 | # if self._ptr == 0: 652 | # shuffle(files) 653 | 654 | # Start and end index for a batch of data. 655 | start = self._ptr 656 | end = self._ptr + batch_size 657 | self._ptr = end 658 | 659 | for i in range(start, end): 660 | j = i - start # Index of the example in current batch. 661 | article, summary = get_article_summary(files[i]) 662 | enc_inp_tokens = article.split(' ') 663 | 664 | if len(enc_inp_tokens) >= self._max_enc_steps: # Truncate. 665 | enc_inp_tokens = enc_inp_tokens[: self._max_enc_steps] 666 | else: # Pad. 667 | enc_inp_tokens += (self._max_enc_steps - len(enc_inp_tokens)) * [params.PAD_TOKEN] 668 | 669 | # Encoder Input representation in fixed vocabulary. 670 | enc_inp_ids = [self.vocab.word2id(w) for w in enc_inp_tokens] # Word to ID's 671 | 672 | # Encoder Input representation in extended vocabulary. 673 | enc_inp_ids_ext_vocab = None 674 | article_oovs = None 675 | if self._use_pgen: 676 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self.vocab) 677 | 678 | if len(article_oovs) > max_oov_size: 679 | max_oov_size = len(article_oovs) 680 | 681 | # Appending to the input batch. 682 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32) 683 | 684 | if self._use_pgen: 685 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32) 686 | ext_vocabs.append(article_oovs) 687 | 688 | sum_tokens = summary.split(' ') 689 | summaries.append(sum_tokens) 690 | 691 | # Resetting the pointer after the last batch 692 | if self._ptr == num_examples: 693 | self._ptr = 0 694 | 695 | # Setting the pointer for the last batch 696 | if self._ptr + batch_size > num_examples: 697 | self._ptr = num_examples - batch_size 698 | 699 | # Repeat a single input beam size times. 700 | # enc_inp = np.repeat(enc_inp, repeats=beam_size, axis=0) # Shape: Bm x T_in. 701 | # enc_inp_ext_vocab = np.repeat(enc_inp_ext_vocab, repeats=beam_size, axis=0) # Shape: Bm x T_in. 702 | enc_padding_mask = (enc_inp != self.vocab.word2id(params.PAD_TOKEN)).astype(np.float32) # Shape: B x T_in. 703 | 704 | # Example indices. 705 | indices = list(range(start, end)) 706 | batches = [indices, summaries, files[start: end], enc_inp, enc_padding_mask] 707 | if self._use_pgen: 708 | batches += [enc_inp_ext_vocab, ext_vocabs, max_oov_size] 709 | 710 | return batches 711 | 712 | def get_batch(self, batch_size, split='train', permutate=False): 713 | if split == 'train' or split == 'val': 714 | return self.get_train_val_batch(batch_size, split, permutate) 715 | elif split == 'test': 716 | return self.get_test_batch(batch_size) 717 | else: 718 | raise ValueError('split should be either of train/val/test only!! \n') 719 | 720 | 721 | class DataGeneratorHier(object): 722 | 723 | def __init__(self, path_to_dataset, max_inp_sent, max_inp_tok_per_sent, max_out_tok, vocab, 724 | use_pgen=False, use_sample=False): 725 | 726 | self._max_enc_sent = max_inp_sent # Max. no of sentences in the encoder sequence. 727 | self._max_enc_tok_per_sent = max_inp_tok_per_sent # Max. no of tokens per sentence in the encoder sequence. 728 | self._max_enc_tok = self._max_enc_sent * self._max_enc_tok_per_sent 729 | self._max_dec_tok = max_out_tok # Max. no of tokens in the decoder sequence. 730 | 731 | self._vocab = vocab # Vocabulary object. 732 | self._use_pgen = use_pgen # Flag whether pointer-generator mechanism is used. 733 | 734 | self._ptr = 0 # Pointer for batching the data. 735 | 736 | # Train files. 737 | train_stories = os.listdir(path_to_dataset + 'train') 738 | self._train_files = [os.path.join(path_to_dataset + 'train/', s) for s in train_stories] 739 | self.num_train_examples = len(self._train_files) 740 | shuffle(self._train_files) 741 | 742 | # Validation files. 743 | val_stories = os.listdir(path_to_dataset + 'val') 744 | self._val_files = [os.path.join(path_to_dataset + 'val/', s) for s in val_stories] 745 | self.num_val_examples = len(self._val_files) 746 | shuffle(self._val_files) 747 | 748 | # Test files. 749 | test_stories = os.listdir(path_to_dataset + 'test') 750 | self._test_files = [os.path.join(path_to_dataset + 'test/', s) for s in test_stories] 751 | self.num_test_examples = len(self._test_files) 752 | # shuffle(self._test_files) 753 | 754 | if use_sample: 755 | # ***************************** PATCH ***************************** # 756 | train_idx = randint(0, self.num_train_examples - 20) 757 | self._train_files = self._train_files[train_idx: train_idx + 20] 758 | self.num_train_examples = len(self._train_files) 759 | 760 | val_idx = randint(0, self.num_val_examples - 20) 761 | self._val_files = self._val_files[val_idx: val_idx + 20] 762 | self.num_val_examples = len(self._val_files) 763 | 764 | test_idx = randint(0, self.num_test_examples - 23) 765 | self._test_files = self._test_files[test_idx: test_idx + 23] 766 | self.num_test_examples = len(self._test_files) 767 | # ***************************** PATCH ***************************** # 768 | 769 | print("Split the data as follows:\n") 770 | print("\t\t Training: {} examples. \n".format(self.num_train_examples)) 771 | print("\t\t Validation: {} examples. \n".format(self.num_val_examples)) 772 | print("\t\t Test: {} examples. \n".format(self.num_test_examples)) 773 | 774 | def get_batch(self, batch_size, split="train", permutate=False, chunk=False): 775 | if split == "train" or split == "val": 776 | return self._get_train_val_batch(batch_size, split, permutate, chunk) 777 | 778 | elif split == "test": 779 | return self._get_test_batch(batch_size, chunk) 780 | 781 | else: 782 | raise ValueError("Split should be either of train/val/test only!! \n") 783 | 784 | def _get_train_val_batch(self, batch_size, split="train", permutate=False, chunk=False): 785 | if split == "train": 786 | num_examples = self.num_train_examples 787 | files = self._train_files 788 | 789 | elif split == "val": 790 | num_examples = self.num_val_examples 791 | files = self._val_files 792 | 793 | else: 794 | raise ValueError("Split is neither train nor val. Check the function call!") 795 | 796 | enc_inp = np.ndarray(shape=[batch_size, self._max_enc_sent, self._max_enc_tok_per_sent], dtype=np.int32) 797 | dec_inp = np.ndarray(shape=[batch_size, self._max_dec_tok], dtype=np.int32) 798 | dec_out = np.ndarray(shape=[batch_size, self._max_dec_tok], dtype=np.int32) 799 | 800 | # Additional inputs in the pointer-generator mode. 801 | max_oov_size = -np.infty 802 | enc_inp_ext_vocab = None 803 | if self._use_pgen: 804 | enc_inp_ext_vocab = np.ndarray(shape=[batch_size, self._max_enc_tok], dtype=np.int32) 805 | 806 | # Shuffle files at the start of an epoch. 807 | if self._ptr == 0: 808 | shuffle(files) 809 | 810 | # Start and end index for a batch of data. 811 | start = self._ptr 812 | end = self._ptr + batch_size 813 | self._ptr = end 814 | 815 | for i in range(start, end): 816 | j = i - start # Index of the example in current batch. 817 | article_sents, summary = get_article_summary(files[i], art_format="sentences") 818 | 819 | # When chunking reshaping the data. 820 | if chunk: 821 | art_words = ' '.join(article_sents) 822 | 823 | article_lines = [[]] 824 | word_count = 0 825 | for word in art_words.split(' '): 826 | article_lines[-1].append(word) 827 | word_count += 1 828 | 829 | if word_count == self._max_enc_tok_per_sent: 830 | word_count = 0 831 | article_lines.append([]) 832 | 833 | article_sents = [] 834 | for line in article_lines: 835 | article_sents.append(' '.join(line)) 836 | 837 | if len(article_sents) >= self._max_enc_sent: # Truncate no. of sentences. 838 | if permutate: 839 | indcs = sorted(sample(range(len(article_sents)), self._max_enc_sent)) 840 | article_sents = [article_sents[i] for i in indcs] 841 | else: 842 | article_sents = article_sents[: self._max_enc_sent] 843 | else: 844 | article_sents += (self._max_enc_sent - 845 | len(article_sents)) * [''] # Add empty sentences. 846 | 847 | enc_inp_ids = [] 848 | enc_inp_tokens = [] 849 | for sent_idx, art_sent in enumerate(article_sents): 850 | 851 | enc_sent_tokens = art_sent.split(' ') 852 | if len(enc_sent_tokens) >= self._max_enc_tok_per_sent: # Truncate no. of tokens. 853 | if permutate: 854 | indcs = sorted(sample(range(len(enc_sent_tokens)), self._max_enc_tok_per_sent)) 855 | enc_sent_tokens = [enc_sent_tokens[i] for i in indcs] 856 | else: 857 | enc_sent_tokens = enc_sent_tokens[: self._max_enc_tok_per_sent] 858 | else: 859 | enc_sent_tokens += (self._max_enc_tok_per_sent 860 | - len(enc_sent_tokens)) * [params.PAD_TOKEN] # Pad. 861 | 862 | # Encoder sentence representation in the fixed vocabulary. 863 | enc_sent_ids = [self._vocab.word2id(w) for w in enc_sent_tokens] 864 | 865 | # Appending to the lists. 866 | enc_inp_ids.append(enc_sent_ids) 867 | enc_inp_tokens += enc_sent_tokens 868 | 869 | # Summary tokens. 870 | sum_tokens = summary.split(' ') # Summary tokens. 871 | sum_tokens = [t for t in sum_tokens if t not in 872 | [params.SENTENCE_START, params.SENTENCE_END]] # Removing , tokens. 873 | 874 | if len(sum_tokens) > self._max_dec_tok - 1: # Truncate. 875 | sum_tokens = sum_tokens[: self._max_dec_tok - 1] 876 | 877 | # Decoder Input. 878 | dec_inp_tokens = [params.START_DECODING] + sum_tokens 879 | if len(dec_inp_tokens) < self._max_dec_tok: 880 | dec_inp_tokens += (self._max_dec_tok - len(dec_inp_tokens)) * [params.PAD_TOKEN] 881 | 882 | # Decoder Output. 883 | dec_out_tokens = sum_tokens + [params.STOP_DECODING] 884 | if len(dec_out_tokens) < self._max_dec_tok: 885 | dec_out_tokens += (self._max_dec_tok - len(dec_out_tokens)) * [params.PAD_TOKEN] 886 | 887 | dec_inp_ids = [self._vocab.word2id(w) for w in dec_inp_tokens] 888 | dec_out_ids = [self._vocab.word2id(w) for w in dec_out_tokens] 889 | 890 | # Encoder input, decoder output representation in extended vocabulary. 891 | enc_inp_ids_ext_vocab = None 892 | if self._use_pgen: 893 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self._vocab) 894 | dec_out_ids = summary2ids(dec_out_tokens, self._vocab, article_oovs) 895 | 896 | if len(article_oovs) > max_oov_size: 897 | max_oov_size = len(article_oovs) 898 | 899 | # Appending to the batch of inputs. 900 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32) 901 | dec_inp[j] = np.array(dec_inp_ids).astype(np.int32) 902 | dec_out[j] = np.array(dec_out_ids).astype(np.int32) 903 | 904 | if self._use_pgen: 905 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32) 906 | 907 | # Resetting the pointer after the last batch. 908 | if self._ptr == num_examples: 909 | self._ptr = 0 910 | 911 | # Setting the pointer for the last batch. 912 | if self._ptr + batch_size > num_examples: 913 | self._ptr = num_examples - batch_size 914 | 915 | # Padding masks. 916 | pad_id = self._vocab.word2id(params.PAD_TOKEN) 917 | enc_pad_mask = (enc_inp != pad_id).astype(np.float32) # B x S_in x T_in. 918 | enc_doc_mask = np.sum(enc_pad_mask, axis=2) # B x S_in. 919 | enc_doc_mask = np.greater(enc_doc_mask, 0).astype(np.float32) # B x S_in. 920 | dec_pad_mask = (dec_out != pad_id).astype(np.float32) # B x T_dec. 921 | 922 | batches = [enc_inp, enc_pad_mask, enc_doc_mask, dec_inp, dec_out, dec_pad_mask] 923 | if self._use_pgen: 924 | batches += [enc_inp_ext_vocab, max_oov_size] 925 | 926 | return batches 927 | 928 | def _get_test_batch(self, batch_size, chunk=False): 929 | num_examples = self.num_test_examples 930 | files = self._test_files 931 | 932 | enc_inp = np.ndarray(shape=[batch_size, self._max_enc_sent, self._max_enc_tok_per_sent], dtype=np.int32) 933 | 934 | max_oov_size = -np.infty 935 | enc_inp_ext_vocab = None 936 | if self._use_pgen: 937 | enc_inp_ext_vocab = np.ndarray(shape=[batch_size, self._max_enc_tok], dtype=np.int32) 938 | 939 | summaries = [] # Used in 'test' mode. 940 | ext_vocabs = [] # Extended vocabularies. 941 | 942 | # # Shuffle files at the start of an epoch. 943 | # if self._ptr == 0: 944 | # shuffle(files) 945 | 946 | # Start and end index for a batch of data. 947 | start = self._ptr 948 | end = self._ptr + batch_size 949 | self._ptr = end 950 | 951 | for i in range(start, end): 952 | j = i - start 953 | article_sents, summary = get_article_summary(files[i], art_format="sentences", sum_format="sentences") 954 | 955 | # When chunking reshaping the data. 956 | if chunk: 957 | art_words = ' '.join(article_sents) 958 | 959 | article_lines = [[]] 960 | word_count = 0 961 | for word in art_words.split(' '): 962 | article_lines[-1].append(word) 963 | word_count += 1 964 | 965 | if word_count == self._max_enc_tok_per_sent: 966 | word_count = 0 967 | article_lines.append([]) 968 | 969 | article_sents = [] 970 | for line in article_lines: 971 | article_sents.append(' '.join(line)) 972 | 973 | if len(article_sents) >= self._max_enc_sent: # Truncate no. of sentences. 974 | article_sents = article_sents[: self._max_enc_sent] 975 | else: # Add empty sentences. 976 | article_sents += (self._max_enc_sent - 977 | len(article_sents)) * [''] 978 | 979 | enc_inp_ids = [] 980 | enc_inp_tokens = [] 981 | for sent_idx, art_sent in enumerate(article_sents): 982 | # Break if max. no. of encoder sentences is reached. 983 | if sent_idx >= self._max_enc_sent: 984 | break 985 | 986 | enc_sent_tokens = art_sent.split(' ') 987 | if len(enc_sent_tokens) >= self._max_enc_tok_per_sent: # Truncate no. of tokens. 988 | enc_sent_tokens = enc_sent_tokens[: self._max_enc_tok_per_sent] 989 | else: # Pad. 990 | enc_sent_tokens += (self._max_enc_tok_per_sent - 991 | len(enc_sent_tokens)) * [params.PAD_TOKEN] 992 | 993 | # Encoder representation in the fixed vocabulary. 994 | enc_sent_ids = [self._vocab.word2id(w) for w in enc_sent_tokens] 995 | 996 | # Appending to the lists. 997 | enc_inp_ids.append(enc_sent_ids) 998 | enc_inp_tokens += enc_sent_tokens 999 | 1000 | # Encoder input representation in the extended vocabulary. 1001 | enc_inp_ids_ext_vocab = None 1002 | article_oovs = None 1003 | if self._use_pgen: 1004 | enc_inp_ids_ext_vocab, article_oovs = article2ids(enc_inp_tokens, self._vocab) 1005 | 1006 | if len(article_oovs) > max_oov_size: 1007 | max_oov_size = len(article_oovs) 1008 | 1009 | # Appending to the input batch. 1010 | enc_inp[j] = np.array(enc_inp_ids).astype(np.int32) 1011 | 1012 | if self._use_pgen: 1013 | enc_inp_ext_vocab[j] = np.array(enc_inp_ids_ext_vocab).astype(np.int32) 1014 | ext_vocabs.append(article_oovs) 1015 | 1016 | # Summaries. 1017 | summaries.append(summary) 1018 | 1019 | # Resetting the pointer after the last batch. 1020 | if self._ptr == num_examples: 1021 | self._ptr = 0 1022 | 1023 | # Setting the pointer for the last batch. 1024 | if self._ptr + batch_size > num_examples: 1025 | self._ptr = num_examples - batch_size 1026 | 1027 | # Padding masks. 1028 | pad_id = self._vocab.word2id(params.PAD_TOKEN) 1029 | enc_pad_mask = (enc_inp != pad_id).astype(np.float32) # B x S_in x T_in. 1030 | enc_doc_mask = np.sum(enc_pad_mask, axis=2) # B x S_in. 1031 | enc_doc_mask = np.greater(enc_doc_mask, 0).astype(np.float32) # B x S_in. 1032 | 1033 | indices = list(range(start, end)) 1034 | batches = [indices, summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask] 1035 | if self._use_pgen: 1036 | batches += [enc_inp_ext_vocab, ext_vocabs, max_oov_size] 1037 | 1038 | return batches 1039 | 1040 | 1041 | def average_gradients(tower_grads): 1042 | """ 1043 | This function collects gradients from all towers and returns the average gradient. 1044 | :param tower_grads: List of gradients from all towers. 1045 | :return: List of average gradients of all variables. 1046 | """ 1047 | average_grads = [] 1048 | for grad_and_vars in zip(*tower_grads): 1049 | # For M variables and N GPUs, tower_grads is of the form 1050 | # ((grad0_gpu0, var0_gpu0), (grad0_gpu1, var0_gpu1), ..., (grad0_gpuN, var0_gpuN)) 1051 | # ((grad1_gpu0, var1_gpu0), (grad1_gpu1, var1_gpu1), ..., (grad1_gpuN, var1_gpuN)) 1052 | # ...... 1053 | # ((gradM_gpu0, varM_gpu0), (gradM_gpu1, varM_gpu1), ..., (gradM_gpuN, varM_gpuN)) 1054 | 1055 | grads = [] 1056 | for g, _ in grad_and_vars: 1057 | # Adding an extra dimension for concatenation later. 1058 | expanded_g = tf.expand_dims(g, 0) 1059 | grads.append(expanded_g) 1060 | 1061 | grad = tf.concat(axis=0, values=grads) # Concatenation along added dimension. 1062 | grad = tf.reduce_mean(grad, axis=0) # Average gradient. 1063 | 1064 | # Variable name is same in all the GPUs. So, it suffices to use the one in the 1st GPU. 1065 | v = grad_and_vars[0][1] 1066 | grad_and_var = (grad, v) 1067 | average_grads.append(grad_and_var) 1068 | 1069 | return average_grads 1070 | 1071 | 1072 | def eval_model(path_to_results): 1073 | """ 1074 | This function is for the ROUGE evaluation. 1075 | :return: 1076 | """ 1077 | r = pyrouge.Rouge155() 1078 | r.system_dir = path_to_results + 'predictions/' 1079 | r.model_dir = path_to_results + 'groundtruths/' 1080 | r.system_filename_pattern = '(\d+)_pred.txt' 1081 | r.model_filename_pattern = '#ID#_gt.txt' 1082 | rouge_results = r.convert_and_evaluate() 1083 | 1084 | rouge_dict = r.output_to_dict(rouge_results) 1085 | rouge_log(rouge_dict, path_to_results) 1086 | 1087 | 1088 | def rouge_log(results_dict, path_to_results): 1089 | """ 1090 | This function saves the rouge results into a file. 1091 | :param results_dict: Dictionary output from pyrouge consisting of ROUGE results. 1092 | :param path_to_results: Path where the results file has to be stored. 1093 | :return: 1094 | """ 1095 | log_str = "" 1096 | for x in ["1", "2", "l"]: 1097 | 1098 | log_str += "\nROUGE-%s:\n" % x 1099 | for y in ["f_score", "recall", "precision"]: 1100 | key = "rouge_%s_%s" % (x, y) 1101 | key_cb = key + "_cb" 1102 | key_ce = key + "_ce" 1103 | val = results_dict[key] 1104 | val_cb = results_dict[key_cb] 1105 | val_ce = results_dict[key_ce] 1106 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 1107 | 1108 | results_file = path_to_results + 'ROUGE_results.txt' 1109 | with open(results_file, "w") as f: 1110 | f.write(log_str) 1111 | 1112 | 1113 | def get_running_avg_loss(loss, running_avg_loss, decay=0.99): 1114 | """ 1115 | This function updates the running averages loss. 1116 | :param loss: Loss at the current step. 1117 | :param running_avg_loss: Running average loss after the previous step. 1118 | :param decay: The decay rate. 1119 | :return: 1120 | """ 1121 | 1122 | if running_avg_loss == 0: 1123 | running_avg_loss = loss 1124 | else: 1125 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 1126 | 1127 | return running_avg_loss 1128 | 1129 | 1130 | def get_dependencies(tensor): 1131 | dependencies = set() 1132 | dependencies.update(tensor.op.inputs) 1133 | for sub_op in tensor.op.inputs: 1134 | dependencies.update(get_dependencies(sub_op)) 1135 | 1136 | return dependencies 1137 | 1138 | 1139 | def get_placeholder_dependencies(tensor): 1140 | dependencies = get_dependencies(tensor) 1141 | dependencies = [tensor for tensor in dependencies if tensor.op.type == "Placeholder"] 1142 | 1143 | return dependencies 1144 | -------------------------------------------------------------------------------- /codes/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Rajeev Bhatt Ambati" 3 | 4 | from NSE import NSE 5 | from utils import average_gradients, get_running_avg_loss 6 | import params as params 7 | from beam_search import run_beam_search 8 | from rouge_batch import rouge_l_fscore_batch as rouge_l_fscore 9 | 10 | import random 11 | import time 12 | import math 13 | import tensorflow as tf 14 | from tensorflow.nn.rnn_cell import LSTMStateTuple 15 | from tensorflow.contrib.framework.python.framework import checkpoint_utils 16 | import numpy as np 17 | from joblib import Parallel, delayed 18 | 19 | tf.set_random_seed(2019) 20 | random.seed(2019) 21 | tf.reset_default_graph() 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | 26 | class SummarizationModel(object): 27 | def __init__(self, vocab, data): 28 | self.vocab = vocab 29 | self.data = data 30 | 31 | self._dense_init = tf.contrib.layers.variance_scaling_initializer() 32 | 33 | self._config = tf.ConfigProto() 34 | self._config.gpu_options.allow_growth = True 35 | 36 | self._best_val_loss = np.infty 37 | self._num_gpus = len(FLAGS.GPUs) 38 | 39 | self._saver = None # Saver. 40 | self._init = None # Initializer. 41 | self._sess = None # Session. 42 | 43 | def _create_placeholders(self): 44 | """ 45 | This function creates the placeholders needed for the computation graph. 46 | :return: 47 | """ 48 | self._global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') 49 | self.learning_rate = tf.placeholder(tf.float32, name='learning_rate') 50 | 51 | # Word embedding. 52 | if FLAGS.use_pretrained: 53 | with tf.variable_scope('embed', reuse=tf.AUTO_REUSE): 54 | self.word_embedding = tf.get_variable(name="word_embedding", 55 | shape=[self.vocab.size(), FLAGS.dim], 56 | initializer=tf.constant_initializer(self.vocab.wvecs), 57 | dtype=tf.float32, 58 | trainable=True) 59 | else: 60 | self.word_embedding = tf.get_variable(name="word_embedding", 61 | shape=[self.vocab.size(), FLAGS.dim], 62 | dtype=tf.float32, 63 | trainable=True) 64 | 65 | # Graph Inputs/Outputs. 66 | self._enc_in = tf.placeholder(tf.int32, name='enc_in', 67 | shape=[FLAGS.batch_size, FLAGS.enc_steps]) # Shape: B x T_enc. 68 | self._enc_pad_mask = tf.placeholder(tf.float32, name='enc_pad_mask', 69 | shape=[FLAGS.batch_size, FLAGS.enc_steps]) # Shape: B x T_enc. 70 | self._dec_in = tf.placeholder(tf.int32, name='dec_in', 71 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec. 72 | inputs = [self._enc_in, self._enc_pad_mask, self._dec_in] 73 | 74 | if FLAGS.use_pgen: 75 | self._enc_in_ext_vocab = tf.placeholder(tf.int32, name='enc_in_ext_vocab', 76 | shape=[FLAGS.batch_size, 77 | FLAGS.enc_steps]) # Shape: B x T_enc. 78 | self._max_oov_size = tf.placeholder(tf.int32, name='max_oov_size') 79 | inputs.append(self._enc_in_ext_vocab) 80 | 81 | if FLAGS.mode == "train": 82 | self._dec_out = tf.placeholder(tf.int32, name='dec_out', 83 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec. 84 | self._dec_pad_mask = tf.placeholder(tf.float32, name='dec_pad_mask', 85 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec. 86 | inputs += [self._dec_out, self._dec_pad_mask] 87 | 88 | return inputs 89 | 90 | def _create_writers(self): 91 | """ 92 | This function creates the summaries and writers needed for visualization through tensorboard. 93 | :return: writers 94 | """ 95 | self._mean_crossentropy_loss = tf.placeholder(tf.float32, name='mean_crossentropy_loss') 96 | self._mean_rouge_score = tf.placeholder(dtype=tf.float32, name="mean_rouge") 97 | 98 | # Summaries. 99 | cross_entropy_summary = tf.summary.scalar('cross_entropy', self._mean_crossentropy_loss) 100 | rouge_summary = None 101 | if FLAGS.rouge_summary: 102 | rouge_summary = tf.summary.scalar('rouge_score', self._mean_rouge_score) 103 | self._summaries = tf.summary.merge_all() 104 | 105 | summary_list = [cross_entropy_summary] 106 | if FLAGS.rouge_summary: 107 | summary_list.append(rouge_summary) 108 | 109 | self._val_summaries = tf.summary.merge(summary_list, name="validation_summaries") 110 | 111 | # Summary Writers. 112 | self._train_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'train') # , self._sess.graph) 113 | self._val_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'val') 114 | 115 | def build_graph(self): 116 | if FLAGS.mode == 'train': 117 | if len(FLAGS.GPUs) > 1: 118 | self._parallel_model() # Parallel model in case of multiple-GPUs. 119 | else: 120 | self._single_model() # Single model for a single GPU. 121 | 122 | if FLAGS.mode == 'test': 123 | inputs = self._create_placeholders() # [enc_in, dec_in, enc_in_ext_vocab] 124 | self._forward(inputs) # Predictions Shape: Bm x 1 125 | 126 | self._saver = tf.train.Saver(tf.global_variables()) # Saver. 127 | self._init = tf.global_variables_initializer() # Initializer. 128 | self._sess = tf.Session(config=self._config) # Session. 129 | 130 | # Restoring the trained model. 131 | self._saver.restore(self._sess, FLAGS.PathToCheckpoint) 132 | 133 | def _forward(self, inputs): 134 | """ 135 | This function creates the TensorFlow computation graph. 136 | :param inputs: A list of input placeholders. 137 | [0]: enc_in, Encoder input sequence of ID's, Shape: B x T_in. 138 | [1]: enc_pad_mask: Mask for padding tokens in encoder input, Shape: B x T_in. 139 | [2]: dec_in: Input to the decoder, Shape: B x T_out. 140 | Following additional input in pointer generator mode. 141 | [3]: enc_in_ext_vocab: Encoder input representation in the extended vocabulary, Shape: B x T_in. 142 | [4]: labels: Ground-Truth labels, Only in train mode, Shape: B x T_out. 143 | [5]: dec_pad_mask: Mask for padding tokens in decoder output, Shape: B x T_out 144 | :return: returns loss in train/val mode and predictions in test mode. 145 | """ 146 | batch_size = inputs[0].get_shape().as_list()[0] # Batch-size. 147 | 148 | # The NSE instance. 149 | self._nse = NSE(batch_size=batch_size, dim=FLAGS.dim, dense_init=self._dense_init, 150 | mode=FLAGS.mode, use_comp_lstm=FLAGS.use_comp_lstm) 151 | 152 | # Encoder. 153 | self._prev_states, self._p_attns_enc = self._encoder(inputs[: 2]) 154 | 155 | # Decoder. 156 | outputs = self._decoder(inputs[1: 4], self._prev_states) 157 | 158 | if FLAGS.mode == 'test': 159 | self._topk_ids, self._topk_log_probs, self._curr_states, self._p_attns_dec = outputs[: 4] 160 | 161 | if FLAGS.use_pgen: 162 | self._p_gens = outputs[4] 163 | 164 | else: 165 | probs, p_attns_dec = outputs 166 | 167 | crossentropy_loss = self._get_crossentropy_loss(probs, inputs[-2], inputs[-1]) 168 | 169 | # Evaluating a few samples. 170 | samples = [] 171 | if FLAGS.rouge_summary: 172 | samples = self._get_samples(inputs[1: 5], self._prev_states) 173 | 174 | return crossentropy_loss, samples 175 | 176 | def _encoder(self, inputs): 177 | """ 178 | This is the encoder. 179 | :param inputs: A list of the following inputs. 180 | enc_in: Encoder input sequence of ID's, Shape: B x T_in. 181 | enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x T_in. 182 | :return: 183 | output: 184 | [0]: p_attn: Attention distribution while encoding, Shape: B x T_in. 185 | [1]: prev_coverage: Coverage vector after the last encoding step, Shape: B x T_in. 186 | Returns the internal states of NSE after the last encoding step. 187 | memory: NSE memory after the last time step, Shape: B x T_in x D. 188 | read_state: Hidden state of the read LSTM after the last time step, (c, h) Shape: 2 * [B x D]. 189 | write_state: Hidden state of write LSTM after the last encoding step, (c, h) Shape: 2 * [B x D]. 190 | comp_state: Hidden state of compose LSTM after last encoding step, (c, h) Shape: 2 * [B x (2*D)]. 191 | """ 192 | enc_in, enc_pad_mask = inputs 193 | 194 | # Converting ID's to word-vectors. 195 | enc_in_vecs = tf.nn.embedding_lookup(params=self.word_embedding, ids=enc_in) # Shape: B x T_in x D. 196 | enc_in_vecs = tf.cast(enc_in_vecs, dtype=tf.float32) # Cast to float32 197 | 198 | p_attns = [] 199 | state = [enc_in_vecs, None, None, None] 200 | for i in range(FLAGS.enc_steps): 201 | x_t = enc_in_vecs[:, i, :] # Shape: B x D. 202 | output, state = self._nse.step( 203 | x_t=x_t, mem_mask=enc_pad_mask, prev_state=state) # One step of NSE. 204 | 205 | return state, p_attns 206 | 207 | def _decoder(self, inputs, state): 208 | """ 209 | This is the decoder. 210 | :param inputs: List of the following inputs. 211 | [0]: enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x T_in. 212 | [1]: dec_in: Input to the decoder. 213 | Test mode: Output of previous time step, Shape: Bm x 1. 214 | Train/Val mode: decoder input sequence, Shape: B x T_out. 215 | [2]: enc_in_ext_vocab (Only in test mode): 216 | Encoder input representation in the extended vocabulary in 217 | pointer generator mode, Shape: B x T_in. 218 | :param state: Internal states of NSE after the encoding is done. 219 | [0]: memory: NSE memory after the last time step of encoder, Shape: B x T_in x D. 220 | [1]: read_state: Hidden state of read LSTM after the previous step, (c, h), Shape: (B x D, B x D) 221 | [2]: write_state: Hidden state of write LSTM after the previous step, (c, h), Shape: (B x D, B x D) 222 | [3]: comp_state: Hidden state of compose LSTM after the previous step, (c, h), Shape: (B x D, B x D) 223 | :return: 224 | In test mode: 225 | [topk_ids, topk_log_probs, state, p_attns] + [p_gens] 226 | In train mode: 227 | [p_cums, p_attns] 228 | """ 229 | enc_pad_mask, dec_in = inputs[0], inputs[1] 230 | 231 | # Converting ID's to word-vectors. 232 | dec_in_vecs = tf.nn.embedding_lookup(params=self.word_embedding, ids=dec_in) # Shape: B x T_out x D. 233 | dec_in_vecs = tf.cast(dec_in_vecs, tf.float32) # Cast to tf.float32. 234 | 235 | writes = [] 236 | p_attns = [] 237 | p_gens = [] 238 | for i in range(FLAGS.dec_steps): 239 | x_t = dec_in_vecs[:, i, :] # Shape: B x D. 240 | output, state = self._nse.step( 241 | x_t=x_t, mem_mask=enc_pad_mask, prev_state=state, use_pgen=FLAGS.use_pgen 242 | ) # One-step of NSE. 243 | 244 | # Appending the outputs. 245 | writes.append(output[0]) 246 | p_attns.append(output[1]) 247 | if FLAGS.use_pgen: 248 | p_gens.append(output[2]) 249 | 250 | p_vocabs = self._get_vocab_dist(writes) # Shape: T_out * [B x vsize] 251 | 252 | p_cums = p_vocabs 253 | if FLAGS.use_pgen: 254 | enc_in_ext_vocab = inputs[2] # Shape: T_in x B 255 | p_cums = self._get_cumulative_dist( 256 | p_vocabs, p_gens, p_attns, enc_in_ext_vocab 257 | ) # Shape: T_out * [B x ext_vsize] 258 | 259 | if FLAGS.mode == 'test': 260 | p_final = p_cums[0] # Shape: Bm x ext_vsize 261 | 262 | # The top k predictions, topk_ids, Shape: Bm x (2*Bm). 263 | # Respective probabilities, topk_probs, Shape: Bm x (2*Bm). 264 | topk_probs, topk_ids = tf.nn.top_k(p_final, k=2*FLAGS.beam_size, name='topk_preds') 265 | topk_log_probs = tf.log(tf.clip_by_value(topk_probs, 1e-10, 1.0)) 266 | 267 | outputs = [topk_ids, topk_log_probs, state, p_attns] 268 | if FLAGS.use_pgen: 269 | outputs.append(p_gens) 270 | 271 | return outputs 272 | else: 273 | return p_cums, p_attns 274 | 275 | def _get_samples(self, inputs, state): 276 | """ 277 | This is the decoder. 278 | :param inputs: List of the following inputs. 279 | [0]: enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x T_in. 280 | [1]: dec_in: Input to the decoder. 281 | Test mode: Output of previous time step, Shape: Bm x 1. 282 | Train/Val mode: decoder input sequence, Shape: B x T_out. 283 | [2]: enc_in_ext_vocab (Only in test mode): 284 | Encoder input representation in the extended vocabulary in 285 | pointer generator mode, Shape: B x T_in. 286 | :param state: Internal states of NSE after the encoding is done. 287 | [0]: memory: NSE memory after the last time step of encoder, Shape: B x T_in x D. 288 | [1]: read_state: Hidden state of read LSTM after the previous step, (c, h), Shape: (B x D, B x D) 289 | [2]: write_state: Hidden state of write LSTM after the previous step, (c, h), Shape: (B x D, B x D) 290 | [3]: comp_state: Hidden state of compose LSTM after the previous step, (c, h), Shape: (B x D, B x D) 291 | :return: 292 | In test mode: 293 | [topk_ids, topk_log_probs, state, p_attns] + [p_gens] 294 | In train mode: 295 | [p_cums, p_attns] 296 | """ 297 | unk_id = self.vocab.word2id(params.UNKNOWN_TOKEN) # Unknown ID. 298 | start_id = self.vocab.word2id(params.START_DECODING) # Start ID. 299 | 300 | enc_pad_mask = inputs[0] # Encoder mask. 301 | batch_size = enc_pad_mask.get_shape().as_list()[0] 302 | samples = [] # Sampling Outputs. 303 | for i in range(FLAGS.dec_steps): 304 | 305 | if i == 0: 306 | id_t = tf.fill([batch_size], start_id) # Shape: B x . 307 | else: 308 | id_t = samples[-1][:, 0] 309 | # Replacing the ID's from external vocabulary (if any) to UNK id's. 310 | id_t = tf.where( 311 | tf.less(id_t, self.vocab.size()), id_t, unk_id * tf.ones_like(id_t) 312 | ) 313 | 314 | # Getting the word vector. 315 | x_t = tf.nn.embedding_lookup(params=self.word_embedding, ids=id_t) # Shape: B x D. 316 | x_t = tf.cast(x_t, dtype=tf.float32) 317 | 318 | output, state = self._nse.step( 319 | x_t=x_t, mem_mask=enc_pad_mask, prev_state=state, use_pgen=FLAGS.use_pgen 320 | ) # One-step of NSE. 321 | 322 | # Output probability distribution. 323 | p_vocab = self._get_vocab_dist([output[0]]) # Shape: [B x vsize]. 324 | 325 | # Calculate cumulative probability distribution using pointer mechanism. 326 | p_cum = p_vocab 327 | if FLAGS.use_pgen: 328 | p_cum = self._get_cumulative_dist( 329 | p_vocabs=p_vocab, p_gens=[output[2]], p_attns=[output[1]], enc_in_ext_vocab=inputs[2] 330 | ) # Shape: T_dec * [B x ext_vsize]. 331 | 332 | # Greedy sampling. 333 | _, gs_sample = tf.nn.top_k(p_cum[0], k=FLAGS.num_samples) # Shape: B x 1. 334 | samples.append(gs_sample) 335 | 336 | samples = tf.concat(samples, axis=1) # Shape: B x T_dec. 337 | 338 | return samples 339 | 340 | def run_encoder(self, enc_in_batch, enc_pad_mask): 341 | """ 342 | This function calculates the internal states of NSE after the last step of encoding. 343 | :param enc_in_batch: A batch of encoder input sequence, Shape: Bm x T_enc. 344 | :param enc_pad_mask: Padding mask for encoder input, Shape: Bm x T_enc. 345 | :return: 346 | NSE internal states after encoding. 347 | final_memory: NSE memory after last time-step of encoding, Shape: 1 x mem_size x D. 348 | final_read_state: Hidden state of read LSTM after last time-step of encoding, 349 | (c, h) Shape: (1 x D, 1 x D). 350 | final_write_state: Hidden state of write LSTM after last time-step of encoding, 351 | (c, h) Shape: (1 x D, 1 x D). 352 | final_comp_state: Hidden state of compose LSTM after last time-step of encoding, 353 | (c, h) Shape: (1 x 2D, 1 x 2D). 354 | Additionally, returns the following when using encoder coverage: 355 | p_attns: Encoder attention distributions, Shape: T_enc * [1 x T_enc]. 356 | """ 357 | to_return = [self._prev_states] 358 | 359 | outputs = self._sess.run(to_return, feed_dict={self._enc_in: enc_in_batch, 360 | self._enc_pad_mask: enc_pad_mask}) 361 | states = outputs[0] 362 | 363 | final_memory, final_read_state, final_write_state = states[: 3] 364 | # Since the states repeated values, slicing only first one. 365 | states[0] = final_memory[0, np.newaxis, :, :] # Memory, Shape: 1 x mem_size x D. 366 | states[1] = LSTMStateTuple(final_read_state.c[0, np.newaxis, :], 367 | final_read_state.h[0, np.newaxis, :]) # Shape: (1 x D, 1 x D). 368 | states[2] = LSTMStateTuple(final_write_state.c[0, np.newaxis, :], 369 | final_write_state.h[0, np.newaxis, :]) # Shape: (1 x D, 1 x D). 370 | 371 | if FLAGS.use_comp_lstm: 372 | final_comp_state = states[3] 373 | states[3] = LSTMStateTuple(final_comp_state.c[0, np.newaxis, :], 374 | final_comp_state.h[0, np.newaxis, :]) # Shape: (1 x (2D), 1 x (2D)). 375 | 376 | return states 377 | 378 | def decode_one_step(self, inputs, prev_states): 379 | """ 380 | This function performs one step of decoding. 381 | :param inputs: 382 | dec_in_batch: The input to the decoder. This is the output from previous time-step, Shape: Bm * [1 x] 383 | enc_pad_mask: Padding mask for the memory, Shape: Bm x T. 384 | In pointer generator mode, there are following additional inputs: 385 | enc_in_ex_vocab_batch: Encoder input sequence represented in extended vocabulary, Shape: Bm x T_in. 386 | max_oov_size: Size of the largest OOV tokens in the current batch, Shape: () 387 | :param prev_states: previous internal states of NSE of all Bm hypothesis, Bm * [prev_state]. 388 | where state is a list of internal states of NSE for a single hypothesis. 389 | prev_state = [prev_memory, prev_read_state, prev_write_state] 390 | prev_memory: NSE memory after the previous time step, Shape: 1 x T x D. 391 | prev_read_state: Hidden state of read LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D]. 392 | prev_write_state: Hidden state of write LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D]. 393 | prev_comp_state: Hidden state of compose LSTM after previous time step, (c, h) Shape: [1 x 2D, 1 x 2D]. 394 | :return: 395 | topk_ids: Top-k predictions in the current step, Shape: Bm x (2*Bm) 396 | topk_log_probs: log probabilities of top-k predictions, Shape: Bm x (2*Bm) 397 | curr_states: Current internal states of NSE, Bm * [state]. 398 | [0]: memory: NSE memory, Shape: Bm x T x D. 399 | [1]: read_state, (c, h) Shape: (Bm x D, Bm x D). 400 | [2]: write_state, (c, h) Shape: (Bm x D, Bm x D). 401 | [3]: comp_state, (c, h) Shape: (Bm x 2D, Bm x 2D). 402 | p_gens: Generation probabilities, Shape: Bm x . 403 | p_attns: Attention probabilities, Shape: Bm x mem_size. 404 | """ 405 | # Decoder input 406 | dec_in = inputs[0] # Shape: Bm * [1 x] 407 | dec_in = np.stack(dec_in, axis=0) # Shape: Bm x 408 | inputs[0] = np.expand_dims(dec_in, axis=-1) # Shape: Bm x 1 409 | 410 | # Previous memories of Bm hypothesis. 411 | prev_memories = [state[0] for state in prev_states] # Shape: Bm * [1 x T x D]. 412 | prev_memories = np.concatenate(prev_memories, axis=0) # Shape: Bm x T x D. 413 | 414 | # Previous read states of Bm hypothesis. 415 | # Cell states. 416 | prev_read_states_c = [state[1].c for state in prev_states] # Shape: Bm * [1 x D]. 417 | prev_read_states_c = np.concatenate(prev_read_states_c, axis=0) # Shape: Bm x D. 418 | 419 | # Hidden states. 420 | prev_read_states_h = [state[1].h for state in prev_states] # Shape: Bm * [1 x D]. 421 | prev_read_states_h = np.concatenate(prev_read_states_h, axis=0) # Shape: Bm x D. 422 | 423 | prev_read_states = LSTMStateTuple(prev_read_states_c, prev_read_states_h) 424 | 425 | # Previous write states of Bm hypothesis. 426 | # Cell states. 427 | prev_write_states_c = [state[2].c for state in prev_states] # Shape: Bm * [1 x D]. 428 | prev_write_states_c = np.concatenate(prev_write_states_c, axis=0) # Shape: Bm x D. 429 | 430 | # Hidden states. 431 | prev_write_states_h = [state[2].h for state in prev_states] # Shape: Bm * [1 x D]. 432 | prev_write_states_h = np.concatenate(prev_write_states_h, axis=0) # Shape: Bm x D. 433 | 434 | prev_write_states = LSTMStateTuple(prev_write_states_c, prev_write_states_h) 435 | 436 | feed_dict = { 437 | self._dec_in: inputs[0], 438 | self._enc_pad_mask: inputs[1], 439 | self._prev_states[0]: prev_memories, 440 | self._prev_states[1]: prev_read_states, 441 | self._prev_states[2]: prev_write_states 442 | } 443 | 444 | # Previous compose state of Bm hypothesis. 445 | # Cell states. 446 | if FLAGS.use_comp_lstm: 447 | prev_comp_states_c = [state[3].h for state in prev_states] # Shape: Bm * [1 x (2D)]. 448 | prev_comp_states_c = np.concatenate(prev_comp_states_c, axis=0) # Shape: Bm x (2D). 449 | 450 | prev_comp_states_h = [state[3].h for state in prev_states] # Shape: Bm * [1 x (2D)]. 451 | prev_comp_states_h = np.concatenate(prev_comp_states_h, axis=0) # Shape: Bm x (2D). 452 | 453 | prev_comp_states = LSTMStateTuple(prev_comp_states_c, prev_comp_states_h) 454 | 455 | feed_dict[self._prev_states[3]] = prev_comp_states 456 | 457 | if FLAGS.use_pgen: 458 | feed_dict[self._enc_in_ext_vocab] = inputs[2] 459 | feed_dict[self._max_oov_size] = inputs[3] 460 | 461 | to_return = [self._topk_ids, self._topk_log_probs, self._curr_states] 462 | 463 | if FLAGS.use_pgen: 464 | to_return += [self._p_gens, self._p_attns_dec] 465 | 466 | outputs = self._sess.run(to_return, feed_dict=feed_dict) 467 | 468 | # Preparing the state outputs into lists. 469 | curr_states = outputs[2] 470 | 471 | # Current memories. 472 | curr_memories = np.split(curr_states[0], FLAGS.beam_size, axis=0) # Bm * [1 x T x D]. 473 | 474 | # Current read states. 475 | curr_read_states_c = np.split(curr_states[1].c, FLAGS.beam_size, axis=0) # Bm * [1 x D]. 476 | curr_read_states_h = np.split(curr_states[1].h, FLAGS.beam_size, axis=0) # Bm * [1 x D]. 477 | curr_read_states = [LSTMStateTuple(c, h) 478 | for c, h in zip(curr_read_states_c, curr_read_states_h)] # Bm * [(1 x D, 1 x D)]. 479 | 480 | # Current write states. 481 | curr_write_states_c = np.split(curr_states[2].c, FLAGS.beam_size, axis=0) # Bm * [1 x D]. 482 | curr_write_states_h = np.split(curr_states[2].h, FLAGS.beam_size, axis=0) # Bm * [1 x D]. 483 | curr_write_states = [LSTMStateTuple(c, h) 484 | for c, h in zip(curr_write_states_c, curr_write_states_h)] # Bm * [(1 x D, 1 x D)]. 485 | 486 | if FLAGS.use_comp_lstm: 487 | # Current compose states. 488 | curr_comp_states_c = np.split(curr_states[3].c, FLAGS.beam_size, axis=0) # Bm * [1 x 2D]. 489 | curr_comp_states_h = np.split(curr_states[3].h, FLAGS.beam_size, axis=0) # Bm * [1 x 2D]. 490 | curr_comp_states = [LSTMStateTuple(c, h) 491 | for c, h in zip(curr_comp_states_c, curr_comp_states_h)] # Bm * [(1 x 2D, 1 x 2D)]. 492 | 493 | curr_states_list = [[memory, read_state, write_state, comp_state] 494 | for memory, read_state, write_state, comp_state in 495 | zip(curr_memories, curr_read_states, curr_write_states, curr_comp_states)] 496 | else: 497 | # Forming a list of internal states for Bm hypothesis. 498 | curr_states_list = [[memory, read_state, write_state] for memory, read_state, write_state 499 | in zip(curr_memories, curr_read_states, curr_write_states)] 500 | 501 | outputs[2] = curr_states_list 502 | 503 | # Generation probabilities. 504 | if FLAGS.use_pgen: 505 | p_gens = outputs[3] # Shape: 1 x [Bm x 1]. 506 | p_gens = p_gens[0] # Only one time-step in decoding phase, Shape: Bm x 1. 507 | p_gens = np.squeeze(p_gens, axis=1) # Shape: Bm x 508 | outputs[3] = np.split(p_gens, FLAGS.beam_size, axis=0) # Shape: Bm * [1] 509 | 510 | # Attention probabilities. 511 | if FLAGS.use_pgen: 512 | p_attns = outputs[4] # Shape: 1 x [Bm x mem_size]. 513 | p_attns = p_attns[0] # Only one time-step in the decoding phase, Shape: Bm x mem_size. 514 | outputs[4] = np.split(p_attns, FLAGS.beam_size, axis=0) # Shape: Bm * [1 x mem_size]. 515 | 516 | return outputs 517 | 518 | def _get_vocab_dist(self, inputs): 519 | """ 520 | This function passes the NSE hidden states obtained from decoder through the output layer (dense layer 521 | followed by a softmax layer) and returns the vocabulary distribution thus obtained. 522 | :param inputs: List of hidden states, Shape: T * [B x D]. 523 | :return: p_vocab: List of vocabulary distributions for each time step, Shape: T * [B x vsize]. 524 | """ 525 | steps = len(inputs) 526 | vsize = self.vocab.size() 527 | 528 | inputs = tf.concat(inputs, axis=0) # Shape: (T*B) x D 529 | 530 | with tf.variable_scope('output', reuse=tf.AUTO_REUSE): 531 | scores = tf.layers.dense(inputs, 532 | units=vsize, 533 | activation=None, 534 | kernel_initializer=self._dense_init, 535 | name='output') # Shape: (T*B) x vsize 536 | 537 | p_vocab = tf.nn.softmax(scores, name='prob_scores') # Shape: (T*B) x vsize 538 | 539 | p_vocab = tf.split(p_vocab, num_or_size_splits=steps, axis=0) # Shape: T * [B x vsize] 540 | 541 | return p_vocab 542 | 543 | def _get_cumulative_dist(self, p_vocabs, p_gens, p_attns, enc_in_ext_vocab): 544 | """ 545 | This function calculates the cumulative probability distribution from the vocabulary distribution, 546 | attention distribution and generation probabilities. 547 | :param p_vocabs: A list of vocabulary distributions for each time step, Shape: T_out * [B x vsize]. 548 | :param p_gens: A list of generation probabilities for each time step, Shape: T_out * [B x 1]. 549 | :param p_attns: A list of attention distributions for each time step, Shape: T_out * [B x T_in]. 550 | :param enc_in_ext_vocab: Encoder input represented using extended vocabulary, Shape: B x T_in 551 | :return: 552 | """ 553 | vsize = self.vocab.size() 554 | ext_vsize = vsize + self._max_oov_size 555 | batch_size, enc_steps = enc_in_ext_vocab.get_shape().as_list() 556 | 557 | p_vocabs = [p_gen * p_vocab for p_gen, p_vocab in zip(p_gens, p_vocabs)] # Shape: T_out * [B x vsize] 558 | p_attns = [(1.0 - p_gen) * p_attn for p_gen, p_attn in zip(p_gens, p_attns)] # Shape: T_out * [B x T_in] 559 | 560 | zero_vocab = tf.zeros(shape=[batch_size, self._max_oov_size]) 561 | # Shape: T_out * [B x ext_vsize] 562 | p_vocabs_ext = [tf.concat([p_vocab, zero_vocab], axis=-1) for p_vocab in p_vocabs] 563 | 564 | idx = tf.range(0, limit=batch_size) # Shape: B x 565 | idx = tf.expand_dims(idx, axis=-1) # Shape: B x 1 566 | idx = tf.tile(idx, [1, enc_steps]) # Shape: B x T_in 567 | indices = tf.stack([idx, enc_in_ext_vocab], axis=-1) # Shape: B x T_in x 2 568 | 569 | # First, A zero matrix of shape B x ext_vsize is created. Then, the attention score for each input token 570 | # indexed as in indices is looked in p_attn and updated in the corresponding location, 571 | # Shape: T_out * [B x ext_vsize] 572 | p_attn_cum = [tf.scatter_nd(indices, p_attn, [batch_size, ext_vsize]) for p_attn in p_attns] 573 | 574 | # Cumulative distribution, Shape: T_out * [B x ext_vsize] 575 | p_cum = [gen_prob + copy_prob for gen_prob, copy_prob in zip(p_vocabs_ext, p_attn_cum)] 576 | 577 | return p_cum 578 | 579 | @ staticmethod 580 | def _get_crossentropy_loss(probs, labels, mask): 581 | """ 582 | This function calculates the crossentropy loss from ground-truth and the probability scores. 583 | :param probs: Predicted probabilities, Shape: T * [B x vocab_size] 584 | :param labels: Ground Truth labels, Shape: B x T. 585 | :param mask: Mask to exclude PAD tokens while calculating loss, Shape: B x T. 586 | :return: average mini-batch loss. 587 | """ 588 | if type(probs) is list: 589 | probs = tf.stack(probs, axis=1) # Shape: B x T x vsize. 590 | 591 | true_probs = tf.reduce_sum( 592 | probs * tf.one_hot(labels, depth=tf.shape(probs)[2]), axis=2 593 | ) # Shape: B x T. 594 | logprobs = -tf.log(tf.clip_by_value(true_probs, 1e-10, 1.0)) # Shape: B x T. 595 | xe_loss = tf.reduce_sum(logprobs * mask) / tf.reduce_sum(mask) 596 | 597 | return xe_loss 598 | 599 | def _restore(self): 600 | """ 601 | This function restores parameters from a saved checkpoint. 602 | :return: 603 | """ 604 | restore_path = FLAGS.PathToCheckpoint + "model_epoch10" 605 | 606 | if FLAGS.restore_checkpoint and tf.train.checkpoint_exists(restore_path): 607 | start = time.time() 608 | 609 | # Initializing all variables. 610 | print("Initializing all variables!\n") 611 | self._sess.run(self._init) 612 | 613 | # Restoring checkpoint variables. 614 | vars_restore = [v for v in self._global_vars if "Adam" not in v.name] 615 | restore_saver = tf.train.Saver(vars_restore) 616 | print("Restoring non-Adam parameters from a previous checkpoint.\n") 617 | restore_saver.restore(self._sess, restore_path) 618 | 619 | end = time.time() 620 | print("Restoring model took %.2f sec. \n" % (end - start)) 621 | else: 622 | start = time.time() 623 | self._init.run(session=self._sess) 624 | end = time.time() 625 | print("Running initializer took %.2f time. \n" % (end - start)) 626 | 627 | def train(self): 628 | """ 629 | This function performs the training followed by validation after every few epochs and saves the best model 630 | if validation loss is decreased. It also writes the training summaries for visualization using tensorboard. 631 | :return: 632 | """ 633 | # Restore from previous checkpoint if found one. 634 | self._restore() 635 | 636 | # No. of iterations 637 | num_train_iters_per_epoch = int(math.ceil(self.data.num_train_examples / FLAGS.batch_size)) 638 | 639 | # Running Averages 640 | running_avg_crossentropy_loss = 0.0 641 | running_avg_rouge = 0.0 642 | for epoch in range(1, 1 + FLAGS.num_epochs): 643 | # Training 644 | for iteration in range(1, 1 + num_train_iters_per_epoch): 645 | feed_dict = self._get_feed_dict(split='train') 646 | 647 | to_return = [self._train_op, self._crossentropy_loss, self._global_step, self._summaries] 648 | if FLAGS.rouge_summary: 649 | to_return.append(self._samples) 650 | 651 | # Evaluate summaries for last batch. 652 | feed_dict[self._mean_crossentropy_loss] = running_avg_crossentropy_loss 653 | if FLAGS.rouge_summary: 654 | feed_dict[self._mean_rouge_score] = running_avg_rouge 655 | 656 | outputs = self._sess.run(to_return, feed_dict=feed_dict) 657 | _, crossentropy_loss, global_step = outputs[: 3] 658 | 659 | # Calculating ROUGE score. 660 | rouge_score = 0.0 661 | if FLAGS.rouge_summary: 662 | samples = outputs[4] 663 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x . 664 | gt_labels = feed_dict[self._dec_out] 665 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False) 666 | 667 | # Updating the running average losses. 668 | running_avg_crossentropy_loss = get_running_avg_loss( 669 | crossentropy_loss, running_avg_crossentropy_loss 670 | ) 671 | running_avg_rouge = get_running_avg_loss( 672 | np.mean(rouge_score), running_avg_rouge 673 | ) 674 | 675 | if ((iteration - 2) % FLAGS.summary_every == 0) or (iteration == num_train_iters_per_epoch): 676 | train_summary = outputs[3] 677 | self._train_writer.add_summary(train_summary, global_step) 678 | 679 | print("\rTraining Iteration: {}/{} ({:.1f}%)".format( 680 | iteration, num_train_iters_per_epoch, iteration * 100 / num_train_iters_per_epoch, 681 | )) 682 | 683 | if epoch % FLAGS.val_every == 0: 684 | self._validate() 685 | 686 | def _validate(self): 687 | """ 688 | This function validates the saved model. 689 | :param epoch: Used while writing the validation summary. 690 | :return: N/A. 691 | """ 692 | # Validation 693 | num_val_iters_per_epoch = int(math.ceil(self.data.num_val_examples / FLAGS.val_batch_size)) 694 | num_train_iters_per_epoch = int(math.ceil(self.data.num_train_examples / FLAGS.batch_size)) 695 | 696 | outputs = None 697 | total_crossentropy_loss = 0.0 698 | total_rouge_score = 0.0 699 | for iteration in range(1, 1 + num_val_iters_per_epoch): 700 | feed_dict = self._get_feed_dict(split='val') 701 | to_return = [self._crossentropy_loss] 702 | if FLAGS.rouge_summary: 703 | to_return.append(self._samples) 704 | 705 | if iteration == num_val_iters_per_epoch: 706 | feed_dict[self._mean_crossentropy_loss] = total_crossentropy_loss / (num_val_iters_per_epoch - 1) 707 | feed_dict[self._mean_rouge_score] = total_rouge_score / (num_val_iters_per_epoch - 1) 708 | to_return += [self._val_summaries, self._global_step] 709 | 710 | outputs = self._sess.run(to_return, feed_dict=feed_dict) 711 | crossentropy_loss = outputs[0] 712 | 713 | # Calculating ROUGE score. 714 | rouge_score = 0.0 715 | if FLAGS.rouge_summary: 716 | samples = outputs[1] 717 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x . 718 | gt_labels = feed_dict[self._dec_out] 719 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False) 720 | 721 | print("\rValidation Iteration: {}/{} ({:.1f}%)".format( 722 | iteration, num_val_iters_per_epoch, iteration * 100 / num_val_iters_per_epoch, 723 | )) 724 | 725 | # Updating the total losses. 726 | total_crossentropy_loss += crossentropy_loss 727 | total_rouge_score += np.mean(rouge_score) 728 | 729 | # Writing the validation summaries for visualization. 730 | val_summary, global_step = outputs[-2], outputs[-1] 731 | self._val_writer.add_summary(val_summary, global_step) 732 | epoch = global_step // num_train_iters_per_epoch 733 | 734 | # Cumulative loss. 735 | avg_xe_loss = total_crossentropy_loss / num_val_iters_per_epoch 736 | if avg_xe_loss < self._best_val_loss: 737 | print("\rValidation loss improved :).") 738 | self._saver.save(self._sess, FLAGS.PathToCheckpoint + "model_epoch" + str(epoch)) 739 | 740 | def test(self): 741 | """ 742 | This function predicts the outputs for the test set. 743 | :return: 744 | """ 745 | num_gpus = len(FLAGS.GPUs) # No. of GPUs. 746 | batch_size = num_gpus * FLAGS.num_pools # No. of examples processed per a single test iteration. 747 | 748 | # No. of iterations 749 | num_test_iters_per_epoch = int(math.ceil(self.data.num_test_examples / batch_size)) 750 | 751 | for iteration in range(1, 1 + num_test_iters_per_epoch): 752 | start = time.time() 753 | # [indices, summaries, files[start: end], enc_inp, enc_padding_mask, 754 | # enc_inp_ext_vocab, ext_vocabs, max_oov_size] 755 | input_batches = self.data.get_batch(batch_size, "test") 756 | 757 | # Split the data into batches of size "num_pools" per GPU. 758 | input_batches_per_gpu = [[] for _ in range(num_gpus)] 759 | for i, input_batch in enumerate(input_batches): 760 | for j, idx in zip(range(num_gpus), range(0, batch_size, FLAGS.num_pools)): 761 | if i < 5 or (FLAGS.use_pgen and i < 7): # First 5 inputs (7 inputs in pgen mode). 762 | input_batches_per_gpu[j].append( 763 | input_batch[idx: idx + FLAGS.num_pools] 764 | ) 765 | else: # max_oov_size is the same for all examples. 766 | input_batches_per_gpu[j].append(input_batch) 767 | 768 | # Appending the GPU id. 769 | for gpu_id in range(num_gpus): 770 | input_batches_per_gpu[gpu_id].append(gpu_id) 771 | 772 | Parallel(n_jobs=num_gpus, backend="threading")( 773 | map(delayed(self._test_one_gpu), input_batches_per_gpu)) 774 | 775 | end = time.time() 776 | print("\rTesting Iteration: {}/{} ({:.1f}%) in {:.1f} sec.".format( 777 | iteration, num_test_iters_per_epoch, iteration * 100 / num_test_iters_per_epoch, end - start)) 778 | 779 | def _test_one_gpu(self, inputs): 780 | """ 781 | This function performs testing on the inputs of a single GPU. 782 | :param inputs: 783 | :return: 784 | """ 785 | input_batches = inputs[: -1] 786 | gpu_idx = inputs[-1] 787 | 788 | with tf.device('/gpu:%d' % gpu_idx): 789 | self._test_one_pool(input_batches) 790 | 791 | def _test_one_pool(self, inputs): 792 | """ 793 | This function performs testing on the inputs of a single GPU over parallel pools. 794 | :param inputs: All the first 7 inputs have "num_pools" examples each. 795 | [indices, summaries, files[start: end], enc_inp, enc_padding_mask, 796 | enc_inp_ext_vocab, ext_vocabs, max_oov_size] 797 | :return: 798 | """ 799 | # Split the inputs into each example per pool. 800 | input_batches_per_pool = [[] for _ in range(FLAGS.num_pools)] 801 | for i, input_batch in enumerate(inputs): 802 | for j in range(FLAGS.num_pools): 803 | # Splitting the GT summaries. 804 | if i == 0 or i == 1 or i == 2 or (FLAGS.use_pgen and i == 6): 805 | input_batches_per_pool[j].append([input_batch[j]]) 806 | 807 | # Repeating the same Numpy array for beam size no. of times. 808 | elif i == 3 or i == 4 or (FLAGS.use_pgen and i == 5): 809 | input_batches_per_pool[j].append( 810 | np.repeat(input_batch[np.newaxis, j], repeats=FLAGS.beam_size, axis=0)) 811 | 812 | # max_oov_size is same for all examples in the pool. 813 | else: 814 | input_batches_per_pool[j].append(input_batch) 815 | 816 | Parallel(n_jobs=FLAGS.num_pools, backend="threading")( 817 | map(delayed(self._test_one_ex), input_batches_per_pool)) 818 | 819 | def _test_one_ex(self, inputs): 820 | """ 821 | This function performs testing on one example. 822 | :param inputs: 823 | [indices, summaries, files[start: end], enc_inp, enc_padding_mask, 824 | enc_inp_ext_vocab, ext_vocabs, max_oov_size] 825 | :return: 826 | """ 827 | iteration = inputs[0][0] 828 | input_batches = inputs[1:] 829 | # Inputs for running beam search on one example. 830 | beam_search_inputs = input_batches[2: 4] 831 | if FLAGS.use_pgen: 832 | beam_search_inputs += [input_batches[4], input_batches[6]] 833 | 834 | best_hyp = run_beam_search(beam_search_inputs, self.run_encoder, self.decode_one_step, self.vocab) 835 | pred_ids = best_hyp.tokens # Shape: T_dec * [1 x] 836 | pred_ids = np.stack(pred_ids, axis=0) # Shape: T_dec x 837 | pred_ids = pred_ids[np.newaxis, :] # Shape: 1 x T_dec 838 | 839 | # Writing the outputs. 840 | if FLAGS.use_pgen: 841 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1], input_batches[5]) 842 | else: 843 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1]) 844 | 845 | def _get_feed_dict(self, split='train'): 846 | 847 | if split == "train": 848 | input_batches = self.data.get_batch(FLAGS.batch_size, split="train") 849 | 850 | elif split == "val": 851 | input_batches = self.data.get_batch(FLAGS.val_batch_size, split="val") 852 | 853 | else: 854 | raise ValueError("Split should be either train/val!! \n") 855 | 856 | feed_dict = { 857 | self._enc_in: input_batches[0], 858 | self._enc_pad_mask: input_batches[1], 859 | self._dec_in: input_batches[2], 860 | self._dec_out: input_batches[3], 861 | self._dec_pad_mask: input_batches[4] 862 | } 863 | 864 | if FLAGS.use_pgen: 865 | feed_dict[self._enc_in_ext_vocab] = input_batches[5] 866 | feed_dict[self._max_oov_size] = input_batches[6] 867 | 868 | return feed_dict 869 | 870 | @staticmethod 871 | def make_html_safe(s): 872 | """ 873 | Replace any angled brackets in string s to avoid interfering with HTML attention visualizer. 874 | :param s: Input string. 875 | :return: 876 | """ 877 | s.replace("<", "<") 878 | s.replace(">", ">") 879 | 880 | return s 881 | 882 | def _write_file(self, pred, pred_name, gt, gt_name, ext_vocab=None): 883 | """ 884 | This function writes tokens in vals to a .txt file with given name. 885 | :param pred: Predicted ID's. Shape: 1 x T 886 | :param pred_name: Name of the file in which predictions will be written. 887 | :param gt: Ground truth summary, a list of sentences (strings). 888 | :param gt_name: Name of the file in which GTs will be written. 889 | :param ext_vocab: Extended vocabulary for each example. [ext_words x] 890 | :return: _pred, _gt files will be created for ROUGE evaluation. 891 | """ 892 | # Writing predictions. 893 | vsize = self.vocab.size() 894 | 895 | # Removing the [START] token. 896 | pred = pred[1:] 897 | 898 | # Converting words to ID's 899 | pred_words = [] 900 | for t in pred: 901 | try: 902 | pred_words.append(self.vocab.id2word(t)) 903 | except ValueError: 904 | pred_words.append(ext_vocab[t - vsize]) 905 | 906 | # Considering tokens only till STOP token. 907 | try: 908 | stop_idx = pred_words.index(params.STOP_DECODING) 909 | except ValueError: 910 | stop_idx = len(pred_words) 911 | pred_words = pred_words[: stop_idx] 912 | 913 | # Creating sentences out of the predicted sequence. 914 | pred_sents = [] 915 | while pred_words: 916 | try: 917 | period_idx = pred_words.index(".") 918 | except ValueError: 919 | period_idx = len(pred_words) 920 | 921 | # Append the sentence. 922 | sent = pred_words[: period_idx + 1] 923 | pred_sents.append(" ".join(sent)) 924 | 925 | # Consider the remaining words now. 926 | pred_words = pred_words[period_idx + 1:] 927 | 928 | # Making HTML safe. 929 | pred_sents = [self.make_html_safe(s) for s in pred_sents] 930 | gt_sents = [self.make_html_safe(s) for s in gt] 931 | 932 | # Writing predicted sentences. 933 | f = open(pred_name, 'w', encoding='utf-8') 934 | for i, sent in enumerate(pred_sents): 935 | f.write(sent) if i == len(pred_sents) - 1 else f.write(sent + "\n") 936 | f.close() 937 | 938 | # Writing GT sentences. 939 | f = open(gt_name, 'w', encoding='utf-8') 940 | for i, sent in enumerate(gt_sents): 941 | f.write(sent) if i == len(gt_sents) - 1 else f.write(sent + "\n") 942 | f.close() 943 | 944 | def _write_outputs(self, index, preds, gts, files, ext_vocabs=None): 945 | """ 946 | This function writes the input files 947 | :param index: Number of the test example. 948 | :param preds: The predictions. Shape: B x T 949 | :param: gts: The ground truths. Shape: B x T 950 | :param files: The names of the files. 951 | :param ext_vocabs: Extended vocabularies for each example, Shape: B * [ext_words x] 952 | :return: Saves the predictions and GT's in a .txt format. 953 | """ 954 | for i in range(len(files)): 955 | file, pred, gt = files[i], preds[i], gts[i] 956 | name_pred = FLAGS.PathToResults + 'predictions/' + '%06d_pred.txt' % index 957 | name_gt = FLAGS.PathToResults + 'groundtruths/' + '%06d_gt.txt' % index 958 | 959 | ext_vocab = None 960 | if FLAGS.use_pgen: 961 | ext_vocab = ext_vocabs[i] 962 | 963 | self._write_file(pred, name_pred, gt, name_gt, ext_vocab) 964 | 965 | def _single_model(self): 966 | placeholders = self._create_placeholders() 967 | self._crossentropy_loss, self._samples = self._forward(placeholders) 968 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer 969 | self._train_op = sgd_solver.minimize( 970 | self._crossentropy_loss, global_step=self._global_step, name='train_op') 971 | 972 | self._saver = tf.train.Saver(tf.global_variables()) # Saver. 973 | self._init = tf.global_variables_initializer() # Initializer. 974 | self._sess = tf.Session(config=self._config) # Session. 975 | 976 | # Train and validation summary writers. 977 | if FLAGS.mode == 'train': 978 | self._create_writers() 979 | 980 | self._global_vars = tf.global_variables() 981 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables()))) 982 | # print(tf.trainable_variables()) 983 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()]) 984 | # print('No. of params = {:d}'.format(int(no_params))) 985 | 986 | def _parallel_model(self): 987 | with tf.Graph().as_default(), tf.device('/cpu:0'): 988 | placeholders = self._create_placeholders() 989 | 990 | # Splitting the placeholders for each GPU. 991 | placeholders_per_gpu = [[] for _ in range(self._num_gpus)] 992 | for placeholder in placeholders: 993 | splits = tf.split(placeholder, num_or_size_splits=self._num_gpus, axis=0) 994 | for i, split in enumerate(splits): 995 | placeholders_per_gpu[i].append(split) 996 | 997 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer 998 | tower_grads = [] # Gradients calculated in each tower 999 | with tf.variable_scope(tf.get_variable_scope()): 1000 | all_losses = [] 1001 | all_samples = [] 1002 | for i in range(self._num_gpus): 1003 | with tf.device('/gpu:%d' % i): 1004 | with tf.name_scope('tower_%d' % i): 1005 | crossentropy_loss, samples = self._forward(placeholders_per_gpu[i]) 1006 | 1007 | # Gathering samples from all GPUs. 1008 | all_samples.append(samples) 1009 | 1010 | tf.get_variable_scope().reuse_variables() 1011 | 1012 | grads = sgd_solver.compute_gradients(loss_to_minimize) 1013 | tower_grads.append(grads) 1014 | 1015 | # Updating the losses 1016 | all_losses.append(crossentropy_loss) 1017 | 1018 | self._crossentropy_loss = tf.add_n(all_losses) / self._num_gpus 1019 | 1020 | # Samples from all the GPUs. 1021 | if FLAGS.rouge_summary: 1022 | self._samples = tf.concat(all_samples, axis=0) 1023 | else: 1024 | self._samples = [] 1025 | 1026 | # Synchronization Point 1027 | gradients, variables = zip(*average_gradients(tower_grads)) 1028 | gradients = [tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad) for grad in gradients] 1029 | gradients, global_norm = tf.clip_by_global_norm(gradients, FLAGS.max_grad_norm) 1030 | 1031 | # Summary for the global norm 1032 | tf.summary.scalar('global_norm', global_norm) 1033 | 1034 | # Histograms for gradients. 1035 | for grad, var in zip(gradients, variables): 1036 | if grad is not None: 1037 | tf.summary.histogram(var.op.name + '/gradients', grad) 1038 | 1039 | # Histograms for variables. 1040 | for var in tf.trainable_variables(): 1041 | tf.summary.histogram(var.op.name, var) 1042 | 1043 | self._train_op = sgd_solver.apply_gradients( 1044 | zip(gradients, variables), global_step=self._global_step) 1045 | 1046 | self._saver = tf.train.Saver(tf.global_variables()) # Saver 1047 | self._init = tf.global_variables_initializer() # Initializer. 1048 | self._sess = tf.Session(config=self._config) # Session. 1049 | 1050 | # Train and validation summary writers. 1051 | if FLAGS.mode == 'train': 1052 | self._create_writers() 1053 | 1054 | self._global_vars = tf.global_variables() 1055 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables()))) 1056 | # print(tf.trainable_variables()) 1057 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()]) 1058 | # print('No. of params = {:d}'.format(int(no_params))) 1059 | -------------------------------------------------------------------------------- /codes/model_hier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Rajeev Bhatt Ambati" 3 | 4 | from HierNSE import HierNSE 5 | from utils import average_gradients, get_running_avg_loss 6 | import params as params 7 | from beam_search import run_beam_search_hier 8 | from rouge_batch import rouge_l_fscore_batch as rouge_l_fscore 9 | 10 | import random 11 | import math 12 | import time 13 | 14 | import tensorflow as tf 15 | from tensorflow.nn.rnn_cell import LSTMStateTuple 16 | import numpy as np 17 | from joblib import Parallel, delayed 18 | 19 | tf.set_random_seed(2019) 20 | random.seed(2019) 21 | tf.reset_default_graph() 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | 26 | class SummarizationModelHier(object): 27 | def __init__(self, vocab, data): 28 | self._vocab = vocab 29 | self._data = data 30 | 31 | self._dense_init = tf.contrib.layers.xavier_initializer() 32 | 33 | self._config = tf.ConfigProto() 34 | self._config.gpu_options.allow_growth = True 35 | 36 | self._best_val_loss = np.infty 37 | self._num_gpus = len(FLAGS.GPUs) 38 | 39 | self._sess = None 40 | self._saver = None 41 | self._init = None 42 | 43 | def _create_placeholders(self): 44 | """ 45 | This function creates the placeholders needed for the computation graph. 46 | [enc_in, enc_pad_mask, enc_doc_mask, dec_in, enc_in_ext_vocab, labels, dec_pad_mask] 47 | :return: 48 | """ 49 | self._global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') 50 | self._learning_rate = tf.placeholder(tf.float32, name='learning_rate') 51 | 52 | # Word embedding. 53 | if FLAGS.use_pretrained: 54 | with tf.variable_scope("embed", reuse=tf.AUTO_REUSE): 55 | self._word_embedding = tf.get_variable(name="word_embedding", 56 | shape=[self._vocab.size(), FLAGS.dim], 57 | initializer=tf.constant_initializer(self._vocab.wvecs), 58 | dtype=tf.float32, 59 | trainable=True) 60 | 61 | else: 62 | self._word_embedding = tf.get_variable(name="word_embedding", 63 | shape=[self._vocab.size(), FLAGS.dim], 64 | dtype=tf.float32, 65 | trainable=True) 66 | 67 | # Graph Inputs/Outputs. 68 | self._enc_in = tf.placeholder(dtype=tf.int32, name='enc_in', 69 | shape=[FLAGS.batch_size, FLAGS.max_enc_sent, 70 | FLAGS.max_enc_steps_per_sent]) # Shape: B x S_in x T_in. 71 | self._enc_pad_mask = tf.placeholder(dtype=tf.float32, name='enc_pad_mask', 72 | shape=[FLAGS.batch_size, FLAGS.max_enc_sent, 73 | FLAGS.max_enc_steps_per_sent]) # Shape: B x S_in x T_in. 74 | self._enc_doc_mask = tf.placeholder(dtype=tf.float32, name='enc_doc_mask', 75 | shape=[FLAGS.batch_size, FLAGS.max_enc_sent]) # Shape: B x S_in. 76 | self._dec_in = tf.placeholder(dtype=tf.int32, name='dec_in', 77 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec. 78 | 79 | inputs = [self._enc_in, self._enc_pad_mask, self._enc_doc_mask, self._dec_in] 80 | 81 | # Additional inputs in pointer-generator mode. 82 | if FLAGS.use_pgen: 83 | self._enc_in_ext_vocab = tf.placeholder(dtype=tf.int32, 84 | name='enc_inp_ext_vocab', 85 | shape=[FLAGS.batch_size, 86 | FLAGS.max_enc_sent * 87 | FLAGS.max_enc_steps_per_sent]) # Shape: B x T_enc. 88 | self._max_oov_size = tf.placeholder(dtype=tf.int32, name='max_oov_size') 89 | inputs.append(self._enc_in_ext_vocab) 90 | 91 | # Additional ground-truth's when in training. 92 | if FLAGS.mode.lower() == "train": 93 | self._dec_out = tf.placeholder(dtype=tf.int32, name='dec_out', 94 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec. 95 | self._dec_pad_mask = tf.placeholder(dtype=tf.float32, name='dec_pad_mask', 96 | shape=[FLAGS.batch_size, FLAGS.dec_steps]) # Shape: B x T_dec. 97 | inputs += [self._dec_out, self._dec_pad_mask] 98 | 99 | return inputs 100 | 101 | def _create_writers(self): 102 | """ 103 | This function creates the summaries and writers needed for visualization through tensorboard. 104 | :return: writers. 105 | """ 106 | self._mean_crossentropy_loss = tf.placeholder(dtype=tf.float32, name='mean_crossentropy_loss') 107 | self._mean_rouge_score = tf.placeholder(dtype=tf.float32, name="mean_rouge") 108 | 109 | # Summaries. 110 | cross_entropy_summary = tf.summary.scalar('cross_entropy', self._mean_crossentropy_loss) 111 | rouge_summary = None 112 | if FLAGS.rouge_summary: 113 | rouge_summary = tf.summary.scalar('rouge_score', self._mean_rouge_score) 114 | self._summaries = tf.summary.merge_all() 115 | 116 | summary_list = [cross_entropy_summary] 117 | if FLAGS.rouge_summary: 118 | summary_list.append(rouge_summary) 119 | 120 | self._val_summaries = tf.summary.merge(summary_list, name="validation_summaries") 121 | 122 | # Summary writers. 123 | self._train_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'train') # , self._sess.graph) 124 | self._val_writer = tf.summary.FileWriter(FLAGS.PathToTB + 'val') 125 | 126 | def build_graph(self): 127 | start = time.time() 128 | if FLAGS.mode == 'train': 129 | if len(FLAGS.GPUs) > 1: 130 | self._parallel_model() # Parallel model in case of multiple-GPUs. 131 | else: 132 | self._single_model() # Single model for a single GPU/CPU. 133 | 134 | if FLAGS.mode == 'test': 135 | inputs = self._create_placeholders() # [enc_in, dec_in, enc_in_ext_vocab] 136 | self._forward(inputs) # Predictions Shape: Bm x 1 137 | 138 | self._saver = tf.train.Saver(tf.global_variables()) # Saver. 139 | self._init = tf.global_variables_initializer() # Initializer. 140 | self._sess = tf.Session(config=self._config) # Session. 141 | 142 | # Restoring the trained model. 143 | self._saver.restore(self._sess, FLAGS.PathToCheckpoint) 144 | 145 | end = time.time() 146 | print("build_graph took %.2f sec. \n" % (end - start)) 147 | 148 | def _forward(self, inputs): 149 | """ 150 | This function creates the TensorFlow computation graph. 151 | :param inputs: A list of input placeholders. 152 | [0] enc_in: Encoder input sequence of ID's, Shape: B x S_in x T_in. 153 | [1] enc_pad_mask: Encoder input mask to indicate the presence of PAd tokens, Shape: B x S_in x T_in. 154 | [2] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in. 155 | [3] dec_in: Decoder input sequence of ID's, Shape: B x T_dec. 156 | 157 | Following additional input in pointer-generator mode. 158 | [4] enc_in_ext_vocab: Encoder input representation in the extended vocabulary, Shape: B x T_enc. 159 | 160 | [5] labels: Ground-Truth labels, Only in train mode, Shape: B x T_dec. 161 | [6] dec_pad_mask: Decoder output mask, Shape: B x T_dec. 162 | :return: returns loss in train mode and predictions in test mode. 163 | """ 164 | batch_size = inputs[0].get_shape().as_list()[0] # Batch-size 165 | # The NSE instance 166 | self._nse = HierNSE(batch_size=batch_size, dim=FLAGS.dim, dense_init=self._dense_init, 167 | mode=FLAGS.mode, use_comp_lstm=FLAGS.use_comp_lstm, num_layers=FLAGS.num_layers) 168 | 169 | # Encoder, used while testing phase. 170 | self._prev_states = self._encoder(inputs[: 3]) 171 | 172 | # Decoder. 173 | outputs = self._decoder(inputs[1: 5] + [inputs[-1]], self._prev_states) 174 | 175 | if FLAGS.mode.lower() == "test": 176 | self._topk_ids, self._topk_log_probs, self._curr_states, self._p_attns = outputs[: 4] 177 | 178 | if FLAGS.use_pgen: 179 | self._p_gens = outputs[4] 180 | 181 | else: 182 | probs, p_attns = outputs 183 | 184 | crossentropy_loss = self._get_crossentropy_loss(probs=probs, labels=inputs[-2], mask=inputs[-1]) 185 | 186 | # Evaluating a few samples. 187 | samples = [] 188 | if FLAGS.rouge_summary: 189 | samples = self._get_samples(inputs[1: 5], self._prev_states) 190 | 191 | return crossentropy_loss, samples 192 | 193 | def _encoder(self, inputs): 194 | """ 195 | This is the encoder. 196 | :param inputs: A list of the following inputs. 197 | enc_in: Encoder input sequence of ID's, Shape: B x S_in x T_in. 198 | enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x S_in x T_in. 199 | enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in. 200 | :return: 201 | A list of internal states of NSE after the last encoding step. 202 | [0] memory: [sent_mems, doc_mem] The sentence and document memories respectively, 203 | Shape: [B x S_in x T_in x D, B x S_in x D]. 204 | [1] read_state: Hidden state of the read LSTM after the last encoding step, (c, h) Shape: 2 * [B x D]. 205 | [2] write_state: Hidden state of the write LSTM after the last encoding step, (c, h) Shape: 2 * [B x D]. 206 | [3] comp_state: Hidden state of the compose LSTM after the last encoding step, (c, h) Shape: 2 * [B x D]. 207 | """ 208 | with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): 209 | enc_in, enc_pad_mask, enc_doc_mask = inputs 210 | 211 | # Converting ID's to word-vectors. 212 | enc_in_vecs = tf.nn.embedding_lookup(params=self._word_embedding, ids=enc_in) # Shape: B x S_in x T_in x D. 213 | enc_in_vecs = tf.cast(enc_in_vecs, dtype=tf.float32) # Cast to float32. 214 | 215 | # Document memory. 216 | doc_mem = tf.reduce_mean(enc_in_vecs, axis=2, name="document_memory") # Shape: B x S_in x D. 217 | 218 | new_sent_mems = [] # New sentence memories. 219 | lstm_states = [None, None, None] # LSTM states. 220 | 221 | for i in range(FLAGS.max_enc_sent): 222 | # Mask, memory of ith sentence. 223 | sent_i_mask = enc_pad_mask[:, i, :] # Shape: B x T_in. 224 | sent_i_mem = enc_in_vecs[:, i, :, :] # Shape: B x T_in x D. 225 | 226 | mem_masks = [sent_i_mask, enc_doc_mask] # ith-sentence and document masks. 227 | state = [[sent_i_mem, doc_mem]] + lstm_states # NSE internal state. 228 | 229 | for j in range(FLAGS.max_enc_steps_per_sent): 230 | # j-th token from the ith sentence. 231 | x_t = enc_in_vecs[:, i, j, :] # Shape: B x D. 232 | output, state = self._nse.step(x_t=x_t, mem_masks=mem_masks, prev_state=state) 233 | 234 | # Update 235 | new_sent_mems.append(state[0][0]) # ith-sentence memory. 236 | doc_mem = state[0][1] # Document memory. 237 | lstm_states = state[1:] # Read, write, compose states. 238 | 239 | new_sent_mems = tf.concat(new_sent_mems, axis=1) # Shape: B x (S_in*T_in) x D. 240 | all_states = [[new_sent_mems, doc_mem]] + lstm_states 241 | 242 | return all_states 243 | 244 | def _decoder(self, inputs, all_states): 245 | """ 246 | This is the decoder. 247 | :param inputs: A list of the following inputs. 248 | [0] enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x S_in x T_in. 249 | [1] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in. 250 | [2] dec_in: Input to the decoder, Shape: B x T_dec. 251 | Following additional inputs in pointer generator mode: 252 | [3] enc_in_ext_vocab: (For pointer generator mode) 253 | Encoder input representation in the extended vocabulary, Shape: B x T_enc. 254 | [4] dec_pad_mask: Decoder mask to indicate the presence of PAD tokens, Shape: B x T_dec. 255 | 256 | :param all_states: The internal states of NSE after the last encoding step. 257 | [0] memory: [sent_mems, doc_mem] The sentence and document memories respectively, 258 | Shape: [B x T_enc x D, B x S_in x D]. 259 | [1] read_state: Hidden state of the read LSTM after the last encoding step,(c, h) Shape: 2 * [B x D]. 260 | [2] write_state: Hidden state of the write LSTM after the last encoding step,(c, h) Shape: 2 * [B x D]. 261 | [3] comp_state: Hidden state of the compose LSTM after the last encoding step,(c, h) Shape: 2 * [B x D]. 262 | :return: 263 | In test mode: 264 | [topk_ids, topk_log_probs, state, p_attns] + [p_gens] 265 | In train mode: 266 | [p_cums, p_attns] 267 | """ 268 | with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): 269 | enc_pad_mask, enc_doc_mask, dec_in = inputs[: 3] 270 | 271 | # Concatenating the pad masks of all sentences. 272 | enc_pad_mask = tf.reshape( 273 | enc_pad_mask, [-1, FLAGS.max_enc_sent*FLAGS.max_enc_steps_per_sent]) # Shape: B x T_enc. 274 | 275 | # Converting ID's to word vectors. 276 | dec_in_vecs = tf.nn.embedding_lookup(params=self._word_embedding, ids=dec_in) # Shape: B x T_dec x D. 277 | dec_in_vecs = tf.cast(dec_in_vecs, dtype=tf.float32) # Shape: B x T_dec x D. 278 | 279 | # Memory masks. 280 | mem_masks = [enc_pad_mask, enc_doc_mask] # Shape: [B x T_enc, B x S_in]. 281 | 282 | # NSE internal states. 283 | sent_mems, doc_mem = all_states[0] # Shape: B x T_enc x D, B x S_in x D. 284 | state = [[sent_mems, doc_mem]] + all_states[1:] 285 | 286 | writes = [] 287 | p_attns = [] 288 | p_gens = [] 289 | for i in range(FLAGS.dec_steps): 290 | 291 | x_t = dec_in_vecs[:, i, :] # Shape: B x D. 292 | output, state = self._nse.step( 293 | x_t=x_t, mem_masks=mem_masks, prev_state=state, use_pgen=FLAGS.use_pgen 294 | ) 295 | 296 | # Appending the outputs. 297 | writes.append(output[0]) 298 | p_attns.append(output[1]) 299 | 300 | if FLAGS.use_pgen: 301 | p_gens.append(output[2]) 302 | 303 | p_vocabs = self._get_vocab_dist(writes) # Shape: T_dec * [B x vsize]. 304 | 305 | p_cums = p_vocabs 306 | if FLAGS.use_pgen: 307 | enc_in_ext_vocab = inputs[3] 308 | p_cums = self._get_cumulative_dist( 309 | p_vocabs, p_gens, p_attns, enc_in_ext_vocab 310 | ) # Shape: T_dec * [B x ext_vsize]. 311 | 312 | if FLAGS.mode.lower() == "test": 313 | p_final = p_cums[0] 314 | 315 | # The top k predictions, topk_ids, Shape: Bm * (2*Bm). 316 | # Respective probabilities, topk_probs, Shape: Bm * (2*Bm). 317 | topk_probs, topk_ids = tf.nn.top_k(p_final, k=2*FLAGS.beam_size, name="topk_preds") 318 | topk_log_probs = tf.log(tf.clip_by_value(topk_probs, 1e-10, 1.0)) 319 | 320 | outputs = [topk_ids, topk_log_probs, state, p_attns] 321 | if FLAGS.use_pgen: 322 | outputs.append(p_gens) 323 | 324 | return outputs 325 | else: 326 | return p_cums, p_attns 327 | 328 | def _get_samples(self, inputs, state): 329 | """ 330 | This function samples greedily from decoder output to calculate the ROUGE score. 331 | :param inputs: A list of the following inputs. 332 | [0] enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens, Shape: B x S_in x T_in. 333 | [1] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: B x S_in. 334 | [2] dec_in: Input to the decoder, Shape: B x T_dec. 335 | Following additional inputs in pointer generator mode: 336 | [3] enc_in_ext_vocab: (For pointer generator mode) 337 | Encoder input representation in the extended vocabulary, Shape: B x T_enc. 338 | :param state: The internal states of NSE after the last encoding step. 339 | [0] memory: [sent_mems, doc_mem] The sentence and document memories respectively, 340 | Shape: [B x T_enc x D, B x S_in x D]. 341 | [1] read_state: Hidden state of the read LSTM after the last encoding step,(c, h) Shape: 2 * [B x D]. 342 | [2] write_state: Hidden state of the write LSTM after the last encoding step,(c, h) Shape: 2 * [B x D]. 343 | [3] comp_state: Hidden state of the compose LSTM after the last encoding step,(c, h) Shape: 2 * [B x D]. 344 | :return: 345 | probs: Probabilities used for sampling, Shape: B x T_dec x V. 346 | samples: The samples, Shape: B x T_dec. 347 | """ 348 | unk_id = self._vocab.word2id(params.UNKNOWN_TOKEN) # Unknown ID. 349 | start_id = self._vocab.word2id(params.START_DECODING) # Start ID. 350 | 351 | with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): 352 | enc_pad_mask, enc_doc_mask = inputs[: 2] 353 | batch_size = enc_pad_mask.get_shape().as_list()[0] 354 | # Concatenating the pad masks of all sentences. 355 | enc_pad_mask = tf.reshape( 356 | enc_pad_mask, 357 | [-1, FLAGS.max_enc_sent * FLAGS.max_enc_steps_per_sent]) # Shape: B x T_enc. 358 | 359 | # Memory masks. 360 | mem_masks = [enc_pad_mask, enc_doc_mask] # Shape: [B x T_enc, B x S_in]. 361 | 362 | samples = [] # Sampling outputs. 363 | for i in range(FLAGS.dec_steps): 364 | 365 | if i == 0: 366 | id_t = tf.fill([batch_size], start_id) # Shape: B x . 367 | else: 368 | id_t = samples[-1][:, 0] 369 | # Replacing the ID's from external vocabulary (if any) to UNK id's. 370 | id_t = tf.where( 371 | tf.less(id_t, self._vocab.size()), id_t, unk_id * tf.ones_like(id_t) 372 | ) 373 | 374 | # Getting the word vector. 375 | x_t = tf.nn.embedding_lookup(params=self._word_embedding, ids=id_t) # Shape: B x D. 376 | x_t = tf.cast(x_t, dtype=tf.float32) 377 | 378 | output, state = self._nse.step( 379 | x_t=x_t, mem_masks=mem_masks, prev_state=state, use_pgen=FLAGS.use_pgen, 380 | ) 381 | 382 | # Output probability distribution. 383 | p_vocab = self._get_vocab_dist([output[0]]) # Shape: [B x vsize]. 384 | 385 | # Calculate cumulative probability distribution using pointer mechanism. 386 | p_cum = p_vocab 387 | if FLAGS.use_pgen: 388 | p_cum = self._get_cumulative_dist( 389 | p_vocabs=p_vocab, p_gens=[output[2]], p_attns=[output[1]], enc_in_ext_vocab=inputs[-1] 390 | ) # Shape: T_dec * [B x ext_vsize]. 391 | 392 | # Greedy sampling. 393 | _, gs_sample = tf.nn.top_k(p_cum[0], k=FLAGS.num_samples) # Shape: B x 1. 394 | samples.append(gs_sample) 395 | 396 | samples = tf.concat(samples, axis=1) # Shape: B x T_dec. 397 | 398 | return samples 399 | 400 | def run_encoder(self, inputs): 401 | """ 402 | This function calculates the internal states of NSE after the last step of encoding. 403 | :param inputs: 404 | enc_in_batch: A batch of encoder input sequence, Shape: Bm x S_in x T_in. 405 | enc_pad_mask: Encoder input mask, Shape: Bm x S_in x T_in. 406 | enc_doc_mask: Encoder document mask, Shape: Bm x S_in. 407 | :return: 408 | NSE internal states after encoding: 409 | final_memory: [sent_mems, doc_mem] The sentence and document memories respectively after last time-step 410 | of encoding. Shape: [1 x T_enc x D, 1 x S_in x D]. 411 | final_read_state: Hidden state of read LSTM after last time-step of encoding,(c, h) 412 | Shape: (1 x D, 1 x D). 413 | final_write_state: Hidden state of write LSTM after last time-step of encoding,(c, h) 414 | Shape: (1 x D, 1 x D). 415 | final_comp_state: Hidden state of compose LSTM after last time-step of encoding,(c, h), 416 | Shape: (1 x 3D, 1 x 3D). 417 | Attention distributions: sentence, document attention. 418 | T_enc * [1 x T_enc, 1 x S_in]. 419 | 420 | """ 421 | to_return = [self._prev_states[0][0], self._prev_states[0][1]] + self._prev_states[1:] 422 | 423 | outputs = self._sess.run(to_return, feed_dict={self._enc_in: inputs[0], 424 | self._enc_pad_mask: inputs[1], 425 | self._enc_doc_mask: inputs[2]}) 426 | final_memory_s, final_memory_d = outputs[: 2] 427 | 428 | # memory Shape: [Bm x T_enc x D, Bm x S_in x D]. 429 | # read state Shape: (Bm x D, Bm x D). 430 | # write state Shape: (Bm x D, Bm x D). 431 | final_read_state, final_write_state = outputs[2: 4] 432 | 433 | # Since the states repeated values, slicing only first one. 434 | final_memory_s = final_memory_s[0, np.newaxis, :, :] # Shape: 1 x T_enc x D. 435 | final_memory_d = final_memory_d[0, np.newaxis, :, :] # Shape: 1 x S_in x D. 436 | 437 | final_memory = [final_memory_s, final_memory_d] 438 | final_comp_state = None 439 | 440 | def get_state_slice(inp_states): 441 | if FLAGS.num_layers == 1: 442 | return LSTMStateTuple(inp_states.c[0, np.newaxis, :], 443 | inp_states.h[0, np.newaxis, :]) 444 | 445 | else: 446 | sliced_states = [] 447 | for _, layer_state in enumerate(inp_states): 448 | if FLAGS.rnn.lower() == "lstm": 449 | sliced_states.append( 450 | LSTMStateTuple(layer_state.c[0, np.newaxis, :], 451 | layer_state.h[0, np.newaxis, :]) 452 | ) 453 | else: 454 | sliced_states.append( 455 | layer_state[0, np.newaxis, :] 456 | ) 457 | 458 | return tuple(sliced_states) 459 | 460 | final_read_state = get_state_slice(final_read_state) # Shape: (1 x D, 1 x D). 461 | final_write_state = get_state_slice(final_write_state) # Shape: (1 x D, 1 x D). 462 | if FLAGS.use_comp_lstm: 463 | final_comp_state = get_state_slice(outputs[4]) # Shape: (1 x 3D, 1 x 3D). 464 | 465 | state = [final_memory, final_read_state, final_write_state] 466 | if FLAGS.use_comp_lstm: 467 | state.append(final_comp_state) 468 | 469 | return state 470 | 471 | def decode_one_step(self, inputs, prev_states): 472 | """ 473 | This function performs one step of decoding. 474 | :param inputs: 475 | [0] dec_in_batch: 476 | The input to the decoder. This is the output from previous time-step, Shape: Bm * [1 x] 477 | [1] enc_pad_mask: Encoder input mask to indicate the presence of PAD tokens. 478 | Useful to calculate the attention over memory, Shape: Bm x T_enc. 479 | [2] enc_doc_mask: Encoder document mask to indicate the presence of empty sentences, Shape: Bm x S_in. 480 | In pointer generator mode, there are following additional inputs: 481 | [3]: enc_in_ex_vocab_batch: 482 | Encoder input sequence represented in extended vocabulary, Shape: Bm x T_enc. 483 | [4]: max_oov_size: Size of the largest OOV tokens in the current batch, Shape: () 484 | :param prev_states: previous internal states of NSE of all Bm hypothesis, Bm * [prev_state]. 485 | where state is a list of internal states of NSE for a single hypothesis. 486 | prev_state = [prev_memory, prev_read_state, prev_write_state] 487 | prev_memory: [sent_mems, doc_mem] The sentence and document memories respectively after last 488 | previous time-step, Shape: [1 x T_enc x D, 1 x S_in x D]. 489 | prev_read_state: Hidden state of read LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D]. 490 | prev_write_state: Hidden state of write LSTM after previous time step, (c, h) Shape: [1 x D, 1 x D]. 491 | prev_comp_state: Hidden state of compose LSTM after previous time step, (c, h) Shape: [1 x 3D, 1 x 3D]. 492 | :return: 493 | topk_ids: Top-k predictions in the current step, Shape: Bm x (2*Bm) 494 | topk_log_probs: log probabilities of top-k predictions, Shape: Bm x (2*Bm) 495 | curr_states: Current internal states of NSE, Bm * [state]. 496 | [0]: memory: NSE sentence and document memories. [Bm x T_enc x D, Bm x S_in x D]. 497 | [1]: read_state, (c, h) Shape: (Bm x D, Bm x D). 498 | [2]: write_state, (c, h) Shape: (Bm x D, Bm x D). 499 | [3]: comp_state, (c, h) Shape: (Bm x 3D, Bm x 3D). 500 | p_gens: Generation probabilities, Shape: Bm x . 501 | p_attns: Attention probabilities, Shape: [Bm x T_in, Bm x S_in]. 502 | """ 503 | # Decoder input 504 | dec_in = inputs[0] # Shape: Bm * [1 x] 505 | dec_in = np.stack(dec_in, axis=0) # Shape: Bm x 506 | inputs[0] = np.expand_dims(dec_in, axis=-1) # Shape: Bm x 1 507 | 508 | # Previous memories of Bm hypothesis. 509 | # Sentence memories. 510 | prev_memories_s = [state[0][0] for state in prev_states] # Shape: Bm * [1 x T_enc x D]. 511 | prev_memories_s = np.concatenate(prev_memories_s, axis=0) # Shape: Bm x T_enc x D. 512 | 513 | # Document memory. 514 | prev_memories_d = [state[0][1] for state in prev_states] # Shape: Bm * [1 x S_in x D]. 515 | prev_memories_d = np.concatenate(prev_memories_d, axis=0) # Shape: Bm x S_in x D. 516 | 517 | prev_memories = [prev_memories_s, prev_memories_d] 518 | 519 | def get_combined_states(inp_states): 520 | """ 521 | A function to combine the states of Bm hypothesis. 522 | :param inp_states: List of states of Bm hypothesis. Bm * [(s1, s2, ..., s_l)] 523 | :return: 524 | """ 525 | if FLAGS.num_layers == 1: 526 | # Cell states. 527 | combined_states_c = [hyp_state.c for hyp_state in inp_states] # Shape: Bm * [1 x D]. 528 | combined_states_c = np.concatenate(combined_states_c, axis=0) # Shape: Bm x D. 529 | 530 | # Hidden states. 531 | combined_states_h = [hyp_state.h for hyp_state in inp_states] # Shape: Bm * [1 x D]. 532 | combined_states_h = np.concatenate(combined_states_h, axis=0) # Shape: Bm x D. 533 | 534 | combined_states = LSTMStateTuple(combined_states_c, 535 | combined_states_h) # Shape: (Bm x D, Bm x D). 536 | 537 | else: 538 | combined_states = [] 539 | for i in range(FLAGS.num_layers): 540 | # Cell states. 541 | layer_state_c = [layer_state[i].c for layer_state in inp_states] # Shape: Bm * [1 x D]. 542 | layer_state_c = np.concatenate(layer_state_c, axis=0) # Shape: Bm x D. 543 | 544 | # Hidden states. 545 | layer_state_h = [layer_state[i].h for layer_state in inp_states] # Shape: Bm * [1 x D]. 546 | layer_state_h = np.concatenate(layer_state_h, axis=0) # Shape: Bm x D. 547 | 548 | combined_states.append( 549 | LSTMStateTuple(layer_state_c, layer_state_h) 550 | ) 551 | 552 | combined_states = tuple(combined_states) 553 | 554 | return combined_states 555 | 556 | prev_read_states = get_combined_states([state[1] for state in prev_states]) 557 | prev_write_states = get_combined_states([state[2] for state in prev_states]) 558 | if FLAGS.use_comp_lstm: 559 | prev_comp_states = get_combined_states([state[3] for state in prev_states]) 560 | else: 561 | prev_comp_states = None 562 | 563 | feed_dict = { 564 | self._dec_in: inputs[0], 565 | self._enc_pad_mask: inputs[1], 566 | self._enc_doc_mask: inputs[2], 567 | self._prev_states[0][0]: prev_memories[0], 568 | self._prev_states[0][1]: prev_memories[1], 569 | self._prev_states[1]: prev_read_states, 570 | self._prev_states[2]: prev_write_states 571 | } 572 | 573 | if FLAGS.use_comp_lstm: 574 | feed_dict[self._prev_states[3]] = prev_comp_states 575 | 576 | if FLAGS.use_pgen: 577 | feed_dict[self._enc_in_ext_vocab] = inputs[3] 578 | feed_dict[self._max_oov_size] = inputs[4] 579 | 580 | to_return = [self._topk_ids, self._topk_log_probs, self._curr_states[0][0], self._curr_states[0][1], 581 | self._curr_states[1], self._curr_states[2]] 582 | 583 | if FLAGS.use_comp_lstm: 584 | to_return.append(self._curr_states[3]) 585 | 586 | if FLAGS.use_pgen: 587 | to_return += [self._p_gens, self._p_attns] 588 | 589 | outputs = self._sess.run(to_return, feed_dict=feed_dict) 590 | 591 | # Preparing the next values (inputs and states for next time step). 592 | next_values = outputs[: 2] 593 | 594 | # Current memories. 595 | curr_memories_s = np.split(outputs[2], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x T_enc x D]. 596 | curr_memories_d = np.split(outputs[3], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x S_in x D]. 597 | curr_memories = [[mem_s, mem_d] for mem_s, mem_d in 598 | zip(curr_memories_s, curr_memories_d)] # Shape: Bm * [[1 x T_enc x D, 1 x S_in x D]]. 599 | 600 | def get_states_split(inp_states): 601 | """ 602 | This function splits the states for Bm hypothesis. 603 | :param inp_states: 604 | if just one layer: 605 | A NumPy array of shape Bm x D. 606 | for multiple layers: 607 | A tuple of states of RNN layers (s1, s2, ...., s_l). 608 | :return: 609 | """ 610 | if FLAGS.num_layers == 1: 611 | split_states_c = np.split(inp_states.c, FLAGS.beam_size, axis=0) # Shape: Bm * [1 x D]. 612 | split_states_h = np.split(inp_states.h, FLAGS.beam_size, axis=0) # Shape: Bm * [1 x D]. 613 | split_states = [LSTMStateTuple(c, h) 614 | for c, h in zip(split_states_c, split_states_h)] # Shape: Bm * [(1 x D, 1 x D)]. 615 | 616 | else: 617 | split_states = [] 618 | for i in range(FLAGS.beam_size): 619 | hyp_state_c = [layer_state.c[i, np.newaxis, :] 620 | for layer_state in inp_states] # num_layers * [1 x D]. 621 | hyp_state_h = [layer_state.h[i, np.newaxis, :] 622 | for layer_state in inp_states] # num_layers * [1 x D]. 623 | hyp_state = [LSTMStateTuple(c, h) 624 | for c, h in zip(hyp_state_c, hyp_state_h)] # num_layers * [(1 x D, 1 x D)] 625 | split_states.append(tuple(hyp_state)) 626 | 627 | return split_states 628 | 629 | # Split the states for Bm hypothesis. 630 | curr_read_states = get_states_split(outputs[4]) 631 | curr_write_states = get_states_split(outputs[5]) 632 | if FLAGS.use_comp_lstm: 633 | curr_comp_states = get_states_split(outputs[6]) 634 | else: 635 | curr_comp_states = None 636 | 637 | if FLAGS.use_comp_lstm: 638 | 639 | curr_states_list = [[memory, read_state, write_state, comp_state] 640 | for memory, read_state, write_state, comp_state in 641 | zip(curr_memories, curr_read_states, curr_write_states, curr_comp_states)] 642 | else: 643 | # Forming a list of internal states for Bm hypothesis. 644 | curr_states_list = [[memory, read_state, write_state] for memory, read_state, write_state 645 | in zip(curr_memories, curr_read_states, curr_write_states)] 646 | 647 | next_values.append(curr_states_list) 648 | 649 | # Generation probabilities. 650 | if FLAGS.use_pgen: 651 | p_gens = outputs[-2] # Shape: 1 x [Bm x 1]. 652 | p_gens = p_gens[0] # Only one time-step in decoding phase, Shape: Bm x 1. 653 | p_gens = np.squeeze(p_gens, axis=1) # Shape: Bm x 654 | p_gens = np.split(p_gens, FLAGS.beam_size, axis=0) # Shape: Bm * [1] 655 | next_values.append(p_gens) 656 | 657 | # Attention probabilities. 658 | if FLAGS.use_pgen: 659 | p_attns = outputs[-1] # Shape: 1 x [[Bm x T_enc, Bm x S_in]]. 660 | p_attns = p_attns[0] # Only one time-step in the decoding phase, Shape: [Bm x T_enc, Bm x S_in]. 661 | p_attns_s = np.split(p_attns[0], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x T_enc]. 662 | p_attns_d = np.split(p_attns[1], FLAGS.beam_size, axis=0) # Shape: Bm * [1 x S_in]. 663 | p_attns = [[attn_s, attn_d] for attn_s, attn_d 664 | in zip(p_attns_s, p_attns_d)] # Shape: Bm * [[1 x T_enc, 1 x S_in]]. 665 | next_values.append(p_attns) 666 | 667 | return next_values 668 | 669 | def _get_vocab_dist(self, inputs): 670 | """ 671 | This function passes the NSE hidden states obtained from decoder through the output layer (dense layer 672 | followed by a softmax layer) and returns the vocabulary distribution thus obtained. 673 | :param inputs: List of hidden states, Shape: T_dec * [B x D] 674 | :return: p_vocab: List of vocabulary distributions for each time-step, Shape: T_dec * [B x vsize]. 675 | """ 676 | steps = len(inputs) 677 | vsize = self._vocab.size() 678 | 679 | inputs = tf.concat(inputs, axis=0) # Shape: (B*T_dec) x D. 680 | 681 | with tf.variable_scope("output", reuse=tf.AUTO_REUSE): 682 | scores = tf.layers.dense(inputs, 683 | units=vsize, 684 | activation=None, 685 | kernel_initializer=self._dense_init, 686 | kernel_regularizer=None, 687 | name='output') # Shape: (B*T_dec) x vsize. 688 | 689 | p_vocab = tf.nn.softmax(scores, name="prob_scores") # Shape: (B*T_dec) x vsize. 690 | 691 | p_vocab = tf.split(p_vocab, num_or_size_splits=steps, axis=0) # Shape: T_dec * [B x vsize]. 692 | 693 | return p_vocab 694 | 695 | def _get_cumulative_dist(self, p_vocabs, p_gens, p_attns, enc_in_ext_vocab): 696 | """ 697 | This function calculates the cumulative probability distribution from the vocabulary distribution, attention 698 | distribution and generation probabilities. 699 | :param p_vocabs: A list of vocabulary distributions for each time-step, Shape: T_dec * [B x vsize]. 700 | :param p_gens: A list of generation probabilities for each time-step, Shape: T_dec * [B x 1] 701 | :param p_attns: A list of attention distributions for each time-step, T_dec * [[sent_attn, doc_attn]], 702 | Shape: T_dec * [[B x T_enc, B x S_in]]. 703 | :param enc_in_ext_vocab: Encoder input represented using extended vocabulary, Shape: B x T_enc. 704 | :return: 705 | """ 706 | vsize = self._vocab.size() 707 | ext_vsize = vsize + self._max_oov_size 708 | batch_size, enc_steps = enc_in_ext_vocab.get_shape().as_list() 709 | 710 | p_vocabs = [tf.multiply(p_gen, p_vocab) 711 | for p_gen, p_vocab in zip(p_gens, p_vocabs)] # Shape: T_dec * [B x vsize]. 712 | p_attns = [tf.multiply(tf.subtract(1.0, p_gen), p_attn[0]) 713 | for p_gen, p_attn in zip(p_gens, p_attns)] # Shape: T_dec * [B x T_enc]. 714 | 715 | zero_vocab = tf.zeros(shape=[batch_size, self._max_oov_size]) 716 | # Shape: T_dec * [B x ext_vsize]. 717 | p_vocabs_ext = [tf.concat([p_vocab, zero_vocab], axis=-1) for p_vocab in p_vocabs] 718 | 719 | idx = tf.range(0, limit=batch_size) # Shape: B x. 720 | idx = tf.expand_dims(idx, axis=-1) # Shape: B x 1. 721 | idx = tf.tile(idx, [1, enc_steps]) # Shape: B x T_enc. 722 | indices = tf.stack([idx, enc_in_ext_vocab], axis=-1) # Shape: B x T_enc x 2. 723 | 724 | # First, A zero matrix of shape B x ext_vsize is created. Then, the attention score for each input token 725 | # indexed as in indices is looked in p_attn and updated in the corresponding location, 726 | # Shape: T_dec * [B x ext_vsize] 727 | p_attn_cum = [tf.scatter_nd(indices, p_attn, [batch_size, ext_vsize]) for p_attn in p_attns] 728 | 729 | # Cumulative distribution, Shape: T_dec * [B x ext_vsize]. 730 | p_cum = [tf.add(gen_prob, copy_prob) for gen_prob, copy_prob in zip(p_vocabs_ext, p_attn_cum)] 731 | 732 | return p_cum 733 | 734 | @ staticmethod 735 | def _get_crossentropy_loss(probs, labels, mask): 736 | """ 737 | This function calculates the crossentropy loss from ground-truth and the probability scores. 738 | :param probs: Predicted probabilities, Shape: T * [B x vocab_size] 739 | :param labels: Ground Truth labels, Shape: B x T. 740 | :param mask: Mask to exclude PAD tokens while calculating loss, Shape: B x T. 741 | :return: average mini-batch loss. 742 | """ 743 | if type(probs) is list: 744 | probs = tf.stack(probs, axis=1) # Shape: B x T x vsize. 745 | 746 | true_probs = tf.reduce_sum( 747 | probs * tf.one_hot(labels, depth=tf.shape(probs)[2]), axis=2 748 | ) # Shape: B x T. 749 | logprobs = -tf.log(tf.clip_by_value(true_probs, 1e-10, 1.0)) # Shape: B x T. 750 | xe_loss = tf.reduce_sum(logprobs * mask) / tf.reduce_sum(mask) 751 | 752 | return xe_loss 753 | 754 | def _restore(self): 755 | """ 756 | This function restores parameters from a saved checkpoint. 757 | :return: 758 | """ 759 | restore_path = FLAGS.PathToCheckpoint + "model_epoch10" 760 | 761 | if FLAGS.restore_checkpoint and tf.train.checkpoint_exists(restore_path): 762 | start = time.time() 763 | 764 | # Initializing all variables. 765 | print("Initializing all variables!\n") 766 | self._sess.run(self._init) 767 | 768 | # Restoring checkpoint variables. 769 | vars_restore = [v for v in self._global_vars if "Adam" not in v.name] 770 | restore_saver = tf.train.Saver(vars_restore) 771 | print("Restoring non-Adam parameters from a previous checkpoint.\n") 772 | restore_saver.restore(self._sess, restore_path) 773 | 774 | end = time.time() 775 | print("Restoring model took %.2f sec. \n" % (end - start)) 776 | else: 777 | start = time.time() 778 | self._init.run(session=self._sess) 779 | end = time.time() 780 | print("Running initializer took %.2f time. \n" % (end - start)) 781 | 782 | def train(self): 783 | """ 784 | This function performs the training followed by validation after every few epochs and saves the best model 785 | if validation loss is decreased. It also writes the training summaries for visualization using tensorboard. 786 | :return: 787 | """ 788 | # Restore from previous checkpoint if found one. 789 | self._restore() 790 | 791 | # No. of iterations. 792 | num_train_iters_per_epoch = int(math.ceil(self._data.num_train_examples / FLAGS.batch_size)) 793 | 794 | # Running averages. 795 | running_avg_crossentropy_loss = 0.0 796 | running_avg_rouge = 0.0 797 | for epoch in range(1, 1 + FLAGS.num_epochs): 798 | # Training. 799 | for iteration in range(1, 1 + num_train_iters_per_epoch): 800 | start = time.time() 801 | feed_dict = self._get_feed_dict(split="train") 802 | 803 | to_return = [self._train_op, self._crossentropy_loss, self._global_step, self._summaries] 804 | if FLAGS.rouge_summary: 805 | to_return.append(self._samples) 806 | 807 | # Evaluate summaries for last batch. 808 | feed_dict[self._mean_crossentropy_loss] = running_avg_crossentropy_loss 809 | if FLAGS.rouge_summary: 810 | feed_dict[self._mean_rouge_score] = running_avg_rouge 811 | 812 | outputs = self._sess.run(to_return, feed_dict=feed_dict) 813 | _, crossentropy_loss, global_step = outputs[: 3] 814 | 815 | # Calculating ROUGE score. 816 | rouge_score = 0.0 817 | if FLAGS.rouge_summary: 818 | samples = outputs[4] 819 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x . 820 | gt_labels = feed_dict[self._dec_out] 821 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False) 822 | 823 | # Updating the running averages. 824 | running_avg_crossentropy_loss = get_running_avg_loss( 825 | crossentropy_loss, running_avg_crossentropy_loss 826 | ) 827 | running_avg_rouge = get_running_avg_loss( 828 | np.mean(rouge_score), running_avg_rouge 829 | ) 830 | 831 | if ((iteration - 2) % FLAGS.summary_every == 0) or (iteration == num_train_iters_per_epoch): 832 | train_summary = outputs[3] 833 | self._train_writer.add_summary(train_summary, global_step) 834 | 835 | end = time.time() 836 | print("\rTraining Iteration: {}/{} ({:.1f}%) took {:.2f} sec.".format( 837 | iteration, num_train_iters_per_epoch, iteration * 100 / num_train_iters_per_epoch, end - start 838 | )) 839 | 840 | if epoch % FLAGS.val_every == 0: 841 | start = time.time() 842 | self._validate() 843 | end = time.time() 844 | print("Validation took {:.2f} sec.".format(end - start)) 845 | 846 | def _validate(self): 847 | """ 848 | This function validates the saved model. 849 | :return: 850 | """ 851 | # Validation 852 | num_val_iters_per_epoch = int(math.ceil(self._data.num_val_examples / FLAGS.val_batch_size)) 853 | num_train_iters_per_epoch = int(math.ceil(self._data.num_train_examples / FLAGS.batch_size)) 854 | 855 | outputs = None 856 | total_crossentropy_loss = 0.0 857 | total_rouge_score = 0.0 858 | for iteration in range(1, 1 + num_val_iters_per_epoch): 859 | feed_dict = self._get_feed_dict(split="val") 860 | to_return = [self._crossentropy_loss] 861 | if FLAGS.rouge_summary: 862 | to_return.append(self._samples) 863 | 864 | if iteration == num_val_iters_per_epoch: 865 | feed_dict[self._mean_crossentropy_loss] = total_crossentropy_loss / (num_val_iters_per_epoch - 1) 866 | feed_dict[self._mean_rouge_score] = total_rouge_score / (num_val_iters_per_epoch - 1) 867 | to_return += [self._val_summaries, self._global_step] 868 | 869 | outputs = self._sess.run(to_return, feed_dict=feed_dict) 870 | crossentropy_loss = outputs[0] 871 | 872 | # Calculating ROUGE score. 873 | rouge_score = 0.0 874 | if FLAGS.rouge_summary: 875 | samples = outputs[1] 876 | lens = np.sum(feed_dict[self._dec_pad_mask], axis=1).astype(np.int32) # Shape: B x . 877 | gt_labels = feed_dict[self._dec_out] 878 | rouge_score = rouge_l_fscore(samples, gt_labels, None, lens, False) 879 | 880 | print("\rValidation Iteration: {}/{} ({:.1f}%)".format( 881 | iteration, num_val_iters_per_epoch, iteration * 100 / num_val_iters_per_epoch, 882 | )) 883 | 884 | # Updating the total losses. 885 | total_crossentropy_loss += crossentropy_loss 886 | total_rouge_score += np.mean(rouge_score) 887 | 888 | # Writing the validation summaries for visualization. 889 | val_summary, global_step = outputs[-2], outputs[-1] 890 | self._val_writer.add_summary(val_summary, global_step) 891 | epoch = global_step // num_train_iters_per_epoch 892 | 893 | # Cumulative loss. 894 | avg_xe_loss = total_crossentropy_loss / num_val_iters_per_epoch 895 | if avg_xe_loss < self._best_val_loss: 896 | print("\rValidation loss improved :).") 897 | self._saver.save(self._sess, FLAGS.PathToCheckpoint + "model_epoch" + str(epoch)) 898 | 899 | def test(self): 900 | """ 901 | This function predicts the outputs for the test set. 902 | :return: 903 | """ 904 | num_gpus = len(FLAGS.GPUs) # Total no. of GPUs. 905 | batch_size = num_gpus * FLAGS.num_pools # Total no. of examples processed per a single test iteration. 906 | 907 | # No. of iterations 908 | num_test_iters_per_epoch = int(math.ceil(self._data.num_test_examples / batch_size)) 909 | 910 | for iteration in range(1, 1 + num_test_iters_per_epoch): 911 | start = time.time() 912 | # [indices, summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask, 913 | # enc_inp_ext_vocab, ext_vocabs, max_oov_size] 914 | input_batches = self._data.get_batch( 915 | batch_size, split="test", permutate=FLAGS.permutate, chunk=FLAGS.chunk 916 | ) 917 | 918 | # Split the inputs into batches of size "num_pools" per GPU. 919 | input_batches_per_gpu = [[] for _ in range(num_gpus)] 920 | for i, input_batch in enumerate(input_batches): 921 | for j, idx in zip(range(num_gpus), range(0, batch_size, FLAGS.num_pools)): 922 | 923 | if i < 5 or (FLAGS.use_pgen and i < 7): # First 5 inputs (7 inputs in pgen mode). 924 | input_batches_per_gpu[j].append( 925 | input_batch[idx: idx + FLAGS.num_pools]) 926 | else: # max_oov_size is the same for all examples. 927 | input_batches_per_gpu[j].append(input_batch) 928 | 929 | # Appending 930 | # Appending the GPU id. 931 | for gpu_id in range(num_gpus): 932 | input_batches_per_gpu[gpu_id].append(gpu_id) 933 | 934 | Parallel(n_jobs=num_gpus, backend="threading")( 935 | map(delayed(self._test_one_gpu), input_batches_per_gpu) 936 | ) 937 | end = time.time() 938 | print("\rTesting Iteration: {}/{} ({:.1f}%) in {:.1f} sec.".format( 939 | iteration, num_test_iters_per_epoch, iteration * 100 / num_test_iters_per_epoch, end - start)) 940 | 941 | def _test_one_gpu(self, inputs): 942 | """ 943 | This function performs testing on the inputs of a single GPU. 944 | :return: 945 | """ 946 | input_batches = inputs[: -1] 947 | gpu_idx = inputs[-1] 948 | 949 | with tf.device('/gpu:%d' % gpu_idx): 950 | self._test_one_pool(input_batches) 951 | 952 | def _test_one_pool(self, inputs): 953 | """ 954 | This function performs testing on the inputs of a single GPU over parallel pools. 955 | :param inputs: All the first 8 inputs have "num_pools" examples each. 956 | [indices, summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask, 957 | enc_inp_ext_vocab, ext_vocabs, max_oov_size] 958 | :return: 959 | """ 960 | # Split the inputs into each example per pool. 961 | input_batches_per_pool = [[] for _ in range(FLAGS.num_pools)] 962 | for i, input_batch in enumerate(inputs): 963 | for j in range(FLAGS.num_pools): 964 | # Splitting the GT summaries. 965 | if i == 0 or i == 1 or i == 2 or (FLAGS.use_pgen and i == 7): 966 | input_batches_per_pool[j].append([input_batch[j]]) 967 | 968 | # Repeating the input Numpy arrays for beam size no. of times. 969 | elif i == 3 or i == 4 or i == 5 or (FLAGS.use_pgen and i == 6): 970 | input_batches_per_pool[j].append( 971 | np.repeat(input_batch[np.newaxis, j], repeats=FLAGS.beam_size, axis=0)) 972 | 973 | # max_oov_size is same for all examples in the pool. 974 | else: 975 | input_batches_per_pool[j].append(input_batch) 976 | 977 | Parallel(n_jobs=FLAGS.num_pools, backend="threading")( 978 | map(delayed(self._test_one_ex), input_batches_per_pool)) 979 | 980 | def _test_one_ex(self, inputs): 981 | """ 982 | This function performs testing on one example. 983 | :param inputs: 984 | [summaries, files[start: end], enc_inp, enc_pad_mask, enc_doc_mask, 985 | enc_inp_ext_vocab, ext_vocabs, max_oov_size, iteration] 986 | :return: 987 | """ 988 | iteration = inputs[0][0] 989 | input_batches = inputs[1:] 990 | # Inputs for running beam search on one example. 991 | beam_search_inputs = input_batches[2: 5] 992 | if FLAGS.use_pgen: 993 | beam_search_inputs += [input_batches[5], input_batches[7]] 994 | 995 | best_hyp = run_beam_search_hier( 996 | beam_search_inputs, self.run_encoder, self.decode_one_step, self._vocab 997 | ) 998 | pred_ids = best_hyp.tokens # Shape: T_dec * [1 x] 999 | pred_ids = np.stack(pred_ids, axis=0) # Shape: T_dec x 1000 | pred_ids = pred_ids[np.newaxis, :] # Shape: 1 x T_dec 1001 | 1002 | # Writing the outputs. 1003 | if FLAGS.use_pgen: 1004 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1], input_batches[6]) 1005 | else: 1006 | self._write_outputs(iteration, pred_ids, input_batches[0], input_batches[1]) 1007 | 1008 | def _get_feed_dict(self, split): 1009 | """ 1010 | Returns a feed_dict assigning data batches to the following placeholders: 1011 | [enc_in, enc_pad_mask, enc_doc_mask, dec_in, labels, dec_pad_mask, enc_inp_ext_vocab, max_oov_size] 1012 | :param split: 1013 | :return: 1014 | """ 1015 | 1016 | if split == "train": 1017 | input_batches = self._data.get_batch( 1018 | FLAGS.batch_size, split="train", permutate=FLAGS.permutate, chunk=FLAGS.chunk 1019 | ) 1020 | 1021 | elif split == "val": 1022 | input_batches = self._data.get_batch( 1023 | FLAGS.val_batch_size, split="val", permutate=FLAGS.permutate, chunk=FLAGS.chunk 1024 | ) 1025 | 1026 | else: 1027 | raise ValueError("split should be either train/val!! \n") 1028 | 1029 | feed_dict = { 1030 | self._enc_in: input_batches[0], 1031 | self._enc_pad_mask: input_batches[1], 1032 | self._enc_doc_mask: input_batches[2], 1033 | self._dec_in: input_batches[3], 1034 | self._dec_out: input_batches[4], 1035 | self._dec_pad_mask: input_batches[5] 1036 | } 1037 | 1038 | if FLAGS.use_pgen: 1039 | feed_dict[self._enc_in_ext_vocab] = input_batches[6] 1040 | feed_dict[self._max_oov_size] = input_batches[7] 1041 | 1042 | return feed_dict 1043 | 1044 | @staticmethod 1045 | def make_html_safe(s): 1046 | """ 1047 | Replace any angled brackets in string s to avoid interfering with HTML attention visualizer. 1048 | :param s: Input string. 1049 | :return: 1050 | """ 1051 | s.replace("<", "<") 1052 | s.replace(">", ">") 1053 | 1054 | return s 1055 | 1056 | def _write_file(self, pred, pred_name, gt, gt_name, ext_vocab=None): 1057 | """ 1058 | This function writes tokens in vals to a .txt file with given name. 1059 | :param pred: Predicted ID's. Shape: 1 x T 1060 | :param pred_name: Name of the file in which predictions will be written. 1061 | :param gt: Ground truth summary, a list of sentences (strings). 1062 | :param gt_name: Name of the file in which GTs will be written. 1063 | :param ext_vocab: Extended vocabulary for each example. [ext_words x] 1064 | :return: _pred, _gt files will be created for ROUGE evaluation. 1065 | """ 1066 | # Writing predictions. 1067 | vsize = self._vocab.size() 1068 | 1069 | # Removing the [START] token. 1070 | pred = pred[1:] 1071 | 1072 | # Converting words to ID's 1073 | pred_words = [] 1074 | for t in pred: 1075 | try: 1076 | pred_words.append(self._vocab.id2word(t)) 1077 | except ValueError: 1078 | pred_words.append(ext_vocab[t - vsize]) 1079 | 1080 | # Considering tokens only till STOP token. 1081 | try: 1082 | stop_idx = pred_words.index(params.STOP_DECODING) 1083 | except ValueError: 1084 | stop_idx = len(pred_words) 1085 | pred_words = pred_words[: stop_idx] 1086 | 1087 | # Creating sentences out of the predicted sequence. 1088 | pred_sents = [] 1089 | while pred_words: 1090 | try: 1091 | period_idx = pred_words.index(".") 1092 | except ValueError: 1093 | period_idx = len(pred_words) 1094 | 1095 | # Append the sentence. 1096 | sent = pred_words[: period_idx + 1] 1097 | pred_sents.append(" ".join(sent)) 1098 | 1099 | # Consider the remaining words now. 1100 | pred_words = pred_words[period_idx + 1:] 1101 | 1102 | # Making HTML safe. 1103 | pred_sents = [self.make_html_safe(s) for s in pred_sents] 1104 | gt_sents = [self.make_html_safe(s) for s in gt] 1105 | 1106 | # Writing predicted sentences. 1107 | f = open(pred_name, 'w', encoding='utf-8') 1108 | for i, sent in enumerate(pred_sents): 1109 | f.write(sent) if i == len(pred_sents) - 1 else f.write(sent + "\n") 1110 | f.close() 1111 | 1112 | # Writing GT sentences. 1113 | f = open(gt_name, 'w', encoding='utf-8') 1114 | for i, sent in enumerate(gt_sents): 1115 | f.write(sent) if i == len(gt_sents) - 1 else f.write(sent + "\n") 1116 | f.close() 1117 | 1118 | def _write_outputs(self, index, preds, gts, files, ext_vocabs=None): 1119 | """ 1120 | This function writes the input files 1121 | :param index: Number of the test example. 1122 | :param preds: The predictions. Shape: B x T 1123 | :param: gts: The ground truths. Shape: B x T 1124 | :param files: The names of the files. 1125 | :param ext_vocabs: Extended vocabularies for each example, Shape: B * [ext_words x] 1126 | :return: Saves the predictions and GT's in a .txt format. 1127 | """ 1128 | for i in range(len(files)): 1129 | file, pred, gt = files[i], preds[i], gts[i] 1130 | name_pred = FLAGS.PathToResults + 'predictions/' + '%06d_pred.txt' % index 1131 | name_gt = FLAGS.PathToResults + 'groundtruths/' + '%06d_gt.txt' % index 1132 | 1133 | ext_vocab = None 1134 | if FLAGS.use_pgen: 1135 | ext_vocab = ext_vocabs[i] 1136 | 1137 | self._write_file(pred, name_pred, gt, name_gt, ext_vocab) 1138 | 1139 | def _single_model(self): 1140 | # with tf.device('/gpu'): 1141 | placeholders = self._create_placeholders() 1142 | self._crossentropy_loss, self._samples = self._forward(placeholders) 1143 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer 1144 | self._train_op = sgd_solver.minimize( 1145 | self._crossentropy_loss, global_step=self._global_step, name='train_op') 1146 | 1147 | self._saver = tf.train.Saver(var_list=tf.global_variables()) # Saver. 1148 | self._init = tf.global_variables_initializer() # Initializer. 1149 | self._sess = tf.Session(config=self._config) # Session. 1150 | 1151 | # Train and validation summary writers. 1152 | if FLAGS.mode == 'train': 1153 | self._create_writers() 1154 | 1155 | self._global_vars = tf.global_variables() 1156 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables()))) 1157 | # print(tf.trainable_variables()) 1158 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()]) 1159 | # print('No. of params = {:d}'.format(int(no_params))) 1160 | 1161 | def _parallel_model(self): 1162 | with tf.Graph().as_default(), tf.device('/cpu:0'): 1163 | placeholders = self._create_placeholders() 1164 | 1165 | # Splitting the placeholders for each GPU. 1166 | placeholders_per_gpu = [[] for _ in range(self._num_gpus)] 1167 | for placeholder in placeholders: 1168 | splits = tf.split(placeholder, num_or_size_splits=self._num_gpus, axis=0) 1169 | for i, split in enumerate(splits): 1170 | placeholders_per_gpu[i].append(split) 1171 | 1172 | sgd_solver = tf.train.AdamOptimizer(learning_rate=FLAGS.lr) # Optimizer 1173 | tower_grads = [] # Gradients calculated in each tower 1174 | with tf.variable_scope(tf.get_variable_scope()): 1175 | all_losses = [] 1176 | all_samples = [] 1177 | for i in range(self._num_gpus): 1178 | with tf.device('/gpu:%d' % i): 1179 | with tf.name_scope('tower_%d' % i): 1180 | crossentropy_loss, samples = self._forward(placeholders_per_gpu[i]) 1181 | 1182 | # Gathering samples from all GPUs. 1183 | all_samples.append(samples) 1184 | 1185 | tf.get_variable_scope().reuse_variables() 1186 | 1187 | grads = sgd_solver.compute_gradients(crossentropy_loss) 1188 | tower_grads.append(grads) 1189 | 1190 | # Updating the losses 1191 | all_losses.append(crossentropy_loss) 1192 | 1193 | self._crossentropy_loss = tf.add_n(all_losses) / self._num_gpus 1194 | 1195 | # Samples from all the GPUs. 1196 | if FLAGS.rouge_summary: 1197 | self._samples = tf.concat(all_samples, axis=0) 1198 | else: 1199 | self._samples = [] 1200 | 1201 | # Synchronization Point 1202 | gradients, variables = zip(*average_gradients(tower_grads)) 1203 | gradients, global_norm = tf.clip_by_global_norm(gradients, FLAGS.max_grad_norm) 1204 | 1205 | # Summary for the global norm 1206 | tf.summary.scalar('global_norm', global_norm) 1207 | 1208 | # Histograms for gradients. 1209 | for grad, var in zip(gradients, variables): 1210 | # if (grad is not None) and ("word_embedding" in var.name): 1211 | tf.summary.histogram(var.op.name + '/gradients', grad) 1212 | 1213 | # Histograms for variables. 1214 | for var in tf.trainable_variables(): 1215 | # if "word_embedding" in var.name: 1216 | tf.summary.histogram(var.op.name, var) 1217 | 1218 | self._train_op = sgd_solver.apply_gradients( 1219 | zip(gradients, variables), global_step=self._global_step) 1220 | self._saver = tf.train.Saver(var_list=tf.global_variables(), 1221 | max_to_keep=None) # Saver. 1222 | self._init = tf.global_variables_initializer() # Initializer. 1223 | self._sess = tf.Session(config=self._config) # Session. 1224 | 1225 | # Train and validation summary writers. 1226 | if FLAGS.mode == 'train': 1227 | self._create_writers() 1228 | 1229 | self._global_vars = tf.global_variables() 1230 | # print('No. of variables = {}\n'.format(len(tf.trainable_variables()))) 1231 | # print(tf.trainable_variables()) 1232 | # no_params = np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.trainable_variables()]) 1233 | # print('No. of params = {:d}'.format(int(no_params))) 1234 | --------------------------------------------------------------------------------