├── Abstractive_Text_Summarization.ipynb ├── Assets ├── Loss_graph.jpeg └── model_architecture.png └── README.md /Abstractive_Text_Summarization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Pgen_Pointer_Generator.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "metadata": { 20 | "id": "QYoIaaQHTCaT", 21 | "colab_type": "code", 22 | "colab": {} 23 | }, 24 | "source": [ 25 | "import pandas as pd\n", 26 | "import numpy as np\n", 27 | "import re\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "import random\n", 31 | "import torch.nn.functional as F\n", 32 | "from gensim.models import FastText\n", 33 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 34 | ], 35 | "execution_count": 0, 36 | "outputs": [] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "2kjYbmDGDg-Y", 42 | "colab_type": "code", 43 | "colab": {} 44 | }, 45 | "source": [ 46 | "from google.colab import drive\n", 47 | "drive.mount('/content/drive')" 48 | ], 49 | "execution_count": 0, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "WbowAMISDjHJ", 56 | "colab_type": "code", 57 | "colab": {} 58 | }, 59 | "source": [ 60 | "data = pd.read_csv('/content/drive/My Drive/Amazon_Food/Reviews.csv')\n", 61 | "len(data)" 62 | ], 63 | "execution_count": 0, 64 | "outputs": [] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "id": "S7lqLjyLLDXa", 70 | "colab_type": "code", 71 | "colab": {} 72 | }, 73 | "source": [ 74 | "x = data['Text']\n", 75 | "y = data['Summary']" 76 | ], 77 | "execution_count": 0, 78 | "outputs": [] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "id": "9BlfwPHtNXI9", 84 | "colab_type": "code", 85 | "colab": {} 86 | }, 87 | "source": [ 88 | "def clean(text):\n", 89 | " text = str(text)\n", 90 | " text = text.lower()\n", 91 | " text = re.sub(r'\\'s',r'\\tis',text)\n", 92 | " text = re.sub(r'\\'ll',r'\\twill',text)\n", 93 | " text = re.sub(r'\\'m',r'\\tam',text)\n", 94 | " text = re.sub(r'\\'re',r'\\tare',text)\n", 95 | " text = re.sub(r'\\'d',r'\\twould',text)\n", 96 | " text = re.sub(r'n\\'t',r'\\tnot',text)\n", 97 | " text = re.sub('[^a-zA-Z0-9]',' ',text) \n", 98 | " text = re.sub(r\"[-()\\\"#/@;:<>{}`+=~|.!?,]\", \"\", text)\n", 99 | " return text" 100 | ], 101 | "execution_count": 0, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "969v-EnfDlun", 108 | "colab_type": "code", 109 | "colab": {} 110 | }, 111 | "source": [ 112 | "cleaned_source = list(map(clean,x))\n", 113 | "cleaned_summary = list(map(clean,y))\n", 114 | "\n", 115 | "for i in range(len(cleaned_summary)):\n", 116 | " cleaned_summary[i] = \" \" + cleaned_summary[i] + \" \"\n", 117 | " \n", 118 | "print(cleaned_source[11])\n", 119 | "print(cleaned_summary[11])" 120 | ], 121 | "execution_count": 0, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "metadata": { 127 | "id": "0RUaNE15DpNK", 128 | "colab_type": "code", 129 | "colab": {} 130 | }, 131 | "source": [ 132 | "min_source_length = 1999999\n", 133 | "max_source_length = 0\n", 134 | "min_target_length = 199999\n", 135 | "max_target_length = 0\n", 136 | "\n", 137 | "for i in range(len(data)):\n", 138 | " min_source_length = min(min_source_length,len(cleaned_source[i].split()))\n", 139 | " min_target_length = min(min_target_length,len(cleaned_summary[i].split()))\n", 140 | " max_source_length = max(max_source_length,len(cleaned_source[i].split()))\n", 141 | " max_target_length = max(max_target_length,len(cleaned_summary[i].split()))\n", 142 | "\n", 143 | "print(\"Minimum source length is: \",min_source_length)\n", 144 | "print(\"Minimum target length is: \",min_target_length)\n", 145 | "print(\"Maximum source length is: \",max_source_length)\n", 146 | "print(\"Maximum target length is: \",max_target_length)" 147 | ], 148 | "execution_count": 0, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "metadata": { 154 | "id": "2bgI2vw5DsYw", 155 | "colab_type": "code", 156 | "colab": {} 157 | }, 158 | "source": [ 159 | "new_source = []\n", 160 | "new_summary = []\n", 161 | "\n", 162 | "for i in range(len(cleaned_source)):\n", 163 | " if len(cleaned_source[i].split()) <= 50 and len(cleaned_summary[i].split()) <= 15 :\n", 164 | " new_source.append(cleaned_source[i])\n", 165 | " new_summary.append(cleaned_summary[i])\n", 166 | "\n", 167 | "max_source_length = 50\n", 168 | "max_summary_length = 15\n", 169 | "\n", 170 | "print(len(new_source))\n", 171 | "print(len(new_summary))\n", 172 | "\n", 173 | "new_source = new_source[:30000]\n", 174 | "new_summary = new_summary[:30000]" 175 | ], 176 | "execution_count": 0, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "id": "WkhEAcafDbFW", 183 | "colab_type": "code", 184 | "colab": {} 185 | }, 186 | "source": [ 187 | "sentences = new_source + new_summary\n", 188 | "sent_ted = []\n", 189 | "for sent in sentences:\n", 190 | " sent_ted_child = sent.split()\n", 191 | " sent_ted.append(sent_ted_child)\n", 192 | "\n", 193 | "print(sent_ted[0])" 194 | ], 195 | "execution_count": 0, 196 | "outputs": [] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "metadata": { 201 | "id": "ir5TQ3vGZxNe", 202 | "colab_type": "code", 203 | "colab": {} 204 | }, 205 | "source": [ 206 | "# from gensim.models import FastText\n", 207 | "# model_ted = FastText(sent_ted, size=128, window=3, min_count=1, workers=4,sg=1, iter=1500)\n", 208 | "\n", 209 | "# import pickle\n", 210 | "# pickle.dump(model_ted, open('128_emb_1lakhdata.pkl', 'wb'))" 211 | ], 212 | "execution_count": 0, 213 | "outputs": [] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "metadata": { 218 | "id": "lYIUn8RXDX15", 219 | "colab_type": "code", 220 | "colab": {} 221 | }, 222 | "source": [ 223 | "import pickle\n", 224 | "model_ted = pickle.load(open('/content/drive/My Drive/128_emb.pkl', 'rb'))\n", 225 | "weights = model_ted.wv\n", 226 | "print(model_ted.wv.most_similar(\"milk\"))" 227 | ], 228 | "execution_count": 0, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "metadata": { 234 | "id": "otf9wqD83OxL", 235 | "colab_type": "code", 236 | "colab": {} 237 | }, 238 | "source": [ 239 | "from collections import OrderedDict \n", 240 | "\n", 241 | "word2Index_enc = {}\n", 242 | "word2Index_dec = {}\n", 243 | "word2Index_dec_big = {}\n", 244 | "\n", 245 | "ind2Word_enc = {}\n", 246 | "ind2Word_dec = {}\n", 247 | "ind2Word_dec_big = {}\n", 248 | "\n", 249 | "word2PsuInd_dec = {}\n", 250 | "psuInd2Word_dec = {}\n", 251 | "\n", 252 | "encoder_paragraph = list(set((' '.join(new_source)).split()))\n", 253 | "\n", 254 | "decoder_paragraph_list = list((' '.join(new_summary)).split())\n", 255 | "decoder_dict = OrderedDict()\n", 256 | "for word in decoder_paragraph_list:\n", 257 | " try:\n", 258 | " decoder_dict[word] = decoder_dict[word] + 1\n", 259 | " except:\n", 260 | " decoder_dict[word] = 1\n", 261 | "\n", 262 | "ind2Word_enc[0] = ''\n", 263 | "ind2Word_dec[0] = ''\n", 264 | "word2Index_enc[''] = 0\n", 265 | "word2Index_dec[''] = 0\n", 266 | "ind2Word_dec_big[0] = ''\n", 267 | "word2Index_dec_big[''] = 0\n", 268 | "word2PsuInd_dec[''] = 0\n", 269 | "psuInd2Word_dec[0] = ''\n", 270 | "\n", 271 | "dec_index = 1\n", 272 | "for (decoder_dict_word, decoder_dict_number) in decoder_dict.items():\n", 273 | " word2Index_dec_big[decoder_dict_word] = dec_index\n", 274 | " ind2Word_dec_big[dec_index] = decoder_dict_word\n", 275 | " if decoder_dict_number >= 3 :\n", 276 | " word2Index_dec[decoder_dict_word] = dec_index\n", 277 | " ind2Word_dec[dec_index] = decoder_dict_word\n", 278 | " psuedo_index = len(word2PsuInd_dec.keys())\n", 279 | " word2PsuInd_dec[decoder_dict_word] = psuedo_index\n", 280 | " psuInd2Word_dec[psuedo_index] = decoder_dict_word\n", 281 | " dec_index+=1\n", 282 | "\n", 283 | "enc_index = 1\n", 284 | "for index,word in enumerate(encoder_paragraph):\n", 285 | " if word != ' ':\n", 286 | " word2Index_enc[word] = enc_index\n", 287 | " ind2Word_enc[enc_index] = word \n", 288 | " enc_index+=1\n" 289 | ], 290 | "execution_count": 0, 291 | "outputs": [] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "metadata": { 296 | "id": "r9BR3MZbR5D5", 297 | "colab_type": "code", 298 | "colab": {} 299 | }, 300 | "source": [ 301 | "encoder_input = [[word2Index_enc[word] for word in sentence.split() if word in word2Index_enc.keys()] for sentence in new_source ]\n", 302 | "decoder_input = [[word2Index_dec_big[word] for word in sentence.split() if word in word2Index_dec_big.keys()] for sentence in new_summary ]" 303 | ], 304 | "execution_count": 0, 305 | "outputs": [] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "metadata": { 310 | "id": "m1Xrg0kzCW5D", 311 | "colab_type": "code", 312 | "colab": {} 313 | }, 314 | "source": [ 315 | "encoder_tensor = [torch.tensor(li,dtype=torch.long,device=device).view(-1, 1) for li in encoder_input]\n", 316 | "decoder_tensor = [torch.tensor(li,dtype=torch.long,device=device).view(-1, 1) for li in decoder_input]" 317 | ], 318 | "execution_count": 0, 319 | "outputs": [] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "metadata": { 324 | "id": "ErsV--5vShLH", 325 | "colab_type": "code", 326 | "colab": {} 327 | }, 328 | "source": [ 329 | "class Encoder(nn.Module):\n", 330 | " def __init__(self,input_vocab_size,hidden_size,num_layers=1,bidirectional=False):\n", 331 | " super(Encoder,self).__init__()\n", 332 | " self.bidirectional = bidirectional\n", 333 | " self.num_layers = num_layers\n", 334 | " self.hidden_size = hidden_size\n", 335 | " self.input_vocab_size = input_vocab_size\n", 336 | " self.gru_layer = nn.GRU(input_size = self.hidden_size,hidden_size = self.hidden_size,num_layers = self.num_layers)\n", 337 | "\n", 338 | " def forward(self,input_,prev_hidden_state):\n", 339 | " input_word = ind2Word_enc[input_.data.tolist()[0]]\n", 340 | " embedded_outputs = torch.tensor(weights[input_word], device = device).view(1,1,-1)\n", 341 | " output,prev_hidden_state = self.gru_layer(embedded_outputs,prev_hidden_state) #output is batch_size times hidden_size\n", 342 | " return output,prev_hidden_state\n", 343 | "\n", 344 | " def init_hidden(self):\n", 345 | " return torch.zeros(1,1,self.hidden_size,device=device)" 346 | ], 347 | "execution_count": 0, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "metadata": { 353 | "id": "I29Rw8CA7esY", 354 | "colab_type": "code", 355 | "colab": {} 356 | }, 357 | "source": [ 358 | "class AttentionDecoder(nn.Module):\n", 359 | " def __init__(self,output_vocab_size,hidden_size,max_length_encoder,dropout_value,num_layers=1):\n", 360 | " super(AttentionDecoder,self).__init__()\n", 361 | " self.hidden_size = hidden_size\n", 362 | " self.num_layers = num_layers\n", 363 | " self.output_vocab_size = output_vocab_size\n", 364 | " self.dropout_p = dropout_value\n", 365 | " self.max_length_encoder = max_length_encoder\n", 366 | " self.embedding_layer = nn.Embedding(self.output_vocab_size,self.hidden_size)\n", 367 | " self.attention_layer = nn.Linear(self.hidden_size*2,self.max_length_encoder)\n", 368 | " self.attention_combine = nn.Linear(self.hidden_size*2,self.hidden_size)\n", 369 | "\n", 370 | " self.s_layer = nn.Linear(self.hidden_size, 1)\n", 371 | " self.x_layer = nn.Linear(self.hidden_size, 1)\n", 372 | " self.context_layer = nn.Linear(self.hidden_size, 1)\n", 373 | " self.linear_pgen = nn.Linear(3, 1)\n", 374 | "\n", 375 | " self.gru_layer = nn.GRU(self.hidden_size,self.hidden_size)\n", 376 | " self.output_layer = nn.Linear(self.hidden_size,self.output_vocab_size)\n", 377 | " self.dropout_layer = nn.Dropout(self.dropout_p) \n", 378 | "\n", 379 | " def forward(self,input_,prev_hidden_state,encoder_output,prev_unk_word):\n", 380 | " input_word = ind2Word_dec_big[input_.data.tolist()[0]]\n", 381 | " if input_word == '':\n", 382 | " embedded_outputs = torch.tensor(weights[prev_unk_word], device = device).view(1,1,-1)\n", 383 | " else:\n", 384 | " embedded_outputs = torch.tensor(weights[input_word], device = device).view(1,1,-1)\n", 385 | " \n", 386 | " embeddings_dropout = self.dropout_layer(embedded_outputs)\n", 387 | " attention_layer_output = self.attention_layer(torch.cat((embeddings_dropout[0],prev_hidden_state[0]),1))\n", 388 | " attention_weights = nn.functional.softmax(attention_layer_output,dim=1)\n", 389 | " attention_applied = torch.bmm(attention_weights.unsqueeze(0),encoder_output.unsqueeze(0))\n", 390 | " attention_combine_logits = self.attention_combine(torch.cat((embeddings_dropout[0],attention_applied[0]),1)).unsqueeze(0) #since gru requires a batch dimension\n", 391 | " attention_combine_relu = nn.functional.relu(attention_combine_logits)\n", 392 | "\n", 393 | " s_output = self.s_layer(prev_hidden_state[0])\n", 394 | " x_output = self.x_layer(embeddings_dropout[0])\n", 395 | " context = torch.flatten(attention_applied)\n", 396 | " context_weights = self.context_layer(attention_applied)\n", 397 | " sx = torch.cat((s_output[0],x_output[0]),0)\n", 398 | " sxc = torch.cat((sx,context_weights[0][0]),0)\n", 399 | " linear_pgen = self.linear_pgen(sxc)\n", 400 | " m = nn.Sigmoid()\n", 401 | " pgen = m(linear_pgen)\n", 402 | "\n", 403 | " output,hidden = self.gru_layer(attention_combine_relu,prev_hidden_state)\n", 404 | " output_logits = self.output_layer(output)\n", 405 | " output_softmax = nn.functional.log_softmax(output_logits[0],dim=1)\n", 406 | " return output_softmax,hidden,attention_weights,pgen\n", 407 | "\n", 408 | " def init_hidden(self):\n", 409 | " return torch.zeros(1,1,self.hidden_size,device=device)" 410 | ], 411 | "execution_count": 0, 412 | "outputs": [] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "metadata": { 417 | "id": "Y40789vQM8gc", 418 | "colab_type": "code", 419 | "colab": {} 420 | }, 421 | "source": [ 422 | "teacher_forcing_ratio = 0.5\n", 423 | "\n", 424 | "def train(encoder, decoder, input_tensor, target_tensor, encoder_optimizer, decoder_optimizer, criterion, max_length, iters):\n", 425 | "\n", 426 | " encoder_optimizer.zero_grad()\n", 427 | " decoder_optimizer.zero_grad()\n", 428 | "\n", 429 | " prev_unk_word = ''\n", 430 | "\n", 431 | " encoder_hidden = encoder.init_hidden()\n", 432 | "\n", 433 | " encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device = device)\n", 434 | "\n", 435 | " input_length = input_tensor.size(0)\n", 436 | " output_length = target_tensor.size(0)\n", 437 | "\n", 438 | " loss = 0\n", 439 | "\n", 440 | " for encoder_index in range(0, input_length):\n", 441 | " encoder_output,encoder_hidden = encoder(input_tensor[encoder_index], encoder_hidden)\n", 442 | " encoder_outputs[encoder_index] = encoder_output[0,0]\n", 443 | "\n", 444 | " decoder_input = torch.tensor([word2Index_dec['']],device=device) \n", 445 | " decoder_hidden = encoder_hidden\n", 446 | " use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n", 447 | "\n", 448 | " extended_vocab = psuInd2Word_dec.copy()\n", 449 | " reverse_extended_vocab = word2PsuInd_dec.copy()\n", 450 | " duplicate_words = {}\n", 451 | " extend_key = len(word2Index_dec.keys())\n", 452 | " input_list = input_tensor.tolist()\n", 453 | " i =0\n", 454 | " for input_word in input_list:\n", 455 | " if ind2Word_enc[input_word[0]] in word2Index_dec.keys():\n", 456 | " duplicate_words[i] = word2PsuInd_dec[ind2Word_enc[input_word[0]]]\n", 457 | " else:\n", 458 | " extended_vocab[extend_key] = ind2Word_enc[input_word[0]]\n", 459 | " reverse_extended_vocab[ind2Word_enc[input_word[0]]] = extend_key\n", 460 | " extend_key += 1\n", 461 | " i = i+1\n", 462 | "\n", 463 | " if use_teacher_forcing:\n", 464 | " for decoder_index in range(output_length):\n", 465 | " decoder_output,decoder_hidden,decoder_attention,pgen = decoder(decoder_input,decoder_hidden,encoder_outputs,prev_unk_word)\n", 466 | " P_over_extended_vocab = torch.exp(decoder_output)*pgen.expand_as(torch.exp(decoder_output))\n", 467 | "\n", 468 | " decoder_attention = decoder_attention.squeeze(0)[0:input_length].unsqueeze(0)\n", 469 | " p_duplicate_list = torch.zeros([input_length, P_over_extended_vocab.size(1)], device=device)\n", 470 | " p_duplicate_list = p_duplicate_list.tolist()\n", 471 | " for (duplicate_word_key,duplicate_word_value) in duplicate_words.items():\n", 472 | " p_duplicate_list[duplicate_word_key][duplicate_word_value] = 1\n", 473 | " p_duplicate = torch.tensor(p_duplicate_list, dtype=torch.float, device=device)\n", 474 | " p_diag = torch.mm(decoder_attention, p_duplicate)\n", 475 | " p_diag = p_diag*(torch.tensor([1], device=device).sub(pgen)).expand_as(p_diag)\n", 476 | " p_add_diag = torch.diag(p_diag.squeeze(0),diagonal=0)\n", 477 | " P_over_extended_vocab = torch.mm(P_over_extended_vocab,p_add_diag).add(P_over_extended_vocab)\n", 478 | "\n", 479 | " for i in range(input_length):\n", 480 | " if not (1 in p_duplicate_list[i]):\n", 481 | " P_over_extended_vocab = torch.cat((P_over_extended_vocab[0], torch.mm(decoder_attention.squeeze(0)[i].unsqueeze(0).unsqueeze(0), torch.tensor([1], device=device).sub(pgen).unsqueeze(0)).squeeze(0)),0).unsqueeze(0)\n", 482 | "\n", 483 | " try:\n", 484 | " loss += -torch.log(P_over_extended_vocab[0][ reverse_extended_vocab[ ind2Word_dec_big[ target_tensor[decoder_index].item() ] ] ] + 1e-12)\n", 485 | " loss.backward(retain_graph=True)\n", 486 | " except KeyError:\n", 487 | " loss += torch.tensor(0,dtype=torch.float,device=device)\n", 488 | " decoder_input = target_tensor[decoder_index]\n", 489 | " else:\n", 490 | "\n", 491 | " for decoder_index in range(output_length):\n", 492 | " decoder_output,decoder_hidden,decoder_attention,pgen = decoder(decoder_input,decoder_hidden,encoder_outputs,prev_unk_word) \n", 493 | " P_over_extended_vocab = torch.exp(decoder_output)*pgen.expand_as(torch.exp(decoder_output))\n", 494 | "\n", 495 | " decoder_attention = decoder_attention.squeeze(0)[0:input_length].unsqueeze(0)\n", 496 | " p_duplicate_list = torch.zeros([input_length, P_over_extended_vocab.size(1)], device=device)\n", 497 | " p_duplicate_list = p_duplicate_list.tolist()\n", 498 | " for (duplicate_word_key,duplicate_word_value) in duplicate_words.items():\n", 499 | " p_duplicate_list[duplicate_word_key][duplicate_word_value] = 1\n", 500 | " p_duplicate = torch.tensor(p_duplicate_list, dtype=torch.float, device=device)\n", 501 | " p_diag = torch.mm(decoder_attention, p_duplicate)\n", 502 | " p_diag = p_diag*(torch.tensor([1], device=device).sub(pgen)).expand_as(p_diag)\n", 503 | " p_add_diag = torch.diag(p_diag.squeeze(0),diagonal=0)\n", 504 | " P_over_extended_vocab = torch.mm(P_over_extended_vocab,p_add_diag).add(P_over_extended_vocab)\n", 505 | "\n", 506 | " for i in range(input_length):\n", 507 | " if not (1 in p_duplicate_list[i]):\n", 508 | " P_over_extended_vocab = torch.cat((P_over_extended_vocab[0], torch.mm(decoder_attention.squeeze(0)[i].unsqueeze(0).unsqueeze(0), torch.tensor([1], device=device).sub(pgen).unsqueeze(0)).squeeze(0)),0).unsqueeze(0)\n", 509 | "\n", 510 | " try:\n", 511 | " loss += -torch.log(P_over_extended_vocab[0][ reverse_extended_vocab[ ind2Word_dec_big[ target_tensor[decoder_index].item() ] ] ] + 1e-12)\n", 512 | " loss.backward(retain_graph=True)\n", 513 | " except KeyError:\n", 514 | " loss += torch.tensor(0,dtype=torch.float,device=device)\n", 515 | " idx = torch.topk(P_over_extended_vocab, k=1, dim=1)[1]\n", 516 | " if idx.item() < len(word2Index_dec.keys()): \n", 517 | " decoder_input = torch.tensor([idx.item()],dtype=torch.long,device=device)\n", 518 | " elif idx.item() >= len(word2Index_dec.keys()):\n", 519 | " prev_unk_word = extended_vocab[idx.item()]\n", 520 | " decoder_input = torch.tensor([0],dtype=torch.long,device=device)\n", 521 | " elif (decoder_input.item() == word2Index_dec['']):\n", 522 | " break\n", 523 | "\n", 524 | " if iters > 20000:\n", 525 | " torch.nn.utils.clip_grad_norm_(rnn_encoder.parameters(),0.4)\n", 526 | " torch.nn.utils.clip_grad_norm_(rnn_decoder.parameters(),0.4)\n", 527 | "\n", 528 | " encoder_optimizer.step()\n", 529 | " decoder_optimizer.step()\n", 530 | "\n", 531 | " return loss.item()/output_length" 532 | ], 533 | "execution_count": 0, 534 | "outputs": [] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "metadata": { 539 | "id": "ocM_vNF6KTWW", 540 | "colab_type": "code", 541 | "colab": {} 542 | }, 543 | "source": [ 544 | "import time\n", 545 | "import math\n", 546 | "\n", 547 | "\n", 548 | "def asMinutes(s):\n", 549 | " m = math.floor(s / 60)\n", 550 | " s -= m * 60\n", 551 | " return '%dm %ds' % (m, s)\n", 552 | "\n", 553 | "\n", 554 | "def timeSince(since, percent):\n", 555 | " now = time.time()\n", 556 | " s = now - since\n", 557 | " if percent != 0:\n", 558 | " es = s / (percent)\n", 559 | " rs = es - s\n", 560 | " return '%s (- %s)' % (asMinutes(s), asMinutes(rs))\n", 561 | " else:\n", 562 | " return 0" 563 | ], 564 | "execution_count": 0, 565 | "outputs": [] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "metadata": { 570 | "id": "CQnTrJ0SLSYq", 571 | "colab_type": "code", 572 | "colab": {} 573 | }, 574 | "source": [ 575 | "arr = np.arange(len(encoder_tensor))\n", 576 | "np.random.shuffle(arr)\n", 577 | "len(arr)" 578 | ], 579 | "execution_count": 0, 580 | "outputs": [] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "metadata": { 585 | "id": "pBToTQxGLXgI", 586 | "colab_type": "code", 587 | "colab": {} 588 | }, 589 | "source": [ 590 | "# Dictionary for creating loss graph\n", 591 | "loss_graph = {}\n", 592 | "\n", 593 | "def train_Iters(encoder,decoder,n_iters,print_every=50, plot_every=100,learning_rate = 0.03):\n", 594 | " # start = time.time()\n", 595 | " plot_losses = []\n", 596 | " print_loss_total = 0 # Reset every print_every\n", 597 | " plot_loss_total = 0\n", 598 | "\n", 599 | " encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=learning_rate)\n", 600 | " decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=learning_rate)\n", 601 | " training_pairs = [random.choice(pairs) for i in range(n_iters)]\n", 602 | " \n", 603 | " criterion = nn.NLLLoss()\n", 604 | " for iters in range(n_iters):\n", 605 | " training_pair = training_pairs[iters - 1]\n", 606 | " input_tensor = training_pair[0]\n", 607 | " target_tensor = training_pair[1]\n", 608 | "\n", 609 | " input_tensor = torch.tensor(input_tensor, dtype=torch.long, device = device).view(-1, 1)\n", 610 | " target_tensor = torch.tensor(target_tensor, dtype=torch.long, device = device).view(-1, 1)\n", 611 | "\n", 612 | " loss = train(encoder,decoder,input_tensor,target_tensor,encoder_optimizer,decoder_optimizer,criterion,max_source_length, iters)\n", 613 | " print_loss_total += loss\n", 614 | " plot_loss_total += loss\n", 615 | "\n", 616 | " if iters % print_every == 0:\n", 617 | " print_loss_avg = print_loss_total / print_every\n", 618 | " print_loss_total = 0\n", 619 | " print('%s %d%%) %.4f' % (iters, iters / len(arr) * 100, print_loss_avg))\n", 620 | "\n", 621 | " if iters > 0:\n", 622 | " loss_graph[iters] = print_loss_avg\n", 623 | "\n", 624 | " if iters % plot_every == 0:\n", 625 | " plot_loss_avg = plot_loss_total / plot_every\n", 626 | " plot_losses.append(plot_loss_avg)\n", 627 | " plot_loss_total = 0\n", 628 | "\n", 629 | " # showPlot(plot_losses)\n", 630 | " " 631 | ], 632 | "execution_count": 0, 633 | "outputs": [] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "metadata": { 638 | "id": "mDEs3lSzLcrt", 639 | "colab_type": "code", 640 | "colab": {} 641 | }, 642 | "source": [ 643 | "import matplotlib.pyplot as plt\n", 644 | "plt.switch_backend('agg')\n", 645 | "import matplotlib.ticker as ticker\n", 646 | "import numpy as np\n", 647 | "\n", 648 | "\n", 649 | "def showPlot(points):\n", 650 | " plt.figure()\n", 651 | " fig, ax = plt.subplots()\n", 652 | " # this locator puts ticks at regular intervals\n", 653 | " loc = ticker.MultipleLocator(base=0.2)\n", 654 | " ax.yaxis.set_major_locator(loc)\n", 655 | " plt.plot(points)" 656 | ], 657 | "execution_count": 0, 658 | "outputs": [] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "metadata": { 663 | "id": "Jgo_fbnFLgxE", 664 | "colab_type": "code", 665 | "colab": {} 666 | }, 667 | "source": [ 668 | "pairs = []\n", 669 | "for enc,dec in zip(encoder_input,decoder_input):\n", 670 | " pairs.append([enc,dec])" 671 | ], 672 | "execution_count": 0, 673 | "outputs": [] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "metadata": { 678 | "id": "uAJAoFoJDLrm", 679 | "colab_type": "code", 680 | "colab": {} 681 | }, 682 | "source": [ 683 | "hidden_size = 128\n", 684 | "rnn_encoder = Encoder(len(word2Index_enc.keys()),hidden_size).to(device=device)\n", 685 | "rnn_decoder = AttentionDecoder(len(word2Index_dec.keys()),hidden_size,max_source_length,0.2).to(device=device)\n", 686 | "\n", 687 | "train_Iters(rnn_encoder,rnn_decoder,100000)" 688 | ], 689 | "execution_count": 0, 690 | "outputs": [] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "metadata": { 695 | "id": "5IbiBvMaDO3k", 696 | "colab_type": "code", 697 | "colab": {} 698 | }, 699 | "source": [ 700 | "import matplotlib.pyplot as plt\n", 701 | "\n", 702 | "iters = list(loss_graph.keys())\n", 703 | "loss_val = list(loss_graph.values())\n", 704 | "plt.plot(iters, loss_val)\n", 705 | "plt.ylim(0,8) \n", 706 | "plt.xlim(0,77350)\n", 707 | "plt.xlabel('Iterations') \n", 708 | "plt.ylabel('Loss') \n", 709 | "plt.show()" 710 | ], 711 | "execution_count": 0, 712 | "outputs": [] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "metadata": { 717 | "id": "kHA9RnrZgk6j", 718 | "colab_type": "code", 719 | "colab": {} 720 | }, 721 | "source": [ 722 | "def evaluateRandomly(encoder, decoder, n=10):\n", 723 | " for i in range(n):\n", 724 | " pair = random.choice(pairs)\n", 725 | " pair1 = torch.tensor(pair[0],dtype=torch.long,device=device)\n", 726 | " pair2 = pair[1]\n", 727 | " output_words, attentions = evaluate(encoder, decoder, pair1)\n", 728 | " output_sentence = ' '.join(output_words)\n", 729 | " output_list = [ind2Word_dec_big[word] for word in pair2]\n", 730 | " output_list = ' '.join(output_list)\n", 731 | " input_sentence = [ind2Word_enc[element.item()] for element in pair1.flatten()]\n", 732 | " input_sentence = ' '.join(input_sentence)\n", 733 | " print(\"Sentence is \",input_sentence)\n", 734 | " print('<',output_sentence)\n", 735 | " print('=',output_list)" 736 | ], 737 | "execution_count": 0, 738 | "outputs": [] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "metadata": { 743 | "id": "ObARxnAyoUTA", 744 | "colab_type": "code", 745 | "colab": {} 746 | }, 747 | "source": [ 748 | "def evaluate(encoder, decoder, encoder_tensor, max_length=max_source_length):\n", 749 | " with torch.no_grad():\n", 750 | " input_tensor = encoder_tensor\n", 751 | " input_length = input_tensor.size(0)\n", 752 | " encoder_hidden = encoder.init_hidden()\n", 753 | "\n", 754 | " prev_unk_word = ''\n", 755 | "\n", 756 | " encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n", 757 | "\n", 758 | " for ei in range(input_length):\n", 759 | " encoder_output, encoder_hidden = encoder(input_tensor[ei].unsqueeze(0),\n", 760 | " encoder_hidden)\n", 761 | " encoder_outputs[ei] += encoder_output[0, 0]\n", 762 | "\n", 763 | " extended_vocab = psuInd2Word_dec.copy()\n", 764 | " duplicate_words = {}\n", 765 | " extend_key = len(word2Index_dec.keys())\n", 766 | " input_list = input_tensor.tolist()\n", 767 | " i =0\n", 768 | " for input_word in input_list:\n", 769 | " if ind2Word_enc[input_word] in word2Index_dec.keys():\n", 770 | " duplicate_words[i] = word2PsuInd_dec[ind2Word_enc[input_word]]\n", 771 | " else:\n", 772 | " extended_vocab[extend_key] = ind2Word_enc[input_word]\n", 773 | " extend_key += 1\n", 774 | " i = i+1\n", 775 | "\n", 776 | " decoder_input = torch.tensor([word2Index_dec['']], device=device) # SOS\n", 777 | "\n", 778 | " decoder_hidden = encoder_hidden\n", 779 | "\n", 780 | " decoded_words = []\n", 781 | " decoder_attentions = torch.zeros(max_length, max_length)\n", 782 | "\n", 783 | " for di in range(max_length):\n", 784 | " decoder_output, decoder_hidden, decoder_attention,pgen = decoder(\n", 785 | " decoder_input, decoder_hidden, encoder_outputs, prev_unk_word)\n", 786 | " decoder_attentions[di] = decoder_attention.data\n", 787 | "\n", 788 | " P_over_extended_vocab = torch.exp(decoder_output)*pgen.expand_as(torch.exp(decoder_output))\n", 789 | "\n", 790 | " decoder_attention = decoder_attention.squeeze(0)[0:input_length].unsqueeze(0)\n", 791 | " p_duplicate_list = torch.zeros([input_length, P_over_extended_vocab.size(1)], device=device)\n", 792 | " p_duplicate_list = p_duplicate_list.tolist()\n", 793 | " for (duplicate_word_key,duplicate_word_value) in duplicate_words.items():\n", 794 | " p_duplicate_list[duplicate_word_key][duplicate_word_value] = 1\n", 795 | " p_duplicate = torch.tensor(p_duplicate_list, dtype=torch.float, device=device)\n", 796 | " p_diag = torch.mm(decoder_attention, p_duplicate)\n", 797 | " p_diag = p_diag*(torch.tensor([1], device=device).sub(pgen)).expand_as(p_diag)\n", 798 | " p_add_diag = torch.diag(p_diag.squeeze(0),diagonal=0)\n", 799 | " P_over_extended_vocab = torch.mm(P_over_extended_vocab,p_add_diag).add(P_over_extended_vocab)\n", 800 | "\n", 801 | " for i in range(input_length):\n", 802 | " if not (1 in p_duplicate_list[i]):\n", 803 | " P_over_extended_vocab = torch.cat((P_over_extended_vocab[0], torch.mm(decoder_attention.squeeze(0)[i].unsqueeze(0).unsqueeze(0), torch.tensor([1], device=device).sub(pgen).unsqueeze(0)).squeeze(0)),0).unsqueeze(0)\n", 804 | "\n", 805 | " idx = torch.topk(P_over_extended_vocab, k=1, dim=1)[1]\n", 806 | " if idx.item() < len(word2Index_dec.keys()): \n", 807 | " decoder_input = torch.tensor([idx.item()],dtype=torch.long,device=device)\n", 808 | " decoded_words.append(extended_vocab[idx.item()])\n", 809 | " elif idx.item() >= len(word2Index_dec.keys()):\n", 810 | " decoder_input = torch.tensor([0],dtype=torch.long,device=device)\n", 811 | " prev_unk_word = extended_vocab[idx.item()]\n", 812 | " decoded_words.append(extended_vocab[idx.item()])\n", 813 | " if idx.item() == word2Index_dec['']:\n", 814 | " decoded_words.append('')\n", 815 | " break\n", 816 | "\n", 817 | " return decoded_words, decoder_attentions[:di + 1]" 818 | ], 819 | "execution_count": 0, 820 | "outputs": [] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "metadata": { 825 | "id": "oYYHCcO04vV5", 826 | "colab_type": "code", 827 | "colab": {} 828 | }, 829 | "source": [ 830 | "evaluateRandomly(rnn_encoder, rnn_decoder)" 831 | ], 832 | "execution_count": 0, 833 | "outputs": [] 834 | } 835 | ] 836 | } -------------------------------------------------------------------------------- /Assets/Loss_graph.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Developer-Zer0/Get-To-The-Point-Summarization-with-Pointer-Generator-Networks/1a6f4a0fdc2ef779df7ae4b30245832b34667d42/Assets/Loss_graph.jpeg -------------------------------------------------------------------------------- /Assets/model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Developer-Zer0/Get-To-The-Point-Summarization-with-Pointer-Generator-Networks/1a6f4a0fdc2ef779df7ae4b30245832b34667d42/Assets/model_architecture.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Get To The Point: Summarization with Pointer-Generator Networks 2 | Pytorch implementation of [Get To The Point: Summarization with Pointer-Generator Networks (2017)](https://arxiv.org/pdf/1704.04368.pdf) by Abigail See et al. 3 | 4 | ## Model Description 5 | * LSTM based Sequence-to-Sequence model for Abstractive Summarization 6 | * Pointer mechanism for handling Out of Vocabulary (OOV) words [See et al. (2017)](https://arxiv.org/pdf/1704.04368.pdf) 7 | * FastText used for creating embeddings over the dataset . 8 | 9 | ## Model Architecture 10 | 11 |

12 | 13 |

14 | 15 | ## Prerequisites 16 | * Pytorch 17 | * gensim 18 | * python 3 19 | 20 | ## Data 21 | Data set used : [Kaggle:Amazon fine food reviews](https://www.kaggle.com/snap/amazon-fine-food-reviews) 22 | 23 | 24 | # Examples 25 | ```Sentence:``` once more amazon was great the product is good for kids even though it has a little bit more sugar than needed
26 | ```Predicted Summary:``` good as expected
27 | ```Actual Summary:``` as expected
28 | 29 | ```Sentence:``` this is an excellent tea for a breakfast tea or for the afternoon or evening it has a wonderful mellow flavor in the morning i like to brew it with earl grey to create a nice smooth blend it is a great way to start the day
30 | ```Predicted Summary:``` great product tea
31 | ```Actual Summary:``` wonderful anytime tea
32 | 33 | ```Sentence:``` very tasty this mix makes a moist yummy cake we make our own cream cheese frosting with the chocolate cake and it is a winner
34 | ```Predicted Summary:``` tasty
35 | ```Actual Summary:``` versatile and yummy
36 | 37 | ```Sentence:``` ca not complain taste good and had quick delivery it was my first time trying this tea out i usually drink the peppermint one but this gave me energy and sustained me throughout the day
38 | ```Predicted Summary:``` very good good
39 | ```Actual Summary:``` very good
40 | 41 | ```Sentence:``` energy bites f ing rip your face off with molten lava energy br br infinity energy to the f ing max i ate a whole box once and threw a car at a baby br br f ing rave br br seriously these are great
42 | ```Predicted Summary:``` love it
43 | ```Actual Summary:``` maximum rave power
44 | 45 | ```Sentence:``` looks good from the package but does not taste good at all hard crunchy freeze dried flavor br save your money guys
46 | ```Predicted Summary:``` yuck
47 | ```Actual Summary:``` yuck
48 | 49 | ```Sentence:``` this is the best dog food their is because everything is very digestible and when the dog does digest it all digested because they use all the nutrition from it so it is healther and the best out their far as i am think
50 | ```Predicted Summary:``` best dog food
51 | ```Actual Summary:``` natures logic venisen
52 | 53 | ```Sentence:``` great coffee br excellent service br best way to buy k cups br stock up before coffee prices go up again
54 | ```Predicted Summary:``` great coffee
55 | ```Actual Summary:``` great coffee excellent service
56 | 57 | ```Sentence:``` wonderful flavor would purchase this blend of coffee again light flavor not bitter at all and price was great the best i found anywhere
58 | ```Predicted Summary:``` great flavor
59 | ```Actual Summary:``` wolfgang puck k cup breakfast in bed
60 | 61 | ```Sentence:``` no wonder they were so cheap they do not work very well in my machine because of the biodegradable packaging i would not buy them again
62 | ```Predicted Summary:``` disappointed
63 | ```Actual Summary:``` ethical coffee nespresso capsules
64 | 65 | ```Sentence:``` i bought these from a large chain pet store after reading the reviews i checked the bag made in china i threw the whole bag away i wish i would have read the reviews first
66 | ```Predicted Summary:``` do not buy
67 | ```Actual Summary:``` do not buy
68 | 69 | ```Sentence:``` i love these gums they are not as cloyingly sweet as american gummies and have a lot more fruit flavor the only problem is that i ca not eat just one
70 | ```Predicted Summary:``` a addictive
71 | ```Actual Summary:``` danger highly addictive
72 | 73 | ```Sentence:``` the pepper plant habanero extra hot california style hot pepper sauce 10 oz has great flavor as all the pepper plants do i just love it it is a bit pricey but worth it
74 | ```Predicted Summary:``` great seasoning
75 | ```Actual Summary:``` wonderful love it
76 | 77 | ------------------------------------------ 78 | 79 |

Loss for 75k iterations


80 | 81 |

82 | 83 |

84 | 85 | 86 | ------------------------------------------ 87 | ## Contributors: 88 | 89 | - [Ankur Chemburkar](https://github.com/Developer-Zer0) 90 | - [Talha Chafekar](https://github.com/talha1503) 91 | 92 | ------------------------------------------ 93 | ## References 94 | * [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368) 95 | --------------------------------------------------------------------------------