├── README.md └── Term_Extraction_Sequence_Classifier.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Term-Extraction-With-Language-Models 2 | 3 | ## Reference 4 | Lang, C., Wachowiak, L., Heinisch, B., & Gromann, D. Transforming Term Extraction: Transformer-Based Approaches to Multilingual Term Extraction Across Domains. 5 | - [PDF](https://aclanthology.org/2021.findings-acl.316.pdf) 6 | - [Video Presentation](https://www.youtube.com/watch?v=JuBHSfFquCU) 7 | 8 | ## Description 9 | This repository contains the scripts used to finetune XLM-RoBERTa for the termextraction task on the ACTER dataset (https://github.com/AylaRT/ACTER) and the ACL RD-TEC 2.0 dataset (https://github.com/languagerecipes/acl-rd-tec-2.0). One model version is used as a token classifier deciding for each single token of an input sequence simultaneously if it is a term or a continuation of a term. The other model version is a sequence classifier that decides for a given candidate term and a context in which it appears whether it is a term or not. 10 | 11 | ## Requirements 12 | * transformers v.4.2.2 13 | * torch v.1.7.0+cu101 14 | * sentencepiece v.0.1.95 15 | * sklearn v.0.24.1 16 | * nltk v.3.2.5 17 | * spacy v.2.2.4 18 | * sacremoses v.0.0.43 19 | * pandas v.1.1.5 20 | * numpy v.1.19.5 21 | 22 | ## Results 23 | 24 | ### F1 Scores on ACTER 25 | 26 | Training | Test | Sequence Classifier | Token Classifier 27 | ------------ | ------------- | -------------|------------- 28 | EN | EN | 45.2 | 58.3 29 | FR | EN | 44.7 | 44.2 30 | NL | EN | 35.9 | 58.3 31 | ALL | EN | 46.0 | 56.2 32 | | | | 33 | EN | FR | 48.1 | 57.6 34 | FR | FR | 46.0 | 52.9 35 | NL | FR | 40.0 | 54.5 36 | ALL | FR | 46.7 | 55.3 37 | | | | 38 | EN | NL | 58.0 | 69.8 39 | FR | NL | 56.1 | 61.4 40 | NL | NL | 48.5 | 69.6 41 | ALL | NL | 56.0 | 67.8 42 | 43 | ### F1 Scores on ACL RD-TEC 2.0 44 | Data Type | Token Classifier | 45 | ------------ | ------------- | 46 | Annotator 1 | 75.8 | 47 | Annotator 2 | 80.0 | 48 | 49 | ## Hyperparameters 50 | 51 | ### Sequence Classifier 52 | * optimizer: Adam 53 | * learning rate: 2e-5 54 | * batch size: 32 55 | * epochs: 4 56 | 57 | ### Token Classifier 58 | * optimizer: Adam 59 | * learning rate: 2e-5 60 | * batch size: 8 61 | * epochs: Load best model at the end, evaluating the model every 100 steps 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /Term_Extraction_Sequence_Classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "Term Extraction Sequence Classifier.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [ 10 | "9OsjSGOr0bSA" 11 | ], 12 | "authorship_tag": "ABX9TyMoG9pFKSz7mLYf/Ef0v5Bc", 13 | "include_colab_link": true 14 | }, 15 | "kernelspec": { 16 | "display_name": "Python 3", 17 | "name": "python3" 18 | }, 19 | "widgets": { 20 | "application/vnd.jupyter.widget-state+json": { 21 | "c4da41ffca2d4809a64ca7c3b4375bab": { 22 | "model_module": "@jupyter-widgets/controls", 23 | "model_name": "HBoxModel", 24 | "state": { 25 | "_dom_classes": [], 26 | "_model_module": "@jupyter-widgets/controls", 27 | "_model_module_version": "1.5.0", 28 | "_model_name": "HBoxModel", 29 | "_view_count": null, 30 | "_view_module": "@jupyter-widgets/controls", 31 | "_view_module_version": "1.5.0", 32 | "_view_name": "HBoxView", 33 | "box_style": "", 34 | "children": [ 35 | "IPY_MODEL_f8aa0656efa64e5385ec59a765939770", 36 | "IPY_MODEL_59dfeb1cd7f042eba3faa1ce8263eb0f" 37 | ], 38 | "layout": "IPY_MODEL_c0aa55048c3b41f097d1583b02dc3c45" 39 | } 40 | }, 41 | "f8aa0656efa64e5385ec59a765939770": { 42 | "model_module": "@jupyter-widgets/controls", 43 | "model_name": "FloatProgressModel", 44 | "state": { 45 | "_dom_classes": [], 46 | "_model_module": "@jupyter-widgets/controls", 47 | "_model_module_version": "1.5.0", 48 | "_model_name": "FloatProgressModel", 49 | "_view_count": null, 50 | "_view_module": "@jupyter-widgets/controls", 51 | "_view_module_version": "1.5.0", 52 | "_view_name": "ProgressView", 53 | "bar_style": "success", 54 | "description": "Downloading: 100%", 55 | "description_tooltip": null, 56 | "layout": "IPY_MODEL_dc06957d389e4995acbd22e23bdc8cef", 57 | "max": 5069051, 58 | "min": 0, 59 | "orientation": "horizontal", 60 | "style": "IPY_MODEL_53cf87e674334b27ad3a48422ec20030", 61 | "value": 5069051 62 | } 63 | }, 64 | "59dfeb1cd7f042eba3faa1ce8263eb0f": { 65 | "model_module": "@jupyter-widgets/controls", 66 | "model_name": "HTMLModel", 67 | "state": { 68 | "_dom_classes": [], 69 | "_model_module": "@jupyter-widgets/controls", 70 | "_model_module_version": "1.5.0", 71 | "_model_name": "HTMLModel", 72 | "_view_count": null, 73 | "_view_module": "@jupyter-widgets/controls", 74 | "_view_module_version": "1.5.0", 75 | "_view_name": "HTMLView", 76 | "description": "", 77 | "description_tooltip": null, 78 | "layout": "IPY_MODEL_9257c9f9130d4f47a05e3066eec6fffd", 79 | "placeholder": "​", 80 | "style": "IPY_MODEL_134178a4421b41de93598e8e0f08dcfb", 81 | "value": " 5.07M/5.07M [00:01<00:00, 2.72MB/s]" 82 | } 83 | }, 84 | "c0aa55048c3b41f097d1583b02dc3c45": { 85 | "model_module": "@jupyter-widgets/base", 86 | "model_name": "LayoutModel", 87 | "state": { 88 | "_model_module": "@jupyter-widgets/base", 89 | "_model_module_version": "1.2.0", 90 | "_model_name": "LayoutModel", 91 | "_view_count": null, 92 | "_view_module": "@jupyter-widgets/base", 93 | "_view_module_version": "1.2.0", 94 | "_view_name": "LayoutView", 95 | "align_content": null, 96 | "align_items": null, 97 | "align_self": null, 98 | "border": null, 99 | "bottom": null, 100 | "display": null, 101 | "flex": null, 102 | "flex_flow": null, 103 | "grid_area": null, 104 | "grid_auto_columns": null, 105 | "grid_auto_flow": null, 106 | "grid_auto_rows": null, 107 | "grid_column": null, 108 | "grid_gap": null, 109 | "grid_row": null, 110 | "grid_template_areas": null, 111 | "grid_template_columns": null, 112 | "grid_template_rows": null, 113 | "height": null, 114 | "justify_content": null, 115 | "justify_items": null, 116 | "left": null, 117 | "margin": null, 118 | "max_height": null, 119 | "max_width": null, 120 | "min_height": null, 121 | "min_width": null, 122 | "object_fit": null, 123 | "object_position": null, 124 | "order": null, 125 | "overflow": null, 126 | "overflow_x": null, 127 | "overflow_y": null, 128 | "padding": null, 129 | "right": null, 130 | "top": null, 131 | "visibility": null, 132 | "width": null 133 | } 134 | }, 135 | "dc06957d389e4995acbd22e23bdc8cef": { 136 | "model_module": "@jupyter-widgets/base", 137 | "model_name": "LayoutModel", 138 | "state": { 139 | "_model_module": "@jupyter-widgets/base", 140 | "_model_module_version": "1.2.0", 141 | "_model_name": "LayoutModel", 142 | "_view_count": null, 143 | "_view_module": "@jupyter-widgets/base", 144 | "_view_module_version": "1.2.0", 145 | "_view_name": "LayoutView", 146 | "align_content": null, 147 | "align_items": null, 148 | "align_self": null, 149 | "border": null, 150 | "bottom": null, 151 | "display": null, 152 | "flex": null, 153 | "flex_flow": null, 154 | "grid_area": null, 155 | "grid_auto_columns": null, 156 | "grid_auto_flow": null, 157 | "grid_auto_rows": null, 158 | "grid_column": null, 159 | "grid_gap": null, 160 | "grid_row": null, 161 | "grid_template_areas": null, 162 | "grid_template_columns": null, 163 | "grid_template_rows": null, 164 | "height": null, 165 | "justify_content": null, 166 | "justify_items": null, 167 | "left": null, 168 | "margin": null, 169 | "max_height": null, 170 | "max_width": null, 171 | "min_height": null, 172 | "min_width": null, 173 | "object_fit": null, 174 | "object_position": null, 175 | "order": null, 176 | "overflow": null, 177 | "overflow_x": null, 178 | "overflow_y": null, 179 | "padding": null, 180 | "right": null, 181 | "top": null, 182 | "visibility": null, 183 | "width": null 184 | } 185 | }, 186 | "53cf87e674334b27ad3a48422ec20030": { 187 | "model_module": "@jupyter-widgets/controls", 188 | "model_name": "ProgressStyleModel", 189 | "state": { 190 | "_model_module": "@jupyter-widgets/controls", 191 | "_model_module_version": "1.5.0", 192 | "_model_name": "ProgressStyleModel", 193 | "_view_count": null, 194 | "_view_module": "@jupyter-widgets/base", 195 | "_view_module_version": "1.2.0", 196 | "_view_name": "StyleView", 197 | "bar_color": null, 198 | "description_width": "initial" 199 | } 200 | }, 201 | "9257c9f9130d4f47a05e3066eec6fffd": { 202 | "model_module": "@jupyter-widgets/base", 203 | "model_name": "LayoutModel", 204 | "state": { 205 | "_model_module": "@jupyter-widgets/base", 206 | "_model_module_version": "1.2.0", 207 | "_model_name": "LayoutModel", 208 | "_view_count": null, 209 | "_view_module": "@jupyter-widgets/base", 210 | "_view_module_version": "1.2.0", 211 | "_view_name": "LayoutView", 212 | "align_content": null, 213 | "align_items": null, 214 | "align_self": null, 215 | "border": null, 216 | "bottom": null, 217 | "display": null, 218 | "flex": null, 219 | "flex_flow": null, 220 | "grid_area": null, 221 | "grid_auto_columns": null, 222 | "grid_auto_flow": null, 223 | "grid_auto_rows": null, 224 | "grid_column": null, 225 | "grid_gap": null, 226 | "grid_row": null, 227 | "grid_template_areas": null, 228 | "grid_template_columns": null, 229 | "grid_template_rows": null, 230 | "height": null, 231 | "justify_content": null, 232 | "justify_items": null, 233 | "left": null, 234 | "margin": null, 235 | "max_height": null, 236 | "max_width": null, 237 | "min_height": null, 238 | "min_width": null, 239 | "object_fit": null, 240 | "object_position": null, 241 | "order": null, 242 | "overflow": null, 243 | "overflow_x": null, 244 | "overflow_y": null, 245 | "padding": null, 246 | "right": null, 247 | "top": null, 248 | "visibility": null, 249 | "width": null 250 | } 251 | }, 252 | "134178a4421b41de93598e8e0f08dcfb": { 253 | "model_module": "@jupyter-widgets/controls", 254 | "model_name": "DescriptionStyleModel", 255 | "state": { 256 | "_model_module": "@jupyter-widgets/controls", 257 | "_model_module_version": "1.5.0", 258 | "_model_name": "DescriptionStyleModel", 259 | "_view_count": null, 260 | "_view_module": "@jupyter-widgets/base", 261 | "_view_module_version": "1.2.0", 262 | "_view_name": "StyleView", 263 | "description_width": "" 264 | } 265 | } 266 | } 267 | } 268 | }, 269 | "cells": [ 270 | { 271 | "cell_type": "markdown", 272 | "metadata": { 273 | "id": "view-in-github", 274 | "colab_type": "text" 275 | }, 276 | "source": [ 277 | "\"Open" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": { 283 | "id": "0yMTmZptEkHC" 284 | }, 285 | "source": [ 286 | "# Imports\n" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "metadata": { 292 | "id": "aaxWLY9GFE2W" 293 | }, 294 | "source": [ 295 | "!pip install transformers\n", 296 | "!pip install sacremoses\n", 297 | "!pip install sentencepiece" 298 | ], 299 | "execution_count": null, 300 | "outputs": [] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "metadata": { 305 | "id": "m9fYtB3_FHuK" 306 | }, 307 | "source": [ 308 | "#torch and tranformers for model and training\n", 309 | "import torch \n", 310 | "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n", 311 | "from torch.utils.data import TensorDataset\n", 312 | "from transformers import XLMRobertaTokenizer \n", 313 | "from transformers import XLMRobertaForSequenceClassification\n", 314 | "from transformers import AdamW \n", 315 | "from transformers import get_linear_schedule_with_warmup\n", 316 | "import sentencepiece\n", 317 | "\n", 318 | "#sklearn for evaluation\n", 319 | "from sklearn import preprocessing \n", 320 | "from sklearn.metrics import classification_report \n", 321 | "from sklearn.metrics import f1_score\n", 322 | "from sklearn.metrics import confusion_matrix\n", 323 | "from sklearn.model_selection import ParameterGrid \n", 324 | "from sklearn.model_selection import ParameterSampler \n", 325 | "from sklearn.utils.fixes import loguniform\n", 326 | "\n", 327 | "#nlp preprocessing\n", 328 | "from nltk import ngrams \n", 329 | "from spacy.pipeline import SentenceSegmenter\n", 330 | "from spacy.lang.en import English\n", 331 | "from spacy.pipeline import Sentencizer\n", 332 | "from sacremoses import MosesTokenizer, MosesDetokenizer\n", 333 | "\n", 334 | "\n", 335 | "#utilities\n", 336 | "import pandas as pd\n", 337 | "import glob, os\n", 338 | "import time\n", 339 | "import datetime\n", 340 | "import random\n", 341 | "import numpy as np\n", 342 | "import matplotlib.pyplot as plt\n", 343 | "% matplotlib inline\n", 344 | "import seaborn as sns\n", 345 | "import pickle" 346 | ], 347 | "execution_count": 3, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "metadata": { 353 | "colab": { 354 | "base_uri": "https://localhost:8080/" 355 | }, 356 | "id": "kpY66eTVxQNH", 357 | "outputId": "ba5a7610-f5ce-44cc-b19c-c2d9db65909f" 358 | }, 359 | "source": [ 360 | "# connect to GPU \n", 361 | "device = torch.device('cuda')\n", 362 | "\n", 363 | "print('Connected to GPU:', torch.cuda.get_device_name(0))" 364 | ], 365 | "execution_count": null, 366 | "outputs": [ 367 | { 368 | "output_type": "stream", 369 | "text": [ 370 | "Connected to GPU: Tesla P100-PCIE-16GB\n" 371 | ], 372 | "name": "stdout" 373 | } 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": { 379 | "id": "3RPZ14sYHHUm" 380 | }, 381 | "source": [ 382 | "# Prepare Data" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "id": "TKqV3YfXHSNz" 389 | }, 390 | "source": [ 391 | "Training Data: corp, wind\n", 392 | "\n", 393 | "Valid: equi\n", 394 | "\n", 395 | "Test Data: htfl" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "metadata": { 401 | "id": "ERUBsPPOFfe1" 402 | }, 403 | "source": [ 404 | "#load terms\n", 405 | "\n", 406 | "#en\n", 407 | "df_corp_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/corp/annotations/corp_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 408 | "df_equi_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/equi/annotations/equi_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 409 | "df_htfl_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/htfl/annotations/htfl_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 410 | "df_wind_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/wind/annotations/wind_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 411 | "\n", 412 | "#fr\n", 413 | "df_corp_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/corp/annotations/corp_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 414 | "df_equi_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/equi/annotations/equi_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 415 | "df_htfl_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/htfl/annotations/htfl_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 416 | "df_wind_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/wind/annotations/wind_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 417 | "\n", 418 | "#nl\n", 419 | "df_corp_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/corp/annotations/corp_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 420 | "df_equi_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/equi/annotations/equi_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 421 | "df_htfl_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/htfl/annotations/htfl_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 422 | "df_wind_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/wind/annotations/wind_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n", 423 | "\n", 424 | "labels=[\"Random\", \"Term\"]" 425 | ], 426 | "execution_count": null, 427 | "outputs": [] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "metadata": { 432 | "colab": { 433 | "base_uri": "https://localhost:8080/", 434 | "height": 419 435 | }, 436 | "id": "tw11QcsHF8Gc", 437 | "outputId": "a16e39d2-4ca0-4127-b7b0-414a028ef98f" 438 | }, 439 | "source": [ 440 | "# example terms\n", 441 | "df_wind_terms_en" 442 | ], 443 | "execution_count": null, 444 | "outputs": [ 445 | { 446 | "output_type": "execute_result", 447 | "data": { 448 | "text/html": [ 449 | "
\n", 450 | "\n", 463 | "\n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | "
TermLabel
048/600Named_Entity
14energiaNamed_Entity
24energyNamed_Entity
3ab \"lietuvos energija\"Named_Entity
4ab lietuvos elektrineNamed_Entity
.........
1529zhiquanNamed_Entity
1530çetinkayaNamed_Entity
1531çeti̇nkayaNamed_Entity
1532çeşmeNamed_Entity
1533özgenNamed_Entity
\n", 529 | "

