├── LICENSE ├── README.md ├── data_utils.py ├── document_summarizer_training_testing.py ├── model_docsum.py ├── model_utils.py ├── my_flags.py ├── my_model.py ├── reward_utils.py └── scripts └── oracle-estimator ├── estimate_multiple_oracles.py └── rouge.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Shashi Narayan 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Refresh: Ranking Sentences for Extractive Summarization with Reinforcement Learning 2 | 3 | This repository releases our code for the Refresh model. It is improved from our code for [Sidenet](https://github.com/shashiongithub/sidenet). It uses Tensorflow 0.10, please use scripts provided by Tensorflow to translate them to newer upgrades. 4 | 5 | Please contact me at shashi.narayan@gmail.com for any question. 6 | 7 | Please cite this paper if you use our code or data: 8 | 9 | **Ranking Sentences for Extractive Summarization with Reinforcement Learning, Shashi Narayan, Shay B. Cohen and Mirella Lapata, NAACL 2018.** 10 | 11 | > Single document summarization is the task of producing a shorter version of a document while preserving its principal information content. In this paper we conceptualize extractive summarization as a sentence ranking task and propose a novel training algorithm which globally optimizes the ROUGE evaluation metric through a reinforcement learning objective. We use our algorithm to train a neural summarization model on the CNN and DailyMail datasets and demonstrate experimentally that it outperforms state-of-the-art extractive and abstractive systems when evaluated automatically and by humans. 12 | 13 | ## CNN and Dailymail Data 14 | 15 | In addition to our code, please find links to additional files which are not uploaded here. 16 | 17 | #### Preprocessed Data and Word Embedding File 18 | 19 | * [Pretrained word embeddings](http://bollin.inf.ed.ac.uk/public/direct/Refresh-NAACL18-1-billion-benchmark-wordembeddings.tar.gz) trained on "1 billion word language modeling benchmark r13output" (405MB) 20 | * [Preprocessed CNN and DailyMail data](http://bollin.inf.ed.ac.uk/public/direct/Refresh-NAACL18-preprocessed-input-data.tar.gz): Articles are tokenized/segmented with the original case. Then, words are replaced with word ids in the word embedding file with (PAD_ID = 0, UNK_ID = 1). (1.9GB) 21 | * [Original Test and Validation mainbody data](http://bollin.inf.ed.ac.uk/public/direct/Refresh-NAACL18-CNN-DM-Filtered-TokenizedSegmented.tar.gz): These files are used to assemble summaries. (35MB) 22 | * [Gold Test and Validation highlights](http://bollin.inf.ed.ac.uk/public/direct/Refresh-NAACL18-baseline-gold-data.tar.gz): These files are used to estimate ROUGE scores. (11MB) 23 | 24 | #### Best Pretrained Models 25 | 26 | We train for a certain number of epochs and then we estimate ROUGE score on the validation set after each epoch. The chosen models are the best ones performing on the validation set. 27 | 28 | * [CNN and DailyMail Pretrained Models](http://bollin.inf.ed.ac.uk/public/direct/Refresh-NAACL18-pretrained-models.tar.gz) (1.8GB) 29 | 30 | #### Human Evaluation Data 31 | 32 | We have selected 20 (10 CNN and 10 DailyMail) articles. Please see our paper for the experiment setup. 33 | 34 | * [CNN and DailyMail Human Evaluation Data](http://bollin.inf.ed.ac.uk/public/direct/Refresh-NAACL18-human-evaluations.tar.gz) 35 | 36 | ## Training and Evaluation Instructions 37 | 38 | Please download data using the above links and then either update `my_flags.py` for the following parameters or pass them as in-line arguments: 39 | 40 | ``` 41 | pretrained_wordembedding: /address/data/1-billion-word-language-modeling-benchmark-r13output.word2vec.vec (Pretrained wordembedding file trained on the one million benchmark data) 42 | preprocessed_data_directory: /address/data/preprocessed-input-directory (Preprocessed news articles) 43 | gold_summary_directory: /address/data/Baseline-Gold-Models (Gold summary directory) 44 | doc_sentence_directory: /address/data/CNN-DM-Filtered-TokenizedSegmented (Directory where document sentences are kept) 45 | ``` 46 | 47 | #### CNN 48 | 49 | ``` 50 | mkdir -p /address/to/training/directory/cnn-reinforcementlearn-singlesample-from-moracle-noatt-sample5 51 | 52 | # Training 53 | python document_summarizer_training_testing.py --use_gpu /gpu:2 --data_mode cnn --train_dir /address/to/training/directory/cnn-reinforcementlearn-singlesample-from-moracle-noatt-sample5 --num_sample_rollout 5 > /address/to/training/directory/cnn-reinforcementlearn-singlesample-from-moracle-noatt-sample5/train.log 54 | 55 | # Evaluation 56 | python document_summarizer_training_testing.py --use_gpu /gpu:2 --data_mode cnn --exp_mode test --model_to_load 11 --train_dir /address/to/training/directory/cnn-reinforcementlearn-singlesample-from-moracle-noatt-sample5 --num_sample_rollout 5 > /address/to/training/directory/cnn-reinforcementlearn-singlesample-from-moracle-noatt-sample5/test.model11.log 57 | ``` 58 | 59 | #### DailyMail 60 | 61 | ``` 62 | mkdir -p /address/to/training/directory/dailymail-reinforcementlearn-singlesample-from-moracle-noatt-sample15 63 | 64 | # Training 65 | python document_summarizer_training_testing.py --use_gpu /gpu:2 --data_mode dailymail --train_dir /address/to/training/directory/dailymail-reinforcementlearn-singlesample-from-moracle-noatt-sample15 --num_sample_rollout 15 > /address/to/training/directory/dailymail-reinforcementlearn-singlesample-from-moracle-noatt-sample15/train.log 66 | 67 | # Evaluation 68 | python document_summarizer_training_testing.py --use_gpu /gpu:2 --data_mode dailymail --exp_mode test --model_to_load 7 --train_dir /address/to/training/directory/dailymail-reinforcementlearn-singlesample-from-moracle-noatt-sample15 --num_sample_rollout 15 > /address/to/training/directory/dailymail-reinforcementlearn-singlesample-from-moracle-noatt-sample15/test.model7.log 69 | ``` 70 | 71 | ## Oracle Estimation 72 | 73 | Check our "scripts/oracle-estimator" to compute multiple oracles for your own dataset for training. 74 | 75 | ## Blog post and Live Demo 76 | 77 | You could find a live demo of Refresh [here](http://kinloch.inf.ed.ac.uk/sidenet.html). 78 | 79 | See [here](https://nurture.ai/p/e5c2a653-404a-4af8-b35f-e9e0d17fd272) for a light introduction of our paper written by [nurture.ai](https://nurture.ai). 80 | 81 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | #################################### 7 | 8 | """ 9 | Document Summarization Modules and Models 10 | """ 11 | 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | import random 20 | import os 21 | 22 | from my_flags import FLAGS 23 | from model_utils import convert_logits_to_softmax, predict_topranked 24 | 25 | # Special IDs 26 | PAD_ID = 0 27 | UNK_ID = 1 28 | 29 | class Data: 30 | def __init__(self, vocab_dict, data_type): 31 | self.filenames = [] 32 | self.docs = [] 33 | self.titles = [] 34 | self.images = [] 35 | self.labels = [] 36 | self.rewards = [] 37 | self.weights = [] 38 | 39 | self.fileindices = [] 40 | 41 | self.data_type = data_type 42 | 43 | # populate the data 44 | self.populate_data(vocab_dict, data_type) 45 | 46 | # Write to files 47 | self.write_to_files(data_type) 48 | 49 | def write_prediction_summaries(self, pred_logits, modelname, session=None): 50 | print("Writing predictions and final summaries ...") 51 | 52 | # Convert to softmax logits 53 | pred_logits = convert_logits_to_softmax(pred_logits, session=session) 54 | # Save Output Logits 55 | np.save(FLAGS.train_dir+"/"+modelname+"."+self.data_type+"-prediction", pred_logits) 56 | 57 | # Writing 58 | pred_labels = predict_topranked(pred_logits, self.weights, self.filenames) 59 | self.write_predictions(modelname+"."+self.data_type, pred_logits, pred_labels) 60 | self.process_predictions_topranked(modelname+"."+self.data_type) 61 | 62 | def write_predictions(self, file_prefix, np_predictions, np_labels): 63 | foutput = open(FLAGS.train_dir+"/"+file_prefix+".predictions", "w") 64 | for fileindex in self.fileindices: 65 | filename = self.filenames[fileindex] 66 | foutput.write(filename+"\n") 67 | 68 | sentcount = 0 69 | for sentpred, sentlabel in zip(np_predictions[fileindex], np_labels[fileindex]): 70 | one_prob = sentpred[0] 71 | label = sentlabel[0] 72 | 73 | if sentcount < len(self.weights[fileindex]): 74 | foutput.write(str(int(label))+"\t"+str(one_prob)+"\n") 75 | else: 76 | break 77 | 78 | sentcount += 1 79 | foutput.write("\n") 80 | foutput.close() 81 | 82 | def process_predictions_topranked(self, file_prefix): 83 | predictiondata = open(FLAGS.train_dir+"/"+file_prefix+".predictions").read().strip().split("\n\n") 84 | # print len(predictiondata) 85 | 86 | summary_dirname = FLAGS.train_dir+"/"+file_prefix+"-summary-topranked" 87 | os.system("mkdir "+summary_dirname) 88 | 89 | for item in predictiondata: 90 | # print(item) 91 | 92 | itemdata = item.strip().split("\n") 93 | # print len(itemdata) 94 | 95 | filename = itemdata[0] 96 | # print filename 97 | 98 | # predictions file already have top three sentences marked 99 | final_sentids = [] 100 | for sentid in range(len(itemdata[1:])): 101 | label_score = itemdata[sentid+1].split() 102 | if label_score[0] == "1": 103 | final_sentids.append(sentid) 104 | 105 | # Create final summary files 106 | fileid = filename.split("-")[-1] # cnn-fileid, dailymail-fileid 107 | summary_file = open(summary_dirname+"/"+fileid+".model", "w") 108 | 109 | # Read Sents in the document : Always use original sentences 110 | sent_filename = FLAGS.doc_sentence_directory + "/" + self.data_type +"/mainbody/"+fileid+".mainbody" 111 | docsents = open(sent_filename).readlines() 112 | 113 | # Top Ranked three sentences 114 | selected_sents = [docsents[sentid] for sentid in final_sentids if sentid < len(docsents)] 115 | # print(selected_sents) 116 | 117 | summary_file.write("".join(selected_sents)+"\n") 118 | summary_file.close() 119 | 120 | def get_batch(self, startidx, endidx): 121 | # This is very fast if you keep everything in Numpy 122 | 123 | def process_to_chop_pad(orgids, requiredsize): 124 | if (len(orgids) >= requiredsize): 125 | return orgids[:requiredsize] 126 | else: 127 | padids = [PAD_ID] * (requiredsize - len(orgids)) 128 | return (orgids + padids) 129 | 130 | # Numpy dtype 131 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 132 | 133 | # For train, (endidx-startidx)=FLAGS.batch_size, for others its as specified 134 | batch_docnames = np.empty((endidx-startidx), dtype="S60") # File ID of size "cnn-" or "dailymail-" with fileid of size 40 135 | batch_docs = np.empty(((endidx-startidx), (FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length), FLAGS.max_sent_length), dtype="int32") 136 | batch_label = np.empty(((endidx-startidx), FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) # Single best oracle, used for JP models or accuracy estimation 137 | batch_weight = np.empty(((endidx-startidx), FLAGS.max_doc_length), dtype=dtype) 138 | batch_oracle_multiple = np.empty(((endidx-startidx), 1, FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) 139 | batch_reward_multiple = np.empty(((endidx-startidx), 1), dtype=dtype) 140 | 141 | batch_idx = 0 142 | for fileindex in self.fileindices[startidx:endidx]: 143 | # Document Names 144 | batch_docnames[batch_idx] = self.filenames[fileindex] 145 | 146 | # Document 147 | doc_wordids = [] # [FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length, FLAGS.max_sent_length] 148 | for idx in range(FLAGS.max_doc_length): 149 | thissent = [] 150 | if idx < len(self.docs[fileindex]): 151 | thissent = self.docs[fileindex][idx][:] 152 | thissent = process_to_chop_pad(thissent, FLAGS.max_sent_length) # [FLAGS.max_sent_length] 153 | doc_wordids.append(thissent) 154 | for idx in range(FLAGS.max_title_length): 155 | thissent = [] 156 | if idx < len(self.titles[fileindex]): 157 | thissent = self.titles[fileindex][idx][:] 158 | thissent = process_to_chop_pad(thissent, FLAGS.max_sent_length) # [FLAGS.max_sent_length] 159 | doc_wordids.append(thissent) 160 | for idx in range(FLAGS.max_image_length): 161 | thissent = [] 162 | if idx < len(self.images[fileindex]): 163 | thissent = self.images[fileindex][idx][:] 164 | thissent = process_to_chop_pad(thissent, FLAGS.max_sent_length) # [FLAGS.max_sent_length] 165 | doc_wordids.append(thissent) 166 | batch_docs[batch_idx] = np.array(doc_wordids[:], dtype="int32") 167 | 168 | # Labels: Select the single best 169 | labels_vecs = [[1, 0] if (item in self.labels[fileindex][0]) else [0, 1] for item in range(FLAGS.max_doc_length)] 170 | batch_label[batch_idx] = np.array(labels_vecs[:], dtype=dtype) 171 | 172 | # Weights 173 | weights = process_to_chop_pad(self.weights[fileindex][:], FLAGS.max_doc_length) 174 | batch_weight[batch_idx] = np.array(weights[:], dtype=dtype) 175 | 176 | # Multiple Labels and rewards 177 | labels_set = [] # FLAGS.num_sample_rollout, FLAGS.max_doc_length, FLAGS.target_label_size 178 | reward_set = [] # FLAGS.num_sample_rollout, FLAGS.max_doc_length, FLAGS.target_label_size 179 | for idx in range(FLAGS.num_sample_rollout): 180 | thislabels = [] 181 | if idx < len(self.labels[fileindex]): 182 | thislabels = [[1, 0] if (item in self.labels[fileindex][idx]) else [0, 1] for item in range(FLAGS.max_doc_length)] 183 | reward_set.append(self.rewards[fileindex][idx]) 184 | else: 185 | # Simply copy the best one 186 | thislabels = [[1, 0] if (item in self.labels[fileindex][0]) else [0, 1] for item in range(FLAGS.max_doc_length)] 187 | reward_set.append(self.rewards[fileindex][0]) 188 | labels_set.append(thislabels) 189 | # Randomly Sample one oracle label 190 | randidx_oracle = random.randint(0, (FLAGS.num_sample_rollout-1)) 191 | batch_oracle_multiple[batch_idx][0] = np.array(labels_set[randidx_oracle][:], dtype=dtype) 192 | batch_reward_multiple[batch_idx] = np.array([reward_set[randidx_oracle]], dtype=dtype) 193 | 194 | # increase batch count 195 | batch_idx += 1 196 | 197 | return batch_docnames, batch_docs, batch_label, batch_weight, batch_oracle_multiple, batch_reward_multiple 198 | 199 | def shuffle_fileindices(self): 200 | random.shuffle(self.fileindices) 201 | 202 | def write_to_files(self, data_type): 203 | full_data_file_prefix = FLAGS.train_dir + "/" + FLAGS.data_mode + "." + data_type 204 | print("Writing data files with prefix (.filename, .doc, .title, .image, .label, .weight, .rewards): %s"%full_data_file_prefix) 205 | 206 | ffilenames = open(full_data_file_prefix+".filename", "w") 207 | fdoc = open(full_data_file_prefix+".doc", "w") 208 | ftitle = open(full_data_file_prefix+".title", "w") 209 | fimage = open(full_data_file_prefix+".image", "w") 210 | flabel = open(full_data_file_prefix+".label", "w") 211 | fweight = open(full_data_file_prefix+".weight", "w") 212 | freward = open(full_data_file_prefix+".reward", "w") 213 | 214 | for filename, doc, title, image, label, weight, reward in zip(self.filenames, self.docs, self.titles, self.images, self.labels, self.weights, self.rewards): 215 | ffilenames.write(filename+"\n") 216 | fdoc.write("\n".join([" ".join([str(item) for item in itemlist]) for itemlist in doc])+"\n\n") 217 | ftitle.write("\n".join([" ".join([str(item) for item in itemlist]) for itemlist in title])+"\n\n") 218 | fimage.write("\n".join([" ".join([str(item) for item in itemlist]) for itemlist in image])+"\n\n") 219 | flabel.write("\n".join([" ".join([str(item) for item in itemlist]) for itemlist in label])+"\n\n") 220 | fweight.write(" ".join([str(item) for item in weight])+"\n") 221 | freward.write(" ".join([str(item) for item in reward])+"\n") 222 | 223 | ffilenames.close() 224 | fdoc.close() 225 | ftitle.close() 226 | fimage.close() 227 | flabel.close() 228 | fweight.close() 229 | freward.close() 230 | 231 | def populate_data(self, vocab_dict, data_type): 232 | 233 | full_data_file_prefix = FLAGS.preprocessed_data_directory + "/" + FLAGS.data_mode + "." + data_type 234 | print("Data file prefix (.doc, .title, .image, .label.multipleoracle): %s"%full_data_file_prefix) 235 | 236 | # Process doc, title, image, label 237 | doc_data_list = open(full_data_file_prefix+".doc").read().strip().split("\n\n") 238 | title_data_list = open(full_data_file_prefix+".title").read().strip().split("\n\n") 239 | image_data_list = open(full_data_file_prefix+".image").read().strip().split("\n\n") 240 | label_data_list = open(full_data_file_prefix+".label.multipleoracle").read().strip().split("\n\n") 241 | 242 | print("Data sizes: %d %d %d %d"%(len(doc_data_list), len(title_data_list), len(image_data_list), len(label_data_list))) 243 | 244 | print("Reading data (no padding to save memory) ...") 245 | doccount = 0 246 | for doc_data, title_data, image_data, label_data in zip(doc_data_list, title_data_list, image_data_list, label_data_list): 247 | 248 | doc_lines = doc_data.strip().split("\n") 249 | title_lines = title_data.strip().split("\n") 250 | image_lines = image_data.strip().split("\n") 251 | label_lines = label_data.strip().split("\n") 252 | 253 | filename = doc_lines[0].strip() 254 | 255 | if ((filename == title_lines[0].strip()) and (filename == image_lines[0].strip()) and (filename == label_lines[0].strip())): 256 | # Put filename 257 | self.filenames.append(filename) 258 | 259 | # Doc 260 | thisdoc = [] 261 | for line in doc_lines[1:FLAGS.max_doc_length+1]: 262 | thissent = [int(item) for item in line.strip().split()] 263 | thisdoc.append(thissent) 264 | self.docs.append(thisdoc) 265 | 266 | # Title 267 | thistitle = [] 268 | for line in title_lines[1:FLAGS.max_title_length+1]: 269 | thissent = [int(item) for item in line.strip().split()] 270 | thistitle.append(thissent) 271 | self.titles.append(thistitle) 272 | 273 | # Image 274 | thisimage = [] 275 | for line in image_lines[1:FLAGS.max_image_length+1]: 276 | thissent = [int(item) for item in line.strip().split()] 277 | thisimage.append(thissent) 278 | self.images.append(thisimage) 279 | 280 | # Weights 281 | originaldoclen = int(label_lines[1].strip()) 282 | thisweight = [1 for item in range(originaldoclen)][:FLAGS.max_doc_length] 283 | self.weights.append(thisweight) 284 | 285 | # Labels (multiple oracles and preestimated rewards) 286 | thislabel = [] 287 | thisreward = [] 288 | for line in label_lines[2:FLAGS.num_sample_rollout+2]: 289 | thislabel.append([int(item) for item in line.split()[:-1]]) 290 | thisreward.append(float(line.split()[-1])) 291 | self.labels.append(thislabel) 292 | self.rewards.append(thisreward) 293 | 294 | else: 295 | print("Some problem with %s.* files. Exiting!"%full_data_file_prefix) 296 | exit(0) 297 | 298 | if doccount%10000==0: 299 | print("%d ..."%doccount) 300 | doccount += 1 301 | 302 | # Set Fileindices 303 | self.fileindices = range(len(self.filenames)) 304 | 305 | class DataProcessor: 306 | def prepare_news_data(self, vocab_dict, data_type="training"): 307 | data = Data(vocab_dict, data_type) 308 | return data 309 | 310 | def prepare_vocab_embeddingdict(self): 311 | # Numpy dtype 312 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 313 | 314 | vocab_dict = {} 315 | word_embedding_array = [] 316 | 317 | # Add padding 318 | vocab_dict["_PAD"] = PAD_ID 319 | # Add UNK 320 | vocab_dict["_UNK"] = UNK_ID 321 | 322 | # Read word embedding file 323 | wordembed_filename = FLAGS.pretrained_wordembedding 324 | print("Reading pretrained word embeddings file: %s"%wordembed_filename) 325 | 326 | embed_line = "" 327 | linecount = 0 328 | with open(wordembed_filename, "r") as fembedd: 329 | for line in fembedd: 330 | if linecount == 0: 331 | vocabsize = int(line.split()[0]) 332 | # Initiate fixed size empty array 333 | word_embedding_array = np.empty((vocabsize, FLAGS.wordembed_size), dtype=dtype) 334 | else: 335 | linedata = line.split() 336 | vocab_dict[linedata[0]] = linecount + 1 337 | embeddata = [float(item) for item in linedata[1:]][0:FLAGS.wordembed_size] 338 | word_embedding_array[linecount-1] = np.array(embeddata, dtype=dtype) 339 | 340 | if linecount%100000 == 0: 341 | print(str(linecount)+" ...") 342 | linecount += 1 343 | print("Read pretrained embeddings: %s"%str(word_embedding_array.shape)) 344 | 345 | print("Size of vocab: %d (_PAD:0, _UNK:1)"%len(vocab_dict)) 346 | vocabfilename = FLAGS.train_dir+"/vocab.txt" 347 | print("Writing vocab file: %s"%vocabfilename) 348 | 349 | foutput = open(vocabfilename,"w") 350 | vocab_list = [(vocab_dict[key], key) for key in vocab_dict.keys()] 351 | vocab_list.sort() 352 | vocab_list = [item[1] for item in vocab_list] 353 | foutput.write("\n".join(vocab_list)+"\n") 354 | foutput.close() 355 | return vocab_dict, word_embedding_array 356 | 357 | -------------------------------------------------------------------------------- /document_summarizer_training_testing.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | # Comments: Jan 2017 7 | # Improved for Reinforcement Learning 8 | #################################### 9 | 10 | """ 11 | Document Summarization System 12 | """ 13 | 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | import math 19 | import os 20 | import random 21 | import sys 22 | import time 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from reward_utils import Reward_Generator 28 | from data_utils import DataProcessor 29 | from my_flags import FLAGS 30 | from my_model import MY_Model 31 | 32 | 33 | ######################## Batch Testing a model on some dataset ############ 34 | 35 | def batch_predict_with_a_model(data, model, session=None): 36 | 37 | data_logits = [] 38 | data_labels = [] 39 | data_weights = [] 40 | 41 | step = 1 42 | while (step * FLAGS.batch_size) <= len(data.fileindices): 43 | # Get batch data as Numpy Arrays : Without shuffling 44 | batch_docnames, batch_docs, batch_label, batch_weight, batch_oracle_multiple, batch_reward_multiple = data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size)) 45 | batch_logits = session.run(model.logits, feed_dict={model.document_placeholder: batch_docs}) 46 | 47 | data_logits.append(batch_logits) 48 | data_labels.append(batch_label) 49 | data_weights.append(batch_weight) 50 | 51 | # Increase step 52 | step += 1 53 | 54 | # Check if any data left 55 | if (len(data.fileindices) > ((step-1)*FLAGS.batch_size)): 56 | # Get last batch as Numpy Arrays 57 | batch_docnames, batch_docs, batch_label, batch_weight, batch_oracle_multiple, batch_reward_multiple = data.get_batch(((step-1)*FLAGS.batch_size), len(data.fileindices)) 58 | batch_logits = session.run(model.logits, feed_dict={model.document_placeholder: batch_docs}) 59 | 60 | data_logits.append(batch_logits) 61 | data_labels.append(batch_label) 62 | data_weights.append(batch_weight) 63 | # print(data_logits) 64 | 65 | # Convert list to tensors 66 | data_logits = tf.concat(0, data_logits) 67 | data_lables = tf.concat(0, data_labels) 68 | data_weights = tf.concat(0, data_weights) 69 | # print(data_logits,data_lables,data_weights) 70 | return data_logits, data_lables, data_weights 71 | 72 | ######################## Training Mode ########################### 73 | 74 | def train(): 75 | """ 76 | Training Mode: Create a new model and train the network 77 | """ 78 | 79 | # Training: use the tf default graph 80 | with tf.Graph().as_default() and tf.device(FLAGS.use_gpu): 81 | 82 | config = tf.ConfigProto(allow_soft_placement = True) 83 | 84 | # Start a session 85 | with tf.Session(config = config) as sess: 86 | 87 | ### Prepare data for training 88 | print("Prepare vocab dict and read pretrained word embeddings ...") 89 | vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict() 90 | # vocab_dict contains _PAD and _UNK but not word_embedding_array 91 | 92 | print("Prepare training data ...") 93 | train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training") 94 | 95 | print("Prepare validation data ...") 96 | validation_data = DataProcessor().prepare_news_data(vocab_dict, data_type="validation") 97 | 98 | print("Prepare ROUGE reward generator ...") 99 | rouge_generator = Reward_Generator() 100 | 101 | # Create Model with various operations 102 | model = MY_Model(sess, len(vocab_dict)-2) 103 | 104 | # Start training with some pretrained model 105 | start_epoch = 1 106 | # selected_modelpath = FLAGS.train_dir+"/model.ckpt.epoch-"+str(start_epoch-1) 107 | # if not (os.path.isfile(selected_modelpath)): 108 | # print("Model not found in checkpoint folder.") 109 | # exit(0) 110 | # # Reload saved model and test 111 | # print("Reading model parameters from %s" % selected_modelpath) 112 | # model.saver.restore(sess, selected_modelpath) 113 | # print("Model loaded.") 114 | 115 | # Initialize word embedding before training 116 | print("Initialize word embedding vocabulary with pretrained embeddings ...") 117 | sess.run(model.vocab_embed_variable.assign(word_embedding_array)) 118 | 119 | ########### Start (No Mixer) Training : Reinforcement learning ################ 120 | # Reward aware training as part of Reward weighted CE , 121 | # No Curriculam learning: No annealing, consider full document like in MRT 122 | # Multiple Samples (include gold sample), No future reward, Similar to MRT 123 | # During training does not use PYROUGE to avoid multiple file rewritings 124 | # Approximate MRT with multiple pre-estimated oracle samples 125 | # June 2017: Use Single sample from multiple oracles 126 | ############################################################################### 127 | 128 | print("Start Reinforcement Training (single rollout from largest prob mass) ...") 129 | 130 | for epoch in range(start_epoch, FLAGS.train_epoch_wce + 1): 131 | print("MRT: Epoch "+str(epoch)) 132 | 133 | print("MRT: Epoch "+str(epoch)+" : Reshuffle training document indices") 134 | train_data.shuffle_fileindices() 135 | 136 | print("MRT: Epoch "+str(epoch)+" : Restore Rouge Dict") 137 | rouge_generator.restore_rouge_dict() 138 | 139 | # Start Batch Training 140 | step = 1 141 | while (step * FLAGS.batch_size) <= len(train_data.fileindices): 142 | # Get batch data as Numpy Arrays 143 | batch_docnames, batch_docs, batch_label, batch_weight, batch_oracle_multiple, batch_reward_multiple = train_data.get_batch(((step-1)*FLAGS.batch_size), 144 | (step * FLAGS.batch_size)) 145 | # print(batch_docnames) 146 | # print(batch_label[0]) 147 | # print(batch_weight[0]) 148 | # print(batch_oracle_multiple[0]) 149 | # print(batch_reward_multiple[0]) 150 | # exit(0) 151 | 152 | # Print the progress 153 | if (step % FLAGS.training_checkpoint) == 0: 154 | 155 | ce_loss_val, ce_loss_sum, acc_val, acc_sum = sess.run([model.rewardweighted_cross_entropy_loss_multisample, model.rewardweighted_ce_multisample_loss_summary, 156 | model.accuracy, model.taccuracy_summary], 157 | feed_dict={model.document_placeholder: batch_docs, 158 | model.predicted_multisample_label_placeholder: batch_oracle_multiple, 159 | model.actual_reward_multisample_placeholder: batch_reward_multiple, 160 | model.label_placeholder: batch_label, 161 | model.weight_placeholder: batch_weight}) 162 | 163 | # Print Summary to Tensor Board 164 | model.summary_writer.add_summary(ce_loss_sum, ((epoch-1)*len(train_data.fileindices)+ step*FLAGS.batch_size)) 165 | model.summary_writer.add_summary(acc_sum, ((epoch-1)*len(train_data.fileindices)+step*FLAGS.batch_size)) 166 | 167 | print("MRT: Epoch "+str(epoch)+" : Covered " + str(step*FLAGS.batch_size)+"/"+str(len(train_data.fileindices)) + 168 | " : Minibatch Reward Weighted Multisample CE Loss= {:.6f}".format(ce_loss_val) + " : Minibatch training accuracy= {:.6f}".format(acc_val)) 169 | 170 | # Run optimizer: optimize policy network 171 | sess.run([model.train_op_policynet_expreward], feed_dict={model.document_placeholder: batch_docs, 172 | model.predicted_multisample_label_placeholder: batch_oracle_multiple, 173 | model.actual_reward_multisample_placeholder: batch_reward_multiple, 174 | model.weight_placeholder: batch_weight}) 175 | 176 | # Increase step 177 | step += 1 178 | 179 | # if step == 20: 180 | # break 181 | 182 | # Save Model 183 | print("MRT: Epoch "+str(epoch)+" : Saving model after epoch completion") 184 | checkpoint_path = os.path.join(FLAGS.train_dir, "model.ckpt.epoch-"+str(epoch)) 185 | model.saver.save(sess, checkpoint_path) 186 | 187 | # Backup Rouge Dict 188 | print("MRT: Epoch "+str(epoch)+" : Saving rouge dictionary") 189 | rouge_generator.save_rouge_dict() 190 | 191 | # Performance on the validation set 192 | print("MRT: Epoch "+str(epoch)+" : Performance on the validation data") 193 | # Get Predictions: Prohibit the use of gold labels 194 | validation_logits, validation_labels, validation_weights = batch_predict_with_a_model(validation_data, model, session=sess) 195 | # Validation Accuracy and Prediction 196 | validation_acc, validation_sum = sess.run([model.final_accuracy, model.vaccuracy_summary], feed_dict={model.logits_placeholder: validation_logits.eval(session=sess), 197 | model.label_placeholder: validation_labels.eval(session=sess), 198 | model.weight_placeholder: validation_weights.eval(session=sess)}) 199 | # Print Validation Summary 200 | model.summary_writer.add_summary(validation_sum, (epoch*len(train_data.fileindices))) 201 | 202 | print("MRT: Epoch "+str(epoch)+" : Validation ("+str(len(validation_data.fileindices))+") accuracy= {:.6f}".format(validation_acc)) 203 | # Writing validation predictions and final summaries 204 | print("MRT: Epoch "+str(epoch)+" : Writing final validation summaries") 205 | validation_data.write_prediction_summaries(validation_logits, "model.ckpt.epoch-"+str(epoch), session=sess) 206 | # Extimate Rouge Scores 207 | rouge_score = rouge_generator.get_full_rouge(FLAGS.train_dir+"/model.ckpt.epoch-"+str(epoch)+".validation-summary-topranked", "validation") 208 | print("MRT: Epoch "+str(epoch)+" : Validation ("+str(len(validation_data.fileindices))+") rouge= {:.6f}".format(rouge_score)) 209 | 210 | # break 211 | 212 | print("Optimization Finished!") 213 | 214 | # ######################## Test Mode ########################### 215 | 216 | def test(): 217 | """ 218 | Test Mode: Loads an existing model and test it on the test set 219 | """ 220 | 221 | # Training: use the tf default graph 222 | 223 | with tf.Graph().as_default() and tf.device(FLAGS.use_gpu): 224 | 225 | config = tf.ConfigProto(allow_soft_placement = True) 226 | 227 | # Start a session 228 | with tf.Session(config = config) as sess: 229 | 230 | ### Prepare data for training 231 | print("Prepare vocab dict and read pretrained word embeddings ...") 232 | vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict() 233 | # vocab_dict contains _PAD and _UNK but not word_embedding_array 234 | 235 | print("Prepare test data ...") 236 | test_data = DataProcessor().prepare_news_data(vocab_dict, data_type="test") 237 | 238 | # Create Model with various operations 239 | model = MY_Model(sess, len(vocab_dict)-2) 240 | 241 | # # Initialize word embedding before training 242 | # print("Initialize word embedding vocabulary with pretrained embeddings ...") 243 | # sess.run(model.vocab_embed_variable.assign(word_embedding_array)) 244 | 245 | # Select the model 246 | if (os.path.isfile(FLAGS.train_dir+"/model.ckpt.epoch-"+str(FLAGS.model_to_load))): 247 | selected_modelpath = FLAGS.train_dir+"/model.ckpt.epoch-"+str(FLAGS.model_to_load) 248 | else: 249 | print("Model not found in checkpoint folder.") 250 | exit(0) 251 | 252 | # Reload saved model and test 253 | print("Reading model parameters from %s" % selected_modelpath) 254 | model.saver.restore(sess, selected_modelpath) 255 | print("Model loaded.") 256 | 257 | # Initialize word embedding before training 258 | print("Initialize word embedding vocabulary with pretrained embeddings ...") 259 | sess.run(model.vocab_embed_variable.assign(word_embedding_array)) 260 | 261 | # Test Accuracy and Prediction 262 | print("Performance on the test data:") 263 | FLAGS.authorise_gold_label = False 264 | test_logits, test_labels, test_weights = batch_predict_with_a_model(test_data, model, session=sess) 265 | test_acc = sess.run(model.final_accuracy, feed_dict={model.logits_placeholder: test_logits.eval(session=sess), 266 | model.label_placeholder: test_labels.eval(session=sess), 267 | model.weight_placeholder: test_weights.eval(session=sess)}) 268 | # Print Test Summary 269 | print("Test ("+str(len(test_data.fileindices))+") accuracy= {:.6f}".format(test_acc)) 270 | # Writing test predictions and final summaries 271 | test_data.write_prediction_summaries(test_logits, "model.ckpt.epoch-"+str(FLAGS.model_to_load), session=sess) 272 | 273 | ######################## Main Function ########################### 274 | 275 | def main(_): 276 | if FLAGS.exp_mode == "train": 277 | train() 278 | else: 279 | test() 280 | 281 | if __name__ == "__main__": 282 | tf.app.run() 283 | 284 | 285 | 286 | 287 | -------------------------------------------------------------------------------- /model_docsum.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | #################################### 7 | 8 | """ 9 | Document Summarization Modules and Models 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from tensorflow.python.ops import variable_scope 19 | from tensorflow.python.ops import seq2seq 20 | from tensorflow.python.ops import math_ops 21 | 22 | # from tf.nn import variable_scope 23 | from my_flags import FLAGS 24 | from model_utils import * 25 | 26 | ### Various types of extractor 27 | 28 | def sentence_extractor_nonseqrnn_noatt(sents_ext, encoder_state): 29 | """Implements Sentence Extractor: No attention and non-sequential RNN 30 | Args: 31 | sents_ext: Embedding of sentences to label for extraction 32 | encoder_state: encoder_state 33 | Returns: 34 | extractor output and logits 35 | """ 36 | # Define Variables 37 | weight = variable_on_cpu('weight', [FLAGS.size, FLAGS.target_label_size], tf.random_normal_initializer()) 38 | bias = variable_on_cpu('bias', [FLAGS.target_label_size], tf.random_normal_initializer()) 39 | 40 | # Get RNN output 41 | rnn_extractor_output, _ = simple_rnn(sents_ext, initial_state=encoder_state) 42 | 43 | with variable_scope.variable_scope("Reshape-Out"): 44 | rnn_extractor_output = reshape_list2tensor(rnn_extractor_output, FLAGS.max_doc_length, FLAGS.size) 45 | 46 | # Get Final logits without softmax 47 | extractor_output_forlogits = tf.reshape(rnn_extractor_output, [-1, FLAGS.size]) 48 | logits = tf.matmul(extractor_output_forlogits, weight) + bias 49 | # logits: [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 50 | logits = tf.reshape(logits, [-1, FLAGS.max_doc_length, FLAGS.target_label_size]) 51 | return rnn_extractor_output, logits 52 | 53 | def sentence_extractor_nonseqrnn_titimgatt(sents_ext, encoder_state, titleimages): 54 | """Implements Sentence Extractor: Non-sequential RNN with attention over title-images 55 | Args: 56 | sents_ext: Embedding of sentences to label for extraction 57 | encoder_state: encoder_state 58 | titleimages: Embeddings of title and images in the document 59 | Returns: 60 | extractor output and logits 61 | """ 62 | 63 | # Define Variables 64 | weight = variable_on_cpu('weight', [FLAGS.size, FLAGS.target_label_size], tf.random_normal_initializer()) 65 | bias = variable_on_cpu('bias', [FLAGS.target_label_size], tf.random_normal_initializer()) 66 | 67 | # Get RNN output 68 | rnn_extractor_output, _ = simple_attentional_rnn(sents_ext, titleimages, initial_state=encoder_state) 69 | 70 | with variable_scope.variable_scope("Reshape-Out"): 71 | rnn_extractor_output = reshape_list2tensor(rnn_extractor_output, FLAGS.max_doc_length, FLAGS.size) 72 | 73 | # Get Final logits without softmax 74 | extractor_output_forlogits = tf.reshape(rnn_extractor_output, [-1, FLAGS.size]) 75 | logits = tf.matmul(extractor_output_forlogits, weight) + bias 76 | # logits: [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 77 | logits = tf.reshape(logits, [-1, FLAGS.max_doc_length, FLAGS.target_label_size]) 78 | return rnn_extractor_output, logits 79 | 80 | def sentence_extractor_seqrnn_docatt(sents_ext, encoder_outputs, encoder_state, sents_labels): 81 | """Implements Sentence Extractor: Sequential RNN with attention over sentences during encoding 82 | Args: 83 | sents_ext: Embedding of sentences to label for extraction 84 | encoder_outputs, encoder_state 85 | sents_labels: Gold sent labels for training 86 | Returns: 87 | extractor output and logits 88 | """ 89 | # Define MLP Variables 90 | weights = { 91 | 'h1': variable_on_cpu('weight_1', [2*FLAGS.size, FLAGS.size], tf.random_normal_initializer()), 92 | 'h2': variable_on_cpu('weight_2', [FLAGS.size, FLAGS.size], tf.random_normal_initializer()), 93 | 'out': variable_on_cpu('weight_out', [FLAGS.size, FLAGS.target_label_size], tf.random_normal_initializer()) 94 | } 95 | biases = { 96 | 'b1': variable_on_cpu('bias_1', [FLAGS.size], tf.random_normal_initializer()), 97 | 'b2': variable_on_cpu('bias_2', [FLAGS.size], tf.random_normal_initializer()), 98 | 'out': variable_on_cpu('bias_out', [FLAGS.target_label_size], tf.random_normal_initializer()) 99 | } 100 | 101 | # Shift sents_ext for RNN 102 | with variable_scope.variable_scope("Shift-SentExt"): 103 | # Create embeddings for special symbol (lets assume all 0) and put in the front by shifting by one 104 | special_tensor = tf.zeros_like(sents_ext[0]) # tf.ones_like(sents_ext[0]) 105 | sents_ext_shifted = [special_tensor] + sents_ext[:-1] 106 | 107 | # Reshape sents_labels for RNN (Only used for cross entropy training) 108 | with variable_scope.variable_scope("Reshape-Label"): 109 | # only used for training 110 | sents_labels = reshape_tensor2list(sents_labels, FLAGS.max_doc_length, FLAGS.target_label_size) 111 | 112 | # Define Sequential Decoder 113 | extractor_outputs, logits = jporg_attentional_seqrnn_decoder(sents_ext_shifted, encoder_outputs, encoder_state, sents_labels, weights, biases) 114 | 115 | # Final logits without softmax 116 | with variable_scope.variable_scope("Reshape-Out"): 117 | logits = reshape_list2tensor(logits, FLAGS.max_doc_length, FLAGS.target_label_size) 118 | extractor_outputs = reshape_list2tensor(extractor_outputs, FLAGS.max_doc_length, 2*FLAGS.size) 119 | 120 | return extractor_outputs, logits 121 | 122 | 123 | def policy_network(vocab_embed_variable, document_placeholder, label_placeholder): 124 | """Build the policy core network. 125 | Args: 126 | vocab_embed_variable: [vocab_size, FLAGS.wordembed_size], embeddings without PAD and UNK 127 | document_placeholder: [None,(FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length), FLAGS.max_sent_length] 128 | label_placeholder: Gold label [None, FLAGS.max_doc_length, FLAGS.target_label_size], only used during cross entropy training of JP's model. 129 | Returns: 130 | Outputs of sentence extractor and logits without softmax 131 | """ 132 | 133 | with tf.variable_scope('PolicyNetwork') as scope: 134 | 135 | ### Full Word embedding Lookup Variable 136 | # PADDING embedding non-trainable 137 | pad_embed_variable = variable_on_cpu("pad_embed", [1, FLAGS.wordembed_size], tf.constant_initializer(0), trainable=False) 138 | # UNK embedding trainable 139 | unk_embed_variable = variable_on_cpu("unk_embed", [1, FLAGS.wordembed_size], tf.constant_initializer(0), trainable=True) 140 | # Get fullvocab_embed_variable 141 | fullvocab_embed_variable = tf.concat(0, [pad_embed_variable, unk_embed_variable, vocab_embed_variable]) 142 | # print(fullvocab_embed_variable) 143 | 144 | ### Lookup layer 145 | with tf.variable_scope('Lookup') as scope: 146 | document_placeholder_flat = tf.reshape(document_placeholder, [-1]) 147 | document_word_embedding = tf.nn.embedding_lookup(fullvocab_embed_variable, document_placeholder_flat, name="Lookup") 148 | document_word_embedding = tf.reshape(document_word_embedding, [-1, (FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length), 149 | FLAGS.max_sent_length, FLAGS.wordembed_size]) 150 | # print(document_word_embedding) 151 | 152 | ### Convolution Layer 153 | with tf.variable_scope('ConvLayer') as scope: 154 | document_word_embedding = tf.reshape(document_word_embedding, [-1, FLAGS.max_sent_length, FLAGS.wordembed_size]) 155 | document_sent_embedding = conv1d_layer_sentence_representation(document_word_embedding) # [None, sentembed_size] 156 | document_sent_embedding = tf.reshape(document_sent_embedding, [-1, (FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length), 157 | FLAGS.sentembed_size]) 158 | # print(document_sent_embedding) 159 | 160 | ### Reshape Tensor to List [-1, (max_doc_length+max_title_length+max_image_length), sentembed_size] -> List of [-1, sentembed_size] 161 | with variable_scope.variable_scope("ReshapeDoc_TensorToList"): 162 | document_sent_embedding = reshape_tensor2list(document_sent_embedding, (FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length), FLAGS.sentembed_size) 163 | # print(document_sent_embedding) 164 | 165 | # document_sents_enc 166 | document_sents_enc = document_sent_embedding[:FLAGS.max_doc_length] 167 | if FLAGS.doc_encoder_reverse: 168 | document_sents_enc = document_sents_enc[::-1] 169 | 170 | # document_sents_ext 171 | document_sents_ext = document_sent_embedding[:FLAGS.max_doc_length] 172 | 173 | # document_sents_titimg 174 | document_sents_titimg = document_sent_embedding[FLAGS.max_doc_length:] 175 | 176 | ### Document Encoder 177 | with tf.variable_scope('DocEnc') as scope: 178 | encoder_outputs, encoder_state = simple_rnn(document_sents_enc) 179 | 180 | ### Sentence Label Extractor 181 | with tf.variable_scope('SentExt') as scope: 182 | if (FLAGS.attend_encoder) and (len(document_sents_titimg) != 0): 183 | # Multiple decoder 184 | print("Multiple decoder is not implement yet.") 185 | exit(0) 186 | # # Decoder to attend captions 187 | # attendtitimg_extractor_output, _ = simple_attentional_rnn(document_sents_ext, document_sents_titimg, initial_state=encoder_state) 188 | # # Attend previous decoder 189 | # logits = sentence_extractor_seqrnn_docatt(document_sents_ext, attendtitimg_extractor_output, encoder_state, label_placeholder) 190 | 191 | elif (not FLAGS.attend_encoder) and (len(document_sents_titimg) != 0): 192 | # Attend only titimages during decoding 193 | extractor_output, logits = sentence_extractor_nonseqrnn_titimgatt(document_sents_ext, encoder_state, document_sents_titimg) 194 | 195 | elif (FLAGS.attend_encoder) and (len(document_sents_titimg) == 0): 196 | # JP model: attend encoder 197 | extractor_outputs, logits = sentence_extractor_seqrnn_docatt(document_sents_ext, encoder_outputs, encoder_state, label_placeholder) 198 | 199 | else: 200 | # Attend nothing 201 | extractor_output, logits = sentence_extractor_nonseqrnn_noatt(document_sents_ext, encoder_state) 202 | 203 | # print(extractor_output) 204 | # print(logits) 205 | return extractor_output, logits 206 | 207 | def baseline_future_reward_estimator(extractor_output): 208 | """Implements linear regression to estimate future rewards 209 | Args: 210 | extractor_output: [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.size or 2*FLAGS.size] 211 | Output: 212 | rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 213 | """ 214 | 215 | with tf.variable_scope('FutureRewardEstimator') as scope: 216 | 217 | last_size = extractor_output.get_shape()[2].value 218 | 219 | # Define Variables 220 | weight = variable_on_cpu('weight', [last_size, 1], tf.random_normal_initializer()) 221 | bias = variable_on_cpu('bias', [1], tf.random_normal_initializer()) 222 | 223 | extractor_output_forreward = tf.reshape(extractor_output, [-1, last_size]) 224 | future_rewards = tf.matmul(extractor_output_forreward, weight) + bias 225 | # future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length, 1] 226 | future_rewards = tf.reshape(future_rewards, [-1, FLAGS.max_doc_length, 1]) 227 | future_rewards = tf.squeeze(future_rewards) 228 | return future_rewards 229 | 230 | def baseline_single_future_reward_estimator(extractor_output): 231 | """Implements linear regression to estimate future rewards for whole document 232 | Args: 233 | extractor_output: [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.size or 2*FLAGS.size] 234 | Output: 235 | rewards: [FLAGS.batch_size] 236 | """ 237 | 238 | with tf.variable_scope('FutureRewardEstimator') as scope: 239 | 240 | last_size = extractor_output.get_shape()[2].value 241 | 242 | # Define Variables 243 | weight = variable_on_cpu('weight', [FLAGS.max_doc_length*last_size, 1], tf.random_normal_initializer()) 244 | bias = variable_on_cpu('bias', [1], tf.random_normal_initializer()) 245 | 246 | extractor_output_forreward = tf.reshape(extractor_output, [-1, FLAGS.max_doc_length*last_size]) # [FLAGS.batch_size, FLAGS.max_doc_length*(FLAGS.size or 2*FLAGS.size)] 247 | future_rewards = tf.matmul(extractor_output_forreward, weight) + bias # [FLAGS.batch_size, 1] 248 | # future_rewards: [FLAGS.batch_size, 1] 249 | future_rewards = tf.squeeze(future_rewards) # [FLAGS.batch_size] 250 | return future_rewards 251 | 252 | ### Loss Functions 253 | 254 | def mean_square_loss_doclevel(future_rewards, actual_reward): 255 | """Implements mean_square_loss for futute reward prediction 256 | args: 257 | future_rewards: [FLAGS.batch_size] 258 | actual_reward: [FLAGS.batch_size] 259 | Output 260 | Float Value 261 | """ 262 | with tf.variable_scope('MeanSquareLoss') as scope: 263 | sq_loss = tf.square(future_rewards - actual_reward) # [FLAGS.batch_size] 264 | 265 | mean_sq_loss = tf.reduce_mean(sq_loss) 266 | 267 | tf.add_to_collection('mean_square_loss', mean_sq_loss) 268 | 269 | return mean_sq_loss 270 | 271 | def mean_square_loss(future_rewards, actual_reward, weights): 272 | """Implements mean_square_loss for futute reward prediction 273 | args: 274 | future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 275 | actual_reward: [FLAGS.batch_size] 276 | weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 277 | Output 278 | Float Value 279 | """ 280 | with tf.variable_scope('MeanSquareLoss') as scope: 281 | actual_reward = tf.expand_dims(actual_reward, 1) # [FLAGS.batch_size, 1] 282 | sq_loss = tf.square(future_rewards - actual_reward) # [FLAGS.batch_size, FLAGS.max_doc_length] 283 | 284 | mean_sq_loss = 0 285 | if FLAGS.weighted_loss: 286 | sq_loss = tf.mul(sq_loss, weights) 287 | sq_loss_sum = tf.reduce_sum(sq_loss) 288 | valid_sentences = tf.reduce_sum(weights) 289 | mean_sq_loss = sq_loss_sum / valid_sentences 290 | else: 291 | mean_sq_loss = tf.reduce_mean(sq_loss) 292 | 293 | tf.add_to_collection('mean_square_loss', mean_sq_loss) 294 | 295 | return mean_sq_loss 296 | 297 | def cross_entropy_loss(logits, labels, weights): 298 | """Estimate cost of predictions 299 | Add summary for "cost" and "cost/avg". 300 | Args: 301 | logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 302 | labels: Sentence extraction gold levels [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 303 | weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 304 | Returns: 305 | Cross-entropy Cost 306 | """ 307 | with tf.variable_scope('CrossEntropyLoss') as scope: 308 | # Reshape logits and labels to match the requirement of softmax_cross_entropy_with_logits 309 | logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 310 | labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 311 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) # [FLAGS.batch_size*FLAGS.max_doc_length] 312 | cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 313 | if FLAGS.weighted_loss: 314 | cross_entropy = tf.mul(cross_entropy, weights) 315 | 316 | # Cross entroy / document 317 | cross_entropy = tf.reduce_sum(cross_entropy, reduction_indices=1) # [FLAGS.batch_size] 318 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='crossentropy') 319 | 320 | # ## Cross entroy / sentence 321 | # cross_entropy_sum = tf.reduce_sum(cross_entropy) 322 | # valid_sentences = tf.reduce_sum(weights) 323 | # cross_entropy_mean = cross_entropy_sum / valid_sentences 324 | 325 | # cross_entropy = -tf.reduce_sum(labels * tf.log(logits), reduction_indices=1) 326 | # cross_entropy_mean = tf.reduce_mean(cross_entropy, name='crossentropy') 327 | 328 | tf.add_to_collection('cross_entropy_loss', cross_entropy_mean) 329 | # # # The total loss is defined as the cross entropy loss plus all of 330 | # # # the weight decay terms (L2 loss). 331 | # # return tf.add_n(tf.get_collection('losses'), name='total_loss') 332 | return cross_entropy_mean 333 | 334 | def predict_labels(logits): 335 | """ Predict self labels 336 | logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 337 | Return [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 338 | """ 339 | with tf.variable_scope('PredictLabels') as scope: 340 | # Reshape logits for argmax and argmin 341 | logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 342 | # Get labels predicted using these logits 343 | logits_argmax = tf.argmax(logits, 1) # [FLAGS.batch_size*FLAGS.max_doc_length] 344 | logits_argmax = tf.reshape(logits_argmax, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 345 | logits_argmax = tf.expand_dims(logits_argmax, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 346 | 347 | logits_argmin = tf.argmin(logits, 1) # [FLAGS.batch_size*FLAGS.max_doc_length] 348 | logits_argmin = tf.reshape(logits_argmin, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 349 | logits_argmin = tf.expand_dims(logits_argmin, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 350 | 351 | # Convert argmin and argmax to labels, works only if FLAGS.target_label_size = 2 352 | labels = tf.concat(2, [logits_argmin, logits_argmax]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 353 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 354 | labels = tf.cast(labels, dtype) 355 | 356 | return labels 357 | 358 | def estimate_ltheta_ot(logits, labels, future_rewards, actual_rewards, weights): 359 | """ 360 | Args: 361 | logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 362 | labels: Label placeholdr for self prediction [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 363 | future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 364 | actual_reward: [FLAGS.batch_size] 365 | weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 366 | Returns: 367 | [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 368 | """ 369 | with tf.variable_scope('LTheta_Ot') as scope: 370 | # Get Reward Weights: External reward - Predicted reward 371 | actual_rewards = tf.tile(actual_rewards, [FLAGS.max_doc_length]) # [FLAGS.batch_size * FLAGS.max_doc_length] , [a,b] * 3 = [a, b, a, b, a, b] 372 | actual_rewards = tf.reshape(actual_rewards, [FLAGS.max_doc_length, -1]) # [FLAGS.max_doc_length, FLAGS.batch_size], # [[a,b], [a,b], [a,b]] 373 | actual_rewards = tf.transpose(actual_rewards) # [FLAGS.batch_size, FLAGS.max_doc_length] # [[a,a,a], [b,b,b]] 374 | 375 | diff_act_pred = actual_rewards - future_rewards # [FLAGS.batch_size, FLAGS.max_doc_length] 376 | diff_act_pred = tf.expand_dims(diff_act_pred, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 377 | # Convert (FLAGS.target_label_size = 2) 378 | diff_act_pred = tf.concat(2, [diff_act_pred, diff_act_pred]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 379 | 380 | # Reshape logits and labels to match the requirement of softmax_cross_entropy_with_logits 381 | logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 382 | logits = tf.nn.softmax(logits) 383 | logits = tf.reshape(logits, [-1, FLAGS.max_doc_length, FLAGS.target_label_size]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 384 | 385 | # Get the difference 386 | diff_logits_indicator = logits - labels # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 387 | 388 | 389 | # Multiply with reward 390 | d_ltheta_ot = tf.mul(diff_act_pred, diff_logits_indicator) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 391 | 392 | # Multiply with weight 393 | weights = tf.expand_dims(weights, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 394 | weights = tf.concat(2, [weights, weights]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 395 | d_ltheta_ot = tf.mul(d_ltheta_ot, weights) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 396 | 397 | return d_ltheta_ot 398 | 399 | # def estimate_ltheta_ot_mixer(logits, labels_gold, labels_pred, future_rewards, actual_rewards, weights, annealing_step): 400 | # """ 401 | # Args: 402 | # logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 403 | # labels_gold: Label placeholdr for gold labels [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 404 | # labels_pred: Label placeholdr for self prediction [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 405 | # future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 406 | # actual_reward: [FLAGS.batch_size] 407 | # weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 408 | # annealing_step: [1], single value but in tensor form 409 | # Returns: 410 | # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 411 | # """ 412 | # with tf.variable_scope('LTheta_Ot_Mixer') as scope: 413 | # print(annealing_step) 414 | 415 | # policygradloss_length = tf.reduce_sum(annealing_step) * FLAGS.annealing_step_delta 416 | # crossentryloss_length = FLAGS.max_doc_length - policygradloss_length 417 | 418 | 419 | 420 | # # Reshape logits and partition 421 | # logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 422 | # logits = tf.nn.softmax(logits) 423 | # logits = tf.reshape(logits, [-1, FLAGS.max_doc_length, FLAGS.target_label_size]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 424 | # logits_list = reshape_tensor2list(logits, FLAGS.max_doc_length, FLAGS.target_label_size) 425 | # logits_ce_gold_list = logits_list[0:crossentryloss_length] 426 | # logits_ce_gold = reshape_list2tensor(logits_ce_gold_list, crossentryloss_length, FLAGS.target_label_size) # [FLAGS.batch_size, crossentryloss_length, FLAGS.target_label_size] 427 | # logits_reward_list = logits_list[crossentryloss_length:] 428 | # logits_reward = reshape_list2tensor(logits_reward_list, policygradloss_length, FLAGS.target_label_size) # [FLAGS.batch_size, policygradloss_length, FLAGS.target_label_size] 429 | 430 | # # Crossentropy loss with gold labels: partition gold_labels 431 | # labels_gold_list = reshape_tensor2list(labels_gold, FLAGS.max_doc_length, FLAGS.target_label_size) 432 | # labels_gold_used_list = labels_gold_list[0:crossentryloss_length] 433 | # labels_gold_used = reshape_list2tensor(labels_gold_used_list, crossentryloss_length, FLAGS.target_label_size) # [FLAGS.batch_size, crossentryloss_length, FLAGS.target_label_size] 434 | 435 | # # d_ltheta_ot : cross entropy 436 | # diff_logits_goldlabels = logits_ce_gold - labels_gold_used # [FLAGS.batch_size, crossentryloss_length, FLAGS.target_label_size] 437 | 438 | # # Policy gradient for rest 439 | 440 | # # Get Reward Weights: External reward - Predicted reward 441 | # actual_rewards = tf.tile(actual_rewards, [FLAGS.max_doc_length]) # [FLAGS.batch_size * FLAGS.max_doc_length] , [a,b] * 3 = [a, b, a, b, a, b] 442 | # actual_rewards = tf.reshape(actual_rewards, [FLAGS.max_doc_length, -1]) # [FLAGS.max_doc_length, FLAGS.batch_size], # [[a,b], [a,b], [a,b]] 443 | # actual_rewards = tf.transpose(actual_rewards) # [FLAGS.batch_size, FLAGS.max_doc_length] # [[a,a,a], [b,b,b]] 444 | # diff_act_pred = actual_rewards - future_rewards # [FLAGS.batch_size, FLAGS.max_doc_length] 445 | # diff_act_pred = tf.expand_dims(diff_act_pred, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 446 | # # Convert (FLAGS.target_label_size = 2) 447 | # diff_act_pred = tf.concat(2, [diff_act_pred, diff_act_pred]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 448 | 449 | # # Get used reward diff 450 | # diff_act_pred_list = reshape_tensor2list(diff_act_pred, FLAGS.max_doc_length, FLAGS.target_label_size) 451 | # diff_reward_act_pred_used_list = diff_act_pred_list[crossentryloss_length:] 452 | # diff_reward_act_pred_used = reshape_list2tensor(diff_reward_act_pred_used_list, policygradloss_length, FLAGS.target_label_size) # [FLAGS.batch_size, policygradloss_length, FLAGS.target_label_size] 453 | 454 | # # Partition predicted labels 455 | # labels_pred_list = reshape_tensor2list(labels_pred, FLAGS.max_doc_length, FLAGS.target_label_size) 456 | # labels_pred_used_list = labels_pred_list[crossentryloss_length:] 457 | # labels_pred_used = reshape_list2tensor(labels_pred_used_list, policygradloss_length, FLAGS.target_label_size) # [FLAGS.batch_size, policygradloss_length, FLAGS.target_label_size] 458 | 459 | # # d_ltheta_ot : reward weighted 460 | # diff_logits_predlabels = logits_reward - labels_pred_used # [FLAGS.batch_size, policygradloss_length, FLAGS.target_label_size] 461 | # # Multiply with reward 462 | # reward_weighted_diff_logits_predlabels = tf.mul(diff_reward_act_pred_used, diff_logits_predlabels) # [FLAGS.batch_size, policygradloss_length, FLAGS.target_label_size] 463 | 464 | # # Concat both part 465 | # d_ltheta_ot_mixer = tf.concat(1, [diff_logits_goldlabels, reward_weighted_diff_logits_predlabels]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 466 | 467 | # # Multiply with weight 468 | # weights = tf.expand_dims(weights, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 469 | # weights = tf.concat(2, [weights, weights]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 470 | # d_ltheta_ot_mixer = tf.mul(d_ltheta_ot_mixer, weights) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 471 | 472 | # return d_ltheta_ot_mixer 473 | 474 | def reward_weighted_cross_entropy_loss_multisample(logits, labels, actual_rewards, weights): 475 | """Estimate cost of predictions 476 | Add summary for "cost" and "cost/avg". 477 | Args: 478 | logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 479 | labels: Label placeholdr for multiple sampled prediction [FLAGS.batch_size, 1, FLAGS.max_doc_length, FLAGS.target_label_size] 480 | actual_rewards: [FLAGS.batch_size, 1] 481 | weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 482 | Returns: 483 | Cross-entropy Cost 484 | """ 485 | 486 | with tf.variable_scope('RWCELossMultiSample') as scope: 487 | # Expand logits and weights for roll outs 488 | logits_temp = tf.expand_dims(logits, 1) # [FLAGS.batch_size, 1, FLAGS.max_doc_length, FLAGS.target_label_size] 489 | weights_temp = tf.expand_dims(weights, 1) # [FLAGS.batch_size, 1, FLAGS.max_doc_length] 490 | logits_expanded = logits_temp 491 | weights_expanded = weights_temp 492 | # for ridx in range(1,FLAGS.num_sample_rollout): 493 | # logits_expanded = tf.concat(1, [logits_expanded, logits_temp]) # [FLAGS.batch_size, n++, FLAGS.max_doc_length, FLAGS.target_label_size] 494 | # weights_expanded = tf.concat(1, [weights_expanded, weights_temp]) # [FLAGS.batch_size, n++, FLAGS.max_doc_length] 495 | 496 | # Reshape logits and labels to match the requirement of softmax_cross_entropy_with_logits 497 | logits_expanded = tf.reshape(logits_expanded, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*1*FLAGS.max_doc_length, FLAGS.target_label_size] 498 | labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*1*FLAGS.max_doc_length, FLAGS.target_label_size] 499 | 500 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits_expanded, labels) # [FLAGS.batch_size*1*FLAGS.max_doc_length] 501 | cross_entropy = tf.reshape(cross_entropy, [-1, 1, FLAGS.max_doc_length]) # [FLAGS.batch_size, 1, FLAGS.max_doc_length] 502 | if FLAGS.weighted_loss: 503 | cross_entropy = tf.mul(cross_entropy, weights_expanded) # [FLAGS.batch_size, 1, FLAGS.max_doc_length] 504 | 505 | # Reshape actual rewards 506 | actual_rewards = tf.reshape(actual_rewards, [-1]) # [FLAGS.batch_size*1] 507 | # [[a, b], [c, d], [e, f]] 3x2 => [a, b, c, d, e, f] [6] 508 | actual_rewards = tf.tile(actual_rewards, [FLAGS.max_doc_length]) # [FLAGS.batch_size * 1 * FLAGS.max_doc_length] 509 | # [a, b, c, d, e, f] * 2 = [a, b, c, d, e, f, a, b, c, d, e, f] [12] 510 | actual_rewards = tf.reshape(actual_rewards, [FLAGS.max_doc_length, -1]) # [FLAGS.max_doc_length, FLAGS.batch_size*1] 511 | # [[a, b, c, d, e, f], [a, b, c, d, e, f]] [2, 6] 512 | actual_rewards = tf.transpose(actual_rewards) # [FLAGS.batch_size*1, FLAGS.max_doc_length] 513 | # [[a,a], [b,b], [c,c], [d,d], [e,e], [f,f]] [6 x 2] 514 | actual_rewards = tf.reshape(actual_rewards, [-1, 1, FLAGS.max_doc_length]) # [FLAGS.batch_size, 1, FLAGS.max_doc_length], 515 | # [[[a,a], [b,b]], [[c,c], [d,d]], [[e,e], [f,f]]] [3 x 2 x 2] 516 | 517 | # Multiply with reward 518 | reward_weighted_cross_entropy = tf.mul(cross_entropy, actual_rewards) # [FLAGS.batch_size, 1, FLAGS.max_doc_length] 519 | 520 | # Cross entroy / sample / document 521 | reward_weighted_cross_entropy = tf.reduce_sum(reward_weighted_cross_entropy, reduction_indices=2) # [FLAGS.batch_size, 1] 522 | reward_weighted_cross_entropy_mean = tf.reduce_mean(reward_weighted_cross_entropy, name='rewardweightedcemultisample') 523 | 524 | tf.add_to_collection('reward_cross_entropy_loss_multisample', reward_weighted_cross_entropy_mean) 525 | 526 | return reward_weighted_cross_entropy_mean 527 | 528 | def reward_weighted_cross_entropy_loss(logits, labels, actual_rewards, weights): 529 | """Estimate cost of predictions 530 | Add summary for "cost" and "cost/avg". 531 | Args: 532 | logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 533 | labels: Label placeholdr for self prediction [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 534 | actual_reward: [FLAGS.batch_size] 535 | weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 536 | Returns: 537 | Cross-entropy Cost 538 | """ 539 | 540 | with tf.variable_scope('RewardWeightedCrossEntropyLoss') as scope: 541 | 542 | # Reshape logits and labels to match the requirement of softmax_cross_entropy_with_logits 543 | logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 544 | labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 545 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) # [FLAGS.batch_size*FLAGS.max_doc_length] 546 | cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 547 | if FLAGS.weighted_loss: 548 | cross_entropy = tf.mul(cross_entropy, weights) # [FLAGS.batch_size, FLAGS.max_doc_length] 549 | 550 | # Reshape actual rewards 551 | actual_rewards = tf.tile(actual_rewards, [FLAGS.max_doc_length]) # [FLAGS.batch_size * FLAGS.max_doc_length] , [a,b] * 3 = [a, b, a, b, a, b] 552 | actual_rewards = tf.reshape(actual_rewards, [FLAGS.max_doc_length, -1]) # [FLAGS.max_doc_length, FLAGS.batch_size], # [[a,b], [a,b], [a,b]] 553 | actual_rewards = tf.transpose(actual_rewards) # [FLAGS.batch_size, FLAGS.max_doc_length] # [[a,a,a], [b,b,b]] 554 | 555 | # Multiply with reward 556 | reward_weighted_cross_entropy = tf.mul(cross_entropy, actual_rewards) # [FLAGS.batch_size, FLAGS.max_doc_length] 557 | 558 | # Cross entroy / document 559 | reward_weighted_cross_entropy = tf.reduce_sum(reward_weighted_cross_entropy, reduction_indices=1) # [FLAGS.batch_size] 560 | reward_weighted_cross_entropy_mean = tf.reduce_mean(reward_weighted_cross_entropy, name='rewardweightedcrossentropy') 561 | 562 | tf.add_to_collection('reward_cross_entropy_loss', reward_weighted_cross_entropy_mean) 563 | 564 | return reward_weighted_cross_entropy_mean 565 | 566 | # def reward_weighted_cross_entropy_loss(logits, labels, future_rewards, actual_rewards, weights): 567 | # """Estimate cost of predictions 568 | # Add summary for "cost" and "cost/avg". 569 | # Args: 570 | # logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 571 | # labels: Label placeholdr for self prediction [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 572 | # future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 573 | # actual_reward: [FLAGS.batch_size] 574 | # weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 575 | # Returns: 576 | # Cross-entropy Cost 577 | # """ 578 | 579 | # with tf.variable_scope('RewardWeightedCrossEntropyLoss') as scope: 580 | # # Get Reward Weights: External reward - Predicted reward 581 | # actual_rewards = tf.tile(actual_rewards, [FLAGS.max_doc_length]) # [FLAGS.batch_size * FLAGS.max_doc_length] , [a,b] * 3 = [a, b, a, b, a, b] 582 | # actual_rewards = tf.reshape(actual_rewards, [FLAGS.max_doc_length, -1]) # [FLAGS.max_doc_length, FLAGS.batch_size], # [[a,b], [a,b], [a,b]] 583 | # actual_rewards = tf.transpose(actual_rewards) # [FLAGS.batch_size, FLAGS.max_doc_length] # [[a,a,a], [b,b,b]] 584 | 585 | # # Error: actual_rewards = tf.reshape(tf.tile(actual_rewards, [FLAGS.max_doc_length]),[-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 586 | 587 | # diff_act_pred = future_rewards - actual_rewards # actual_rewards - future_rewards # [FLAGS.batch_size, FLAGS.max_doc_length] 588 | 589 | # # Reshape logits and labels to match the requirement of softmax_cross_entropy_with_logits 590 | # logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 591 | # labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 592 | # cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) # [FLAGS.batch_size*FLAGS.max_doc_length] 593 | # cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 594 | # if FLAGS.weighted_loss: 595 | # cross_entropy = tf.mul(cross_entropy, weights) # [FLAGS.batch_size, FLAGS.max_doc_length] 596 | 597 | # # Multiply with reward 598 | # reward_weighted_cross_entropy = tf.mul(cross_entropy, diff_act_pred) # [FLAGS.batch_size, FLAGS.max_doc_length] 599 | 600 | # # Cross entroy / document 601 | # reward_weighted_cross_entropy = tf.reduce_sum(reward_weighted_cross_entropy, reduction_indices=1) # [FLAGS.batch_size] 602 | # reward_weighted_cross_entropy_mean = tf.reduce_mean(reward_weighted_cross_entropy, name='rewardweightedcrossentropy') 603 | 604 | # tf.add_to_collection('reward_cross_entropy_loss', reward_weighted_cross_entropy_mean) 605 | 606 | # return reward_weighted_cross_entropy_mean 607 | 608 | # def temp_reward_weighted_cross_entropy_loss(logits, labels, future_rewards, actual_rewards, weights): 609 | # """Estimate cost of predictions 610 | # Add summary for "cost" and "cost/avg". 611 | # Args: 612 | # logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 613 | # labels: Label placeholdr for self prediction [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 614 | # future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 615 | # actual_reward: [FLAGS.batch_size] 616 | # weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 617 | # Returns: 618 | # Cross-entropy Cost 619 | # """ 620 | 621 | # with tf.variable_scope('TempRewardWeightedCrossEntropyLoss') as scope: 622 | # # Get Reward Weights: External reward - Predicted reward 623 | # actual_rewards = tf.tile(actual_rewards, [FLAGS.max_doc_length]) # [FLAGS.batch_size * FLAGS.max_doc_length] , [a,b] * 3 = [a, b, a, b, a, b] 624 | # actual_rewards = tf.reshape(actual_rewards, [FLAGS.max_doc_length, -1]) # [FLAGS.max_doc_length, FLAGS.batch_size], # [[a,b], [a,b], [a,b]] 625 | # actual_rewards = tf.transpose(actual_rewards) # [FLAGS.batch_size, FLAGS.max_doc_length] # [[a,a,a], [b,b,b]] 626 | 627 | # diff_act_pred = future_rewards - actual_rewards # actual_rewards - future_rewards # [FLAGS.batch_size, FLAGS.max_doc_length] 628 | 629 | # # Reshape logits and labels to match the requirement of softmax_cross_entropy_with_logits 630 | # logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 631 | # labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 632 | # cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) # [FLAGS.batch_size*FLAGS.max_doc_length] 633 | # cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 634 | # if FLAGS.weighted_loss: 635 | # cross_entropy = tf.mul(cross_entropy, weights) # [FLAGS.batch_size, FLAGS.max_doc_length] 636 | 637 | # # Multiply with reward 638 | # reward_weighted_cross_entropy = tf.mul(cross_entropy, diff_act_pred) # [FLAGS.batch_size, FLAGS.max_doc_length] 639 | 640 | # # Cross entroy / document 641 | # reward_weighted_cross_entropy = tf.reduce_sum(reward_weighted_cross_entropy, reduction_indices=1) # [FLAGS.batch_size] 642 | # reward_weighted_cross_entropy_mean = tf.reduce_mean(reward_weighted_cross_entropy, name='rewardweightedcrossentropy') 643 | 644 | # optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 645 | 646 | # # Compute gradients of policy network 647 | # policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PolicyNetwork") 648 | # # print(policy_network_variables) 649 | 650 | # # Compute gradients of policy network 651 | # grads_and_vars = optimizer.compute_gradients(reward_weighted_cross_entropy_mean, var_list=policy_network_variables) 652 | # # print(grads_and_vars) 653 | 654 | # return actual_rewards, cross_entropy, diff_act_pred, reward_weighted_cross_entropy, reward_weighted_cross_entropy_mean, grads_and_vars 655 | 656 | 657 | # def cross_entropy_loss_selfprediction(logits, weights): 658 | # """Optimizing expected reward: Weighted cross entropy 659 | # args: 660 | # logits: Logits without softmax. [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 661 | # weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 662 | # return: 663 | # [FLAGS.batch_size, FLAGS.max_doc_length] 664 | # """ 665 | # with tf.variable_scope('SelfPredCrossEntropyLoss') as scope: 666 | # # Reshape logits for argmax and argmin 667 | # logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 668 | 669 | # # Get labels if predicted using these logits 670 | # logits_argmax = tf.argmax(logits, 1) # [FLAGS.batch_size*FLAGS.max_doc_length] 671 | # logits_argmax = tf.reshape(logits_argmax, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 672 | # logits_argmax = tf.expand_dims(logits_argmax, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 673 | # logits_argmin = tf.argmin(logits, 1) # [FLAGS.batch_size*FLAGS.max_doc_length] 674 | # logits_argmin = tf.reshape(logits_argmin, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 675 | # logits_argmin = tf.expand_dims(logits_argmin, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 676 | # # Convert argmin and argmax to labels, works only if FLAGS.target_label_size = 2 677 | # labels = tf.concat(2, [logits_argmin, logits_argmax]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 678 | # dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 679 | # labels = tf.cast(labels, dtype) 680 | # labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 681 | 682 | # # softmax_cross_entropy_with_logits 683 | # cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) # [FLAGS.batch_size*FLAGS.max_doc_length] 684 | # cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 685 | # if FLAGS.weighted_loss: 686 | # cross_entropy = tf.mul(cross_entropy, weights) 687 | # return cross_entropy 688 | 689 | # def weighted_cross_entropy_loss(logits, future_rewards, actual_reward, weights): 690 | # """Optimizing expected reward: Weighted cross entropy 691 | # args: 692 | # logits: Logits without softmax. [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 693 | # future_rewards: [FLAGS.batch_size, FLAGS.max_doc_length] 694 | # actual_reward: [FLAGS.batch_size] 695 | # weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 696 | # """ 697 | 698 | # with tf.variable_scope('WeightedCrossEntropyLoss') as scope: 699 | # # Get Weights: External reward - Predicted reward 700 | # actual_reward = tf.reshape(tf.tile(actual_reward, [FLAGS.max_doc_length]),[-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 701 | # diff_act_pred = future_rewards - actual_reward # actual_reward - future_rewards # [FLAGS.batch_size, FLAGS.max_doc_length] 702 | 703 | # # Reshape logits for argmax and argmin 704 | # logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 705 | 706 | # # Get labels if predicted using these logits 707 | # logits_argmax = tf.argmax(logits, 1) # [FLAGS.batch_size*FLAGS.max_doc_length] 708 | # logits_argmax = tf.reshape(logits_argmax, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 709 | # logits_argmax = tf.expand_dims(logits_argmax, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 710 | # logits_argmin = tf.argmin(logits, 1) # [FLAGS.batch_size*FLAGS.max_doc_length] 711 | # logits_argmin = tf.reshape(logits_argmin, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 712 | # logits_argmin = tf.expand_dims(logits_argmin, 2) # [FLAGS.batch_size, FLAGS.max_doc_length, 1] 713 | # # Convert argmin and argmax to labels, works only if FLAGS.target_label_size = 2 714 | # labels = tf.concat(2, [logits_argmin, logits_argmax]) # [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 715 | # dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 716 | # labels = tf.cast(labels, dtype) 717 | # labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 718 | 719 | # # softmax_cross_entropy_with_logits 720 | # cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) # [FLAGS.batch_size*FLAGS.max_doc_length] 721 | # cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 722 | # if FLAGS.weighted_loss: 723 | # cross_entropy = tf.mul(cross_entropy, weights) 724 | # # Multiply with reward 725 | # cross_entropy = tf.mul(cross_entropy, diff_act_pred) 726 | 727 | # # Cross entroy / document 728 | # cross_entropy = tf.reduce_sum(cross_entropy, reduction_indices=1) # [FLAGS.batch_size] 729 | # cross_entropy_mean = tf.reduce_mean(cross_entropy, name='crossentropy') 730 | 731 | # tf.add_to_collection('reward_cross_entropy_loss', cross_entropy_mean) 732 | # # # # The total loss is defined as the cross entropy loss plus all of 733 | # # # # the weight decay terms (L2 loss). 734 | # # # return tf.add_n(tf.get_collection('losses'), name='total_loss') 735 | # return cross_entropy_mean 736 | 737 | ### Training functions 738 | 739 | def train_cross_entropy_loss(cross_entropy_loss): 740 | """ Training with Gold Label: Pretraining network to start with a better policy 741 | Args: cross_entropy_loss 742 | """ 743 | with tf.variable_scope('TrainCrossEntropyLoss') as scope: 744 | 745 | optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 746 | 747 | # Compute gradients of policy network 748 | policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PolicyNetwork") 749 | # print(policy_network_variables) 750 | grads_and_vars = optimizer.compute_gradients(cross_entropy_loss, var_list=policy_network_variables) 751 | # print(grads_and_vars) 752 | 753 | # Apply Gradients 754 | return optimizer.apply_gradients(grads_and_vars) 755 | 756 | def train_meansq_loss(futreward_meansq_loss): 757 | """ Training with Gold Label: Pretraining network to start with a better policy 758 | Args: futreward_meansq_loss 759 | """ 760 | with tf.variable_scope('TrainMeanSqLoss') as scope: 761 | optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 762 | 763 | # Compute gradients of Future reward estimator 764 | futreward_estimator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="FutureRewardEstimator") 765 | # print(futreward_estimator_variables) 766 | grads_and_vars = optimizer.compute_gradients(futreward_meansq_loss, var_list=futreward_estimator_variables) 767 | # print(grads_and_vars) 768 | 769 | # Apply Gradients 770 | return optimizer.apply_gradients(grads_and_vars) 771 | 772 | def train_neg_expectedreward(reward_weighted_cross_entropy_loss_multisample): 773 | """Training with Policy Gradient: Optimizing expected reward 774 | args: 775 | reward_weighted_cross_entropy_loss_multisample 776 | """ 777 | with tf.variable_scope('TrainExpReward') as scope: 778 | 779 | optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 780 | 781 | # Compute gradients of policy network 782 | policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PolicyNetwork") 783 | # print(policy_network_variables) 784 | 785 | # Compute gradients of policy network 786 | grads_and_vars = optimizer.compute_gradients(reward_weighted_cross_entropy_loss_multisample, var_list=policy_network_variables) 787 | # print(grads_and_vars) 788 | 789 | # Clip gradient: Pascanu et al. 2013, Exploding gradient problem 790 | grads_and_vars_capped_norm = [(tf.clip_by_norm(grad, 5.0), var) for grad, var in grads_and_vars] 791 | 792 | # Apply Gradients 793 | # return optimizer.apply_gradients(grads_and_vars) 794 | return optimizer.apply_gradients(grads_and_vars_capped_norm) 795 | 796 | # def train_neg_expectedreward(reward_weighted_cross_entropy_loss): 797 | # """Training with Policy Gradient: Optimizing expected reward 798 | # args: 799 | # reward_weighted_cross_entropy_loss 800 | # """ 801 | # with tf.variable_scope('TrainExpReward') as scope: 802 | 803 | # optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 804 | 805 | # # Compute gradients of policy network 806 | # policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PolicyNetwork") 807 | # # print(policy_network_variables) 808 | 809 | # # Compute gradients of policy network 810 | # grads_and_vars = optimizer.compute_gradients(reward_weighted_cross_entropy_loss, var_list=policy_network_variables) 811 | # # print(grads_and_vars) 812 | 813 | # # Clip gradient: Pascanu et al. 2013, Exploding gradient problem 814 | # grads_and_vars_capped_norm = [(tf.clip_by_norm(grad, 5.0), var) for grad, var in grads_and_vars] 815 | 816 | # # Apply Gradients 817 | # # return optimizer.apply_gradients(grads_and_vars) 818 | # return optimizer.apply_gradients(grads_and_vars_capped_norm) 819 | 820 | # def train_neg_expectedreward(logits, d_ltheta_ot): 821 | # """Training with Policy Gradient: Optimizing expected reward 822 | # args: 823 | # logits: Logits without softmax. [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 824 | # d_ltheta_ot: Placeholder [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 825 | # """ 826 | # with tf.variable_scope('TrainExpReward') as scope: 827 | 828 | # optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 829 | 830 | # # Modify logits with d_ltheta_ot 831 | # logits = tf.mul(logits, d_ltheta_ot) 832 | 833 | # # Compute gradients of policy network 834 | # policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PolicyNetwork") 835 | # # print(policy_network_variables) 836 | 837 | # # Compute gradients of policy network 838 | # grads_and_vars = optimizer.compute_gradients(logits, var_list=policy_network_variables) 839 | # # print(grads_and_vars) 840 | 841 | # # Clip gradient: Pascanu et al. 2013, Exploding gradient problem 842 | # grads_and_vars_capped_norm = [(tf.clip_by_norm(grad, 5.0), var) for grad, var in grads_and_vars] 843 | 844 | # # Apply Gradients 845 | # # return optimizer.apply_gradients(grads_and_vars) 846 | # return optimizer.apply_gradients(grads_and_vars_capped_norm) 847 | 848 | # def temp_train_neg_expectedreward(logits, d_ltheta_ot): 849 | # with tf.variable_scope('TempTrainExpReward') as scope: 850 | 851 | # optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='adam') 852 | 853 | # # Modify logits with d_ltheta_ot 854 | # logits = tf.mul(logits, d_ltheta_ot) 855 | 856 | # # Compute gradients of policy network 857 | # policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PolicyNetwork") 858 | # # print(policy_network_variables) 859 | 860 | # # Compute gradients of policy network 861 | # grads_and_vars = optimizer.compute_gradients(logits, var_list=policy_network_variables) 862 | 863 | # grads_and_vars_capped_norm = [(tf.clip_by_norm(grad, 5.0), var) for grad, var in grads_and_vars] 864 | 865 | # grads_and_vars_capped_val = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in grads_and_vars] 866 | 867 | 868 | 869 | # # tf.clip_by_norm(t, clip_norm, axes=None, name=None) 870 | # # https://www.tensorflow.org/versions/r0.11/api_docs/python/train/gradient_clipping 871 | 872 | 873 | 874 | # return grads_and_vars, grads_and_vars_capped_norm, grads_and_vars_capped_val 875 | 876 | 877 | 878 | 879 | ### Accuracy Calculations 880 | 881 | def accuracy(logits, labels, weights): 882 | """Estimate accuracy of predictions 883 | Args: 884 | logits: Logits from inference(). [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 885 | labels: Sentence extraction gold levels [FLAGS.batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 886 | weights: Weights to avoid padded part [FLAGS.batch_size, FLAGS.max_doc_length] 887 | Returns: 888 | Accuracy: Estimates average of accuracy for each sentence 889 | """ 890 | with tf.variable_scope('Accuracy') as scope: 891 | logits = tf.reshape(logits, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 892 | labels = tf.reshape(labels, [-1, FLAGS.target_label_size]) # [FLAGS.batch_size*FLAGS.max_doc_length, FLAGS.target_label_size] 893 | correct_pred = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1)) # [FLAGS.batch_size*FLAGS.max_doc_length] 894 | correct_pred = tf.reshape(correct_pred, [-1, FLAGS.max_doc_length]) # [FLAGS.batch_size, FLAGS.max_doc_length] 895 | correct_pred = tf.cast(correct_pred, tf.float32) 896 | # Get Accuracy 897 | accuracy = tf.reduce_mean(correct_pred, name='accuracy') 898 | if FLAGS.weighted_loss: 899 | correct_pred = tf.mul(correct_pred, weights) 900 | correct_pred = tf.reduce_sum(correct_pred, reduction_indices=1) # [FLAGS.batch_size] 901 | doc_lengths = tf.reduce_sum(weights, reduction_indices=1) # [FLAGS.batch_size] 902 | correct_pred_avg = tf.div(correct_pred, doc_lengths) 903 | accuracy = tf.reduce_mean(correct_pred_avg, name='accuracy') 904 | return accuracy 905 | 906 | # Improve it to show exact accuracy (top three ranked ones), not all. 907 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | #################################### 7 | 8 | """ 9 | Document Summarization Model Utilities 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from tensorflow.python.ops import variable_scope 19 | from tensorflow.python.ops import seq2seq 20 | 21 | # from tf.nn import variable_scope 22 | from my_flags import FLAGS 23 | 24 | ### Get Variable 25 | 26 | def variable_on_cpu(name, shape, initializer, trainable=True): 27 | """Helper to create a Variable stored on CPU memory. 28 | Args: 29 | name: name of the variable 30 | shape: list of ints 31 | initializer: initializer for Variable 32 | trainable: is trainable 33 | Returns: 34 | Variable Tensor 35 | """ 36 | with tf.device('/cpu:0'): 37 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 38 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable) 39 | return var 40 | 41 | def get_vocab_embed_variable(vocab_size): 42 | '''Returns vocab_embed_variable without any local initialization 43 | ''' 44 | vocab_embed_variable = "" 45 | if FLAGS.trainable_wordembed: 46 | vocab_embed_variable = variable_on_cpu("vocab_embed", [vocab_size, FLAGS.wordembed_size], tf.constant_initializer(0), trainable=True) 47 | else: 48 | vocab_embed_variable = variable_on_cpu("vocab_embed", [vocab_size, FLAGS.wordembed_size], tf.constant_initializer(0), trainable=False) 49 | 50 | return vocab_embed_variable 51 | 52 | def get_lstm_cell(): 53 | """Define LSTM Cell 54 | """ 55 | single_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.size) if (FLAGS.lstm_cell == "lstm") else tf.nn.rnn_cell.GRUCell(FLAGS.size) 56 | cell = single_cell 57 | if FLAGS.num_layers > 1: 58 | cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * FLAGS.num_layers) 59 | return cell 60 | 61 | ### Reshaping 62 | 63 | def reshape_tensor2list(tensor, n_steps, n_input): 64 | """Reshape tensor [?, n_steps, n_input] to lists of n_steps items with [?, n_input] 65 | """ 66 | # Prepare data shape to match `rnn` function requirements 67 | # Current data input shape (batch_size, n_steps, n_input) 68 | # Required shape: 'n_steps' tensors list of shape (batch_size, n_input) 69 | # 70 | # Permuting batch_size and n_steps 71 | tensor = tf.transpose(tensor, perm=[1, 0, 2], name='transpose') 72 | # Reshaping to (n_steps*batch_size, n_input) 73 | tensor = tf.reshape(tensor, [-1, n_input], name='reshape') 74 | # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input) 75 | tensor = tf.split(0, n_steps, tensor, name='split') 76 | return tensor 77 | 78 | def reshape_list2tensor(listoftensors, n_steps, n_input): 79 | """Reshape lists of n_steps items with [?, n_input] to tensor [?, n_steps, n_input] 80 | """ 81 | # Reverse of _reshape_tensor2list 82 | tensor = tf.concat(0, listoftensors, name="concat") # [n_steps * ?, n_input] 83 | tensor = tf.reshape(tensor, [n_steps, -1, n_input], name='reshape') # [n_steps, ?, n_input] 84 | tensor = tf.transpose(tensor, perm=[1, 0, 2], name='transpose') # [?, n_steps, n_input] 85 | return tensor 86 | 87 | ### Convolution, LSTM, RNNs 88 | 89 | def multilayer_perceptron(final_output, weights, biases): 90 | """MLP over output with attention over enc outputs 91 | Args: 92 | final_output: [batch_size x 2*size] 93 | Returns: 94 | logit: [batch_size x target_label_size] 95 | """ 96 | 97 | # Layer 1 98 | layer_1 = tf.add(tf.matmul(final_output, weights["h1"]), biases["b1"]) 99 | layer_1 = tf.nn.relu(layer_1) 100 | 101 | # Layer 2 102 | layer_2 = tf.add(tf.matmul(layer_1, weights["h2"]), biases["b2"]) 103 | layer_2 = tf.nn.relu(layer_2) 104 | 105 | # output layer 106 | layer_out = tf.add(tf.matmul(layer_2, weights["out"]), biases["out"]) 107 | 108 | return layer_out 109 | 110 | 111 | def conv1d_layer_sentence_representation(sent_wordembeddings): 112 | """Apply mulitple conv1d filters to extract sentence respresentations 113 | Args: 114 | sent_wordembeddings: [None, max_sent_length, wordembed_size] 115 | Returns: 116 | sent_representations: [None, sentembed_size] 117 | """ 118 | 119 | representation_from_filters = [] 120 | 121 | output_channel = 0 122 | if FLAGS.handle_filter_output == "sum": 123 | output_channel = FLAGS.sentembed_size 124 | else: # concat 125 | output_channel = FLAGS.sentembed_size / FLAGS.max_filter_length 126 | if (output_channel * FLAGS.max_filter_length != FLAGS.sentembed_size): 127 | print("Error: Make sure (output_channel * FLAGS.max_filter_length) is equal to FLAGS.sentembed_size.") 128 | exit(0) 129 | 130 | for filterwidth in xrange(1,FLAGS.max_filter_length+1): 131 | # print(filterwidth) 132 | 133 | with tf.variable_scope("Conv1D_%d"%filterwidth) as scope: 134 | 135 | # Convolution 136 | conv_filter = variable_on_cpu("conv_filter_%d" % filterwidth, [filterwidth, FLAGS.wordembed_size, output_channel], tf.truncated_normal_initializer()) 137 | # print(conv_filter.name, conv_filter.get_shape()) 138 | conv = tf.nn.conv1d(sent_wordembeddings, conv_filter, 1, padding='VALID') # [None, out_width=(max_sent_length-(filterwidth-1)), output_channel] 139 | conv_biases = variable_on_cpu("conv_biases_%d" % filterwidth, [output_channel], tf.constant_initializer(0.0)) 140 | pre_activation = tf.nn.bias_add(conv, conv_biases) 141 | conv = tf.nn.relu(pre_activation) # [None, out_width, output_channel] 142 | # print(conv.name, conv.get_shape()) 143 | 144 | # Max pool: Reshape conv to use max_pool 145 | conv_reshaped = tf.expand_dims(conv, 1) # [None, out_height:1, out_width, output_channel] 146 | # print(conv_reshaped.name, conv_reshaped.get_shape()) 147 | out_height = conv_reshaped.get_shape()[1].value 148 | out_width = conv_reshaped.get_shape()[2].value 149 | # print(out_height,out_width) 150 | maxpool = tf.nn.max_pool(conv_reshaped, [1,out_height,out_width,1], [1,1,1,1], padding='VALID') # [None, 1, 1, output_channel] 151 | # print(maxpool.name, maxpool.get_shape()) 152 | 153 | # Local Response Normalization 154 | maxpool_norm = tf.nn.lrn(maxpool, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) # Settings from cifar10 155 | # print(maxpool_norm.name, maxpool_norm.get_shape()) 156 | 157 | # Get back to original dimension 158 | maxpool_sqz = tf.squeeze(maxpool_norm, [1,2]) # [None, output_channel] 159 | # print(maxpool_sqz.name, maxpool_sqz.get_shape()) 160 | 161 | representation_from_filters.append(maxpool_sqz) 162 | # print(representation_from_filters) 163 | 164 | final_representation = [] 165 | with tf.variable_scope("FinalOut") as scope: 166 | if FLAGS.handle_filter_output == "sum": 167 | final_representation = tf.add_n(representation_from_filters) 168 | else: 169 | final_representation = tf.concat(1, representation_from_filters) 170 | 171 | return final_representation 172 | 173 | def simple_rnn(rnn_input, initial_state=None): 174 | """Implements Simple RNN 175 | Args: 176 | rnn_input: List of tensors of sizes [-1, sentembed_size] 177 | Returns: 178 | encoder_outputs, encoder_state 179 | """ 180 | # Setup cell 181 | cell_enc = get_lstm_cell() 182 | 183 | # Setup RNNs 184 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 185 | rnn_outputs, rnn_state = tf.nn.rnn(cell_enc, rnn_input, dtype=dtype, initial_state=initial_state) 186 | # print(rnn_outputs) 187 | # print(rnn_state) 188 | 189 | return rnn_outputs, rnn_state 190 | 191 | def simple_attentional_rnn(rnn_input, attention_state_list, initial_state=None): 192 | """Implements Simple RNN 193 | Args: 194 | rnn_input: List of tensors of sizes [-1, sentembed_size] 195 | attention_state_list: List of tensors of sizes [-1, sentembed_size] 196 | Returns: 197 | outputs, state 198 | """ 199 | 200 | # Reshape attention_state_list to tensor 201 | attention_states = reshape_list2tensor(attention_state_list, len(attention_state_list), FLAGS.sentembed_size) 202 | 203 | # Setup cell 204 | cell = get_lstm_cell() 205 | 206 | # Setup attentional RNNs 207 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 208 | 209 | # if initial_state == None: 210 | # batch_size = tf.shape(rnn_input[0])[0] 211 | # initial_state = cell.zero_state(batch_size, dtype) 212 | 213 | rnn_outputs, rnn_state = seq2seq.attention_decoder(rnn_input, initial_state, attention_states, cell, 214 | output_size=None, num_heads=1, loop_function=None, dtype=dtype, 215 | scope=None, initial_state_attention=False) 216 | # print(rnn_outputs) 217 | # print(rnn_state) 218 | return rnn_outputs, rnn_state 219 | 220 | ### Special decoders 221 | 222 | def jporg_attentional_seqrnn_decoder(sents_ext, encoder_outputs, encoder_state, sents_labels, weights, biases): 223 | """ 224 | Implements JP's special decoder: attention over encoder 225 | """ 226 | 227 | # Setup cell 228 | cell_ext = get_lstm_cell() 229 | 230 | # Define Sequential Decoder 231 | with variable_scope.variable_scope("JP_Decoder"): 232 | state = encoder_state 233 | extractor_logits = [] 234 | extractor_outputs = [] 235 | prev = None 236 | for i, inp in enumerate(sents_ext): 237 | if prev is not None: 238 | with variable_scope.variable_scope("loop_function"): 239 | inp = _loop_function(inp, extractor_logits[-1], sents_labels[i-1]) 240 | if i > 0: 241 | variable_scope.get_variable_scope().reuse_variables() 242 | # Create Cell 243 | output, state = cell_ext(inp, state) 244 | prev = output 245 | 246 | # Convert output to logit 247 | with variable_scope.variable_scope("mlp"): 248 | combined_output = [] # batch_size, 2*size 249 | if FLAGS.doc_encoder_reverse: 250 | combined_output = tf.concat(1, [output, encoder_outputs[(FLAGS.max_doc_length - 1) - i]]) 251 | else: 252 | combined_output = tf.concat(1, [output, encoder_outputs[i]]) 253 | 254 | logit = multilayer_perceptron(combined_output, weights, biases) 255 | 256 | extractor_logits.append(logit) 257 | extractor_outputs.append(combined_output) 258 | 259 | return extractor_outputs, extractor_logits 260 | 261 | ### Private Functions 262 | 263 | def _loop_function(current_inp, ext_logits, gold_logits): 264 | """ Update current input wrt previous logits 265 | Args: 266 | current_inp: [batch_size x sentence_embedding_size] 267 | ext_logits: [batch_size x target_label_size] [1, 0] 268 | gold_logits: [batch_size x target_label_size] 269 | Returns: 270 | updated_inp: [batch_size x sentence_embedding_size] 271 | """ 272 | 273 | prev_logits = gold_logits 274 | if not FLAGS.authorise_gold_label: 275 | prev_logits = ext_logits 276 | prev_logits = tf.nn.softmax(prev_logits) # [batch_size x target_label_size] 277 | 278 | prev_logits = tf.split(1, FLAGS.target_label_size, prev_logits) # [[batch_size], [batch_size], ...] 279 | prev_weight_one = prev_logits[0] 280 | 281 | updated_inp = tf.mul(current_inp, prev_weight_one) 282 | # print(updated_inp) 283 | 284 | return updated_inp 285 | 286 | 287 | ### SoftMax and Predictions 288 | 289 | def convert_logits_to_softmax(batch_logits, session=None): 290 | """ Convert logits to probabilities 291 | batch_logits: [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 292 | """ 293 | # Convert logits [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] to probabilities 294 | batch_logits = tf.reshape(batch_logits, [-1, FLAGS.target_label_size]) 295 | batch_softmax_logits = tf.nn.softmax(batch_logits) 296 | batch_softmax_logits = tf.reshape(batch_softmax_logits, [-1, FLAGS.max_doc_length, FLAGS.target_label_size]) 297 | # Convert back to numpy array 298 | batch_softmax_logits = batch_softmax_logits.eval(session=session) 299 | return batch_softmax_logits 300 | 301 | def predict_topranked(batch_softmax_logits, batch_weights, batch_filenames): 302 | """ Predict top ranked outputs: cnn:3, dm:4 303 | batch_softmax_logits: Numpy Array [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 304 | batch_weights: Numpy Array [batch_size, FLAGS.max_doc_length] 305 | batch_filenames: String [batch_size] 306 | Return: 307 | batch_predicted_labels: [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 308 | """ 309 | 310 | def process_to_chop_pad(orgids, requiredsize): 311 | if (len(orgids) >= requiredsize): 312 | return orgids[:requiredsize] 313 | else: 314 | padids = [0] * (requiredsize - len(orgids)) 315 | return (orgids + padids) 316 | 317 | batch_size = batch_softmax_logits.shape[0] 318 | 319 | # Numpy dtype 320 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 321 | 322 | batch_predicted_labels = np.empty((batch_size, FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) 323 | 324 | for batch_idx in range(batch_size): 325 | softmax_logits = batch_softmax_logits[batch_idx] 326 | weights = process_to_chop_pad(batch_weights[batch_idx], FLAGS.max_doc_length) 327 | filename = batch_filenames[batch_idx] 328 | 329 | # Find top scoring sentence to consider for summary, if score is same, select sentences with low indices 330 | oneprob_sentidx = {} 331 | for sentidx in range(FLAGS.max_doc_length): 332 | prob = softmax_logits[sentidx][0] # probability of predicting one 333 | weight = weights[sentidx] 334 | if weight == 1: 335 | if prob not in oneprob_sentidx: 336 | oneprob_sentidx[prob] = [sentidx] 337 | else: 338 | oneprob_sentidx[prob].append(sentidx) 339 | else: 340 | break 341 | oneprob_keys = oneprob_sentidx.keys() 342 | oneprob_keys.sort(reverse=True) 343 | 344 | # Rank sentences with scores: if same score lower ones ranked first 345 | sentindices = [] 346 | for oneprob in oneprob_keys: 347 | sent_withsamescore = oneprob_sentidx[oneprob] 348 | sent_withsamescore.sort() 349 | sentindices += sent_withsamescore 350 | 351 | # Select Top Sentences : CNN-3 and DM-4 352 | final_sentences = [] 353 | if filename.startswith("cnn-"): 354 | final_sentences = sentindices[:3] 355 | elif filename.startswith("dailymail-"): 356 | final_sentences = sentindices[:4] 357 | else: 358 | print(filename) 359 | print("Filename does not have cnn or dailymail in it.") 360 | exit(0) 361 | 362 | # Final Labels 363 | labels_vecs = [[1, 0] if (sentidx in final_sentences) else [0, 1] for sentidx in range(FLAGS.max_doc_length)] 364 | batch_predicted_labels[batch_idx] = np.array(labels_vecs[:], dtype=dtype) 365 | 366 | return batch_predicted_labels 367 | 368 | def predict_toprankedthree(batch_softmax_logits, batch_weights): 369 | """ Convert logits to probabilities 370 | batch_softmax_logits: Numpy Array [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 371 | batch_weights: Numpy Array [batch_size, FLAGS.max_doc_length], it may not be [batch_size, FLAGS.max_doc_length] called for validation and test sets, not padded. 372 | Return: 373 | batch_predicted_labels: [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 374 | """ 375 | 376 | def process_to_chop_pad(orgids, requiredsize): 377 | if (len(orgids) >= requiredsize): 378 | return orgids[:requiredsize] 379 | else: 380 | padids = [0] * (requiredsize - len(orgids)) 381 | return (orgids + padids) 382 | 383 | batch_size = batch_softmax_logits.shape[0] 384 | 385 | # Numpy dtype 386 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 387 | 388 | batch_predicted_labels = np.empty((batch_size, FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) 389 | 390 | for batch_idx in range(batch_size): 391 | softmax_logits = batch_softmax_logits[batch_idx] 392 | weights = process_to_chop_pad(batch_weights[batch_idx], FLAGS.max_doc_length) 393 | 394 | # Find top three scoring sentence to consider for summary, if score is same, select sentences with low indices 395 | oneprob_sentidx = {} 396 | for sentidx in range(FLAGS.max_doc_length): 397 | prob = softmax_logits[sentidx][0] # probability of predicting one 398 | weight = weights[sentidx] 399 | if weight == 1: 400 | if prob not in oneprob_sentidx: 401 | oneprob_sentidx[prob] = [sentidx] 402 | else: 403 | oneprob_sentidx[prob].append(sentidx) 404 | else: 405 | break 406 | oneprob_keys = oneprob_sentidx.keys() 407 | oneprob_keys.sort(reverse=True) 408 | 409 | # Rank sentences with scores: if same score lower ones ranked first 410 | sentindices = [] 411 | for oneprob in oneprob_keys: 412 | sent_withsamescore = oneprob_sentidx[oneprob] 413 | sent_withsamescore.sort() 414 | sentindices += sent_withsamescore 415 | 416 | # Select Top 3 417 | final_sentences = sentindices[:3] 418 | 419 | # Final Labels 420 | labels_vecs = [[1, 0] if (sentidx in final_sentences) else [0, 1] for sentidx in range(FLAGS.max_doc_length)] 421 | batch_predicted_labels[batch_idx] = np.array(labels_vecs[:], dtype=dtype) 422 | 423 | return batch_predicted_labels 424 | 425 | def sample_three_forsummary(batch_softmax_logits): 426 | """ Sample three ones to select in the summary 427 | batch_softmax_logits: Numpy Array [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 428 | Return: 429 | batch_predicted_labels: [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 430 | """ 431 | 432 | batch_size = batch_softmax_logits.shape[0] 433 | 434 | # Numpy dtype 435 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 436 | 437 | batch_sampled_labels = np.empty((batch_size, FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) 438 | 439 | for batch_idx in range(batch_size): 440 | softmax_logits = batch_softmax_logits[batch_idx] # [FLAGS.max_doc_length, FLAGS.target_label_size] 441 | 442 | # Collect probabilities for predicting one for a sentence 443 | sentence_ids = range(FLAGS.max_doc_length) 444 | sentence_oneprobs = [softmax_logits[sentidx][0] for sentidx in sentence_ids] 445 | normalized_sentence_oneprobs = [item/sum(sentence_oneprobs) for item in sentence_oneprobs] 446 | 447 | # Sample three sentences to select for summary from this distribution 448 | final_sentences = np.random.choice(sentence_ids, p=normalized_sentence_oneprobs, size=3, replace=False) 449 | 450 | # Final Labels 451 | labels_vecs = [[1, 0] if (sentidx in final_sentences) else [0, 1] for sentidx in range(FLAGS.max_doc_length)] 452 | batch_sampled_labels[batch_idx] = np.array(labels_vecs[:], dtype=dtype) 453 | 454 | return batch_sampled_labels 455 | 456 | def smaple_with_numpy_random_choice(sentence_ids, normalized_sentence_oneprobs, no_ones_tosample): 457 | sampled_final_sentences = np.random.choice(sentence_ids, p=normalized_sentence_oneprobs, size=no_ones_tosample, replace=False) 458 | sampled_final_sentences.sort() 459 | return sampled_final_sentences 460 | 461 | 462 | def multisample_three_forsummary(batch_softmax_logits, batch_gold_label, batch_weight): 463 | """ Sample three ones to select in the summary: Mix from gold and sampled, one sample always include gold 464 | batch_softmax_logits: Numpy Array [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 465 | batch_gold_label: Numpy Array [batch_size, FLAGS.max_doc_length, FLAGS.target_label_size] 466 | batch_weight: Numpy Array [batch_size, FLAGS.max_doc_length] 467 | Return: 468 | # batch_gold_sampled_label_multisample: [batch_size, FLAGS.num_sample_rollout, FLAGS.max_doc_length, FLAGS.target_label_size] 469 | batch_gold_sampled_labelstr_multisample: [batch_size, FLAGS.num_sample_rollout] 470 | """ 471 | 472 | # Start Sampling for each document 473 | batch_size = batch_softmax_logits.shape[0] 474 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 475 | # batch_gold_sampled_label_multisample = np.empty((batch_size, FLAGS.num_sample_rollout, FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) 476 | batch_gold_sampled_labelstr_multisample = np.empty((batch_size, FLAGS.num_sample_rollout), dtype="S20") 477 | 478 | # Number of ones to sample (this is required is a document has less than three sentences) 479 | batch_gold_label_one, _ = np.split(batch_gold_label, [1], axis=2) # [batch_size, max_doc_length, 1] 480 | # print(batch_gold_label_one[0]) 481 | batch_gold_label_tosample_one = np.squeeze(batch_gold_label_one, axis=2) # [batch_size, max_doc_length] 482 | # print(batch_gold_label_tosample_one[0]) 483 | batch_gold_label_tosample_count = np.sum(batch_gold_label_tosample_one , axis=1) # [batch_size] 484 | # print(batch_gold_label_tosample_count[0]) 485 | 486 | # Probabilities used to Sample 487 | batch_oneprob_usetosample, _ = np.split(batch_softmax_logits, [1], axis=2) # [batch_size, max_doc_length, 1] 488 | # print(batch_oneprob_usetosample[0]) 489 | batch_oneprob_usetosample = np.squeeze(batch_oneprob_usetosample, axis=2) # [batch_size, max_doc_length] 490 | # print(batch_oneprob_usetosample[0]) 491 | 492 | for batch_idx in range(batch_size): 493 | 494 | # Always keep the oracle sample 495 | # batch_gold_sampled_label_multisample[batch_idx][0] = np.array(batch_gold_label[batch_idx][:], dtype=dtype) 496 | one_labels = [] 497 | for sentidx in range(FLAGS.max_doc_length): 498 | if int(batch_gold_label[batch_idx][sentidx][0]) == 1: 499 | one_labels.append(str(sentidx)) 500 | batch_gold_sampled_labelstr_multisample[batch_idx][0] = "-".join(one_labels) 501 | 502 | # Rest: Sample (FLAGS.num_sample_rollout - 1) times 503 | sentence_oneprobs = batch_oneprob_usetosample[batch_idx] # [max_doc_length] 504 | # print(sentence_oneprobs) 505 | sentence_ids = range(len(sentence_oneprobs)) # [max_doc_length] 506 | 507 | # Make sure that it is not sampled from out of the document, 508 | # consider weight while normalising, if w = 0, p = exact 0, 509 | # if w is all zero, no_ones_tosample will be 0 and we wont be 510 | # here in the first place, and at the same time sum = 1 511 | 512 | weight_samplepart = batch_weight[batch_idx] # [max_doc_length] 513 | # print(weight_samplepart) 514 | 515 | # Smooth normalization, weight considered. Nonzero will alwyas be >= no_ones_tosample, because for them w = 1 516 | # Get sum: this will never be zero as w is never all 0 517 | l1_norm_sum = sum(np.multiply(sentence_oneprobs, weight_samplepart)) + (0.000000000001*sum(weight_samplepart)) 518 | normalized_sentence_oneprobs = [((item_prob+0.000000000001)*item_weight)/l1_norm_sum for item_prob, item_weight in zip(sentence_oneprobs, weight_samplepart)] 519 | # print(normalized_sentence_oneprobs) 520 | 521 | # Number of ones to sample 522 | no_ones_tosample = int(batch_gold_label_tosample_count[batch_idx]) 523 | # print(no_ones_tosample) 524 | 525 | # Start sampling (FLAGS.num_sample_rollout - 1) times 526 | for rollout_idx in range(1, FLAGS.num_sample_rollout): 527 | sampled_final_sentences_sorted = smaple_with_numpy_random_choice(sentence_ids, normalized_sentence_oneprobs, no_ones_tosample) 528 | 529 | # # Final Labels # This step here, will couse the following loop for the same sample, does not ignore duplicates or take adv of pool 530 | # sampled_labels_vecs = [[1, 0] if (sentidx in sampled_final_sentences) else [0, 1] for sentidx in sentence_ids] # [max_doc_length, target_label_size] 531 | # # Store 532 | # batch_gold_sampled_label_multisample[batch_idx][rollout_idx] = np.array(sampled_labels_vecs[:], dtype=dtype) 533 | batch_gold_sampled_labelstr_multisample[batch_idx][rollout_idx] = "-".join([str(sentidx) for sentidx in sampled_final_sentences_sorted]) 534 | 535 | return batch_gold_sampled_labelstr_multisample 536 | # return batch_gold_sampled_label_multisample, batch_gold_sampled_labelstr_multisample 537 | -------------------------------------------------------------------------------- /my_flags.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | #################################### 7 | 8 | """ 9 | My flags 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | ########### ============ Set Global FLAGS ============= ############# 20 | 21 | ### Temporary Directory to avoid conflict with others 22 | 23 | # VERY IMPORTANT # : SET this directory as TMP by exporting it 24 | 25 | tf.app.flags.DEFINE_string("tmp_directory", "/tmp", "Temporary directory used by rouge code.") 26 | 27 | tf.app.flags.DEFINE_string("use_gpu", "/gpu:3", "Specify which gpu to use.") 28 | 29 | ### Global setting 30 | 31 | tf.app.flags.DEFINE_string("exp_mode", "train", "Training 'train' or Test 'test' Mode.") 32 | 33 | tf.app.flags.DEFINE_integer("model_to_load", 100, "Model to load for testing.") 34 | 35 | tf.app.flags.DEFINE_boolean("use_fp16", False, "Use fp16 instead of fp32.") 36 | 37 | tf.app.flags.DEFINE_string("data_mode", "cnn", "cnn or dailymail or cnn-dailymail") 38 | 39 | ### Pretrained wordembeddings features 40 | 41 | tf.app.flags.DEFINE_integer("wordembed_size", 200, "Size of wordembedding (<= 200).") 42 | 43 | tf.app.flags.DEFINE_boolean("trainable_wordembed", False, "Is wordembedding trainable?") 44 | # UNK and PAD are always trainable and non-trainable respectively. 45 | 46 | ### Sentence level features 47 | 48 | tf.app.flags.DEFINE_integer("max_sent_length", 100, "Maximum sentence length (word per sent.)") 49 | 50 | tf.app.flags.DEFINE_integer("sentembed_size", 350, "Size of sentence embedding.") 51 | 52 | ### Document level features 53 | 54 | tf.app.flags.DEFINE_integer("max_doc_length", 110, "Maximum Document length (sent. per document).") 55 | 56 | tf.app.flags.DEFINE_integer("max_title_length", 0, "Maximum number of top title to consider.") # 1 57 | 58 | tf.app.flags.DEFINE_integer("max_image_length", 0, "Maximum number of top image captions to consider.") # 10 59 | 60 | tf.app.flags.DEFINE_integer("target_label_size", 2, "Size of target label (1/0).") 61 | 62 | ### Convolution Layer features 63 | 64 | tf.app.flags.DEFINE_integer("max_filter_length", 7, "Maximum filter length.") 65 | # Filter of sizes 1 to max_filter_length will be used, each producing 66 | # one vector. 1-7 same as Kim and JP. max_filter_length <= 67 | # max_sent_length 68 | 69 | tf.app.flags.DEFINE_string("handle_filter_output", "concat", "sum or concat") 70 | # If concat, make sure that sentembed_size is multiple of max_filter_length. 71 | # Sum is JP's model 72 | 73 | ### LSTM Features 74 | 75 | tf.app.flags.DEFINE_integer("size", 600, "Size of each model layer.") 76 | 77 | tf.app.flags.DEFINE_integer("num_layers", 1, "Number of layers in the model.") 78 | 79 | tf.app.flags.DEFINE_string("lstm_cell", "lstm", "Type of LSTM Cell: lstm or gru.") 80 | 81 | ### Encoder Layer features 82 | 83 | # Document Encoder: Unidirectional LSTM-RNNs 84 | tf.app.flags.DEFINE_boolean("doc_encoder_reverse", True, "Encoding sentences inorder or revorder.") 85 | 86 | ### Extractor Layer features 87 | 88 | tf.app.flags.DEFINE_boolean("attend_encoder", False, "Attend encoder outputs (JP model).") 89 | 90 | tf.app.flags.DEFINE_boolean("authorise_gold_label", True, "Authorise Gold Label for JP's Model.") 91 | 92 | ### Reinforcement Learning 93 | 94 | tf.app.flags.DEFINE_boolean("rouge_reward_fscore", True, "Fscore if true, otherwise recall.") # Not used, always use fscore 95 | 96 | tf.app.flags.DEFINE_integer("train_epoch_wce", 20, "Number of training epochs per step.") 97 | 98 | tf.app.flags.DEFINE_integer("num_sample_rollout", 10, "Number of Multiple Oracles Used.") # default 10 99 | 100 | ### Training features 101 | 102 | tf.app.flags.DEFINE_string("train_dir", "/address/to/training/directory", "Training directory.") 103 | 104 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") 105 | 106 | tf.app.flags.DEFINE_boolean("weighted_loss", True, "Weighted loss to ignore padded parts.") 107 | 108 | tf.app.flags.DEFINE_integer("batch_size", 20, "Batch size to use during training.") 109 | 110 | tf.app.flags.DEFINE_integer("training_checkpoint", 1, "How many training steps to do per checkpoint.") 111 | 112 | ###### Input file addresses: No change needed 113 | 114 | # Pretrained wordembeddings data 115 | 116 | tf.app.flags.DEFINE_string("pretrained_wordembedding", 117 | "/address/data/1-billion-word-language-modeling-benchmark-r13output.word2vec.vec", 118 | "Pretrained wordembedding file trained on the one million benchmark data.") 119 | 120 | # Data directory address 121 | 122 | tf.app.flags.DEFINE_string("preprocessed_data_directory", "/address/data/preprocessed-input-directory", 123 | "Pretrained news articles for various types of word embeddings.") 124 | 125 | tf.app.flags.DEFINE_string("gold_summary_directory", 126 | "/address/data/Baseline-Gold-Models", 127 | "Gold summary directory.") 128 | 129 | tf.app.flags.DEFINE_string("doc_sentence_directory", 130 | "/address/data/CNN-DM-Filtered-TokenizedSegmented", 131 | "Directory where document sentences are kept.") 132 | 133 | ############ Create FLAGS 134 | FLAGS = tf.app.flags.FLAGS 135 | -------------------------------------------------------------------------------- /my_model.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | # Comments: Jan 2017 7 | # Improved for Reinforcement Learning 8 | #################################### 9 | 10 | """ 11 | Document Summarization Final Model 12 | """ 13 | 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | import math 19 | import os 20 | import random 21 | import sys 22 | import time 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | import model_docsum 28 | from my_flags import FLAGS 29 | import model_utils 30 | 31 | ###################### Define Final Network ############################ 32 | 33 | class MY_Model: 34 | def __init__(self, sess, vocab_size): 35 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 36 | 37 | ### Few variables that has been initianlised here 38 | # Word embedding variable 39 | self.vocab_embed_variable = model_utils.get_vocab_embed_variable(vocab_size) 40 | 41 | ### Define Place Holders 42 | self.document_placeholder = tf.placeholder("int32", [None, 43 | (FLAGS.max_doc_length + FLAGS.max_title_length + FLAGS.max_image_length), 44 | FLAGS.max_sent_length], name='doc-ph') 45 | self.label_placeholder = tf.placeholder(dtype, [None, FLAGS.max_doc_length, FLAGS.target_label_size], name='label-ph') 46 | self.weight_placeholder = tf.placeholder(dtype, [None, FLAGS.max_doc_length], name='weight-ph') 47 | 48 | # Reward related place holders: Pass both rewards as place holders to make them constant for rl optimizer 49 | self.actual_reward_multisample_placeholder = tf.placeholder(dtype, [None, 1], name='actual-reward-multisample-ph') # [FLAGS.batch_size, Single Sample] 50 | 51 | # Self predicted label placeholder 52 | self.predicted_multisample_label_placeholder = tf.placeholder(dtype, [None, 1, FLAGS.max_doc_length, FLAGS.target_label_size], name='pred-multisample-label-ph') 53 | 54 | # Only used for test/validation corpus 55 | self.logits_placeholder = tf.placeholder(dtype, [None, FLAGS.max_doc_length, FLAGS.target_label_size], name='logits-ph') 56 | 57 | ### Define Policy Core Network: Consists of Encoder, Decoder and Convolution. 58 | self.extractor_output, self.logits = model_docsum.policy_network(self.vocab_embed_variable, self.document_placeholder, self.label_placeholder) 59 | 60 | ### Define Reward-Weighted Cross Entropy Loss 61 | self.rewardweighted_cross_entropy_loss_multisample = model_docsum.reward_weighted_cross_entropy_loss_multisample(self.logits, self.predicted_multisample_label_placeholder, 62 | self.actual_reward_multisample_placeholder, self.weight_placeholder) 63 | 64 | ### Define training operators 65 | self.train_op_policynet_expreward = model_docsum.train_neg_expectedreward(self.rewardweighted_cross_entropy_loss_multisample) 66 | 67 | # accuracy operation : exact match 68 | self.accuracy = model_docsum.accuracy(self.logits, self.label_placeholder, self.weight_placeholder) 69 | # final accuracy operation 70 | self.final_accuracy = model_docsum.accuracy(self.logits_placeholder, self.label_placeholder, self.weight_placeholder) 71 | 72 | # Create a saver. 73 | self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=None) 74 | 75 | # Scalar Summary Operations 76 | self.rewardweighted_ce_multisample_loss_summary = tf.scalar_summary("rewardweighted-cross-entropy-multisample-loss", self.rewardweighted_cross_entropy_loss_multisample) 77 | self.taccuracy_summary = tf.scalar_summary("training_accuracy", self.accuracy) 78 | self.vaccuracy_summary = tf.scalar_summary("validation_accuracy", self.final_accuracy) 79 | 80 | # # Build the summary operation based on the TF collection of Summaries. 81 | # # self.summary_op = tf.merge_all_summaries() 82 | 83 | # Build an initialization operation to run below. 84 | init = tf.initialize_all_variables() 85 | 86 | # Start running operations on the Graph. 87 | sess.run(init) 88 | 89 | # Create Summary Graph for Tensorboard 90 | self.summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) 91 | -------------------------------------------------------------------------------- /reward_utils.py: -------------------------------------------------------------------------------- 1 | #################################### 2 | # Author: Shashi Narayan 3 | # Date: September 2016 4 | # Project: Document Summarization 5 | # H2020 Summa Project 6 | #################################### 7 | 8 | """ 9 | Document Summarization Modules and Models 10 | """ 11 | 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | import random 20 | import os 21 | import re 22 | import os.path 23 | 24 | from pyrouge import Rouge155 25 | import json 26 | from multiprocessing import Pool 27 | from contextlib import closing 28 | 29 | from my_flags import FLAGS 30 | 31 | def _rouge(system_dir, gold_dir): 32 | # Run rouge 33 | r = Rouge155() 34 | r.system_dir = system_dir 35 | r.model_dir = gold_dir 36 | r.system_filename_pattern = '([a-zA-Z0-9]*).model' 37 | r.model_filename_pattern = '#ID#.gold' 38 | output = r.convert_and_evaluate(rouge_args="-e /address/to/rouge/data/directory/rouge/data -a -c 95 -m -n 4 -w 1.2") 39 | # print output 40 | output_dict = r.output_to_dict(output) 41 | # print output_dict 42 | 43 | # avg_rscore = 0 44 | # if FLAGS.rouge_reward_fscore: 45 | # avg_rscore = (output_dict["rouge_1_f_score"]+output_dict["rouge_2_f_score"]+ 46 | # output_dict["rouge_3_f_score"]+output_dict["rouge_4_f_score"]+ 47 | # output_dict["rouge_l_f_score"])/5.0 48 | # else: 49 | # avg_rscore = (output_dict["rouge_1_recall"]+output_dict["rouge_2_recall"]+ 50 | # output_dict["rouge_3_recall"]+output_dict["rouge_4_recall"]+ 51 | # output_dict["rouge_l_recall"])/5.0 52 | 53 | avg_rscore = (output_dict["rouge_1_f_score"]+output_dict["rouge_2_f_score"]+output_dict["rouge_l_f_score"])/3.0 54 | 55 | return avg_rscore 56 | 57 | def _rouge_wrapper_traindata(docname, final_labels, final_labels_str): 58 | # Gold Summary Directory : Always use original sentences 59 | gold_summary_directory = FLAGS.gold_summary_directory + "/gold-"+FLAGS.data_mode+"-training-org" 60 | gold_summary_fileaddress = gold_summary_directory + "/" + docname + ".gold" 61 | 62 | # Prepare Gold Model File 63 | os.system("mkdir -p "+FLAGS.tmp_directory+"/gold-"+docname+"-"+final_labels_str) 64 | os.system("cp "+gold_summary_fileaddress+" "+FLAGS.tmp_directory+"/gold-"+docname+"-"+final_labels_str+"/") 65 | 66 | # Document Sentence: Always use original sentences to generate summaries 67 | doc_sent_fileaddress = FLAGS.doc_sentence_directory + "/" + FLAGS.data_mode + "/training-sent/"+docname+".summary.final.org_sents" 68 | doc_sents = open(doc_sent_fileaddress).readlines() 69 | 70 | # Prepare Model file 71 | os.system("mkdir -p "+FLAGS.tmp_directory+"/model-"+docname+"-"+final_labels_str) 72 | 73 | # Write selected sentences 74 | labels_ones = [idx for idx in range(len(final_labels[:len(doc_sents)])) if final_labels[idx]=="1"] 75 | model_highlights = [doc_sents[idx] for idx in labels_ones] 76 | foutput = open(FLAGS.tmp_directory+"/model-"+docname+"-"+final_labels_str+"/"+docname+".model" , "w") 77 | foutput.write("".join(model_highlights)) 78 | foutput.close() 79 | 80 | return _rouge(FLAGS.tmp_directory+"/model-"+docname+"-"+final_labels_str, FLAGS.tmp_directory+"/gold-"+docname+"-"+final_labels_str) 81 | 82 | def _multi_run_wrapper(args): 83 | return _rouge_wrapper_traindata(*args) 84 | 85 | def _get_lcs(a, b): 86 | lengths = [[0 for j in range(len(b)+1)] for i in range(len(a)+1)] 87 | # row 0 and column 0 are initialized to 0 already 88 | for i, x in enumerate(a): 89 | for j, y in enumerate(b): 90 | if x == y: 91 | lengths[i+1][j+1] = lengths[i][j] + 1 92 | else: 93 | lengths[i+1][j+1] = max(lengths[i+1][j], lengths[i][j+1]) 94 | # read the substring out from the matrix 95 | result = [] 96 | x, y = len(a), len(b) 97 | while x != 0 and y != 0: 98 | if lengths[x][y] == lengths[x-1][y]: 99 | x -= 1 100 | elif lengths[x][y] == lengths[x][y-1]: 101 | y -= 1 102 | else: 103 | assert a[x-1] == b[y-1] 104 | result = [a[x-1]] + result 105 | x -= 1 106 | y -= 1 107 | return len(result) 108 | 109 | def _get_ngram_sets(highlights): 110 | set_1gram = set() 111 | set_2gram = set() 112 | set_3gram = set() 113 | set_4gram = set() 114 | fullen = len(highlights) 115 | for widx in range(fullen): 116 | # 1gram 117 | set_1gram.add(str(highlights[widx])) 118 | # 2gram 119 | if (widx+1) < fullen: 120 | set_2gram.add(str(highlights[widx])+"-"+str(highlights[widx+1])) 121 | # 3gram 122 | if (widx+2) < fullen: 123 | set_3gram.add(str(highlights[widx])+"-"+str(highlights[widx+1])+"-"+str(highlights[widx+2])) 124 | # 4gram 125 | if (widx+3) < fullen: 126 | set_4gram.add(str(highlights[widx])+"-"+str(highlights[widx+1])+"-"+str(highlights[widx+2])+"-"+str(highlights[widx+3])) 127 | return set_1gram, set_2gram, set_3gram, set_4gram 128 | 129 | def _rouge_wrapper_traindata_nopyrouge(docname, final_labels_str, document, highlights): 130 | cand_highlights_full = [] 131 | for sentidx in final_labels_str.split("-"): 132 | cand_highlights_full += [wordid for wordid in document[int(sentidx)] if wordid != 0] 133 | cand_highlights_full.append(0) 134 | highlights_full = [] 135 | for sent in highlights: 136 | highlights_full += sent 137 | highlights_full.append(0) 138 | # print(cand_highlights_full,highlights_full) 139 | 140 | # Get sets 141 | cand_1gram, cand_2gram, cand_3gram, cand_4gram = _get_ngram_sets(cand_highlights_full) 142 | # print(cand_1gram, cand_2gram, cand_3gram, cand_4gram) 143 | gold_1gram, gold_2gram, gold_3gram, gold_4gram = _get_ngram_sets(highlights_full) 144 | # print(gold_1gram, gold_2gram, gold_3gram, gold_4gram) 145 | 146 | # Get ROUGE-N recalls 147 | rouge_recall_1 = 0 148 | if len(gold_1gram) != 0: 149 | rouge_recall_1 = float(len(gold_1gram.intersection(cand_1gram)))/float(len(gold_1gram)) 150 | rouge_recall_2 = 0 151 | if len(gold_2gram) != 0: 152 | rouge_recall_2 = float(len(gold_2gram.intersection(cand_2gram)))/float(len(gold_2gram)) 153 | rouge_recall_3 = 0 154 | if len(gold_3gram) != 0: 155 | rouge_recall_3 = float(len(gold_3gram.intersection(cand_3gram)))/float(len(gold_3gram)) 156 | rouge_recall_4 = 0 157 | if len(gold_4gram) != 0: 158 | rouge_recall_4 = float(len(gold_4gram.intersection(cand_4gram)))/float(len(gold_4gram)) 159 | 160 | # Get ROUGE-L 161 | len_lcs = _get_lcs(cand_highlights_full, highlights_full) 162 | r = 0 if (len_lcs == 0) else (float(len_lcs)/len(cand_highlights_full)) 163 | p = 0 if (len_lcs == 0) else (float(len_lcs)/len(highlights_full)) 164 | b = 0 if (r == 0) else (p / r) 165 | rouge_recall_l = 0 if (len_lcs == 0) else (((1+(b*b))*r*p)/(r+(b*b*p))) 166 | 167 | rouge_recall_average = (rouge_recall_1+rouge_recall_2+rouge_recall_3+rouge_recall_4+rouge_recall_l)/5.0 168 | # print(rouge_recall_1, rouge_recall_2, rouge_recall_3, rouge_recall_4, rouge_recall_l, rouge_recall_average) 169 | 170 | # Get final labels 171 | final_labels = [[1, 0] if (str(sentidx) in final_labels_str.split("-")) else [0, 1] for sentidx in range(FLAGS.max_doc_length)] # [max_doc_length, target_label_size] 172 | 173 | return rouge_recall_average, final_labels 174 | 175 | def _multi_run_wrapper_nopyrouge(args): 176 | return _rouge_wrapper_traindata_nopyrouge(*args) 177 | 178 | class Reward_Generator: 179 | def __init__(self): 180 | self.rouge_dict = {} 181 | 182 | # Start a pool 183 | self.pool = Pool(10) 184 | 185 | def save_rouge_dict(self): 186 | with open(FLAGS.train_dir+"/rouge-dict.json", 'w') as outfile: 187 | json.dump(self.rouge_dict, outfile) 188 | 189 | def restore_rouge_dict(self): 190 | self.rouge_dict = {} 191 | if os.path.isfile(FLAGS.train_dir+"/rouge-dict.json"): 192 | with open(FLAGS.train_dir+"/rouge-dict.json") as data_file: 193 | self.rouge_dict = json.load(data_file) 194 | 195 | def get_full_rouge(self, system_dir, datatype): 196 | # Gold Directory: Always use original files 197 | gold_summary_directory = FLAGS.gold_summary_directory + "/gold-"+FLAGS.data_mode+"-"+datatype+"-orgcase" 198 | 199 | rouge_score = _rouge(system_dir, gold_summary_directory) 200 | 201 | # Delete any tmp file 202 | os.system("rm -r "+FLAGS.tmp_directory+"/tmp*") 203 | 204 | return rouge_score 205 | 206 | # def get_batch_rouge(self, batch_docnames, batch_predicted_labels): 207 | 208 | # # Numpy dtype 209 | # dtype = np.float16 if FLAGS.use_fp16 else np.float32 210 | 211 | # # Batch Size 212 | # batch_size = len(batch_docnames) 213 | 214 | # # batch_rouge 215 | # batch_rouge = np.empty(batch_size, dtype=dtype) 216 | 217 | # # Estimate list of arguments to run pool 218 | # didx_list = [] 219 | # docname_labels_list = [] 220 | # for docindex in range(batch_size): 221 | # docname = batch_docnames[docindex] 222 | # predicted_labels = batch_predicted_labels[docindex] 223 | 224 | # # Prepare final labels for summary generation 225 | # final_labels = [str(int(predicted_labels[sentidx][0])) for sentidx in range(FLAGS.max_doc_length)] 226 | # # print(final_labels) 227 | 228 | # isfound = False 229 | # rougescore = 0.0 230 | # if docname in self.rouge_dict: 231 | # final_labels_string = "".join(final_labels) 232 | # if final_labels_string in self.rouge_dict[docname]: 233 | # rougescore = self.rouge_dict[docname][final_labels_string] 234 | # isfound = True 235 | 236 | # if isfound: 237 | # # Update batch_rouge 238 | # batch_rouge[docindex] = rougescore 239 | # else: 240 | # didx_list.append(docindex) 241 | # docname_labels_list.append((docname, final_labels)) 242 | 243 | # # Run parallel pool 244 | # if(len(didx_list) > 0): 245 | # # Run in parallel 246 | # rougescore_list = self.pool.map(_multi_run_wrapper,docname_labels_list) 247 | # # Process results 248 | # for didx, rougescore, docname_labels in zip(didx_list, rougescore_list, docname_labels_list): 249 | # # Update batch_rouge 250 | # batch_rouge[didx] = rougescore 251 | 252 | # # Update rouge dict 253 | # docname = docname_labels[0] 254 | # final_labels_string = "".join(docname_labels[1]) 255 | # if docname not in self.rouge_dict: 256 | # self.rouge_dict[docname] = {final_labels_string:rougescore} 257 | # else: 258 | # self.rouge_dict[docname][final_labels_string] = rougescore 259 | # # Delete any tmp file 260 | # os.system("rm -r "+ FLAGS.tmp_directory+"/tmp* " + FLAGS.tmp_directory+"/gold-* " + FLAGS.tmp_directory+"/model-*") 261 | # # print(self.rouge_dict) 262 | # return batch_rouge 263 | 264 | def get_batch_rouge_withmultisample(self, batch_docnames, batch_predicted_labels_multisample): 265 | """ 266 | Args: 267 | batch_docnames: [batch_size] 268 | batch_predicted_labels_multisample: [batch_size, rollout_count, FLAGS.max_doc_length, FLAGS.target_label_size] 269 | Return: 270 | rougescore: [batch_size, FLAGS.num_sample_rollout] 271 | """ 272 | 273 | # Numpy dtype 274 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 275 | 276 | # Batch Size and sample rollout count 277 | batch_size = len(batch_docnames) 278 | rollout_count = batch_predicted_labels_multisample.shape[1] 279 | 280 | # batch_rouge 281 | batch_rouge_multisample = np.empty((batch_size, rollout_count), dtype=dtype) 282 | 283 | # Prepare of all rollout labels dict and prepare docname_labels_list to run 284 | docname_labels_rollout_dict = {} 285 | docname_labels_list = [] 286 | for docindex in range(batch_size): 287 | docname = batch_docnames[docindex] 288 | # print(docname) 289 | 290 | for rolloutidx in range(rollout_count): 291 | predicted_labels = batch_predicted_labels_multisample[docindex][rolloutidx] # [FLAGS.max_doc_length, FLAGS.target_label_size] 292 | # Prepare final labels for summary generation 293 | final_labels = [] 294 | final_labels_sindices = [] 295 | for sentidx in range(FLAGS.max_doc_length): 296 | final_labels.append(str(int(predicted_labels[sentidx][0]))) 297 | if int(predicted_labels[sentidx][0]) == 1: 298 | final_labels_sindices.append(str(sentidx+1)) 299 | final_labels_string = "-".join(final_labels_sindices) 300 | 301 | # print(final_labels,final_labels_string) 302 | 303 | isfound = False 304 | rougescore = 0.0 305 | if docname in self.rouge_dict: 306 | if final_labels_string in self.rouge_dict[docname]: 307 | rougescore = self.rouge_dict[docname][final_labels_string] 308 | isfound = True 309 | 310 | if isfound: 311 | # Update batch_rouge 312 | batch_rouge_multisample[docindex][rolloutidx] = rougescore 313 | else: 314 | if docname not in docname_labels_rollout_dict: 315 | docname_labels_rollout_dict[docname] = [docindex, {final_labels_string:[rolloutidx]}] 316 | docname_labels_list.append((docname, final_labels, final_labels_string)) 317 | else: 318 | if final_labels_string not in docname_labels_rollout_dict[docname][1]: 319 | docname_labels_rollout_dict[docname][1][final_labels_string] = [rolloutidx] 320 | docname_labels_list.append((docname, final_labels, final_labels_string)) 321 | else: 322 | docname_labels_rollout_dict[docname][1][final_labels_string].append(rolloutidx) 323 | # no need to add to docname_labels_list 324 | 325 | # print(docname_labels_list) 326 | # Run parallel pool 327 | if(len(docname_labels_list) > 0): 328 | # Run in parallel 329 | with closing(Pool(10)) as mypool: 330 | rougescore_list = mypool.map(_multi_run_wrapper,docname_labels_list) 331 | # rougescore_list = self.pool.map(_multi_run_wrapper,docname_labels_list) 332 | 333 | # Process results 334 | for rougescore, docname_labels in zip(rougescore_list, docname_labels_list): 335 | docname = docname_labels[0] 336 | final_labels = docname_labels[1] 337 | final_labels_string = docname_labels[2] 338 | 339 | # Update batch_rouge 340 | docindex = docname_labels_rollout_dict[docname][0] 341 | for rolloutidx in docname_labels_rollout_dict[docname][1][final_labels_string]: 342 | batch_rouge_multisample[docindex][rolloutidx] = rougescore 343 | 344 | # Update rouge dict 345 | if docname not in self.rouge_dict: 346 | self.rouge_dict[docname] = {final_labels_string:rougescore} 347 | else: 348 | self.rouge_dict[docname][final_labels_string] = rougescore 349 | # Delete any tmp file 350 | os.system("rm -r "+ FLAGS.tmp_directory+"/tmp* " + FLAGS.tmp_directory+"/gold-* " + FLAGS.tmp_directory+"/model-*") 351 | # print(self.rouge_dict) 352 | 353 | return batch_rouge_multisample 354 | 355 | def get_batch_rouge_withmultisample_nopyrouge(self, batch_docnames, batch_predicted_labels_multisample_str, batch_docs, batch_highlights_nonnumpy): 356 | """ 357 | Args: 358 | batch_docnames: [batch_size] 359 | batch_predicted_labels_multisample_str: [batch_size, rollout_count] 360 | batch_docs: [batch_size, FLAGS.max_doc_length, FLAGS.max_sent_length] 361 | batch_highlights_nonnumpy: [batch_size, highlights_lengths, each_highlights] 362 | 363 | Return: 364 | rougescore: [batch_size, FLAGS.num_sample_rollout] 365 | batch_gold_sampled_label_multisample: [batch_size, FLAGS.num_sample_rollout, FLAGS.max_doc_length, FLAGS.target_label_size] 366 | """ 367 | 368 | # Numpy dtype 369 | dtype = np.float16 if FLAGS.use_fp16 else np.float32 370 | 371 | # Batch Size and sample rollout count 372 | batch_size = len(batch_docnames) 373 | rollout_count = batch_predicted_labels_multisample_str.shape[1] 374 | 375 | # batch_rouge 376 | batch_rouge_multisample = np.empty((batch_size, rollout_count), dtype=dtype) 377 | batch_gold_sampled_label_multisample = np.empty((batch_size, rollout_count, FLAGS.max_doc_length, FLAGS.target_label_size), dtype=dtype) 378 | 379 | # Prepare of all rollout labels dict and prepare docname_labels_list to run 380 | docname_labels_rollout_dict = {} 381 | docname_labels_list = [] 382 | for docindex in range(batch_size): 383 | docname = batch_docnames[docindex] 384 | document = batch_docs[docindex] 385 | highlights = batch_highlights_nonnumpy[docindex] 386 | # print(docname) 387 | 388 | for rolloutidx in range(rollout_count): 389 | final_labels_string = batch_predicted_labels_multisample_str[docindex][rolloutidx] 390 | # print(final_labels_string) 391 | 392 | if docname not in docname_labels_rollout_dict: 393 | docname_labels_rollout_dict[docname] = [docindex, {final_labels_string:[rolloutidx]}] 394 | docname_labels_list.append((docname, final_labels_string, document, highlights)) 395 | else: 396 | if final_labels_string not in docname_labels_rollout_dict[docname][1]: 397 | docname_labels_rollout_dict[docname][1][final_labels_string] = [rolloutidx] 398 | docname_labels_list.append((docname, final_labels_string, document, highlights)) 399 | else: 400 | docname_labels_rollout_dict[docname][1][final_labels_string].append(rolloutidx) 401 | # no need to add to docname_labels_list 402 | 403 | # isfound = False 404 | # rougescore = 0.0 405 | # if docname in self.rouge_dict: 406 | # if final_labels_string in self.rouge_dict[docname]: 407 | # rougescore = self.rouge_dict[docname][final_labels_string] 408 | # isfound = True 409 | 410 | # if isfound: 411 | # # Update batch_rouge 412 | # batch_rouge_multisample[docindex][rolloutidx] = rougescore 413 | # else: 414 | # if docname not in docname_labels_rollout_dict: 415 | # docname_labels_rollout_dict[docname] = [docindex, {final_labels_string:[rolloutidx]}] 416 | # docname_labels_list.append((docname, final_labels_string, document, highlights)) 417 | # else: 418 | # if final_labels_string not in docname_labels_rollout_dict[docname][1]: 419 | # docname_labels_rollout_dict[docname][1][final_labels_string] = [rolloutidx] 420 | # docname_labels_list.append((docname, final_labels_string, document, highlights)) 421 | # else: 422 | # docname_labels_rollout_dict[docname][1][final_labels_string].append(rolloutidx) 423 | # # no need to add to docname_labels_list 424 | 425 | # print(docname_labels_rollout_dict ) 426 | # print(docname_labels_list) 427 | 428 | # Run parallel pool 429 | if(len(docname_labels_list) > 0): 430 | # Run in parallel 431 | # with closing(Pool(10)) as mypool: 432 | # rougescore_finallabels_list = mypool.map(_multi_run_wrapper_nopyrouge,docname_labels_list) 433 | rougescore_finallabels_list = self.pool.map(_multi_run_wrapper_nopyrouge,docname_labels_list) 434 | 435 | # Process results 436 | for rougescore_finallabels, docname_labels in zip(rougescore_finallabels_list, docname_labels_list): 437 | rougescore = rougescore_finallabels[0] 438 | finallabels = rougescore_finallabels[1] 439 | docname = docname_labels[0] 440 | final_labels_string = docname_labels[1] 441 | 442 | # Update batch_rouge 443 | docindex = docname_labels_rollout_dict[docname][0] 444 | for rolloutidx in docname_labels_rollout_dict[docname][1][final_labels_string]: 445 | batch_rouge_multisample[docindex][rolloutidx] = rougescore 446 | batch_gold_sampled_label_multisample[docindex][rolloutidx] = np.array(finallabels[:], dtype=dtype) 447 | 448 | # # Update rouge dict 449 | # if docname not in self.rouge_dict: 450 | # self.rouge_dict[docname] = {final_labels_string:rougescore} 451 | # else: 452 | # self.rouge_dict[docname][final_labels_string] = rougescore 453 | 454 | # print(self.rouge_dict) 455 | 456 | return batch_rouge_multisample, batch_gold_sampled_label_multisample 457 | 458 | -------------------------------------------------------------------------------- /scripts/oracle-estimator/estimate_multiple_oracles.py: -------------------------------------------------------------------------------- 1 | 2 | # Credits 3 | # Written by Shashi Narayan to use original ROUGE 4 | # Improved by Yang Liu to use a must faster ROUGE 5 | 6 | 7 | import os 8 | import re 9 | import itertools 10 | from multiprocessing import Pool 11 | import sys 12 | from preprocess import rouge 13 | import codecs 14 | 15 | 16 | def cal_rouge(fullset, sentdata, golddata): 17 | fullset.sort() 18 | model_highlights = [sentdata[idx] for idx in range(len(sentdata)) if idx in fullset] 19 | rouge_1 = rouge.rouge_n(model_highlights, golddata, 1)['f'] 20 | rouge_2 = rouge.rouge_n(model_highlights, golddata, 2)['f'] 21 | rouge_l = rouge.rouge_l_summary_level(model_highlights, golddata)['f'] 22 | rouge_score = (rouge_1 + rouge_2 + rouge_l)/3.0 23 | return (rouge_score, fullset) 24 | 25 | 26 | def _multi_run_wrapper(args): 27 | return cal_rouge(*args) 28 | 29 | 30 | def get_fileids(topdir, newstype, datatype, server): 31 | if newstype == "cnn": 32 | return open(topdir + "/Temp-ServerFileIds/cnn-" + datatype + "-fileids.txt." + server).read().strip().split( 33 | "\n") 34 | 35 | if newstype == "dailymail": 36 | return open( 37 | topdir + "/Temp-ServerFileIds/dailymail-" + datatype + "-fileids.txt." + server).read().strip().split("\n") 38 | 39 | print "Error: Only CNN and DailyMail allowed." 40 | exit(0) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | args = sys.argv[1:] 46 | 47 | pool = Pool(int(args[0])) 48 | data_dir = args[1] 49 | task = int(args[2]) 50 | sent_limit = 4 51 | 52 | fullcount = 0 53 | 54 | mainbody_dir = os.path.join(data_dir, 'article') 55 | highlights_dir = os.path.join(data_dir, 'abstracts') 56 | 57 | _listfiles = os.listdir(mainbody_dir) 58 | all_l = len(_listfiles) 59 | _listfiles = sorted(_listfiles) 60 | if (task == 0): 61 | listfiles = _listfiles[:int(all_l / 3)] 62 | elif (task == 1): 63 | listfiles = _listfiles[int(all_l / 3):int(all_l * 2 / 3)] 64 | elif (task == 2): 65 | listfiles = _listfiles[int(all_l * 2 / 3):] 66 | 67 | fianllabeldir = os.path.join(data_dir, 'article-oracles') 68 | if (not os.path.exists(fianllabeldir)): 69 | os.mkdir(fianllabeldir) 70 | 71 | for summaryfname in listfiles: 72 | if not re.match('^\d', summaryfname): 73 | continue 74 | summaryfname = summaryfname.split('.')[0] 75 | if os.path.isfile(os.path.join(fianllabeldir, summaryfname + ".f-sent")) and os.path.isfile( 76 | os.path.join(fianllabeldir, summaryfname + ".moracle")): 77 | continue 78 | print(summaryfname) 79 | sentfull = os.path.join(mainbody_dir, summaryfname + ".article") 80 | sentdata = (codecs.open(sentfull, encoding='utf-8').readlines()) # full doc 81 | goldfull = os.path.join(highlights_dir, summaryfname + ".abstracts") 82 | golddata = (codecs.open(goldfull, encoding='utf-8').readlines()) # full doc 83 | 84 | rougesentwisefile = os.path.join(fianllabeldir, summaryfname + ".f-sent") 85 | 86 | sentids_lst = [[sentid] for sentid in range(len(sentdata))] 87 | rougescore_sentwise = [] 88 | for sentids in sentids_lst: 89 | rougescore_sentwise.append(cal_rouge(sentids, sentdata, golddata)) 90 | 91 | # print rougescore_sentwise 92 | foutput = open(rougesentwisefile, "w") 93 | foutput.write("\n".join([str(item[0]) + "\t" + str(item[1][0]) for item in rougescore_sentwise]) + "\n") 94 | foutput.close() 95 | rougescore_sentwise.sort(reverse=True) 96 | 97 | toprougesentences = [item[1][0] for item in rougescore_sentwise[:10]] 98 | toprougesentences.sort() 99 | 100 | labelfullmodif = fianllabeldir + "/" + summaryfname + ".moracle" 101 | rougescore_sentids = [] 102 | rougescore_sentids += rougescore_sentwise[:10][:] 103 | 104 | arguments_list = [] 105 | for itemcount in range(2, sent_limit + 1): 106 | arguments_list += [(list(sentids), sentdata, golddata) for sentids in 107 | itertools.combinations(toprougesentences, itemcount)] 108 | rougescore_sentids = pool.map(_multi_run_wrapper, arguments_list) 109 | 110 | # Process results 111 | # for rougescore, arguments in zip(rougescore_list, arguments_list): 112 | # rougescore_sentids.append((rougescore, arguments[0])) 113 | 114 | # rougescore_sentids = [] 115 | # for sentids in sentids_lst: 116 | # rougescore_sentids.append(cal_rouge(sentids, sentdata, golddata)) 117 | 118 | rougescore_sentids.sort(reverse=True) 119 | 120 | foutput = open(labelfullmodif, "w") 121 | for item in rougescore_sentids: 122 | foutput.write((" ".join([str(sentidx) for sentidx in item[1]])) + "\t" + str(item[0]) + "\n") 123 | foutput.close() 124 | 125 | if fullcount % 100 == 0: 126 | print fullcount 127 | fullcount += 1 128 | -------------------------------------------------------------------------------- /scripts/oracle-estimator/rouge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2017 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ROUGE Metric Implementation 16 | 17 | This is a very slightly version of: 18 | https://github.com/pltrdy/seq2seq/blob/master/seq2seq/metrics/rouge.py 19 | 20 | --- 21 | 22 | ROUGe metric implementation. 23 | 24 | This is a modified and slightly extended verison of 25 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 26 | """ 27 | from __future__ import absolute_import 28 | from __future__ import division, print_function, unicode_literals 29 | import itertools 30 | 31 | 32 | def _get_ngrams(n, text): 33 | """Calcualtes n-grams. 34 | 35 | Args: 36 | n: which n-grams to calculate 37 | text: An array of tokens 38 | 39 | Returns: 40 | A set of n-grams 41 | """ 42 | ngram_set = set() 43 | text_length = len(text) 44 | max_index_ngram_start = text_length - n 45 | for i in range(max_index_ngram_start + 1): 46 | ngram_set.add(tuple(text[i:i + n])) 47 | return ngram_set 48 | 49 | 50 | def _split_into_words(sentences): 51 | """Splits multiple sentences into words and flattens the result""" 52 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 53 | 54 | 55 | def _get_word_ngrams(n, sentences): 56 | """Calculates word n-grams for multiple sentences. 57 | """ 58 | assert len(sentences) > 0 59 | assert n > 0 60 | 61 | words = _split_into_words(sentences) 62 | return _get_ngrams(n, words) 63 | 64 | 65 | def _len_lcs(x, y): 66 | """ 67 | Returns the length of the Longest Common Subsequence between sequences x 68 | and y. 69 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 70 | 71 | Args: 72 | x: sequence of words 73 | y: sequence of words 74 | 75 | Returns 76 | integer: Length of LCS between x and y 77 | """ 78 | table = _lcs(x, y) 79 | n, m = len(x), len(y) 80 | return table[n, m] 81 | 82 | 83 | def _lcs(x, y): 84 | """ 85 | Computes the length of the longest common subsequence (lcs) between two 86 | strings. The implementation below uses a DP programming algorithm and runs 87 | in O(nm) time where n = len(x) and m = len(y). 88 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 89 | 90 | Args: 91 | x: collection of words 92 | y: collection of words 93 | 94 | Returns: 95 | Table of dictionary of coord and len lcs 96 | """ 97 | n, m = len(x), len(y) 98 | table = dict() 99 | for i in range(n + 1): 100 | for j in range(m + 1): 101 | if i == 0 or j == 0: 102 | table[i, j] = 0 103 | elif x[i - 1] == y[j - 1]: 104 | table[i, j] = table[i - 1, j - 1] + 1 105 | else: 106 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 107 | return table 108 | 109 | 110 | def _recon_lcs(x, y): 111 | """ 112 | Returns the Longest Subsequence between x and y. 113 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 114 | Args: 115 | x: sequence of words 116 | y: sequence of words 117 | Returns: 118 | sequence: LCS of x and y 119 | """ 120 | i, j = len(x), len(y) 121 | table = _lcs(x, y) 122 | 123 | def _recon(i, j): 124 | """private recon calculation""" 125 | if i == 0 or j == 0: 126 | return [] 127 | elif x[i - 1] == y[j - 1]: 128 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 129 | elif table[i - 1, j] > table[i, j - 1]: 130 | return _recon(i - 1, j) 131 | else: 132 | return _recon(i, j - 1) 133 | 134 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 135 | return recon_tuple 136 | 137 | 138 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 139 | """ 140 | Computes ROUGE-N of two text collections of sentences. 141 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 142 | papers/rouge-working-note-v1.3.1.pdf 143 | 144 | Args: 145 | evaluated_sentences: The sentences that have been picked by the 146 | summarizer 147 | reference_sentences: The sentences from the referene set 148 | n: Size of ngram. Defaults to 2. 149 | 150 | Returns: 151 | A tuple (f1, precision, recall) for ROUGE-N 152 | 153 | Raises: 154 | ValueError: raises exception if a param has len <= 0 155 | """ 156 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 157 | raise ValueError("Collections must contain at least 1 sentence.") 158 | 159 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 160 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 161 | reference_count = len(reference_ngrams) 162 | evaluated_count = len(evaluated_ngrams) 163 | 164 | # Gets the overlapping ngrams between evaluated and reference 165 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 166 | overlapping_count = len(overlapping_ngrams) 167 | 168 | # Handle edge case. This isn't mathematically correct, but it's good enough 169 | if evaluated_count == 0: 170 | precision = 0.0 171 | else: 172 | precision = overlapping_count / evaluated_count 173 | 174 | if reference_count == 0: 175 | recall = 0.0 176 | else: 177 | recall = overlapping_count / reference_count 178 | 179 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 180 | 181 | return {"f": f1_score, "p": precision, "r": recall} 182 | 183 | 184 | def _union_lcs(evaluated_sentences, reference_sentence, prev_union=None): 185 | """ 186 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 187 | subsequence between reference sentence ri and candidate summary C. 188 | For example: 189 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 190 | and c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 191 | is "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". 192 | The union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" 193 | and LCS_u(r_i, C) = 4/5. 194 | 195 | Args: 196 | evaluated_sentences: The sentences that have been picked by the 197 | summarizer 198 | reference_sentence: One of the sentences in the reference summaries 199 | 200 | Returns: 201 | float: LCS_u(r_i, C) 202 | 203 | ValueError: 204 | Raises exception if a param has len <= 0 205 | """ 206 | if prev_union is None: 207 | prev_union = set() 208 | 209 | if len(evaluated_sentences) <= 0: 210 | raise ValueError("Collections must contain at least 1 sentence.") 211 | 212 | lcs_union = prev_union 213 | prev_count = len(prev_union) 214 | reference_words = _split_into_words([reference_sentence]) 215 | 216 | combined_lcs_length = 0 217 | for eval_s in evaluated_sentences: 218 | evaluated_words = _split_into_words([eval_s]) 219 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 220 | combined_lcs_length += len(lcs) 221 | lcs_union = lcs_union.union(lcs) 222 | 223 | new_lcs_count = len(lcs_union) - prev_count 224 | return new_lcs_count, lcs_union 225 | 226 | 227 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 228 | """ 229 | Computes ROUGE-L (summary level) of two text collections of sentences. 230 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 231 | rouge-working-note-v1.3.1.pdf 232 | 233 | Calculated according to: 234 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 235 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 236 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 237 | 238 | where: 239 | SUM(i,u) = SUM from i through u 240 | u = number of sentences in reference summary 241 | C = Candidate summary made up of v sentences 242 | m = number of words in reference summary 243 | n = number of words in candidate summary 244 | 245 | Args: 246 | evaluated_sentences: The sentences that have been picked by the 247 | summarizer 248 | reference_sentence: One of the sentences in the reference summaries 249 | 250 | Returns: 251 | A float: F_lcs 252 | 253 | Raises: 254 | ValueError: raises exception if a param has len <= 0 255 | """ 256 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 257 | raise ValueError("Collections must contain at least 1 sentence.") 258 | 259 | # total number of words in reference sentences 260 | m = len(set(_split_into_words(reference_sentences))) 261 | 262 | # total number of words in evaluated sentences 263 | n = len(set(_split_into_words(evaluated_sentences))) 264 | 265 | # print("m,n %d %d" % (m, n)) 266 | union_lcs_sum_across_all_references = 0 267 | union = set() 268 | for ref_s in reference_sentences: 269 | lcs_count, union = _union_lcs(evaluated_sentences, 270 | ref_s, 271 | prev_union=union) 272 | union_lcs_sum_across_all_references += lcs_count 273 | 274 | llcs = union_lcs_sum_across_all_references 275 | r_lcs = llcs / m 276 | p_lcs = llcs / n 277 | beta = p_lcs / (r_lcs + 1e-12) 278 | num = (1 + (beta ** 2)) * r_lcs * p_lcs 279 | denom = r_lcs + ((beta ** 2) * p_lcs) 280 | f_lcs = num / (denom + 1e-12) 281 | return {"f": f_lcs, "p": p_lcs, "r": r_lcs} 282 | --------------------------------------------------------------------------------