├── README.md ├── Saved-Models └── .placeholder ├── dataloader.py ├── eval_model.py ├── finished_files ├── .placeholder └── vocabulary.bin ├── main.py ├── models.py └── preprocess.py /README.md: -------------------------------------------------------------------------------- 1 | # Text-Summarization 2 | 3 | Text Summarization using Pointer Attention Models 4 | 5 | This is a Pytorch implementation of [Get To The Point: Summarization with Pointer-Generator Networks by See et. al.](https://arxiv.org/abs/1704.04368) 6 | 7 | ## Dependencies 8 | You require : 9 | 10 | 1. [Pytorch](pytorch.org/) v0.2 with CUDA support 11 | 2. [Visdom](https://github.com/facebookresearch/visdom/) visualization package for easy monitoring of training progress in web browser 12 | 3. tqdm for terminal-level progress updates 13 | 14 | ## How to get the data 15 | 16 | Download the fully-preprocessed data splits [here](https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail) and save yourself the trouble of downloading CoreNLP and subsequent tokenization. 17 | This implementation only uses the `Finised_Files` directory containing `train.bin`, `val.bin`, `test.bin` splits. Move these files to the `finished_files` directory. 18 | 19 | The only other file `vocabulary.bin` has already been provided in the `finished_files` in this repo. 20 | 21 | Alternatively, follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download and preprocess the dataset. 22 | Use the given `preprocess.py` instead script to generate the data splits as this version does not have Tensorflow dependencies. 23 | 24 | 25 | ## How to run 26 | 27 | The original hyperparameters as described in the paper are memory intensive, so this implementation uses a smaller RNN `hidden_size` as the default setting. All other hyperparameters are kept the same. 28 | You can trade-off `vocabulary_size`, `batch_size`, `hidden_size`, `max_abstract_size` and `max_article_size` to achieve your memory budget. 29 | 30 | 1. Fire up a visdom server : 31 | `python -m visdom.server` 32 | 33 | 2. To train using default settings, from the repo's root directory : 34 | `CUDA_VISIBLE_DEVICES=0 python main.py` 35 | 36 | 3. Monitor the training progress by going to `127.0.0.1:8097` in your web browser or the remote URL is you're executing your code on SSH 37 | 38 | Configurations can be changed using command line options. 39 | 40 | `python main.py --help` to get a list of all options. 41 | 42 | 43 | The model is evaluated periodically during training on a sample from `test.bin`and decoding is done using beam search. The model is also saved after every epoch in `Saved-Models` 44 | 45 | 46 | To bootstrap with pre-trained embeddings, you will need to obtain pre-trained Glove/Word2Vec embeddings for words in your vocabulary. OOV words can be assigned a random value. Save this as a Pytorch Tensor `embeds.pkl` and make sure the size of vocabulary matches size of tensor. 47 | The default setting is initialize with random word embeddings since that has been reported to perform better. 48 | 49 | -------------------------------------------------------------------------------- /Saved-Models/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hashbangCoder/Text-Summarization/dd0105a225d65ed9f7925e788036dbd9e65ca0f6/Saved-Models/.placeholder -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import os, random,pdb, torch 3 | from itertools import groupby 4 | from tqdm import tqdm 5 | import pdb 6 | import numpy as np 7 | 8 | class dataloader(): 9 | def __init__(self, batchSize, epochs, vocab, train_path, test_path, max_article_size=400, max_abstract_size=140, test_mode=False): 10 | self.maxEpochs = epochs 11 | self.epoch = 1 12 | self.batchSize = batchSize 13 | self.iterInd = 0 14 | self.globalInd = 1 15 | #self.vocab = vocab #list of all vocabulary words 16 | self.word2id, self.id2word = self.getVocabMap(vocab) 17 | self.vocabSize = len(vocab) 18 | self.max_article_size = max_article_size 19 | self.max_abstract_size = max_abstract_size 20 | self.test_mode = test_mode 21 | 22 | assert os.path.isfile(train_path) and os.path.isfile(test_path), 'Invalid paths to train/test datafiles' 23 | self.train_path = train_path 24 | self.test_path = test_path 25 | if not self.test_mode: 26 | print 'Loading training data from disk...will take a minute...' 27 | with open(self.train_path,'rb') as f: 28 | self.train_data = pickle.load(f) 29 | self.trainSamples = len(self.train_data) 30 | else: 31 | print 'Initializing Dataloader in test mode with only eval-dataset...' 32 | #Load eval set 33 | print 'Loading eval data from disk...' 34 | with open(self.test_path,'rb') as f: 35 | self.test_data = pickle.load(f) 36 | self.testSamples = len(self.test_data) 37 | 38 | # self.loadEvalBatch() 39 | if not self.test_mode: 40 | self.stopFlag = False 41 | self.pbar = tqdm(total=self.trainSamples * self.maxEpochs) 42 | self.pbar.set_description('Epoch : %d/%d' % (self.epoch, self.maxEpochs)) 43 | 44 | def getVocabMap(self, vocab): 45 | word2id, id2word = {}, {} 46 | for i, word in enumerate(vocab): 47 | word2id[word] = i+1 # reserve 0 for pad 48 | id2word[i+1] = word 49 | id2word[0] = '' 50 | # word2id[''] = len(word2id) + 1 51 | # id2word[len(word2id)] = '' 52 | # word2id[''] = len(word2id) + 1 53 | # id2word[len(word2id)] = '' 54 | return word2id, id2word 55 | 56 | def makeEncoderInput(self, article): 57 | # tokenize article 58 | # list of OOV words in article 59 | self.encUnkCount = 1 60 | _intArticle, extIntArticle = [], [] 61 | article_oov = [] 62 | art_len = min(self.max_article_size, len(article)) 63 | for word_ind, word in enumerate(article[:art_len]): 64 | try: 65 | _intArticle.append(self.word2id[word.lower().strip()]) 66 | extIntArticle.append(self.word2id[word.lower().strip()]) 67 | except KeyError: 68 | _intArticle.append(self.word2id['']) 69 | extIntArticle.append(self.vocabSize + self.encUnkCount) 70 | article_oov.append(word) 71 | #article_oov_ind.append(word_ind) 72 | self.encUnkCount += 1 73 | 74 | return _intArticle, extIntArticle, article_oov, art_len 75 | 76 | def makeDecoderInput(self, abstract, article_oov): 77 | _intAbstract, extIntAbstract = [], [] 78 | abs_len = min(self.max_abstract_size, len(abstract)) 79 | # tokenize abstract 80 | self.decUnkCount = 0 81 | for word in abstract[:abs_len]: 82 | try: 83 | _intAbstract.append(self.word2id[word.lower().strip()]) 84 | extIntAbstract.append(self.word2id[word.lower().strip()]) 85 | except KeyError: 86 | _intAbstract.append(self.word2id['']) 87 | #check if OOV word present in article 88 | if word in article_oov: 89 | extIntAbstract.append(self.vocabSize + article_oov.index(word) + 1) 90 | else: 91 | extIntAbstract.append(self.word2id['']) 92 | self.decUnkCount += 1 93 | return _intAbstract, extIntAbstract, abs_len 94 | 95 | def preproc(self, samples): 96 | # batchArticles --> tensor batch of articles with ids 97 | # batchRevArticles --> tensor batch of reversed articles with ids 98 | # batchExtArticles --> tensor batch of articles with ids and replaced by temp OOV ids 99 | # batchAbstracts --> tensor batch of abstract (input for decoder) with ids 100 | # batchTargets --> tensor batch of target abstracts 101 | # max_article_oov --> max number of OOV tokens in article batch 102 | 103 | 104 | # limit max article size to 400 tokens 105 | extIntArticles, intRevArticles, intAbstract, intTargets, extIntAbstracts = [], [], [], [], [] 106 | art_lens, abs_lens= [], [] 107 | maxLen = 0 108 | max_article_oov = 0 109 | for sampl in samples: 110 | article = sampl['article'].split(' ') 111 | abstract = sampl['abstract'].split(' ') 112 | # get article and abstract int-tokenized 113 | _intArticle, _extIntArticle, article_oov, art_len = self.makeEncoderInput(article) 114 | if max_article_oov < len(article_oov): 115 | max_article_oov = len(article_oov) 116 | _intRevArticle = list(reversed(_intArticle)) 117 | _intAbstract, _extIntAbstract, abs_len = self.makeDecoderInput(abstract, article_oov) 118 | 119 | # append stopping/start tokens and increment length by 1 120 | intAbstract.append([self.word2id['']] + _intAbstract) 121 | # append end token 122 | intTargets.append(_extIntAbstract + [self.word2id['']]) 123 | abs_len += 1 124 | extIntArticles.append(_extIntArticle) 125 | intRevArticles.append(_intRevArticle) 126 | art_lens.append(art_len) 127 | abs_lens.append(abs_len) 128 | 129 | padExtArticles = [torch.LongTensor(item + [0] * (max(art_lens) - len(item))) for item in extIntArticles] 130 | padRevArticles = [torch.LongTensor(item + [0] * (max(art_lens) - len(item))) for item in intRevArticles] 131 | padAbstracts = [torch.LongTensor(item + [0] * (max(abs_lens) - len(item))) for item in intAbstract] 132 | padTargets = [torch.LongTensor(item + [0] * (max(abs_lens) - len(item))) for item in intTargets] 133 | 134 | batchExtArticles = torch.stack(padExtArticles, 0) 135 | # replace temp ids with unk token id for enc input 136 | batchArticles = batchExtArticles.clone().masked_fill_((batchExtArticles > self.vocabSize), self.word2id['']) 137 | batchRevArticles = torch.stack(padRevArticles, 0) 138 | batchAbstracts = torch.stack(padAbstracts, 0) 139 | batchTargets = torch.stack(padTargets, 0) 140 | art_lens = torch.LongTensor(art_lens) 141 | abs_lens = torch.LongTensor(abs_lens) 142 | return batchArticles, batchExtArticles, batchRevArticles, batchAbstracts, batchTargets, art_lens, abs_lens, max_article_oov, article_oov 143 | 144 | def getBatch(self, num_samples=None): 145 | if num_samples is None: 146 | num_samples = self.batchSize 147 | 148 | if self.epoch > self.maxEpochs: 149 | print 'Maximum Epoch Limit reached' 150 | self.stopFlag = True 151 | return None 152 | 153 | if self.iterInd + num_samples > self.trainSamples: 154 | data = [self.train_data[i] for i in xrange(self.iterInd, self.trainSamples)] 155 | else: 156 | data = [self.train_data[i] for i in xrange(self.iterInd, self.iterInd + num_samples)] 157 | 158 | batchData = self.preproc(data) 159 | 160 | self.globalInd += 1 161 | self.iterInd += num_samples 162 | if self.iterInd > self.trainSamples: 163 | self.iterInd = 0 164 | self.epoch += 1 165 | self.globalInd = 1 166 | self.pbar.set_description('Epoch : %d/%d' % (self.epoch, self.maxEpochs)) 167 | 168 | return batchData 169 | 170 | def getEvalBatch(self, num_samples=1): 171 | # select first sample for eval 172 | data = [self.test_data[i] for i in range(num_samples)] 173 | batchData = self.evalPreproc(data[0]) 174 | return batchData 175 | 176 | def evalPreproc(self, sample): 177 | # sample length = 1 178 | # limit max article size to 400 tokens 179 | extIntArticles, intRevArticles = [], [] 180 | max_article_oov = 0 181 | article = sample['article'].split(' ') 182 | # get article int-tokenized 183 | _intArticle, _extIntArticle, article_oov, _ = self.makeEncoderInput(article) 184 | if max_article_oov < len(article_oov): 185 | max_article_oov = len(article_oov) 186 | _intRevArticle = list(reversed(_intArticle)) 187 | # _intAbstract, _extIntAbstract, abs_len = self.makeDecoderInput(abstract, article_oov) 188 | 189 | extIntArticles.append(_extIntArticle) 190 | intRevArticles.append(_intRevArticle) 191 | 192 | padExtArticles = [torch.LongTensor(item) for item in extIntArticles] 193 | padRevArticles = [torch.LongTensor(item) for item in intRevArticles] 194 | 195 | batchExtArticles = torch.stack(padExtArticles, 0) 196 | # replace temp ids with unk token id for enc input 197 | batchArticles = batchExtArticles.clone().masked_fill_((batchExtArticles > self.vocabSize), self.word2id['']) 198 | batchRevArticles = torch.stack(padRevArticles, 0) 199 | 200 | return batchArticles, batchRevArticles, batchExtArticles, max_article_oov, article_oov, sample['article'], sample['abstract'] 201 | 202 | def getEvalSample(self, index=None): 203 | if index is None: 204 | rand_index = np.random.randint(0, self.testSamples-1) 205 | data = self.test_data[rand_index] 206 | return self.evalPreproc(data) 207 | 208 | elif isinstance(index, int) and (index>=0 and index < self.testSamples): 209 | data = self.test_data[index] 210 | return self.evalPreproc(data) 211 | 212 | def getInputTextSample(self, tokenized_text): 213 | extIntArticles, intRevArticles = [], [] 214 | max_article_oov = 0 215 | # get article int-tokenized 216 | _intArticle, _extIntArticle, article_oov, _ = self.makeEncoderInput(tokenized_text) 217 | if max_article_oov < len(article_oov): 218 | max_article_oov = len(article_oov) 219 | _intRevArticle = list(reversed(_intArticle)) 220 | 221 | extIntArticles.append(_extIntArticle) 222 | intRevArticles.append(_intRevArticle) 223 | 224 | padExtArticles = [torch.LongTensor(item) for item in extIntArticles] 225 | padRevArticles = [torch.LongTensor(item) for item in intRevArticles] 226 | 227 | batchExtArticles = torch.stack(padExtArticles, 0) 228 | # replace temp ids with unk token id for enc input 229 | batchArticles = batchExtArticles.clone().masked_fill_((batchExtArticles > self.vocabSize), self.word2id['']) 230 | batchRevArticles = torch.stack(padRevArticles, 0) 231 | 232 | return batchArticles, batchRevArticles, batchExtArticles, max_article_oov, article_oov 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /eval_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import cPickle as pickle 4 | import argparse 5 | import pdb, os 6 | import numpy as np 7 | import models 8 | from torch.nn.utils import clip_grad_norm 9 | from tqdm import tqdm 10 | import dataloader 11 | from visdom import Visdom 12 | from nltk.tokenize import word_tokenize 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument("--train-file", dest="train_file", help="Path to train datafile", default='finished_files/train.bin', type=str) 17 | parser.add_argument("--test-file", dest="test_file", help="Path to test/eval datafile", default='finished_files/test.bin', type=str) 18 | parser.add_argument("--vocab-file", dest="vocab_file", help="Path to vocabulary datafile", default='finished_files/vocabulary.bin', type=str) 19 | 20 | parser.add_argument("--max-abstract-size", dest="max_abstract_size", help="Maximum size of abstract for decoder input", default=110, type=int) 21 | parser.add_argument("--max-article-size", dest="max_article_size", help="Maximum size of article for encoder input", default=300, type=int) 22 | parser.add_argument("--batch-size", dest="batchSize", help="Mini-batch size", default=32, type=int) 23 | parser.add_argument("--embed-size", dest="embedSize", help="Size of word embedding", default=300, type=int) 24 | parser.add_argument("--hidden-size", dest="hiddenSize", help="Size of hidden to model", default=128, type=int) 25 | 26 | parser.add_argument("--lambda", dest="lmbda", help="Hyperparameter for auxillary cost", default=1, type=float) 27 | parser.add_argument("--beam-size", dest="beam_size", help="beam size for beam search decoding", default=4, type=int) 28 | parser.add_argument("--max-decode", dest="max_decode", help="Maximum length of decoded output", default=120, type=int) 29 | parser.add_argument("--truncate-vocab", dest="trunc_vocab", help="size of truncated Vocabulary <= 50000 [to save memory]", default=50000, type=int) 30 | parser.add_argument("--bootstrap", dest="bootstrap", help="Bootstrap word embeds with GloVe?", default=0, type=int) 31 | parser.add_argument("--print-ground-truth", dest="print_ground_truth", help="Print the article and abstract", default=1, type=int) 32 | 33 | parser.add_argument("--load-model", dest="load_model", help="Directory from which to load trained models", default=None, type=str) 34 | parser.add_argument("--article", dest="article_path", help="Path to article text file", default=None, type=str) 35 | 36 | opt = parser.parse_args() 37 | vis = Visdom() 38 | 39 | assert opt.load_model is not None and os.path.isfile(opt.vocab_file), 'Invalid Path to trained model file' 40 | 41 | 42 | ### utility code for displaying generated abstract 43 | def displayOutput(all_summaries, article, abstract, article_oov, show_ground_truth=True): 44 | special_tokens = ['','','',''] 45 | print '*' * 150 46 | print '\n' 47 | if show_ground_truth: 48 | print 'ARTICLE TEXT : \n', article 49 | print 'ACTUAL ABSTRACT : \n', abstract 50 | for i, summary in enumerate(all_summaries): 51 | generated_summary = ' '.join([dl.id2word[ind] if ind<=dl.vocabSize else article_oov[ind % dl.vocabSize] for ind in summary]) 52 | for token in special_tokens: 53 | generated_summary.replace(token, '') 54 | print 'GENERATED ABSTRACT #%d : \n' %(i+1), generated_summary 55 | print '*' * 150 56 | return 57 | 58 | # Utility code to save model to disk 59 | def save_model(net, optimizer,all_summaries, article_string, abs_string): 60 | save_dict = dict({'model': net.state_dict(), 'optim': optimizer.state_dict(), 'epoch': dl.epoch, 'iter':dl.iterInd, 'summaries':all_summaries, 'article':article_string, 'abstract_gold':abs_string}) 61 | print '\n','-' * 60 62 | print 'Saving Model to : ', opt.save_dir 63 | save_name = opt.save_dir + 'savedModel_E%d_%d.pth' % (dl.epoch, dl.iterInd) 64 | torch.save(save_dict, save_name) 65 | print '-' * 60 66 | return 67 | 68 | 69 | 70 | assert opt.trunc_vocab <= 50000, 'Invalid value for --truncate-vocab' 71 | assert os.path.isfile(opt.vocab_file), 'Invalid Path to vocabulary file' 72 | with open(opt.vocab_file) as f: 73 | vocab = pickle.load(f) #list of tuples of word,count. Convert to list of words 74 | vocab = [item[0] for item in vocab[:-(5+ 50000 - opt.trunc_vocab)]] # Truncate vocabulary to conserve memory 75 | vocab += ['', '', '', '', ''] # add special token to vocab to bring total count to 50k 76 | 77 | dl = dataloader.dataloader(opt.batchSize, None, vocab, opt.train_file, opt.test_file, 78 | opt.max_article_size, opt.max_abstract_size, test_mode=True) 79 | 80 | 81 | wordEmbed = torch.nn.Embedding(len(vocab) + 1, opt.embedSize, 0) 82 | print 'Building SummaryNet...' 83 | net = models.SummaryNet(opt.embedSize, opt.hiddenSize, dl.vocabSize, wordEmbed, 84 | start_id=dl.word2id[''], stop_id=dl.word2id[''], unk_id=dl.word2id[''], 85 | max_decode=opt.max_decode, beam_size=opt.beam_size, lmbda=opt.lmbda) 86 | net = net.cuda() 87 | 88 | print 'Loading weights from file...might take a minute...' 89 | saved_file = torch.load(opt.load_model) 90 | net.load_state_dict(saved_file['model']) 91 | print '\n','*'*30, 'LOADED WEIGHTS FROM MODEL FILE : %s' %opt.load_model,'*'*30 92 | 93 | ############################################################################################ 94 | # Set model to eval mode 95 | ############################################################################################ 96 | net.eval() 97 | print '\n\n' 98 | 99 | # Run x times to get x random test data samples for output 100 | for _ in range(5): 101 | # If article file provided 102 | if opt.article_path is not None and os.path.isfile(opt.article_path): 103 | with open(opt.article_path,'r') as f: 104 | article_string = f.read().strip() 105 | article_tokenized = word_tokenize(article_string) 106 | _article, _revArticle, _extArticle, max_article_oov, article_oov = dl.getInputTextSample(article_tokenized) 107 | abs_string = '**No abstract available**' 108 | else: 109 | # pull random test sample 110 | data_batch = dl.getEvalSample() 111 | _article, _revArticle, _extArticle, max_article_oov, article_oov, article_string, abs_string = dl.getEvalSample() 112 | 113 | _article = Variable(_article.cuda(), volatile=True) 114 | _extArticle = Variable(_extArticle.cuda(), volatile=True) 115 | _revArticle = Variable(_revArticle.cuda(), volatile=True) 116 | all_summaries = net((_article, _revArticle, _extArticle), max_article_oov, decode_flag=True) 117 | 118 | displayOutput(all_summaries, article_string, abs_string, article_oov, show_ground_truth=opt.print_ground_truth) 119 | 120 | 121 | -------------------------------------------------------------------------------- /finished_files/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hashbangCoder/Text-Summarization/dd0105a225d65ed9f7925e788036dbd9e65ca0f6/finished_files/.placeholder -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import cPickle as pickle 4 | import argparse 5 | import pdb, os 6 | import numpy as np 7 | import models 8 | from torch.nn.utils import clip_grad_norm 9 | from tqdm import tqdm 10 | import dataloader 11 | from visdom import Visdom 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--train-file", dest="train_file", help="Path to train datafile", default='finished_files/train.bin', type=str) 16 | parser.add_argument("--test-file", dest="test_file", help="Path to test/eval datafile", default='finished_files/test.bin', type=str) 17 | parser.add_argument("--vocab-file", dest="vocab_file", help="Path to vocabulary datafile", default='finished_files/vocabulary.bin', type=str) 18 | 19 | parser.add_argument("--max-abstract-size", dest="max_abstract_size", help="Maximum size of abstract for decoder input", default=110, type=int) 20 | parser.add_argument("--max-article-size", dest="max_article_size", help="Maximum size of article for encoder input", default=300, type=int) 21 | parser.add_argument("--num-epochs", dest="epochs", help="Number of epochs", default=10, type=int) 22 | parser.add_argument("--batch-size", dest="batchSize", help="Mini-batch size", default=32, type=int) 23 | parser.add_argument("--embed-size", dest="embedSize", help="Size of word embedding", default=300, type=int) 24 | parser.add_argument("--hidden-size", dest="hiddenSize", help="Size of hidden to model", default=128, type=int) 25 | 26 | parser.add_argument("--learning-rate", dest="lr", help="Learning Rate", default=0.001, type=float) 27 | parser.add_argument("--lambda", dest="lmbda", help="Hyperparameter for auxillary cost", default=1, type=float) 28 | parser.add_argument("--beam-size", dest="beam_size", help="beam size for beam search decoding", default=4, type=int) 29 | parser.add_argument("--max-decode", dest="max_decode", help="Maximum length of decoded output", default=40, type=int) 30 | parser.add_argument("--grad-clip", dest="grad_clip", help="Clip gradients of RNN model", default=2, type=float) 31 | parser.add_argument("--truncate-vocab", dest="trunc_vocab", help="size of truncated Vocabulary <= 50000 [to save memory]", default=50000, type=int) 32 | parser.add_argument("--bootstrap", dest="bootstrap", help="Bootstrap word embeds with GloVe?", default=0, type=int) 33 | parser.add_argument("--print-ground-truth", dest="print_ground_truth", help="Print the article and abstract", default=1, type=int) 34 | 35 | parser.add_argument("--eval-freq", dest="eval_freq", help="How frequently (every mini-batch) to evaluate model", default=20000, type=int) 36 | parser.add_argument("--save-dir", dest="save_dir", help="Directory to save trained models", default='Saved-Models/', type=str) 37 | parser.add_argument("--load-model", dest="load_model", help="Directory from which to load trained models", default=None, type=str) 38 | 39 | opt = parser.parse_args() 40 | vis = Visdom() 41 | 42 | 43 | ### evaluation code 44 | def evalModel(model): 45 | # set model to eval mode 46 | model.eval() 47 | print '\n\n' 48 | print '*'*30, ' MODEL EVALUATION ', '*'*30 49 | 50 | _article, _revArticle, _extArticle, max_article_oov, article_oov, article_string, abs_string = dl.getEvalBatch() 51 | _article = Variable(_article.cuda(), volatile=True) 52 | _extArticle = Variable(_extArticle.cuda(), volatile=True) 53 | _revArticle = Variable(_revArticle.cuda(), volatile=True) 54 | all_summaries = model((_article, _revArticle, _extArticle), max_article_oov, decode_flag=True) 55 | model.train() 56 | return all_summaries, article_string, abs_string, article_oov 57 | 58 | ### utility code for displaying generated abstract 59 | def displayOutput(all_summaries, article, abstract, article_oov, show_ground_truth=False): 60 | print '*' * 80 61 | print '\n' 62 | if show_ground_truth: 63 | print 'ARTICLE TEXT : \n', article 64 | print 'ACTUAL ABSTRACT : \n', abstract 65 | for i, summary in enumerate(all_summaries): 66 | generated_summary = ' '.join([dl.id2word[ind] if ind<=dl.vocabSize else article_oov[ind % dl.vocabSize] for ind in summary]) 67 | print 'GENERATED ABSTRACT #%d : \n' %(i+1), generated_summary 68 | print '*' * 80 69 | return 70 | 71 | # Utility code to save model to disk 72 | def save_model(net, optimizer,all_summaries, article_string, abs_string): 73 | save_dict = dict({'model': net.state_dict(), 'optim': optimizer.state_dict(), 'epoch': dl.epoch, 'iter':dl.iterInd, 'summaries':all_summaries, 'article':article_string, 'abstract_gold':abs_string}) 74 | print '\n','-' * 60 75 | print 'Saving Model to : ', opt.save_dir 76 | save_name = opt.save_dir + 'savedModel_E%d_%d.pth' % (dl.epoch, dl.iterInd) 77 | torch.save(save_dict, save_name) 78 | print '-' * 60 79 | return 80 | 81 | 82 | 83 | assert opt.trunc_vocab <= 50000, 'Invalid value for --truncate-vocab' 84 | assert os.path.isfile(opt.vocab_file), 'Invalid Path to vocabulary file' 85 | with open(opt.vocab_file) as f: 86 | vocab = pickle.load(f) #list of tuples of word,count. Convert to list of words 87 | vocab = [item[0] for item in vocab[:-(5+ 50000 - opt.trunc_vocab)]] # Truncate vocabulary to conserve memory 88 | vocab += ['', '', '', '', ''] # add special token to vocab to bring total count to 50k 89 | 90 | dl = dataloader.dataloader(opt.batchSize, opt.epochs, vocab, opt.train_file, opt.test_file, 91 | opt.max_article_size, opt.max_abstract_size) 92 | 93 | 94 | if opt.bootstrap: 95 | # bootstrap with pretrained embeddings 96 | wordEmbed = torch.nn.Embedding(len(vocab) + 1, 300, 0) 97 | print 'Bootstrapping with pretrained GloVe word vectors...' 98 | assert os.path.isfile('embeds.pkl'), 'Cannot find pretrained Word embeddings to bootstrap' 99 | with open('embeds.pkl', 'rb') as f: 100 | embeds = pickle.load(f) 101 | assert wordEmbed.weight.size() == embeds.size() 102 | wordEmbed.weight.data[1:,:] = embeds 103 | else: 104 | # learn embeddings from scratch (default) 105 | wordEmbed = torch.nn.Embedding(len(vocab) + 1, opt.embedSize, 0) 106 | 107 | print 'Building and initializing SummaryNet...' 108 | net = models.SummaryNet(opt.embedSize, opt.hiddenSize, dl.vocabSize, wordEmbed, 109 | start_id=dl.word2id[''], stop_id=dl.word2id[''], unk_id=dl.word2id[''], 110 | max_decode=opt.max_decode, beam_size=opt.beam_size, lmbda=opt.lmbda) 111 | net = net.cuda() 112 | optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) 113 | 114 | if opt.load_model is not None and os.path.isfile(opt.load_model): 115 | saved_file = torch.load(opt.load_model) 116 | net.load_state_dict(saved_file['model']) 117 | optimizer.load_state_dict(saved_file['optim']) 118 | dl.epoch = saved_file['epoch'] 119 | dl.iterInd = saved_file['iter'] 120 | dl.pbar.update(dl.iterInd) 121 | print '\n','*'*30, 'RESUME FROM CHECKPOINT : %s' %opt.load_model,'*'*30 122 | 123 | else: 124 | print '\n','*'*30, 'START TRAINING','*'*30 125 | 126 | #dl.iterInd = 287226 127 | #dl.pbar.update(dl.iterInd) 128 | all_loss = [] 129 | win = None 130 | ### Training loop' 131 | while dl.epoch <= opt.epochs: 132 | data_batch = dl.getBatch(opt.batchSize) 133 | batchArticles, batchExtArticles, batchRevArticles, batchAbstracts, batchTargets, _, _, max_article_oov, article_oov = data_batch 134 | # end of training/max epoch reached 135 | if data_batch is None: 136 | print '-'*50, 'END OF TRAINING', '-'*50 137 | break 138 | 139 | batchArticles = Variable(batchArticles.cuda()) 140 | batchExtArticles = Variable(batchExtArticles.cuda()) 141 | batchRevArticles = Variable(batchRevArticles.cuda()) 142 | batchTargets = Variable(batchTargets.cuda()) 143 | batchAbstracts = Variable(batchAbstracts.cuda()) 144 | 145 | losses = net((batchArticles, batchExtArticles, batchRevArticles, batchAbstracts, batchTargets), max_article_oov) 146 | batch_loss = losses.mean() 147 | 148 | batch_loss.backward() 149 | # gradient clipping by norm 150 | clip_grad_norm(net.parameters(), opt.grad_clip) 151 | optimizer.step() 152 | optimizer.zero_grad() 153 | 154 | # update loss ticker 155 | dl.pbar.set_postfix(loss=batch_loss.cpu().data[0]) 156 | dl.pbar.update(opt.batchSize) 157 | 158 | # save losses periodically 159 | if dl.iterInd % 50: 160 | all_loss.append(batch_loss.cpu().data.tolist()[0]) 161 | title = 'Pointer Model with Coverage' 162 | if win is None: 163 | win = vis.line(Y=np.array(all_loss), X=np.arange(1, len(all_loss)+1), opts=dict(title=title, xlabel='#Mini-Batches (x%d)' %(opt.batchSize), 164 | ylabel='Train-Loss')) 165 | vis.line(Y=np.array(all_loss), X=np.arange(1, len(all_loss)+1), win=win, update='replace', opts=dict(title=title, xlabel='#Mini-Batches (x%d)' %(opt.batchSize), 166 | ylabel='Train-Loss')) 167 | 168 | # evaluate model periodically 169 | if dl.iterInd % opt.eval_freq < opt.batchSize and dl.iterInd > opt.batchSize: 170 | all_summaries, article_string, abs_string, article_oov = evalModel(net) 171 | displayOutput(all_summaries, article_string, abs_string, article_oov, show_ground_truth=opt.print_ground_truth) 172 | 173 | #if dl.epoch > 1 and dl.iterInd == 0: 174 | if dl.iterInd % (6*opt.eval_freq) < opt.batchSize and dl.iterInd > opt.batchSize: 175 | save_model(net, optimizer, all_summaries, article_string, abs_string) 176 | 177 | del batch_loss, batchArticles, batchExtArticles, batchRevArticles, batchAbstracts, batchTargets 178 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn import LSTM, GRU, Linear, LSTMCell, Module 4 | import torch.nn.functional as F 5 | import pdb 6 | 7 | 8 | # idea similar to https://github.com/abisee/pointer-generator/blob/master/beam_search.py 9 | class Hypothesis(object): 10 | def __init__(self, token_id, hidden_state, cell_state, log_prob): 11 | self._h = hidden_state 12 | self._c = cell_state 13 | self.log_prob = log_prob 14 | self.full_prediction = token_id # list 15 | self.survivability = self.log_prob/ float(len(self.full_prediction)) 16 | 17 | def extend(self, token_id, hidden_state, cell_state, log_prob): 18 | return Hypothesis(token_id= self.full_prediction + [token_id], 19 | hidden_state=hidden_state, 20 | cell_state=cell_state, 21 | log_prob= self.log_prob + log_prob) 22 | 23 | # encoder net for the article 24 | class Encoder(Module): 25 | def __init__(self, input_size, hidden_size, wordEmbed): 26 | super(Encoder,self).__init__() 27 | self.input_size = input_size 28 | self.hidden_size = hidden_size 29 | 30 | self.word_embed = wordEmbed 31 | self.fwd_rnn = LSTM(self.input_size, self.hidden_size, batch_first=True) 32 | self.bkwd_rnn = LSTM(self.input_size, self.hidden_size, batch_first=True) 33 | self.output_cproj = Linear(self.hidden_size * 2, self.hidden_size) 34 | self.output_hproj = Linear(self.hidden_size * 2, self.hidden_size) 35 | 36 | def forward(self, _input, rev_input): 37 | batch_size, max_len = _input.size(0), _input.size(1) 38 | embed_fwd = self.word_embed(_input) 39 | embed_rev = self.word_embed(rev_input) 40 | 41 | # get mask for location of PAD 42 | mask = _input.eq(0).detach() 43 | 44 | fwd_out, fwd_state = self.fwd_rnn(embed_fwd) 45 | bkwd_out, bkwd_state = self.bkwd_rnn(embed_rev) 46 | hidden_cat = torch.cat((fwd_out, bkwd_out), 2) 47 | 48 | # inverse of mask 49 | inv_mask = mask.eq(0).unsqueeze(2).expand(batch_size, max_len, self.hidden_size * 2).float().detach() 50 | hidden_out = hidden_cat * inv_mask 51 | final_hidd_proj = self.output_hproj(torch.cat((fwd_state[0].squeeze(0), bkwd_state[0].squeeze(0)), 1)) 52 | final_cell_proj = self.output_cproj(torch.cat((fwd_state[1].squeeze(0), bkwd_state[1].squeeze(0)), 1)) 53 | 54 | return hidden_out, final_hidd_proj, final_cell_proj, mask 55 | 56 | 57 | 58 | # TODO Enhancement: Project input embedding with previous context vector for current input 59 | class PointerAttentionDecoder(Module): 60 | def __init__(self, input_size, hidden_size, vocab_size, wordEmbed): 61 | super(PointerAttentionDecoder, self).__init__() 62 | self.input_size = input_size 63 | self.hidden_size = hidden_size 64 | self.vocab_size = vocab_size 65 | self.word_embed = wordEmbed 66 | 67 | #self.decoderRNN = LSTMCell(self.input_size, self.hidden_size) 68 | self.decoderRNN = LSTM(self.input_size, self.hidden_size, batch_first=True) 69 | #params for attention 70 | self.Wh = Linear(2 * self.hidden_size, 2*self. hidden_size) 71 | self.Ws = Linear(self.hidden_size, 2*self.hidden_size) 72 | self.w_c = Linear(1, 2*self.hidden_size) 73 | self.v = Linear(2*self.hidden_size, 1) 74 | 75 | # parameters for p_gen 76 | self.w_h = Linear(2 * self.hidden_size, 1) # double due to concat of BiDi encoder states 77 | self.w_s = Linear(self.hidden_size, 1) 78 | self.w_x = Linear(self.input_size, 1) 79 | 80 | #params for output proj 81 | 82 | self.V = Linear(self.hidden_size * 3, self.vocab_size) 83 | self.min_length = 40 84 | 85 | def setValues(self, start_id, stop_id, unk_id, beam_size, max_decode=40, lmbda=1): 86 | # start/stop tokens 87 | self.start_id = start_id 88 | self.stop_id = stop_id 89 | self.unk_id = unk_id 90 | self.max_decode_steps = max_decode 91 | # max_article_oov -> max number of OOV in articles i.e. enc inputs. Will be set for each batch individually 92 | self.max_article_oov = None 93 | self.beam_size = beam_size 94 | self.lmbda = lmbda 95 | 96 | 97 | def forward(self, enc_states, enc_final_state, enc_mask, _input, article_inds, targets, decode=False): 98 | # enc_states -> output states of encoder 99 | # enc_final_state -> final output of encoder 100 | # enc_mask -> mask indicating location of PAD in encoder input 101 | # _input -> decoder inputs 102 | # article_inds -> modified encoder input with temporary OOV ids for each OOV token 103 | # targets -> decoder targets 104 | # decode -> Boolean flag for train/eval mode 105 | 106 | if decode is True: 107 | return self.decode(enc_states, enc_final_state, enc_mask, article_inds) 108 | 109 | 110 | batch_size, max_enc_len, enc_size = enc_states.size() 111 | max_dec_len = _input.size(1) 112 | # coverage initially zero 113 | coverage = Variable(torch.zeros(batch_size, max_enc_len).cuda()) 114 | dec_lens = (_input > 0).float().sum(1) 115 | state = enc_final_state[0].unsqueeze(0),enc_final_state[1].unsqueeze(0) 116 | 117 | enc_proj = self.Wh(enc_states.view(batch_size*max_enc_len, enc_size)).view(batch_size, max_enc_len, -1) 118 | embed_input = self.word_embed(_input) 119 | 120 | lm_loss, cov_loss = [], [] 121 | hidden, _ = self.decoderRNN(embed_input, state) 122 | 123 | # step through decoder hidden states 124 | for _step in range(max_dec_len): 125 | _h = hidden[:, _step, :] 126 | target = targets[:, _step].unsqueeze(1) 127 | 128 | dec_proj = self.Ws(_h).unsqueeze(1).expand_as(enc_proj) 129 | cov_proj = self.w_c(coverage.view(-1, 1)).view(batch_size, max_enc_len, -1) 130 | e_t = self.v(F.tanh(enc_proj + dec_proj + cov_proj).view(batch_size*max_enc_len, -1)) 131 | 132 | # mask to -INF before applying softmax 133 | attn_scores = e_t.view(batch_size, max_enc_len) 134 | del e_t 135 | attn_scores.masked_fill_(enc_mask, -float('inf')) 136 | attn_scores = F.softmax(attn_scores) 137 | 138 | context = attn_scores.unsqueeze(1).bmm(enc_states).squeeze(1) 139 | p_vocab = F.softmax(self.V(torch.cat((_h, context), 1))) #output proj calculation 140 | p_gen = F.sigmoid(self.w_h(context) + self.w_s(_h) + self.w_x(embed_input[:, _step, :])) # p_gen calculation 141 | p_gen = p_gen.view(-1, 1) 142 | weighted_Pvocab = p_gen * p_vocab 143 | weighted_attn = (1-p_gen)* attn_scores 144 | 145 | if self.max_article_oov > 0: 146 | ext_vocab = Variable(torch.zeros(batch_size, self.max_article_oov).cuda()) #create OOV (but in-article) zero vectors 147 | combined_vocab = torch.cat((weighted_Pvocab, ext_vocab), 1) 148 | del ext_vocab 149 | else: 150 | combined_vocab = weighted_Pvocab 151 | 152 | del weighted_Pvocab 153 | assert article_inds.data.min() >=0 and article_inds.data.max() <= (self.vocab_size+ self.max_article_oov), 'Recheck OOV indexes!' 154 | 155 | # scatter article word probs to combined vocab prob. 156 | # subtract one to account for 0-index 157 | article_inds_masked = article_inds.add(-1).masked_fill_(enc_mask, 0) 158 | combined_vocab = combined_vocab.scatter_add(1, article_inds_masked, weighted_attn) 159 | 160 | # mask the output to account for PAD 161 | # subtract one from target for 0-index 162 | target_mask_0 = target.ne(0).detach() 163 | target_mask_p = target.eq(0).detach() 164 | target = target - 1 165 | output = combined_vocab.gather(1, target.masked_fill_(target_mask_p, 0)) 166 | lm_loss.append(output.log().mul(-1) * target_mask_0.float()) 167 | 168 | coverage = coverage + attn_scores 169 | 170 | # Coverage Loss 171 | # take minimum across both attn_scores and coverage 172 | _cov_loss, _ = torch.stack((coverage, attn_scores), 2).min(2) 173 | cov_loss.append(_cov_loss.sum(1)) 174 | 175 | # add individual losses 176 | total_masked_loss = torch.cat(lm_loss, 1).sum(1).div(dec_lens) + self.lmbda*torch.stack(cov_loss, 1).sum(1).div(dec_lens) 177 | return total_masked_loss 178 | 179 | def decode_step(self, enc_states, state, _input, enc_mask, article_inds): 180 | # decode for one step with beam search 181 | # for first step, batch_size =1 182 | # successive steps batch_size = beam_size 183 | batch_size, max_enc_len, enc_size = enc_states.size() 184 | 185 | # coverage initially zero 186 | coverage = Variable(torch.zeros(batch_size, max_enc_len).cuda()) 187 | 188 | enc_proj = self.Wh(enc_states.view(batch_size*max_enc_len, enc_size)).view(batch_size, max_enc_len, -1) 189 | embed_input = self.word_embed(_input) 190 | 191 | _h, _c = self.decoderRNN(embed_input, state)[1] 192 | _h = _h.squeeze(0) 193 | dec_proj = self.Ws(_h).unsqueeze(1).expand_as(enc_proj) 194 | cov_proj = self.w_c(coverage.view(-1, 1)).view(batch_size, max_enc_len, -1) 195 | e_t = self.v(F.tanh(enc_proj + dec_proj + cov_proj).view(batch_size*max_enc_len, -1)) 196 | attn_scores = e_t.view(batch_size, max_enc_len) 197 | del e_t 198 | attn_scores.masked_fill_(enc_mask, -float('inf')) 199 | attn_scores = F.softmax(attn_scores) 200 | 201 | context = attn_scores.unsqueeze(1).bmm(enc_states) 202 | p_vocab = F.softmax(self.V(torch.cat((_h, context.squeeze(1)), 1))) # output proj calculation 203 | p_gen = F.sigmoid(self.w_h(context.squeeze(1)) + self.w_s(_h) + self.w_x(embed_input[:, 0, :])) # p_gen calculation 204 | p_gen = p_gen.view(-1, 1) 205 | weighted_Pvocab = p_gen * p_vocab 206 | weighted_attn = (1-p_gen)* attn_scores 207 | 208 | if self.max_article_oov > 0: 209 | ext_vocab = Variable(torch.zeros(batch_size, self.max_article_oov).cuda()) # create OOV (but in-article) zero vectors 210 | combined_vocab = torch.cat((weighted_Pvocab, ext_vocab), 1) 211 | del ext_vocab 212 | else: 213 | combined_vocab = weighted_Pvocab 214 | assert article_inds.data.min() >=0 and article_inds.data.max() <= (self.vocab_size+ self.max_article_oov), 'Recheck OOV indexes!' 215 | 216 | # scatter article word probs to combined vocab prob. 217 | # subtract one to account for 0-index 218 | combined_vocab = combined_vocab.scatter_add(1, article_inds.add(-1), weighted_attn) 219 | 220 | return combined_vocab, _h, _c.squeeze(0) 221 | 222 | 223 | def getOverallTopk(self, vocab_probs, _h, _c, all_hyps, results): 224 | # return top-k values i.e. top-k over all beams i.e. next step input ids 225 | # return hidden, cell states corresponding to topk 226 | probs, inds = vocab_probs.topk(k=self.beam_size, dim=1) 227 | probs = probs.log().data 228 | inds = inds.data 229 | inds.add_(1) 230 | candidates = [] 231 | assert len(all_hyps) == probs.size(0), '# Hypothesis and log-prob size dont match' 232 | # cycle through all hypothesis in full beam 233 | for i, hypo in enumerate(probs.tolist()): 234 | for j, _ in enumerate(hypo): 235 | new_cand = all_hyps[i].extend(token_id=inds[i,j], 236 | hidden_state=_h[i].unsqueeze(0), 237 | cell_state=_c[i].unsqueeze(0), 238 | log_prob= probs[i,j]) 239 | candidates.append(new_cand) 240 | # sort in descending order 241 | candidates = sorted(candidates, key=lambda x:x.survivability, reverse=True) 242 | new_beam, next_inp = [], [] 243 | next_h, next_c = [], [] 244 | #prune hypotheses and generate new beam 245 | for h in candidates: 246 | if h.full_prediction[-1] == self.stop_id: 247 | # weed out small sentences that likely have no meaning 248 | if len(h.full_prediction)>=self.min_length: 249 | results.append(h.full_prediction) 250 | else: 251 | new_beam.append(h) 252 | next_inp.append(h.full_prediction[-1]) 253 | next_h.append(h._h.data) 254 | next_c.append(h._c.data) 255 | if len(new_beam) >= self.beam_size: 256 | break 257 | assert len(new_beam) >= 1, 'Non-existent beam' 258 | return new_beam, torch.LongTensor([next_inp]), results, torch.cat(next_h, 0), torch.cat(next_c, 0) 259 | 260 | # Beam Search Decoding 261 | def decode(self, enc_states, enc_final_state, enc_mask, article_inds): 262 | _input = Variable(torch.LongTensor([[self.start_id]]).cuda(), volatile=True) 263 | init_state = enc_final_state[0].unsqueeze(0),enc_final_state[1].unsqueeze(0) 264 | decoded_outputs = [] 265 | # all_hyps --> list of current beam hypothesis. start with base initial hypothesis 266 | all_hyps = [Hypothesis([self.start_id], None, None, 0)] 267 | # start decoding 268 | for _step in range(self.max_decode_steps): 269 | # ater first step, input is of batch_size=curr_beam_size 270 | # curr_beam_size <= self.beam_size due to pruning of beams that have terminated 271 | # adjust enc_states and init_state accordingly 272 | curr_beam_size = _input.size(0) 273 | beam_enc_states = enc_states.expand(curr_beam_size, enc_states.size(1), enc_states.size(2)).contiguous().detach() 274 | beam_article_inds = article_inds.expand(curr_beam_size, article_inds.size(1)).detach() 275 | 276 | vocab_probs, next_h, next_c = self.decode_step(beam_enc_states, init_state, _input, enc_mask, beam_article_inds) 277 | 278 | # does bulk of the beam search 279 | # decoded_outputs --> list of all ouputs terminated with stop tokens and of minimal length 280 | all_hyps, decode_inds, decoded_outputs, init_h, init_c = self.getOverallTopk(vocab_probs, next_h, next_c, all_hyps, decoded_outputs) 281 | 282 | # convert OOV words to unk tokens for lookup 283 | decode_inds.masked_fill_((decode_inds > self.vocab_size), self.unk_id) 284 | decode_inds = decode_inds.t() 285 | _input = Variable(decode_inds.cuda(), volatile=True) 286 | init_state = (Variable(init_h.unsqueeze(0), volatile=True), Variable(init_c.unsqueeze(0), volatile=True)) 287 | 288 | 289 | 290 | non_terminal_output = [item.full_prediction for item in all_hyps] 291 | all_outputs = decoded_outputs + non_terminal_output 292 | return all_outputs 293 | 294 | 295 | class SummaryNet(Module): 296 | def __init__(self, input_size, hidden_size, vocab_size, wordEmbed, start_id, stop_id, unk_id, beam_size=4, max_decode=40, lmbda=1): 297 | super(SummaryNet, self).__init__() 298 | self.input_size = input_size 299 | self.hidden_size = hidden_size 300 | 301 | self.encoder = Encoder(self.input_size, self.hidden_size, wordEmbed) 302 | self.pointerDecoder = PointerAttentionDecoder(self.input_size, self.hidden_size, vocab_size, wordEmbed) 303 | self.pointerDecoder.setValues(start_id, stop_id, unk_id, beam_size, max_decode, lmbda) 304 | 305 | def forward(self, _input, max_article_oov, decode_flag=False): 306 | # set num article OOVs in decoder 307 | self.pointerDecoder.max_article_oov = max_article_oov 308 | # decode/eval code 309 | if decode_flag: 310 | enc_input, rev_enc_input, article_inds = _input 311 | enc_states, enc_hn, enc_cn, enc_mask = self.encoder(enc_input, rev_enc_input) 312 | model_summary = self.pointerDecoder(enc_states, (enc_hn, enc_cn), enc_mask, None, article_inds, targets=None, decode=True) 313 | return model_summary 314 | 315 | else: 316 | # train code 317 | enc_input, article_inds, rev_enc_input, dec_input, dec_target = _input 318 | enc_states, enc_hn, enc_cn, enc_mask = self.encoder(enc_input, rev_enc_input) 319 | 320 | total_loss = self.pointerDecoder(enc_states, (enc_hn, enc_cn), enc_mask, dec_input, article_inds, targets=dec_target) 321 | return total_loss 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import hashlib 4 | import struct 5 | import subprocess 6 | import collections 7 | import tensorflow as tf 8 | from tensorflow.core.example import example_pb2 9 | import cPickle as pickle 10 | import pdb 11 | dm_single_close_quote = u'\u2019' # unicode 12 | dm_double_close_quote = u'\u201d' 13 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence 14 | 15 | # We use these to separate the summary sentences in the .bin datafiles 16 | SENTENCE_START = '' 17 | SENTENCE_END = '' 18 | 19 | all_train_urls = "url_lists/all_train.txt" 20 | all_val_urls = "url_lists/all_val.txt" 21 | all_test_urls = "url_lists/all_test.txt" 22 | 23 | cnn_tokenized_stories_dir = "cnn_stories_tokenized" 24 | dm_tokenized_stories_dir = "dm_stories_tokenized" 25 | finished_files_dir = "finished_files" 26 | chunks_dir = os.path.join(finished_files_dir, "chunked") 27 | 28 | # These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir 29 | num_expected_cnn_stories = 92579 30 | num_expected_dm_stories = 219506 31 | 32 | VOCAB_SIZE = 50000 33 | CHUNK_SIZE = 1000 # num examples per chunk, for the chunked data 34 | 35 | 36 | def chunk_file(set_name): 37 | in_file = 'finished_files/%s.bin' % set_name 38 | reader = open(in_file, "rb") 39 | chunk = 0 40 | finished = False 41 | while not finished: 42 | chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (set_name, chunk)) # new chunk 43 | with open(chunk_fname, 'wb') as writer: 44 | for _ in range(CHUNK_SIZE): 45 | len_bytes = reader.read(8) 46 | if not len_bytes: 47 | finished = True 48 | break 49 | str_len = struct.unpack('q', len_bytes)[0] 50 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 51 | writer.write(struct.pack('q', str_len)) 52 | writer.write(struct.pack('%ds' % str_len, example_str)) 53 | chunk += 1 54 | 55 | 56 | def chunk_all(): 57 | # Make a dir to hold the chunks 58 | if not os.path.isdir(chunks_dir): 59 | os.mkdir(chunks_dir) 60 | # Chunk the data 61 | for set_name in ['train', 'val', 'test']: 62 | print "Splitting %s data into chunks..." % set_name 63 | chunk_file(set_name) 64 | print "Saved chunked data in %s" % chunks_dir 65 | 66 | 67 | def tokenize_stories(stories_dir, tokenized_stories_dir): 68 | """Maps a whole directory of .story files to a tokenized version using Stanford CoreNLP Tokenizer""" 69 | print "Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir) 70 | stories = os.listdir(stories_dir) 71 | # make IO list file 72 | print "Making list of files to tokenize..." 73 | with open("mapping.txt", "w") as f: 74 | for s in stories: 75 | f.write("%s \t %s\n" % (os.path.join(stories_dir, s), os.path.join(tokenized_stories_dir, s))) 76 | command = ['java', 'edu.stanford.nlp.process.PTBTokenizer', '-ioFileList', '-preserveLines', 'mapping.txt'] 77 | print "Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir) 78 | subprocess.call(command) 79 | print "Stanford CoreNLP Tokenizer has finished." 80 | os.remove("mapping.txt") 81 | 82 | # Check that the tokenized stories directory contains the same number of files as the original directory 83 | num_orig = len(os.listdir(stories_dir)) 84 | num_tokenized = len(os.listdir(tokenized_stories_dir)) 85 | if num_orig != num_tokenized: 86 | raise Exception("The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % (tokenized_stories_dir, num_tokenized, stories_dir, num_orig)) 87 | print "Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir) 88 | 89 | 90 | def read_text_file(text_file): 91 | lines = [] 92 | with open(text_file, "r") as f: 93 | for line in f: 94 | lines.append(line.strip()) 95 | return lines 96 | 97 | 98 | def hashhex(s): 99 | """Returns a heximal formated SHA1 hash of the input string.""" 100 | h = hashlib.sha1() 101 | h.update(s) 102 | return h.hexdigest() 103 | 104 | 105 | def get_url_hashes(url_list): 106 | return [hashhex(url) for url in url_list] 107 | 108 | 109 | def fix_missing_period(line): 110 | """Adds a period to a line that is missing a period""" 111 | if "@highlight" in line: return line 112 | if line=="": return line 113 | if line[-1] in END_TOKENS: return line 114 | # print line[-1] 115 | return line + " ." 116 | 117 | 118 | def get_art_abs(story_file): 119 | lines = read_text_file(story_file) 120 | 121 | # Lowercase everything 122 | lines = [line.lower() for line in lines] 123 | 124 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) 125 | lines = [fix_missing_period(line) for line in lines] 126 | 127 | # Separate out article and abstract sentences 128 | article_lines = [] 129 | highlights = [] 130 | next_is_highlight = False 131 | for idx,line in enumerate(lines): 132 | if line == "": 133 | continue # empty line 134 | elif line.startswith("@highlight"): 135 | next_is_highlight = True 136 | elif next_is_highlight: 137 | highlights.append(line) 138 | else: 139 | article_lines.append(line) 140 | 141 | # Make article into a single string 142 | article = ' '.join(article_lines) 143 | 144 | # Make abstract into a signle string, putting and tags around the sentences 145 | abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) 146 | 147 | return article, abstract 148 | 149 | 150 | def write_to_bin(url_file, out_file, makevocab=False): 151 | """Reads the tokenized .story files corresponding to the urls listed in the url_file and writes them to a out_file.""" 152 | print "Making bin file for URLs listed in %s..." % url_file 153 | url_list = read_text_file(url_file) 154 | url_hashes = get_url_hashes(url_list) 155 | story_fnames = [s+".story" for s in url_hashes] 156 | num_stories = len(story_fnames) 157 | 158 | if makevocab: 159 | vocab_counter = collections.Counter() 160 | 161 | data_file = {} 162 | for idx,s in enumerate(story_fnames): 163 | if idx % 1000 == 0: 164 | print "Writing story %i of %i; %.2f percent done" % (idx, num_stories, float(idx)*100.0/float(num_stories)) 165 | 166 | # Look in the tokenized story dirs to find the .story file corresponding to this url 167 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)): 168 | story_file = os.path.join(cnn_tokenized_stories_dir, s) 169 | elif os.path.isfile(os.path.join(dm_tokenized_stories_dir, s)): 170 | story_file = os.path.join(dm_tokenized_stories_dir, s) 171 | else: 172 | print "Error: Couldn't find tokenized story file %s in either tokenized story directories %s and %s. Was there an error during tokenization?" % (s, cnn_tokenized_stories_dir, dm_tokenized_stories_dir) 173 | # Check again if tokenized stories directories contain correct number of files 174 | print "Checking that the tokenized stories directories %s and %s contain correct number of files..." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir) 175 | check_num_stories(cnn_tokenized_stories_dir, num_expected_cnn_stories) 176 | check_num_stories(dm_tokenized_stories_dir, num_expected_dm_stories) 177 | raise Exception("Tokenized stories directories %s and %s contain correct number of files but story file %s found in neither." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir, s)) 178 | 179 | # Get the strings to write to .bin file 180 | article, abstract = get_art_abs(story_file) 181 | 182 | # Write to dict 183 | data_file[idx] = {'article':article, 'abstract':abstract} 184 | 185 | # Write to tf.Example 186 | # tf_example = example_pb2.Example() 187 | # tf_example.features.feature['article'].bytes_list.value.extend([article]) 188 | # tf_example.features.feature['abstract'].bytes_list.value.extend([abstract]) 189 | # tf_example_str = tf_example.SerializeToString() 190 | # str_len = len(tf_example_str) 191 | # writer.write(struct.pack('q', str_len)) 192 | # writer.write(struct.pack('%ds' % str_len, tf_example_str)) 193 | 194 | # Write the vocab to file, if applicable 195 | if makevocab: 196 | art_tokens = article.split(' ') 197 | abs_tokens = abstract.split(' ') 198 | abs_tokens = [t for t in abs_tokens if t not in [SENTENCE_START, SENTENCE_END]] # remove these tags from vocab 199 | tokens = art_tokens + abs_tokens 200 | tokens = [t.strip() for t in tokens] # strip 201 | tokens = [t for t in tokens if t!=""] # remove empty 202 | vocab_counter.update(tokens) 203 | 204 | with open(out_file, 'wb') as writer: 205 | pickle.dump(data_file, writer) 206 | print "Finished writing file : %s\n" % out_file 207 | 208 | # write vocab to file 209 | if makevocab: 210 | print "Writing vocab file..." 211 | with open(os.path.join(finished_files_dir, "vocab"), 'w') as writer: 212 | for word, count in vocab_counter.most_common(VOCAB_SIZE): 213 | writer.write(word + ' ' + str(count) + '\n') 214 | print "Finished writing vocab file" 215 | 216 | 217 | def check_num_stories(stories_dir, num_expected): 218 | num_stories = len(os.listdir(stories_dir)) 219 | if num_stories != num_expected: 220 | raise Exception("stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected)) 221 | 222 | 223 | if __name__ == '__main__': 224 | if len(sys.argv) != 3: 225 | print "USAGE: python make_datafiles.py " 226 | sys.exit() 227 | cnn_stories_dir = sys.argv[1] 228 | dm_stories_dir = sys.argv[2] 229 | 230 | # Check the stories directories contain the correct number of .story files 231 | check_num_stories(cnn_stories_dir, num_expected_cnn_stories) 232 | check_num_stories(dm_stories_dir, num_expected_dm_stories) 233 | 234 | # Create some new directories 235 | if not os.path.exists(cnn_tokenized_stories_dir): os.makedirs(cnn_tokenized_stories_dir) 236 | if not os.path.exists(dm_tokenized_stories_dir): os.makedirs(dm_tokenized_stories_dir) 237 | if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir) 238 | 239 | # Run stanford tokenizer on both stories dirs, outputting to tokenized stories directories 240 | #tokenize_stories(cnn_stories_dir, cnn_tokenized_stories_dir) 241 | #tokenize_stories(dm_stories_dir, dm_tokenized_stories_dir) 242 | 243 | # Read the tokenized stories, do a little postprocessing then write to bin files 244 | write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test.bin")) 245 | write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val.bin")) 246 | write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train.bin"), makevocab=True) 247 | 248 | # Chunk the data. This splits each of train.bin, val.bin and test.bin into smaller chunks, each containing e.g. 1000 examples, and saves them in finished_files/chunks 249 | #chunk_all() --------------------------------------------------------------------------------