├── README.md
└── Term_Extraction_Sequence_Classifier.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # Term-Extraction-With-Language-Models
2 |
3 | ## Reference
4 | Lang, C., Wachowiak, L., Heinisch, B., & Gromann, D. Transforming Term Extraction: Transformer-Based Approaches to Multilingual Term Extraction Across Domains.
5 | - [PDF](https://aclanthology.org/2021.findings-acl.316.pdf)
6 | - [Video Presentation](https://www.youtube.com/watch?v=JuBHSfFquCU)
7 |
8 | ## Description
9 | This repository contains the scripts used to finetune XLM-RoBERTa for the termextraction task on the ACTER dataset (https://github.com/AylaRT/ACTER) and the ACL RD-TEC 2.0 dataset (https://github.com/languagerecipes/acl-rd-tec-2.0). One model version is used as a token classifier deciding for each single token of an input sequence simultaneously if it is a term or a continuation of a term. The other model version is a sequence classifier that decides for a given candidate term and a context in which it appears whether it is a term or not.
10 |
11 | ## Requirements
12 | * transformers v.4.2.2
13 | * torch v.1.7.0+cu101
14 | * sentencepiece v.0.1.95
15 | * sklearn v.0.24.1
16 | * nltk v.3.2.5
17 | * spacy v.2.2.4
18 | * sacremoses v.0.0.43
19 | * pandas v.1.1.5
20 | * numpy v.1.19.5
21 |
22 | ## Results
23 |
24 | ### F1 Scores on ACTER
25 |
26 | Training | Test | Sequence Classifier | Token Classifier
27 | ------------ | ------------- | -------------|-------------
28 | EN | EN | 45.2 | 58.3
29 | FR | EN | 44.7 | 44.2
30 | NL | EN | 35.9 | 58.3
31 | ALL | EN | 46.0 | 56.2
32 | | | |
33 | EN | FR | 48.1 | 57.6
34 | FR | FR | 46.0 | 52.9
35 | NL | FR | 40.0 | 54.5
36 | ALL | FR | 46.7 | 55.3
37 | | | |
38 | EN | NL | 58.0 | 69.8
39 | FR | NL | 56.1 | 61.4
40 | NL | NL | 48.5 | 69.6
41 | ALL | NL | 56.0 | 67.8
42 |
43 | ### F1 Scores on ACL RD-TEC 2.0
44 | Data Type | Token Classifier |
45 | ------------ | ------------- |
46 | Annotator 1 | 75.8 |
47 | Annotator 2 | 80.0 |
48 |
49 | ## Hyperparameters
50 |
51 | ### Sequence Classifier
52 | * optimizer: Adam
53 | * learning rate: 2e-5
54 | * batch size: 32
55 | * epochs: 4
56 |
57 | ### Token Classifier
58 | * optimizer: Adam
59 | * learning rate: 2e-5
60 | * batch size: 8
61 | * epochs: Load best model at the end, evaluating the model every 100 steps
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/Term_Extraction_Sequence_Classifier.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "Term Extraction Sequence Classifier.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": [
10 | "9OsjSGOr0bSA"
11 | ],
12 | "authorship_tag": "ABX9TyMoG9pFKSz7mLYf/Ef0v5Bc",
13 | "include_colab_link": true
14 | },
15 | "kernelspec": {
16 | "display_name": "Python 3",
17 | "name": "python3"
18 | },
19 | "widgets": {
20 | "application/vnd.jupyter.widget-state+json": {
21 | "c4da41ffca2d4809a64ca7c3b4375bab": {
22 | "model_module": "@jupyter-widgets/controls",
23 | "model_name": "HBoxModel",
24 | "state": {
25 | "_dom_classes": [],
26 | "_model_module": "@jupyter-widgets/controls",
27 | "_model_module_version": "1.5.0",
28 | "_model_name": "HBoxModel",
29 | "_view_count": null,
30 | "_view_module": "@jupyter-widgets/controls",
31 | "_view_module_version": "1.5.0",
32 | "_view_name": "HBoxView",
33 | "box_style": "",
34 | "children": [
35 | "IPY_MODEL_f8aa0656efa64e5385ec59a765939770",
36 | "IPY_MODEL_59dfeb1cd7f042eba3faa1ce8263eb0f"
37 | ],
38 | "layout": "IPY_MODEL_c0aa55048c3b41f097d1583b02dc3c45"
39 | }
40 | },
41 | "f8aa0656efa64e5385ec59a765939770": {
42 | "model_module": "@jupyter-widgets/controls",
43 | "model_name": "FloatProgressModel",
44 | "state": {
45 | "_dom_classes": [],
46 | "_model_module": "@jupyter-widgets/controls",
47 | "_model_module_version": "1.5.0",
48 | "_model_name": "FloatProgressModel",
49 | "_view_count": null,
50 | "_view_module": "@jupyter-widgets/controls",
51 | "_view_module_version": "1.5.0",
52 | "_view_name": "ProgressView",
53 | "bar_style": "success",
54 | "description": "Downloading: 100%",
55 | "description_tooltip": null,
56 | "layout": "IPY_MODEL_dc06957d389e4995acbd22e23bdc8cef",
57 | "max": 5069051,
58 | "min": 0,
59 | "orientation": "horizontal",
60 | "style": "IPY_MODEL_53cf87e674334b27ad3a48422ec20030",
61 | "value": 5069051
62 | }
63 | },
64 | "59dfeb1cd7f042eba3faa1ce8263eb0f": {
65 | "model_module": "@jupyter-widgets/controls",
66 | "model_name": "HTMLModel",
67 | "state": {
68 | "_dom_classes": [],
69 | "_model_module": "@jupyter-widgets/controls",
70 | "_model_module_version": "1.5.0",
71 | "_model_name": "HTMLModel",
72 | "_view_count": null,
73 | "_view_module": "@jupyter-widgets/controls",
74 | "_view_module_version": "1.5.0",
75 | "_view_name": "HTMLView",
76 | "description": "",
77 | "description_tooltip": null,
78 | "layout": "IPY_MODEL_9257c9f9130d4f47a05e3066eec6fffd",
79 | "placeholder": "",
80 | "style": "IPY_MODEL_134178a4421b41de93598e8e0f08dcfb",
81 | "value": " 5.07M/5.07M [00:01<00:00, 2.72MB/s]"
82 | }
83 | },
84 | "c0aa55048c3b41f097d1583b02dc3c45": {
85 | "model_module": "@jupyter-widgets/base",
86 | "model_name": "LayoutModel",
87 | "state": {
88 | "_model_module": "@jupyter-widgets/base",
89 | "_model_module_version": "1.2.0",
90 | "_model_name": "LayoutModel",
91 | "_view_count": null,
92 | "_view_module": "@jupyter-widgets/base",
93 | "_view_module_version": "1.2.0",
94 | "_view_name": "LayoutView",
95 | "align_content": null,
96 | "align_items": null,
97 | "align_self": null,
98 | "border": null,
99 | "bottom": null,
100 | "display": null,
101 | "flex": null,
102 | "flex_flow": null,
103 | "grid_area": null,
104 | "grid_auto_columns": null,
105 | "grid_auto_flow": null,
106 | "grid_auto_rows": null,
107 | "grid_column": null,
108 | "grid_gap": null,
109 | "grid_row": null,
110 | "grid_template_areas": null,
111 | "grid_template_columns": null,
112 | "grid_template_rows": null,
113 | "height": null,
114 | "justify_content": null,
115 | "justify_items": null,
116 | "left": null,
117 | "margin": null,
118 | "max_height": null,
119 | "max_width": null,
120 | "min_height": null,
121 | "min_width": null,
122 | "object_fit": null,
123 | "object_position": null,
124 | "order": null,
125 | "overflow": null,
126 | "overflow_x": null,
127 | "overflow_y": null,
128 | "padding": null,
129 | "right": null,
130 | "top": null,
131 | "visibility": null,
132 | "width": null
133 | }
134 | },
135 | "dc06957d389e4995acbd22e23bdc8cef": {
136 | "model_module": "@jupyter-widgets/base",
137 | "model_name": "LayoutModel",
138 | "state": {
139 | "_model_module": "@jupyter-widgets/base",
140 | "_model_module_version": "1.2.0",
141 | "_model_name": "LayoutModel",
142 | "_view_count": null,
143 | "_view_module": "@jupyter-widgets/base",
144 | "_view_module_version": "1.2.0",
145 | "_view_name": "LayoutView",
146 | "align_content": null,
147 | "align_items": null,
148 | "align_self": null,
149 | "border": null,
150 | "bottom": null,
151 | "display": null,
152 | "flex": null,
153 | "flex_flow": null,
154 | "grid_area": null,
155 | "grid_auto_columns": null,
156 | "grid_auto_flow": null,
157 | "grid_auto_rows": null,
158 | "grid_column": null,
159 | "grid_gap": null,
160 | "grid_row": null,
161 | "grid_template_areas": null,
162 | "grid_template_columns": null,
163 | "grid_template_rows": null,
164 | "height": null,
165 | "justify_content": null,
166 | "justify_items": null,
167 | "left": null,
168 | "margin": null,
169 | "max_height": null,
170 | "max_width": null,
171 | "min_height": null,
172 | "min_width": null,
173 | "object_fit": null,
174 | "object_position": null,
175 | "order": null,
176 | "overflow": null,
177 | "overflow_x": null,
178 | "overflow_y": null,
179 | "padding": null,
180 | "right": null,
181 | "top": null,
182 | "visibility": null,
183 | "width": null
184 | }
185 | },
186 | "53cf87e674334b27ad3a48422ec20030": {
187 | "model_module": "@jupyter-widgets/controls",
188 | "model_name": "ProgressStyleModel",
189 | "state": {
190 | "_model_module": "@jupyter-widgets/controls",
191 | "_model_module_version": "1.5.0",
192 | "_model_name": "ProgressStyleModel",
193 | "_view_count": null,
194 | "_view_module": "@jupyter-widgets/base",
195 | "_view_module_version": "1.2.0",
196 | "_view_name": "StyleView",
197 | "bar_color": null,
198 | "description_width": "initial"
199 | }
200 | },
201 | "9257c9f9130d4f47a05e3066eec6fffd": {
202 | "model_module": "@jupyter-widgets/base",
203 | "model_name": "LayoutModel",
204 | "state": {
205 | "_model_module": "@jupyter-widgets/base",
206 | "_model_module_version": "1.2.0",
207 | "_model_name": "LayoutModel",
208 | "_view_count": null,
209 | "_view_module": "@jupyter-widgets/base",
210 | "_view_module_version": "1.2.0",
211 | "_view_name": "LayoutView",
212 | "align_content": null,
213 | "align_items": null,
214 | "align_self": null,
215 | "border": null,
216 | "bottom": null,
217 | "display": null,
218 | "flex": null,
219 | "flex_flow": null,
220 | "grid_area": null,
221 | "grid_auto_columns": null,
222 | "grid_auto_flow": null,
223 | "grid_auto_rows": null,
224 | "grid_column": null,
225 | "grid_gap": null,
226 | "grid_row": null,
227 | "grid_template_areas": null,
228 | "grid_template_columns": null,
229 | "grid_template_rows": null,
230 | "height": null,
231 | "justify_content": null,
232 | "justify_items": null,
233 | "left": null,
234 | "margin": null,
235 | "max_height": null,
236 | "max_width": null,
237 | "min_height": null,
238 | "min_width": null,
239 | "object_fit": null,
240 | "object_position": null,
241 | "order": null,
242 | "overflow": null,
243 | "overflow_x": null,
244 | "overflow_y": null,
245 | "padding": null,
246 | "right": null,
247 | "top": null,
248 | "visibility": null,
249 | "width": null
250 | }
251 | },
252 | "134178a4421b41de93598e8e0f08dcfb": {
253 | "model_module": "@jupyter-widgets/controls",
254 | "model_name": "DescriptionStyleModel",
255 | "state": {
256 | "_model_module": "@jupyter-widgets/controls",
257 | "_model_module_version": "1.5.0",
258 | "_model_name": "DescriptionStyleModel",
259 | "_view_count": null,
260 | "_view_module": "@jupyter-widgets/base",
261 | "_view_module_version": "1.2.0",
262 | "_view_name": "StyleView",
263 | "description_width": ""
264 | }
265 | }
266 | }
267 | }
268 | },
269 | "cells": [
270 | {
271 | "cell_type": "markdown",
272 | "metadata": {
273 | "id": "view-in-github",
274 | "colab_type": "text"
275 | },
276 | "source": [
277 | "
"
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {
283 | "id": "0yMTmZptEkHC"
284 | },
285 | "source": [
286 | "# Imports\n"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "metadata": {
292 | "id": "aaxWLY9GFE2W"
293 | },
294 | "source": [
295 | "!pip install transformers\n",
296 | "!pip install sacremoses\n",
297 | "!pip install sentencepiece"
298 | ],
299 | "execution_count": null,
300 | "outputs": []
301 | },
302 | {
303 | "cell_type": "code",
304 | "metadata": {
305 | "id": "m9fYtB3_FHuK"
306 | },
307 | "source": [
308 | "#torch and tranformers for model and training\n",
309 | "import torch \n",
310 | "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
311 | "from torch.utils.data import TensorDataset\n",
312 | "from transformers import XLMRobertaTokenizer \n",
313 | "from transformers import XLMRobertaForSequenceClassification\n",
314 | "from transformers import AdamW \n",
315 | "from transformers import get_linear_schedule_with_warmup\n",
316 | "import sentencepiece\n",
317 | "\n",
318 | "#sklearn for evaluation\n",
319 | "from sklearn import preprocessing \n",
320 | "from sklearn.metrics import classification_report \n",
321 | "from sklearn.metrics import f1_score\n",
322 | "from sklearn.metrics import confusion_matrix\n",
323 | "from sklearn.model_selection import ParameterGrid \n",
324 | "from sklearn.model_selection import ParameterSampler \n",
325 | "from sklearn.utils.fixes import loguniform\n",
326 | "\n",
327 | "#nlp preprocessing\n",
328 | "from nltk import ngrams \n",
329 | "from spacy.pipeline import SentenceSegmenter\n",
330 | "from spacy.lang.en import English\n",
331 | "from spacy.pipeline import Sentencizer\n",
332 | "from sacremoses import MosesTokenizer, MosesDetokenizer\n",
333 | "\n",
334 | "\n",
335 | "#utilities\n",
336 | "import pandas as pd\n",
337 | "import glob, os\n",
338 | "import time\n",
339 | "import datetime\n",
340 | "import random\n",
341 | "import numpy as np\n",
342 | "import matplotlib.pyplot as plt\n",
343 | "% matplotlib inline\n",
344 | "import seaborn as sns\n",
345 | "import pickle"
346 | ],
347 | "execution_count": 3,
348 | "outputs": []
349 | },
350 | {
351 | "cell_type": "code",
352 | "metadata": {
353 | "colab": {
354 | "base_uri": "https://localhost:8080/"
355 | },
356 | "id": "kpY66eTVxQNH",
357 | "outputId": "ba5a7610-f5ce-44cc-b19c-c2d9db65909f"
358 | },
359 | "source": [
360 | "# connect to GPU \n",
361 | "device = torch.device('cuda')\n",
362 | "\n",
363 | "print('Connected to GPU:', torch.cuda.get_device_name(0))"
364 | ],
365 | "execution_count": null,
366 | "outputs": [
367 | {
368 | "output_type": "stream",
369 | "text": [
370 | "Connected to GPU: Tesla P100-PCIE-16GB\n"
371 | ],
372 | "name": "stdout"
373 | }
374 | ]
375 | },
376 | {
377 | "cell_type": "markdown",
378 | "metadata": {
379 | "id": "3RPZ14sYHHUm"
380 | },
381 | "source": [
382 | "# Prepare Data"
383 | ]
384 | },
385 | {
386 | "cell_type": "markdown",
387 | "metadata": {
388 | "id": "TKqV3YfXHSNz"
389 | },
390 | "source": [
391 | "Training Data: corp, wind\n",
392 | "\n",
393 | "Valid: equi\n",
394 | "\n",
395 | "Test Data: htfl"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "metadata": {
401 | "id": "ERUBsPPOFfe1"
402 | },
403 | "source": [
404 | "#load terms\n",
405 | "\n",
406 | "#en\n",
407 | "df_corp_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/corp/annotations/corp_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
408 | "df_equi_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/equi/annotations/equi_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
409 | "df_htfl_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/htfl/annotations/htfl_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
410 | "df_wind_terms_en=pd.read_csv('ACTER-master/ACTER-master/en/wind/annotations/wind_en_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
411 | "\n",
412 | "#fr\n",
413 | "df_corp_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/corp/annotations/corp_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
414 | "df_equi_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/equi/annotations/equi_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
415 | "df_htfl_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/htfl/annotations/htfl_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
416 | "df_wind_terms_fr=pd.read_csv('ACTER-master/ACTER-master/fr/wind/annotations/wind_fr_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
417 | "\n",
418 | "#nl\n",
419 | "df_corp_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/corp/annotations/corp_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
420 | "df_equi_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/equi/annotations/equi_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
421 | "df_htfl_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/htfl/annotations/htfl_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
422 | "df_wind_terms_nl=pd.read_csv('ACTER-master/ACTER-master/nl/wind/annotations/wind_nl_terms_nes.ann', delimiter=\"\\t\", names=[\"Term\", \"Label\"]) \n",
423 | "\n",
424 | "labels=[\"Random\", \"Term\"]"
425 | ],
426 | "execution_count": null,
427 | "outputs": []
428 | },
429 | {
430 | "cell_type": "code",
431 | "metadata": {
432 | "colab": {
433 | "base_uri": "https://localhost:8080/",
434 | "height": 419
435 | },
436 | "id": "tw11QcsHF8Gc",
437 | "outputId": "a16e39d2-4ca0-4127-b7b0-414a028ef98f"
438 | },
439 | "source": [
440 | "# example terms\n",
441 | "df_wind_terms_en"
442 | ],
443 | "execution_count": null,
444 | "outputs": [
445 | {
446 | "output_type": "execute_result",
447 | "data": {
448 | "text/html": [
449 | "
\n",
450 | "\n",
463 | "
\n",
464 | " \n",
465 | " \n",
466 | " | \n",
467 | " Term | \n",
468 | " Label | \n",
469 | "
\n",
470 | " \n",
471 | " \n",
472 | " \n",
473 | " | 0 | \n",
474 | " 48/600 | \n",
475 | " Named_Entity | \n",
476 | "
\n",
477 | " \n",
478 | " | 1 | \n",
479 | " 4energia | \n",
480 | " Named_Entity | \n",
481 | "
\n",
482 | " \n",
483 | " | 2 | \n",
484 | " 4energy | \n",
485 | " Named_Entity | \n",
486 | "
\n",
487 | " \n",
488 | " | 3 | \n",
489 | " ab \"lietuvos energija\" | \n",
490 | " Named_Entity | \n",
491 | "
\n",
492 | " \n",
493 | " | 4 | \n",
494 | " ab lietuvos elektrine | \n",
495 | " Named_Entity | \n",
496 | "
\n",
497 | " \n",
498 | " | ... | \n",
499 | " ... | \n",
500 | " ... | \n",
501 | "
\n",
502 | " \n",
503 | " | 1529 | \n",
504 | " zhiquan | \n",
505 | " Named_Entity | \n",
506 | "
\n",
507 | " \n",
508 | " | 1530 | \n",
509 | " çetinkaya | \n",
510 | " Named_Entity | \n",
511 | "
\n",
512 | " \n",
513 | " | 1531 | \n",
514 | " çeti̇nkaya | \n",
515 | " Named_Entity | \n",
516 | "
\n",
517 | " \n",
518 | " | 1532 | \n",
519 | " çeşme | \n",
520 | " Named_Entity | \n",
521 | "
\n",
522 | " \n",
523 | " | 1533 | \n",
524 | " özgen | \n",
525 | " Named_Entity | \n",
526 | "
\n",
527 | " \n",
528 | "
\n",
529 | "
1534 rows × 2 columns
\n",
530 | "
"
531 | ],
532 | "text/plain": [
533 | " Term Label\n",
534 | "0 48/600 Named_Entity\n",
535 | "1 4energia Named_Entity\n",
536 | "2 4energy Named_Entity\n",
537 | "3 ab \"lietuvos energija\" Named_Entity\n",
538 | "4 ab lietuvos elektrine Named_Entity\n",
539 | "... ... ...\n",
540 | "1529 zhiquan Named_Entity\n",
541 | "1530 çetinkaya Named_Entity\n",
542 | "1531 çeti̇nkaya Named_Entity\n",
543 | "1532 çeşme Named_Entity\n",
544 | "1533 özgen Named_Entity\n",
545 | "\n",
546 | "[1534 rows x 2 columns]"
547 | ]
548 | },
549 | "metadata": {
550 | "tags": []
551 | },
552 | "execution_count": 8
553 | }
554 | ]
555 | },
556 | {
557 | "cell_type": "markdown",
558 | "metadata": {
559 | "id": "sU7NMPaDvbWt"
560 | },
561 | "source": [
562 | "**Functions for preprocessing and creating of Training Data**"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "metadata": {
568 | "id": "3_stqlIDvZxA"
569 | },
570 | "source": [
571 | "#load all text files from folder into a string\n",
572 | "def load_text_corpus(path):\n",
573 | " text_data=\"\"\n",
574 | " print(glob.glob(path))\n",
575 | " for file in glob.glob(path+\"*.txt\"):\n",
576 | " print(file)\n",
577 | " with open(file) as f:\n",
578 | " temp_data = f.read()\n",
579 | " print(len(temp_data))\n",
580 | " text_data=text_data+\" \"+temp_data\n",
581 | " print(len(text_data))\n",
582 | " return text_data"
583 | ],
584 | "execution_count": null,
585 | "outputs": []
586 | },
587 | {
588 | "cell_type": "code",
589 | "metadata": {
590 | "id": "4nXtHwAyPoK0"
591 | },
592 | "source": [
593 | "#split in sentences and tokenize\n",
594 | "def preprocess(text):\n",
595 | " #sentenize (from spacy)\n",
596 | " sentencizer = Sentencizer()\n",
597 | " nlp = English()\n",
598 | " nlp.add_pipe(sentencizer)\n",
599 | " doc = nlp(text)\n",
600 | "\n",
601 | " #tokenize\n",
602 | " sentence_list=[]\n",
603 | " mt = MosesTokenizer(lang='en')\n",
604 | " for s in doc.sents:\n",
605 | " tokenized_text = mt.tokenize(s, return_str=True)\n",
606 | " sentence_list.append((tokenized_text.split(), s)) #append tuple of tokens and original senteence\n",
607 | " return sentence_list\n"
608 | ],
609 | "execution_count": null,
610 | "outputs": []
611 | },
612 | {
613 | "cell_type": "code",
614 | "metadata": {
615 | "id": "1qBA_KhoQkhB"
616 | },
617 | "source": [
618 | "#input is list of sentences and dataframe containing terms\n",
619 | "def create_training_data(sentence_list, df_terms, n):\n",
620 | "\n",
621 | " #create empty dataframe\n",
622 | " training_data = pd.DataFrame(columns=['n_gram', 'Context', 'Label', \"Termtype\"])\n",
623 | "\n",
624 | " md = MosesDetokenizer(lang='en')\n",
625 | "\n",
626 | "\n",
627 | " print(len(sentence_list))\n",
628 | " count=0\n",
629 | "\n",
630 | " for sen in sentence_list:\n",
631 | " count+=1\n",
632 | " if count%100==0:print(count)\n",
633 | "\n",
634 | " s=sen[0] #take first part of tuple, i.e. the tokens\n",
635 | "\n",
636 | " # 1-gram up to n-gram\n",
637 | " for i in range(1,n+1):\n",
638 | " #create n-grams of this sentence\n",
639 | " n_grams = ngrams(s, i)\n",
640 | "\n",
641 | " #look if n-grams are in the annotation dataset\n",
642 | " for n_gram in n_grams: \n",
643 | " n_gram=md.detokenize(n_gram) \n",
644 | " context=str(sen[1]).strip()\n",
645 | " #if yes add an entry to the training data\n",
646 | " if n_gram.lower() in df_terms.values:\n",
647 | " #append positive sample\n",
648 | " #get termtype like common term\n",
649 | " termtype=\"/\"#df_terms.loc[df_terms['Term'] == n_gram.lower()].iloc[0][\"Label\"]\n",
650 | " training_data = training_data.append({'n_gram': n_gram, 'Context': context, 'Label': 1, \"Termtype\":termtype}, ignore_index=True)\n",
651 | " else:\n",
652 | " #append negative sample\n",
653 | " training_data = training_data.append({'n_gram': n_gram, 'Context': context, 'Label': 0, \"Termtype\":\"None\"}, ignore_index=True)\n",
654 | "\n",
655 | " return training_data\n",
656 | "\n",
657 | " "
658 | ],
659 | "execution_count": null,
660 | "outputs": []
661 | },
662 | {
663 | "cell_type": "markdown",
664 | "metadata": {
665 | "id": "4HhBTwYl1-dy"
666 | },
667 | "source": [
668 | "**Create Training Data**"
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "metadata": {
674 | "id": "UemCf-2xPrn1"
675 | },
676 | "source": [
677 | "# en \n",
678 | "#create trainings data for all corp texts\n",
679 | "corp_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/corp/texts/annotated/\") # load test\n",
680 | "corp_s_list=preprocess(corp_text_en) # preprocess\n",
681 | "train_data_corp_en=create_training_data(corp_s_list, df_corp_terms_en, 6) # create training data\n",
682 | "\n",
683 | "#create trainings data for all wind texts\n",
684 | "wind_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/wind/texts/annotated/\") # load test\n",
685 | "wind_s_list=preprocess(wind_text_en) # preprocess\n",
686 | "train_data_wind_en=create_training_data(wind_s_list, df_wind_terms_en, 6) # create training data\n",
687 | "\n",
688 | "#create trainings data for all equi texts\n",
689 | "equi_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/equi/texts/annotated/\") # load test\n",
690 | "equi_s_list=preprocess(equi_text_en) # preprocess\n",
691 | "train_data_equi_en=create_training_data(equi_s_list, df_equi_terms_en, 6) # create training data\n",
692 | "\n",
693 | "#create trainings data for all htfl texts\n",
694 | "htfl_text_en=load_text_corpus(\"ACTER-master/ACTER-master/en/htfl/texts/annotated/\") # load test\n",
695 | "htfl_s_list=preprocess(htfl_text_en) # preprocess\n",
696 | "train_data_htfl_en=create_training_data(htfl_s_list, df_htfl_terms_en, 6) # create training data "
697 | ],
698 | "execution_count": null,
699 | "outputs": []
700 | },
701 | {
702 | "cell_type": "code",
703 | "metadata": {
704 | "id": "stFqQ_Sd2gAN"
705 | },
706 | "source": [
707 | "#fr\n",
708 | "corp_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/corp/texts/annotated/\") # load text\n",
709 | "corp_s_list=preprocess(corp_text_fr) # preprocess\n",
710 | "train_data_corp_fr=create_training_data(corp_s_list, df_corp_terms_fr, 6) # create training data\n",
711 | "\n",
712 | "wind_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/wind/texts/annotated/\") # load text\n",
713 | "wind_s_list=preprocess(wind_text_fr) # preprocess\n",
714 | "train_data_wind_fr=create_training_data(wind_s_list, df_wind_terms_fr, 6) # create training data\n",
715 | "\n",
716 | "equi_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/equi/texts/annotated/\") # load text\n",
717 | "equi_s_list=preprocess(equi_text_fr) # preprocess\n",
718 | "train_data_equi_fr=create_training_data(equi_s_list, df_equi_terms_fr, 6) # create training data\n",
719 | "\n",
720 | "htfl_text_fr=load_text_corpus(\"ACTER-master/ACTER-master/fr/htfl/texts/annotated/\") # load text\n",
721 | "htfl_s_list=preprocess(htfl_text_fr) # preprocess\n",
722 | "train_data_htfl_fr=create_training_data(htfl_s_list, df_htfl_terms_fr, 6) # create training data "
723 | ],
724 | "execution_count": null,
725 | "outputs": []
726 | },
727 | {
728 | "cell_type": "code",
729 | "metadata": {
730 | "id": "z2PI4ngj2gKZ"
731 | },
732 | "source": [
733 | "#nl\n",
734 | "corp_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/corp/texts/annotated/\") # load text\n",
735 | "corp_s_list=preprocess(corp_text_nl) # preprocess\n",
736 | "train_data_corp_nl=create_training_data(corp_s_list, df_corp_terms_nl, 6) # create training data\n",
737 | "\n",
738 | "wind_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/wind/texts/annotated/\") # load text\n",
739 | "wind_s_list=preprocess(wind_text_nl) # preprocess\n",
740 | "train_data_wind_nl=create_training_data(wind_s_list, df_wind_terms_nl, 6) # create training data\n",
741 | "\n",
742 | "equi_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/equi/texts/annotated/\") # load text\n",
743 | "equi_s_list=preprocess(equi_text_nl) # preprocess\n",
744 | "train_data_equi_nl=create_training_data(equi_s_list, df_equi_terms_nl, 6) # create training data\n",
745 | "\n",
746 | "htfl_text_nl=load_text_corpus(\"ACTER-master/ACTER-master/nl/htfl/texts/annotated/\") # load text\n",
747 | "htfl_s_list=preprocess(htfl_text_nl) # preprocess\n",
748 | "train_data_htfl_nl=create_training_data(htfl_s_list, df_htfl_terms_nl, 6) # create training data "
749 | ],
750 | "execution_count": null,
751 | "outputs": []
752 | },
753 | {
754 | "cell_type": "code",
755 | "metadata": {
756 | "colab": {
757 | "base_uri": "https://localhost:8080/"
758 | },
759 | "id": "GXrT0L_DNCE_",
760 | "outputId": "6eed88af-a6eb-43e1-e18c-66c3fd432754"
761 | },
762 | "source": [
763 | "print(train_data_corp_en.groupby('Label').count())\n",
764 | "print(train_data_wind_en.groupby('Label').count())\n",
765 | "print(train_data_equi_en.groupby('Label').count())\n",
766 | "print(train_data_htfl_en.groupby('Label').count())"
767 | ],
768 | "execution_count": null,
769 | "outputs": [
770 | {
771 | "output_type": "stream",
772 | "text": [
773 | " n_gram Context Termtype\n",
774 | "Label \n",
775 | "0 274139 274139 274139\n",
776 | "1 8708 8708 8708\n",
777 | " n_gram Context Termtype\n",
778 | "Label \n",
779 | "0 311535 311535 311535\n",
780 | "1 10542 10542 10542\n",
781 | " n_gram Context Termtype\n",
782 | "Label \n",
783 | "0 298863 298863 298863\n",
784 | "1 13891 13891 13891\n",
785 | " n_gram Context Termtype\n",
786 | "Label \n",
787 | "0 290334 290334 290334\n",
788 | "1 14376 14376 14376\n"
789 | ],
790 | "name": "stdout"
791 | }
792 | ]
793 | },
794 | {
795 | "cell_type": "code",
796 | "metadata": {
797 | "colab": {
798 | "base_uri": "https://localhost:8080/",
799 | "height": 419
800 | },
801 | "id": "S4_Q9krEESA2",
802 | "outputId": "5fe80a12-619a-47f5-e7f1-cfd4a9a36d11"
803 | },
804 | "source": [
805 | "train_data_equi_en"
806 | ],
807 | "execution_count": null,
808 | "outputs": [
809 | {
810 | "output_type": "execute_result",
811 | "data": {
812 | "text/html": [
813 | "\n",
814 | "\n",
827 | "
\n",
828 | " \n",
829 | " \n",
830 | " | \n",
831 | " n_gram | \n",
832 | " Context | \n",
833 | " Label | \n",
834 | " Termtype | \n",
835 | "
\n",
836 | " \n",
837 | " \n",
838 | " \n",
839 | " | 0 | \n",
840 | " Pirouette | \n",
841 | " Pirouette (dressage)\\n\\nA Pirouette is a Frenc... | \n",
842 | " 1 | \n",
843 | " Specific_Term | \n",
844 | "
\n",
845 | " \n",
846 | " | 1 | \n",
847 | " ( | \n",
848 | " Pirouette (dressage)\\n\\nA Pirouette is a Frenc... | \n",
849 | " 0 | \n",
850 | " None | \n",
851 | "
\n",
852 | " \n",
853 | " | 2 | \n",
854 | " dressage | \n",
855 | " Pirouette (dressage)\\n\\nA Pirouette is a Frenc... | \n",
856 | " 1 | \n",
857 | " Common_Term | \n",
858 | "
\n",
859 | " \n",
860 | " | 3 | \n",
861 | " ) | \n",
862 | " Pirouette (dressage)\\n\\nA Pirouette is a Frenc... | \n",
863 | " 0 | \n",
864 | " None | \n",
865 | "
\n",
866 | " \n",
867 | " | 4 | \n",
868 | " A | \n",
869 | " Pirouette (dressage)\\n\\nA Pirouette is a Frenc... | \n",
870 | " 1 | \n",
871 | " Specific_Term | \n",
872 | "
\n",
873 | " \n",
874 | " | ... | \n",
875 | " ... | \n",
876 | " ... | \n",
877 | " ... | \n",
878 | " ... | \n",
879 | "
\n",
880 | " \n",
881 | " | 312749 | \n",
882 | " about it when he's done | \n",
883 | " Stop and let your horse think about it when he... | \n",
884 | " 0 | \n",
885 | " None | \n",
886 | "
\n",
887 | " \n",
888 | " | 312750 | \n",
889 | " it when he's done something | \n",
890 | " Stop and let your horse think about it when he... | \n",
891 | " 0 | \n",
892 | " None | \n",
893 | "
\n",
894 | " \n",
895 | " | 312751 | \n",
896 | " when he's done something right | \n",
897 | " Stop and let your horse think about it when he... | \n",
898 | " 0 | \n",
899 | " None | \n",
900 | "
\n",
901 | " \n",
902 | " | 312752 | \n",
903 | " he's done something right. | \n",
904 | " Stop and let your horse think about it when he... | \n",
905 | " 0 | \n",
906 | " None | \n",
907 | "
\n",
908 | " \n",
909 | " | 312753 | \n",
910 | " 's done something right. \" | \n",
911 | " Stop and let your horse think about it when he... | \n",
912 | " 0 | \n",
913 | " None | \n",
914 | "
\n",
915 | " \n",
916 | "
\n",
917 | "
312754 rows × 4 columns
\n",
918 | "
"
919 | ],
920 | "text/plain": [
921 | " n_gram ... Termtype\n",
922 | "0 Pirouette ... Specific_Term\n",
923 | "1 ( ... None\n",
924 | "2 dressage ... Common_Term\n",
925 | "3 ) ... None\n",
926 | "4 A ... Specific_Term\n",
927 | "... ... ... ...\n",
928 | "312749 about it when he's done ... None\n",
929 | "312750 it when he's done something ... None\n",
930 | "312751 when he's done something right ... None\n",
931 | "312752 he's done something right. ... None\n",
932 | "312753 's done something right. \" ... None\n",
933 | "\n",
934 | "[312754 rows x 4 columns]"
935 | ]
936 | },
937 | "metadata": {
938 | "tags": []
939 | },
940 | "execution_count": 24
941 | }
942 | ]
943 | },
944 | {
945 | "cell_type": "markdown",
946 | "metadata": {
947 | "id": "WRYG7Q_sDnNw"
948 | },
949 | "source": [
950 | "**Undersample**"
951 | ]
952 | },
953 | {
954 | "cell_type": "code",
955 | "metadata": {
956 | "id": "qvHscNxQCmvJ"
957 | },
958 | "source": [
959 | "#undersample class 0 so the amount of trainingsample is the same as label 1 \n",
960 | "\n",
961 | "def undersample(train_data):\n",
962 | "# Class count\n",
963 | " print(\"Before\")\n",
964 | " print(train_data.Label.value_counts())\n",
965 | " count_class_0, count_class_1 = train_data.Label.value_counts()\n",
966 | "\n",
967 | " # Divide by class\n",
968 | " df_class_0 = train_data[train_data['Label'] == 0]\n",
969 | " df_class_1 = train_data[train_data['Label'] == 1]\n",
970 | "\n",
971 | " df_class_0_under = df_class_0.sample(count_class_1)\n",
972 | " df_test_under = pd.concat([df_class_0_under, df_class_1], axis=0)\n",
973 | "\n",
974 | " print(\"After\")\n",
975 | " print(df_test_under.Label.value_counts())\n",
976 | "\n",
977 | " return df_test_under"
978 | ],
979 | "execution_count": null,
980 | "outputs": []
981 | },
982 | {
983 | "cell_type": "code",
984 | "metadata": {
985 | "colab": {
986 | "base_uri": "https://localhost:8080/"
987 | },
988 | "id": "-wi80YrXFHj4",
989 | "outputId": "270e8a84-7e17-4952-fad3-a9998718f99b"
990 | },
991 | "source": [
992 | "# undersample the trainingsdata\n",
993 | "\n",
994 | "#en\n",
995 | "train_data_corp_en=undersample(train_data_corp_en)\n",
996 | "\n",
997 | "train_data_wind_en=undersample(train_data_wind_en)\n",
998 | "\n",
999 | "\n",
1000 | "#fr\n",
1001 | "train_data_corp_fr=undersample(train_data_corp_fr)\n",
1002 | "\n",
1003 | "train_data_wind_fr=undersample(train_data_wind_fr)\n",
1004 | "\n",
1005 | "\n",
1006 | "#nl\n",
1007 | "train_data_corp_nl=undersample(train_data_corp_nl)\n",
1008 | "\n",
1009 | "train_data_wind_nl=undersample(train_data_wind_nl)"
1010 | ],
1011 | "execution_count": null,
1012 | "outputs": [
1013 | {
1014 | "output_type": "stream",
1015 | "text": [
1016 | "Before\n",
1017 | "0 274139\n",
1018 | "1 8708\n",
1019 | "Name: Label, dtype: int64\n",
1020 | "After\n",
1021 | "1 8708\n",
1022 | "0 8708\n",
1023 | "Name: Label, dtype: int64\n",
1024 | "Before\n",
1025 | "0 311535\n",
1026 | "1 10542\n",
1027 | "Name: Label, dtype: int64\n",
1028 | "After\n",
1029 | "1 10542\n",
1030 | "0 10542\n",
1031 | "Name: Label, dtype: int64\n",
1032 | "Before\n",
1033 | "0 325242\n",
1034 | "1 7443\n",
1035 | "Name: Label, dtype: int64\n",
1036 | "After\n",
1037 | "1 7443\n",
1038 | "0 7443\n",
1039 | "Name: Label, dtype: int64\n",
1040 | "Before\n",
1041 | "0 356805\n",
1042 | "1 9293\n",
1043 | "Name: Label, dtype: int64\n",
1044 | "After\n",
1045 | "1 9293\n",
1046 | "0 9293\n",
1047 | "Name: Label, dtype: int64\n",
1048 | "Before\n",
1049 | "0 283267\n",
1050 | "1 7071\n",
1051 | "Name: Label, dtype: int64\n",
1052 | "After\n",
1053 | "1 7071\n",
1054 | "0 7071\n",
1055 | "Name: Label, dtype: int64\n",
1056 | "Before\n",
1057 | "0 287361\n",
1058 | "1 5582\n",
1059 | "Name: Label, dtype: int64\n",
1060 | "After\n",
1061 | "1 5582\n",
1062 | "0 5582\n",
1063 | "Name: Label, dtype: int64\n"
1064 | ],
1065 | "name": "stdout"
1066 | }
1067 | ]
1068 | },
1069 | {
1070 | "cell_type": "code",
1071 | "metadata": {
1072 | "colab": {
1073 | "base_uri": "https://localhost:8080/"
1074 | },
1075 | "id": "VSy8hZggPQpf",
1076 | "outputId": "a8893623-b01f-4794-e20d-79e5d6a2136a"
1077 | },
1078 | "source": [
1079 | "#concat trainingsdata\n",
1080 | "trainings_data_df = pd.concat([train_data_corp_en, train_data_wind_en, train_data_corp_fr, train_data_wind_fr, train_data_corp_nl, train_data_wind_nl])\n",
1081 | "\n",
1082 | "valid_data_df = train_data_equi_en #pd.concat([train_data_equi_en, train_data_equi_fr, train_data_equi_nl ])\n",
1083 | "\n",
1084 | "test_data_df_en = train_data_htfl_en\n",
1085 | "test_data_df_fr = train_data_htfl_fr\n",
1086 | "test_data_df_nl = train_data_htfl_nl\n",
1087 | "\n",
1088 | "print(len(trainings_data_df))\n",
1089 | "print(len(valid_data_df))\n",
1090 | "print(len(test_data_df_en))\n",
1091 | "print(len(test_data_df_fr))\n",
1092 | "print(len(test_data_df_nl))"
1093 | ],
1094 | "execution_count": null,
1095 | "outputs": [
1096 | {
1097 | "output_type": "stream",
1098 | "text": [
1099 | "97278\n",
1100 | "312754\n",
1101 | "304710\n",
1102 | "303069\n",
1103 | "292615\n"
1104 | ],
1105 | "name": "stdout"
1106 | }
1107 | ]
1108 | },
1109 | {
1110 | "cell_type": "markdown",
1111 | "metadata": {
1112 | "id": "jKtVpCjIWPvO"
1113 | },
1114 | "source": [
1115 | "**Tokenizer**"
1116 | ]
1117 | },
1118 | {
1119 | "cell_type": "code",
1120 | "metadata": {
1121 | "colab": {
1122 | "base_uri": "https://localhost:8080/",
1123 | "height": 66,
1124 | "referenced_widgets": [
1125 | "c4da41ffca2d4809a64ca7c3b4375bab",
1126 | "f8aa0656efa64e5385ec59a765939770",
1127 | "59dfeb1cd7f042eba3faa1ce8263eb0f",
1128 | "c0aa55048c3b41f097d1583b02dc3c45",
1129 | "dc06957d389e4995acbd22e23bdc8cef",
1130 | "53cf87e674334b27ad3a48422ec20030",
1131 | "9257c9f9130d4f47a05e3066eec6fffd",
1132 | "134178a4421b41de93598e8e0f08dcfb"
1133 | ]
1134 | },
1135 | "id": "pJjnroUuWOdg",
1136 | "outputId": "aad8ace8-3731-49da-d753-6649fb6ecd52"
1137 | },
1138 | "source": [
1139 | "xlmr_tokenizer = XLMRobertaTokenizer.from_pretrained(\"xlm-roberta-base\")"
1140 | ],
1141 | "execution_count": null,
1142 | "outputs": [
1143 | {
1144 | "output_type": "display_data",
1145 | "data": {
1146 | "application/vnd.jupyter.widget-view+json": {
1147 | "model_id": "c4da41ffca2d4809a64ca7c3b4375bab",
1148 | "version_major": 2,
1149 | "version_minor": 0
1150 | },
1151 | "text/plain": [
1152 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…"
1153 | ]
1154 | },
1155 | "metadata": {
1156 | "tags": []
1157 | }
1158 | },
1159 | {
1160 | "output_type": "stream",
1161 | "text": [
1162 | "\n"
1163 | ],
1164 | "name": "stdout"
1165 | }
1166 | ]
1167 | },
1168 | {
1169 | "cell_type": "code",
1170 | "metadata": {
1171 | "id": "9v7WbIW6WV8D"
1172 | },
1173 | "source": [
1174 | "def tokenizer_xlm(data, max_len):\n",
1175 | " labels_ = []\n",
1176 | " input_ids_ = []\n",
1177 | " attn_masks_ = []\n",
1178 | "\n",
1179 | " # for each datasample:\n",
1180 | " for index, row in data.iterrows():\n",
1181 | "\n",
1182 | " sentence = row['n_gram']+\". \"+row[\"Context\"]\n",
1183 | " #print(sentence)\n",
1184 | " \n",
1185 | " # create requiered input, i.e. ids and attention masks\n",
1186 | " encoded_dict = xlmr_tokenizer.encode_plus(sentence,\n",
1187 | " max_length=max_len, \n",
1188 | " padding='max_length',\n",
1189 | " truncation=True, \n",
1190 | " return_tensors='pt')\n",
1191 | "\n",
1192 | " # add encoded sample to lists\n",
1193 | " input_ids_.append(encoded_dict['input_ids'])\n",
1194 | " attn_masks_.append(encoded_dict['attention_mask'])\n",
1195 | " labels_.append(row['Label'])\n",
1196 | " \n",
1197 | " # Convert each Python list of Tensors into a 2D Tensor matrix.\n",
1198 | " input_ids_ = torch.cat(input_ids_, dim=0)\n",
1199 | " attn_masks_ = torch.cat(attn_masks_, dim=0)\n",
1200 | "\n",
1201 | " # labels to tensor\n",
1202 | " labels_ = torch.tensor(labels_)\n",
1203 | "\n",
1204 | " print('Encoder finished. {:,} examples.'.format(len(labels_)))\n",
1205 | " return input_ids_, attn_masks_, labels_"
1206 | ],
1207 | "execution_count": null,
1208 | "outputs": []
1209 | },
1210 | {
1211 | "cell_type": "code",
1212 | "metadata": {
1213 | "colab": {
1214 | "base_uri": "https://localhost:8080/"
1215 | },
1216 | "id": "RcCexBG1ZuP_",
1217 | "outputId": "c2641990-2539-4ee3-b53e-aee04dfb052b"
1218 | },
1219 | "source": [
1220 | "#tokenize input for the different training/test sets\n",
1221 | "max_len=64\n",
1222 | "\n",
1223 | "input_ids_train, attn_masks_train, labels_all_train = tokenizer_xlm(trainings_data_df, max_len)\n",
1224 | "\n",
1225 | "input_ids_valid, attn_masks_valid, labels_all_valid = tokenizer_xlm(valid_data_df, max_len)\n",
1226 | "\n",
1227 | "input_ids_test_en, attn_masks_test_en, labels_test_en = tokenizer_xlm(test_data_df_en, max_len)\n",
1228 | "input_ids_test_fr, attn_masks_test_fr, labels_test_fr = tokenizer_xlm(test_data_df_fr, max_len)\n",
1229 | "input_ids_test_nl, attn_masks_test_nl, labels_test_nl = tokenizer_xlm(test_data_df_nl, max_len)"
1230 | ],
1231 | "execution_count": null,
1232 | "outputs": [
1233 | {
1234 | "output_type": "stream",
1235 | "text": [
1236 | "Encoder finished. 97,278 examples.\n",
1237 | "Encoder finished. 312,754 examples.\n",
1238 | "Encoder finished. 304,710 examples.\n",
1239 | "Encoder finished. 303,069 examples.\n",
1240 | "Encoder finished. 292,615 examples.\n"
1241 | ],
1242 | "name": "stdout"
1243 | }
1244 | ]
1245 | },
1246 | {
1247 | "cell_type": "code",
1248 | "metadata": {
1249 | "id": "nLCLiW9-Nkd-"
1250 | },
1251 | "source": [
1252 | "# create datasets\n",
1253 | "train_dataset = TensorDataset(input_ids_train, attn_masks_train, labels_all_train)\n",
1254 | "\n",
1255 | "valid_dataset = TensorDataset(input_ids_valid, attn_masks_valid, labels_all_valid)\n",
1256 | "\n",
1257 | "test_dataset_en = TensorDataset(input_ids_test_en, attn_masks_test_en, labels_test_en)\n",
1258 | "test_dataset_fr = TensorDataset(input_ids_test_fr, attn_masks_test_fr, labels_test_fr)\n",
1259 | "test_dataset_nl = TensorDataset(input_ids_test_nl, attn_masks_test_nl, labels_test_nl)"
1260 | ],
1261 | "execution_count": null,
1262 | "outputs": []
1263 | },
1264 | {
1265 | "cell_type": "code",
1266 | "metadata": {
1267 | "id": "Si-ng4T8Ny2O"
1268 | },
1269 | "source": [
1270 | "# create dataloaders\n",
1271 | "batch_size = 32\n",
1272 | "\n",
1273 | "train_dataloader = DataLoader(train_dataset, sampler = RandomSampler(train_dataset), batch_size = batch_size) #random sampling\n",
1274 | "valid_dataloader = DataLoader(valid_dataset, sampler = SequentialSampler(valid_dataset),batch_size = batch_size ) #sequential sampling\n",
1275 | "\n",
1276 | "test_dataloader_en = DataLoader(test_dataset_en, sampler = SequentialSampler(test_dataset_en),batch_size = batch_size ) #sequential sampling\n",
1277 | "test_dataloader_fr = DataLoader(test_dataset_fr, sampler = SequentialSampler(test_dataset_fr),batch_size = batch_size ) #sequential sampling\n",
1278 | "test_dataloader_nl = DataLoader(test_dataset_nl, sampler = SequentialSampler(test_dataset_nl),batch_size = batch_size ) #sequential sampling"
1279 | ],
1280 | "execution_count": null,
1281 | "outputs": []
1282 | },
1283 | {
1284 | "cell_type": "markdown",
1285 | "metadata": {
1286 | "id": "Hart2Y_ia5qD"
1287 | },
1288 | "source": [
1289 | "#Model"
1290 | ]
1291 | },
1292 | {
1293 | "cell_type": "code",
1294 | "metadata": {
1295 | "id": "sF72Sc2ur-ds"
1296 | },
1297 | "source": [
1298 | "def create_model(lr, eps, train_dataloader, epochs, device):\n",
1299 | " xlmr_model = XLMRobertaForSequenceClassification.from_pretrained(\"xlm-roberta-base\", num_labels=2)\n",
1300 | " desc = xlmr_model.to(device)\n",
1301 | " print('Connected to GPU:', torch.cuda.get_device_name(0))\n",
1302 | " optimizer = AdamW(xlmr_model.parameters(),\n",
1303 | " lr = lr, \n",
1304 | " eps = eps \n",
1305 | " )\n",
1306 | " total_steps = len(train_dataloader) * epochs\n",
1307 | " scheduler = get_linear_schedule_with_warmup(optimizer, \n",
1308 | " num_warmup_steps = 0, \n",
1309 | " num_training_steps = total_steps)\n",
1310 | " return xlmr_model, optimizer, scheduler"
1311 | ],
1312 | "execution_count": null,
1313 | "outputs": []
1314 | },
1315 | {
1316 | "cell_type": "code",
1317 | "metadata": {
1318 | "id": "R7acJSCUtHN6"
1319 | },
1320 | "source": [
1321 | "def format_time(elapsed):\n",
1322 | " '''\n",
1323 | " Takes a time in seconds and returns a string hh:mm:ss\n",
1324 | " '''\n",
1325 | " elapsed_rounded = int(round((elapsed)))\n",
1326 | " return str(datetime.timedelta(seconds=elapsed_rounded)) "
1327 | ],
1328 | "execution_count": null,
1329 | "outputs": []
1330 | },
1331 | {
1332 | "cell_type": "code",
1333 | "metadata": {
1334 | "id": "YsxS3wVltI5i"
1335 | },
1336 | "source": [
1337 | "def validate(validation_dataloader, validation_df, xlmr_model, verbose, print_cm): \n",
1338 | " \n",
1339 | " # put model in evaluation mode \n",
1340 | " xlmr_model.eval()\n",
1341 | "\n",
1342 | " #extract terms and compute scores\n",
1343 | " extracted_terms_equi=extract_terms(train_data_equi_en, xlmr_model)\n",
1344 | " extracted_terms_equi_en = set([item.lower() for item in extracted_terms_equi_en])\n",
1345 | " gold_set_equi_en=set(df_equi_terms_en[\"Term\"])\n",
1346 | " true_pos=extracted_terms_equi_en.intersection(gold_set_equi_en)\n",
1347 | " recall=len(true_pos)/len(gold_set_equi_en)\n",
1348 | " precision=len(true_pos)/len(extracted_terms_equi_en)\n",
1349 | " f1=2*(precision*recall)/(precision+recall)\n",
1350 | "\n",
1351 | " return recall, precision, f1"
1352 | ],
1353 | "execution_count": null,
1354 | "outputs": []
1355 | },
1356 | {
1357 | "cell_type": "code",
1358 | "metadata": {
1359 | "id": "UYBNMpiszm_h"
1360 | },
1361 | "source": [
1362 | "def extract_terms(validation_df, xlmr_model): \n",
1363 | " print(len(validation_df))\n",
1364 | " term_list=[]\n",
1365 | "\n",
1366 | " # put model in evaluation mode \n",
1367 | " xlmr_model.eval()\n",
1368 | "\n",
1369 | " for index, row in validation_df.iterrows():\n",
1370 | " sentence = row['n_gram']+\". \"+row[\"Context\"]\n",
1371 | " label=validation_df[\"Label\"]\n",
1372 | "\n",
1373 | " encoded_dict = xlmr_tokenizer.encode_plus(sentence, \n",
1374 | " max_length=max_len, \n",
1375 | " padding='max_length',\n",
1376 | " truncation=True, \n",
1377 | " return_tensors='pt') \n",
1378 | " input_id=encoded_dict['input_ids'].to(device)\n",
1379 | " attn_mask=encoded_dict['attention_mask'].to(device)\n",
1380 | " label=torch.tensor(0).to(device) \n",
1381 | "\n",
1382 | " with torch.no_grad(): \n",
1383 | " output = xlmr_model(input_id, \n",
1384 | " token_type_ids=None, \n",
1385 | " attention_mask=attn_mask,\n",
1386 | " labels=label)\n",
1387 | " loss=output.loss\n",
1388 | " logits=output.logits\n",
1389 | " \n",
1390 | " logits = logits.detach().cpu().numpy()\n",
1391 | " pred=labels[logits[0].argmax(axis=0)]\n",
1392 | " if pred==\"Term\":\n",
1393 | " term_list.append(row['n_gram'])\n",
1394 | "\n",
1395 | " return set(term_list)\n",
1396 | " "
1397 | ],
1398 | "execution_count": null,
1399 | "outputs": []
1400 | },
1401 | {
1402 | "cell_type": "code",
1403 | "metadata": {
1404 | "id": "zs7cOPFJtLUG"
1405 | },
1406 | "source": [
1407 | "def train_model(epochs, xlmr_model, train_dataloader, validation_dataloader, validation_df, random_seed, verbose, optimizer, scheduler):\n",
1408 | "\n",
1409 | " seed_val = random_seed\n",
1410 | "\n",
1411 | " random.seed(seed_val)\n",
1412 | " np.random.seed(seed_val)\n",
1413 | " torch.manual_seed(seed_val)\n",
1414 | " torch.cuda.manual_seed_all(seed_val)\n",
1415 | "\n",
1416 | " # mostly contains scores about how the training went for each epoch\n",
1417 | " training_stats = []\n",
1418 | "\n",
1419 | " # total training time\n",
1420 | " total_t0 = time.time()\n",
1421 | "\n",
1422 | " print('\\033[1m'+\"================ Model Training ================\"+'\\033[0m')\n",
1423 | "\n",
1424 | " # For each epoch...\n",
1425 | " for epoch_i in range(0, epochs):\n",
1426 | "\n",
1427 | " print(\"\")\n",
1428 | " print('\\033[1m'+'======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)+'\\033[0m')\n",
1429 | "\n",
1430 | " t0 = time.time()\n",
1431 | "\n",
1432 | " # summed training loss of the epoch\n",
1433 | " total_train_loss = 0\n",
1434 | "\n",
1435 | "\n",
1436 | " # model is being put into training mode as mechanisms like dropout work differently during train and test time\n",
1437 | " xlmr_model.train()\n",
1438 | "\n",
1439 | " # iterrate over batches\n",
1440 | " for step, batch in enumerate(train_dataloader):\n",
1441 | "\n",
1442 | " # unpack training batch at load it to gpu (device) \n",
1443 | " b_input_ids = batch[0].to(device)\n",
1444 | " b_input_mask = batch[1].to(device)\n",
1445 | " b_labels = batch[2].to(device)\n",
1446 | "\n",
1447 | " # clear gradients before calculating new ones\n",
1448 | " xlmr_model.zero_grad() \n",
1449 | "\n",
1450 | " # forward pass with current batch\n",
1451 | " output = xlmr_model(b_input_ids, \n",
1452 | " token_type_ids=None, \n",
1453 | " attention_mask=b_input_mask, \n",
1454 | " labels=b_labels)\n",
1455 | " \n",
1456 | " loss=output.loss\n",
1457 | " logits=output.logits\n",
1458 | "\n",
1459 | " # add up the loss\n",
1460 | " total_train_loss += loss.item()\n",
1461 | "\n",
1462 | " # calculate new gradients\n",
1463 | " loss.backward()\n",
1464 | "\n",
1465 | " # gradient clipping (not bigger than)\n",
1466 | " torch.nn.utils.clip_grad_norm_(xlmr_model.parameters(), 1.0)\n",
1467 | "\n",
1468 | " # Update the networks weights based on the gradient as well as the optimiziers parameters\n",
1469 | " optimizer.step()\n",
1470 | "\n",
1471 | " # lr update\n",
1472 | " scheduler.step()\n",
1473 | "\n",
1474 | " # avg loss over all batches\n",
1475 | " avg_train_loss = total_train_loss / len(train_dataloader) \n",
1476 | " \n",
1477 | " # training time of this epoch\n",
1478 | " training_time = format_time(time.time() - t0)\n",
1479 | "\n",
1480 | " print(\"\")\n",
1481 | " print(\" Average training loss: {0:.2f}\".format(avg_train_loss))\n",
1482 | " print(\" Training epoch took: {:}\".format(training_time))\n",
1483 | " \n",
1484 | " \n",
1485 | " # VALIDATION\n",
1486 | " print(\"evaluate\")\n",
1487 | " if epoch_i==epochs-1:print_cm=True #Print out cm in final iteration\n",
1488 | " else: print_cm=False\n",
1489 | " recall, precision, f1 = validate(validation_dataloader, validation_df, xlmr_model, verbose, print_cm) \n",
1490 | " \n",
1491 | "\n",
1492 | " #print('\\033[1m'+ \" Validation Loss All: {0:.2f}\".format(avg_val_loss) + '\\033[0m')\n",
1493 | "\n",
1494 | " training_stats.append(\n",
1495 | " {\n",
1496 | " 'epoch': epoch_i + 1,\n",
1497 | " 'Training Loss': avg_train_loss,\n",
1498 | " \"precision\": precision,\n",
1499 | " \"recall\": recall,\n",
1500 | " \"f1\": f1,\n",
1501 | " 'Training Time': training_time,\n",
1502 | " }\n",
1503 | " )\n",
1504 | "\n",
1505 | " print(\"Precicion\", precision)\n",
1506 | " print(\"Recall\", recall)\n",
1507 | " print(\"F1\", f1)\n",
1508 | "\n",
1509 | " print(\"\\n\\nTraining complete!\")\n",
1510 | " print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))\n",
1511 | " \n",
1512 | " return training_stats\n"
1513 | ],
1514 | "execution_count": null,
1515 | "outputs": []
1516 | },
1517 | {
1518 | "cell_type": "code",
1519 | "metadata": {
1520 | "id": "V1VTPA1anQ2w"
1521 | },
1522 | "source": [
1523 | "lr=2e-5\n",
1524 | "eps=1e-8\n",
1525 | "epochs=3\n",
1526 | "device = torch.device('cuda')\n",
1527 | "xlmr_model, optimizer, scheduler = create_model(lr=lr, eps=eps, train_dataloader=train_dataloader, epochs=epochs, device=device)"
1528 | ],
1529 | "execution_count": null,
1530 | "outputs": []
1531 | },
1532 | {
1533 | "cell_type": "code",
1534 | "metadata": {
1535 | "id": "56767talsn4M"
1536 | },
1537 | "source": [
1538 | "training_stats=train_model(epochs=epochs,\n",
1539 | " xlmr_model=xlmr_model,\n",
1540 | " train_dataloader=train_dataloader,\n",
1541 | " validation_dataloader=valid_dataloader,\n",
1542 | " validation_df=train_data_htfl_en,\n",
1543 | " random_seed=42,\n",
1544 | " verbose=True,\n",
1545 | " optimizer=optimizer,\n",
1546 | " scheduler=scheduler)"
1547 | ],
1548 | "execution_count": null,
1549 | "outputs": []
1550 | },
1551 | {
1552 | "cell_type": "markdown",
1553 | "metadata": {
1554 | "id": "0-9PbANQp4Uj"
1555 | },
1556 | "source": [
1557 | "# Test Set Evaluation"
1558 | ]
1559 | },
1560 | {
1561 | "cell_type": "code",
1562 | "metadata": {
1563 | "id": "kEniE8WRdjF3"
1564 | },
1565 | "source": [
1566 | "extracted_terms_htfl_en=extract_terms(train_data_htfl_en, xlmr_model)\n",
1567 | "extracted_terms_htfl_fr=extract_terms(train_data_htfl_fr, xlmr_model)\n",
1568 | "extracted_terms_htfl_nl=extract_terms(train_data_htfl_nl, xlmr_model)"
1569 | ],
1570 | "execution_count": null,
1571 | "outputs": []
1572 | },
1573 | {
1574 | "cell_type": "code",
1575 | "metadata": {
1576 | "id": "kB_M7qk9xbj5"
1577 | },
1578 | "source": [
1579 | "def computeTermEvalMetrics(extracted_terms, gold_df):\n",
1580 | " #make lower case cause gold standard is lower case\n",
1581 | " extracted_terms = set([item.lower() for item in extracted_terms])\n",
1582 | " gold_set=set(gold_df)\n",
1583 | " true_pos=extracted_terms.intersection(gold_set)\n",
1584 | " recall=len(true_pos)/len(gold_set)\n",
1585 | " precision=len(true_pos)/len(extracted_terms)\n",
1586 | "\n",
1587 | " print(\"Intersection\",len(true_pos))\n",
1588 | " print(\"Gold\",len(gold_set))\n",
1589 | " print(\"Extracted\",len(extracted_terms))\n",
1590 | " print(\"Recall:\", recall)\n",
1591 | " print(\"Precision:\", precision)\n",
1592 | " print(\"F1:\", 2*(precision*recall)/(precision+recall))"
1593 | ],
1594 | "execution_count": null,
1595 | "outputs": []
1596 | },
1597 | {
1598 | "cell_type": "code",
1599 | "metadata": {
1600 | "id": "cABUXjRY1ZHI"
1601 | },
1602 | "source": [
1603 | "computeTermEvalMetrics(extracted_terms_htfl_en, df_htfl_terms_en[\"Term\"])"
1604 | ],
1605 | "execution_count": null,
1606 | "outputs": []
1607 | },
1608 | {
1609 | "cell_type": "code",
1610 | "metadata": {
1611 | "id": "z1B0JipczW_7"
1612 | },
1613 | "source": [
1614 | "computeTermEvalMetrics(extracted_terms_htfl_fr, df_htfl_terms_fr[\"Term\"])"
1615 | ],
1616 | "execution_count": null,
1617 | "outputs": []
1618 | },
1619 | {
1620 | "cell_type": "code",
1621 | "metadata": {
1622 | "id": "KKTVrV0AzXEt"
1623 | },
1624 | "source": [
1625 | "computeTermEvalMetrics(extracted_terms_htfl_nl, df_htfl_terms_nl[\"Term\"])"
1626 | ],
1627 | "execution_count": null,
1628 | "outputs": []
1629 | }
1630 | ]
1631 | }
--------------------------------------------------------------------------------