├── README.md ├── conditional_bert_contextual_augmentation.ipynb ├── unsupervised_extractive_summarization_with_bert.ipynb └── unsupervised_lexical_simplification_with_bert.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Use-cases of Hugging Face's BERT 2 | 3 | This notebook show use-cases of BERT using [pytorch-transformers](https://github.com/huggingface/pytorch-transformers). 4 | 5 | - [Unsupervised Lexical Simplification](https://github.com/marucha80t/use-cases_of_bert/blob/master/lexical_substitution_with_bert.ipynb) 6 | - [Conditional BERT Contextual Augmentation](https://colab.research.google.com/github/tkmaroon/use-cases-of-bert/blob/master/conditional_bert_contextual_augmentation.ipynb) 7 | - [Unsupervised Extractive Summarization](https://github.com/marucha80t/use-cases_of_bert/blob/master/unsupervised_extractive_summarization_with_bert.ipynb) 8 |
9 | 10 | 11 | ## References 12 | 13 | - [https://github.com/huggingface/pytorch-transformers](https://github.com/huggingface/pytorch-transformers) 14 | - [A Simple BERT-Based Approach for Lexical Simplification](https://arxiv.org/abs/1907.06226) 15 | - [Conditional BERT Contextual Augmentation](https://arxiv.org/abs/1812.06705) 16 | - [Simple Unsupervised Keyphrase Extraction using Sentence Embeddings](https://arxiv.org/abs/1801.04470) 17 | -------------------------------------------------------------------------------- /conditional_bert_contextual_augmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "lexical_substitution_with_bert.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "i7LF71CEeB0M", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "# Conditional BERT Contextual Augmentation\n", 36 | "## Overview\n", 37 | "This notebook performs lexical substitution using BERT of Hugging face.\n", 38 | "It shows how to constrain substitution candidates following the *Conditional BERT Contextual Augmentation*.\n", 39 | "\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "DpxaNWMmdkO7", 46 | "colab_type": "text" 47 | }, 48 | "source": [ 49 | "## Settings" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "ZMX_w_VydE7j", 56 | "colab_type": "code", 57 | "cellView": "form", 58 | "outputId": "bafe7f61-7338-43f4-bf81-f0fbe43dc6e6", 59 | "colab": { 60 | "base_uri": "https://localhost:8080/", 61 | "height": 102 62 | } 63 | }, 64 | "source": [ 65 | "#@title Setup environment\n", 66 | "!pip install --quiet pytorch-transformers\n", 67 | "!pip install --quiet pytorch-nlp\n", 68 | "!pip install --quiet tqdm" 69 | ], 70 | "execution_count": 2, 71 | "outputs": [ 72 | { 73 | "output_type": "stream", 74 | "text": [ 75 | "\u001b[?25l\r\u001b[K |█▉ | 10kB 31.2MB/s eta 0:00:01\r\u001b[K |███▊ | 20kB 6.7MB/s eta 0:00:01\r\u001b[K |█████▋ | 30kB 9.3MB/s eta 0:00:01\r\u001b[K |███████▍ | 40kB 6.4MB/s eta 0:00:01\r\u001b[K |█████████▎ | 51kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████▏ | 61kB 9.3MB/s eta 0:00:01\r\u001b[K |█████████████ | 71kB 10.6MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 81kB 11.8MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 92kB 13.1MB/s eta 0:00:01\r\u001b[K |██████████████████▋ | 102kB 10.3MB/s eta 0:00:01\r\u001b[K |████████████████████▍ | 112kB 10.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▎ | 122kB 10.3MB/s eta 0:00:01\r\u001b[K |████████████████████████▏ | 133kB 10.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 143kB 10.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 153kB 10.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▊ | 163kB 10.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 174kB 10.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 184kB 10.3MB/s \n", 76 | "\u001b[K |████████████████████████████████| 1.0MB 54.1MB/s \n", 77 | "\u001b[K |████████████████████████████████| 870kB 77.6MB/s \n", 78 | "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 79 | "\u001b[K |████████████████████████████████| 92kB 6.8MB/s \n", 80 | "\u001b[?25h" 81 | ], 82 | "name": "stdout" 83 | } 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "ckESRGmKdMrh", 90 | "colab_type": "code", 91 | "cellView": "form", 92 | "colab": {} 93 | }, 94 | "source": [ 95 | "#@title Setup common imports\n", 96 | "import random\n", 97 | "import math\n", 98 | "from collections import OrderedDict\n", 99 | "from tqdm import tqdm\n", 100 | "\n", 101 | "import torch\n", 102 | "from torch.nn import CrossEntropyLoss\n", 103 | "import torch.nn.functional as F\n", 104 | "\n", 105 | "from torchnlp.datasets import smt_dataset\n", 106 | "\n", 107 | "from pytorch_transformers import (\n", 108 | " BertConfig,\n", 109 | " BertTokenizer,\n", 110 | " BertForMaskedLM,\n", 111 | " BertForTokenClassification,\n", 112 | " AdamW,\n", 113 | " WarmupLinearSchedule,\n", 114 | ")\n", 115 | "\n", 116 | "import matplotlib.pyplot as plt\n", 117 | "% matplotlib inline" 118 | ], 119 | "execution_count": 0, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": { 125 | "id": "fWT62eJMJYHF", 126 | "colab_type": "text" 127 | }, 128 | "source": [ 129 | "## Examples\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "C3Ju8s5k4dCs", 136 | "colab_type": "text" 137 | }, 138 | "source": [ 139 | "### Conditional BERT Contextual Augmentation\n" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "metadata": { 145 | "id": "84InK9iVi4K4", 146 | "colab_type": "code", 147 | "cellView": "both", 148 | "colab": {} 149 | }, 150 | "source": [ 151 | "#@title setup common functions\n", 152 | "class BatchIterator(object):\n", 153 | " def __init__(self, tokenizer, data, batchsize, device, shuffle=True, repeat=False):\n", 154 | " self.tokenizer = tokenizer\n", 155 | " self.pad_idx = self.tokenizer.convert_tokens_to_ids('[PAD]')\n", 156 | " self._numericalize(data)\n", 157 | " self.batchsize = batchsize\n", 158 | " self.device = device\n", 159 | " self.shuffle = shuffle\n", 160 | " self.repeat = repeat\n", 161 | " \n", 162 | " def __len__(self):\n", 163 | " return math.ceil(len(self.data)/self.batchsize)\n", 164 | "\n", 165 | " def __iter__(self):\n", 166 | " while True:\n", 167 | " self._init_batches()\n", 168 | " for batch in self.batches:\n", 169 | " yield batch.to(self.device)\n", 170 | " if not self.repeat:\n", 171 | " return\n", 172 | " \n", 173 | " def _numericalize(self, data):\n", 174 | " self.data = [self.tokenizer.encode(s) for s in data]\n", 175 | "\n", 176 | " def _init_batches(self):\n", 177 | " data = random.sample(self.data, len(self.data)) if self.shuffle else self.data\n", 178 | " self.batches = [self._padding(data[i:i+self.batchsize]) for i in range(0, len(self.data), self.batchsize)]\n", 179 | " \n", 180 | " def _padding(self, batch):\n", 181 | " maxlen = max([len(b) for b in batch])\n", 182 | " return torch.tensor([b + [self.pad_idx for _ in range(maxlen-len(b))] for b in batch]) \n", 183 | " \n", 184 | " \n", 185 | "class Perturbator(object):\n", 186 | " def __init__(self, tokenizer, vocab_range, sampling_rate=0.15, \\\n", 187 | " masking_ratio=0.8, replacing_ratio=0.1, unchanging_ratio=0.1):\n", 188 | " self.mask_idx = tokenizer.mask_token_id\n", 189 | " self.pad_idx = tokenizer.pad_token_id\n", 190 | " self.vocab_range = vocab_range\n", 191 | " self.sampling_rate = sampling_rate\n", 192 | " self.masking_ratio = masking_ratio\n", 193 | " self.replacing_ratio = replacing_ratio\n", 194 | " self.unchanging_ratio = unchanging_ratio\n", 195 | " assert (self.masking_ratio + self.replacing_ratio + self.unchanging_ratio) == 1.0, \\\n", 196 | " '`masking_ratio + replacing_ratio + unchanging_ratio` must be 1.0'\n", 197 | "\n", 198 | " def __call__(self, batch):\n", 199 | " device = batch.device\n", 200 | " bsz, slen = batch.size()\n", 201 | " batch = batch.to(torch.device('cpu'))\n", 202 | " sampler = (torch.rand((bsz, slen)).le(self.sampling_rate)) & batch.ne(self.pad_idx) # [PAD] tokens are not sampled\n", 203 | " sampler[:, 0] = 0 # [CLS] tokens are not sampled\n", 204 | " \n", 205 | " masked_lm_labels = torch.where(\n", 206 | " sampler,\n", 207 | " batch,\n", 208 | " torch.ones_like(batch) * -100\n", 209 | " )\n", 210 | "\n", 211 | " rnd = torch.rand((bsz, slen))\n", 212 | " batch = torch.where(\n", 213 | " (self.masking_ratio >= rnd) & sampler,\n", 214 | " torch.ones_like(batch) * self.mask_idx,\n", 215 | " batch\n", 216 | " )\n", 217 | " \n", 218 | " th = self.replacing_ratio + self.masking_ratio\n", 219 | " batch = torch.where(\n", 220 | " ((th >= rnd) & (rnd > self.masking_ratio) & sampler), \n", 221 | " torch.randint_like(batch, self.vocab_range[0], self.vocab_range[1]),\n", 222 | " batch,\n", 223 | " )\n", 224 | " return batch.to(device), masked_lm_labels.to(device)" 225 | ], 226 | "execution_count": 0, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "metadata": { 232 | "id": "VTV29b1g4lpf", 233 | "colab_type": "code", 234 | "cellView": "both", 235 | "outputId": "e77630f4-5d63-4a17-be4f-f80b0d1d5629", 236 | "colab": { 237 | "base_uri": "https://localhost:8080/", 238 | "height": 153 239 | } 240 | }, 241 | "source": [ 242 | "#@title Load training data and preprocessing\n", 243 | "\n", 244 | "train_smt = smt_dataset(train=True)\n", 245 | "\n", 246 | "print('----- DETAILS OF TRAINING DATA -----')\n", 247 | "print('* Data Summary')\n", 248 | "for i in range(3):\n", 249 | " print(train_smt[i]['label'] + '\\t' + train_smt[i]['text'][:100] + '...')\n", 250 | "print('...')\n", 251 | "print(f'* Train data size: {len(train_smt)}')\n", 252 | "\n", 253 | "unique_labels = set([data['label'] for data in train_smt])\n", 254 | "print(f'* class: {unique_labels}')\n", 255 | "\n", 256 | "train_data = ['[' + data['label'].upper() + '] ' + data['text'] for data in train_smt]\n" 257 | ], 258 | "execution_count": 6, 259 | "outputs": [ 260 | { 261 | "output_type": "stream", 262 | "text": [ 263 | "----- DETAILS OF TRAINING DATA -----\n", 264 | "* Data Summary\n", 265 | "positive\tThe Rock is destined to be the 21st Century 's new `` Conan '' and that he 's going to make a splash...\n", 266 | "positive\tThe gorgeously elaborate continuation of `` The Lord of the Rings '' trilogy is so huge that a colum...\n", 267 | "positive\tSinger\\/composer Bryan Adams contributes a slew of songs -- a few potential hits , a few more simply...\n", 268 | "...\n", 269 | "* Train data size: 8544\n", 270 | "* class: {'negative', 'positive', 'neutral'}\n" 271 | ], 272 | "name": "stdout" 273 | } 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "metadata": { 279 | "id": "SKdU3O3l5faT", 280 | "colab_type": "code", 281 | "cellView": "form", 282 | "colab": {} 283 | }, 284 | "source": [ 285 | "#@title Setup tokenizer and iterator\n", 286 | "\n", 287 | "# Set tokenizer \n", 288 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", 289 | "added_special_tokens = [ '[' + label.upper() + ']' for label in iter(unique_labels)]\n", 290 | "tokenizer.add_tokens(added_special_tokens)\n", 291 | "\n", 292 | "batch_size = 32#@param {type:\"integer\"}\n", 293 | "assert batch_size > 0, 'Please set `batch_size` value more than zero'\n", 294 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 295 | "train_iter = BatchIterator(tokenizer, train_data, batch_size, device)\n", 296 | "#train_iter = BatchIterator(tokenizer, small_train, batch_size, device)\n" 297 | ], 298 | "execution_count": 0, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "metadata": { 304 | "id": "JahA0MmVibbg", 305 | "colab_type": "code", 306 | "cellView": "both", 307 | "outputId": "8bf68a8e-9073-401a-c3dc-4c5aeefa0b80", 308 | "colab": { 309 | "base_uri": "https://localhost:8080/", 310 | "height": 34 311 | } 312 | }, 313 | "source": [ 314 | "#@title Build a model\n", 315 | "config = BertConfig.from_pretrained('bert-base-uncased', vocab_size=len(tokenizer))\n", 316 | "model = BertForMaskedLM(config)\n", 317 | "model.to(device)\n", 318 | "\n", 319 | "# freeze bert\n", 320 | "# for param in model.bert.parameters():\n", 321 | "# param.requires_grad = False\n", 322 | "\n", 323 | "print('Built a model!')" 324 | ], 325 | "execution_count": 8, 326 | "outputs": [ 327 | { 328 | "output_type": "stream", 329 | "text": [ 330 | "Built a model!\n" 331 | ], 332 | "name": "stdout" 333 | } 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "metadata": { 339 | "id": "_HoYYzjNJVfK", 340 | "colab_type": "code", 341 | "cellView": "both", 342 | "outputId": "a3066600-f88a-4e84-c80c-ce5dc421e5f9", 343 | "colab": { 344 | "base_uri": "https://localhost:8080/", 345 | "height": 187 346 | } 347 | }, 348 | "source": [ 349 | "#@title Fine-tuning\n", 350 | "n_epochs = 10 #@param {type:\"integer\"}\n", 351 | "assert n_epochs > 0, 'Please set `n_epochs` value more than zero'\n", 352 | "\n", 353 | "learning_rate = 0.25 #@param\n", 354 | "num_total_steps = n_epochs * len(train_iter)\n", 355 | "num_warmup_steps = len(train_iter)\n", 356 | "\n", 357 | "# optimizer = AdamW(model.parameters(), lr=learning_rate, correct_bias=False)\n", 358 | "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n", 359 | "#scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)\n", 360 | "\n", 361 | "vocab_range = (max(tokenizer.all_special_ids), min(tokenizer.convert_tokens_to_ids(added_special_tokens)))\n", 362 | "\n", 363 | "sampling_rate = 0.2 #@param {type:\"slider\", min:0.1, max:1.0, step:0.05}\n", 364 | "perturbator = Perturbator(tokenizer, vocab_range, sampling_rate)\n", 365 | "\n", 366 | "loss_fn = CrossEntropyLoss(ignore_index=-100)\n", 367 | "\n", 368 | "for epoch in range(1, n_epochs+1): \n", 369 | " with tqdm(train_iter, dynamic_ncols=True) as pbar:\n", 370 | " train_loss = 0.0\n", 371 | " for batch in pbar: \n", 372 | " bsz, slen = batch.size()\n", 373 | " srcs, masked_lm_labels = perturbator(batch)\n", 374 | "\n", 375 | " outputs = model(srcs)[0].view(bsz*slen, -1)\n", 376 | " loss = loss_fn(outputs, masked_lm_labels.view(-1))\n", 377 | " pbar.set_description(f'epoch {str(epoch).zfill(3)}')\n", 378 | " progress_state = OrderedDict(\n", 379 | " loss=loss.item(),\n", 380 | " bsz=bsz,\n", 381 | " )\n", 382 | " pbar.set_postfix(progress_state)\n", 383 | " loss.backward()\n", 384 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n", 385 | "\n", 386 | " optimizer.step()\n", 387 | " # scheduler.step()\n", 388 | " " 389 | ], 390 | "execution_count": 10, 391 | "outputs": [ 392 | { 393 | "output_type": "stream", 394 | "text": [ 395 | "epoch 001: 100%|██████████| 267/267 [01:46<00:00, 2.64it/s, loss=7.96, bsz=32]\n", 396 | "epoch 002: 100%|██████████| 267/267 [01:46<00:00, 2.38it/s, loss=7.32, bsz=32]\n", 397 | "epoch 003: 100%|██████████| 267/267 [01:47<00:00, 2.44it/s, loss=7.38, bsz=32]\n", 398 | "epoch 004: 100%|██████████| 267/267 [01:47<00:00, 2.68it/s, loss=7.04, bsz=32]\n", 399 | "epoch 005: 100%|██████████| 267/267 [01:47<00:00, 2.60it/s, loss=6.6, bsz=32]\n", 400 | "epoch 006: 100%|██████████| 267/267 [01:46<00:00, 2.57it/s, loss=7.32, bsz=32]\n", 401 | "epoch 007: 100%|██████████| 267/267 [01:47<00:00, 2.56it/s, loss=7.19, bsz=32]\n", 402 | "epoch 008: 100%|██████████| 267/267 [01:47<00:00, 2.18it/s, loss=7.45, bsz=32]\n", 403 | "epoch 009: 100%|██████████| 267/267 [01:47<00:00, 2.54it/s, loss=6.8, bsz=32]\n", 404 | "epoch 010: 100%|██████████| 267/267 [01:46<00:00, 2.32it/s, loss=6.47, bsz=32]\n" 405 | ], 406 | "name": "stderr" 407 | } 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "metadata": { 413 | "id": "Oj3OF1hg_Ouw", 414 | "colab_type": "code", 415 | "cellView": "form", 416 | "outputId": "a00fe629-0c44-4e3b-d41f-d66ab5bd435f", 417 | "colab": { 418 | "base_uri": "https://localhost:8080/", 419 | "height": 34 420 | } 421 | }, 422 | "source": [ 423 | "#@title Select condition\n", 424 | "example = \"The actors are fantastic .\"\n", 425 | "\n", 426 | "sentiment_label = '[NEGATIVE]' #@param [\"[NEUTRAL]\", \"[NEGATIVE]\", \"[POSITIVE]\"]\n", 427 | "conditional_example = sentiment_label + example\n", 428 | "tokenized_example = tokenizer.tokenize(conditional_example)\n", 429 | "print(tokenized_example)\n" 430 | ], 431 | "execution_count": 11, 432 | "outputs": [ 433 | { 434 | "output_type": "stream", 435 | "text": [ 436 | "['[NEGATIVE]', 'the', 'actors', 'are', 'fantastic', '.']\n" 437 | ], 438 | "name": "stdout" 439 | } 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "metadata": { 445 | "id": "GxcUMyPf_yo3", 446 | "colab_type": "code", 447 | "outputId": "93b7bbea-e8c2-4640-9ec4-299ae1babad3", 448 | "cellView": "form", 449 | "colab": { 450 | "base_uri": "https://localhost:8080/", 451 | "height": 51 452 | } 453 | }, 454 | "source": [ 455 | "#@title Output top10 candidates\n", 456 | "masked_position = 4 #@param {type:\"integer\"}\n", 457 | "tokenized_example[masked_position] = '[MASK]'\n", 458 | "print(f'Input: {tokenized_example}')\n", 459 | "\n", 460 | "input_tensor = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_example)]).to(device)\n", 461 | "outputs = model(input_tensor)[0]\n", 462 | "\n", 463 | "topk_score, topk_index = torch.topk(outputs[0, masked_position], 10)\n", 464 | "topk_tokens = tokenizer.convert_ids_to_tokens(topk_index.tolist())\n", 465 | "print(topk_tokens)" 466 | ], 467 | "execution_count": 12, 468 | "outputs": [ 469 | { 470 | "output_type": "stream", 471 | "text": [ 472 | "Input: ['[NEGATIVE]', 'the', 'actors', 'are', '[MASK]', '.']\n", 473 | "[',', '-', 'the', \"'\", 'of', 'a', 's', 'and', 'is', 'to']\n" 474 | ], 475 | "name": "stdout" 476 | } 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "id": "TNWIfSGPFFG8", 483 | "colab_type": "code", 484 | "outputId": "d1875231-7da2-4b4d-9102-41288fd596c5", 485 | "cellView": "form", 486 | "colab": { 487 | "base_uri": "https://localhost:8080/", 488 | "height": 276 489 | } 490 | }, 491 | "source": [ 492 | "#@title Visualize output probabilities\n", 493 | "plt.bar(topk_tokens, torch.softmax(topk_score, 0).tolist())\n", 494 | "plt.xticks(rotation=70)\n", 495 | "plt.ylabel('Probability')\n", 496 | "plt.show()" 497 | ], 498 | "execution_count": 13, 499 | "outputs": [ 500 | { 501 | "output_type": "display_data", 502 | "data": { 503 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEDCAYAAAAvNJM9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAU40lEQVR4nO3de7ydVX3n8c+XcBUEEeILSyhBiJdU\nKXYCdMaWXhQIUsFRKNHaItWhzpSpl9E21Xpp7KvSVlsvxUEULMIoUNROWqIUFJ2OFkwQhEKlRkQJ\nY9tIkHEUueU3fzxP7PFkkbOR8+x9OPm8X6/zOs9+Lmf9Aufs737WWnvtVBWSJE23w6QLkCTNTQaE\nJKnJgJAkNRkQkqQmA0KS1LTjpAuYLfvuu28tXrx40mVI0qPKtdde+62qWtg6Nm8CYvHixaxbt27S\nZUjSo0qSrz/UMbuYJElNBoQkqcmAkCQ1GRCSpCYDQpLUZEBIkpoMCElSkwEhSWoyICRJTfPmndSP\n1OKVlw3exm1nHj94G5I0W7yDkCQ1GRCSpCYDQpLUZEBIkpoMCElSkwEhSWoyICRJTQaEJKnJgJAk\nNRkQkqQmA0KS1GRASJKaDAhJUpMBIUlqMiAkSU0GhCSpyYCQJDUNGhBJlie5Jcn6JCsbx1+T5OYk\nNyT5VJIDpxw7NclX+q9Th6xTkrS1wQIiyQLgLOA4YCnwoiRLp512HbCsqg4FLgX+uL/28cCbgSOB\nI4A3J9l7qFolSVsb8g7iCGB9Vd1aVfcBFwEnTj2hqq6qqu/1D68GFvXbxwJXVNWmqroLuAJYPmCt\nkqRphgyI/YHbpzze0O97KC8DPvFwrk1yepJ1SdZt3LjxEZYrSZpqTgxSJ3kJsAz4k4dzXVWdU1XL\nqmrZwoULhylOkrZTQwbEHcABUx4v6vf9kCTPAd4AnFBV9z6cayVJwxkyINYCS5IclGRnYAWweuoJ\nSZ4JvI8uHP51yqHLgWOS7N0PTh/T75MkjcmOQ/3gqnogyRl0T+wLgPOq6qYkq4B1VbWarktpD+Av\nkwB8o6pOqKpNSd5KFzIAq6pq01C1SpK2NlhAAFTVGmDNtH1vmrL9nG1cex5w3nDVSZK2ZU4MUkuS\n5h4DQpLUZEBIkpoMCElSkwEhSWoyICRJTQaEJKnJgJAkNRkQkqQmA0KS1GRASJKaDAhJUpMBIUlq\nMiAkSU0GhCSpyYCQJDUZEJKkJgNCktRkQEiSmgwISVKTASFJajIgJElNBoQkqcmAkCQ1GRCSpCYD\nQpLUZEBIkpoMCElSkwEhSWoyICRJTQaEJKnJgJAkNRkQkqQmA0KS1GRASJKadpx0AYLFKy8bvI3b\nzjx+8DYkzS/eQUiSmgwISVLToAGRZHmSW5KsT7KycfyoJF9M8kCSk6YdezDJ9f3X6iHrlCRtbbAx\niCQLgLOAo4ENwNokq6vq5imnfQN4KfDaxo+4p6oOG6o+SdK2DTlIfQSwvqpuBUhyEXAi8IOAqKrb\n+mObB6xDkvQjGLKLaX/g9imPN/T7RrVrknVJrk7y/NYJSU7vz1m3cePGR1KrJGmauTxIfWBVLQNe\nDLwzycHTT6iqc6pqWVUtW7hw4fgrlKR5bMiAuAM4YMrjRf2+kVTVHf33W4HPAM+czeIkSds2ZECs\nBZYkOSjJzsAKYKTZSEn2TrJLv70v8CymjF1IkoY3WEBU1QPAGcDlwD8Cl1TVTUlWJTkBIMnhSTYA\nJwPvS3JTf/nTgHVJvgRcBZw5bfaTJGlggy61UVVrgDXT9r1pyvZauq6n6dd9HnjGkLVJkrZtpDuI\nJB9LcnySuTyoLUmaRaM+4b+XbjbRV5KcmeQpA9YkSZoDRgqIqrqyqn4F+CngNuDKJJ9PclqSnYYs\nUJI0GSN3GSXZh25ZjJcD1wHvoguMKwapTJI0USMNUif5OPAU4ALgeVX1zf7QxUnWDVWcJGlyRp3F\n9P5+RtIPJNmlqu7t3+0sSZpnRu1i+oPGvr+fzUIkSXPLNu8gkuxHt8DebkmeCaQ/tCfwmIFrkyRN\n0ExdTMfSDUwvAv50yv7vAK8fqCZJ0hywzYCoqvOB85O8sKo+OqaaJElzwExdTC+pqguBxUleM/14\nVf1p4zJJ0jwwUxfT7v33PYYuRJI0t8zUxfS+/vvvj6ccSdJcMVMX07u3dbyqfmt2y5EkzRUzdTFd\nO5YqJElzziizmCRJ26GZupjeWVWvSvLXQE0/XlUnDFaZxmLxyssGb+O2M48fvA1Js2+mLqYL+u9v\nH7oQSdLcMlMX07X9988m2Rl4Kt2dxC1Vdd8Y6pMkTcioy30fD5wNfJVuPaaDkvxGVX1iyOIkSZMz\n6nLf7wB+oarWAyQ5GLgMMCAkaZ4adbnv72wJh96tdAv2SZLmqZlmMb2g31yXZA1wCd0YxMnA2oFr\nkyRN0ExdTM+bsv0vwM/12xuB3QapSJI0J8w0i+m0cRUiSZpbRp3FtCvwMuAngF237K+qXx+oLknS\nhI06SH0BsB/dJ8x9lu4T5hyklqR5bNSAOKSq3gh8t1+f6XjgyOHKkiRN2qgBcX///dtJng7sBTxh\nmJIkSXPBqG+UOyfJ3sAbgdV0nzD3xsGqkiRN3EgBUVUf6Dc/CzxpuHIkSXPFSF1MSfZJ8p4kX0xy\nbZJ3Jtln6OIkSZMz6hjERcC/Ai8ETgK+BVw8VFGSpMkbdQziiVX11imP/yDJKUMUJEmaG0a9g/jb\nJCuS7NB//TJw+ZCFSZIma6bF+r5DtzhfgFcBF/aHdgD+H/DaQauTJE3MTGsxPXZchUiS5pZRu5hI\nckKSt/dfvzTiNcuT3JJkfZKVjeNH9TOjHkhy0rRjpyb5Sv916qh1SpJmx6jTXM8EXgnc3H+9Msnb\nZrhmAXAWcBywFHhRkqXTTvsG8FLgw9OufTzwZrrlPI4A3ty/UU+SNCajzmJ6LnBYVW0GSHI+cB3w\nu9u45ghgfVXd2l9zEXAiXcAAUFW39cc2T7v2WOCKqtrUH78CWA58ZMR6JUmP0MhdTMDjpmzvNcL5\n+wO3T3m8od83ipGuTXJ6knVJ1m3cuHHEHy1JGsWodxBvA65LchXdjKajgK3GFMatqs4BzgFYtmxZ\nTbgcSZpXZgyIJAH+N/DTwOH97t+pqn+e4dI7gAOmPF7U7xvFHcDPT7v2MyNeK0maBTN2MVVVAWuq\n6ptVtbr/mikcANYCS5IclGRnYAXdSrCjuBw4Jsne/eD0MfjGPEkaq1HHIL6Y5PCZT/s3VfUAcAbd\nE/s/ApdU1U1JViU5ASDJ4Uk2ACcD70tyU3/tJuCtdCGzFli1ZcBakjQeo45BHAm8JMltwHfpxiGq\nqg7d1kVVtQZYM23fm6Zsr6XrPmpdex5w3oj1SZJm2agBceygVUiS5pyZ1mLaFXgFcAhwI3Bu33Uk\nSZrnZhqDOB9YRhcOxwHvGLwiSdKcMFMX09KqegZAknOBLwxfkiRpLpjpDuL+LRt2LUnS9mWmO4if\nTPJ/++0Au/WPt8xi2nPQ6iRJEzPT50EsGFch2v4sXnnZ4G3cdubxg7chzVcPZ7E+SdJ2xICQJDUZ\nEJKkJgNCktRkQEiSmgwISVKTASFJajIgJElNoy73Lc0rvklPmpkBIY2Z4aRHC7uYJElNBoQkqcmA\nkCQ1GRCSpCYHqaXtiAPkeji8g5AkNRkQkqQmA0KS1GRASJKaDAhJUpOzmCSNhTOoHn0MCEnznuH0\no7GLSZLUZEBIkpoMCElSkwEhSWoyICRJTQaEJKnJaa6SNKBH8xRb7yAkSU2DBkSS5UluSbI+ycrG\n8V2SXNwfvybJ4n7/4iT3JLm+/zp7yDolSVsbrIspyQLgLOBoYAOwNsnqqrp5ymkvA+6qqkOSrAD+\nCDilP/bVqjpsqPokSds25B3EEcD6qrq1qu4DLgJOnHbOicD5/falwLOTZMCaJEkjGjIg9gdun/J4\nQ7+veU5VPQDcDezTHzsoyXVJPpvkZ1sNJDk9ybok6zZu3Di71UvSdm6uDlJ/E/jxqnom8Brgw0n2\nnH5SVZ1TVcuqatnChQvHXqQkzWdDBsQdwAFTHi/q9zXPSbIjsBdwZ1XdW1V3AlTVtcBXgScPWKsk\naZohA2ItsCTJQUl2BlYAq6edsxo4td8+Cfh0VVWShf0gN0meBCwBbh2wVknSNIPNYqqqB5KcAVwO\nLADOq6qbkqwC1lXVauBc4IIk64FNdCECcBSwKsn9wGbgFVW1aahaJUlbG/Sd1FW1Blgzbd+bpmx/\nHzi5cd1HgY8OWZskadvm6iC1JGnCDAhJUpMBIUlqMiAkSU0GhCSpyYCQJDUZEJKkJgNCktRkQEiS\nmgwISVKTASFJajIgJElNBoQkqcmAkCQ1GRCSpCYDQpLUZEBIkpoMCElSkwEhSWoyICRJTQaEJKnJ\ngJAkNRkQkqQmA0KS1GRASJKaDAhJUpMBIUlqMiAkSU0GhCSpyYCQJDUZEJKkJgNCktRkQEiSmgwI\nSVKTASFJajIgJElNBoQkqcmAkCQ1DRoQSZYnuSXJ+iQrG8d3SXJxf/yaJIunHPvdfv8tSY4dsk5J\n0tYGC4gkC4CzgOOApcCLkiyddtrLgLuq6hDgz4A/6q9dCqwAfgJYDry3/3mSpDEZ8g7iCGB9Vd1a\nVfcBFwEnTjvnROD8fvtS4NlJ0u+/qKruraqvAev7nydJGpNU1TA/ODkJWF5VL+8f/ypwZFWdMeWc\nf+jP2dA//ipwJPAW4OqqurDffy7wiaq6dFobpwOn9w+fAtwyyD+mbV/gW2Nsz7Zt27a3n/bH2faB\nVbWwdWDHMRUwiKo6BzhnEm0nWVdVy2zbtm17/rU96fYn/W/fYsgupjuAA6Y8XtTva56TZEdgL+DO\nEa+VJA1oyIBYCyxJclCSnekGnVdPO2c1cGq/fRLw6er6vFYDK/pZTgcBS4AvDFirJGmawbqYquqB\nJGcAlwMLgPOq6qYkq4B1VbUaOBe4IMl6YBNdiNCfdwlwM/AA8JtV9eBQtf6IJtK1Zdu2bdvbRfuT\n/rcDAw5SS5Ie3XwntSSpyYCQZtC/N0fa7hgQUkOSpyXZB6Dshx27JDsYzJNnQOhh207+cN8L/FiS\nI5LsN+lixinJ7kl+KcluE2o/VbV5SzAnWbCd/M7Nub8tA+JRKMlz+6nD4253JxjfK+ok70pyaZKj\nxtHelHYfC9xDN8vvD4F/6feP7Y83yW5JTkvygSQvT3Jokp3HVMMZwNFVdU+SJyd5c5JTxtDuFtcl\n+UySowGq6sGqqnQGec7a8nOT/FiSRZN6op5rd6sGxCPQv0dj3G0+FnhHv77VONt9LvD6JBcmecKY\nmn018CfAfxpzIN4LXAJcARxOt9DkE8f8x/tq4FC6kFoFnA18HHjeGNpeDrw7ySLgrXTLPjxrjHdS\nvwBcBVyaZEOSc5L8ZHU2D9z2fwcO7QNpSZLnj+v3PcmyJCuTvCDJnFh77lG91MYkJXkccEWSa4E3\nVNX6MTX9U8B7xtFQf6tfSZ4GvIFuxd1TgLuS/DhwIPD5od6j0j8ZXNN/jU0fvn+R5ABgJ+A3gFVJ\nrgY+CVwyhoB+NvArwErgvwI3An9Ft6rAYJLsBXwb+FXg5+nm418C/D3wBOCfB25/p6q6K8lVdM9P\ndwMHA1cm2QS8t6reNdvtVtXmJAuBRVW1JslhdKtRfxnYDfjIbLcJP/Q3tgJ4Md2d6wHA45JcWVUf\nHKLdUXkH8SOqqm/TLWP+eaC50NVArqd7g+E4bLnN/nXgA3S/vF+qqvuBxcAr5uAbGB+RLV0LSfYA\nvgfsA/wP4C+ArwCvp1tQcsga9gFu7dt/BnBNVf0T3ZPVp4Zsu6ruBn4P2Ax8sqo+TPeiZHNV3TBk\n23379/eb7wauqKq3V9V/Bn4R+Drdcjyzakp30pHAhn5h0VfSBcSHgN+a7TYbXgy8p6pOobt7/Eu6\n1SQOGUPbD8k7iEegqu5LMpZX81PavHuMbW25nf8HuleVpwFn9vt+me4Ja77ZAXiQ7lX7U+leADwI\nPBH4WlVN/0yTWVdVd9J1q+0A/C1wYb/awKKqGnzF4n4lg5v7V7Y7A/+O8b0ooW/zc0y5W6qqG5Ns\nBGb9FfWWrsOq+pskzwBeQLd69IeTvBa4erbbnNp2kscA3wCO6Bfpuwv4qySvprubGFfvxFYMiEdo\nDH2ic8Gn6F5FHw5ck2QJ8NNs/fkej3pT7oieBfxeVV0PkOSTwPuTfKGq/teYatmc5M+B/wPsQfdq\ndiymPGnel+SDwP0zXDKbbd+X5ELgvCS/Rte9cyDw9Kq6fTbbSvJzdL/HbwB2r6q3TTm2EPgPdOMw\nQ/oZuhdgB9PdNXwf2A/YoaquGrjtbXKpDY2kn/J4NHAMsDvwzqr60mSrGka6Ty/8beDfA78P3Ng/\naV0PnFZV1020wO1A//9gR+DXgGOBK4HPVdWNA7S1E3AY3cD41cD5VXVBf+zAqvr6bLfZ/+z/SBcE\np9F1426i6166h64L90NVdeUQbY/KgNBD6rs4fgd4PF130tfoPsjpexMtbAz62WK/TdfltB/dTJ7v\n933EmmVJdujvmA6lGyDfk+4Jcw2wtqq+P6Y6ngy8CnghcAPw0qqa9Y8a6Mc9ltJ9zPJRdAH4IWBN\n/2LkXOAts33H9HAZEHpISV4E/CbdL+9edLM57gZuAi6Ya3O2Z1vfN3w43UD1/XRPVIPO4tleTZnN\ncwHdf+u/Bvan+1z6J9INmJ89xnoWAD8L3FBVmwZs56l0A/B30t0tHQxsAHatqp8Zqt1RGRDaSpLX\n0d1urwA+XlWfS7IL3WDlc4BvVtX7J1mj5p/+VfVbgD+uqu8meTxdOCwFvjxE99Jc0t+xPwlYBtwy\nF7oyDQj9kCR70g3YLQH2Br5DN1h7w5RzFsy36a2avCTPBz4GnF1V/2XS9ciA0DRTbvVX0nWv3EPX\nvXQ38HfAx6pq4yRr1PyV5GS6gdqnA58GPlhV/3PL7+Vkq9v+OM1V053c30WcALy8qm7u31X6FOC/\nATszpndya/6b8oLkMXQDw0+mmzkWundyfyTJ8ZOe7rm98p3U+oG+D/RmujcKPR34wyQvAf6pqi6m\nu5vwD1Wzactz0KvoxrgOAi4AXgc8DniO4TA5djFpK1NmVnyLbsrhEuB2YEFV/eIka9P8lORT/Nva\nU39HN1Pub4A/q6qzJlnb9swuJm2lqr5Mv4xGkkvpZlYcDgy+zIO2P421p95eVRuS3EA3xVoT4h2E\npDmh7+J8HXAc3fpDh1bVnFj2entlQEiaM5LsTjcGtgdw07jWvVKbASFJanIWkySpyYCQJDUZEJKk\nJgNCktRkQEiSmgwISVLT/wep0icpcMB2GQAAAABJRU5ErkJggg==\n", 504 | "text/plain": [ 505 | "
" 506 | ] 507 | }, 508 | "metadata": { 509 | "tags": [] 510 | } 511 | } 512 | ] 513 | } 514 | ] 515 | } -------------------------------------------------------------------------------- /unsupervised_extractive_summarization_with_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "unsupervised_extractive_summarization_with_bert.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "ICvreRPWZHlM", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "# Unsupervised Extractive Summarization with BERT\n", 36 | "This notebook demonstrates EmbedRank, which is an unsupervised keyphrase extraction model [1]. Sentence embeddings obtained from Hugging Face's BERT are used to calculate each sentence's importance.\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "zzp224RYZedC", 43 | "colab_type": "text" 44 | }, 45 | "source": [ 46 | "## Settings" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "hHzpRs-jZOof", 53 | "colab_type": "code", 54 | "colab": { 55 | "base_uri": "https://localhost:8080/", 56 | "height": 105 57 | }, 58 | "cellView": "form", 59 | "outputId": "712e9fe1-f4c5-41f3-87da-e3d72e59bddb" 60 | }, 61 | "source": [ 62 | "#@title Setup Environment\n", 63 | "!pip install --quiet googletrans==2.4.0\n", 64 | "!pip install --quiet japanize-matplotlib==1.0.4\n", 65 | "!pip install --quiet pytorch_transformers\n", 66 | "!pip install --quiet mecab-python3\n", 67 | "!pip install --quiet https://github.com/megagonlabs/ginza/releases/download/v1.0.2/ja_ginza_nopn-1.0.2.tgz\n", 68 | "!pip install --quiet https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz\n", 69 | "!ln -s /usr/local/lib/python3.6/dist-packages/ja_ginza_nopn /usr/local/lib/python3.6/dist-packages/spacy/data/ja_ginza_nopn" 70 | ], 71 | "execution_count": 1, 72 | "outputs": [ 73 | { 74 | "output_type": "stream", 75 | "text": [ 76 | "\u001b[K |████████████████████████████████| 122.4MB 312kB/s \n", 77 | "\u001b[?25h Building wheel for ja-ginza-nopn (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 78 | "\u001b[K |████████████████████████████████| 11.1MB 1.7MB/s \n", 79 | "\u001b[?25h Building wheel for en-core-web-sm (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 80 | "ln: failed to create symbolic link '/usr/local/lib/python3.6/dist-packages/spacy/data/ja_ginza_nopn/ja_ginza_nopn': File exists\n" 81 | ], 82 | "name": "stdout" 83 | } 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "ukbf524JZjRg", 90 | "colab_type": "code", 91 | "colab": {}, 92 | "cellView": "form" 93 | }, 94 | "source": [ 95 | "#@title Setup common imports and functions\n", 96 | "from googletrans import Translator\n", 97 | "import torch\n", 98 | "from pytorch_transformers import (\n", 99 | " BertTokenizer,\n", 100 | " BertModel,\n", 101 | ")\n", 102 | "import numpy as np\n", 103 | "import matplotlib.pyplot as plt\n", 104 | "import japanize_matplotlib\n", 105 | "import seaborn as sns\n", 106 | "import spacy\n", 107 | "from IPython.display import HTML\n", 108 | "from sklearn import manifold\n", 109 | "from sklearn.metrics.pairwise import cosine_distances\n", 110 | "%matplotlib inline\n", 111 | "\n", 112 | "\n", 113 | "def ncossim(embs_1, embs_2, axis=0):\n", 114 | " sims = np.inner(embs_1, embs_2)\n", 115 | " std = np.std(sims, axis=axis)\n", 116 | " ex = np.mean((sims-np.min(sims, axis=axis))/np.max(sims, axis=axis), axis=axis)\n", 117 | " return 0.5 + (sims-ex)/std\n", 118 | "\n", 119 | "\n", 120 | "def mmr(doc_emb, cand_embs, key_embs):\n", 121 | " param = 0.5\n", 122 | " scores = param * ncossim(cand_embs, doc_emb, axis=0)\n", 123 | " if key_embs is not None:\n", 124 | " scores -= (1-param) * np.max(ncossim(cand_embs, key_embs), axis=1).reshape(scores.shape[0], -1)\n", 125 | " return scores\n", 126 | "\n", 127 | "\n", 128 | "def embedrank(doc_emb, sent_embs, n_keys):\n", 129 | " assert 0 < n_keys, 'Please `key_size` value set more than 0'\n", 130 | " assert n_keys < len(sent_embs), 'Please `key_size` value set lower than `#sentences`'\n", 131 | " sims = np.inner(doc_emb, sent_embs).reshape(-1)\n", 132 | " return np.argsort(-sims)[:n_keys]\n", 133 | "\n", 134 | "\n", 135 | "def embedrankpp(doc_emb, sent_embs, n_keys):\n", 136 | " assert 0 < n_keys, 'Please `key_size` value set more than 0'\n", 137 | " assert n_keys < len(sent_embs), 'Please `key_size` value set lower than `#sentences`'\n", 138 | " cand_idx = list(range(len(sent_embs)))\n", 139 | " key_idx = []\n", 140 | " while len(key_idx) < n_keys:\n", 141 | " cand_embs = sent_embs[cand_idx]\n", 142 | " key_embs = sent_embs[key_idx] if len(key_idx) > 0 else None\n", 143 | " scores = mmr(doc_emb, cand_embs, key_embs)\n", 144 | " key_idx.append(cand_idx[np.argmax(scores)])\n", 145 | " cand_idx.pop(np.argmax(scores))\n", 146 | " return key_idx" 147 | ], 148 | "execution_count": 0, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": { 154 | "id": "1lWmmj1raL4l", 155 | "colab_type": "text" 156 | }, 157 | "source": [ 158 | "## Extractive Summarization" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "GVZPWVlsdnFn", 165 | "colab_type": "code", 166 | "colab": {} 167 | }, 168 | "source": [ 169 | "# Document from Wikipedia\n", 170 | "doc = \"\"\"\n", 171 | "自動要約(じどうようやく)は、コンピュータプログラムを用いて、文書からその要約を作成する処理である。\n", 172 | "作成される要約は、要約の対象となる文書の最も重要な要素のみを残しているべきであり、いわゆる情報のオーバーロードに伴い自動要約に対する関心も増している。\n", 173 | "首尾一貫した要約を作成するためには要約の長さや書き方のスタイル、文法などといった点が考慮されなければならない。\n", 174 | "自動要約の応用先の1つはGoogleなどの検索エンジンであるが、もちろん独立した1つの要約プログラムといったものもありうる。\n", 175 | "自動要約は、要約の目的や要約の対象とする文書の数、要約の方法などによっていくつかの種類に分類することができる。\n", 176 | "抽出的要約は、要約の対象となる文書に含まれる単語や句、文といった単位をうまく抽出し、それらを組み合わせることで要約を作成する。\n", 177 | "一方、生成的要約は、文書を一度何らかの中間表現(あるいは意味表現)に変換し、この中間表現を元に自然言語生成の技術を用いて要約を作成する。\n", 178 | "そのため、生成的要約によって作成された要約には元の文書に含まれていない表現が含まれることもありうる。\n", 179 | "生成的要約には、文書を中間表現に正確に変換すること(すなわち、精度の高い自然言語理解を実現すること)、そこから要約を生成するための自然言語生成器が必要になるといった問題が存在するため、もっぱら研究の焦点は抽出的要約にあてられている。\n", 180 | "\"\"\"" 181 | ], 182 | "execution_count": 0, 183 | "outputs": [] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "metadata": { 188 | "id": "8nSSOvQZuQZj", 189 | "colab_type": "code", 190 | "cellView": "form", 191 | "outputId": "e84c58f0-7938-4ae2-edd8-7da05351129a", 192 | "colab": { 193 | "base_uri": "https://localhost:8080/", 194 | "height": 52 195 | } 196 | }, 197 | "source": [ 198 | "#@title Language detection and sentence segmentation\n", 199 | "translator = Translator()\n", 200 | "detected_lang = translator.detect(doc)\n", 201 | "\n", 202 | "assert detected_lang.lang in ['ja', 'en'], 'Please, input Japanese text or English text'\n", 203 | "if detected_lang.lang == 'ja':\n", 204 | " sentence_splitter = spacy.load('ja_ginza_nopn')\n", 205 | "elif detected_lang.lang == 'en':\n", 206 | " sentence_splitter = spacy.load('en_core_web_sm')\n", 207 | "\n", 208 | "sents = [str(s) for s in sentence_splitter(doc.replace('\\n', '')).sents]\n", 209 | "print(f'Language: {detected_lang.lang}')\n", 210 | "print(f'#sentences: {len(sents)}')" 211 | ], 212 | "execution_count": 4, 213 | "outputs": [ 214 | { 215 | "output_type": "stream", 216 | "text": [ 217 | "Language: ja\n", 218 | "#sentences: 9\n" 219 | ], 220 | "name": "stdout" 221 | } 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "metadata": { 227 | "id": "no3oU6hfaNIG", 228 | "colab_type": "code", 229 | "colab": { 230 | "base_uri": "https://localhost:8080/", 231 | "height": 34 232 | }, 233 | "cellView": "form", 234 | "outputId": "76e6c541-56f8-49c2-b2a7-fbcade63d0a8" 235 | }, 236 | "source": [ 237 | "#@title Build a model\n", 238 | "tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')\n", 239 | "model = BertModel.from_pretrained('bert-base-multilingual-cased')\n", 240 | "rank_fn = embedrankpp" 241 | ], 242 | "execution_count": 5, 243 | "outputs": [ 244 | { 245 | "output_type": "stream", 246 | "text": [ 247 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 248 | ], 249 | "name": "stderr" 250 | } 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "metadata": { 256 | "id": "kgRNVcBFcY9O", 257 | "colab_type": "code", 258 | "colab": {}, 259 | "cellView": "form" 260 | }, 261 | "source": [ 262 | "#@title Model run\n", 263 | "# Convert tokens into ids\n", 264 | "encoded_doc = torch.tensor(tokenizer.encode(doc)).unsqueeze(0)\n", 265 | "encoded_sents = [tokenizer.encode(s) for s in sents]\n", 266 | "pad_idx = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)\n", 267 | "maxlen = max([len(s) for s in encoded_sents])\n", 268 | "encoded_sents = torch.tensor([s + [pad_idx for _ in range(maxlen-len(s))] for s in encoded_sents])\n", 269 | "\n", 270 | "if torch.cuda.is_available:\n", 271 | " model.to('cuda')\n", 272 | " encoded_doc = encoded_doc.to('cuda')\n", 273 | " encoded_sents = encoded_sents.to('cuda')\n", 274 | " \n", 275 | "# Encode\n", 276 | "doc_emb = torch.mean(model(encoded_doc)[0], dim=1).to('cpu').detach().numpy()\n", 277 | "sent_embs = torch.mean(model(encoded_sents)[0], dim=1).to('cpu').detach().numpy()\n", 278 | "\n", 279 | "# Ranking\n", 280 | "key_size = 3 #@param {type:\"integer\"}\n", 281 | "keys = rank_fn(doc_emb, sent_embs, key_size)" 282 | ], 283 | "execution_count": 0, 284 | "outputs": [] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "metadata": { 289 | "id": "b5HKVZd8s4ct", 290 | "colab_type": "code", 291 | "cellView": "both", 292 | "outputId": "a8ef6658-a0c7-491f-cf3d-a8535a0b2cb0", 293 | "colab": { 294 | "base_uri": "https://localhost:8080/", 295 | "height": 118 296 | } 297 | }, 298 | "source": [ 299 | "#@title Display\n", 300 | "display_sents = []\n", 301 | "for i, s in enumerate(sents):\n", 302 | " line = '' + s + '' if i in keys else s\n", 303 | " display_sents.append(line)\n", 304 | "HTML(''.join(display_sents))" 305 | ], 306 | "execution_count": 7, 307 | "outputs": [ 308 | { 309 | "output_type": "execute_result", 310 | "data": { 311 | "text/html": [ 312 | "自動要約(じどうようやく)は、コンピュータプログラムを用いて、文書からその要約を作成する処理である。作成される要約は、要約の対象となる文書の最も重要な要素のみを残しているべきであり、いわゆる情報のオーバーロードに伴い自動要約に対する関心も増している。首尾一貫した要約を作成するためには要約の長さや書き方のスタイル、文法などといった点が考慮されなければならない。自動要約の応用先の1つはGoogleなどの検索エンジンであるが、もちろん独立した1つの要約プログラムといったものもありうる。自動要約は、要約の目的や要約の対象とする文書の数、要約の方法などによっていくつかの種類に分類することができる。抽出的要約は、要約の対象となる文書に含まれる単語や句、文といった単位をうまく抽出し、それらを組み合わせることで要約を作成する。一方、生成的要約は、文書を一度何らかの中間表現(あるいは意味表現)に変換し、この中間表現を元に自然言語生成の技術を用いて要約を作成する。そのため、生成的要約によって作成された要約には元の文書に含まれていない表現が含まれることもありうる。生成的要約には、文書を中間表現に正確に変換すること(すなわち、精度の高い自然言語理解を実現すること)、そこから要約を生成するための自然言語生成器が必要になるといった問題が存在するため、もっぱら研究の焦点は抽出的要約にあてられている。" 313 | ], 314 | "text/plain": [ 315 | "" 316 | ] 317 | }, 318 | "metadata": { 319 | "tags": [] 320 | }, 321 | "execution_count": 7 322 | } 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "metadata": { 328 | "id": "47H6fAx6qdRt", 329 | "colab_type": "code", 330 | "colab": { 331 | "base_uri": "https://localhost:8080/", 332 | "height": 462 333 | }, 334 | "cellView": "form", 335 | "outputId": "43b8e35b-cfab-466d-e41b-543da15296fb" 336 | }, 337 | "source": [ 338 | "#@title Visualize sentence embeddings\n", 339 | "# print sentences\n", 340 | "print('id' + '\\t' + 'sentence')\n", 341 | "for i, sent in enumerate(sents, 1):\n", 342 | " if len(sent) < 50:\n", 343 | " print(str(i) + '\\t' + sent)\n", 344 | " else:\n", 345 | " print(str(i) + '\\t' + sent[:50] + '...')\n", 346 | "print('')\n", 347 | "\n", 348 | "# MDS\n", 349 | "mds = manifold.MDS(n_components=2, dissimilarity=\"precomputed\")\n", 350 | "embs = np.concatenate((doc_emb, sent_embs), 0)\n", 351 | "dist_matrix = cosine_distances(embs, embs)\n", 352 | "pns = mds.fit_transform(dist_matrix)\n", 353 | "fixed_pns = pns - pns[0]\n", 354 | "# plot\n", 355 | "keys_idx = [idx + 1 for idx in keys]\n", 356 | "other_idx = [idx for idx in range(1, len(sents)+1) if idx not in keys_idx]\n", 357 | "plt.scatter(fixed_pns[0,0], fixed_pns[0,1], color='green', marker='*', s=150, label='document')\n", 358 | "plt.scatter(fixed_pns[keys_idx,0], fixed_pns[keys_idx, 1], color='blue', label='key sentences')\n", 359 | "plt.scatter(fixed_pns[other_idx,0], fixed_pns[other_idx, 1], color='white', edgecolors='black', label='other sentences')\n", 360 | "\n", 361 | "for i,(x,y) in enumerate(fixed_pns[1:], 1):\n", 362 | " plt.annotate(str(i), (x+0.002 ,y+0.002))\n", 363 | " \n", 364 | "plt.legend()\n", 365 | "plt.show()\n" 366 | ], 367 | "execution_count": 8, 368 | "outputs": [ 369 | { 370 | "output_type": "stream", 371 | "text": [ 372 | "id\tsentence\n", 373 | "1\t自動要約(じどうようやく)は、コンピュータプログラムを用いて、文書からその要約を作成する処理である。...\n", 374 | "2\t作成される要約は、要約の対象となる文書の最も重要な要素のみを残しているべきであり、いわゆる情報のオー...\n", 375 | "3\t首尾一貫した要約を作成するためには要約の長さや書き方のスタイル、文法などといった点が考慮されなければ...\n", 376 | "4\t自動要約の応用先の1つはGoogleなどの検索エンジンであるが、もちろん独立した1つの要約プログラム...\n", 377 | "5\t自動要約は、要約の目的や要約の対象とする文書の数、要約の方法などによっていくつかの種類に分類すること...\n", 378 | "6\t抽出的要約は、要約の対象となる文書に含まれる単語や句、文といった単位をうまく抽出し、それらを組み合わ...\n", 379 | "7\t一方、生成的要約は、文書を一度何らかの中間表現(あるいは意味表現)に変換し、この中間表現を元に自然言...\n", 380 | "8\tそのため、生成的要約によって作成された要約には元の文書に含まれていない表現が含まれることもありうる。...\n", 381 | "9\t生成的要約には、文書を中間表現に正確に変換すること(すなわち、精度の高い自然言語理解を実現すること)...\n", 382 | "\n" 383 | ], 384 | "name": "stdout" 385 | }, 386 | { 387 | "output_type": "display_data", 388 | "data": { 389 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl4VeW59/HvTQiTUKkhCAUJHBUH\nkFINKAYIOBxxQBwqlVoEjKa1l4gKeuTN5QSirRWO74tHARmliFppiVUErSQMgkCgqUOFVqtgKngC\nZRBCBsj9/pGdNITA3oEkO8n6fa5rXzv7WdO9V5JfVp691rPM3RERkWBoFO0CRESk9ij0RUQCRKEv\nIhIgCn0RkQBR6IuIBIhCX0QkQBT6IiIBotAXEQkQhb6ISIA0jnYBFbVp08Y7d+4c7TJEROqVjRs3\n7nT3+HDz1bnQ79y5M1lZWdEuQ0SkXjGzrZHMp+4dEZEAUeiLiASIQl9EJEDqXJ9+ZYqKisjJySE/\nPz/apUgVNGvWjI4dOxIbGxvtUkTqhaeeeorFixcTGxvLD37wA2bPnk2rVq2qdRv1IvRzcnJo1aoV\nnTt3xsyiXY5EwN3ZtWsXOTk5dOnSJdrliNR5H3/8Menp6axdu5aYmBjuv/9+pk2bxoMPPlit26kX\n3Tv5+fnExcVVOfALDxfWUEUSjpkRFxen/85EItSmTRuaNm3KoUOHADh8+DA9e/as9u3UiyN9oMqB\nn3sgl3OeP4ct92wh/pSwp65KDdB/ZSKRa9++Pffccw+//OUvOeuss/j+97/PFVdcUe3bqRdH+idi\n8ebF7M7fTfqW9GiXIiISVkZGBitXrmTWrFmMHz+ebt268dhjj1X7dhps6M/JnnPEc3V6/PHHmTZt\nWrWvt7r84x//4Jtvvol2GSJyHAsXLqR79+7ExMTQvXt35s+fT0FBQdn0wsJC/v73v1f7dutN905V\n7D64m43bNwKQ9U0We/L30LpZ6yhXVXsmTJjAyJEj+cEPfhDtUkSkEgsXLiQtLY1Zs2bRt29fVq9e\nzahRo+jUqRO9e/cmNjaW5s2bM3PmzGrfdoMI/dQ/pvLGX9/AcQAOFR8itlEshYcLiW0Uyxn/fQaN\nG5W8VcO45fxbmD54epW2MXbsWFavXl12CmK7du3Iyspi7NixmBmtWrXipZdeol27dqxbt44HHniA\n4uJizjjjDObOncvrr7/O5s2b+dWvfgVAx44dycnJITMzkylTptC4cWO2bNnChAkTmDFjBtu3b2fq\n1KkkJyezY8cOUlJS2L9/P9/73veYO3cucXFxnHvuufz0pz8lMzOTvXv3kp6eztdff83SpUvJzs7m\n1ltv5eGHH67enS0iJ23SpEnMmjWLgQMHAjBw4EDmzJnD6NGj+eSTT2p02w2ie2fcpeNoe0pb8ory\n2JO/h/2F+zlQdACAA0UH2F+4nz35e8gryqPtKW0Ze+nYKq1/yZIlfP7553z44Ye8/vrrfPfddwD8\n7Gc/Y+bMmWRmZnL77bdz3333ATB8+HDmzp3L2rVr+dnPfkZubu5x1/+Pf/yD1157jccff5y0tDTe\neustXnzxRaZMmVLy/saN4yc/+QkrVqzgzjvvZOLEiQAUFBRwwQUXsHz5cm644QZ+97vf0adPHwYN\nGsRzzz2nwBepoz777DP69u17RFvfvn357LPPanzbDSL0u8Z15S+/+At3XXgXLWJbVDpP88bNSb0w\nlY/u/oiucV2rtP5PP/2U5ORkzIyYmBh69erFzp07iY2N5eyzzwbgmmuuISsri507d9K0adOy9uuv\nv56EhITjrr979+7ExsYSHx9Pz549iY2N5fTTT2fv3r0AZGdn89JLLzFgwAAmT55MTk4OUHIu/NVX\nXw2UfPJfOr+I1G3nnXceq1evPqJt9erVnHfeeTW+7QYR+gBNGzfl+Wue5+Gkh2nZpOUR01rGtmR8\n3/FMvWYqTWKaVHndPXr04P3336e4uJj8/HwyMjKIi4ujoKCArVtLBrZbunQpPXv2pE2bNhQWFrJl\nyxYAVq1axWeffcapp57Kt99+C8D69evZvn17lbb/6KOPkpmZyXvvvce4ceOOO7+ZUVioaxRE6qq0\ntDRSUlLIyMigqKiIjIwMUlJSSEtLq/FtN4g+/fLW5Kxhf+F+ABpbYw75IfYX7efDnA9PeJ1XXXUV\nmZmZ9O7dm9NOO41u3bphZvz2t7/l9ttvp1GjRpxyyinMmDEDgPnz5zNq1CjMjDZt2jBnzhwSEhJ4\n4YUXSE5O5kc/+hFnnnlmxNufMmUKqampPPnkkxQXF4c9jat///6MGTOGu+++m3vvvfeE37eI1Ixh\nw4YBMHr0aD777DPOO+88Jk2aVNZek8zda3wjVZGYmOgVx9Mv3Snh5BXlcdqvT6PgcAHNGzdn8DmD\n+eOWP3Lw0EGaxjRl93/tpnls85oqXSoR6fdOpC575plnWLx4MQcPHuRHP/oR06ZNo0mTqvca1CQz\n2+juieHmazDdOwDLPl9GweEC2rdsz4qRK3jtx6+xYuQK2rdsT8HhApZ9sSzaJYpIPbNz50727t3L\nBx98wJ///Gfy8vJIT6+/F302qNDftncbQ7sNZcs9W+jVoRcAvTr0YvM9mxnabShb90R0YxkRkTJt\n2rRh0qRJmBn79+9n3759dO/ePdplnbCI+vTNbCgwDogBMt19bLlpjYDfAJcArYC33X18aNoPgf8H\nNAVygdvdfXe1voNyxlwyhjGMOar9e02/x2s/fq2mNisidczIkSPZvHkzzZo1A+CBBx7g+uuvP6l1\n3nbbbbz77rs89NBDnHvuudVRZlSEDX0zSwAmAr2BfcCrZnazuy8KzXI28I27J5lZDLDSzHoBWcCr\nwDB3zzazXwITgNE18UZEREpt27aNzMzMstCvDgsWLCAvL4/hw4czb948Ro4cWW3rrk2RdO8MAha5\n+14v+dR3OnBD6UR33+Luk0MvTwMOA18BXYHd7p4dmjYTuLa6ChcROZY9e/bwi1/8gv79+3PPPfeQ\nl5cX8bIVx8R5+umnmTdvHgAtWrSga9eu7Nmzp6ZKr3GRhH4csKPc6+1A24ozmVkm8Akw091zKy7n\n7oUc4z8LM0s1sywzywp39aqISDiJiYlMnDiRlStXEh8fX3YVezilY+JMnTqV/Px8pk6dyvTp03n5\n5ZdJTEykX79+fPXVV9x11101/A5qkLsf9wGkAE+Wez0QePkY834fWAMMAM4EVpeb1hT4PNz2Lrro\nIq/or3/961FttenLL7/0iy++OKo1lDp06JCvWbMm2mVELNrfO5FPP/3UL7vssojm7datmy9fvvyI\ntuXLl3u3bt1qorRqBWR5mHx194iO9JcAN5pZ6Y0a7wDKzlcysyvM7LrQH5DdwFagtbt/AbQ0s9KP\nuYcD75zoH6eqWLAAOneGRo1KnhcsqI2t1o6vv/6a8ePHR7sMkTqhst/1gwcP8sgjj5Rdlf7OO+9w\n4YUXRrS+aI6JU1vChr67bweeouQD2nXAt+6+yMwyzawdkA0MN7P1ZrYW2AW8GVp8JPCSmX0AXA88\nWhNvorwFCyA1FbZuBfeS59TU6gv+e++9l2effRaAN998k4svvpikpCSefvpp3J0LLriAf/7zn6Fa\nFhx1f8vJkydzySWXMGDAADZt2gTAzJkz6d27N3369CnrO5w7dy6jRo1i8ODB/PCHP+Tpp58G4LHH\nHiM7O5sBAwawY8cO1q5dS1JSEv369WP06JLPyL/66iv69+/PyJEjufTSSxkyZAjFxcUAzJkzh169\nepGYmMjjjz8OwI4dO7j22mtJTk5m8ODB7Nq1C3fnpz/9Kf369WPIkCH861//qp4dKFJNjvW7/vvf\nN6dNmzb07t2b5ORkNm7cyKOPRhY90RwTp9ZE8u9AbT5OtnsnIcG95EfgyEdCQsSrOEpp985TTz3l\nv/71r93dfffu3Z6QkOC7du1yd/cbb7zRN23a5C+++KI/+eST7u5+1VVX+d///vcj1pWUlOQ7d+70\n3bt3+969e33z5s3evXt3P3jwoBcVFXlSUpJv377d58yZ44mJiV5QUOAHDx709u3bl9WSnJxctr7/\n+I//KNvGfffd53/4wx/8yy+/9JYtW/rWrVvd3f2yyy7zTZs2+ZYtW7xHjx6el5fn7u7Tpk3zgoIC\nv+2223zevHnu7r548WIfM2aM79692/v16+dFRUW+bds2Ly4uPqF9p+4dqSk18bv+yiuveJcuXXz5\n8uVeWFjoy5cv9y5duvgrr7xSXWXXGCLs3mlwY+9s21a19kh9+umnHDhwoGx8nc8//5wDBw5w0003\nAbBv3z62bNnC7bffzoABAxg5ciSNGjXirLPOOmI98+fP51e/+hXFxcU8/PDDfPzxx+zZs4dBgwaV\nrefzzz8H4PLLLy+71LtRo6P/Kdu5cyf/+7//y5133glAXl4eHTp0oGfPnnTr1o1OnToB/x6B84sv\nvqB///40b14yFMXPf/5zoGQUz61btzJ79myKi4tp27YtrVu3ZuLEidx333106NCBhx56iJiYmJPb\niSLVqCZ+16M5Jk5taXCh36lTyb95lbWfjPPOO4+33nqLa6+9lqVLl3LmmWdyxhln8Pbbb3PKKaew\nefNmWrduTYsWLbj88su55557SE1NPWo9hw4d4je/+Q0rVqzg6aefJjU1la5du7Js2TIaN27Mpk2b\nOOuss8qCv6LyI2jGxcXRpUsXXnvtNU4//XRycnLIz88/5nvo0aMHTz75JHl5ebRo0YKFCxdy7bXX\n0qNHD0aNGsWVV15JQUEBf/7znzl8+DCdOnXi+eef54knnuDtt98+6YtbRKpTTf2uDxs2rEGFfEUN\nahgGgEmToEWFIfVbtChpPxmNGjWibdu2PP7444wYMYLWrVszYcIErrzySvr168cjjzxC06ZNgZIj\n6KysLAYPHnzUembPns2AAQMYO3YsgwcPLrv7Vd++fenXrx//8z//U7aeyrRv354DBw5w+eWXs2fP\nHqZPn84tt9xC//79SU1NJTY29pjLdu3alTFjxtCvXz8uueQS1q9fT6tWrZgyZUrZXbquuOIK9u/f\nz3fffcf48ePp378/f/rTn+jVq9fJ7UCRalbV3/U33niDoUOHlv0HHFiR9AHV5qM6Ttn87W9L+vXM\nSp5/+9sqLX7SXn75ZX/iiSdqd6N1lPr0pSZV5Xc9MzPTc3Nz/fTTT6+t8moVQe3TB7jttpJHNMyZ\nM4e5c+fW61H4ROqLqvyuJycn12wx9USD696JtlGjRrFixQpat24d7VJERI6i0BeRBqkhX6R5Mhpk\n946IBFvphVul46yVXrglOtIXkQYoLe3fgV8qL6+kPeh0pC8iDc7xLtwqLt5R+cSA0JH+SVq1alXZ\n1yNHjmTp0qVRrObfytclEjTHOhU/6KfoQwMN/Yo3QVi4cGGNbWv48OE1tu6TUVfrEqkNNXWRZkPQ\n4EK/spsgpKWlnXTwP/PMM1xyySX06dOHSaGfnMcee4wdO3YwYMAAsrNLbhD2/vvvM2TIEM4//3ze\nffddoPJRLAG6devG9OnTSUlJOWJbGRkZ9O7dm/79+zN37lyAKo2m+eKLL5bVtXTp0mNu/9xzz2XC\nhAlcdtllXHTRReTk5ACwbt06kpKS6NOnD0OHDiUvL4/i4mJ++ctfkpSURP/+/dm4cSNQ+aihItF2\n220wYwYkJIBZyfOMGSd2/c7EiRMZMGBAtdcYNZFcwVWbj5O9IrcmboLw/vvv+5VXXumHDh3yw4cP\n+3XXXefvvPOOu7snlBvSb8SIET569Gh3d1+5cqVff/317u6VjmLp7n7OOeeUtZc3btw4f+utt/zw\n4cP+9ddfu3vVRtOsWNextt+5c2f//e9/7+7uEyZM8ClTpri7+9lnn+1/+9vf3N09PT3dv/rqK3/p\npZd81KhR7u7+z3/+0/v06ePuR48aWpGuyJX6bMOGDT5q1KgjRratqwjqFbk1cROETZs2cdVVV5WN\nMjlo0CCysrLKRsYs75prrgH+PbIlVD6KJUB+fj433HDDUet47LHHeO6551iyZAl33nknzZo1q9Jo\nmhUda/vuztVXX122bE5ODjt37qRp06acffbZAGWDrGVnZ7Nhw4ayI55du3ZRWFh41KihIg3FwYMH\nuf/++1m0aBFDhw6NdjnVpsGFfulNEAYOHFjWdrI3QejZsyfPPfccDzzwAADvvvtu2T0yi4qKwi5f\n2SiWpUqHTi4vNzeX8ePHU1hYyH/+53+ycuXKKo2mCf8ejbNJkybH3X5Fbdq0obCwkC1btnDOOeew\natUq2rRpQ48ePTj11FPLurZWrFhBkyZNjho1dMqUKWH3hzRc27Zt495772Xfvn3ExMQwefJkevTo\nEe2yTsiDDz7ImDFjyg6SGooG16eflpZGSkoKGRkZFBUVkZGRQUpKCmkncYLuFVdcQVJSUtnjwgsv\n5LrrrgPg/PPPp1+/fmzevPmYy1c2iuXxbNiwgYEDB5KcnMyQIUMwsyqNpgklY/H37duXFStWVHn7\n8+fPZ9SoUSQlJfHss89y+umnk5KSwr59+8r2QVZWFnD0qKESbHfffTfPPPMMy5cv55VXXqFDhw7R\nLumELFu2jN27d/PjH/842qVUOyvpCqo7EhMTvTRQSpXezCBSCxcuZNKkSWXLpaWlNejxseuyqn7v\npP7asWMHw4YN46KLLmLdunVccMEFTJ48ueymPXVZxczo0KED7k6L0ClAq1ev5pprruHll1+OcqXH\nZmYb3T0x7IyRdPzX5qM6hlaWukPfu+BYt26dn3rqqf6Xv/zF3d3T0tL8kUceiXJV4UVyi8SG9EFu\ng+veEZHoaN26NT169Cjrw//JT35SdmpvXTZp0iRmzZrFwIEDiY2NZeDAgcyaNavs8yuAzMzM6BVY\nzSIKfTMbambrzWyjmU2uZPpoM/vQzNaa2Qtm1ijU/riZZZtZZuhxwkMeeR3rhpLw9D1r2CpeBLlh\nwwby8vL44osvgJJ+8Z49e0a5yvBq4oy/uizs2TtmlgBMBHoD+4BXzexmd18Umt4NGAwkufthM/sd\ncB3wJtAFGOrufzuZIps1a8auXbuIi4vDzE5mVVJL3J1du3bRrFmzaJciNaD0IshZs2bRt29fVq9e\nTUpKCj//+c+56667KCoqol27dsyaNSvapYZVE2f81WWRnLI5CFjk7nsBzGw6MApYBODun5rZ9e5+\nuNw6D4a+7gTcb2bnA9uA+919Z1WL7NixIzk5OeTm5lZ1UYmiZs2a0bFjx2iXITWgfJcIUNYlMnr0\naD755JMoV1c1pWf8VfwDNqmBjtkQSejHAeWHpdsOHHHiqrvnm1lr4AUg293fC03aAMx394/NbAQw\nFajyaTSxsbF06dKlqouJSA1pSF0ipWf2jR49uuzsnUmTJjXYM/4iCf1vKemmKdUu1FbGzLoDk4FH\n3X1dabu7P1Rutt8Bj1a2gVBffyqgO9WL1AMNrUtk2LBhDTbkK4rkg9wlwI1m1ir0+g6g7K7fZhYP\nPEdJ3/26cu1mZhPN7NRQ09VApSNyufsMd09098T4+PgTeR8iUotq4iJIqR1hj/TdfbuZPQWsNLNC\nYJW7LzKzTOBW4MeU/CeQXu5D1lfcfYaZfQJkmNl+YC9wV028CRGpXUHrEmlI6sUVuSIicnyRXpGr\ni7NERAJEoS8iEiAKfRE5Ya+//jp9+vShX79+ZXdZk7pNoS8iJ+Rf//pX2TDKq1atIiEhgZkzZ0a7\nLAmjwd1ERURqx2mnncbq1avLhto4dOhQvRhGOeh0pC8iJ6xZs2bk5+czZswYDh48yB133BHtkiQM\nhb6InLCcnBxuvPFGBg0axLRp08ruIy11l0JfRCJScSjlefPmMXLkSGbMmMHVV18d7fIkQurTF5Gw\nKhtKediwYRQWFjJ8+PCy+S677DIefbTSIbakjtAVuSISVvfu3Zk6deoRA6xlZGTUy6GUG6pIr8hV\n6ItIWDExMeTn5xMbG1vWVlRURLNmzTh8+PBxlpTaomEYRBqYFStWMGDAgLLHmWeeyX333Vcr2y4d\nSrm8+jyUcpCpT1+knkhOTi67QXdxcTHJyck8+OCDtbLtoN1dqiFT6IvUQ/PmzeOKK66gQ4cOtbI9\nDaXccKhPX6SeOXToEImJiWRmZtK6detolyN1hPr0RRqoN954g6SkJAW+nBCFvkgdVfFiqIULFwIw\nffp0RowYEeXqpL5Sn75IHVTZxVApKSns3buXzZs306tXr2iXKPWU+vRF6iBdDCVVpYuzROoxXQwl\nVaUPckXqMV0MJTUlotA3s6Fmtt7MNprZ5EqmjzazD81srZm9YGaNQu0DQ23rzWy+mTWp7jcg0hCV\nXgyVkZFBUVERGRkZpKSkkJaWFu3SpJ4L+0GumSUAE4HewD7gVTO72d0XhaZ3AwYDSe5+2Mx+B1xn\nZsuBOUBfd88xs2eA0cBRfzRE5Ei6GEpqStg+fTP7OZDg7v8n9PoyYJS7Dy83TzN3zw99/QfgBUr+\nixjh7j8Ntf8HMM/d+x1ve+rTFxGpuurs048DdpR7vR1oW34Gd883s9Zm9gqQ7e7vRbJcuWJTzSzL\nzLJyc3MjKElERE5EJKH/LUeGdbtQWxkz6w68Bvxfd38i0uVKufsMd09098T4+PhIaxcRkSqKJPSX\nADeaWavQ6zuA9NKJZhYPPAcMdfd15Zb7ALjYzNqHXqeUX05ERGpf2A9y3X27mT0FrDSzQmCVuy8y\ns0zgVuDHQBcg3cxKF3vF3WeY2d3AW2ZWAHwOTKiJNyEiIpHRxVkiIg2ALs4SEZGjKPRFRAJEoS8i\nEiAKfRGRAFHoi4gEiEJfRCRAFPoiIgGi0BcRCRCFvohIgCj0RUQCRKEvIhIgCn0RkQBR6IuIBIhC\nX0QkQBT6IiIBotAXEQkQhb6ISIAo9EVEAkShLyISIAp9EZEAUeiLiARIRKFvZkPNbL2ZbTSzyZVM\nv9PMlpjZBxXaR5rZZjPLDD0era7CRUSk6sKGvpklABOBK4FEoKOZ3Vxhtq3Aw0BMhfYuwL3uPiD0\nmFANNYuIyAmK5Eh/ELDI3fe6uwPTgRvKz+Du7wH7Klm2M3Br6Cj/D2bW5WQLFhGRE9c4gnnigB3l\nXm8H2ka4/r8C69w908wGAAuASyvOZGapQCpAp06dIly1iIhUVSRH+t9yZMi3C7WF5e6/dvfM0NeZ\nQGczs0rmm+Huie6eGB8fH8mqRUTkBEQS+kuAG82sVej1HUB6JCs3s/8yszNCXycCX4e6iEREJArC\ndu+4+3YzewpYaWaFwCp3X2RmmcCt7r7jOItvABaZWQFQCAyvjqJFROTEWF078E5MTPSsrKxolyEi\nUq+Y2UZ3Tww3ny7OEhEJEIW+iEiAKPRFRAJEoS8iEiAKfRGRAFHoi4gEiEJfRCRAFPoiIgGi0BcR\nCRCFvohIgCj0RUQCRKEvIhIgCn0RkQBR6IuIBIhCX0QkQBT6IiIBotAXEQkQhb6ISIAo9EVEAkSh\nLyISIBGFvpkNNbP1ZrbRzCZXMv1OM1tiZh9UaP+hma0wsw/N7I9m9v3qKlxERKoubOibWQIwEbgS\nSAQ6mtnNFWbbCjwMxJRbzoBXgTHufgnwDjChmuoWEZETEMmR/iBgkbvvdXcHpgM3lJ/B3d8D9lVY\nriuw292zQ69nAteeZL0iInISIgn9OGBHudfbgbZVXc7dC4HGVapORESqVSSh/y1Hhny7UFuVljOz\npkBhZTOaWaqZZZlZVm5ubgSrFhGRExFJ6C8BbjSzVqHXdwDp4RZy9y+AlmbWPdQ0nJJ+/crmneHu\nie6eGB8fH0FJIiJyIsJ2t7j7djN7ClhpZoXAKndfZGaZwK3uvuM4i48EXjKzYmAXMKIaahYRkRNk\nJZ/N1h2JiYmelZUV7TJEROoVM9vo7onh5tPFWSIiAaLQFxEJEIW+iEiAKPRFRAJEoS8iEiAKfRGR\nAFHoi4gEiEJfRCRAFPoiIgGi0BcRCZAGG/ovvvgivXv3JjExkQkTdO8WERFooKG/ZcsWZs+ezapV\nq1i3bh3r169n+fLl0S5LRCTqGmTof/TRRyQlJdG0aVNiYmK46aabWLZsWbTLEhGJugYZ+j169GDF\nihXs3buXgoICFi1axHfffRftskREoq5BhP6CBdC5MzRqVPKclXUOY8eO5eqrr+bmm2+mV69edOrU\nKdpliohEXb0P/QULIDUVtm4F95Lnu+7K55tvLmTNmjWkp6eTnZ3NLbfcEu1SRUSirt7fqDwtDfLy\njmw7ePAwTzzxJIsXf0WzZs24++67OfPMM6NToIhIHVLvQ3/btspaT+HgwVdYs6a2qxERqdvqfffO\nsbrq1YUvInK0eh/6kyZBixZHtrVoUdIuIiJHqvehf9ttMGMGJCSAWcnzjBkl7SIicqSIQt/MhprZ\nejPbaGaTK5l+b2h6tpmNK9f+eKgtM/RIrc7iS912G3z1FRQXlzwr8EVEKhf2g1wzSwAmAr2BfcCr\nZnazuy8KTU8ChgF9Q4ssN7NMd88CugBD3f1vNVK9iIhUSSRH+oOARe6+190dmA7cUG76dcAcdy90\n90JgNjAkNK0TcL+ZrTCz+WbWpjqLFxGRqokk9OOAHeVebwfaRjh9A/CCuycDfwKmVrYBM0s1sywz\ny8rNzY20dhERqaJIQv9bjgz5dqG2sNPd/SF3/zjU/jtKuoiO4u4z3D3R3RPj4+MjrV1ERKooktBf\nAtxoZq1Cr+8A0stNTwduN7NYM4sBRgBvWomJZnZqaL6rgU3VVbiIiFRd2A9y3X27mT0FrDSzQmCV\nuy8ys0zgVnfPMrM3gfXAIeDV0Ie4mNknQIaZ7Qf2AnfV1BsREZHwrOSz2bojMTHRs7Kyol2GiEi9\nYmYb3T0x3Hz1/uIsERGJnEJfRCRAFPoiIgGi0BcRCRCFvohIgCj0RUQCRKEvIhIgCn0RkQBR6IuI\nBIhCX0QkQBT6IiIBotAXEQkQhb6ISIAo9EVEAkShLyISIAp9EZEAUeiLiASIQr8eKDxcGO0SRKSB\nUOjXcbkHcmn3bDtyD+RGuxQRaQAU+nXc4s2L2Z2/m/Qt6dEuRUQaAIV+HTcne84RzyIiJyOi0Dez\noWa23sw2mtnkSqbfG5qebWZ6qoZLAAAFeElEQVTjyrUPNLO1oWnzzaxJdRbf0O0+uJuN2zcCkPVN\nFnvy90S5IhGp7xqHm8HMEoCJQG9gH/Cqmd3s7otC05OAYUDf0CLLzSwT2AzMAfq6e46ZPQOMBo76\noyElUv+Yyht/fQPHAThUfIjYRrEUHi4ktlEsZ/z3GTRuVPItM4xbzr+F6YOnR7NkEalnIjnSHwQs\ncve97u7AdOCGctOvA+a4e6G7FwKzgSFAErDG3XNC802rsJxUMO7ScbQ9pS15RXnsyd/D/sL9HCg6\nAMCBogPsL9zPnvw95BXl0faUtoy9dGyUKxaR+iaS0I8DdpR7vR1oG8H0cMuVMbNUM8sys6zc3OCe\npdI1rit/+cVfuOvCu2gR26LSeZo3bk7qhal8dPdHdI3rWssVikh9F0nof8uRYd0u1BZuerjlyrj7\nDHdPdPfE+Pj4SOpusJo2bsrz1zzPw0kP07JJyyOmtYxtyfi+45l6zVSaxOjjERGpukhCfwlwo5m1\nCr2+Ayh//mA6cLuZxZpZDDACeBP4ALjYzNqH5kupsJwcx5qcNewv3A9AYyvpx99ftJ8Pcz6MZlki\nUs+FDX133w48Baw0s3XAt+6+yMwyzaydu2dREvLrgQ+BP7p7lrvnA3cDb5nZGqATMLXG3kkDkleU\nR8aXGUBJd85N599E88bNAXj/y/c5WHQwmuWJSD0W9uwdAHdfACyo0Dag3NfPAs9WstyfgItOrsTg\nWfb5MgoOF9C+ZXvSb02nV4debPjnBoa8OoTt+7ez7Itl3HCuPhMXkarTxVl10La92xjabShb7tlC\nrw69AOjVoReb79nM0G5D2bpna5QrFJH6ykrOwqw7EhMTPSsrK9pliIjUK2a20d0Tw82nI30RkQBR\n6IuIBIhCX0QkQBT6IiIBotAXEQkQhb6ISIAo9EVEAkShLyISIAp9EZEAUeiLiASIQl9EJEAU+iIi\nAaLQFxEJEIW+iEiAKPRFRAJEoS8iEiB17iYqZpYL6NZQR2oD7Ix2EXWU9k3ltF+OraHumwR3jw83\nU50LfTmamWVFckecINK+qZz2y7EFfd+oe0dEJEAU+iIiAaLQrx9mRLuAOkz7pnLaL8cW6H2jPn0R\nkQDRkb6ISIAo9OsQMxtqZuvNbKOZTa5k+r2h6dlmNi4aNUZDBPvlTjNbYmYfRKO+aIpg34w2sw/N\nbK2ZvWBmgfmdP96+MbNGZjbZzD4ws4/M7Olo1VnbAvMDUNeZWQIwEbgSSAQ6mtnN5aYnAcOAvkBv\n4AYza/CnnYXbLyFbgYeBmFouL6oi+JnpBgwGkty9DxAPXBeNWmtbBD83ZwPfuHsS8COgv5n1qv1K\na59Cv+4YBCxy971e8kHLdOCGctOvA+a4e6G7FwKzgSFRqLO2hdsvuPt7wL5oFBdlx9037v4pcL27\nHw41NQYO1n6ZURFu32xx99Kj/9OAw8BXtV5lFCj06444YEe519uBtlWY3lAF9X1HIuy+cfd8M2tt\nZq8A2aE/kEEQ0c+NmWUCnwAz3T23dkqLrsbRLkDKfAt0Kfe6Xait/PS2x5neUIXbL0EWdt+YWXdg\nMvCou6+rxdqiLaKfG3cfYGbfB942s23unllL9UWNjvTrjiXAjWbWKvT6DiC93PR04HYzizWzGGAE\n8GYt1xgN4fZLkB1335hZPPAcMDRggQ/h980VZnYdgLvvpuRzoda1XmUUKPTrCHffDjwFrDSzdcC3\n7r7IzDLNrJ27Z1ES8uuBD4E/htoatHD7JcrlRVUE++YnlBztpofaMs0sNZo115YI9k02MDx0ds9a\nYBfBOIjSxVkiIkGiI30RkQBR6IuIBIhCX0QkQBT6IiIBotAXEQkQhb6ISIAo9EVEAkShLyISIP8f\nMlp6fIFeFjgAAAAASUVORK5CYII=\n", 390 | "text/plain": [ 391 | "
" 392 | ] 393 | }, 394 | "metadata": { 395 | "tags": [] 396 | } 397 | } 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "metadata": { 403 | "id": "jlmXgAH9YlC9", 404 | "colab_type": "text" 405 | }, 406 | "source": [ 407 | "## References\n", 408 | "1. [Kamil Bennani-Smires, Claudiu Musat, Andreaa Hossmann, Michael Baeriswyl, Martin Jaggi. Simple Unsupervised Keyphrase Extraction using Sentence Embeddings. CoNLL 2018, pages 221–229](https://www.aclweb.org/anthology/K18-1022)\n", 409 | "\n", 410 | "\n" 411 | ] 412 | } 413 | ] 414 | } -------------------------------------------------------------------------------- /unsupervised_lexical_simplification_with_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "unsupervised_lexical_simplification_with_bert.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyON8SEl9gy4ZSbGuDcmQmwA", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "7pWD5xb4L_JV", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "# Unsupervised Lexical Simplification with BERT\n", 35 | "## Overview\n", 36 | "Lexical simplification aims to replace complex words in a given sentence with their simpler alternatives whlie preserve meaning. This notebook performs lexical simplification with BERT following \"*A Simple BERT-Based Approach for Lexical Simplification*\" [1]. The proposed method is a very simple, but effective. \n", 37 | "\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "e1SFYlc3QsTw", 44 | "colab_type": "text" 45 | }, 46 | "source": [ 47 | "## Settings" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "vhYTBzUDQp8X", 54 | "colab_type": "code", 55 | "cellView": "both", 56 | "outputId": "5c155e7c-6cf5-4a79-96f0-18406147bdd6", 57 | "colab": { 58 | "base_uri": "https://localhost:8080/", 59 | "height": 102 60 | } 61 | }, 62 | "source": [ 63 | "#@title Setup environment\n", 64 | "!pip install --quiet pytorch-transformers\n", 65 | "!pip install --quiet pytorch-nlp\n", 66 | "!pip install --quiet tqdm" 67 | ], 68 | "execution_count": 0, 69 | "outputs": [ 70 | { 71 | "output_type": "stream", 72 | "text": [ 73 | "\u001b[K |████████████████████████████████| 184kB 2.8MB/s \n", 74 | "\u001b[K |████████████████████████████████| 1.0MB 8.2MB/s \n", 75 | "\u001b[K |████████████████████████████████| 870kB 20.2MB/s \n", 76 | "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 77 | "\u001b[K |████████████████████████████████| 92kB 2.9MB/s \n", 78 | "\u001b[?25h" 79 | ], 80 | "name": "stdout" 81 | } 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "metadata": { 87 | "id": "kMCjFnONQvPR", 88 | "colab_type": "code", 89 | "cellView": "both", 90 | "colab": {} 91 | }, 92 | "source": [ 93 | "#@title Setup common imports\n", 94 | "from collections import Counter\n", 95 | "from tqdm import tqdm\n", 96 | "import numpy as np\n", 97 | "import torch\n", 98 | "from pytorch_transformers import (\n", 99 | " BertTokenizer,\n", 100 | " BertForMaskedLM,\n", 101 | ")\n", 102 | "\n", 103 | "import matplotlib.pyplot as plt\n", 104 | "% matplotlib inline" 105 | ], 106 | "execution_count": 0, 107 | "outputs": [] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "A90nF_ctPVy_", 113 | "colab_type": "text" 114 | }, 115 | "source": [ 116 | "## Substitution Generation\n", 117 | "We obtain substitution candidate...\n", 118 | "\n", 119 | "\n" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "metadata": { 125 | "id": "Ape_pX4ARUKT", 126 | "colab_type": "code", 127 | "outputId": "1fc06ad9-3c7f-43f8-b13a-a8229e428dfd", 128 | "colab": { 129 | "base_uri": "https://localhost:8080/", 130 | "height": 68 131 | } 132 | }, 133 | "source": [ 134 | "# Build model\n", 135 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", 136 | "model = BertForMaskedLM.from_pretrained('bert-base-uncased')" 137 | ], 138 | "execution_count": 0, 139 | "outputs": [ 140 | { 141 | "output_type": "stream", 142 | "text": [ 143 | "100%|██████████| 231508/231508 [00:00<00:00, 5965161.35B/s]\n", 144 | "100%|██████████| 361/361 [00:00<00:00, 76719.89B/s]\n", 145 | "100%|██████████| 440473133/440473133 [00:07<00:00, 59185022.20B/s]\n" 146 | ], 147 | "name": "stderr" 148 | } 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "metadata": { 154 | "id": "3P-4wiBxPX7S", 155 | "colab_type": "code", 156 | "colab": {} 157 | }, 158 | "source": [ 159 | "# Model's input\n", 160 | "text = \"[CLS] the cat perched on the mat [SEP] the cat perched on the mat [SEP]\"\n", 161 | "masked_idx = 10" 162 | ], 163 | "execution_count": 0, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "gWo9CFXIRzmw", 170 | "colab_type": "code", 171 | "colab": {} 172 | }, 173 | "source": [ 174 | "# Tokenize a text\n", 175 | "tokenized_text = tokenizer.tokenize(text)\n", 176 | "\n", 177 | "# Mask a complex token which should be substituted\n", 178 | "complex_word = tokenized_text[masked_idx]\n", 179 | "tokenized_text[masked_idx] = '[MASK]'\n", 180 | "\n", 181 | "# Convert inputs to PyTorch tensors\n", 182 | "tokens_ids = tokenizer.convert_tokens_to_ids(tokenized_text)\n", 183 | "segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]\n", 184 | "tokens_tensor = torch.tensor([tokens_ids])\n", 185 | "segments_tensors = torch.tensor([segments_ids])" 186 | ], 187 | "execution_count": 0, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "metadata": { 193 | "id": "QsyoSBm3SEJi", 194 | "colab_type": "code", 195 | "colab": {} 196 | }, 197 | "source": [ 198 | "# Predict a masked token\n", 199 | "model.eval()\n", 200 | "if torch.cuda.is_available():\n", 201 | " tokens_tensor = tokens_tensor.to('cuda')\n", 202 | " segments_tensors = segments_tensors.to('cuda')\n", 203 | " model.to('cuda')\n", 204 | "\n", 205 | "with torch.no_grad():\n", 206 | " outputs = model(tokens_tensor, token_type_ids=segments_tensors)\n", 207 | " predictions = outputs[0]" 208 | ], 209 | "execution_count": 0, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "metadata": { 215 | "id": "lSueDInYSSXO", 216 | "colab_type": "code", 217 | "outputId": "959cf97f-fc73-40b1-aa65-bd8a35f2b3c0", 218 | "colab": { 219 | "base_uri": "https://localhost:8080/", 220 | "height": 336 221 | } 222 | }, 223 | "source": [ 224 | "# Output top 10 of candidates\n", 225 | "topk_score, topk_index = torch.topk(predictions[0, masked_idx], 10)\n", 226 | "topk_tokens = tokenizer.convert_ids_to_tokens(topk_index.tolist())\n", 227 | "print(f'Input: {\" \".join(tokenized_text)}')\n", 228 | "print(f'Top10: {topk_tokens}')\n", 229 | "\n", 230 | "# Visualize output probabilities\n", 231 | "plt.bar(topk_tokens, torch.softmax(topk_score, 0).tolist())\n", 232 | "plt.xticks(rotation=70)\n", 233 | "plt.ylabel('Probability')\n", 234 | "plt.show()" 235 | ], 236 | "execution_count": 0, 237 | "outputs": [ 238 | { 239 | "output_type": "stream", 240 | "text": [ 241 | "Input: [CLS] the cat perched on the mat [SEP] the cat [MASK] on the mat [SEP]\n", 242 | "Top10: ['perched', 'sat', 'landed', 'was', 'rested', 'stood', 'settled', 'hovered', 'sitting', 'crouched']\n" 243 | ], 244 | "name": "stdout" 245 | }, 246 | { 247 | "output_type": "display_data", 248 | "data": { 249 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEdCAYAAAABymAfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3deZhcZZXH8e8viSEg+y4ECPvqKBAB\nB2RRZJUwDsgmIIvGBQREUWAGVBxHdhWEYYIgm2yKMhECCAi4sAYQkG0IESQMS5BVEALkzB/nlrk0\n3ekOqbeK9P19nqefrq3rfaur6p53PVcRgZmZNdeQblfAzMy6y4HAzKzhHAjMzBrOgcDMrOEcCMzM\nGm5YtyswqxZddNEYNWpUt6thZjZHuf3225+JiMV6u69YIJB0JvAJ4OmIWKuX+wX8ENgGeAXYKyLu\n6O95R40axcSJE9tdXTOzQU3So33dV3Jo6Cxgq5ncvzWwcvUzFvivgnUxM7M+FAsEEfFb4NmZPGR7\n4JxINwMLSnpfqfqYmVnvujlZvDTwWO36lOq2t5E0VtJESROnTp3akcqZmTXFHLFqKCLGRcToiBi9\n2GK9znWYmdk71M1A8DiwTO36yOo2MzProG4GgvHAnkobAC9ExBNdrI+ZWSOVXD56AbApsKikKcA3\ngfcARMRpwARy6egkcvno3qXqYmZmfSsWCCJi137uD2C/UuWbmdnAzBGTxWZmVs4cl2Jidow69PLi\nZTxy9LbFyzAzayf3CMzMGs6BwMys4RwIzMwazoHAzKzhHAjMzBrOgcDMrOEcCMzMGs6BwMys4RwI\nzMwazoHAzKzhHAjMzBrOgcDMrOEcCMzMGs6BwMys4RwIzMwazoHAzKzhHAjMzBrOgcDMrOEcCMzM\nGs6BwMys4RwIzMwazoHAzKzhHAjMzBrOgcDMrOEcCMzMGs6BwMys4RwIzMwazoHAzKzhHAjMzBrO\ngcDMrOEcCMzMGs6BwMys4YoGAklbSXpQ0iRJh/Zy/7KSrpN0p6S7JW1Tsj5mZvZ2xQKBpKHAKcDW\nwBrArpLW6PGwfwcujoi1gV2AU0vVx8zMeleyR7AeMCkiJkfENOBCYPsejwlg/uryAsD/FayPmZn1\nomQgWBp4rHZ9SnVb3beA3SVNASYAX+7tiSSNlTRR0sSpU6eWqKuZWWN1e7J4V+CsiBgJbAOcK+lt\ndYqIcRExOiJGL7bYYh2vpJnZYFYyEDwOLFO7PrK6rW5f4GKAiLgJGAEsWrBOZmbWQ8lAcBuwsqTl\nJQ0nJ4PH93jMX4CPAUhanQwEHvsxM+ugYoEgIt4A9geuAu4nVwfdK+koSWOqh30V+Jyku4ALgL0i\nIkrVyczM3m5YySePiAnkJHD9tiNrl+8DNixZBzMzm7luTxabmVmXORCYmTWcA4GZWcM5EJiZNZwD\ngZlZwzkQmJk1nAOBmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnDORCYmTWcA4GZ\nWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnD\nORCYmTWcA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnDORCYmTWcA4GZWcMNKBBI+oWkbSU5cJiZ\nDTIDPbCfCuwGPCTpaEmrDuSPJG0l6UFJkyQd2sdjdpJ0n6R7JZ0/wPqYmVmbDBvIgyLiGuAaSQsA\nu1aXHwNOB86LiNd7/o2kocApwMeBKcBtksZHxH21x6wMHAZsGBHPSVp8tl+RmZnNkgEP9UhaBNgL\n+CxwJ/BDYB3g6j7+ZD1gUkRMjohpwIXA9j0e8znglIh4DiAinp6l2puZ2Wwb6BzBL4HfAfMA20XE\nmIi4KCK+DMzbx58tDTxWuz6luq1uFWAVSX+QdLOkrfoof6ykiZImTp06dSBVNjOzARrQ0BBwekRM\nqN8gaa6IeC0iRs9m+SsDmwIjgd9Ken9EPF9/UESMA8YBjB49OmajPDMz62GgQ0P/0cttN/XzN48D\ny9Suj6xuq5sCjI+I1yPiz8D/koHBzMw6ZKY9AklLksM5c0taG1B11/zkMNHM3AasLGl5MgDsQq48\nqruUnHz+iaRFyaGiybP0CszMbLb0NzS0JTlBPBI4sXb7S8DhM/vDiHhD0v7AVcBQ4MyIuFfSUcDE\niBhf3beFpPuAN4FDIuKv7+iVmJnZOzLTQBARZwNnS9ohIi6Z1Sev5hUm9LjtyNrlAA6ufszMrAv6\nGxraPSLOA0ZJetvBOiJO7OXPzMxsDtLf0NB7q999LRE1M7M5XH9DQ/9d/f52Z6pjZmad1t/Q0Ekz\nuz8iDmhvdczMrNP6Gxq6vSO1MDOzrhnIqiEzMxvE+hsa+kFEHCTpV8DbUjtExJhiNTMzs47ob2jo\n3Or38aUrYmZm3dHf0NDt1e8bJA0HViN7Bg9WqaXNzGwON6Dso5K2BU4DHibzDS0v6fMRcUXJypmZ\nWXkDTUN9ArBZREwCkLQicDngQGBmNocbaBrql1pBoDKZTDxnZmZzuP5WDf1rdXGipAnAxeQcwafI\nNNNmZjaH629oaLva5aeATarLU4G5i9TIzMw6qr9VQ3t3qiJmZtYdA101NALYF1gTGNG6PSL2KVQv\nMzPrkIFOFp8LLEmesewG8oxlniw2MxsEBhoIVoqII4CXq/xD2wLrl6uWmZl1ykADwevV7+clrQUs\nACxepkpmZtZJA91QNk7SQsARwHjyjGVHFKuVmZl1zIACQUT8uLp4A7BCueqYmVmnDWhoSNIikk6W\ndIek2yX9QNIipStnZmblDXSO4ELgaWAHYEfgGeCiUpUyM7POGegcwfsi4ju16/8haecSFTIzs84a\naI/g15J2kTSk+tkJuKpkxczMrDP6Szr3EplkTsBBwHnVXUOAvwFfK1o7MzMrrr9cQ/N1qiJmZtYd\nA50jQNIYYOPq6vURcVmZKpmZWScNdPno0cCBwH3Vz4GSvleyYmZm1hkD7RFsA3wwIqYDSDobuBM4\nrFTFzMysMwa6aghgwdrlBdpdETMz646B9gi+B9wp6TpyBdHGwKHFamVmZh3TbyCQJOD3wAbAh6qb\nvxERT5asmJmZdUa/Q0MREcCEiHgiIsZXPwMKApK2kvSgpEmS+uxBSNpBUkgaPQt1NzOzNhjoHMEd\nkj7U/8NmkDQUOAXYGlgD2FXSGr08bj5yRdIts/L8ZmbWHgMNBOsDN0t6WNLdku6RdHc/f7MeMCki\nJkfENDJx3fa9PO47wDHAqwOutZmZtc1AJ4u3fAfPvTTwWO36FHqc3lLSOsAyEXG5pEP6eiJJY4Gx\nAMsuu+w7qIqZmfWlv1xDI4AvACsB9wBnRMQb7ShY0hDgRGCv/h4bEeOAcQCjR4+OdpRvZmapv6Gh\ns4HRZBDYGjhhFp77cWCZ2vWR1W0t8wFrAddLeoRclTTeE8ZmZp3V39DQGhHxfgBJZwC3zsJz3was\nLGl5MgDsAuzWujMiXgAWbV2XdD3wtYiYOAtlmJnZbOqvR/B668KsDglVj9+fPG/B/cDFEXGvpKOq\nBHZmZvYu0F+P4AOSXqwuC5i7ui5yi8H8M/vjiJgATOhx25F9PHbTAdXYzMzaqr/zEQztVEXMzKw7\nZiXpnJmZDUIOBGZmDedAYGbWcA4EZmYN50BgZtZwDgRmZg3nQGBm1nAOBGZmDedAYGbWcA4EZmYN\n50BgZtZwDgRmZg3nQGBm1nAOBGZmDedAYGbWcA4EZmYN50BgZtZwDgRmZg3nQGBm1nAOBGZmDedA\nYGbWcA4EZmYN50BgZtZwDgRmZg3nQGBm1nAOBGZmDedAYGbWcA4EZmYN50BgZtZwDgRmZg3nQGBm\n1nAOBGZmDedAYGbWcEUDgaStJD0oaZKkQ3u5/2BJ90m6W9K1kpYrWR8zM3u7YoFA0lDgFGBrYA1g\nV0lr9HjYncDoiPgn4OfAsaXqY2ZmvSvZI1gPmBQRkyNiGnAhsH39ARFxXUS8Ul29GRhZsD5mZtaL\nkoFgaeCx2vUp1W192Re4orc7JI2VNFHSxKlTp7aximZm9q6YLJa0OzAaOK63+yNiXESMjojRiy22\nWGcrZ2Y2yA0r+NyPA8vUro+sbnsLSZsD/wZsEhGvFayPmZn1omSP4DZgZUnLSxoO7AKMrz9A0trA\nfwNjIuLpgnUxM7M+FAsEEfEGsD9wFXA/cHFE3CvpKEljqocdB8wL/EzSHyWN7+PpzMyskJJDQ0TE\nBGBCj9uOrF3evGT5ZmbWv3fFZLGZmXWPA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnDORCYmTWc\nA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOB\nmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZwzkQmJk1nAOBmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZ\nwzkQmJk1nAOBmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZww0r+eSStgJ+CAwFfhwRR/e4fy7gHGBd\n4K/AzhHxSMk6dcuoQy8vXsYjR29bvAwzG3yKBQJJQ4FTgI8DU4DbJI2PiPtqD9sXeC4iVpK0C3AM\nsHOpOjWVg5CZzUzJHsF6wKSImAwg6UJge6AeCLYHvlVd/jnwI0mKiChYL+ugbgeh0uU7ANpgoFLH\nXEk7AltFxGer63sA60fE/rXH/Kl6zJTq+sPVY57p8VxjgbHV1VWBB4tUuneLAs/0+yiX7bJdtst+\nd5e9XEQs1tsdRecI2iUixgHjulG2pIkRMdplu2yX7bIHS9k9lVw19DiwTO36yOq2Xh8jaRiwADlp\nbGZmHVIyENwGrCxpeUnDgV2A8T0eMx74THV5R+A3nh8wM+usYkNDEfGGpP2Bq8jlo2dGxL2SjgIm\nRsR44AzgXEmTgGfJYPFu05UhKZftsl22y+6UYpPFZmY2Z/DOYjOzhnMgMDNrOAcCs4okVb+HtC6b\nNYEDQT/qB4QqbYYNXnNLGhER01ur1yQN7URQkLS6pIVLlzOnaFogljRfN8t3IOhHRISkDavLb8I/\nWoxF/ne1VulOkt7TrS+EpFUkbS5pZLXHo1Plzt/LbZ36nO4AvCLpV5K2gHzPa0GhZD32AY6V9BVJ\nH+zk/xxA0oclnSBpi97egw6Uv6qk0yWtC/m9q27vRBAe1s1GnqQPAUdK2qVqEHR8o69XDfVD0prA\n7cCTwLnkMtg/1+4fEhHT21TWMGApYDngjIhYpR3P+w7r8jVgQ+AO4C/AncATETG1cLnHA1dExLWS\nFoqI50qW10v5w4GzyTxYLwFXAt+PiD8WLvOfqp/NgSWBPwH3ANe28nWVJGkzYBNgcWAu8v2+Hbi1\n1QAqXP5ywBHAPwMvAFcDP6l/1wqUOSQipkvaGtiJXOr+x4h4oFSZfdTjI8BmwILA+4A/AvcDd7TS\n7xSvgwPBwEg6ADgEWIg8MJ5GflBfamMZS5NJ+LYik/PtA0yLiKmSPghsExH/2a7yBlCfDwDHAqOB\nG4GJ5Id0MvBgRExrc3mLA9cBHyC/EKdVZe8cEde3s6xeym4dFOYhU6dfSu5t2QP4NDAN2CAiHi5Y\nh28CS5Cfr6nAx4DFgKsj4thS5dbKXxw4DliJDEQALwMPAedHxAsdqMN3yHxiS1X1mEwG5nMi4u+F\nyjwa2JI8+L5E/u9vBG6MiOc7kQiz6m2OB6YDTwOvAPMBDwDnRsT/lSx/jsg11E21Fv8mwO7kh2Uj\n4HvADyTtEBG/bEdZEfG4pC+SB8B5yHM1PC7p8qrs+2b29+0iaVhEvAGMAp4g04BsQLaSzwbOB75c\noOiPAr+tNiN+jWyhnQp8Eri+QHl1rS/63sD7IqKVtvQmSXcCHywZBCo7ARtWB595gUfIxsftJQuV\nNLRq9W8OzBsRG0paAVgH+AbZ+PlxwfJVDcGuBWwZEetVt48A/hs4AbiBPCi2u8z1yO/zJuTnfQ3g\nq2RKnNWBE0oGgdr//lPAaxGxg6T3Ah8iG4VDycBQlANBP6pW4kbAKhFxQxUYfiFpCvnmXQszPliz\nU1b1HG9I+nJE/F3SgsA2wBbkgfCM2Xs1A1MFAcgDw/0R8QrwG+A3kp4GRhQaLvhfYH9JTwHHRcRJ\nkvYD5i5Q1lvU3rsbgfUkrVAbklmcbKUXUx30bgAOl3RCRDwFXCvpGODekmXX3svFyfeA6rVPlrQE\nsFBEvF6w/Nb/fjngeUmLAM9HxKuSjgWeafdwTa3M9wN3RcSLwN3A3dW8xNbAhpL+GhFntbPsHvVo\n/e+HA9MlLRgRzwPXSzqFbBi80fcztIcDwcD8kWwZ7h0RP6luWw0YWX2AaEMQaA1NbAeMkrQomW77\nhog4f3aeezaMBw6VNJn8kvyZnDc4uURhEXGHpC2BNSPi1moC70vAniXK66MOd0q6D7hV0h3AreRQ\n3RdKlVk1AF6tvvgHk8FgUeDvwFMR8WSpsnu4ErhI0kLAr8meyM7AiZ0oPCIul7QpeYKqcdVwyeGU\n7RH9DviKpJOAn0fEb8kgcAkZGFcvWHbd5WSv5Ehlev43yaHh73eicM8RDFC1cuiH5ETedcAi5BzB\nz2rdu9ktYxh54BlHTpxNILuok8gJy+KThr3U6dNk1/lNstv8ZETsVqCcFYHlybH4J4DngdfJFlHR\ns8vUgvBQYH0yR/xfyC/iEODS0pN2VdlLknMiw4H5AQG/qnoHHVENCW1JDk2MJucGjp75X7Wt7GHk\n//tgYAzwMNlDOTUiimUllrQsmedsVXLS9jpy6PNS4IiIuKVQuW8ZRajm5DYnJ41XJIdJTytR9tvq\n4kDQu9rBYV7yCzEyIs6TtDbZG7i81RtoY5k7kR+Eg8lVE/9Cjo++Bny+E13EWl0+QB74n69+3ktO\nGj7bzgnyWnmXk2OhW5Kv/XGyF3J1yYNAjzr8ghyGORxYKyLulzR3NUzX9gnDVgOiamQcSK7WeRh4\nlRyb7lhKdkm7kRP0I4GzyEnaV4Dp7X7dPcpt/Q+2AT5Bfte2i4inJL03Il4uUGbru70WORw1mRx+\nfI6cLH6RDMoHRsRX211+L/X5ItnzGA4cTw6FPV+63DrvI+hb68N/DrmK5UeS9oyIO8kWw0vVWGI7\nLVKV9ynguqoleDnwSieCQOv1SFod+C4ZiE6KiJuA35If0BJB4KPk6Nre5PkoLiFbhLuSPYRiquEH\nqiGJIWSv77YqCCwLfK86IJU4GLae8yjgV8C/AReRLcLDJQ0v8Bn7hx7v9+fIFStbRsRdZFBah+yV\nFFMFgSHk8NOxwNLA0pLmArarhqnaXeZ0SXOTE+BHkkOdnwE2JhdGTI+Iv5ANgiJq//s1gd2Aa4Bt\ngUeBNyRtpw7uJ3Ag6EO1omAkuZzvZHK10ITq7mOA5dt5cFAuW3yKXBl0C7C9pOOA/cgWcie0vvR7\nAheTX8yJ1W2fJFczlbAOGWg/TZ6TYjxwKLlvoe2Bpy5m7AHZEDiP/DLeVN22Bjlf0fZWaavs6n0f\nApwXEX+qhiEOIhsfKxVetth6v/cAziTHy++ublsb+Fa0aY9Mr4XPCHLbkw2N14E/R8QdwAjgK+Q8\nSTvLbB3zxgC3R8T65BDg/eTcwBHkMCgR8Vo7y+7DtmTjbzLw+2pS/sPAYZ0cAXAgmLm1yIP/PwOP\nRsQzVStxnXaN19c+mAcCH4mIZ8nu6V3AdsDp1YGxE1pfzKFk9/hb5MER8sN5f9sLzIPBB4G/kUND\na0v6JNlCnTizv22z08n/9zFAa0HAF8ilsiVNIwPPLZLGVC3g+YEVIqL0cuFWkLmHHBY6hhnLRLcD\n/lC08BlB7mHyRFaH18r/JPBYRLza5mJb37cPAMMlzRMRf4mI0yJiF+CAqhFY9NhYe+0PkPMBZ5G9\nQciVghN6+bNivGpoJiLiSkk7AF8nh0ogWwzj4S1rgGenjFaLawvgc1Uv5PvkiqF6L6S42msZR77m\n5YHHqrmLjcgvZ7vNT77OjwA/A04iV6o8S7aUiqmNT69Etjx/QR4QJ0i6h/ySFg0EkcuF/53sCa4L\nfAeYAhSfoK0djK4Evk2Oi79fudN1NPk+FKVcMy/ywLwXMLSal9uB/F+0VfX/HkEOg60NnCrpbjIY\n3xoRU6r5oGI9oR71GV99/qYDK0j6OTkf97lOlN/iyeJ+VMv49iI3Gk0Dfkmmf3i8XROIkhYgWwR3\nAZuSB+JLyA/nXhFxd59/3CaSfgQ8Rq6EelrS+sAB1d2vAb+OiAsLlT0cOIw8XekvyTQej5Qoq0e5\no8iezyXk5Oxl1e2LA8tFxG2Fym1tZpqH7A2tQI6NTyH3iwwjW8Mlh2XOIedjzoyIe6oD8h7AwsxY\nEfenmT3HbJbfmrDdj9yXcoIyz9AXyJ7hT9vdI5L0JWB8awWYpGXIVUKrk///RyPiiHaW2Uc9Wu//\ncPIY/JqkjwFrkqMBf4iI/y1djzr3CHpQtau2WlGwNDlWOZ5spc9Xn81v1/htRLxQtQp3AK6MiPOV\nOx5f71AQEDk+vD1wnaSbgQvJ3cxzFeiet8pdgOwO/4qcFH+DDD4flLRvFMxrVE0Wbkz2xFYDXq+W\nTj5TBcLtJD1SqA5DyHHo/cg0Ev9H9gjWA+6NgnmNan5MTpD+XLlJ8KfAL6NDS1VrQW5RcoKUiLid\nWku4wEqtvwJPSrqEXAl3WkScUw0DbUA1VKY25g/rTRUEhgLfBD5R/f9PjIiTSpXZH/cI+lANDUwi\nW8kihyoeIdd1P1OozHpLYV/gzYgofl7T2hDJaGAsmedlLjLPza+ByyI32rS73DXI3dIrkJPTl5GT\ntqMjYpt2l9dL+XMD+5MTdg+QPZ+7yAP1f0bE4oXLfwj4UGRKiVFkz3N1cpy62AG554FO0s7kAoHV\nyGG6CyPivL7+vo31GEk2Ap4jV+/cB7wUBXcxV6uRViMbXR+vbr4MuCgiJpUqt1b+PBHxSjXcuh/Z\nC94N+Cy5lPUaYIfCiwTeXi8HghkkrRoRD0pamZy136e6vDqwMrm07PDIlAul6zKC7BF0IvNjq5t+\nLZnb53/I1tG+5J6GCRFxcKGy5yX3DmwEXBURV5Yop5dyW0H3n8lMm4uT8xQfJtfP3x0R5xYsfyFy\nCPB71SqZ1u23kckFS/aGRkTuZB4PnBwRV9fqtCewYER8u1T5VVmtydjNyR21G5I7128ErolCWUeV\na/YfiIjrqs/eBuQy5XkjohNzIp8je0AbkiukzqrdtwKwUUQUnRvrjYeG3mo/ZYqHyVS5ZSLiIeCh\nqpW+dCeCQFVukeGYPsqaXgWe18gvSWvZ2jhJ65CJv9pK0mLk7s3h5Cauh4FDJK3fgYNQKwisRi7V\n3D0iHq1aa0Vbwq2yI+I55Sa605Ub2R4H5iX3jJQMAvMBX1euX1+LWjK/qk5/oHCSu6qs6dWQ5N/I\nTVTHk0s6P01O3LctECjX43+GnIQdS06EExF/A66pembPVI8tNixU/e8/QDY45iJzGT1D9oSeiCq/\nU4my+62bewQzVB/MNcksmAeRk8PjyYmru7pZt06QtDu5euRM4AJyiOinEbFcgbJGkq3BN8ilipBp\nd+eJiE3bXV6PsltDYSeTE4THSzoI+Bo5jrxtlE8pcXJEfLlaobMe+Zm7h0y3XHTZqDKZ4YnkROlz\n5AbJH5NDoNdExFIFy271Pj9MTgwPB5aIiI/2fEwbyxxGHoCPq35fS36vz6/qcinZGPhbu8rspz6L\nku/5x8mJ+SfIock/lVqg0G+dHAhS7QO6ABkAppPdxk+SK3n+GhEf62IVO6JaLbQvOVxzLZni4YIO\nlDuMbLG91qnekKRzybmJZcm13IeT6cVvjIifFShvlaqc5YGPRsSO1RDJXFEo134vdWh9zg8j904s\nQx6QNyeXLF8TEcWSzNXKP5Ocf1qRHIo6RLmhkIj4aZvLbPUADyNXZg0hd++vTAb+JyLTPxedJK4t\nRNmH3EA4rVqUsgl5jLksIs4uVf5MRYR/MhgOqX6fDXy4x32LAetWl4d2u65tfM2thsAi5KTVWHLi\ndMlu161Dr39LctPaRHIH+RByWGTZQuWNIpcJT6vK3B5YuLrvS8DxhV/v/GTQ25icA+l5/8Kd+HyT\nPY+rqsuXkRs0Ifdx7NPmsobVLv+cXKoKOTSzKrlqbFR1W7HXXvuurUbOg8zdy2PmKv2/7+vHcwT8\no8UwvRquWDcyt05rqEhky+FWeMumq8FA5KTw18lxy9vI17qFpGfJLe/XdrF+pf2BbI1F5EqOncmk\nekXOPRC5N2IvSY+S68UPAo6W9Euy91k63fPS5OT/VsADyjPivRkRTyqTDG4REceVrECtR3CupN8C\ni0emHx9J9g7a3ftcTXn606XJlN6vQqaPkPQY2eh5pLqtE9/tPcjTj/5d0lxVPTYh//f/1oHye+UU\nE7xlP8AGVCfmkDS8un15sqXWsbwfnRIzusFLkCtVDiTXk99EDtMUP3F4N0gaIum7ZAqNC4H3KE/Y\nfg25cbBEmUOr34eR+1GOB/6VDL4fo7ahrZSIuB/4InAFuXfhXDIQ7QT8B9nzLaYWBNaOnJQfDzyh\nzL9/Ajlc0tYhsshNcZ8mV4KtKekvkn5Qvd97kkt263mPiqgdY+4hzzeyZMzIZfRJcmNj13iOoKZa\nV34y2RI+q1opdAw5bn2o2nTegXeD2rjpumQr7LCIuKR2/4LAy1FwTXen1SaJdyNbxb8Ddo2Ijyoz\ncK5Y+mAs6Tpy09Sb5OTlQ+RqrVMjotgpCVvvd3V5OTIQzEUeJDcgD1BnRuH0x9V37HpyUvyU6oC8\nLLlaqq0rZqrlmItExG2SPh4RV0vagNy0uA25bPiQiJhYen6gVqeFyZ7fE+RKpXmrunyqVE90QPVq\neiCotVJWJb8QL5CtowWAm8nlbd+OiEfqX6bBoprA/C65fv4hspV2QTVcMOheL4Cki8mcUVuR68e/\nK+kr5DkI9i1YbtdSidSC4AHkXMHuwEHRoX0bPeqyKpld9imyJ1Rkuayk7cnJ4KHk//sa4KbWAb9q\nlRc/+1vtGLNiRDysTGGyE3mMWZhcvVR8ye5M6zgIv+ezpPYmfYccQ/xRdfsqwFIRcX1XK9ghVQ9g\nS3It9+ZkS/k33a1V+1Ut0k7ABsMAAAarSURBVN3IFNOfADaIXD//a/IscFcULn9NclfrtIg4WplK\n5OTIdMhFacYZ8D5DLhHeJzLP0BfJ9BKdOiVmq7X+JXIj31FRcFevpM+SS2WXJIPP1eTpZ+/uVA+/\nGnq6hRyiuoicLO/KnoHeND4QtEg6muyi/wB4cbAMAfWmFvwWJg/685MfzMeq+5cDpgym/0HtNR9M\n7hV5lAx6d5MbKxeKiO1m9hxtrEu3UomMIYchvk7uFt9ImfjuFjKtR5H8+7XXuyS5SucOco/KCPLz\ntxmwabTxjGy1HtC6wKcj4mDlTuITyPNfBDCmQz2C+rDcv5DDceuSQemHUSiZ46zwqiFAeb7czYD3\nkB+Qq5S7DZ8bTAfDXvyEPAgcCrwg6UFy8vTiwfa6a+O/y5Kb5H6j3EW7FPm+X9zBukT1e5qkn5An\nZOmEW8jhzwuZ8Xr3Bf5YKgjAWyZKVyEn43cgN1D9E7ko4e52BoEe9qnKgpwofy85R3F6p3pAVRBc\nPCKejohLgUuVmU9/yoz3o6sa3yOoTxJVY5djyd2H08jshJ06KUxHVeOUl0XEepJ+T05gHkyOHW8c\nXdrhWFL1mieSDaDDybmQTpyFqmtqrfF5ycbOUmSCs78x44Q0/9WpYcDWpqra9RHkHp4iqVskXU0O\ng61K7lw/h0x5fkFE/E+JMnupwzxkptHJZFK/ByIz3J5Kpvvu+nfNPQIYImlz8jR1D0TEV6s3bgw5\ns188LW0n1bqpmwLXS3o/mdzufklHAm+8Gz6YJVRfvg+T52LeGThQ0sPkGP0N3a1dMfWU10tUQyR3\nAp8nFwhsHYWy6cJbhuRWIXfzjpa0CHkynDOifNrr75DLRNcAtoqIF6u6FD37Wg8LkZPWK1f1+Luk\nJcj5qS91sB59amyPoDaGuDeZ5+URYL2I2FKZHvnJyNNGDlrVF3JJ8uTpNwDrk6k0DpjpHw4SyqRz\ne5LnAGhrWoN3m9qy1VfJ5YsPk8OgJ5UcIql9z8aRyeTOI/MLjSWHpL5fquyq/CHk6py/R2Zc3QnY\nKar0Hp1q4FUT9UPIhHdrkCMOf46I33Wi/P40ORC0uszXkK2lz5I5R06U9A1yaO/Y7tayvSRtROZ/\nPwb4caslWE1gbU5uIPthdPjsSFZWj2Wrm5HZZC8hUx3sXXLZalX+EHIYasPIcy8IWIlMef7V0uXX\n6jGUHApTRDxWMhDUji9LkfMie5CNzbPJ8z28qw68jd1ZXL1J72HGCdI/Qp6FDHJlxUR4S970OV5E\n/J4cK50G3CvpJkl7RMSlEbF/ROznIDD4RMQLwL+TiRSviIjzyXmwNzoQBFrn//0F1dnHqoPgo+RO\n5o593iLizYiY0lodV7g30DpufIWcl9mMXJyxP3Dhu+240tgeQYukbckUvM+QWRhXBPbrxLrubqvG\nSr9CbnG/D9g3Cp0QxLqvW8tWq7K3J3dSB5nY7xkyydrnO1F+t0i6gdy1f2PttouAH71bhoWg4YFA\n0l7kzr6NyZbyPORSrvMj4i4NopQSM1N1mTcC7hns8yKWVPgMeLWgsyjZEn6YTAG9EnlKxjOA/4lB\nlMKkN9Uc5GrkprmXqw2NN9OBc17MisYFgtoHdDS5lOw0cknX6uSk0okR8XI362g2WChPzbgZOR+w\nMrk4YS7y2PPNbtatlNpKqR3JE/98gcx+ejuwILmrvFgqk3eiiYGgtYphP/INOV3Se8klXicAd0TE\nMd2tpdmcTdKVZG6fZcjJ0VZq9yXIXbUvVnNWg1LVy/5DRGxQXd+IXDF0PTApOnQ2tIFq3D6CWld4\nH+BlSTdUE6QvS/or1S7PpgwLmbVbdRA8k1yJtgGwo6RvkTu6nwImdLF6RdVWIq0JPF0tRX+wCnrv\n2sDXuB4B/GNN7xhywmxV4ElyDPM9EbFbN+tmNphIeh+ZUmJLsncwiZwovb6b9SqlNvR8CHku7j+T\n6TweBR6PiOe6WsE+NDIQ1NU+qDuSydceAMYN1g+qWbcM9g18tSCwELkU/QJy9/aKwPPkPoKzolxe\npXes8YGgpdrksiqZovdPg/GDambl1OYfv0ae9/qA6vaFyBxea0TEF7tayT40bo6gL9UmlwfIhFRm\nZrOkNqe4EXA8/OOUt89Jmgbc27XK9eNdtbvNzGwQuAL4pqTWeQ8gU2B3MtHdLPHQkJlZG1Wbxr5B\nJtdbhty0+mxE7NHVis2EA4GZWZtVwWA98hwIrwG3RMSL3a1V3xwIzMwaznMEZmYN50BgZtZwDgRm\nZg3nQGBm1nAOBGZmDedAYGbWcP8PZlvKKyrqFIUAAAAASUVORK5CYII=\n", 250 | "text/plain": [ 251 | "
" 252 | ] 253 | }, 254 | "metadata": { 255 | "tags": [] 256 | } 257 | } 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": { 263 | "id": "FVPajT4APaA1", 264 | "colab_type": "text" 265 | }, 266 | "source": [ 267 | "## Substitution Ranking\n", 268 | "Previous section, we obtained subsitution candidates of a complex word \"*perched*\". \n", 269 | "In this section, we rank substitution candidates by following four features:\n", 270 | "\n", 271 | "- Probability of BERT prediction;\n", 272 | "- Probability of n-gram language model;\n", 273 | "- Similarity obtained by FastText;\n", 274 | "- Word frequency;\n", 275 | "\n", 276 | "We choose the top of candidate with the highest average rank over all features as the simplification replacement.\n", 277 | "\n", 278 | "In this notebook, we use only a **probabitity of BERT predicaiton** and a **word frequency** for ranking step." 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "metadata": { 284 | "id": "8BNPz8jrUysR", 285 | "colab_type": "code", 286 | "cellView": "both", 287 | "outputId": "29019cfb-10e0-4d7f-db10-1731ee7b73c1", 288 | "colab": { 289 | "base_uri": "https://localhost:8080/", 290 | "height": 1000 291 | } 292 | }, 293 | "source": [ 294 | "#@title Download the Children's Book Test (CBT)\n", 295 | "!wget http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz\n", 296 | "!tar -xzf ./CBTest.tgz\n", 297 | "!DATA_PATH=./CBTest/data; cat ${DATA_PATH}/cbt_train.txt ${DATA_PATH}/cbt_valid.txt ${DATA_PATH}/cbt_test.txt > ./cbt_all.txt\n", 298 | "!wc ./cbt_all.txt" 299 | ], 300 | "execution_count": 0, 301 | "outputs": [ 302 | { 303 | "output_type": "stream", 304 | "text": [ 305 | "--2020-02-06 05:24:26-- http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz\n", 306 | "Resolving www.thespermwhale.com (www.thespermwhale.com)... 69.65.3.213\n", 307 | "Connecting to www.thespermwhale.com (www.thespermwhale.com)|69.65.3.213|:80... connected.\n", 308 | "HTTP request sent, awaiting response... 200 OK\n", 309 | "Length: 120547669 (115M) [application/x-tar]\n", 310 | "Saving to: ‘CBTest.tgz’\n", 311 | "\n", 312 | "CBTest.tgz 100%[===================>] 114.96M 10.9MB/s in 11s \n", 313 | "\n", 314 | "2020-02-06 05:24:37 (10.5 MB/s) - ‘CBTest.tgz’ saved [120547669/120547669]\n", 315 | "\n", 316 | "tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'\n", 317 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 318 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 319 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 320 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 321 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 322 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 323 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 324 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 325 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 326 | "tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'\n", 327 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 328 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 329 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 330 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 331 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 332 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 333 | "tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'\n", 334 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 335 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 336 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 337 | "tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'\n", 338 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 339 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 340 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 341 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 342 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 343 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 344 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 345 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 346 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 347 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 348 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 349 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 350 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 351 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 352 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 353 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 354 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 355 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 356 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 357 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 358 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 359 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 360 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 361 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 362 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 363 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 364 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 365 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 366 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 367 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 368 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 369 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 370 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 371 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 372 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 373 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 374 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 375 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 376 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 377 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 378 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 379 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 380 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 381 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 382 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 383 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 384 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 385 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 386 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 387 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 388 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 389 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 390 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 391 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 392 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 393 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 394 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 395 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 396 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 397 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 398 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 399 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 400 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 401 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 402 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 403 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 404 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 405 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 406 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 407 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 408 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 409 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 410 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 411 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 412 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 413 | "tar: Ignoring unknown extended header keyword 'SCHILY.dev'\n", 414 | "tar: Ignoring unknown extended header keyword 'SCHILY.ino'\n", 415 | "tar: Ignoring unknown extended header keyword 'SCHILY.nlink'\n", 416 | " 292912 6135083 28453805 ./cbt_all.txt\n" 417 | ], 418 | "name": "stdout" 419 | } 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "metadata": { 425 | "id": "RHHukcMobWyz", 426 | "colab_type": "code", 427 | "cellView": "both", 428 | "outputId": "d60a095b-5c07-4fef-b71f-0a4cc5f8efe1", 429 | "colab": { 430 | "base_uri": "https://localhost:8080/", 431 | "height": 34 432 | } 433 | }, 434 | "source": [ 435 | "#@title Build a table of word frequency\n", 436 | "def count_lines(path):\n", 437 | " with open(path, 'r') as f:\n", 438 | " return sum([1 for _ in f])\n", 439 | "\n", 440 | "word_frequency = Counter()\n", 441 | "filepath = './cbt_all.txt'\n", 442 | "n_lines = count_lines(filepath)\n", 443 | "with open(filepath, 'r') as f:\n", 444 | " for line in tqdm(f, total=n_lines):\n", 445 | " if line.startswith(\"_BOOK_TITLE_\"):\n", 446 | " continue\n", 447 | " else:\n", 448 | " tokens = tokenizer.tokenize(line.rstrip())\n", 449 | " for token in tokens:\n", 450 | " word_frequency[token] += 1" 451 | ], 452 | "execution_count": 0, 453 | "outputs": [ 454 | { 455 | "output_type": "stream", 456 | "text": [ 457 | "100%|██████████| 292912/292912 [02:01<00:00, 2414.95it/s]\n" 458 | ], 459 | "name": "stderr" 460 | } 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "metadata": { 466 | "id": "R2zj_Ow9u4ch", 467 | "colab_type": "code", 468 | "colab": {} 469 | }, 470 | "source": [ 471 | "# substitution ranking\n", 472 | "bert_rank = np.array([i for i in range(len(topk_tokens))])\n", 473 | "frequency_rank = np.argsort([-word_frequency[token] for token in topk_tokens])\n", 474 | "avg_rank = np.argsort((bert_rank + frequency_rank) / 2)\n", 475 | "\n", 476 | "# sort candidates and except a complex word\n", 477 | "candidates = [topk_tokens[i] for i in avg_rank if topk_tokens[i] != complex_word]" 478 | ], 479 | "execution_count": 0, 480 | "outputs": [] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "metadata": { 485 | "id": "sR1gcnLxwZeC", 486 | "colab_type": "code", 487 | "outputId": "633255c2-43dd-45c3-e44a-be50036f9e4a", 488 | "colab": { 489 | "base_uri": "https://localhost:8080/", 490 | "height": 34 491 | } 492 | }, 493 | "source": [ 494 | "tokenized_text[masked_idx] = candidates[0]\n", 495 | "print(\" \".join(tokenized_text))" 496 | ], 497 | "execution_count": 0, 498 | "outputs": [ 499 | { 500 | "output_type": "stream", 501 | "text": [ 502 | "[CLS] the cat perched on the mat [SEP] the cat sat on the mat [SEP]\n" 503 | ], 504 | "name": "stdout" 505 | } 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": { 511 | "id": "hzOHx9T-QgDV", 512 | "colab_type": "text" 513 | }, 514 | "source": [ 515 | "## References\n", 516 | "\n", 517 | "[1] [A Simple BERT-Based Approach for Lexical Simplification\n", 518 | "](https://arxiv.org/abs/1907.06226) \n", 519 | "[2] [huggingface/transformers](https://github.com/huggingface/transformers)\n" 520 | ] 521 | } 522 | ] 523 | } --------------------------------------------------------------------------------