1534 rows × 2 columns

\n", 530 | "
" 531 | ], 532 | "text/plain": [ 533 | " Term Label\n", 534 | "0 48/600 Named_Entity\n", 535 | "1 4energia Named_Entity\n", 536 | "2 4energy Named_Entity\n", 537 | "3 ab \"lietuvos energija\" Named_Entity\n", 538 | "4 ab lietuvos elektrine Named_Entity\n", 539 | "... ... ...\n", 540 | "1529 zhiquan Named_Entity\n", 541 | "1530 çetinkaya Named_Entity\n", 542 | "1531 çeti̇nkaya Named_Entity\n", 543 | "1532 çeşme Named_Entity\n", 544 | "1533 özgen Named_Entity\n", 545 | "\n", 546 | "[1534 rows x 2 columns]" 547 | ] 548 | }, 549 | "metadata": { 550 | "tags": [] 551 | }, 552 | "execution_count": 8 553 | } 554 | ] 555 | }, 556 | { 557 | "cell_type": "markdown", 558 | "metadata": { 559 | "id": "sU7NMPaDvbWt" 560 | }, 561 | "source": [ 562 | "**Functions for preprocessing and creating of Training Data**" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "metadata": { 568 | "id": "3_stqlIDvZxA" 569 | }, 570 | "source": [ 571 | "#load all text files from folder into a string\n", 572 | "def load_text_corpus(path):\n", 573 | " text_data=\"\"\n", 574 | " print(glob.glob(path))\n", 575 | " for file in glob.glob(path+\"*.txt\"):\n", 576 | " print(file)\n", 577 | " with open(file) as f:\n", 578 | " temp_data = f.read()\n", 579 | " print(len(temp_data))\n", 580 | " text_data=text_data+\" \"+temp_data\n", 581 | " print(len(text_data))\n", 582 | " return text_data" 583 | ], 584 | "execution_count": null, 585 | "outputs": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "metadata": { 590 | "id": "4nXtHwAyPoK0" 591 | }, 592 | "source": [ 593 | "#split in sentences and tokenize\n", 594 | "def preprocess(text):\n", 595 | " #sentenize (from spacy)\n", 596 | " sentencizer = Sentencizer()\n", 597 | " nlp = English()\n", 598 | " nlp.add_pipe(sentencizer)\n", 599 | " doc = nlp(text)\n", 600 | "\n", 601 | " #tokenize\n", 602 | " sentence_list=[]\n", 603 | " mt = MosesTokenizer(lang='en')\n", 604 | " for s in doc.sents:\n", 605 | " tokenized_text = mt.tokenize(s, return_str=True)\n", 606 | " sentence_list.append((tokenized_text.split(), s)) #append tuple of tokens and original senteence\n", 607 | " return sentence_list\n" 608 | ], 609 | "execution_count": null, 610 | "outputs": [] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "metadata": { 615 | "id": "1qBA_KhoQkhB" 616 | }, 617 | "source": [ 618 | "#input is list of sentences and dataframe containing terms\n", 619 | "def create_training_data(sentence_list, df_terms, n):\n", 620 | "\n", 621 | " #create empty dataframe\n", 622 | " training_data = pd.DataFrame(columns=['n_gram', 'Context', 'Label', \"Termtype\"])\n", 623 | "\n", 624 | " md = MosesDetokenizer(lang='en')\n", 625 | "\n", 626 | "\n", 627 | " print(len(sentence_list))\n", 628 | " count=0\n", 629 | "\n", 630 | " for sen in sentence_list:\n", 631 | " count+=1\n", 632 | " if count%100==0:print(count)\n", 633 | "\n", 634 | " s=sen[0] #take first part of tuple, i.e. the tokens\n", 635 | "\n", 636 | " # 1-gram up to n-gram\n", 637 | " for i in range(1,n+1):\n", 638 | " #create n-grams of this sentence\n", 639 | " n_grams = ngrams(s, i)\n", 640 | "\n", 641 | " #look if n-grams are in the annotation dataset\n", 642 | " for n_gram in n_grams: \n", 643 | " n_gram=md.detokenize(n_gram) \n", 644 | " context=str(sen[1]).strip()\n", 645 | " #if yes add an entry to the training data\n", 646 | " if n_gram.lower() in df_terms.values:\n", 647 | " #append positive sample\n", 648 | " #get termtype like common term\n", 649 | " termtype=\"/\"#df_terms.loc[df_terms['Term'] == n_gram.lower()].iloc[0][\"Label\"]\n", 650 | " training_data = training_data.append({'n_gram': n_gram, 'Context': context, 'Label': 1, \"Termtype\":termtype}, ignore_index=True)\n", 651 | " else:\n", 652 | " #append negative sample\n", 653 | " training_data = training_data.append({'n_gram': n_gram, 'Context': context, 'Label': 0, \"Termtype\":\"None\"}, ignore_index=True)\n", 654 | "\n", 655 | " return training_data\n", 656 | "\n", 657 | " " 658 | ], 659 | "execution_count": null, 660 | "outputs": [] 661 | }, 662 | { 663 | "cell_type": "markdown", 664 | "metadata": { 665 | "id": "4HhBTwYl1-dy" 666 | }, 667 | "source": [ 668 | "**Create Training Data**" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "metadata": { 674 | "id": "UemCf-2xPrn1" 675 | }, 676 | "source": [ 677 | "# en \n", 678 | "#create trainings data for all corp texts\n", 679 | "corp_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/corp/texts/annotated/\") # load test\n", 680 | "corp_s_list=preprocess(corp_text_en) # preprocess\n", 681 | "train_data_corp_en=create_training_data(corp_s_list, df_corp_terms_en, 6) # create training data\n", 682 | "\n", 683 | "#create trainings data for all wind texts\n", 684 | "wind_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/wind/texts/annotated/\") # load test\n", 685 | "wind_s_list=preprocess(wind_text_en) # preprocess\n", 686 | "train_data_wind_en=create_training_data(wind_s_list, df_wind_terms_en, 6) # create training data\n", 687 | "\n", 688 | "#create trainings data for all equi texts\n", 689 | "equi_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/equi/texts/annotated/\") # load test\n", 690 | "equi_s_list=preprocess(equi_text_en) # preprocess\n", 691 | "train_data_equi_en=create_training_data(equi_s_list, df_equi_terms_en, 6) # create training data\n", 692 | "\n", 693 | "#create trainings data for all htfl texts\n", 694 | "htfl_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/htfl/texts/annotated/\") # load test\n", 695 | "htfl_s_list=preprocess(htfl_text_en) # preprocess\n", 696 | "train_data_htfl_en=create_training_data(htfl_s_list, df_htfl_terms_en, 6) # create training data " 697 | ], 698 | "execution_count": null, 699 | "outputs": [] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "metadata": { 704 | "id": "stFqQ_Sd2gAN" 705 | }, 706 | "source": [ 707 | "#fr\n", 708 | "corp_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/corp/texts/annotated/\") # load text\n", 709 | "corp_s_list=preprocess(corp_text_fr) # preprocess\n", 710 | "train_data_corp_fr=create_training_data(corp_s_list, df_corp_terms_fr, 6) # create training data\n", 711 | "\n", 712 | "wind_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/wind/texts/annotated/\") # load text\n", 713 | "wind_s_list=preprocess(wind_text_fr) # preprocess\n", 714 | "train_data_wind_fr=create_training_data(wind_s_list, df_wind_terms_fr, 6) # create training data\n", 715 | "\n", 716 | "equi_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/equi/texts/annotated/\") # load text\n", 717 | "equi_s_list=preprocess(equi_text_fr) # preprocess\n", 718 | "train_data_equi_fr=create_training_data(equi_s_list, df_equi_terms_fr, 6) # create training data\n", 719 | "\n", 720 | "htfl_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/htfl/texts/annotated/\") # load text\n", 721 | "htfl_s_list=preprocess(htfl_text_fr) # preprocess\n", 722 | "train_data_htfl_fr=create_training_data(htfl_s_list, df_htfl_terms_fr, 6) # create training data " 723 | ], 724 | "execution_count": null, 725 | "outputs": [] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "metadata": { 730 | "id": "z2PI4ngj2gKZ" 731 | }, 732 | "source": [ 733 | "#nl\n", 734 | "corp_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/corp/texts/annotated/\") # load text\n", 735 | "corp_s_list=preprocess(corp_text_nl) # preprocess\n", 736 | "train_data_corp_nl=create_training_data(corp_s_list, df_corp_terms_nl, 6) # create training data\n", 737 | "\n", 738 | "wind_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/wind/texts/annotated/\") # load text\n", 739 | "wind_s_list=preprocess(wind_text_nl) # preprocess\n", 740 | "train_data_wind_nl=create_training_data(wind_s_list, df_wind_terms_nl, 6) # create training data\n", 741 | "\n", 742 | "equi_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/equi/texts/annotated/\") # load text\n", 743 | "equi_s_list=preprocess(equi_text_nl) # preprocess\n", 744 | "train_data_equi_nl=create_training_data(equi_s_list, df_equi_terms_nl, 6) # create training data\n", 745 | "\n", 746 | "htfl_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/htfl/texts/annotated/\") # load text\n", 747 | "htfl_s_list=preprocess(htfl_text_nl) # preprocess\n", 748 | "train_data_htfl_nl=create_training_data(htfl_s_list, df_htfl_terms_nl, 6) # create training data " 749 | ], 750 | "execution_count": null, 751 | "outputs": [] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "metadata": { 756 | "colab": { 757 | "base_uri": "https://localhost:8080/" 758 | }, 759 | "id": "GXrT0L_DNCE_", 760 | "outputId": "6eed88af-a6eb-43e1-e18c-66c3fd432754" 761 | }, 762 | "source": [ 763 | "print(train_data_corp_en.groupby('Label').count())\n", 764 | "print(train_data_wind_en.groupby('Label').count())\n", 765 | "print(train_data_equi_en.groupby('Label').count())\n", 766 | "print(train_data_htfl_en.groupby('Label').count())" 767 | ], 768 | "execution_count": null, 769 | "outputs": [ 770 | { 771 | "output_type": "stream", 772 | "text": [ 773 | " n_gram Context Termtype\n", 774 | "Label \n", 775 | "0 274139 274139 274139\n", 776 | "1 8708 8708 8708\n", 777 | " n_gram Context Termtype\n", 778 | "Label \n", 779 | "0 311535 311535 311535\n", 780 | "1 10542 10542 10542\n", 781 | " n_gram Context Termtype\n", 782 | "Label \n", 783 | "0 298863 298863 298863\n", 784 | "1 13891 13891 13891\n", 785 | " n_gram Context Termtype\n", 786 | "Label \n", 787 | "0 290334 290334 290334\n", 788 | "1 14376 14376 14376\n" 789 | ], 790 | "name": "stdout" 791 | } 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "metadata": { 797 | "colab": { 798 | "base_uri": "https://localhost:8080/", 799 | "height": 419 800 | }, 801 | "id": "S4_Q9krEESA2", 802 | "outputId": "5fe80a12-619a-47f5-e7f1-cfd4a9a36d11" 803 | }, 804 | "source": [ 805 | "train_data_equi_en" 806 | ], 807 | "execution_count": null, 808 | "outputs": [ 809 | { 810 | "output_type": "execute_result", 811 | "data": { 812 | "text/html": [ 813 | "
\n", 814 | "\n", 827 | "\n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | "
n_gramContextLabelTermtype
0PirouettePirouette (dressage)\\n\\nA Pirouette is a Frenc...1Specific_Term
1(Pirouette (dressage)\\n\\nA Pirouette is a Frenc...0None
2dressagePirouette (dressage)\\n\\nA Pirouette is a Frenc...1Common_Term
3)Pirouette (dressage)\\n\\nA Pirouette is a Frenc...0None
4APirouette (dressage)\\n\\nA Pirouette is a Frenc...1Specific_Term
...............
312749about it when he's doneStop and let your horse think about it when he...0None
312750it when he's done somethingStop and let your horse think about it when he...0None
312751when he's done something rightStop and let your horse think about it when he...0None
312752he's done something right.Stop and let your horse think about it when he...0None
312753's done something right. \"Stop and let your horse think about it when he...0None
\n", 917 | "

312754 rows × 4 columns

\n", 918 | "
" 919 | ], 920 | "text/plain": [ 921 | " n_gram ... Termtype\n", 922 | "0 Pirouette ... Specific_Term\n", 923 | "1 ( ... None\n", 924 | "2 dressage ... Common_Term\n", 925 | "3 ) ... None\n", 926 | "4 A ... Specific_Term\n", 927 | "... ... ... ...\n", 928 | "312749 about it when he's done ... None\n", 929 | "312750 it when he's done something ... None\n", 930 | "312751 when he's done something right ... None\n", 931 | "312752 he's done something right. ... None\n", 932 | "312753 's done something right. \" ... None\n", 933 | "\n", 934 | "[312754 rows x 4 columns]" 935 | ] 936 | }, 937 | "metadata": { 938 | "tags": [] 939 | }, 940 | "execution_count": 24 941 | } 942 | ] 943 | }, 944 | { 945 | "cell_type": "markdown", 946 | "metadata": { 947 | "id": "WRYG7Q_sDnNw" 948 | }, 949 | "source": [ 950 | "**Undersample**" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "metadata": { 956 | "id": "qvHscNxQCmvJ" 957 | }, 958 | "source": [ 959 | "#undersample class 0 so the amount of trainingsample is the same as label 1 \n", 960 | "\n", 961 | "def undersample(train_data):\n", 962 | "# Class count\n", 963 | " print(\"Before\")\n", 964 | " print(train_data.Label.value_counts())\n", 965 | " count_class_0, count_class_1 = train_data.Label.value_counts()\n", 966 | "\n", 967 | " # Divide by class\n", 968 | " df_class_0 = train_data[train_data['Label'] == 0]\n", 969 | " df_class_1 = train_data[train_data['Label'] == 1]\n", 970 | "\n", 971 | " df_class_0_under = df_class_0.sample(count_class_1)\n", 972 | " df_test_under = pd.concat([df_class_0_under, df_class_1], axis=0)\n", 973 | "\n", 974 | " print(\"After\")\n", 975 | " print(df_test_under.Label.value_counts())\n", 976 | "\n", 977 | " return df_test_under" 978 | ], 979 | "execution_count": null, 980 | "outputs": [] 981 | }, 982 | { 983 | "cell_type": "code", 984 | "metadata": { 985 | "colab": { 986 | "base_uri": "https://localhost:8080/" 987 | }, 988 | "id": "-wi80YrXFHj4", 989 | "outputId": "270e8a84-7e17-4952-fad3-a9998718f99b" 990 | }, 991 | "source": [ 992 | "# undersample the trainingsdata\n", 993 | "\n", 994 | "#en\n", 995 | "train_data_corp_en=undersample(train_data_corp_en)\n", 996 | "\n", 997 | "train_data_wind_en=undersample(train_data_wind_en)\n", 998 | "\n", 999 | "\n", 1000 | "#fr\n", 1001 | "train_data_corp_fr=undersample(train_data_corp_fr)\n", 1002 | "\n", 1003 | "train_data_wind_fr=undersample(train_data_wind_fr)\n", 1004 | "\n", 1005 | "\n", 1006 | "#nl\n", 1007 | "train_data_corp_nl=undersample(train_data_corp_nl)\n", 1008 | "\n", 1009 | "train_data_wind_nl=undersample(train_data_wind_nl)" 1010 | ], 1011 | "execution_count": null, 1012 | "outputs": [ 1013 | { 1014 | "output_type": "stream", 1015 | "text": [ 1016 | "Before\n", 1017 | "0 274139\n", 1018 | "1 8708\n", 1019 | "Name: Label, dtype: int64\n", 1020 | "After\n", 1021 | "1 8708\n", 1022 | "0 8708\n", 1023 | "Name: Label, dtype: int64\n", 1024 | "Before\n", 1025 | "0 311535\n", 1026 | "1 10542\n", 1027 | "Name: Label, dtype: int64\n", 1028 | "After\n", 1029 | "1 10542\n", 1030 | "0 10542\n", 1031 | "Name: Label, dtype: int64\n", 1032 | "Before\n", 1033 | "0 325242\n", 1034 | "1 7443\n", 1035 | "Name: Label, dtype: int64\n", 1036 | "After\n", 1037 | "1 7443\n", 1038 | "0 7443\n", 1039 | "Name: Label, dtype: int64\n", 1040 | "Before\n", 1041 | "0 356805\n", 1042 | "1 9293\n", 1043 | "Name: Label, dtype: int64\n", 1044 | "After\n", 1045 | "1 9293\n", 1046 | "0 9293\n", 1047 | "Name: Label, dtype: int64\n", 1048 | "Before\n", 1049 | "0 283267\n", 1050 | "1 7071\n", 1051 | "Name: Label, dtype: int64\n", 1052 | "After\n", 1053 | "1 7071\n", 1054 | "0 7071\n", 1055 | "Name: Label, dtype: int64\n", 1056 | "Before\n", 1057 | "0 287361\n", 1058 | "1 5582\n", 1059 | "Name: Label, dtype: int64\n", 1060 | "After\n", 1061 | "1 5582\n", 1062 | "0 5582\n", 1063 | "Name: Label, dtype: int64\n" 1064 | ], 1065 | "name": "stdout" 1066 | } 1067 | ] 1068 | }, 1069 | { 1070 | "cell_type": "code", 1071 | "metadata": { 1072 | "colab": { 1073 | "base_uri": "https://localhost:8080/" 1074 | }, 1075 | "id": "VSy8hZggPQpf", 1076 | "outputId": "a8893623-b01f-4794-e20d-79e5d6a2136a" 1077 | }, 1078 | "source": [ 1079 | "#concat trainingsdata\n", 1080 | "trainings_data_df = pd.concat([train_data_corp_en, train_data_wind_en, train_data_corp_fr, train_data_wind_fr, train_data_corp_nl, train_data_wind_nl])\n", 1081 | "\n", 1082 | "valid_data_df = train_data_equi_en #pd.concat([train_data_equi_en, train_data_equi_fr, train_data_equi_nl ])\n", 1083 | "\n", 1084 | "test_data_df_en = train_data_htfl_en\n", 1085 | "test_data_df_fr = train_data_htfl_fr\n", 1086 | "test_data_df_nl = train_data_htfl_nl\n", 1087 | "\n", 1088 | "print(len(trainings_data_df))\n", 1089 | "print(len(valid_data_df))\n", 1090 | "print(len(test_data_df_en))\n", 1091 | "print(len(test_data_df_fr))\n", 1092 | "print(len(test_data_df_nl))" 1093 | ], 1094 | "execution_count": null, 1095 | "outputs": [ 1096 | { 1097 | "output_type": "stream", 1098 | "text": [ 1099 | "97278\n", 1100 | "312754\n", 1101 | "304710\n", 1102 | "303069\n", 1103 | "292615\n" 1104 | ], 1105 | "name": "stdout" 1106 | } 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "markdown", 1111 | "metadata": { 1112 | "id": "jKtVpCjIWPvO" 1113 | }, 1114 | "source": [ 1115 | "**Tokenizer**" 1116 | ] 1117 | }, 1118 | { 1119 | "cell_type": "code", 1120 | "metadata": { 1121 | "colab": { 1122 | "base_uri": "https://localhost:8080/", 1123 | "height": 66, 1124 | "referenced_widgets": [ 1125 | "c4da41ffca2d4809a64ca7c3b4375bab", 1126 | "f8aa0656efa64e5385ec59a765939770", 1127 | "59dfeb1cd7f042eba3faa1ce8263eb0f", 1128 | "c0aa55048c3b41f097d1583b02dc3c45", 1129 | "dc06957d389e4995acbd22e23bdc8cef", 1130 | "53cf87e674334b27ad3a48422ec20030", 1131 | "9257c9f9130d4f47a05e3066eec6fffd", 1132 | "134178a4421b41de93598e8e0f08dcfb" 1133 | ] 1134 | }, 1135 | "id": "pJjnroUuWOdg", 1136 | "outputId": "aad8ace8-3731-49da-d753-6649fb6ecd52" 1137 | }, 1138 | "source": [ 1139 | "xlmr_tokenizer = XLMRobertaTokenizer.from_pretrained(\"xlm-roberta-base\")" 1140 | ], 1141 | "execution_count": null, 1142 | "outputs": [ 1143 | { 1144 | "output_type": "display_data", 1145 | "data": { 1146 | "application/vnd.jupyter.widget-view+json": { 1147 | "model_id": "c4da41ffca2d4809a64ca7c3b4375bab", 1148 | "version_major": 2, 1149 | "version_minor": 0 1150 | }, 1151 | "text/plain": [ 1152 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…" 1153 | ] 1154 | }, 1155 | "metadata": { 1156 | "tags": [] 1157 | } 1158 | }, 1159 | { 1160 | "output_type": "stream", 1161 | "text": [ 1162 | "\n" 1163 | ], 1164 | "name": "stdout" 1165 | } 1166 | ] 1167 | }, 1168 | { 1169 | "cell_type": "code", 1170 | "metadata": { 1171 | "id": "9v7WbIW6WV8D" 1172 | }, 1173 | "source": [ 1174 | "def tokenizer_xlm(data, max_len):\n", 1175 | " labels_ = []\n", 1176 | " input_ids_ = []\n", 1177 | " attn_masks_ = []\n", 1178 | "\n", 1179 | " # for each datasample:\n", 1180 | " for index, row in data.iterrows():\n", 1181 | "\n", 1182 | " sentence = row['n_gram']+\". \"+row[\"Context\"]\n", 1183 | " #print(sentence)\n", 1184 | " \n", 1185 | " # create requiered input, i.e. ids and attention masks\n", 1186 | " encoded_dict = xlmr_tokenizer.encode_plus(sentence,\n", 1187 | " max_length=max_len, \n", 1188 | " padding='max_length',\n", 1189 | " truncation=True, \n", 1190 | " return_tensors='pt')\n", 1191 | "\n", 1192 | " # add encoded sample to lists\n", 1193 | " input_ids_.append(encoded_dict['input_ids'])\n", 1194 | " attn_masks_.append(encoded_dict['attention_mask'])\n", 1195 | " labels_.append(row['Label'])\n", 1196 | " \n", 1197 | " # Convert each Python list of Tensors into a 2D Tensor matrix.\n", 1198 | " input_ids_ = torch.cat(input_ids_, dim=0)\n", 1199 | " attn_masks_ = torch.cat(attn_masks_, dim=0)\n", 1200 | "\n", 1201 | " # labels to tensor\n", 1202 | " labels_ = torch.tensor(labels_)\n", 1203 | "\n", 1204 | " print('Encoder finished. {:,} examples.'.format(len(labels_)))\n", 1205 | " return input_ids_, attn_masks_, labels_" 1206 | ], 1207 | "execution_count": null, 1208 | "outputs": [] 1209 | }, 1210 | { 1211 | "cell_type": "code", 1212 | "metadata": { 1213 | "colab": { 1214 | "base_uri": "https://localhost:8080/" 1215 | }, 1216 | "id": "RcCexBG1ZuP_", 1217 | "outputId": "c2641990-2539-4ee3-b53e-aee04dfb052b" 1218 | }, 1219 | "source": [ 1220 | "#tokenize input for the different training/test sets\n", 1221 | "max_len=64\n", 1222 | "\n", 1223 | "input_ids_train, attn_masks_train, labels_all_train = tokenizer_xlm(trainings_data_df, max_len)\n", 1224 | "\n", 1225 | "input_ids_valid, attn_masks_valid, labels_all_valid = tokenizer_xlm(valid_data_df, max_len)\n", 1226 | "\n", 1227 | "input_ids_test_en, attn_masks_test_en, labels_test_en = tokenizer_xlm(test_data_df_en, max_len)\n", 1228 | "input_ids_test_fr, attn_masks_test_fr, labels_test_fr = tokenizer_xlm(test_data_df_fr, max_len)\n", 1229 | "input_ids_test_nl, attn_masks_test_nl, labels_test_nl = tokenizer_xlm(test_data_df_nl, max_len)" 1230 | ], 1231 | "execution_count": null, 1232 | "outputs": [ 1233 | { 1234 | "output_type": "stream", 1235 | "text": [ 1236 | "Encoder finished. 97,278 examples.\n", 1237 | "Encoder finished. 312,754 examples.\n", 1238 | "Encoder finished. 304,710 examples.\n", 1239 | "Encoder finished. 303,069 examples.\n", 1240 | "Encoder finished. 292,615 examples.\n" 1241 | ], 1242 | "name": "stdout" 1243 | } 1244 | ] 1245 | }, 1246 | { 1247 | "cell_type": "code", 1248 | "metadata": { 1249 | "id": "nLCLiW9-Nkd-" 1250 | }, 1251 | "source": [ 1252 | "# create datasets\n", 1253 | "train_dataset = TensorDataset(input_ids_train, attn_masks_train, labels_all_train)\n", 1254 | "\n", 1255 | "valid_dataset = TensorDataset(input_ids_valid, attn_masks_valid, labels_all_valid)\n", 1256 | "\n", 1257 | "test_dataset_en = TensorDataset(input_ids_test_en, attn_masks_test_en, labels_test_en)\n", 1258 | "test_dataset_fr = TensorDataset(input_ids_test_fr, attn_masks_test_fr, labels_test_fr)\n", 1259 | "test_dataset_nl = TensorDataset(input_ids_test_nl, attn_masks_test_nl, labels_test_nl)" 1260 | ], 1261 | "execution_count": null, 1262 | "outputs": [] 1263 | }, 1264 | { 1265 | "cell_type": "code", 1266 | "metadata": { 1267 | "id": "Si-ng4T8Ny2O" 1268 | }, 1269 | "source": [ 1270 | "# create dataloaders\n", 1271 | "batch_size = 32\n", 1272 | "\n", 1273 | "train_dataloader = DataLoader(train_dataset, sampler = RandomSampler(train_dataset), batch_size = batch_size) #random sampling\n", 1274 | "valid_dataloader = DataLoader(valid_dataset, sampler = SequentialSampler(valid_dataset),batch_size = batch_size ) #sequential sampling\n", 1275 | "\n", 1276 | "test_dataloader_en = DataLoader(test_dataset_en, sampler = SequentialSampler(test_dataset_en),batch_size = batch_size ) #sequential sampling\n", 1277 | "test_dataloader_fr = DataLoader(test_dataset_fr, sampler = SequentialSampler(test_dataset_fr),batch_size = batch_size ) #sequential sampling\n", 1278 | "test_dataloader_nl = DataLoader(test_dataset_nl, sampler = SequentialSampler(test_dataset_nl),batch_size = batch_size ) #sequential sampling" 1279 | ], 1280 | "execution_count": null, 1281 | "outputs": [] 1282 | }, 1283 | { 1284 | "cell_type": "markdown", 1285 | "metadata": { 1286 | "id": "Hart2Y_ia5qD" 1287 | }, 1288 | "source": [ 1289 | "#Model" 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "metadata": { 1295 | "id": "sF72Sc2ur-ds" 1296 | }, 1297 | "source": [ 1298 | "def create_model(lr, eps, train_dataloader, epochs, device):\n", 1299 | " xlmr_model = XLMRobertaForSequenceClassification.from_pretrained(\"xlm-roberta-base\", num_labels=2)\n", 1300 | " desc = xlmr_model.to(device)\n", 1301 | " print('Connected to GPU:', torch.cuda.get_device_name(0))\n", 1302 | " optimizer = AdamW(xlmr_model.parameters(),\n", 1303 | " lr = lr, \n", 1304 | " eps = eps \n", 1305 | " )\n", 1306 | " total_steps = len(train_dataloader) * epochs\n", 1307 | " scheduler = get_linear_schedule_with_warmup(optimizer, \n", 1308 | " num_warmup_steps = 0, \n", 1309 | " num_training_steps = total_steps)\n", 1310 | " return xlmr_model, optimizer, scheduler" 1311 | ], 1312 | "execution_count": null, 1313 | "outputs": [] 1314 | }, 1315 | { 1316 | "cell_type": "code", 1317 | "metadata": { 1318 | "id": "R7acJSCUtHN6" 1319 | }, 1320 | "source": [ 1321 | "def format_time(elapsed):\n", 1322 | " '''\n", 1323 | " Takes a time in seconds and returns a string hh:mm:ss\n", 1324 | " '''\n", 1325 | " elapsed_rounded = int(round((elapsed)))\n", 1326 | " return str(datetime.timedelta(seconds=elapsed_rounded)) " 1327 | ], 1328 | "execution_count": null, 1329 | "outputs": [] 1330 | }, 1331 | { 1332 | "cell_type": "code", 1333 | "metadata": { 1334 | "id": "YsxS3wVltI5i" 1335 | }, 1336 | "source": [ 1337 | "def validate(validation_dataloader, validation_df, xlmr_model, verbose, print_cm): \n", 1338 | " \n", 1339 | " # put model in evaluation mode \n", 1340 | " xlmr_model.eval()\n", 1341 | "\n", 1342 | " #extract terms and compute scores\n", 1343 | " extracted_terms_equi=extract_terms(train_data_equi_en, xlmr_model)\n", 1344 | " extracted_terms_equi_en = set([item.lower() for item in extracted_terms_equi_en])\n", 1345 | " gold_set_equi_en=set(df_equi_terms_en[\"Term\"])\n", 1346 | " true_pos=extracted_terms_equi_en.intersection(gold_set_equi_en)\n", 1347 | " recall=len(true_pos)/len(gold_set_equi_en)\n", 1348 | " precision=len(true_pos)/len(extracted_terms_equi_en)\n", 1349 | " f1=2*(precision*recall)/(precision+recall)\n", 1350 | "\n", 1351 | " return recall, precision, f1" 1352 | ], 1353 | "execution_count": null, 1354 | "outputs": [] 1355 | }, 1356 | { 1357 | "cell_type": "code", 1358 | "metadata": { 1359 | "id": "UYBNMpiszm_h" 1360 | }, 1361 | "source": [ 1362 | "def extract_terms(validation_df, xlmr_model): \n", 1363 | " print(len(validation_df))\n", 1364 | " term_list=[]\n", 1365 | "\n", 1366 | " # put model in evaluation mode \n", 1367 | " xlmr_model.eval()\n", 1368 | "\n", 1369 | " for index, row in validation_df.iterrows():\n", 1370 | " sentence = row['n_gram']+\". \"+row[\"Context\"]\n", 1371 | " label=validation_df[\"Label\"]\n", 1372 | "\n", 1373 | " encoded_dict = xlmr_tokenizer.encode_plus(sentence, \n", 1374 | " max_length=max_len, \n", 1375 | " padding='max_length',\n", 1376 | " truncation=True, \n", 1377 | " return_tensors='pt') \n", 1378 | " input_id=encoded_dict['input_ids'].to(device)\n", 1379 | " attn_mask=encoded_dict['attention_mask'].to(device)\n", 1380 | " label=torch.tensor(0).to(device) \n", 1381 | "\n", 1382 | " with torch.no_grad(): \n", 1383 | " output = xlmr_model(input_id, \n", 1384 | " token_type_ids=None, \n", 1385 | " attention_mask=attn_mask,\n", 1386 | " labels=label)\n", 1387 | " loss=output.loss\n", 1388 | " logits=output.logits\n", 1389 | " \n", 1390 | " logits = logits.detach().cpu().numpy()\n", 1391 | " pred=labels[logits[0].argmax(axis=0)]\n", 1392 | " if pred==\"Term\":\n", 1393 | " term_list.append(row['n_gram'])\n", 1394 | "\n", 1395 | " return set(term_list)\n", 1396 | " " 1397 | ], 1398 | "execution_count": null, 1399 | "outputs": [] 1400 | }, 1401 | { 1402 | "cell_type": "code", 1403 | "metadata": { 1404 | "id": "zs7cOPFJtLUG" 1405 | }, 1406 | "source": [ 1407 | "def train_model(epochs, xlmr_model, train_dataloader, validation_dataloader, validation_df, random_seed, verbose, optimizer, scheduler):\n", 1408 | "\n", 1409 | " seed_val = random_seed\n", 1410 | "\n", 1411 | " random.seed(seed_val)\n", 1412 | " np.random.seed(seed_val)\n", 1413 | " torch.manual_seed(seed_val)\n", 1414 | " torch.cuda.manual_seed_all(seed_val)\n", 1415 | "\n", 1416 | " # mostly contains scores about how the training went for each epoch\n", 1417 | " training_stats = []\n", 1418 | "\n", 1419 | " # total training time\n", 1420 | " total_t0 = time.time()\n", 1421 | "\n", 1422 | " print('\\033[1m'+\"================ Model Training ================\"+'\\033[0m')\n", 1423 | "\n", 1424 | " # For each epoch...\n", 1425 | " for epoch_i in range(0, epochs):\n", 1426 | "\n", 1427 | " print(\"\")\n", 1428 | " print('\\033[1m'+'======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)+'\\033[0m')\n", 1429 | "\n", 1430 | " t0 = time.time()\n", 1431 | "\n", 1432 | " # summed training loss of the epoch\n", 1433 | " total_train_loss = 0\n", 1434 | "\n", 1435 | "\n", 1436 | " # model is being put into training mode as mechanisms like dropout work differently during train and test time\n", 1437 | " xlmr_model.train()\n", 1438 | "\n", 1439 | " # iterrate over batches\n", 1440 | " for step, batch in enumerate(train_dataloader):\n", 1441 | "\n", 1442 | " # unpack training batch at load it to gpu (device) \n", 1443 | " b_input_ids = batch[0].to(device)\n", 1444 | " b_input_mask = batch[1].to(device)\n", 1445 | " b_labels = batch[2].to(device)\n", 1446 | "\n", 1447 | " # clear gradients before calculating new ones\n", 1448 | " xlmr_model.zero_grad() \n", 1449 | "\n", 1450 | " # forward pass with current batch\n", 1451 | " output = xlmr_model(b_input_ids, \n", 1452 | " token_type_ids=None, \n", 1453 | " attention_mask=b_input_mask, \n", 1454 | " labels=b_labels)\n", 1455 | " \n", 1456 | " loss=output.loss\n", 1457 | " logits=output.logits\n", 1458 | "\n", 1459 | " # add up the loss\n", 1460 | " total_train_loss += loss.item()\n", 1461 | "\n", 1462 | " # calculate new gradients\n", 1463 | " loss.backward()\n", 1464 | "\n", 1465 | " # gradient clipping (not bigger than)\n", 1466 | " torch.nn.utils.clip_grad_norm_(xlmr_model.parameters(), 1.0)\n", 1467 | "\n", 1468 | " # Update the networks weights based on the gradient as well as the optimiziers parameters\n", 1469 | " optimizer.step()\n", 1470 | "\n", 1471 | " # lr update\n", 1472 | " scheduler.step()\n", 1473 | "\n", 1474 | " # avg loss over all batches\n", 1475 | " avg_train_loss = total_train_loss / len(train_dataloader) \n", 1476 | " \n", 1477 | " # training time of this epoch\n", 1478 | " training_time = format_time(time.time() - t0)\n", 1479 | "\n", 1480 | " print(\"\")\n", 1481 | " print(\" Average training loss: {0:.2f}\".format(avg_train_loss))\n", 1482 | " print(\" Training epoch took: {:}\".format(training_time))\n", 1483 | " \n", 1484 | " \n", 1485 | " # VALIDATION\n", 1486 | " print(\"evaluate\")\n", 1487 | " if epoch_i==epochs-1:print_cm=True #Print out cm in final iteration\n", 1488 | " else: print_cm=False\n", 1489 | " recall, precision, f1 = validate(validation_dataloader, validation_df, xlmr_model, verbose, print_cm) \n", 1490 | " \n", 1491 | "\n", 1492 | " #print('\\033[1m'+ \" Validation Loss All: {0:.2f}\".format(avg_val_loss) + '\\033[0m')\n", 1493 | "\n", 1494 | " training_stats.append(\n", 1495 | " {\n", 1496 | " 'epoch': epoch_i + 1,\n", 1497 | " 'Training Loss': avg_train_loss,\n", 1498 | " \"precision\": precision,\n", 1499 | " \"recall\": recall,\n", 1500 | " \"f1\": f1,\n", 1501 | " 'Training Time': training_time,\n", 1502 | " }\n", 1503 | " )\n", 1504 | "\n", 1505 | " print(\"Precicion\", precision)\n", 1506 | " print(\"Recall\", recall)\n", 1507 | " print(\"F1\", f1)\n", 1508 | "\n", 1509 | " print(\"\\n\\nTraining complete!\")\n", 1510 | " print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))\n", 1511 | " \n", 1512 | " return training_stats\n" 1513 | ], 1514 | "execution_count": null, 1515 | "outputs": [] 1516 | }, 1517 | { 1518 | "cell_type": "code", 1519 | "metadata": { 1520 | "id": "V1VTPA1anQ2w" 1521 | }, 1522 | "source": [ 1523 | "lr=2e-5\n", 1524 | "eps=1e-8\n", 1525 | "epochs=3\n", 1526 | "device = torch.device('cuda')\n", 1527 | "xlmr_model, optimizer, scheduler = create_model(lr=lr, eps=eps, train_dataloader=train_dataloader, epochs=epochs, device=device)" 1528 | ], 1529 | "execution_count": null, 1530 | "outputs": [] 1531 | }, 1532 | { 1533 | "cell_type": "code", 1534 | "metadata": { 1535 | "id": "56767talsn4M" 1536 | }, 1537 | "source": [ 1538 | "training_stats=train_model(epochs=epochs,\n", 1539 | " xlmr_model=xlmr_model,\n", 1540 | " train_dataloader=train_dataloader,\n", 1541 | " validation_dataloader=valid_dataloader,\n", 1542 | " validation_df=train_data_htfl_en,\n", 1543 | " random_seed=42,\n", 1544 | " verbose=True,\n", 1545 | " optimizer=optimizer,\n", 1546 | " scheduler=scheduler)" 1547 | ], 1548 | "execution_count": null, 1549 | "outputs": [] 1550 | }, 1551 | { 1552 | "cell_type": "markdown", 1553 | "metadata": { 1554 | "id": "0-9PbANQp4Uj" 1555 | }, 1556 | "source": [ 1557 | "# Test Set Evaluation" 1558 | ] 1559 | }, 1560 | { 1561 | "cell_type": "code", 1562 | "metadata": { 1563 | "id": "kEniE8WRdjF3" 1564 | }, 1565 | "source": [ 1566 | "extracted_terms_htfl_en=extract_terms(train_data_htfl_en, xlmr_model)\n", 1567 | "extracted_terms_htfl_fr=extract_terms(train_data_htfl_fr, xlmr_model)\n", 1568 | "extracted_terms_htfl_nl=extract_terms(train_data_htfl_nl, xlmr_model)" 1569 | ], 1570 | "execution_count": null, 1571 | "outputs": [] 1572 | }, 1573 | { 1574 | "cell_type": "code", 1575 | "metadata": { 1576 | "id": "kB_M7qk9xbj5" 1577 | }, 1578 | "source": [ 1579 | "def computeTermEvalMetrics(extracted_terms, gold_df):\n", 1580 | " #make lower case cause gold standard is lower case\n", 1581 | " extracted_terms = set([item.lower() for item in extracted_terms])\n", 1582 | " gold_set=set(gold_df)\n", 1583 | " true_pos=extracted_terms.intersection(gold_set)\n", 1584 | " recall=len(true_pos)/len(gold_set)\n", 1585 | " precision=len(true_pos)/len(extracted_terms)\n", 1586 | "\n", 1587 | " print(\"Intersection\",len(true_pos))\n", 1588 | " print(\"Gold\",len(gold_set))\n", 1589 | " print(\"Extracted\",len(extracted_terms))\n", 1590 | " print(\"Recall:\", recall)\n", 1591 | " print(\"Precision:\", precision)\n", 1592 | " print(\"F1:\", 2*(precision*recall)/(precision+recall))" 1593 | ], 1594 | "execution_count": null, 1595 | "outputs": [] 1596 | }, 1597 | { 1598 | "cell_type": "code", 1599 | "metadata": { 1600 | "id": "cABUXjRY1ZHI" 1601 | }, 1602 | "source": [ 1603 | "computeTermEvalMetrics(extracted_terms_htfl_en, df_htfl_terms_en[\"Term\"])" 1604 | ], 1605 | "execution_count": null, 1606 | "outputs": [] 1607 | }, 1608 | { 1609 | "cell_type": "code", 1610 | "metadata": { 1611 | "id": "z1B0JipczW_7" 1612 | }, 1613 | "source": [ 1614 | "computeTermEvalMetrics(extracted_terms_htfl_fr, df_htfl_terms_fr[\"Term\"])" 1615 | ], 1616 | "execution_count": null, 1617 | "outputs": [] 1618 | }, 1619 | { 1620 | "cell_type": "code", 1621 | "metadata": { 1622 | "id": "KKTVrV0AzXEt" 1623 | }, 1624 | "source": [ 1625 | "computeTermEvalMetrics(extracted_terms_htfl_nl, df_htfl_terms_nl[\"Term\"])" 1626 | ], 1627 | "execution_count": null, 1628 | "outputs": [] 1629 | } 1630 | ] 1631 | } --------------------------------------------------------------------------------