├── README.md ├── checkpoint_models └── README.md ├── data └── README.md ├── mtl_learning.py ├── mtl_testing.py ├── pickles ├── README.md ├── ent_data.pkl └── summ_data.pkl └── reloads └── README.md /README.md: -------------------------------------------------------------------------------- 1 | #### Information and Instructions about the code 2 | 3 | - To run the code, Python 3.6 and PyTorch 0.2 are needed 4 | - The training code is present in the file `mtl_learning.py` whereas the testing code is present in the file `mtl_testing.py` 5 | - The code in `mtl_learning.py` is pretty much self documented 6 | - For the training, the datasets should be present as pickle files in the `pickles/` directory. The embeddings file should be present in the `data/` directory. 7 | - In case of reloading the models and continuing the training, the models should be placed in `reloads/` directory. 8 | - When all the necessary files and data are present, simply run `python mtl_learning.py` for training and `python mtl_testing.py` for testing. 9 | - The outputs of the training would be present in a file named `outputs.txt` and that of testing would be present in `test_output.txt` 10 | 11 | #### Credits 12 | 13 | The project is inspired from Pasunuru et al's (2017) work on ["Towards Improving Abstractive Summarization via Entailment Generation"](http://www.aclweb.org/anthology/W17-4504) and Sean Robertson's [tutorial](http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) on seq2seq translation. 14 | -------------------------------------------------------------------------------- /checkpoint_models/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quettabit/pytorch_mtl/086c08c5c9f0fae9479765f3f68900347686e0db/checkpoint_models/README.md -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quettabit/pytorch_mtl/086c08c5c9f0fae9479765f3f68900347686e0db/data/README.md -------------------------------------------------------------------------------- /mtl_learning.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import glob 3 | import re 4 | import random 5 | import string 6 | import unicodedata 7 | from io import open 8 | 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.nn as nn 13 | 14 | from nltk import tokenize 15 | from torch import optim 16 | from torch.autograd import Variable 17 | import torch.nn.functional as F 18 | 19 | USE_CUDA = False 20 | MAX_LENGTH = 52 21 | SPLIT_RATIOS = {'train': 80, 'validation': 10, 'test': 10} 22 | BATCH_SIZE = 32 23 | SOS_TOKEN = 0 24 | EOS_TOKEN = 1 25 | 26 | class MetaData: 27 | def __init__(self): 28 | self.word_to_index = {"SOS": 0, "EOS": 1} 29 | self.word_to_count = {} 30 | self.index_to_word = {0: "SOS", 1: "EOS"} 31 | self.num_words = 2 # Count SOS and EOS 32 | self.max_len = -1 33 | 34 | def add_sentence(self, sentence): 35 | for word in sentence.split(' '): 36 | self.add_word(word) 37 | 38 | def add_word(self, word): 39 | if word not in self.word_to_index: 40 | self.word_to_index[word] = self.num_words 41 | self.word_to_count[word] = 1 42 | self.index_to_word[self.num_words] = word 43 | self.num_words += 1 44 | else: 45 | self.word_to_count[word] += 1 46 | 47 | 48 | class EncoderRNN(nn.Module): 49 | ''' 50 | RNN GRU Encoder 51 | ''' 52 | def __init__(self, vocab_size, embedding_size, 53 | hidden_size, word_embeddings): 54 | super(EncoderRNN, self).__init__() 55 | self.vocab_size = vocab_size 56 | self.embedding_size = embedding_size 57 | self.hidden_size = hidden_size 58 | 59 | 60 | self.embedding = nn.Embedding(vocab_size, embedding_size) 61 | self.embedding.weight.data.copy_(torch.from_numpy(word_embeddings)) 62 | self.linear = nn.Linear(self.embedding_size, self.hidden_size) 63 | self.gru = nn.GRU(hidden_size, hidden_size) 64 | 65 | def forward(self, input, hidden): 66 | embedded = self.embedding(input).view(1, 1, -1) 67 | output = self.linear(embedded) 68 | output, hidden = self.gru(output, hidden) 69 | return output, hidden 70 | 71 | def init_hidden(self): 72 | result = Variable(torch.zeros(1, 1, self.hidden_size)) 73 | if USE_CUDA: 74 | return result.cuda() 75 | else: 76 | return result 77 | 78 | 79 | class AttnDecoderRNN(nn.Module): 80 | ''' 81 | RNN GRU Decoder with Attention 82 | ''' 83 | def __init__(self, vocab_size, embedding_size, 84 | hidden_size, word_embeddings, dropout_p=0.1, 85 | max_length=MAX_LENGTH): 86 | super(AttnDecoderRNN, self).__init__() 87 | self.embedding_size = embedding_size 88 | self.hidden_size = hidden_size 89 | self.vocab_size = vocab_size 90 | self.dropout_p = dropout_p 91 | self.max_length = max_length 92 | 93 | self.embedding = nn.Embedding(self.vocab_size, self.embedding_size) 94 | self.embedding.weight.data.copy_(torch.from_numpy(word_embeddings)) 95 | self.linear = nn.Linear(self.embedding_size, self.hidden_size) 96 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length) 97 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) 98 | self.dropout = nn.Dropout(self.dropout_p) 99 | self.gru = nn.GRU(self.hidden_size, self.hidden_size) 100 | self.out = nn.Linear(self.hidden_size, self.vocab_size) 101 | 102 | def forward(self, input, hidden, encoder_outputs): 103 | embedded = self.embedding(input).view(1, 1, -1) 104 | embedded = self.dropout(embedded) 105 | embedded = self.linear(embedded) 106 | 107 | attn_weights = F.softmax(self.attn(torch.cat((embedded[0], 108 | hidden[0]), 1))) 109 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), 110 | encoder_outputs.unsqueeze(0)) 111 | 112 | output = torch.cat((embedded[0], attn_applied[0]), 1) 113 | output = self.attn_combine(output).unsqueeze(0) 114 | output = F.relu(output) 115 | 116 | output, hidden = self.gru(output, hidden) 117 | 118 | output = F.log_softmax(self.out(output[0])) 119 | 120 | return output, hidden, attn_weights 121 | 122 | def init_hidden(self): 123 | result = Variable(torch.zeros(1, 1, self.hidden_size)) 124 | if USE_CUDA: 125 | return result.cuda() 126 | else: 127 | return result 128 | 129 | def print_msg(msg): 130 | current_time = datetime.datetime.now() 131 | msg = "{:%D:%H:%M:%S} ---- ".format(current_time) + msg + "\n" 132 | with open('output.txt', 'a') as f: 133 | f.write(msg) 134 | 135 | ''' 136 | load summarization and entailment datasets from pkl file 137 | ''' 138 | summ_data = pickle.load(open('pickles/summ_data.pkl','rb')) 139 | ent_data = pickle.load(open('pickles/ent_data.pkl','rb')) 140 | print_msg('datasets loaded ..') 141 | 142 | random.shuffle(summ_data) 143 | random.shuffle(ent_data) 144 | print_msg('datasets shuffled ..') 145 | 146 | ''' 147 | load GloVe word embeddings 148 | ''' 149 | word_embeddings = {} 150 | with open('data/glove.6B.300d.txt', 'r') as f: 151 | for line in f: 152 | splits = line.split(' ') 153 | word = splits[0] 154 | embeds = splits[1:len(splits)] 155 | embeds = [float(embed) for embed in embeds] 156 | word_embeddings[word] = embeds 157 | print_msg('embeddings loaded ..') 158 | 159 | ''' 160 | create unified vocabulary out of summarization dataset + entailment dataset 161 | ''' 162 | word_embedding_keys = set(list(word_embeddings.keys())) 163 | meta_data = MetaData() 164 | 165 | for pair in summ_data: 166 | meta_data.add_sentence(pair[0]) 167 | meta_data.add_sentence(pair[1]) 168 | 169 | for pair in ent_data: 170 | meta_data.add_sentence(pair[0]) 171 | meta_data.add_sentence(pair[1]) 172 | 173 | print_msg('meta data created ..') 174 | 175 | pickle.dump(meta_data, open('pickles/meta_data.pkl', 'wb')) 176 | 177 | print_msg('meta_data pickled ..') 178 | 179 | vocab_size = meta_data.num_words 180 | embedding_size = 300 181 | 182 | ''' 183 | merge embeddings - glove embedding if present; else a normal distribution 184 | ''' 185 | np_embeddings = np.ndarray(shape=(vocab_size, embedding_size)) 186 | for index in range(vocab_size): 187 | word = meta_data.index_to_word[index] 188 | if word in word_embedding_keys: 189 | np_embeddings[index] = word_embeddings[word] 190 | else: 191 | np_embeddings[index] = np.random.normal(0, 1, embedding_size) 192 | 193 | print_msg('numpy embedding matrix created ..') 194 | 195 | 196 | ''' 197 | helper functions to create pytorch autograd.Variables out of indexes 198 | in vocab mapped from/to input/output strings 199 | ''' 200 | def indexes_from_sentence(meta_data, data): 201 | return [meta_data.word_to_index[word] for word in data.split(' ')] 202 | 203 | 204 | def variable_from_sentence(meta_data, data): 205 | indexes = indexes_from_sentence(meta_data, data) 206 | indexes.append(EOS_TOKEN) 207 | result = Variable(torch.LongTensor(indexes).view(-1, 1)) 208 | if USE_CUDA: 209 | return result.cuda() 210 | else: 211 | return result 212 | 213 | 214 | def variables_from_data(data, meta_data): 215 | input_variable = variable_from_sentence(meta_data, data[0]) 216 | target_variable = variable_from_sentence(meta_data, data[1]) 217 | return (input_variable, target_variable) 218 | 219 | 220 | train_summ_data_len = int(len(summ_data)*SPLIT_RATIOS['train']/100) 221 | valid_summ_data_len = int(len(summ_data)*SPLIT_RATIOS['validation']/100) 222 | 223 | 224 | train_summ_data = summ_data[:train_summ_data_len] 225 | valid_summ_data = summ_data[train_summ_data_len: 226 | (train_summ_data_len + valid_summ_data_len)] 227 | test_summ_data = summ_data[(train_summ_data_len + valid_summ_data_len):] 228 | 229 | print_msg('train/valid/test datasets created ..') 230 | 231 | pickle.dump(train_summ_data, open('pickles/train_summ_data.pkl', 'wb')) 232 | pickle.dump(valid_summ_data, open('pickles/valid_summ_data.pkl', 'wb')) 233 | pickle.dump(test_summ_data, open('pickles/test_summ_data.pkl', 'wb')) 234 | 235 | print_msg('train/valid/test datasets pickled ..') 236 | 237 | train_summ_batches = [train_summ_data[x:x+BATCH_SIZE] 238 | for x in range(0, len(train_summ_data), BATCH_SIZE)] 239 | 240 | 241 | train_ent_batches = [ent_data[x:x+BATCH_SIZE] 242 | for x in range(0, len(ent_data), BATCH_SIZE)] 243 | 244 | print_msg('batch datasets created ..') 245 | 246 | 247 | def checkpoint(summ_encoder, ent_encoder, decoder, 248 | valid_loss): 249 | ''' 250 | saves the encoder and decoder objects 251 | ''' 252 | current_time = datetime.datetime.now() 253 | timestamp = "{:%D_%H_%M_%S}".format(current_time).replace('/','_') 254 | loss = str(valid_loss).split('.')[0] 255 | torch.save(summ_encoder, 256 | "checkpoint_models/summ_encoder_%s_%s" % (timestamp, loss)) 257 | print_msg('summ_encoder model saved ..') 258 | torch.save(ent_encoder, 259 | "checkpoint_models/ent_encoder_%s_%s" % (timestamp, loss)) 260 | print_msg('ent_encoder model saved ..') 261 | torch.save(decoder, 262 | "checkpoint_models/decoder_%s_%s" % (timestamp, loss)) 263 | print_msg('decoder model saved ..') 264 | 265 | 266 | def evaluate(encoder, decoder, validation_set, 267 | meta_data, max_length=MAX_LENGTH): 268 | ''' 269 | evaluates the performance of the model on the validation set 270 | ''' 271 | 272 | eval_pairs = [variables_from_data(sample, meta_data) 273 | for sample in validation_set] 274 | criterion = nn.NLLLoss() 275 | avg_loss = 0 276 | 277 | for _, pair in enumerate(eval_pairs): 278 | 279 | input_variable = pair[0] 280 | target_variable = pair[1] 281 | input_length = input_variable.size()[0] 282 | target_length = target_variable.size()[0] 283 | encoder_hidden = encoder.init_hidden() 284 | 285 | encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) 286 | encoder_outputs = encoder_outputs.cuda() if USE_CUDA \ 287 | else encoder_outputs 288 | 289 | for ei in range(input_length): 290 | encoder_output, encoder_hidden = encoder(input_variable[ei], 291 | encoder_hidden) 292 | encoder_outputs[ei] = encoder_outputs[ei] + encoder_output[0][0] 293 | 294 | decoder_input = Variable(torch.LongTensor([[SOS_TOKEN]])) 295 | decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input 296 | 297 | decoder_hidden = encoder_hidden 298 | 299 | 300 | loss = 0 301 | 302 | for di in range(target_length): 303 | decoder_output,\ 304 | decoder_hidden,\ 305 | decoder_attention = decoder(decoder_input, decoder_hidden, 306 | encoder_outputs) 307 | topv, topi = decoder_output.data.topk(1) 308 | ni = topi[0][0] 309 | 310 | decoder_input = Variable(torch.LongTensor([[ni]])) 311 | decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input 312 | 313 | loss += criterion(decoder_output, target_variable[di]) 314 | if ni == EOS_TOKEN: 315 | break 316 | 317 | avg_loss += loss.data[0] / target_length 318 | 319 | 320 | return (avg_loss / len(eval_pairs)) 321 | 322 | 323 | def train_sample(input_variable, target_variable, encoder, 324 | decoder, encoder_optimizer, decoder_optimizer, 325 | criterion, teacher_forcing_ratio, max_length=MAX_LENGTH): 326 | ''' 327 | trains a single data sample 328 | ''' 329 | 330 | encoder_hidden = encoder.init_hidden() 331 | 332 | input_length = input_variable.size()[0] 333 | target_length = target_variable.size()[0] 334 | 335 | encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) 336 | encoder_outputs = encoder_outputs.cuda() if USE_CUDA else encoder_outputs 337 | 338 | loss = 0 339 | 340 | for ei in range(input_length): 341 | encoder_output, encoder_hidden = encoder(input_variable[ei], 342 | encoder_hidden) 343 | encoder_outputs[ei] = encoder_output[0][0] 344 | 345 | decoder_input = Variable(torch.LongTensor([[SOS_TOKEN]])) 346 | decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input 347 | 348 | decoder_hidden = encoder_hidden 349 | 350 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio \ 351 | else False 352 | 353 | if use_teacher_forcing: 354 | 355 | for di in range(target_length): 356 | decoder_output,\ 357 | decoder_hidden,\ 358 | decoder_attention = decoder(decoder_input, decoder_hidden, 359 | encoder_outputs) 360 | loss += criterion(decoder_output, target_variable[di]) 361 | decoder_input = target_variable[di] 362 | 363 | else: 364 | 365 | for di in range(target_length): 366 | decoder_output,\ 367 | decoder_hidden,\ 368 | decoder_attention = decoder(decoder_input, decoder_hidden, 369 | encoder_outputs) 370 | topv, topi = decoder_output.data.topk(1) 371 | ni = topi[0][0] 372 | 373 | decoder_input = Variable(torch.LongTensor([[ni]])) 374 | decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input 375 | 376 | loss += criterion(decoder_output, target_variable[di]) 377 | if ni == EOS_TOKEN: 378 | break 379 | 380 | loss.backward() 381 | 382 | 383 | 384 | return loss.data[0] / target_length 385 | 386 | 387 | def train_batch(batch, meta_data, encoder, 388 | decoder, encoder_optimizer, 389 | decoder_optimizer, criterion): 390 | ''' 391 | trains a batch of data samples 392 | ''' 393 | encoder_optimizer.zero_grad() 394 | decoder_optimizer.zero_grad() 395 | teacher_forcing_ratio = 0.5 396 | loss = 0 397 | training_pairs = [variables_from_data(sample, meta_data) 398 | for sample in batch] 399 | 400 | for pair in training_pairs: 401 | 402 | loss += train_sample(pair[0], pair[1], 403 | encoder, decoder, 404 | encoder_optimizer, 405 | decoder_optimizer, 406 | criterion, 407 | teacher_forcing_ratio) 408 | 409 | 410 | encoder_optimizer.step() 411 | decoder_optimizer.step() 412 | 413 | return (loss / len(batch)) 414 | 415 | 416 | def train(train_summ_batches, valid_summ_data, train_ent_batches, 417 | np_embeddings, meta_data, num_epochs, 418 | pt_reload=False, switch=1, print_every=1, 419 | validate_every=1, learning_rate=0.005): 420 | 421 | ent_batches_len = 10 422 | hidden_size = 512 423 | 424 | summ_encoder = None 425 | ent_encoder = None 426 | decoder = None 427 | 428 | if pt_reload: 429 | summ_encoder = torch.load('reloads/summ_encoder.pt') 430 | ent_encoder = torch.load('reloads/ent_encoder.pt') 431 | decoder = torch.load('reloads/decoder.pt') 432 | print_msg('reloading of saved models done ..') 433 | else: 434 | summ_encoder = EncoderRNN(vocab_size, embedding_size, 435 | hidden_size, np_embeddings) 436 | ent_encoder = EncoderRNN(vocab_size, embedding_size, 437 | hidden_size, np_embeddings) 438 | decoder = AttnDecoderRNN(vocab_size, embedding_size, 439 | hidden_size, np_embeddings) 440 | 441 | 442 | if USE_CUDA: 443 | summ_encoder = summ_encoder.cuda() 444 | ent_encoder = ent_encoder.cuda() 445 | decoder = decoder.cuda() 446 | 447 | 448 | summ_encoder_optimizer = optim.Adam(summ_encoder.parameters(), 449 | lr=learning_rate) 450 | ent_encoder_optimizer = optim.Adam(ent_encoder.parameters(), 451 | lr=learning_rate) 452 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) 453 | 454 | criterion = nn.NLLLoss() 455 | 456 | avg_train_loss_val = 0.0 457 | avg_train_loss_num = 0 458 | 459 | for epoch in range(num_epochs): 460 | for sbi, summ_batch in enumerate(train_summ_batches): 461 | 462 | try: 463 | 464 | loss = train_batch(summ_batch, meta_data, 465 | summ_encoder.train(), 466 | decoder.train(), 467 | summ_encoder_optimizer, 468 | decoder_optimizer, 469 | criterion) 470 | 471 | 472 | avg_train_loss_val += loss 473 | avg_train_loss_num += 1 474 | 475 | if (sbi+1) % print_every == 0: 476 | print_msg("epoch %s, batch %s, avg. loss so far is %s .." % 477 | (epoch, sbi, 478 | avg_train_loss_val/avg_train_loss_num )) 479 | 480 | 481 | if (sbi+1) % switch == 0: 482 | 483 | sample_ent_batches = random.sample(train_ent_batches, 484 | ent_batches_len) 485 | 486 | for ent_batch in sample_ent_batches: 487 | 488 | train_batch(ent_batch, meta_data, 489 | ent_encoder.train(), 490 | decoder.train(), 491 | ent_encoder_optimizer, 492 | decoder_optimizer, 493 | criterion) 494 | 495 | print_msg('batch mixing done ..') 496 | 497 | if (sbi+1) % validate_every == 0: 498 | 499 | avg_valid_loss = evaluate(summ_encoder.eval(), 500 | decoder.eval(), valid_summ_data, 501 | meta_data) 502 | 503 | print_msg("avg. validation loss is %s .." % 504 | (avg_valid_loss)) 505 | 506 | checkpoint(summ_encoder, ent_encoder, decoder, 507 | avg_valid_loss) 508 | 509 | except Exception as e: 510 | print_msg("Exception at epoch %s, batch %s" % (epoch, sbi)) 511 | print_msg(e) 512 | 513 | print_msg("epoch %s done, avg. train loss is %s .." % 514 | (epoch, (avg_train_loss_val/avg_train_loss_num))) 515 | 516 | avg_valid_loss = evaluate(summ_encoder.eval(), decoder.eval(), 517 | valid_summ_data, meta_data) 518 | 519 | print_msg("avg. validation loss is %s .." % (avg_valid_loss)) 520 | 521 | 522 | if __name__ == '__main__': 523 | print_msg('training started ..') 524 | 525 | train(train_summ_batches, 526 | valid_summ_data, 527 | train_ent_batches, 528 | np_embeddings, meta_data, 75, False) 529 | 530 | print_msg('training done ..') -------------------------------------------------------------------------------- /mtl_testing.py: -------------------------------------------------------------------------------- 1 | import random 2 | from io import open 3 | 4 | import pickle 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import texttable as tt 9 | 10 | from mtl_learning import MetaData, EncoderRNN, AttnDecoderRNN 11 | from mtl_learning import print_msg, variables_from_data 12 | 13 | USE_CUDA = False 14 | MAX_LENGTH = 52 15 | SPLIT_RATIOS = {'train': 80, 'validation': 10, 'test': 10} 16 | BATCH_SIZE = 32 17 | SOS_TOKEN = 0 18 | EOS_TOKEN = 1 19 | 20 | def test(encoder, decoder, test_set, 21 | meta_data, max_length=MAX_LENGTH, printEvery=25): 22 | 23 | test_pairs = [variables_from_data(sample, meta_data) for sample in test_set] 24 | criterion = nn.NLLLoss() 25 | avg_loss = 0 26 | nw_output = [] 27 | 28 | for i, pair in enumerate(test_pairs): 29 | 30 | if (i+1) % printEvery == 0: 31 | print_msg("%s samples tested .." % (i)) 32 | 33 | input_variable = pair[0] 34 | target_variable = pair[1] 35 | input_length = input_variable.size()[0] 36 | target_length = target_variable.size()[0] 37 | encoder_hidden = encoder.init_hidden() 38 | 39 | encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) 40 | encoder_outputs = encoder_outputs.cuda() if USE_CUDA else encoder_outputs 41 | 42 | for ei in range(input_length): 43 | encoder_output, encoder_hidden = encoder(input_variable[ei], 44 | encoder_hidden) 45 | encoder_outputs[ei] = encoder_outputs[ei] + encoder_output[0][0] 46 | 47 | decoder_input = Variable(torch.LongTensor([[SOS_TOKEN]])) 48 | decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input 49 | 50 | decoder_hidden = encoder_hidden 51 | 52 | 53 | loss = 0 54 | decoded_words = [] 55 | 56 | for di in range(target_length): 57 | decoder_output,\ 58 | decoder_hidden,\ 59 | decoder_attention = decoder( 60 | decoder_input, decoder_hidden, encoder_outputs) 61 | topv, topi = decoder_output.data.topk(1) 62 | ni = topi[0][0] 63 | 64 | loss += criterion(decoder_output, target_variable[di]) 65 | if ni == EOS_TOKEN: 66 | break 67 | else: 68 | decoded_words.append(meta_data.index_to_word[ni]) 69 | 70 | decoder_input = Variable(torch.LongTensor([[ni]])) 71 | decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input 72 | 73 | 74 | nw_output.append((test_set[i][0], test_set[i][1], ' '.join(decoded_words))) 75 | 76 | 77 | avg_loss += loss.data[0] / target_length 78 | 79 | test_loss = avg_loss / len(test_pairs) 80 | return test_loss, nw_output 81 | 82 | def rogue1_F1(outputs): 83 | avg_r1_f1 = 0 84 | r1_f1_scores = [] 85 | for i, output in enumerate(outputs): 86 | ref_tokens = output[1].split(' ') 87 | sys_tokens = output[2].split(' ') 88 | tt_dict = {} 89 | for token in ref_tokens: 90 | if token in tt_dict: 91 | tt_dict[token] = tt_dict[token] + 1 92 | else: 93 | tt_dict[token] = 0 94 | num_overlaps = 0 95 | for token in sys_tokens: 96 | if token in tt_dict and tt_dict[token] > 0: 97 | tt_dict[token] = tt_dict[token] - 1 98 | num_overlaps += 1 99 | r1_recall = num_overlaps / len(ref_tokens) 100 | r1_precision = num_overlaps / len(sys_tokens) 101 | try: 102 | r1_f1 = 2*(r1_precision * r1_recall) / (r1_precision + r1_recall) 103 | except ZeroDivisionError: 104 | r1_f1 = 0 105 | r1_f1_scores.append((r1_f1, i)) 106 | avg_r1_f1 += r1_f1 107 | return (avg_r1_f1 / len(outputs)), r1_f1_scores 108 | 109 | 110 | def print_sample(outputs): 111 | tab = tt.Texttable() 112 | headings = ['Text','Reference Summary','System Summary'] 113 | tab.header(headings) 114 | for output in outputs: 115 | tab.add_row(list(output)) 116 | s = tab.draw() 117 | print_msg(s) 118 | 119 | 120 | 121 | if __name__ == '__main__': 122 | 123 | encoder = torch.load('checkpoint_models/summ_encoder_06_10_18_22_37_14_1') 124 | decoder = torch.load('checkpoint_models/decoder_06_10_18_22_37_14_1') 125 | meta_data = pickle.load(open('pickles/meta_data.pkl', 'rb')) 126 | test_data = pickle.load(open('pickles/test_summ_data.pkl', 'rb')) 127 | print_msg('models, metadata, and data loaded ..') 128 | print_msg('testing started ..') 129 | loss, nw_outputs = test(encoder.eval(), decoder.eval(), 130 | test_data, meta_data) 131 | print_msg('testing done ..') 132 | print_msg("The negative log likelihood loss for the test data is %s .." % 133 | (loss)) 134 | avg_r1_f1, r1_f1_scores = rogue1_F1(nw_outputs) 135 | print_msg("The average ROGUE-1 F1 measure is %s .." % (avg_r1_f1)) 136 | sample_outputs = random.sample(nw_outputs, 10) 137 | print_msg("Here are some of the sample results .. ") 138 | print_sample(sample_outputs) 139 | -------------------------------------------------------------------------------- /pickles/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quettabit/pytorch_mtl/086c08c5c9f0fae9479765f3f68900347686e0db/pickles/README.md -------------------------------------------------------------------------------- /pickles/ent_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quettabit/pytorch_mtl/086c08c5c9f0fae9479765f3f68900347686e0db/pickles/ent_data.pkl -------------------------------------------------------------------------------- /pickles/summ_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quettabit/pytorch_mtl/086c08c5c9f0fae9479765f3f68900347686e0db/pickles/summ_data.pkl -------------------------------------------------------------------------------- /reloads/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quettabit/pytorch_mtl/086c08c5c9f0fae9479765f3f68900347686e0db/reloads/README.md --------------------------------------------------------------------------------